Skip to content

Commit 06e6fb4

Browse files
authored
Enhance/const correct(er) dataobjects (#115)
* Const correct arguments for data object messages; actually const object refs for RT usage * MLP const update * SharedClient const updates * workflow: Disable parallel test runner and turn up verbosity * remove std::cout access from segfaulting test (just in case)
1 parent 5a87c90 commit 06e6fb4

22 files changed

+151
-134
lines changed

include/algorithms/public/MLP.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ class MLP
100100
}
101101

102102
void processFrame(RealVectorView in, RealVectorView out, index startLayer,
103-
index endLayer)
103+
index endLayer) const
104104
{
105105
using namespace _impl;
106106
using namespace Eigen;
@@ -113,13 +113,13 @@ class MLP
113113
out <<= asFluid(tmpOut);
114114
}
115115

116-
void forward(Eigen::Ref<ArrayXXd> in, Eigen::Ref<ArrayXXd> out)
116+
void forward(Eigen::Ref<ArrayXXd> in, Eigen::Ref<ArrayXXd> out) const
117117
{
118118
forward(in, out, 0, asSigned(mLayers.size()));
119119
}
120120

121121
void forward(Eigen::Ref<ArrayXXd> in, Eigen::Ref<ArrayXXd> out,
122-
index startLayer, index endLayer)
122+
index startLayer, index endLayer) const
123123
{
124124
if (startLayer >= asSigned(mLayers.size()) ||
125125
endLayer > asSigned(mLayers.size()))
@@ -137,7 +137,7 @@ class MLP
137137
out = output;
138138
}
139139

140-
void backward(Eigen::Ref<ArrayXXd> out)
140+
void backward(Eigen::Ref<ArrayXXd> out)
141141
{
142142
index nRows = out.rows();
143143
ArrayXXd chain =

include/algorithms/public/UMAP.hpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ class UMAP
133133
return out;
134134
}
135135

136-
DataSet transform(DataSet& in, index maxIter = 200, double learningRate = 1.0)
136+
DataSet transform(DataSet& in, index maxIter = 200, double learningRate = 1.0) const
137137
{
138138
if (!mInitialized) return DataSet();
139139
SparseMatrixXd knnGraph(in.size(), mEmbedding.rows());
@@ -158,7 +158,7 @@ class UMAP
158158
}
159159

160160

