Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature conv3d #53

Merged
merged 18 commits into from
May 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
278 changes: 277 additions & 1 deletion crates/luminal_nn/src/convolution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,194 @@ impl<
.reshape::<R3<CH_OUT, DIMX_OUT, DIMY_OUT>>()
}
}
pub struct Conv3D<
const CH_IN: usize,
const CH_OUT: usize,
const KERNELX: usize,
const KERNELY: usize,
const KERNELZ: usize,
const STRIDEX: usize,
const STRIDEY: usize,
const STRIDEZ: usize,
const DILATIONX: usize,
const DILATIONY: usize,
const DILATIONZ: usize,
const DIMX_TIMES_KERNELX: usize,
const DIMX_TIMES_DIMY_DIMZ_OUT: usize
> {
pub weight: GraphTensor<R5<CH_OUT, CH_IN, KERNELX, KERNELY, KERNELZ>>,
}

impl<
const CH_IN: usize,
const CH_OUT: usize,
const KERNELX: usize,
const KERNELY: usize,
const KERNELZ: usize,
const STRIDEX: usize,
const STRIDEY: usize,
const STRIDEZ: usize,
const DILATIONX: usize,
const DILATIONY: usize,
const DILATIONZ: usize,
const DIMX_TIMES_KERNELX: usize,
const DIMX_TIMES_DIMY_DIMZ_OUT: usize
> InitModule
for Conv3D<
CH_IN,
CH_OUT,
KERNELX,
KERNELY,
KERNELZ,
STRIDEX,
STRIDEY,
STRIDEZ,
DILATIONX,
DILATIONY,
DILATIONZ,
DIMX_TIMES_KERNELX,
DIMX_TIMES_DIMY_DIMZ_OUT,
>
{
fn initialize(cx: &mut Graph) -> Self {
// Init weight as uniform(-1, 1)
let mut rng = thread_rng();
Self {
weight: cx.named_tensor("Weight").set(
(0..(CH_IN * CH_OUT * KERNELX * KERNELY * KERNELZ))
.map(|_| rng.gen_range(-1_f32..1_f32))
.collect::<Vec<_>>(),
),
}
}
}

impl<
const CH_IN: usize,
const CH_OUT: usize,
const KERNELX: usize,
const KERNELY: usize,
const KERNELZ: usize,
const STRIDEX: usize,
const STRIDEY: usize,
const STRIDEZ: usize,
const DILATIONX: usize,
const DILATIONY: usize,
const DILATIONZ: usize,
const DIMX_TIMES_KERNELX: usize,
const DIMX_TIMES_DIMY_DIMZ_OUT: usize,
> SerializeModule
for Conv3D<
CH_IN,
CH_OUT,
KERNELX,
KERNELY,
KERNELZ,
STRIDEX,
STRIDEY,
STRIDEZ,
DILATIONX,
DILATIONY,
DILATIONZ,
DIMX_TIMES_KERNELX,
DIMX_TIMES_DIMY_DIMZ_OUT,
>
{
fn serialize(&self, s: &mut luminal::module::Serializer) {
s.tensor("weight", self.weight);
}
}

