diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/experiments/RunPredictionServer.java | 183 | ||||
-rw-r--r-- | src/log4j.properties | 17 | ||||
-rw-r--r-- | src/predictor/PredictorFactory.java | 98 | ||||
-rw-r--r-- | src/predictor/PredictorInterface.java | 18 |
4 files changed, 316 insertions, 0 deletions
diff --git a/src/experiments/RunPredictionServer.java b/src/experiments/RunPredictionServer.java new file mode 100644 index 0000000..6612a86 --- /dev/null +++ b/src/experiments/RunPredictionServer.java @@ -0,0 +1,183 @@ +package experiments; + +import java.io.*; +import java.text.DateFormat; +import java.text.SimpleDateFormat; +import java.util.*; + +import weka.filters.Filter; +import model.Model; +import org.apache.log4j.Logger; +import weka.core.Instances; +import weka.core.OptionHandler; + +/** + * Created by hailiu on 2016/2/22. + */ +public class RunPredictionServer { + protected Logger logger = Logger.getLogger(RunPredictionServer.class); + + final String configFile = "data/config.txt";//input file path + String[] option = new String[50]; + String filter; + ArrayList filterList; + String DataFile; + protected static DateFormat dateFormat = new SimpleDateFormat("yyyyMMdd_HHmmss"); + protected static Date date = new Date(); + protected static String resultFilename = dateFormat.format(date)+"_results"; + Map map; + + public RunPredictionServer() { + this.filterList = new ArrayList(); + setClassName(); + } + + public static void main(String[] args) + { + RunPredictionServer runServer = new RunPredictionServer(); + //read config file + runServer.loadConfigFile(); + Model tempModel = new Model(); + tempModel.loadTrainingData(runServer.DataFile); + tempModel.addPredictor("ZEROR"); + tempModel.addPredictor("PART"); + tempModel.addPredictor("ONER"); + tempModel.addPredictor("JRIP"); + tempModel.addPredictor("IBK"); + tempModel.addPredictor("NBM"); + tempModel.addPredictor("RF"); + tempModel.addPredictor("LWL"); + tempModel.addPredictor("NBC"); + tempModel.addPredictor("BN"); + tempModel.addPredictor("REPTREE"); + tempModel.addPredictor("DT"); + tempModel.addPredictor("J48"); + tempModel.addPredictor("SMO"); + tempModel.addPredictor("MP"); + tempModel.addPredictor("SL"); + tempModel.addPredictor("LOG"); + tempModel.addPredictor("SGD"); + tempModel.addPredictor("VP"); + tempModel.addPredictor("SVM"); + tempModel.addPredictor("KSTAR"); + + Instances tempInstances=tempModel.getTrainingInstances(); + + Iterator ite=runServer.filterList.iterator(); + + while (ite.hasNext()) + { + ArrayList filterOption=(ArrayList) ite.next(); + Iterator tempIte = filterOption.iterator(); + String tFilter=null; + if (tempIte.hasNext()) + { + tFilter=(String) tempIte.next(); + } + String[] option = new String[filterOption.size()-1]; + int i=0; + while (tempIte.hasNext()) + { + option[i]=(String) tempIte.next(); + i++; + } + tempInstances = runServer.addFilter(tempInstances,tFilter,option); + } + + //tempInstances.setClassIndex(tempInstances.numAttributes()-2); + tempModel.setPreprocessedInstances(tempInstances); + tempModel.savePreprocessedInstances("preprocessed.arff"); + + try { + tempModel.benchmark(2,resultFilename); + + }catch (Exception e) + { + runServer.logger.warn(e.toString()); + } + + tempModel.crossValidatePredictors(10); + try { + tempModel.saveResults(resultFilename); + }catch (Exception e) + { + runServer.logger.warn(e.toString()); + } + } + + //read config file + private void loadConfigFile() { + logger.debug("Reading configure from " + configFile); + try { + BufferedReader in = new BufferedReader(new FileReader(configFile)); + DataFile = in.readLine(); + System.out.println(DataFile); + + String tempStr=null; + while((tempStr=in.readLine())!=null) + { + String[] tempString = tempStr.split(" "); + ArrayList<String> filterOption = new ArrayList<>(); + for(String ts:tempString) + { + filterOption.add(ts); + } + this.filterList.add(filterOption); + } + } catch (Exception e) { + logger.warn(e.toString()); + } + } + + private Instances addFilter(Instances tempInstances,String filter, String[] option) + { + Class<?> tempClass=null; + try { + tempClass = Class.forName(this.map.get(filter).toString()); + }catch (ClassNotFoundException e) + { + logger.warn(e.toString()); + } + + if(tempClass==null) return null; + Object tempFilter=null; + try { + tempFilter = tempClass.newInstance(); + }catch (InstantiationException e) + { + logger.warn(e.toString()); + } + catch (IllegalAccessException e) + { + logger.warn(e.toString()); + } + + try { + ((OptionHandler)tempFilter).setOptions(option); + ((Filter)tempFilter).setInputFormat(tempInstances); + }catch (Exception e) + { + logger.warn(e.toString()); + } + + Instances newInstances=null; + try { + newInstances = Filter.useFilter(tempInstances, (Filter) (tempFilter)); + }catch (Exception e) + { + logger.warn(e.toString()); + } + return newInstances; + } + + private void setClassName() + { + map=new HashMap(); + map.put("addexpression","weka.filters.unsupervised.attribute.AddExpression"); + map.put("remove","weka.filters.unsupervised.attribute.Remove"); + map.put("classassigner","weka.filters.unsupervised.attribute.ClassAssigner"); + map.put("numerictonominal","weka.filters.unsupervised.attribute.NumericToNominal"); + map.put("mathexpression","weka.filters.unsupervised.attribute.MathExpression"); + } +} + diff --git a/src/log4j.properties b/src/log4j.properties new file mode 100644 index 0000000..797baca --- /dev/null +++ b/src/log4j.properties @@ -0,0 +1,17 @@ +log4j.appender.consoleAppender = org.apache.log4j.ConsoleAppender +log4j.appender.rollingFile = org.apache.log4j.RollingFileAppender +log4j.appender.rollingFile.File = T1.log +log4j.appender.rollingFile.MaxFileSize = 10MB +log4j.appender.rollingFile.MaxBackupIndex = 5 + +log4j.appender.consoleAppender.layout = org.apache.log4j.PatternLayout +log4j.appender.consoleAppender.layout.ConversionPattern = %d %-6r [%t] %-5p %-90l - %m%n +log4j.appender.rollingFile.layout = org.apache.log4j.PatternLayout +log4j.appender.rollingFile.layout.ConversionPattern = %d %-6r [%t] %-5p %-90l - %m%n + +#log4j.rootLogger = TRACE, consoleAppender, rollingFile +log4j.rootLogger = DEBUG, consoleAppender, rollingFile +#log4j.rootLogger = INFO, consoleAppender, rollingFile +#log4j.rootLogger = WARN, consoleAppender, rollingFile +#log4j.rootLogger = TRACE, rollingFile + diff --git a/src/predictor/PredictorFactory.java b/src/predictor/PredictorFactory.java new file mode 100644 index 0000000..af21b10 --- /dev/null +++ b/src/predictor/PredictorFactory.java @@ -0,0 +1,98 @@ +package predictor; + +import org.apache.log4j.*; + +import weka.classifiers.bayes.*; +import weka.classifiers.trees.*; +import weka.classifiers.rules.*; +import weka.classifiers.functions.*; +import weka.classifiers.functions.LibSVM; +import weka.classifiers.lazy.*; + +import predictor.PredictorInterface; + +public class PredictorFactory { + public static Logger logger = Logger.getLogger(PredictorFactory.class); + + public static enum PredictionTechnique { + NAIVE_BAYES ("NBC", "Naive Bayes Classifier"), + BAYES_NET ("BN", "Bayesian Network"), + M5P ("M5P", "M5P Decision Tree"), + J48 ("J48", "C4.5 Decision Tree"), + DT ("DT", "Decision Table"), + ZEROR ("ZEROR", "ZeroR"), + REPTREE ("REPTREE", "REPTree"), + SMO ("SMO", "Sequential Minimal Optimization"), + RBFN ("RBFN", "RBF Network"), + MP ("MP", "Multilayer Perceptron"), + SLR ("SLR", "Simple Linear Regression"), + SL ("SL", "Simple Logistic"), + SVM ("SVM", "Support Vector Machine"), + LOG ("LOG", "Logistic"), + SGD ("SGD", "Stochastic Gradient Descent"), + VP ("VP", "VotedPerceptron"), + SMOR ("SMOR", "Sequential Minimal Optimization Regression"), + KSTAR ("KSTAR", "KStar"), + LWL ("LWL", "Locally weighted learning"), + RF ("RF", "Random Forest"), + NBM ("NBM", "Naive Bayes Multinomial"), + IBK ("IBK", "Instance-based Learning"), + JRIP ("JRIP", "JRip"), + M5R ("M5R", "M5Rules"), + ONER ("ONER", "OneR"), + PART ("PART", "PART"), + ; + + private final String shortName; + private final String name; + + PredictionTechnique(String shortName, String name) { + this.shortName = shortName; + this.name = name; + } + + public String getShortName() { + return this.shortName; + } + + public String getName() { + return this.name; + } + } + + public static PredictorInterface createPredictor(PredictionTechnique pTechnique) { + String name = pTechnique.getName(); + logger.debug("Creating predictor: " + name); + + switch(pTechnique) { + case NAIVE_BAYES: return new ClassifierAdapter(new NaiveBayes(),name); + case BAYES_NET: return new ClassifierAdapter(new BayesNet(),name); + case M5P: return new ClassifierAdapter(new M5P(), name); + case J48: return new ClassifierAdapter(new J48(), name); + case DT: return new ClassifierAdapter(new DecisionTable(), name); + case ZEROR: return new ClassifierAdapter(new ZeroR(), name); + case REPTREE: return new ClassifierAdapter(new REPTree(), name); + case SMO: return new ClassifierAdapter(new SMO(), name); + //case RBFN: return new ClassifierAdapter(new RBFNetwork(), name); + case MP: return new ClassifierAdapter(new MultilayerPerceptron(), name); + case SLR: return new ClassifierAdapter(new SimpleLinearRegression(), name); + case SL: return new ClassifierAdapter(new SimpleLogistic(), name); + case SVM: return new ClassifierAdapter(new LibSVM(), name); + case LOG: return new ClassifierAdapter(new Logistic(), name); + //case SGD: return new ClassifierAdapter(new SGD(), name); + case VP: return new ClassifierAdapter(new VotedPerceptron(), name); + case SMOR: return new ClassifierAdapter(new SMOreg(), name); + case KSTAR: return new ClassifierAdapter(new KStar(), name); + case LWL: return new ClassifierAdapter(new LWL(), name); + case RF: return new ClassifierAdapter(new RandomForest(), name); + case NBM: return new ClassifierAdapter(new NaiveBayesMultinomial(), name); + case IBK: return new ClassifierAdapter(new IBk(), name); + case JRIP: return new ClassifierAdapter(new JRip(), name); + case M5R: return new ClassifierAdapter(new M5Rules(), name); + case ONER: return new ClassifierAdapter(new OneR(), name); + case PART: return new ClassifierAdapter(new PART(), name); + default: return new ClassifierAdapter(new NaiveBayes(),name); + } + } +} + diff --git a/src/predictor/PredictorInterface.java b/src/predictor/PredictorInterface.java new file mode 100644 index 0000000..03b9052 --- /dev/null +++ b/src/predictor/PredictorInterface.java @@ -0,0 +1,18 @@ +package predictor; + +import java.util.Random; + +import weka.core.*; + +public interface PredictorInterface { + public String getName(); + public void crossValidate(Instances instances, int numFold, Random rand) throws Exception; + //public String getEvaluationSummaryString(); + //public String getEvaluationMatrixString() throws Exception; + public FastVector getEvaluationPredictions(); + public void train(Instances instances) throws Exception; + public int predict(Instance instance) throws Exception; + public int predict(Instances instances) throws Exception; + public String getEvaluationResults() throws Exception; +} + |