diff options
author | Hai Liu <hai.liu@huawei.com> | 2016-03-18 15:03:47 +0800 |
---|---|---|
committer | Hai Liu <hai.liu@huawei.com> | 2016-03-18 15:03:47 +0800 |
commit | 4034199deccff41cb6661812d4c4aa7c523d78b4 (patch) | |
tree | 3d047da70e1ee6a67e05cf62d3dd58602dad1d72 | |
parent | a4669a5eff3b0866f9f15346a5a0d4d81af409d7 (diff) |
Add data model for predictor
JIRA: PREDICTION-36
Change-Id: I871ce7e3696e7154ee4adb5c83cdb45c02a97d7b
Signed-off-by: Hai Liu <hai.liu@huawei.com>
-rw-r--r-- | src/model/Model.java | 325 | ||||
-rw-r--r-- | src/model/ModelInterface.java | 32 |
2 files changed, 357 insertions, 0 deletions
diff --git a/src/model/Model.java b/src/model/Model.java new file mode 100644 index 0000000..ca9897b --- /dev/null +++ b/src/model/Model.java @@ -0,0 +1,325 @@ +package model; + +import java.io.*; +import java.util.*; +import java.text.DateFormat; +import java.text.SimpleDateFormat; + +import org.apache.log4j.*; + +import org.apache.commons.math3.stat.descriptive.moment.*; +import org.apache.commons.math3.distribution.NormalDistribution; + +import weka.core.*; +import weka.classifiers.evaluation.*; +import weka.core.converters.*; + +import input.ARFFReader; +import predictor.*; + + +public class Model implements ModelInterface { + protected Logger logger = Logger.getLogger(Model.class); + + protected String datapath; + protected Instances trainingInstances; + protected Instances preprocessedInstances; + protected ArrayList<PredictorInterface> predictors; + + protected static DateFormat dateFormat = new SimpleDateFormat("yyyyMMdd_HHmmss"); + protected static Date date = new Date(); + protected static String resultFilename = dateFormat.format(date); + + public Model(){ + this.datapath = ""; + this.trainingInstances = null; + this.preprocessedInstances = null; + this.predictors = new ArrayList<PredictorInterface>(); + } + + public Model (ModelInterface model) { + this.datapath = model.getDatapath(); + this.trainingInstances = new Instances(model.getTrainingInstances()); + this.preprocessedInstances = new Instances(model.getPreprocessedInstances()); + this.predictors = new ArrayList<PredictorInterface>(model.getPredictors()); + } + + @Override + public Instances getPreprocessedInstances() { + return this.preprocessedInstances; + } + + @Override + public Instances getTrainingInstances() { + return this.trainingInstances; + } + + + @Override + public String getDatapath() { + return this.datapath; + } + + + @Override + public ArrayList<PredictorInterface> getPredictors() { + return this.predictors; + } + + + @Override + public void loadTrainingData(String path) { + logger.debug("Reading training data from " + path); + this.datapath = path; + try { + trainingInstances = ARFFReader.read(path); + preprocessedInstances = trainingInstances; + logger.debug("Training data is read"); + logger.trace(trainingInstances); + } catch (Exception e) { + logger.warn(e.toString()); + } + } + + @Override + public void loadRawLog(String path) { + logger.debug("Reading logfile from " + path); + try { + BufferedReader in = new BufferedReader(new FileReader(path)); + String line = in.readLine(); + System.out.println(line); + } catch (Exception e) { + logger.warn(e.toString()); + } + } + + + + @Override + public void setPreprocessedInstances(Instances instances){ + this.preprocessedInstances=new Instances(instances); + } + + @Override + public void savePreprocessedInstances(String path) { + ArffSaver saver = new ArffSaver(); + saver.setInstances(this.preprocessedInstances); + try { + saver.setFile(new File(path)); + saver.writeBatch(); + } catch (Exception e) { + logger.error("Cannot save preprocessed instances to file " + path); + } + } + + @Override + public void addPredictor(String shortName) { + for (PredictorFactory.PredictionTechnique pTechnique : PredictorFactory.PredictionTechnique.values()) { + if (pTechnique.getShortName().equals(shortName)) { + predictors.add(PredictorFactory.createPredictor(pTechnique)); + logger.info("Added predictor: " + pTechnique.getName()); + return; + } + } + logger.warn("Added predictor: None"); + logger.warn(shortName + " is not in the list."); + } + + @Override + public void selectTrainingMethod() { + + } + + @Override + public void trainPredictors() throws Exception{ + if (preprocessedInstances == null) { + throw new Exception("No training data"); + } + if (predictors.size() == 0) { + throw new Exception("No predictors selected"); + } + for (PredictorInterface p : predictors) { + try { + logger.info("Training " + p.getName()); + p.train(preprocessedInstances); + logger.debug(p.toString()); + } catch (Exception e) { + logger.error(e.toString()); + } + } + } + + @Override + public void crossValidatePredictors(int numFold) { + long seed = 1; + this.crossValidatePredictors(numFold, seed); + } + + @Override + public void crossValidatePredictors(int numFold, long seed) { + for (PredictorInterface p : predictors) { + try { + logger.debug(numFold + "-fold cross-validating: " + p.getName()); + Random rand = new Random(seed); + p.crossValidate(preprocessedInstances, numFold, rand); + + ///* + ThresholdCurve tc = new ThresholdCurve(); + int classIndex = 1; + Instances result = tc.getCurve(p.getEvaluationPredictions(), classIndex); + + + // Save ROC + + BufferedWriter br = new BufferedWriter(new FileWriter(resultFilename + "_" + p.getName().replace(" ", "_") + "_ROC.arff")); + br.write(result.toString()); + br.close(); + } catch (Exception e) { + logger.error(e.toString()); + } + } + } + + @Override + public void benchmark(int rounds, String filename) throws Exception { + ArrayList<ArrayList<Double>> trainingTime = new ArrayList<ArrayList<Double>>(this.predictors.size()); + ArrayList<ArrayList<Double>> predictionTime = new ArrayList<ArrayList<Double>>(this.predictors.size()); + + Runtime runtime = Runtime.getRuntime(); + + for (int i=0; i<this.predictors.size(); i++) { + trainingTime.add(new ArrayList<Double>(rounds)); + predictionTime.add(new ArrayList<Double>(rounds)); + } + + // Benchmark - using preprocessed instances + long startTime; + long endTime; + double elapsedTime; + for (int pIndex=0; pIndex<this.predictors.size(); pIndex++) { + for (int rIndex=0; rIndex<rounds; rIndex++) { + logger.debug("Benchmarking " + this.predictors.get(pIndex).getName() + " round " + rIndex); + + // Training time + startTime = System.currentTimeMillis(); + this.predictors.get(pIndex).train(this.preprocessedInstances); + endTime = System.currentTimeMillis(); + elapsedTime = ((double)(endTime - startTime))/1000; + trainingTime.get(pIndex).add(elapsedTime); + logger.debug("Training time = " + elapsedTime + " seconds"); + + // Prediction time + startTime = System.currentTimeMillis(); + this.predictors.get(pIndex).predict(this.preprocessedInstances); + endTime = System.currentTimeMillis(); + elapsedTime = ((double)(endTime - startTime))/1000; + predictionTime.get(pIndex).add(elapsedTime); + logger.debug("Prediction time = " + elapsedTime + " seconds"); + } + } + + // Save time results to file + BufferedWriter br = new BufferedWriter(new FileWriter(filename+"_benchmark")); + // Write header + String header = ""; + for (int pIndex=0; pIndex<this.predictors.size(); pIndex++) { + header += "\"" + this.predictors.get(pIndex).getName() + " Training\" "; + header += "\"" + this.predictors.get(pIndex).getName() + " Prediction\" "; + } + header += "\n"; + br.write(header); + + for (int rIndex=0; rIndex<rounds; rIndex++) { + String line = ""; + for (int pIndex=0; pIndex<this.predictors.size(); pIndex++) { + line += trainingTime.get(pIndex).get(rIndex).toString() + " " + predictionTime.get(pIndex).get(rIndex).toString() + " "; + } + line += "\n"; + br.write(line); + } + br.close(); + + // TODO: move to another function + // refactor + // Calculate and save summary results + BufferedWriter brSummary = new BufferedWriter(new FileWriter(filename+"_benchmark_time_summary")); + //header = "Algorithms tMean tError pMean pError\n"; + //brSummary.write(header); + for (int pIndex=0; pIndex<this.predictors.size(); pIndex++) { + String line = "\"" + this.predictors.get(pIndex).getName() + "\" "; + // Calculate mean + double [] training = new double[rounds]; + // Convert ArrayList to array for Mean.evaluate + for (int rIndex=0; rIndex<rounds; rIndex++) { + //System.out.println("pIndex = " + pIndex + " rIndex = " + rIndex); + //System.out.println("array size = " + training.length); + training[rIndex] = trainingTime.get(pIndex).get(rIndex); + } + double meanTraining = new Mean().evaluate(training); + double varTraining = new Variance().evaluate(training); + double stdTraining = Math.sqrt(varTraining); + double errorTraining = new NormalDistribution().inverseCumulativeProbability(0.975)*(stdTraining/Math.sqrt(rounds)); + double lowerCITraining = meanTraining - errorTraining; + double upperCITraining = meanTraining + errorTraining; + line += meanTraining + " " + errorTraining + " "; + //brSummary.write(line); + + //line = "\"" + this.predictors.get(pIndex).getName() + " Prediction\" "; + // Calculate mean + double [] prediction = new double[rounds]; + // Convert ArrayList to array for Mean.evaluate + for (int rIndex=0; rIndex<rounds; rIndex++) { + prediction[rIndex] = predictionTime.get(pIndex).get(rIndex); + } + double meanPrediction = new Mean().evaluate(prediction); + double varPrediction = new Variance().evaluate(prediction); + double stdPrediction = Math.sqrt(varPrediction); + double errorPrediction = new NormalDistribution().inverseCumulativeProbability(0.975)*(stdPrediction/Math.sqrt(rounds)); + double lowerCIPrediction = meanPrediction - errorPrediction; + double upperCIPrediction = meanPrediction + errorPrediction; + line += meanPrediction + " " + errorPrediction + "\n"; + brSummary.write(line); + } + brSummary.close(); + } + + + @Override + public String getPredictorNames() { + String str = ""; + for (PredictorInterface p : this.predictors) { + str += p.getName() + "\n"; + } + return str; + } + + @Override + public String toString() { + String str = ""; + str += "Training data:\n" + this.datapath + "\n"; + str += "Training data summary:\n" + this.getTrainingInstances().toSummaryString() + "\n"; + str += "Preprocessed data summary:\n" + this.getPreprocessedInstances().toSummaryString() + "\n"; + str += "Predictors:\n" + this.getPredictorNames() + "\n"; + return str; + } + + @Override + public void saveSettings(String filename) throws Exception { + BufferedWriter br = new BufferedWriter(new FileWriter(filename,true)); + br.write(this.toString()); + br.close(); + } + + @Override + public void saveResults(String filename) throws Exception { + BufferedWriter br = new BufferedWriter(new FileWriter(filename,true)); + + for (PredictorInterface p : this.predictors) { + logger.info(p.getEvaluationResults()); + br.write(p.getEvaluationResults()); + } + + br.close(); + } +} + diff --git a/src/model/ModelInterface.java b/src/model/ModelInterface.java new file mode 100644 index 0000000..a4d1b7b --- /dev/null +++ b/src/model/ModelInterface.java @@ -0,0 +1,32 @@ +package model; + +import java.util.ArrayList; + +import weka.core.Instances; + +import predictor.*; + + +public interface ModelInterface { + public String getDatapath(); + public ArrayList<PredictorInterface> getPredictors(); + + public void loadTrainingData(String path); + public void loadRawLog(String path); + public Instances getTrainingInstances(); + public void setPreprocessedInstances(Instances instances); + public Instances getPreprocessedInstances(); + public void savePreprocessedInstances(String path); + public void addPredictor(String shortName); + public void crossValidatePredictors(int numFold); + public void crossValidatePredictors(int numFold, long seed); + public void selectTrainingMethod(); + public void trainPredictors() throws Exception; + public void benchmark(int rounds, String filename) throws Exception; + + public String getPredictorNames(); + public String toString(); + public void saveSettings(String filename) throws Exception; + public void saveResults(String filename) throws Exception; +} + |