-
Notifications
You must be signed in to change notification settings - Fork 0
/
KMeans.java
96 lines (87 loc) · 2.86 KB
/
KMeans.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.Queue;
import java.util.Set;
public class KMeans {
private Cluster[] clusters;
private String[] reference;
private Collection<DocumentObject> docs;
public KMeans(int numberOfClusters, Set<String> wordsSet, Collection<DocumentObject> docs){
clusters = new Cluster[numberOfClusters];
for(int i = 0; i < numberOfClusters;i++){clusters[i] = new Cluster(wordsSet.size());}
this.reference = new String[wordsSet.size()];
this.getWordsReference(wordsSet);
this.docs=docs;
}
public void execute(){
boolean converged = false;
Iterator<DocumentObject> docIt = docs.iterator();;
DocumentObject DOC;
ArrayList<Integer> point;
double currentMinDist = Double.MAX_VALUE, dist;
int chosenCluster = -1;
while(!converged){
System.out.println("NOT CONVIRGED: Classifying points and recalculating centers...");
converged = true;
while(docIt.hasNext()){
DOC = docIt.next();
point = getPoint(DOC.getWords());
currentMinDist = Double.MAX_VALUE;
for(int i = 0; i < this.clusters.length; i++){
dist = calculateDist(point,this.clusters[i].getCenter(),false);
if(dist < currentMinDist){
currentMinDist = dist;
chosenCluster = i;
}
}
if(!DOC.getClassifier().equals(chosenCluster + "")){
if(Integer.getInteger(DOC.getClassifier()) != null){
this.clusters[Integer.getInteger(DOC.getClassifier())].remove(DOC.getID());
}
this.clusters[chosenCluster].add(getPoint(DOC.getWords()), DOC.getID()); // Speed vs memory
DOC.changeClassifier(chosenCluster + "");
converged = false;
}
}
}
}
//get reference words for use
private void getWordsReference(Set<String> words){
Iterator<String> it = words.iterator();
int i = 0;
while(it.hasNext()){
this.reference[i] = it.next();
i++;
}
}
//return an n dimensional binary point of words for a given document.
private ArrayList<Integer> getPoint(Queue<String> documentWords){
ArrayList<Integer> point = new ArrayList<Integer>();
for(int i = 0; i < this.reference.length; i++){
if(documentWords.contains(this.reference[i])){
point.add(i, 1);//getNumberOfWords(documentWords,this.reference.get(i))
}else{point.add(i, 0);}
}
return point;
}
private static double calculateDist(ArrayList<Integer> point, ArrayList<Double> point2,boolean euclidian){
double distance = Double.MAX_VALUE, difference;
int count = 0;
if(euclidian){
for(int i = 0; i < point.size(); i++){
difference = point.get(i) - point2.get(i);
count += difference * difference;
}
distance = Math.sqrt(count + 0.0);
}else{
for(int i = 0; i < point.size(); i++){
difference = Math.abs(point.get(i) - point2.get(i));
count += difference;
}
distance = count;
}
return distance;
}
public Cluster[] getClusters(){return this.clusters;}
}