Skip to content

Commit 3791e5c

Browse files
U-SIOUX\NovikovaU-SIOUX\Novikova
U-SIOUX\Novikova
authored and
U-SIOUX\Novikova
committed
#540: 'predict' method for X-Means algorithm.
1 parent 74833ea commit 3791e5c

File tree

5 files changed

+99
-37
lines changed

5 files changed

+99
-37
lines changed

pyclustering/cluster/kmeans.py

-24
Original file line numberDiff line numberDiff line change
@@ -441,30 +441,6 @@ def predict(self, points):
441441
@return (list) List of closest clusters for each point. Each cluster is denoted by index. Return empty
442442
collection if 'process()' method was not called.
443443
444-
An example how to calculate (or predict) the closest cluster to specified points.
445-
@code
446-
from pyclustering.cluster.kmeans import kmeans
447-
from pyclustering.samples.definitions import SIMPLE_SAMPLES
448-
from pyclustering.utils import read_sample
449-
450-
# Load list of points for cluster analysis.
451-
sample = read_sample(SIMPLE_SAMPLES.SAMPLE_SIMPLE3)
452-
453-
# Initial centers for sample 'Simple3'.
454-
initial_centers = [[0.2, 0.1], [4.0, 1.0], [2.0, 2.0], [2.3, 3.9]]
455-
456-
# Create instance of K-Means algorithm with prepared centers.
457-
kmeans_instance = kmeans(sample, initial_centers)
458-
459-
# Run cluster analysis.
460-
kmeans_instance.process()
461-
462-
# Calculate the closest cluster to following two points.
463-
points = [[0.25, 0.2], [2.5, 4.0]]
464-
closest_clusters = kmeans_instance.predict(points)
465-
print(closest_clusters)
466-
@endcode
467-
468444
"""
469445

470446
nppoints = numpy.array(points)

pyclustering/cluster/tests/integration/it_xmeans.py

+19-8
Original file line numberDiff line numberDiff line change
@@ -24,18 +24,18 @@
2424
"""
2525

2626

27-
import unittest;
27+
import unittest
2828

2929
# Generate images without having a window appear.
30-
import matplotlib;
31-
matplotlib.use('Agg');
30+
import matplotlib
31+
matplotlib.use('Agg')
3232

33-
from pyclustering.cluster.tests.xmeans_templates import XmeansTestTemplates;
34-
from pyclustering.cluster.xmeans import xmeans, splitting_type;
33+
from pyclustering.cluster.tests.xmeans_templates import XmeansTestTemplates
34+
from pyclustering.cluster.xmeans import xmeans, splitting_type
3535

36-
from pyclustering.samples.definitions import SIMPLE_SAMPLES, FCPS_SAMPLES;
36+
from pyclustering.samples.definitions import SIMPLE_SAMPLES, FCPS_SAMPLES
3737

38-
from pyclustering.core.tests import remove_library;
38+
from pyclustering.core.tests import remove_library
3939

4040

4141
class XmeansIntegrationTest(unittest.TestCase):
@@ -184,11 +184,22 @@ def testKmax05Amount20Offset02Initial05(self):
184184
def testKmax05Amount01Offset01Initial04(self):
185185
XmeansTestTemplates.templateMaxAllocatedClusters(True, 1, 1000, 1, 4, 5);
186186

187+
def testPredictOnePoint(self):
188+
centers = [[0.2, 0.1], [4.0, 1.0], [2.0, 2.0], [2.3, 3.9]]
189+
XmeansTestTemplates.templatePredict(SIMPLE_SAMPLES.SAMPLE_SIMPLE3, centers, [[0.3, 0.2]], 4, [0], True)
190+
XmeansTestTemplates.templatePredict(SIMPLE_SAMPLES.SAMPLE_SIMPLE3, centers, [[4.1, 1.1]], 4, [1], True)
191+
XmeansTestTemplates.templatePredict(SIMPLE_SAMPLES.SAMPLE_SIMPLE3, centers, [[2.1, 1.9]], 4, [2], True)
192+
XmeansTestTemplates.templatePredict(SIMPLE_SAMPLES.SAMPLE_SIMPLE3, centers, [[2.1, 4.1]], 4, [3], True)
193+
194+
def testPredictTwoPoints(self):
195+
centers = [[0.2, 0.1], [4.0, 1.0], [2.0, 2.0], [2.3, 3.9]]
196+
XmeansTestTemplates.templatePredict(SIMPLE_SAMPLES.SAMPLE_SIMPLE3, centers, [[0.3, 0.2], [2.1, 1.9]], 4, [0, 2], True)
197+
XmeansTestTemplates.templatePredict(SIMPLE_SAMPLES.SAMPLE_SIMPLE3, centers, [[2.1, 4.1], [2.1, 1.9]], 4, [3, 2], True)
187198

