summaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/experiments/RunPredictionServer.java183
-rw-r--r--src/log4j.properties17
-rw-r--r--src/predictor/PredictorFactory.java98
-rw-r--r--src/predictor/PredictorInterface.java18
4 files changed, 316 insertions, 0 deletions
diff --git a/src/experiments/RunPredictionServer.java b/src/experiments/RunPredictionServer.java
new file mode 100644
index 0000000..6612a86
--- /dev/null
+++ b/src/experiments/RunPredictionServer.java
@@ -0,0 +1,183 @@
+package experiments;
+
+import java.io.*;
+import java.text.DateFormat;
+import java.text.SimpleDateFormat;
+import java.util.*;
+
+import weka.filters.Filter;
+import model.Model;
+import org.apache.log4j.Logger;
+import weka.core.Instances;
+import weka.core.OptionHandler;
+
+/**
+ * Created by hailiu on 2016/2/22.
+ */
+public class RunPredictionServer {
+ protected Logger logger = Logger.getLogger(RunPredictionServer.class);
+
+ final String configFile = "data/config.txt";//input file path
+ String[] option = new String[50];
+ String filter;
+ ArrayList filterList;
+ String DataFile;
+ protected static DateFormat dateFormat = new SimpleDateFormat("yyyyMMdd_HHmmss");
+ protected static Date date = new Date();
+ protected static String resultFilename = dateFormat.format(date)+"_results";
+ Map map;
+
+ public RunPredictionServer() {
+ this.filterList = new ArrayList();
+ setClassName();
+ }
+
+ public static void main(String[] args)
+ {
+ RunPredictionServer runServer = new RunPredictionServer();
+ //read config file
+ runServer.loadConfigFile();
+ Model tempModel = new Model();
+ tempModel.loadTrainingData(runServer.DataFile);
+ tempModel.addPredictor("ZEROR");
+ tempModel.addPredictor("PART");
+ tempModel.addPredictor("ONER");
+ tempModel.addPredictor("JRIP");
+ tempModel.addPredictor("IBK");
+ tempModel.addPredictor("NBM");
+ tempModel.addPredictor("RF");
+ tempModel.addPredictor("LWL");
+ tempModel.addPredictor("NBC");
+ tempModel.addPredictor("BN");
+ tempModel.addPredictor("REPTREE");
+ tempModel.addPredictor("DT");
+ tempModel.addPredictor("J48");
+ tempModel.addPredictor("SMO");
+ tempModel.addPredictor("MP");
+ tempModel.addPredictor("SL");
+ tempModel.addPredictor("LOG");
+ tempModel.addPredictor("SGD");
+ tempModel.addPredictor("VP");
+ tempModel.addPredictor("SVM");
+ tempModel.addPredictor("KSTAR");
+
+ Instances tempInstances=tempModel.getTrainingInstances();
+
+ Iterator ite=runServer.filterList.iterator();
+
+ while (ite.hasNext())
+ {
+ ArrayList filterOption=(ArrayList) ite.next();
+ Iterator tempIte = filterOption.iterator();
+ String tFilter=null;
+ if (tempIte.hasNext())
+ {
+ tFilter=(String) tempIte.next();
+ }
+ String[] option = new String[filterOption.size()-1];
+ int i=0;
+ while (tempIte.hasNext())
+ {
+ option[i]=(String) tempIte.next();
+ i++;
+ }
+ tempInstances = runServer.addFilter(tempInstances,tFilter,option);
+ }
+
+ //tempInstances.setClassIndex(tempInstances.numAttributes()-2);
+ tempModel.setPreprocessedInstances(tempInstances);
+ tempModel.savePreprocessedInstances("preprocessed.arff");
+
+ try {
+ tempModel.benchmark(2,resultFilename);
+
+ }catch (Exception e)
+ {
+ runServer.logger.warn(e.toString());
+ }
+
+ tempModel.crossValidatePredictors(10);
+ try {
+ tempModel.saveResults(resultFilename);
+ }catch (Exception e)
+ {
+ runServer.logger.warn(e.toString());
+ }
+ }
+
+ //read config file
+ private void loadConfigFile() {
+ logger.debug("Reading configure from " + configFile);
+ try {
+ BufferedReader in = new BufferedReader(new FileReader(configFile));
+ DataFile = in.readLine();
+ System.out.println(DataFile);
+
+ String tempStr=null;
+ while((tempStr=in.readLine())!=null)
+ {
+ String[] tempString = tempStr.split(" ");
+ ArrayList<String> filterOption = new ArrayList<>();
+ for(String ts:tempString)
+ {
+ filterOption.add(ts);
+ }
+ this.filterList.add(filterOption);
+ }
+ } catch (Exception e) {
+ logger.warn(e.toString());
+ }
+ }
+
+ private Instances addFilter(Instances tempInstances,String filter, String[] option)
+ {
+ Class<?> tempClass=null;
+ try {
+ tempClass = Class.forName(this.map.get(filter).toString());
+ }catch (ClassNotFoundException e)
+ {
+ logger.warn(e.toString());
+ }
+
+ if(tempClass==null) return null;
+ Object tempFilter=null;
+ try {
+ tempFilter = tempClass.newInstance();
+ }catch (InstantiationException e)
+ {
+ logger.warn(e.toString());
+ }
+ catch (IllegalAccessException e)
+ {
+ logger.warn(e.toString());
+ }
+
+ try {
+ ((OptionHandler)tempFilter).setOptions(option);
+ ((Filter)tempFilter).setInputFormat(tempInstances);
+ }catch (Exception e)
+ {
+ logger.warn(e.toString());
+ }
+
+ Instances newInstances=null;
+ try {
+ newInstances = Filter.useFilter(tempInstances, (Filter) (tempFilter));
+ }catch (Exception e)
+ {
+ logger.warn(e.toString());
+ }
+ return newInstances;
+ }
+
+ private void setClassName()
+ {
+ map=new HashMap();
+ map.put("addexpression","weka.filters.unsupervised.attribute.AddExpression");
+ map.put("remove","weka.filters.unsupervised.attribute.Remove");
+ map.put("classassigner","weka.filters.unsupervised.attribute.ClassAssigner");
+ map.put("numerictonominal","weka.filters.unsupervised.attribute.NumericToNominal");
+ map.put("mathexpression","weka.filters.unsupervised.attribute.MathExpression");
+ }
+}
+
diff --git a/src/log4j.properties b/src/log4j.properties
new file mode 100644
index 0000000..797baca
--- /dev/null
+++ b/src/log4j.properties
@@ -0,0 +1,17 @@
+log4j.appender.consoleAppender = org.apache.log4j.ConsoleAppender
+log4j.appender.rollingFile = org.apache.log4j.RollingFileAppender
+log4j.appender.rollingFile.File = T1.log
+log4j.appender.rollingFile.MaxFileSize = 10MB
+log4j.appender.rollingFile.MaxBackupIndex = 5
+
+log4j.appender.consoleAppender.layout = org.apache.log4j.PatternLayout
+log4j.appender.consoleAppender.layout.ConversionPattern = %d %-6r [%t] %-5p %-90l - %m%n
+log4j.appender.rollingFile.layout = org.apache.log4j.PatternLayout
+log4j.appender.rollingFile.layout.ConversionPattern = %d %-6r [%t] %-5p %-90l - %m%n
+
+#log4j.rootLogger = TRACE, consoleAppender, rollingFile
+log4j.rootLogger = DEBUG, consoleAppender, rollingFile
+#log4j.rootLogger = INFO, consoleAppender, rollingFile
+#log4j.rootLogger = WARN, consoleAppender, rollingFile
+#log4j.rootLogger = TRACE, rollingFile
+
diff --git a/src/predictor/PredictorFactory.java b/src/predictor/PredictorFactory.java
new file mode 100644
index 0000000..af21b10
--- /dev/null
+++ b/src/predictor/PredictorFactory.java
@@ -0,0 +1,98 @@
+package predictor;
+
+import org.apache.log4j.*;
+
+import weka.classifiers.bayes.*;
+import weka.classifiers.trees.*;
+import weka.classifiers.rules.*;
+import weka.classifiers.functions.*;
+import weka.classifiers.functions.LibSVM;
+import weka.classifiers.lazy.*;
+
+import predictor.PredictorInterface;
+
+public class PredictorFactory {
+ public static Logger logger = Logger.getLogger(PredictorFactory.class);
+
+ public static enum PredictionTechnique {
+ NAIVE_BAYES ("NBC", "Naive Bayes Classifier"),
+ BAYES_NET ("BN", "Bayesian Network"),
+ M5P ("M5P", "M5P Decision Tree"),
+ J48 ("J48", "C4.5 Decision Tree"),
+ DT ("DT", "Decision Table"),
+ ZEROR ("ZEROR", "ZeroR"),
+ REPTREE ("REPTREE", "REPTree"),
+ SMO ("SMO", "Sequential Minimal Optimization"),
+ RBFN ("RBFN", "RBF Network"),
+ MP ("MP", "Multilayer Perceptron"),
+ SLR ("SLR", "Simple Linear Regression"),
+ SL ("SL", "Simple Logistic"),
+ SVM ("SVM", "Support Vector Machine"),
+ LOG ("LOG", "Logistic"),
+ SGD ("SGD", "Stochastic Gradient Descent"),
+ VP ("VP", "VotedPerceptron"),
+ SMOR ("SMOR", "Sequential Minimal Optimization Regression"),
+ KSTAR ("KSTAR", "KStar"),
+ LWL ("LWL", "Locally weighted learning"),
+ RF ("RF", "Random Forest"),
+ NBM ("NBM", "Naive Bayes Multinomial"),
+ IBK ("IBK", "Instance-based Learning"),
+ JRIP ("JRIP", "JRip"),
+ M5R ("M5R", "M5Rules"),
+ ONER ("ONER", "OneR"),
+ PART ("PART", "PART"),
+ ;
+
+ private final String shortName;
+ private final String name;
+
+ PredictionTechnique(String shortName, String name) {
+ this.shortName = shortName;
+ this.name = name;
+ }
+
+ public String getShortName() {
+ return this.shortName;
+ }
+
+ public String getName() {
+ return this.name;
+ }
+ }
+
+ public static PredictorInterface createPredictor(PredictionTechnique pTechnique) {
+ String name = pTechnique.getName();
+ logger.debug("Creating predictor: " + name);
+
+ switch(pTechnique) {
+ case NAIVE_BAYES: return new ClassifierAdapter(new NaiveBayes(),name);
+ case BAYES_NET: return new ClassifierAdapter(new BayesNet(),name);
+ case M5P: return new ClassifierAdapter(new M5P(), name);
+ case J48: return new ClassifierAdapter(new J48(), name);
+ case DT: return new ClassifierAdapter(new DecisionTable(), name);
+ case ZEROR: return new ClassifierAdapter(new ZeroR(), name);
+ case REPTREE: return new ClassifierAdapter(new REPTree(), name);
+ case SMO: return new ClassifierAdapter(new SMO(), name);
+ //case RBFN: return new ClassifierAdapter(new RBFNetwork(), name);
+ case MP: return new ClassifierAdapter(new MultilayerPerceptron(), name);
+ case SLR: return new ClassifierAdapter(new SimpleLinearRegression(), name);
+ case SL: return new ClassifierAdapter(new SimpleLogistic(), name);
+ case SVM: return new ClassifierAdapter(new LibSVM(), name);
+ case LOG: return new ClassifierAdapter(new Logistic(), name);
+ //case SGD: return new ClassifierAdapter(new SGD(), name);
+ case VP: return new ClassifierAdapter(new VotedPerceptron(), name);
+ case SMOR: return new ClassifierAdapter(new SMOreg(), name);
+ case KSTAR: return new ClassifierAdapter(new KStar(), name);
+ case LWL: return new ClassifierAdapter(new LWL(), name);
+ case RF: return new ClassifierAdapter(new RandomForest(), name);
+ case NBM: return new ClassifierAdapter(new NaiveBayesMultinomial(), name);
+ case IBK: return new ClassifierAdapter(new IBk(), name);
+ case JRIP: return new ClassifierAdapter(new JRip(), name);
+ case M5R: return new ClassifierAdapter(new M5Rules(), name);
+ case ONER: return new ClassifierAdapter(new OneR(), name);
+ case PART: return new ClassifierAdapter(new PART(), name);
+ default: return new ClassifierAdapter(new NaiveBayes(),name);
+ }
+ }
+}
+
diff --git a/src/predictor/PredictorInterface.java b/src/predictor/PredictorInterface.java
new file mode 100644
index 0000000..03b9052
--- /dev/null
+++ b/src/predictor/PredictorInterface.java
@@ -0,0 +1,18 @@
+package predictor;
+
+import java.util.Random;
+
+import weka.core.*;
+
+public interface PredictorInterface {
+ public String getName();
+ public void crossValidate(Instances instances, int numFold, Random rand) throws Exception;
+ //public String getEvaluationSummaryString();
+ //public String getEvaluationMatrixString() throws Exception;
+ public FastVector getEvaluationPredictions();
+ public void train(Instances instances) throws Exception;
+ public int predict(Instance instance) throws Exception;
+ public int predict(Instances instances) throws Exception;
+ public String getEvaluationResults() throws Exception;
+}
+