impl<
const CH_IN: usize,
const CH_OUT: usize,
const KERNELX: usize,
const KERNELY: usize,
const KERNELZ: usize,
const STRIDEX: usize,
const STRIDEY: usize,
const STRIDEZ: usize,
const DILATIONX: usize,
const DILATIONY: usize,
const DILATIONZ: usize,
const DIMX_TIMES_KERNELX: usize,
const DIMX_TIMES_DIMY_DIMZ_OUT: usize,
>
Conv3D<
CH_IN,
CH_OUT,
KERNELX,
KERNELY,
KERNELZ,
STRIDEX,
STRIDEY,
STRIDEZ,
DILATIONX,
DILATIONY,
DILATIONZ,
DIMX_TIMES_KERNELX,
DIMX_TIMES_DIMY_DIMZ_OUT,
>
{
pub fn forward<
const DIMX_IN: usize,
const DIMY_IN: usize,
const DIMZ_IN: usize,
const DIMX_OUT: usize,
const DIMY_OUT: usize,
const DIMZ_OUT: usize,
>(
&self,
input: GraphTensor<R4<CH_IN, DIMX_IN, DIMY_IN, DIMZ_IN>>,
) -> GraphTensor<R4<CH_OUT, DIMX_OUT, DIMY_OUT, DIMZ_OUT>> {
let input_pooled = input
.pool_last_dim::<R5<CH_IN, DIMX_IN, DIMY_OUT, DIMZ_OUT, KERNELY>>(
KERNELY.into(),
STRIDEY.into(),
DILATIONY
)
.permute::<_, Axes5<0, 2, 3, 4, 1>>()
.pool_last_dim::<R6<CH_IN, DIMY_OUT, DIMZ_OUT, KERNELY, DIMX_OUT, KERNELX>>(
KERNELX.into(),
STRIDEX.into(),
DILATIONX
)
.dyn_reshape::<(Const<CH_IN>, Dyn<'-'>)>(vec![
CH_IN.into(),
DIMZ_OUT.into(),
KERNELY.into(),
(DIMX_OUT * KERNELX).into(),
DIMY_IN.into()
]);

let last_pool = input_pooled
.pool_last_dim::<R6<CH_IN, DIMZ_OUT, KERNELY, DIMX_TIMES_KERNELX, DIMY_IN, KERNELZ>>(
KERNELZ.into(),
STRIDEZ.into(),
DILATIONZ
)
.permute::<_, Axes6<0, 2, 5, 3, 1, 4>>();

let reshaped = last_pool
.dyn_reshape::<(_, Dyn<'-'>)>(vec![
(CH_IN * KERNELX * KERNELY * KERNELZ).into(),
DIMX_TIMES_DIMY_DIMZ_OUT.into(),
]);

self.weight
.dyn_reshape::<(Const<CH_OUT>, Dyn<'-'>)>(vec![
CH_OUT.into(),
(CH_IN * KERNELX * KERNELY * KERNELZ).into(),
])
.matmul(reshaped)
.reshape::<R4<CH_OUT, DIMX_OUT, DIMY_OUT, DIMZ_OUT>>()
}
}