188199
@remove_library
189200
def testProcessingWhenLibraryCoreCorrupted(self):
190201
XmeansTestTemplates.templateLengthProcessData(SIMPLE_SAMPLES.SAMPLE_SIMPLE1, [[3.7, 5.5], [6.7, 7.5]], [5, 5], splitting_type.BAYESIAN_INFORMATION_CRITERION, 20, True);
191202

192203

193204
if __name__ == "__main__":
194-
unittest.main();
205+
unittest.main()

pyclustering/cluster/tests/unit/ut_xmeans.py

+12
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,18 @@ def testKmax05Amount20Offset02Initial05(self):
184184
def testKmax05Amount01Offset01Initial04(self):
185185
XmeansTestTemplates.templateMaxAllocatedClusters(False, 1, 1000, 1, 4, 5)
186186

187+
def testPredictOnePoint(self):
188+
centers = [[0.2, 0.1], [4.0, 1.0], [2.0, 2.0], [2.3, 3.9]]
189+
XmeansTestTemplates.templatePredict(SIMPLE_SAMPLES.SAMPLE_SIMPLE3, centers, [[0.3, 0.2]], 4, [0], False)
190+
XmeansTestTemplates.templatePredict(SIMPLE_SAMPLES.SAMPLE_SIMPLE3, centers, [[4.1, 1.1]], 4, [1], False)
191+
XmeansTestTemplates.templatePredict(SIMPLE_SAMPLES.SAMPLE_SIMPLE3, centers, [[2.1, 1.9]], 4, [2], False)
192+
XmeansTestTemplates.templatePredict(SIMPLE_SAMPLES.SAMPLE_SIMPLE3, centers, [[2.1, 4.1]], 4, [3], False)
193+
194+
def testPredictTwoPoints(self):
195+
centers = [[0.2, 0.1], [4.0, 1.0], [2.0, 2.0], [2.3, 3.9]]
196+
XmeansTestTemplates.templatePredict(SIMPLE_SAMPLES.SAMPLE_SIMPLE3, centers, [[0.3, 0.2], [2.1, 1.9]], 4, [0, 2], False)
197+
XmeansTestTemplates.templatePredict(SIMPLE_SAMPLES.SAMPLE_SIMPLE3, centers, [[2.1, 4.1], [2.1, 1.9]], 4, [3, 2], False)
198+
187199

188200
if __name__ == "__main__":
189201
unittest.main()

pyclustering/cluster/tests/xmeans_templates.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,13 @@
2323
2424
"""
2525

26+
import numpy
2627
import random
2728

2829
from pyclustering.cluster.xmeans import xmeans, splitting_type
2930
from pyclustering.cluster.center_initializer import random_center_initializer
3031

31-
from pyclustering.utils import read_sample
32+
from pyclustering.utils import read_sample, distance_metric, type_metric
3233

3334
from pyclustering.tests.assertion import assertion
3435

@@ -63,6 +64,21 @@ def templateLengthProcessData(input_sample, start_centers, expected_cluster_leng
6364
assert obtained_cluster_sizes == expected_cluster_length;
6465

6566

67+
@staticmethod
68+
def templatePredict(path_to_file, initial_centers, points, expected_amount, expected_closest_clusters, ccore, **kwargs):
69+
sample = read_sample(path_to_file)
70+
71+
kmax = kwargs.get('kmax', 20)
72+
73+
xmeans_instance = xmeans(sample, initial_centers, kmax, 0.025, splitting_type.BAYESIAN_INFORMATION_CRITERION, ccore)
74+
xmeans_instance.process()
75+
76+
closest_clusters = xmeans_instance.predict(points)
77+
assertion.eq(expected_amount, len(xmeans_instance.get_clusters()))
78+
assertion.eq(len(expected_closest_clusters), len(closest_clusters))
79+
assertion.true(numpy.array_equal(numpy.array(expected_closest_clusters), closest_clusters))
80+
81+
6682
@staticmethod
6783
def templateClusterAllocationOneDimensionData(ccore_flag):
6884
input_data = [ [0.0] for _ in range(10) ] + [ [5.0] for _ in range(10) ] + [ [10.0] for _ in range(10) ] + [ [15.0] for _ in range(10) ]

pyclustering/cluster/xmeans.py

+51-4
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040

4141
import pyclustering.core.xmeans_wrapper as wrapper
4242

43-
from pyclustering.utils import euclidean_distance_square, euclidean_distance
43+
from pyclustering.utils import euclidean_distance_square, euclidean_distance, distance_metric, type_metric
4444

4545

4646
class splitting_type(IntEnum):
@@ -123,7 +123,7 @@ class xmeans:
123123
124124
"""
125125