161-
void transformPoint(RealVectorView in, RealVectorView out)
161+
void transformPoint(RealVectorView in, RealVectorView out) const
162162
{
163163
if (!mInitialized) return;
164164
SparseMatrixXd knnGraph(1, mEmbedding.rows());
@@ -185,7 +185,7 @@ class UMAP
185185

186186
private:
187187
template <typename F>
188-
void traverseGraph(const SparseMatrixXd& graph, F func)
188+
void traverseGraph(const SparseMatrixXd& graph, F func) const
189189
{
190190
for (index i = 0; i < graph.outerSize(); i++)
191191
{
@@ -204,7 +204,7 @@ class UMAP
204204
}
205205

206206
ArrayXd findSigma(index k, Ref<ArrayXXd> dists, index maxIter = 64,
207-
double tolerance = 1e-5)
207+
double tolerance = 1e-5) const
208208
{
209209
using namespace std;
210210
double target = log2(k);
@@ -242,7 +242,7 @@ class UMAP
242242
}
243243

244244
void computeHighDimProb(const Ref<ArrayXXd>& dists, const Ref<ArrayXd>& sigma,
245-
SparseMatrixXd& graph)
245+
SparseMatrixXd& graph) const
246246
{
247247
traverseGraph(graph, [&](auto it) {
248248
it.valueRef() =
@@ -263,7 +263,7 @@ class UMAP
263263
}
264264

265265
void makeGraph(const DataSet& in, index k, SparseMatrixXd& graph,
266-
Ref<ArrayXXd> dists, bool discardFirst)
266+
Ref<ArrayXXd> dists, bool discardFirst) const
267267
{
268268
graph.reserve(in.size() * k);
269269
auto data = in.getData();
@@ -298,7 +298,7 @@ class UMAP
298298
}
299299

300300
void getGraphIndices(const SparseMatrixXd& graph, Ref<ArrayXi> rowIndices,
301-
Ref<ArrayXi> colIndices)
301+
Ref<ArrayXi> colIndices) const
302302
{
303303
index p = 0;
304304
traverseGraph(graph, [&](auto it) {
@@ -309,7 +309,7 @@ class UMAP
309309
}
310310

311311
void computeEpochsPerSample(const SparseMatrixXd& graph,
312-
Ref<ArrayXd> epochsPerSample)
312+
Ref<ArrayXd> epochsPerSample) const
313313
{
314314
index p = 0;
315315
double maxVal = graph.coeffs().maxCoeff();
@@ -321,7 +321,7 @@ class UMAP
321321
void optimizeLayout(Ref<ArrayXXd> embedding, Ref<ArrayXXd> reference,
322322
Ref<ArrayXi> embIndices, Ref<ArrayXi> refIndices,
323323
Ref<ArrayXd> epochsPerSample, bool updateReference,
324-
double learningRate, index maxIter, double gamma = 1.0)
324+
double learningRate, index maxIter, double gamma = 1.0) const
325325
{
326326
using namespace std;
327327
double alpha = learningRate;
@@ -385,7 +385,7 @@ class UMAP
385385
}
386386

387387
ArrayXXd initTransformEmbedding(const SparseMatrixXd& graph,
388-
Ref<ArrayXXd> reference, index N)
388+
Ref<const ArrayXXd> reference, index N) const
389389
{
390390
ArrayXXd embedding = ArrayXXd::Zero(N, reference.cols());
391391
traverseGraph(graph, [&](auto it) {
@@ -394,7 +394,7 @@ class UMAP
394394
return embedding;
395395
}
396396

397-
void normalizeRows(const SparseMatrixXd& graph)
397+
void normalizeRows(const SparseMatrixXd& graph) const
398398
{
399399
ArrayXd sums = ArrayXd::Zero(graph.innerSize());
400400
traverseGraph(graph, [&](auto it) { sums(it.row()) += it.value(); });
@@ -406,7 +406,7 @@ class UMAP
406406
KDTree mTree;
407407
index mK;
408408
VectorXd mAB;
409-
ArrayXXd mEmbedding;
409+
mutable ArrayXXd mEmbedding;
410410
bool mInitialized{false};
411411
};
412412
}// namespace algorithm

include/algorithms/util/NNLayer.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ class NNLayer
6868

6969
index outputSize() const { return mWeights.cols(); }
7070

71-
void forward(Eigen::Ref<MatrixXd> in, Eigen::Ref<MatrixXd> out)
71+
void forward(Eigen::Ref<MatrixXd> in, Eigen::Ref<MatrixXd> out) const
7272
{
7373
mInput = in;
7474
MatrixXd WT = mWeights.transpose();
@@ -114,8 +114,8 @@ class NNLayer
114114
MatrixXd mPrevWeightsUpdate;
115115
VectorXd mPrevBiasesUpdate;
116116

117-
MatrixXd mInput;
118-
MatrixXd mOutput;
117+
mutable MatrixXd mInput;
118+
mutable MatrixXd mOutput;
119119
};
120120
} // namespace algorithm
121121
} // namespace fluid

include/clients/common/SharedClientUtils.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ class SharedClientRef
2727

2828
SharedClientRef() {}
2929
SharedClientRef(const char* name) : mName{name} {}
30-
WeakPointer get() { return {SharedType::lookup(mName)}; }
30+
WeakPointer get() const { return {SharedType::lookup(mName)}; }
3131
void set(const char* name) { mName = std::string(name); }
32-
const char* name() { return mName.c_str(); }
32+
const char* name() const { return mName.c_str(); }
3333

3434
// Supporting machinery for making new parameter types
3535

