From 16100414d793a0757e976c86be2ee3025085d517 Mon Sep 17 00:00:00 2001 From: pangolulu Date: Fri, 19 Feb 2016 14:14:07 +0800 Subject: [PATCH] Support user input instance --- .../apache/samoa/learners/InputInstance.java | 29 +++++++ .../java/org/apache/samoa/learners/Model.java | 4 +- .../classifiers/NominalInputInstance.java | 83 +++++++++++++++++++ .../classifiers/NumericInputInstance.java | 74 +++++++++++++++++ .../classifiers/ensemble/EnsembleModel.java | 10 ++- .../classifiers/rules/AMRulesModel.java | 4 +- .../classifiers/trees/HoeffdingTreeModel.java | 9 +- .../learners/clusterers/CluStreamModel.java | 14 ++-- .../clusterers/ClusterInputInstance.java | 62 ++++++++++++++ .../src/main/resources/reference.conf | 3 + 10 files changed, 276 insertions(+), 16 deletions(-) create mode 100644 samoa-api/src/main/java/org/apache/samoa/learners/InputInstance.java create mode 100644 samoa-api/src/main/java/org/apache/samoa/learners/classifiers/NominalInputInstance.java create mode 100644 samoa-api/src/main/java/org/apache/samoa/learners/classifiers/NumericInputInstance.java create mode 100644 samoa-api/src/main/java/org/apache/samoa/learners/clusterers/ClusterInputInstance.java diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/InputInstance.java b/samoa-api/src/main/java/org/apache/samoa/learners/InputInstance.java new file mode 100644 index 00000000..c62280e4 --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/learners/InputInstance.java @@ -0,0 +1,29 @@ +package org.apache.samoa.learners; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2016 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import org.apache.samoa.instances.Instance; + +import java.io.Serializable; + +public interface InputInstance extends Serializable { + Instance getInstance(); +} diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/Model.java b/samoa-api/src/main/java/org/apache/samoa/learners/Model.java index f955c041..0e9dbd4f 100644 --- a/samoa-api/src/main/java/org/apache/samoa/learners/Model.java +++ b/samoa-api/src/main/java/org/apache/samoa/learners/Model.java @@ -21,10 +21,8 @@ */ -import org.apache.samoa.instances.Instance; - import java.io.Serializable; public interface Model extends Serializable { - double[] predict(Instance inst); + double[] predict(InputInstance inputInstance); } diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/NominalInputInstance.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/NominalInputInstance.java new file mode 100644 index 00000000..339fa488 --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/NominalInputInstance.java @@ -0,0 +1,83 @@ +package org.apache.samoa.learners.classifiers; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2016 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import org.apache.samoa.instances.*; +import org.apache.samoa.learners.InputInstance; +import org.apache.samoa.moa.core.FastVector; + +public class NominalInputInstance implements InputInstance { + + private int numNominals; + private int numClasses; + private double trueClass; // index started from 0 + private int[] numValsPerNominal; + private double[] data; // index started from 0 + + public NominalInputInstance(int numNominals, int numClasses, double trueClass, + int[] numValsPerNominal, double[] data) { + this.numNominals = numNominals; + this.numClasses = numClasses; + this.trueClass = trueClass; + this.numValsPerNominal = numValsPerNominal; + this.data = data; + } + + private InstancesHeader generateHeader() { + FastVector attributes = new FastVector<>(); + + for (int i = 0; i < this.numNominals; i++) { + FastVector nominalAttVals = new FastVector<>(); + for (int j = 0; j < this.numValsPerNominal[i]; j++) { + nominalAttVals.addElement("value" + (j + 1)); + } + attributes.addElement(new Attribute("nominal" + (i + 1), + nominalAttVals)); + } + + FastVector classLabels = new FastVector<>(); + for (int i = 0; i < this.numClasses; i++) { + classLabels.addElement("class" + (i + 1)); + } + attributes.addElement(new Attribute("class", classLabels)); + + InstancesHeader instancesHeader = new InstancesHeader( + new Instances(null, attributes, 0)); + instancesHeader.setClassIndex(instancesHeader.numAttributes() - 1); + + return instancesHeader; + } + + @Override + public Instance getInstance() { + InstancesHeader header = this.generateHeader(); + Instance inst = new DenseInstance(header.numAttributes()); + + for (int i = 0; i < this.numNominals; i++) { + inst.setValue(i, this.data[i]); + } + + inst.setDataset(header); + inst.setClassValue(this.trueClass); + + return inst; + } +} diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/NumericInputInstance.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/NumericInputInstance.java new file mode 100644 index 00000000..0b2cd405 --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/NumericInputInstance.java @@ -0,0 +1,74 @@ +package org.apache.samoa.learners.classifiers; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2016 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import org.apache.samoa.instances.*; +import org.apache.samoa.learners.InputInstance; +import org.apache.samoa.moa.core.FastVector; + +public class NumericInputInstance implements InputInstance { + private int numNumerics; + private int numClasses; + private double trueClass; // index started from 0 + private double[] data; + + public NumericInputInstance(int numNumerics, int numClasses, double trueClass, double[] data) { + this.numNumerics = numNumerics; + this.numClasses = numClasses; + this.trueClass = trueClass; + this.data = data; + } + + private InstancesHeader generateHeader() { + FastVector attributes = new FastVector<>(); + + for (int i = 0; i < this.numNumerics; i++) { + attributes.addElement(new Attribute("numeric" + (i + 1))); + } + + FastVector classLabels = new FastVector<>(); + for (int i = 0; i < this.numClasses; i++) { + classLabels.addElement("class" + (i + 1)); + } + attributes.addElement(new Attribute("class", classLabels)); + + InstancesHeader instancesHeader = new InstancesHeader( + new Instances(null, attributes, 0)); + instancesHeader.setClassIndex(instancesHeader.numAttributes() - 1); + + return instancesHeader; + } + + @Override + public Instance getInstance() { + InstancesHeader header = this.generateHeader(); + Instance inst = new DenseInstance(header.numAttributes()); + + for (int i = 0; i < this.numNumerics; i++) { + inst.setValue(i, this.data[i]); + } + + inst.setDataset(header); + inst.setClassValue(this.trueClass); + + return inst; + } +} diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/EnsembleModel.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/EnsembleModel.java index 00dd126a..475b3086 100644 --- a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/EnsembleModel.java +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/EnsembleModel.java @@ -2,6 +2,7 @@ import org.apache.samoa.instances.Instance; import org.apache.samoa.instances.Utils; +import org.apache.samoa.learners.InputInstance; import org.apache.samoa.learners.Model; import org.apache.samoa.moa.core.DoubleVector; @@ -42,10 +43,10 @@ public EnsembleModel(ArrayList modelList, ArrayList modelWeightLi } @Override - public double[] predict(Instance inst) { + public double[] predict(InputInstance inputInstance) { DoubleVector combinedVote = new DoubleVector(); for (int i = 0; i < modelList.size(); i++) { - double[] prediction = modelList.get(i).predict(inst); + double[] prediction = modelList.get(i).predict(inputInstance); DoubleVector vote = new DoubleVector(prediction); if (vote.sumOfValues() > 0.0) { vote.normalize(); @@ -56,9 +57,10 @@ public double[] predict(Instance inst) { return combinedVote.getArrayCopy(); } - public boolean evaluate(Instance inst) { + public boolean evaluate(InputInstance inputInstance) { + Instance inst = inputInstance.getInstance(); int trueClass = (int) inst.classValue(); - double[] prediction = this.predict(inst); + double[] prediction = this.predict(inputInstance); int predictedClass = Utils.maxIndex(prediction); return trueClass == predictedClass; } diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/AMRulesModel.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/AMRulesModel.java index 5565ea7d..6b92407d 100644 --- a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/AMRulesModel.java +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/AMRulesModel.java @@ -21,6 +21,7 @@ */ import org.apache.samoa.instances.Instance; +import org.apache.samoa.learners.InputInstance; import org.apache.samoa.learners.Model; import org.apache.samoa.learners.classifiers.rules.common.ActiveRule; import org.apache.samoa.learners.classifiers.rules.common.PassiveRule; @@ -46,7 +47,8 @@ public AMRulesModel(ActiveRule defaultRule, List ruleSet, } @Override - public double[] predict(Instance inst) { + public double[] predict(InputInstance inputInstance) { + Instance inst = inputInstance.getInstance(); double[] prediction; boolean predictionCovered = false; diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/trees/HoeffdingTreeModel.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/trees/HoeffdingTreeModel.java index 4b856795..7999a2fe 100644 --- a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/trees/HoeffdingTreeModel.java +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/trees/HoeffdingTreeModel.java @@ -24,6 +24,7 @@ import org.apache.samoa.instances.Instance; import org.apache.samoa.instances.Instances; import org.apache.samoa.instances.Utils; +import org.apache.samoa.learners.InputInstance; import org.apache.samoa.learners.Model; public class HoeffdingTreeModel implements Model { @@ -39,7 +40,8 @@ public HoeffdingTreeModel() { } @Override - public double[] predict(Instance inst) { + public double[] predict(InputInstance inputInstance) { + Instance inst = inputInstance.getInstance(); double[] prediction; // inst.setDataset(dataset); @@ -58,9 +60,10 @@ public double[] predict(Instance inst) { return prediction; } - public boolean evaluate(Instance inst) { + public boolean evaluate(InputInstance inputInstance) { + Instance inst = inputInstance.getInstance(); int trueClass = (int) inst.classValue(); - double[] prediction = this.predict(inst); + double[] prediction = this.predict(inputInstance); int predictedClass = Utils.maxIndex(prediction); return trueClass == predictedClass; } diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/CluStreamModel.java b/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/CluStreamModel.java index dd153cd6..d24ca26e 100644 --- a/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/CluStreamModel.java +++ b/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/CluStreamModel.java @@ -21,7 +21,7 @@ */ -import org.apache.samoa.instances.Instance; +import org.apache.samoa.learners.InputInstance; import org.apache.samoa.learners.Model; import org.apache.samoa.moa.cluster.Clustering; import org.apache.samoa.moa.core.DataPoint; @@ -40,8 +40,8 @@ public CluStreamModel() { } @Override - public double[] predict(Instance inst) { - DataPoint dataPoint = (DataPoint) inst; + public double[] predict(InputInstance inputInstance) { + DataPoint dataPoint = (DataPoint) inputInstance.getInstance(); double[] distances = new double[clustering.size()]; for (int c = 0; c < clustering.size(); c++) { double distance = 0.0; @@ -55,10 +55,14 @@ public double[] predict(Instance inst) { return distances; } - public double evaluate(ArrayList points, MeasureCollection measure) { + public double evaluate(ArrayList points, MeasureCollection measure) { + ArrayList dataPoints = new ArrayList<>(); + for (InputInstance inputInstance : points) { + dataPoints.add((DataPoint) inputInstance.getInstance()); + } double score = 0.0; try { - measure.evaluateClusteringPerformance(clustering, null, points); + measure.evaluateClusteringPerformance(clustering, null, dataPoints); score = measure.getMean(0); } catch (Exception e) { e.printStackTrace(); diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/ClusterInputInstance.java b/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/ClusterInputInstance.java new file mode 100644 index 00000000..813097ea --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/ClusterInputInstance.java @@ -0,0 +1,62 @@ +package org.apache.samoa.learners.clusterers; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2016 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import org.apache.samoa.instances.*; +import org.apache.samoa.learners.InputInstance; +import org.apache.samoa.moa.core.DataPoint; + +import java.util.ArrayList; + +public class ClusterInputInstance implements InputInstance { + private int numAtts; + private int timeStamp; + private double[] data; + + public ClusterInputInstance(int numAtts, int timeStamp, double[] data) { + this.numAtts = numAtts; + this.timeStamp = timeStamp; + this.data = data; + } + + private InstancesHeader generateHeader() { + ArrayList attributes = new ArrayList<>(); + + for (int i = 0; i < this.numAtts; i++) { + attributes.add(new Attribute("att" + (i + 1))); + } + + // attributes.add(new Attribute("class", null)); + + InstancesHeader instancesHeader = new InstancesHeader( + new Instances(null, attributes, 0)); + instancesHeader.setClassIndex(instancesHeader.numAttributes() - 1); + + return instancesHeader; + } + + @Override + public Instance getInstance() { + Instance inst = new DenseInstance(1.0, this.data); + inst.setDataset(generateHeader()); + return new DataPoint(inst, this.timeStamp); + } +} diff --git a/samoa-gearpump/src/main/resources/reference.conf b/samoa-gearpump/src/main/resources/reference.conf index a6e4d8e5..fbb3f3b3 100644 --- a/samoa-gearpump/src/main/resources/reference.conf +++ b/samoa-gearpump/src/main/resources/reference.conf @@ -57,6 +57,9 @@ gearpump { "org.apache.samoa.learners.classifiers.trees.HoeffdingTreeModel" = "" "org.apache.samoa.learners.classifiers.trees.ActiveLearningNode" = "" "[Lorg.apache.samoa.learners.classifiers.trees.AttributeBatchContentEvent;" = "" + "com.github.javacliparser.IntOption" = "" + "com.github.javacliparser.FloatOption" = "" + "org.apache.samoa.learners.InstanceContent" = "" "java.util.ArrayList" = "" "java.util.LinkedList" = "" "java.util.HashMap" = ""