126-
def __init__(self, data, initial_centers = None, kmax = 20, tolerance = 0.025, criterion = splitting_type.BAYESIAN_INFORMATION_CRITERION, ccore = True):
126+
def __init__(self, data, initial_centers=None, kmax=20, tolerance=0.025, criterion=splitting_type.BAYESIAN_INFORMATION_CRITERION, ccore=True):
127127
"""!
128128
@brief Constructor of clustering algorithm X-Means.
129129
@@ -143,7 +143,7 @@ def __init__(self, data, initial_centers = None, kmax = 20, tolerance = 0.025, c
143143
if initial_centers is not None:
144144
self.__centers = initial_centers[:]
145145
else:
146-
self.__centers = [ [random.random() for _ in range(len(data[0])) ] ]
146+
self.__centers = [[random.random() for _ in range(len(data[0]))]]
147147

148148
self.__kmax = kmax
149149
self.__tolerance = tolerance
@@ -165,7 +165,7 @@ def process(self):
165165
166166
"""
167167

168-
if (self.__ccore is True):
168+
if self.__ccore is True:
169169
self.__clusters, self.__centers = wrapper.xmeans(self.__pointer_data, self.__centers, self.__kmax, self.__tolerance, self.__criterion)
170170

171171
else:
@@ -185,6 +185,53 @@ def process(self):
185185
self.__clusters, self.__centers = self.__improve_parameters(self.__centers)
186186

187187

188+
def predict(self, points):
189+
"""!
190+
@brief Calculates the closest cluster to each point.
191+
192+
@param[in] points (array_like): Points for which closest clusters are calculated.
193+
194+
@return (list) List of closest clusters for each point. Each cluster is denoted by index. Return empty
195+
collection if 'process()' method was not called.
196+
197+
An example how to calculate (or predict) the closest cluster to specified points.
198+
@code
199+
from pyclustering.cluster.xmeans import xmeans
200+
from pyclustering.samples.definitions import SIMPLE_SAMPLES
201+
from pyclustering.utils import read_sample
202+
203+
# Load list of points for cluster analysis.
204+
sample = read_sample(SIMPLE_SAMPLES.SAMPLE_SIMPLE3)
205+
206+
# Initial centers for sample 'Simple3'.
207+
initial_centers = [[0.2, 0.1], [4.0, 1.0], [2.0, 2.0], [2.3, 3.9]]
208+
209+
# Create instance of X-Means algorithm with prepared centers.
210+
xmeans_instance = xmeans(sample, initial_centers)
211+
212+
# Run cluster analysis.
213+
xmeans_instance.process()
214+
215+
# Calculate the closest cluster to following two points.
216+
points = [[0.25, 0.2], [2.5, 4.0]]
217+
closest_clusters = xmeans_instance.predict(points)
218+
print(closest_clusters)
219+
@endcode
220+
221+
"""
222+
nppoints = numpy.array(points)
223+
if len(self.__clusters) == 0:
224+
return []
225+
226+
metric = distance_metric(type_metric.EUCLIDEAN_SQUARE, numpy_usage=True)
227+
228+
differences = numpy.zeros((len(nppoints), len(self.__centers)))
229+
for index_point in range(len(nppoints)):
230+
differences[index_point] = metric(nppoints[index_point], self.__centers)
231+
232+
return numpy.argmin(differences, axis=1)
233+
234+
188235
def get_clusters(self):
189236
"""!
190237
@brief Returns list of allocated clusters, each cluster contains indexes of objects in list of data.

0 commit comments

Comments
 (0)