Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions host/host_tensor/include/host_conv_wrw.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
#pragma once
#include "host_tensor.hpp"

template <typename TOut,
typename TIn,
typename TWei,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void host_direct_convolution_backward_weights(
const Tensor<TOut>& out,
const Tensor<TIn>& in,
Tensor<TWei>& wei,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in wei out or out in wei?

Copy link
Copy Markdown
Contributor Author

@zjing14 zjing14 Aug 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(in, out, wei). I should change the order to out, in, wei.

const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads&,
const ConvTensorLayout layout = ConvTensorLayout::NCHW)
{
using namespace ck;

constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
auto f_kcyx = [&](auto k, auto c, auto y, auto x) {
double v = 0;
for(int n = 0; n < out.mDesc.GetLengths()[0]; ++n)
{
for(int ho = 0; ho < out.mDesc.GetLengths()[2]; ++ho)
{
int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0];
for(int wo = 0; wo < out.mDesc.GetLengths()[3]; ++wo)
{
int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1];
if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 &&
wi < in.mDesc.GetLengths()[3])
{
v += static_cast<const double>(in(n, c, hi, wi)) *
static_cast<const double>(out(n, k, ho, wo));
}
}
}
}
wei(k, c, y, x) = v;
};

auto f_kyxc = [&](auto k, auto y, auto x, auto c) {
double v = 0;
for(int n = 0; n < out.mDesc.GetLengths()[0]; ++n)
{
for(int ho = 0; ho < out.mDesc.GetLengths()[1]; ++ho)
{
int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0];
for(int wo = 0; wo < out.mDesc.GetLengths()[2]; ++wo)
{
int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1];
if(hi >= 0 && hi < in.mDesc.GetLengths()[1] && wi >= 0 &&
wi < in.mDesc.GetLengths()[2])
{
v += static_cast<const double>(in(n, hi, wi, c)) *
static_cast<const double>(out(n, ho, wo, k));
}
}
}
}
wei(k, y, x, c) = v;
};

if(layout == ConvTensorLayout::NCHW)
{
make_ParallelTensorFunctor(f_kcyx,
wei.mDesc.GetLengths()[0],
wei.mDesc.GetLengths()[1],
wei.mDesc.GetLengths()[2],
wei.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
}
else if(layout == ConvTensorLayout::NHWC)
{
make_ParallelTensorFunctor(f_kyxc,
wei.mDesc.GetLengths()[0],
wei.mDesc.GetLengths()[1],
wei.mDesc.GetLengths()[2],
wei.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
}
else
{
throw std::runtime_error("wrong! not supported layout");
}
}