include/clients/nrt/ClientInputChecks.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class InBufferCheck : public ClientInputCheck
3131
{
3232
public:
3333
InBufferCheck(index size) : mInputSize(size){};
34-
bool checkInputs(BufferAdaptor* inputPtr)
34+
bool checkInputs(const BufferAdaptor* inputPtr)
3535
{
3636
if (!inputPtr)
3737
{
@@ -61,7 +61,7 @@ class InOutBuffersCheck : public InBufferCheck
6161

6262
public:
6363
using InBufferCheck::InBufferCheck;
64-
bool checkInputs(BufferAdaptor* inputPtr, BufferAdaptor* outputPtr)
64+
bool checkInputs(const BufferAdaptor* inputPtr, BufferAdaptor* outputPtr)
6565
{
6666
if (!InBufferCheck::checkInputs(inputPtr)) { return false; }
6767
if (!outputPtr)

include/clients/nrt/DataClient.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ class DataClient
2727
public:
2828
using string = std::string;
2929

30-
MessageResult<index> size() { return mAlgorithm.size(); }
30+
MessageResult<index> size() const { return mAlgorithm.size(); }
3131

32-
MessageResult<index> dims() { return mAlgorithm.dims(); }
32+
MessageResult<index> dims() const { return mAlgorithm.dims(); }
3333

3434
MessageResult<void> clear()
3535
{
@@ -80,8 +80,8 @@ class DataClient
8080
}
8181
}
8282

83-
bool initialized() { return mAlgorithm.initialized(); }
84-
T& algorithm() { return mAlgorithm; }
83+
bool initialized() const { return mAlgorithm.initialized(); }
84+
T const& algorithm() const { return mAlgorithm; }
8585
protected:
8686
T mAlgorithm;
8787
};

include/clients/nrt/DataSetClient.hpp

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class DataSetClient : public FluidBaseClient,
3535
public:
3636
using string = std::string;
3737
using BufferPtr = std::shared_ptr<BufferAdaptor>;
38+
using InputBufferPtr = std::shared_ptr<const BufferAdaptor>;
3839
using DataSet = FluidDataSet<string, double, 1>;
3940
using LabelSet = FluidDataSet<string, string, 1>;
4041

@@ -61,11 +62,11 @@ class DataSetClient : public FluidBaseClient,
6162

6263
DataSetClient(ParamSetViewType& p) : mParams(p) {}
6364

64-
MessageResult<void> addPoint(string id, BufferPtr data)
65+
MessageResult<void> addPoint(string id, InputBufferPtr data)
6566
{
6667
DataSet& dataset = mAlgorithm;
6768
if (!data) return Error(NoBuffer);
68-
BufferAdaptor::Access buf(data.get());
69+
BufferAdaptor::ReadAccess buf(data.get());
6970
if (!buf.exists()) return Error(InvalidBuffer);
7071
if (buf.numFrames() == 0) return Error(EmptyBuffer);
7172
if (dataset.size() == 0)
@@ -101,23 +102,23 @@ class DataSetClient : public FluidBaseClient,
101102
}
102103
}
103104

104-
MessageResult<void> updatePoint(string id, BufferPtr data)
105+
MessageResult<void> updatePoint(string id, InputBufferPtr data)
105106
{
106107
if (!data) return Error(NoBuffer);
107-
BufferAdaptor::Access buf(data.get());
108+
BufferAdaptor::ReadAccess buf(data.get());
108109
if (!buf.exists()) return Error(InvalidBuffer);
109110
if (buf.numFrames() < mAlgorithm.dims()) return Error(WrongPointSize);
110111
RealVector point(mAlgorithm.dims());
111112
point <<= buf.samps(0, mAlgorithm.dims(), 0);
112113
return mAlgorithm.update(id, point) ? OK() : Error(PointNotFound);
113114
}
114115

115-
MessageResult<void> setPoint(string id, BufferPtr data)
116+
MessageResult<void> setPoint(string id, InputBufferPtr data)
116117
{
117118
if (!data) return Error(NoBuffer);
118119

119120
{ // restrict buffer lock to this scope in case addPoint is called
120-
BufferAdaptor::Access buf(data.get());
121+
BufferAdaptor::ReadAccess buf(data.get());
121122
if (!buf.exists()) return Error(InvalidBuffer);
122123
if (buf.numFrames() < mAlgorithm.dims()) return Error(WrongPointSize);
123124
RealVector point(mAlgorithm.dims());
@@ -133,7 +134,7 @@ class DataSetClient : public FluidBaseClient,
133134
return mAlgorithm.remove(id) ? OK() : Error(PointNotFound);
134135
}
135136

136-
MessageResult<void> merge(SharedClientRef<DataSetClient> datasetClient,
137+
MessageResult<void> merge(SharedClientRef<const DataSetClient> datasetClient,
137138
bool overwrite)
138139
{
139140
auto datasetClientPtr = datasetClient.get().lock();
@@ -154,11 +155,11 @@ class DataSetClient : public FluidBaseClient,
154155
}
155156

156157
MessageResult<void>
157-
fromBuffer(BufferPtr data, bool transpose,
158-
SharedClientRef<labelset::LabelSetClient> labels)
158+
fromBuffer(InputBufferPtr data, bool transpose,
159+
SharedClientRef<const labelset::LabelSetClient> labels)
159160
{
160161
if (!data) return Error(NoBuffer);
161-
BufferAdaptor::Access buf(data.get());
162+
BufferAdaptor::ReadAccess buf(data.get());
162163
if (!buf.exists()) return Error(InvalidBuffer);
163164
auto bufView = transpose ? buf.allFrames() : buf.allFrames().transpose();
164165
if (auto labelsPtr = labels.get().lock())
@@ -256,6 +257,8 @@ class DataSetClient : public FluidBaseClient,
256257
} // namespace dataset
257258

258259
using DataSetClientRef = SharedClientRef<dataset::DataSetClient>;
260+
using InputDataSetClientRef = SharedClientRef<const dataset::DataSetClient>;
261+
259262
using NRTThreadedDataSetClient =
260263
NRTThreadingAdaptor<typename DataSetClientRef::SharedType>;
261264

include/clients/nrt/DataSetQueryClient.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ class DataSetQueryClient : public FluidBaseClient, OfflineIn, OfflineOut
100100
}
101101

102102

103-
MessageResult<void> transform(DataSetClientRef sourceClient,
103+
MessageResult<void> transform(InputDataSetClientRef sourceClient,
104104
DataSetClientRef destClient)
105105
{
106106
if (mAlgorithm.numColumns() <= 0) return Error("No columns");
@@ -118,8 +118,8 @@ class DataSetQueryClient : public FluidBaseClient, OfflineIn, OfflineOut
118118
return OK();
119119
}
120120

121-
MessageResult<void> transformJoin(DataSetClientRef source1Client,
122-
DataSetClientRef source2Client,
121+
MessageResult<void> transformJoin(InputDataSetClientRef source1Client,
122+
InputDataSetClientRef source2Client,
123123
DataSetClientRef destClient)
124124
{
125125
auto src1Ptr = source1Client.get().lock();

include/clients/nrt/GridClient.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class GridClient : public FluidBaseClient, OfflineIn, OfflineOut, ModelObject
5656

5757
GridClient(ParamSetViewType& p) : mParams(p) {}
5858

59-
MessageResult<void> fitTransform(DataSetClientRef sourceClient,
59+
MessageResult<void> fitTransform(InputDataSetClientRef sourceClient,
6060
DataSetClientRef destClient)
6161
{
6262
auto srcPtr = sourceClient.get().lock();

include/clients/nrt/KDTreeClient.hpp

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class KDTreeClient : public FluidBaseClient,
3535
public:
3636
using string = std::string;
3737
using BufferPtr = std::shared_ptr<BufferAdaptor>;
38+
using InputBufferPtr = std::shared_ptr<const BufferAdaptor>;
3839
using StringVector = FluidTensor<string, 1>;
3940
using ParamDescType = decltype(KDTreeParams);
4041

@@ -63,7 +64,7 @@ class KDTreeClient : public FluidBaseClient,
6364
return {};
6465
}
6566

66-
MessageResult<void> fit(DataSetClientRef datasetClient)
67+
MessageResult<void> fit(InputDataSetClientRef datasetClient)
6768
{
6869
mDataSetClient = datasetClient;
6970
auto datasetClientPtr = mDataSetClient.get().lock();
@@ -74,7 +75,7 @@ class KDTreeClient : public FluidBaseClient,
7475
return OK();
7576
}
7677

77-
MessageResult<StringVector> kNearest(BufferPtr data) const
78+
MessageResult<StringVector> kNearest(InputBufferPtr data) const
7879
{
7980
index k = get<kNumNeighbors>();
8081
if (k > mAlgorithm.size()) return Error<StringVector>(SmallDataSet);
@@ -92,7 +93,7 @@ class KDTreeClient : public FluidBaseClient,
9293
return result;
9394
}
9495

95-
MessageResult<RealVector> kNearestDist(BufferPtr data) const
96+
MessageResult<RealVector> kNearestDist(InputBufferPtr data) const
9697
{
9798
// TODO: refactor with kNearest
9899
index k = get<kNumNeighbors>();
@@ -126,22 +127,22 @@ class KDTreeClient : public FluidBaseClient,
126127
makeMessage("read", &KDTreeClient::read));
127128
}
128129

129-
DataSetClientRef getDataSet() { return mDataSetClient; }
130+
InputDataSetClientRef getDataSet() const { return mDataSetClient; }
130131

131-
const algorithm::KDTree& algorithm() { return mAlgorithm; }
132+
const algorithm::KDTree& algorithm() const { return mAlgorithm; }
132133

133134
private:
134-
DataSetClientRef mDataSetClient;
135+
InputDataSetClientRef mDataSetClient;
135136
};
136137

137-
using KDTreeRef = SharedClientRef<KDTreeClient>;
138+
using KDTreeRef = SharedClientRef<const KDTreeClient>;
138139

139140
constexpr auto KDTreeQueryParams = defineParameters(
140141
KDTreeRef::makeParam("tree", "KDTree"),
141142
LongParam("numNeighbours", "Number of Nearest Neighbours", 1),
142143
FloatParam("radius", "Maximum distance", 0, Min(0)),
143-
DataSetClientRef::makeParam("dataSet", "DataSet Name"),
144-
BufferParam("inputPointBuffer", "Input Point Buffer"),
144+
InputDataSetClientRef::makeParam("dataSet", "DataSet Name"),
145+
InputBufferParam("inputPointBuffer", "Input Point Buffer"),
145146
BufferParam("predictionBuffer", "Prediction Buffer"));
146147

147148
class KDTreeQuery : public FluidBaseClient, ControlIn, ControlOut
@@ -238,7 +239,7 @@ class KDTreeQuery : public FluidBaseClient, ControlIn, ControlOut
238239

239240
private:
240241
RealVector mRTBuffer;
241-
DataSetClientRef mDataSetClient;
242+
InputDataSetClientRef mDataSetClient;
242243
};
243244

244245
} // namespace kdtree

0 commit comments

Comments
 (0)