@@ -133,7 +133,7 @@ class UMAP
133133 return out;
134134 }
135135
136- DataSet transform (DataSet& in, index maxIter = 200 , double learningRate = 1.0 )
136+ DataSet transform (DataSet& in, index maxIter = 200 , double learningRate = 1.0 ) const
137137 {
138138 if (!mInitialized ) return DataSet ();
139139 SparseMatrixXd knnGraph (in.size (), mEmbedding .rows ());
@@ -158,7 +158,7 @@ class UMAP
158158 }
159159
160160
161- void transformPoint (RealVectorView in, RealVectorView out)
161+ void transformPoint (RealVectorView in, RealVectorView out) const
162162 {
163163 if (!mInitialized ) return ;
164164 SparseMatrixXd knnGraph (1 , mEmbedding .rows ());
@@ -185,7 +185,7 @@ class UMAP
185185
186186private:
187187 template <typename F>
188- void traverseGraph (const SparseMatrixXd& graph, F func)
188+ void traverseGraph (const SparseMatrixXd& graph, F func) const
189189 {
190190 for (index i = 0 ; i < graph.outerSize (); i++)
191191 {
@@ -204,7 +204,7 @@ class UMAP
204204 }
205205
206206 ArrayXd findSigma (index k, Ref<ArrayXXd> dists, index maxIter = 64 ,
207- double tolerance = 1e-5 )
207+ double tolerance = 1e-5 ) const
208208 {
209209 using namespace std ;
210210 double target = log2 (k);
@@ -242,7 +242,7 @@ class UMAP
242242 }
243243
244244 void computeHighDimProb (const Ref<ArrayXXd>& dists, const Ref<ArrayXd>& sigma,
245- SparseMatrixXd& graph)
245+ SparseMatrixXd& graph) const
246246 {
247247 traverseGraph (graph, [&](auto it) {
248248 it.valueRef () =
@@ -263,7 +263,7 @@ class UMAP
263263 }
264264
265265 void makeGraph (const DataSet& in, index k, SparseMatrixXd& graph,
266- Ref<ArrayXXd> dists, bool discardFirst)
266+ Ref<ArrayXXd> dists, bool discardFirst) const
267267 {
268268 graph.reserve (in.size () * k);
269269 auto data = in.getData ();
@@ -298,7 +298,7 @@ class UMAP
298298 }
299299
300300 void getGraphIndices (const SparseMatrixXd& graph, Ref<ArrayXi> rowIndices,
301- Ref<ArrayXi> colIndices)
301+ Ref<ArrayXi> colIndices) const
302302 {
303303 index p = 0 ;
304304 traverseGraph (graph, [&](auto it) {
@@ -309,7 +309,7 @@ class UMAP
309309 }
310310
311311 void computeEpochsPerSample (const SparseMatrixXd& graph,
312- Ref<ArrayXd> epochsPerSample)
312+ Ref<ArrayXd> epochsPerSample) const
313313 {
314314 index p = 0 ;
315315 double maxVal = graph.coeffs ().maxCoeff ();
@@ -321,7 +321,7 @@ class UMAP
321321 void optimizeLayout (Ref<ArrayXXd> embedding, Ref<ArrayXXd> reference,
322322 Ref<ArrayXi> embIndices, Ref<ArrayXi> refIndices,
323323 Ref<ArrayXd> epochsPerSample, bool updateReference,
324- double learningRate, index maxIter, double gamma = 1.0 )
324+ double learningRate, index maxIter, double gamma = 1.0 ) const
325325 {
326326 using namespace std ;
327327 double alpha = learningRate;
@@ -385,7 +385,7 @@ class UMAP
385385 }
386386
387387 ArrayXXd initTransformEmbedding (const SparseMatrixXd& graph,
388- Ref<ArrayXXd> reference, index N)
388+ Ref<const ArrayXXd> reference, index N) const
389389 {
390390 ArrayXXd embedding = ArrayXXd::Zero (N, reference.cols ());
391391 traverseGraph (graph, [&](auto it) {
@@ -394,7 +394,7 @@ class UMAP
394394 return embedding;
395395 }
396396
397- void normalizeRows (const SparseMatrixXd& graph)
397+ void normalizeRows (const SparseMatrixXd& graph) const
398398 {
399399 ArrayXd sums = ArrayXd::Zero (graph.innerSize ());
400400 traverseGraph (graph, [&](auto it) { sums (it.row ()) += it.value (); });
@@ -406,7 +406,7 @@ class UMAP
406406 KDTree mTree ;
407407 index mK ;
408408 VectorXd mAB ;
409- ArrayXXd mEmbedding ;
409+ mutable ArrayXXd mEmbedding ;
410410 bool mInitialized {false };
411411};
412412}// namespace algorithm
0 commit comments