diff --git a/include/algorithms/public/PCA.hpp b/include/algorithms/public/PCA.hpp index 4b7f05c31..28362f551 100644 --- a/include/algorithms/public/PCA.hpp +++ b/include/algorithms/public/PCA.hpp @@ -113,6 +113,29 @@ class PCA return variance / total; } + void inverseProcess(RealMatrixView in, RealMatrixView out, bool whiten = false) const + { + using namespace Eigen; + + if (in.cols() > dims()) return; + if (out.cols() < in.cols()) return; + + if (!whiten) + _impl::asEigen(out) = + (_impl::asEigen(in) * mBases.transpose()).rowwise() + + mMean.transpose(); + + else + { + _impl::asEigen(out) = + (_impl::asEigen(in) * + (mExplainedVariance.sqrt().matrix().asDiagonal() * + mBases.transpose())) + .rowwise() + + mMean.transpose(); + } + } + bool initialized() const { return mInitialized; } void getBases(RealMatrixView out) const { out <<= _impl::asFluid(mBases); } diff --git a/include/clients/nrt/PCAClient.hpp b/include/clients/nrt/PCAClient.hpp index b5ac1dd4b..9ce8ea5ce 100644 --- a/include/clients/nrt/PCAClient.hpp +++ b/include/clients/nrt/PCAClient.hpp @@ -111,6 +111,35 @@ class PCAClient : public FluidBaseClient, return result; } + MessageResult inverseTransform(InputDataSetClientRef sourceClient, + DataSetClientRef destClient) const + { + + auto srcPtr = sourceClient.get().lock(); + auto destPtr = destClient.get().lock(); + + if (srcPtr && destPtr) + { + auto srcDataSet = srcPtr->getDataSet(); + if (srcDataSet.size() == 0) return Error(EmptyDataSet); + if (!mAlgorithm.initialized()) return Error(NoDataFitted); + StringVector ids{srcDataSet.getIds()}; + RealMatrix paddedInput(srcPtr->size(), mAlgorithm.dims()); + auto inputData = srcDataSet.getData(); + paddedInput(Slice(0, inputData.rows()), Slice(0, inputData.cols())) <<= + inputData; + RealMatrix output(srcDataSet.size(), mAlgorithm.dims()); + mAlgorithm.inverseProcess(paddedInput, output,get() == 1); + FluidDataSet result(ids, output); + destPtr->setDataSet(result); + return {}; + } + else + { + return Error(NoDataSet); + } + } + MessageResult transformPoint(InputBufferPtr in, BufferPtr out) const { index k = get(); @@ -150,7 +179,7 @@ class PCAClient : public FluidBaseClient, Result resizeResult = outBuf.resize(mAlgorithm.dims(), 1, outBuf.sampleRate()); mAlgorithm.inverseProcessFrame(src, dst, get()); - outBuf.samps(0,mAlgorithm.dims(),0)<< = dst; + outBuf.samps(0,mAlgorithm.dims(),0) <<= dst; return OK(); } @@ -160,6 +189,7 @@ class PCAClient : public FluidBaseClient, makeMessage("fit", &PCAClient::fit), makeMessage("transform", &PCAClient::transform), makeMessage("fitTransform", &PCAClient::fitTransform), + makeMessage("inverseTransform",&PCAClient::inverseTransform), makeMessage("transformPoint", &PCAClient::transformPoint), makeMessage("inverseTransformPoint", &PCAClient::inverseTransformPoint), makeMessage("cols", &PCAClient::dims),