Skip to content

Commit 1d4e06e

Browse files
g-romatremblap
andauthored
Feature/skmeans (#99)
* add pca whitening * actually add pca whitening * add spherical kmeans * actually add spherical kmeans * SKMeans fixes, change KMeans getDistances to transform * adding RT query class * <fit>transform<point> -> <fit>encode<point> Co-authored-by: tremblap <[email protected]>
1 parent 483503b commit 1d4e06e

File tree

6 files changed

+507
-5
lines changed

6 files changed

+507
-5
lines changed

FlucomaClients.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ add_client(DataSetQuery clients/nrt/DataSetQueryClient.hpp CLASS NRTThreadedData
144144
add_client(LabelSet clients/nrt/LabelSetClient.hpp CLASS NRTThreadedLabelSetClient GROUP MANIPULATION)
145145
add_client(KDTree clients/nrt/KDTreeClient.hpp CLASS NRTThreadedKDTreeClient GROUP MANIPULATION)
146146
add_client(KMeans clients/nrt/KMeansClient.hpp CLASS NRTThreadedKMeansClient GROUP MANIPULATION)
147+
add_client(SKMeans clients/nrt/SKMeansClient.hpp CLASS NRTThreadedSKMeansClient GROUP MANIPULATION)
147148
add_client(KNNClassifier clients/nrt/KNNClassifierClient.hpp CLASS NRTThreadedKNNClassifierClient GROUP MANIPULATION)
148149
add_client(KNNRegressor clients/nrt/KNNRegressorClient.hpp CLASS NRTThreadedKNNRegressorClient GROUP MANIPULATION)
149150
add_client(Normalize clients/nrt/NormalizeClient.hpp CLASS NRTThreadedNormalizeClient GROUP MANIPULATION)

include/algorithms/public/KMeans.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ class KMeans
109109
out <<= _impl::asFluid(mAssignments);
110110
}
111111

112-
void getDistances(RealMatrixView data, RealMatrixView out) const
112+
void transform(RealMatrixView data, RealMatrixView out) const
113113
{
114114
Eigen::ArrayXXd points = _impl::asEigen<Eigen::Array>(data);
115115
Eigen::ArrayXXd D = fluid::algorithm::DistanceMatrix(points, 2);
@@ -118,8 +118,8 @@ class KMeans
118118
out <<= _impl::asFluid(D);
119119
}
120120

121-
private:
122-
double distance(Eigen::ArrayXd v1, Eigen::ArrayXd v2) const
121+
protected:
122+
double distance(const Eigen::ArrayXd& v1, const Eigen::ArrayXd& v2) const
123123
{
124124
return (v1 - v2).matrix().norm();
125125
}
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
/*
2+
Part of the Fluid Corpus Manipulation Project (http://www.flucoma.org/)
3+
Copyright 2017-2019 University of Huddersfield.
4+
Licensed under the BSD-3 License.
5+
See license.md file in the project root for full license information.
6+
This project has received funding from the European Research Council (ERC)
7+
under the European Union’s Horizon 2020 research and innovation programme
8+
(grant agreement No 725899).
9+
*/
10+
11+
#pragma once
12+
13+
#include "../public/KMeans.hpp"
14+
#include "../util/FluidEigenMappings.hpp"
15+
#include "../../data/FluidDataSet.hpp"
16+
#include "../../data/FluidIndex.hpp"
17+
#include "../../data/FluidTensor.hpp"
18+
#include "../../data/TensorTypes.hpp"
19+
#include <Eigen/Core>
20+
#include <queue>
21+
#include <string>
22+
23+
namespace fluid {
24+
namespace algorithm {
25+
26+
class SKMeans : public KMeans
27+
{
28+
29+
public:
30+
void train(const FluidDataSet<std::string, double, 1>& dataset, index k,
31+
index maxIter)
32+
{
33+
using namespace Eigen;
34+
using namespace _impl;
35+
assert(!mTrained || (dataset.pointSize() == mDims && mK == k));
36+
MatrixXd dataPoints = asEigen<Matrix>(dataset.getData());
37+
MatrixXd dataPointsT = dataPoints.transpose();
38+
if (mTrained) { mAssignments = assignClusters(dataPointsT);}
39+
else
40+
{
41+
mK = k;
42+
mDims = dataset.pointSize();
43+
initMeans(dataPoints);
44+
}
45+
46+
while (maxIter-- > 0)
47+
{
48+
mEmbedding = mMeans.matrix() * dataPointsT;
49+
auto assignments = assignClusters(mEmbedding);
50+
if (!changed(assignments)) { break; }
51+
else
52+
mAssignments = assignments;
53+
updateEmbedding();
54+
computeMeans(dataPoints);
55+
}
56+
mTrained = true;
57+
}
58+
59+
60+
void encode(RealMatrixView data, RealMatrixView out,
61+
double alpha = 0.25) const
62+
{
63+
using namespace Eigen;
64+
MatrixXd points = _impl::asEigen<Matrix>(data).transpose();
65+
MatrixXd embedding = (mMeans.matrix() * points).array() - alpha;
66+
embedding = (embedding.array() > 0).select(embedding, 0).transpose();
67+
out <<= _impl::asFluid(embedding);
68+
}
69+
70+
private:
71+
72+
void initMeans(Eigen::MatrixXd& dataPoints)
73+
{
74+
using namespace Eigen;
75+
mMeans = ArrayXXd::Zero(mK, mDims);
76+
mAssignments =
77+
((0.5 + (0.5 * ArrayXd::Random(dataPoints.rows()))) * (mK - 1))
78+
.round()
79+
.cast<int>();
80+
mEmbedding = MatrixXd::Zero(mK, dataPoints.rows());
81+
for (index i = 0; i < dataPoints.rows(); i++)
82+
mEmbedding(mAssignments(i), i) = 1;
83+
computeMeans(dataPoints);
84+
}
85+
86+
void updateEmbedding()
87+
{
88+
for (index i = 0; i < mAssignments.cols(); i++)
89+
{
90+
double val = mEmbedding(mAssignments(i), i);
91+
mEmbedding.col(i).setZero();
92+
mEmbedding(mAssignments(i), i) = val;
93+
}
94+
}
95+
96+
97+
Eigen::VectorXi assignClusters(Eigen::MatrixXd& embedding) const
98+
{
99+
Eigen::VectorXi assignments = Eigen::VectorXi::Zero(embedding.cols());
100+
for (index i = 0; i < embedding.cols(); i++)
101+
{
102+
Eigen::VectorXd::Index maxIndex;
103+
embedding.col(i).maxCoeff(&maxIndex);
104+
assignments(i) = static_cast<int>(maxIndex);
105+
}
106+
return assignments;
107+
}
108+
109+
110+
void computeMeans(Eigen::MatrixXd& dataPoints)
111+
{
112+
mMeans = mEmbedding * dataPoints;
113+
mMeans.matrix().rowwise().normalize();
114+
}
115+
116+
117+
private:
118+
Eigen::MatrixXd mEmbedding;
119+
};
120+
} // namespace algorithm
121+
} // namespace fluid

include/clients/nrt/KMeansClient.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ class KMeansClient : public FluidBaseClient,
146146

147147
StringVectorView ids = srcDataSet.getIds();
148148
RealMatrix output(srcDataSet.size(), mAlgorithm.size());
149-
mAlgorithm.getDistances(srcDataSet.getData(), output);
149+
mAlgorithm.transform(srcDataSet.getData(), output);
150150
FluidDataSet<string, double, 1> result(ids, output);
151151
destPtr->setDataSet(result);
152152
return OK();
@@ -224,7 +224,7 @@ class KMeansClient : public FluidBaseClient,
224224
RealMatrix dest(1, mAlgorithm.size());
225225
src.row(0) <<=
226226
BufferAdaptor::ReadAccess(in.get()).samps(0, mAlgorithm.dims(), 0);
227-
mAlgorithm.getDistances(src, dest);
227+
mAlgorithm.transform(src, dest);
228228
outBuf.allFrames()(Slice(0, 1), Slice(0, mAlgorithm.size())) <<= dest;
229229
return OK();
230230
}

0 commit comments

Comments
 (0)