#[cfg(test)]
mod tests {
use super::{Conv1D, Conv2D};
use super::{Conv1D, Conv2D, Conv3D};
use luminal::{prelude::*, tests::assert_close};

#[test]
Expand Down Expand Up @@ -334,4 +518,96 @@ mod tests {

assert_close(&out1.data(), &exp_out1.data())
}

#[test]
fn test_conv3d() {
let mut cx = Graph::new();

const CH_IN: usize = 5;
const CH_OUT: usize = 2;
const KERNELX: usize = 2;
const KERNELY: usize = 2;
const KERNELZ: usize = 2;
const STRIDEX: usize = 2;
const STRIDEY: usize = 2;
const STRIDEZ: usize = 2;
const DILATIONX: usize = 0;
const DILATIONY: usize = 0;
const DILATIONZ: usize = 0;
const DIMX_IN: usize = 2;
const DIMY_IN: usize = 3;
const DIMZ_IN: usize = 5;
const DIMX_OUT: usize = ((DIMX_IN - (DILATIONX + 1) * (KERNELX - 1) - 1) / STRIDEX) + 1;
const DIMY_OUT: usize = ((DIMY_IN - (DILATIONY + 1) * (KERNELY - 1) - 1) / STRIDEY) + 1;
const DIMZ_OUT: usize = ((DIMZ_IN - (DILATIONZ + 1) * (KERNELZ - 1) - 1) / STRIDEZ) + 1;
const DIMX_TIMES_KERNELX: usize = DIMX_OUT * KERNELX;
const DIMX_TIMES_DIMY_DIMZ_OUT:usize = DIMX_OUT * DIMY_OUT * DIMZ_OUT;

let inp1 = cx.tensor::<R4<CH_IN, DIMX_IN, DIMY_IN, DIMZ_IN>>();
inp1.set(vec![
// Example input data (5 channels, 2x3x5 volume)
8., 8., 5., 7., 0., 6., 5., 3., 0., 7., 0., 6., 6., 7., 7., 5., 0., 6., 9., 4., 0., 8.,
8., 5., 7., 6., 2., 8., 9., 5., 0., 3., 1., 1., 8., 4., 1., 1., 5., 6., 9., 3., 2., 9.,
4., 7., 1., 0., 7., 7., 4., 9., 5., 0., 4., 7., 4., 7., 8., 8., 4., 8., 4., 7., 9., 3.,
7., 9., 5., 8., 5., 9., 0., 9., 5., 6., 8., 9., 5., 4., 1., 9., 7., 2., 2., 7., 9., 3.,
1., 2., 8., 4., 0., 8., 0., 5., 6., 7., 7., 4., 3., 4., 6., 8., 3., 7., 8., 8., 7., 1.,
5., 1., 8., 0., 1., 1., 7., 3., 2., 1., 0., 4., 5., 4., 3., 2., 5., 4., 2., 4., 1., 9.,
4., 1., 9., 7., 7., 1., 2., 6., 3., 4., 1., 1., 6., 6., 8., 2., 7., 7.,
]);

let exp_out1 = cx.tensor::<R4<CH_OUT, DIMX_OUT, DIMY_OUT, DIMZ_OUT>>();
exp_out1.set(vec![
// Example expected output data (2 channels, 1x1x2 volume)
90.6935, 98.7138, 98.8273, 102.6553,
]);

exp_out1.retrieve();

let model: Conv3D<
CH_IN,
CH_OUT,
KERNELX,
KERNELY,
KERNELZ,
STRIDEX,
STRIDEY,
STRIDEZ,
DILATIONX,
DILATIONY,
DILATIONZ,
DIMX_TIMES_KERNELX,
DIMX_TIMES_DIMY_DIMZ_OUT,
> = Conv3D::initialize(&mut cx);
let weights = vec![
4.273e-01, 1.388e-01, 3.546e-01, 2.403e-01,
5.572e-01, 2.788e-01, 6.718e-01, 6.935e-01,
8.410e-01, 1.297e-01, 7.073e-01, 3.455e-01,
4.166e-01, 9.513e-01, 4.682e-01, 4.546e-02,
5.061e-01, 4.117e-01, 1.667e-01, 5.557e-02,
6.092e-01, 9.675e-01, 7.083e-01, 7.946e-01,
3.518e-01, 4.697e-01, 6.052e-01, 6.832e-01,
2.312e-02, 6.932e-01, 6.135e-01, 9.216e-01,
8.011e-01, 1.971e-01, 7.086e-01, 2.394e-01,
3.663e-01, 6.619e-01, 4.211e-01, 1.852e-01,
8.635e-01, 1.311e-01, 4.206e-01, 5.413e-01,
7.938e-01, 9.604e-01, 7.966e-01, 7.400e-01,
3.212e-01, 4.644e-01, 3.224e-01, 1.123e-01,
4.000e-01, 7.678e-01, 7.545e-01, 9.423e-01,
5.605e-02, 2.675e-02, 5.022e-02, 8.632e-01,
9.305e-01, 9.836e-01, 1.635e-01, 2.379e-01,
9.291e-01, 4.029e-01, 6.675e-01, 4.912e-01,
8.904e-01, 6.938e-01, 9.581e-01, 1.720e-01,
7.835e-01, 4.658e-04, 2.818e-01, 5.373e-01,
3.437e-01, 1.254e-01, 6.868e-02, 7.546e-01,
];
model.weight.set(weights);

let out1 = model
.forward::<DIMX_IN, DIMY_IN, DIMZ_IN, DIMX_OUT, DIMY_OUT, DIMZ_OUT>(inp1)
.retrieve();

cx.execute();

assert_close(&out1.data(), &exp_out1.data());
}
}
10 changes: 10 additions & 0 deletions src/shape/permute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,15 @@ macro_rules! impl_permute {
{
}
};
($Ax0:tt, $Ax1:tt, $Ax2:tt, $Ax3:tt, $Ax4:tt, $Ax5:tt) => {
impl<D1, D2, D3, D4, D5, D6>
PermuteShapeTo<
(d!($Ax0), d!($Ax1), d!($Ax2), d!($Ax3), d!($Ax4), d!($Ax5)),
Axes6<$Ax0, $Ax1, $Ax2, $Ax3, $Ax4, $Ax5>,
> for (D1, D2, D3, D4, D5, D6)
{
}
};
}

/// Expand out all the possible permutations for 2-4d
Expand Down Expand Up @@ -125,3 +134,4 @@ permutations!([0, 1]);
permutations!([0, 1, 2]);
permutations!([0, 1, 2, 3]);
permutations!([0, 1, 2, 3, 4]);
permutations!([0, 1, 2, 3, 4, 5]);
Loading