Skip to content

Commit

Permalink
Support user input instance
Browse files Browse the repository at this point in the history
  • Loading branch information
gy910210 committed Feb 22, 2016
1 parent cca1d81 commit 1610041
Show file tree
Hide file tree
Showing 10 changed files with 276 additions and 16 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package org.apache.samoa.learners;

/*
* #%L
* SAMOA
* %%
* Copyright (C) 2014 - 2016 Apache Software Foundation
* %%
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* #L%
*/

import org.apache.samoa.instances.Instance;

import java.io.Serializable;

public interface InputInstance extends Serializable {
Instance getInstance();
}
4 changes: 1 addition & 3 deletions samoa-api/src/main/java/org/apache/samoa/learners/Model.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,8 @@
*/


import org.apache.samoa.instances.Instance;

import java.io.Serializable;

public interface Model extends Serializable {
double[] predict(Instance inst);
double[] predict(InputInstance inputInstance);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package org.apache.samoa.learners.classifiers;

/*
* #%L
* SAMOA
* %%
* Copyright (C) 2014 - 2016 Apache Software Foundation
* %%
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* #L%
*/

import org.apache.samoa.instances.*;
import org.apache.samoa.learners.InputInstance;
import org.apache.samoa.moa.core.FastVector;

public class NominalInputInstance implements InputInstance {

private int numNominals;
private int numClasses;
private double trueClass; // index started from 0
private int[] numValsPerNominal;
private double[] data; // index started from 0

public NominalInputInstance(int numNominals, int numClasses, double trueClass,
int[] numValsPerNominal, double[] data) {
this.numNominals = numNominals;
this.numClasses = numClasses;
this.trueClass = trueClass;
this.numValsPerNominal = numValsPerNominal;
this.data = data;
}

private InstancesHeader generateHeader() {
FastVector<Attribute> attributes = new FastVector<>();

for (int i = 0; i < this.numNominals; i++) {
FastVector<String> nominalAttVals = new FastVector<>();
for (int j = 0; j < this.numValsPerNominal[i]; j++) {
nominalAttVals.addElement("value" + (j + 1));
}
attributes.addElement(new Attribute("nominal" + (i + 1),
nominalAttVals));
}

FastVector<String> classLabels = new FastVector<>();
for (int i = 0; i < this.numClasses; i++) {
classLabels.addElement("class" + (i + 1));
}
attributes.addElement(new Attribute("class", classLabels));

InstancesHeader instancesHeader = new InstancesHeader(
new Instances(null, attributes, 0));
instancesHeader.setClassIndex(instancesHeader.numAttributes() - 1);

return instancesHeader;
}

@Override
public Instance getInstance() {
InstancesHeader header = this.generateHeader();
Instance inst = new DenseInstance(header.numAttributes());

for (int i = 0; i < this.numNominals; i++) {
inst.setValue(i, this.data[i]);
}

inst.setDataset(header);
inst.setClassValue(this.trueClass);

return inst;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package org.apache.samoa.learners.classifiers;

/*
* #%L
* SAMOA
* %%
* Copyright (C) 2014 - 2016 Apache Software Foundation
* %%
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* #L%
*/

import org.apache.samoa.instances.*;
import org.apache.samoa.learners.InputInstance;
import org.apache.samoa.moa.core.FastVector;

public class NumericInputInstance implements InputInstance {
private int numNumerics;
private int numClasses;
private double trueClass; // index started from 0
private double[] data;

public NumericInputInstance(int numNumerics, int numClasses, double trueClass, double[] data) {
this.numNumerics = numNumerics;
this.numClasses = numClasses;
this.trueClass = trueClass;
this.data = data;
}

private InstancesHeader generateHeader() {
FastVector<Attribute> attributes = new FastVector<>();

for (int i = 0; i < this.numNumerics; i++) {
attributes.addElement(new Attribute("numeric" + (i + 1)));
}

FastVector<String> classLabels = new FastVector<>();
for (int i = 0; i < this.numClasses; i++) {
classLabels.addElement("class" + (i + 1));
}
attributes.addElement(new Attribute("class", classLabels));

InstancesHeader instancesHeader = new InstancesHeader(
new Instances(null, attributes, 0));
instancesHeader.setClassIndex(instancesHeader.numAttributes() - 1);

return instancesHeader;
}

@Override
public Instance getInstance() {
InstancesHeader header = this.generateHeader();
Instance inst = new DenseInstance(header.numAttributes());

for (int i = 0; i < this.numNumerics; i++) {
inst.setValue(i, this.data[i]);
}

inst.setDataset(header);
inst.setClassValue(this.trueClass);

return inst;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import org.apache.samoa.instances.Instance;
import org.apache.samoa.instances.Utils;
import org.apache.samoa.learners.InputInstance;
import org.apache.samoa.learners.Model;
import org.apache.samoa.moa.core.DoubleVector;

Expand Down Expand Up @@ -42,10 +43,10 @@ public EnsembleModel(ArrayList<Model> modelList, ArrayList<Double> modelWeightLi
}

@Override
public double[] predict(Instance inst) {
public double[] predict(InputInstance inputInstance) {
DoubleVector combinedVote = new DoubleVector();
for (int i = 0; i < modelList.size(); i++) {
double[] prediction = modelList.get(i).predict(inst);
double[] prediction = modelList.get(i).predict(inputInstance);
DoubleVector vote = new DoubleVector(prediction);
if (vote.sumOfValues() > 0.0) {
vote.normalize();
Expand All @@ -56,9 +57,10 @@ public double[] predict(Instance inst) {
return combinedVote.getArrayCopy();
}

public boolean evaluate(Instance inst) {
public boolean evaluate(InputInstance inputInstance) {
Instance inst = inputInstance.getInstance();
int trueClass = (int) inst.classValue();
double[] prediction = this.predict(inst);
double[] prediction = this.predict(inputInstance);
int predictedClass = Utils.maxIndex(prediction);
return trueClass == predictedClass;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
*/

import org.apache.samoa.instances.Instance;
import org.apache.samoa.learners.InputInstance;
import org.apache.samoa.learners.Model;
import org.apache.samoa.learners.classifiers.rules.common.ActiveRule;
import org.apache.samoa.learners.classifiers.rules.common.PassiveRule;
Expand All @@ -46,7 +47,8 @@ public AMRulesModel(ActiveRule defaultRule, List<PassiveRule> ruleSet,
}

@Override
public double[] predict(Instance inst) {
public double[] predict(InputInstance inputInstance) {
Instance inst = inputInstance.getInstance();
double[] prediction;
boolean predictionCovered = false;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.samoa.instances.Instance;
import org.apache.samoa.instances.Instances;
import org.apache.samoa.instances.Utils;
import org.apache.samoa.learners.InputInstance;
import org.apache.samoa.learners.Model;

public class HoeffdingTreeModel implements Model {
Expand All @@ -39,7 +40,8 @@ public HoeffdingTreeModel() {
}

@Override
public double[] predict(Instance inst) {
public double[] predict(InputInstance inputInstance) {
Instance inst = inputInstance.getInstance();
double[] prediction;
// inst.setDataset(dataset);

Expand All @@ -58,9 +60,10 @@ public double[] predict(Instance inst) {
return prediction;
}

public boolean evaluate(Instance inst) {
public boolean evaluate(InputInstance inputInstance) {
Instance inst = inputInstance.getInstance();
int trueClass = (int) inst.classValue();
double[] prediction = this.predict(inst);
double[] prediction = this.predict(inputInstance);
int predictedClass = Utils.maxIndex(prediction);
return trueClass == predictedClass;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
*/


import org.apache.samoa.instances.Instance;
import org.apache.samoa.learners.InputInstance;
import org.apache.samoa.learners.Model;
import org.apache.samoa.moa.cluster.Clustering;
import org.apache.samoa.moa.core.DataPoint;
Expand All @@ -40,8 +40,8 @@ public CluStreamModel() {
}

@Override
public double[] predict(Instance inst) {
DataPoint dataPoint = (DataPoint) inst;
public double[] predict(InputInstance inputInstance) {
DataPoint dataPoint = (DataPoint) inputInstance.getInstance();
double[] distances = new double[clustering.size()];
for (int c = 0; c < clustering.size(); c++) {
double distance = 0.0;
Expand All @@ -55,10 +55,14 @@ public double[] predict(Instance inst) {
return distances;
}

public double evaluate(ArrayList<DataPoint> points, MeasureCollection measure) {
public double evaluate(ArrayList<InputInstance> points, MeasureCollection measure) {
ArrayList<DataPoint> dataPoints = new ArrayList<>();
for (InputInstance inputInstance : points) {
dataPoints.add((DataPoint) inputInstance.getInstance());
}
double score = 0.0;
try {
measure.evaluateClusteringPerformance(clustering, null, points);
measure.evaluateClusteringPerformance(clustering, null, dataPoints);
score = measure.getMean(0);
} catch (Exception e) {
e.printStackTrace();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package org.apache.samoa.learners.clusterers;

/*
* #%L
* SAMOA
* %%
* Copyright (C) 2014 - 2016 Apache Software Foundation
* %%
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* #L%
*/

import org.apache.samoa.instances.*;
import org.apache.samoa.learners.InputInstance;
import org.apache.samoa.moa.core.DataPoint;

import java.util.ArrayList;

public class ClusterInputInstance implements InputInstance {
private int numAtts;
private int timeStamp;
private double[] data;

public ClusterInputInstance(int numAtts, int timeStamp, double[] data) {
this.numAtts = numAtts;
this.timeStamp = timeStamp;
this.data = data;
}

private InstancesHeader generateHeader() {
ArrayList<Attribute> attributes = new ArrayList<>();

for (int i = 0; i < this.numAtts; i++) {
attributes.add(new Attribute("att" + (i + 1)));
}

// attributes.add(new Attribute("class", null));

InstancesHeader instancesHeader = new InstancesHeader(
new Instances(null, attributes, 0));
instancesHeader.setClassIndex(instancesHeader.numAttributes() - 1);

return instancesHeader;
}

@Override
public Instance getInstance() {
Instance inst = new DenseInstance(1.0, this.data);
inst.setDataset(generateHeader());
return new DataPoint(inst, this.timeStamp);
}
}
3 changes: 3 additions & 0 deletions samoa-gearpump/src/main/resources/reference.conf
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ gearpump {
"org.apache.samoa.learners.classifiers.trees.HoeffdingTreeModel" = ""
"org.apache.samoa.learners.classifiers.trees.ActiveLearningNode" = ""
"[Lorg.apache.samoa.learners.classifiers.trees.AttributeBatchContentEvent;" = ""
"com.github.javacliparser.IntOption" = ""
"com.github.javacliparser.FloatOption" = ""
"org.apache.samoa.learners.InstanceContent" = ""
"java.util.ArrayList" = ""
"java.util.LinkedList" = ""
"java.util.HashMap" = ""
Expand Down

0 comments on commit 1610041

Please sign in to comment.