Skip to content

Commit

Permalink
Support data instance class for user usage, and add test cases for model
Browse files Browse the repository at this point in the history
serialization
  • Loading branch information
gy910210 committed Feb 26, 2016
1 parent cca1d81 commit afbcc6f
Show file tree
Hide file tree
Showing 15 changed files with 872 additions and 33 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package org.apache.samoa.learners;

/*
* #%L
* SAMOA
* %%
* Copyright (C) 2014 - 2016 Apache Software Foundation
* %%
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* #L%
*/

import java.io.Serializable;

public interface DataInstance extends Serializable {

}
139 changes: 139 additions & 0 deletions samoa-api/src/main/java/org/apache/samoa/learners/InstanceUtils.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
package org.apache.samoa.learners;

/*
* #%L
* SAMOA
* %%
* Copyright (C) 2014 - 2016 Apache Software Foundation
* %%
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* #L%
*/

import org.apache.samoa.instances.*;
import org.apache.samoa.learners.classifiers.NominalDataInstance;
import org.apache.samoa.learners.classifiers.NumericDataInstance;
import org.apache.samoa.learners.clusterers.ClusterDataInstance;
import org.apache.samoa.moa.core.DataPoint;
import org.apache.samoa.moa.core.FastVector;

import java.util.ArrayList;

public class InstanceUtils {
static private InstancesHeader getNumericInstanceHeader(NumericDataInstance numericDataInstance) {
FastVector<Attribute> attributes = new FastVector<>();

for (int i = 0; i < numericDataInstance.getNumNumerics(); i++) {
attributes.addElement(new Attribute("numeric" + (i + 1)));
}

FastVector<String> classLabels = new FastVector<>();
for (int i = 0; i < numericDataInstance.getNumClasses(); i++) {
classLabels.addElement("class" + (i + 1));
}
attributes.addElement(new Attribute("class", classLabels));

InstancesHeader instancesHeader = new InstancesHeader(
new Instances(null, attributes, 0));
instancesHeader.setClassIndex(instancesHeader.numAttributes() - 1);

return instancesHeader;
}

static private Instance convertNumericInstance(NumericDataInstance numericDataInstance) {
InstancesHeader header = InstanceUtils.getNumericInstanceHeader(numericDataInstance);
Instance inst = new DenseInstance(header.numAttributes());

for (int i = 0; i < numericDataInstance.getNumNumerics(); i++) {
inst.setValue(i, numericDataInstance.getData()[i]);
}

inst.setDataset(header);
inst.setClassValue(numericDataInstance.getTrueClass());

return inst;
}

static private InstancesHeader getNominalInstanceHeader(NominalDataInstance nominalDataInstance) {
FastVector<Attribute> attributes = new FastVector<>();

for (int i = 0; i < nominalDataInstance.getNumNominals(); i++) {
FastVector<String> nominalAttVals = new FastVector<>();
for (int j = 0; j < nominalDataInstance.getNumValsPerNominal()[i]; j++) {
nominalAttVals.addElement("value" + (j + 1));
}
attributes.addElement(new Attribute("nominal" + (i + 1),
nominalAttVals));
}

FastVector<String> classLabels = new FastVector<>();
for (int i = 0; i < nominalDataInstance.getNumClasses(); i++) {
classLabels.addElement("class" + (i + 1));
}
attributes.addElement(new Attribute("class", classLabels));

InstancesHeader instancesHeader = new InstancesHeader(
new Instances(null, attributes, 0));
instancesHeader.setClassIndex(instancesHeader.numAttributes() - 1);

return instancesHeader;
}

static private Instance convertNominalInstance(NominalDataInstance nominalDataInstance) {
InstancesHeader header = InstanceUtils.getNominalInstanceHeader(nominalDataInstance);
Instance inst = new DenseInstance(header.numAttributes());

for (int i = 0; i < nominalDataInstance.getNumNominals(); i++) {
inst.setValue(i, nominalDataInstance.getData()[i]);
}

inst.setDataset(header);
inst.setClassValue(nominalDataInstance.getTrueClass());

return inst;
}

static private InstancesHeader getClusterInstanceHeader(ClusterDataInstance clusterDataInstance) {
ArrayList<Attribute> attributes = new ArrayList<>();

for (int i = 0; i < clusterDataInstance.getNumAtts(); i++) {
attributes.add(new Attribute("att" + (i + 1)));
}

// attributes.add(new Attribute("class", null));

InstancesHeader instancesHeader = new InstancesHeader(
new Instances(null, attributes, 0));
instancesHeader.setClassIndex(instancesHeader.numAttributes() - 1);

return instancesHeader;
}

static private Instance convertClusterInstance(ClusterDataInstance clusterDataInstance) {
Instance inst = new DenseInstance(1.0, clusterDataInstance.getData());
inst.setDataset(InstanceUtils.getClusterInstanceHeader(clusterDataInstance));
return new DataPoint(inst, clusterDataInstance.getTimeStamp());
}

static public Instance convertToSamoaInstance(DataInstance dataInstance) {
if (dataInstance instanceof NumericDataInstance) {
return InstanceUtils.convertNumericInstance((NumericDataInstance) dataInstance);
} else if (dataInstance instanceof NominalDataInstance) {
return InstanceUtils.convertNominalInstance((NominalDataInstance) dataInstance);
} else if (dataInstance instanceof ClusterDataInstance) {
return InstanceUtils.convertClusterInstance((ClusterDataInstance) dataInstance);
} else {
throw new Error("Invalid input class!");
}
}
}
4 changes: 1 addition & 3 deletions samoa-api/src/main/java/org/apache/samoa/learners/Model.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,8 @@
*/


