Skip to content

Commit 114cb0c

Browse files
authored
add dataset version of inverse PCA (#125)
* add dataset version of inverse PCA * PCA: Add whitening to batch inverse transform
1 parent 74c3d05 commit 114cb0c

File tree

2 files changed

+54
-1
lines changed

2 files changed

+54
-1
lines changed

include/algorithms/public/PCA.hpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,29 @@ class PCA
113113
return variance / total;
114114
}
115115

116+
void inverseProcess(RealMatrixView in, RealMatrixView out, bool whiten = false) const
117+
{
118+
using namespace Eigen;
119+
120+
if (in.cols() > dims()) return;
121+
if (out.cols() < in.cols()) return;
122+
123+
if (!whiten)
124+
_impl::asEigen<Matrix>(out) =
125+
(_impl::asEigen<Matrix>(in) * mBases.transpose()).rowwise() +
126+
mMean.transpose();
127+
128+
else
129+
{
130+
_impl::asEigen<Matrix>(out) =
131+
(_impl::asEigen<Matrix>(in) *
132+
(mExplainedVariance.sqrt().matrix().asDiagonal() *
133+
mBases.transpose()))
134+
.rowwise() +
135+
mMean.transpose();
136+
}
137+
}
138+
116139
bool initialized() const { return mInitialized; }
117140

118141
void getBases(RealMatrixView out) const { out <<= _impl::asFluid(mBases); }

include/clients/nrt/PCAClient.hpp

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,35 @@ class PCAClient : public FluidBaseClient,
111111
return result;
112112
}
113113

114+
MessageResult<void> inverseTransform(InputDataSetClientRef sourceClient,
115+
DataSetClientRef destClient) const
116+
{
117+
118+
auto srcPtr = sourceClient.get().lock();
119+
auto destPtr = destClient.get().lock();
120+
121+
if (srcPtr && destPtr)
122+
{
123+
auto srcDataSet = srcPtr->getDataSet();
124+
if (srcDataSet.size() == 0) return Error<void>(EmptyDataSet);
125+
if (!mAlgorithm.initialized()) return Error<void>(NoDataFitted);
126+
StringVector ids{srcDataSet.getIds()};
127+
RealMatrix paddedInput(srcPtr->size(), mAlgorithm.dims());
128+
auto inputData = srcDataSet.getData();
129+
paddedInput(Slice(0, inputData.rows()), Slice(0, inputData.cols())) <<=
130+
inputData;
131+
RealMatrix output(srcDataSet.size(), mAlgorithm.dims());
132+
mAlgorithm.inverseProcess(paddedInput, output,get<kWhiten>() == 1);
133+
FluidDataSet<string, double, 1> result(ids, output);
134+
destPtr->setDataSet(result);
135+
return {};
136+
}
137+
else
138+
{
139+
return Error<void>(NoDataSet);
140+
}
141+
}
142+
114143
MessageResult<void> transformPoint(InputBufferPtr in, BufferPtr out) const
115144
{
116145
index k = get<kNumDimensions>();
@@ -150,7 +179,7 @@ class PCAClient : public FluidBaseClient,
150179
Result resizeResult = outBuf.resize(mAlgorithm.dims(), 1, outBuf.sampleRate());
151180

152181
mAlgorithm.inverseProcessFrame(src, dst, get<kWhiten>());
153-
outBuf.samps(0,mAlgorithm.dims(),0)<< = dst;
182+
outBuf.samps(0,mAlgorithm.dims(),0) <<= dst;
154183
return OK();
155184
}
156185

@@ -160,6 +189,7 @@ class PCAClient : public FluidBaseClient,
160189
makeMessage("fit", &PCAClient::fit),
161190
makeMessage("transform", &PCAClient::transform),
162191
makeMessage("fitTransform", &PCAClient::fitTransform),
192+
makeMessage("inverseTransform",&PCAClient::inverseTransform),
163193
makeMessage("transformPoint", &PCAClient::transformPoint),
164194
makeMessage("inverseTransformPoint", &PCAClient::inverseTransformPoint),
165195
makeMessage("cols", &PCAClient::dims),

0 commit comments

Comments
 (0)