import org.apache.samoa.instances.Instance;

import java.io.Serializable;

public interface Model extends Serializable {
double[] predict(Instance inst);
double[] predict(DataInstance dataInstance);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package org.apache.samoa.learners.classifiers;

/*
* #%L
* SAMOA
* %%
* Copyright (C) 2014 - 2016 Apache Software Foundation
* %%
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* #L%
*/

import org.apache.samoa.learners.DataInstance;

public class NominalDataInstance implements DataInstance {

private int numNominals;
private int numClasses;
private double trueClass; // index started from 0
private int[] numValsPerNominal;
private double[] data; // index started from 0

public NominalDataInstance(int numNominals, int numClasses, double trueClass,
int[] numValsPerNominal, double[] data) {
this.numNominals = numNominals;
this.numClasses = numClasses;
this.trueClass = trueClass;
this.numValsPerNominal = numValsPerNominal;
this.data = data;
}

public int getNumNominals() {
return numNominals;
}

public int getNumClasses() {
return numClasses;
}

public double getTrueClass() {
return trueClass;
}

public int[] getNumValsPerNominal() {
return numValsPerNominal;
}

public double[] getData() {
return data;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package org.apache.samoa.learners.classifiers;

/*
* #%L
* SAMOA
* %%
* Copyright (C) 2014 - 2016 Apache Software Foundation
* %%
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* #L%
*/

import org.apache.samoa.learners.DataInstance;

public class NumericDataInstance implements DataInstance {
private int numNumerics;
private int numClasses;
private double trueClass; // index started from 0
private double[] data;

public NumericDataInstance(int numNumerics, int numClasses, double trueClass, double[] data) {
this.numNumerics = numNumerics;
this.numClasses = numClasses;
this.trueClass = trueClass;
this.data = data;
}

public int getNumNumerics() {
return numNumerics;
}

public int getNumClasses() {
return numClasses;
}

public double getTrueClass() {
return trueClass;
}

public double[] getData() {
return data;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

import org.apache.samoa.instances.Instance;
import org.apache.samoa.instances.Utils;
import org.apache.samoa.learners.DataInstance;
import org.apache.samoa.learners.InstanceUtils;
import org.apache.samoa.learners.Model;
import org.apache.samoa.moa.core.DoubleVector;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;

/*
* #%L
Expand All @@ -33,19 +33,16 @@ public class EnsembleModel implements Model {
private ArrayList<Model> modelList;
private ArrayList<Double> modelWeightList;

public EnsembleModel() {
}

public EnsembleModel(ArrayList<Model> modelList, ArrayList<Double> modelWeightList) {
this.modelList = modelList;
this.modelWeightList = modelWeightList;
}

@Override
public double[] predict(Instance inst) {
public double[] predict(DataInstance dataInstance) {
DoubleVector combinedVote = new DoubleVector();
for (int i = 0; i < modelList.size(); i++) {
double[] prediction = modelList.get(i).predict(inst);
double[] prediction = modelList.get(i).predict(dataInstance);
DoubleVector vote = new DoubleVector(prediction);
if (vote.sumOfValues() > 0.0) {
vote.normalize();
Expand All @@ -56,9 +53,13 @@ public double[] predict(Instance inst) {
return combinedVote.getArrayCopy();
}

public boolean evaluate(Instance inst) {
/*
Predict the class of an input data instance, and evaluate if it is the true class.
*/
public boolean evaluate(DataInstance dataInstance) {
Instance inst = InstanceUtils.convertToSamoaInstance(dataInstance);
int trueClass = (int) inst.classValue();
double[] prediction = this.predict(inst);
double[] prediction = this.predict(dataInstance);
int predictedClass = Utils.maxIndex(prediction);
return trueClass == predictedClass;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
*/

import org.apache.samoa.instances.Instance;
import org.apache.samoa.learners.DataInstance;
import org.apache.samoa.learners.InstanceUtils;
import org.apache.samoa.learners.Model;
import org.apache.samoa.learners.classifiers.rules.common.ActiveRule;
import org.apache.samoa.learners.classifiers.rules.common.PassiveRule;
Expand All @@ -34,9 +36,6 @@ public class AMRulesModel implements Model {
private ErrorWeightedVote errorWeightedVote;
private boolean unorderedRules;

public AMRulesModel() {
}

public AMRulesModel(ActiveRule defaultRule, List<PassiveRule> ruleSet,
ErrorWeightedVote errorWeightedVote, boolean unorderedRules) {
this.defaultRule = defaultRule;
Expand All @@ -46,7 +45,8 @@ public AMRulesModel(ActiveRule defaultRule, List<PassiveRule> ruleSet,
}

@Override
public double[] predict(Instance inst) {
public double[] predict(DataInstance dataInstance) {
Instance inst = InstanceUtils.convertToSamoaInstance(dataInstance);
double[] prediction;
boolean predictionCovered = false;

Expand Down
Loading

0 comments on commit afbcc6f

Please sign in to comment.