From 4034199deccff41cb6661812d4c4aa7c523d78b4 Mon Sep 17 00:00:00 2001 From: Hai Liu Date: Fri, 18 Mar 2016 15:03:47 +0800 Subject: Add data model for predictor JIRA: PREDICTION-36 Change-Id: I871ce7e3696e7154ee4adb5c83cdb45c02a97d7b Signed-off-by: Hai Liu --- src/model/Model.java | 325 ++++++++++++++++++++++++++++++++++++++++++ src/model/ModelInterface.java | 32 +++++ 2 files changed, 357 insertions(+) create mode 100644 src/model/Model.java create mode 100644 src/model/ModelInterface.java (limited to 'src') diff --git a/src/model/Model.java b/src/model/Model.java new file mode 100644 index 0000000..ca9897b --- /dev/null +++ b/src/model/Model.java @@ -0,0 +1,325 @@ +package model; + +import java.io.*; +import java.util.*; +import java.text.DateFormat; +import java.text.SimpleDateFormat; + +import org.apache.log4j.*; + +import org.apache.commons.math3.stat.descriptive.moment.*; +import org.apache.commons.math3.distribution.NormalDistribution; + +import weka.core.*; +import weka.classifiers.evaluation.*; +import weka.core.converters.*; + +import input.ARFFReader; +import predictor.*; + + +public class Model implements ModelInterface { + protected Logger logger = Logger.getLogger(Model.class); + + protected String datapath; + protected Instances trainingInstances; + protected Instances preprocessedInstances; + protected ArrayList predictors; + + protected static DateFormat dateFormat = new SimpleDateFormat("yyyyMMdd_HHmmss"); + protected static Date date = new Date(); + protected static String resultFilename = dateFormat.format(date); + + public Model(){ + this.datapath = ""; + this.trainingInstances = null; + this.preprocessedInstances = null; + this.predictors = new ArrayList(); + } + + public Model (ModelInterface model) { + this.datapath = model.getDatapath(); + this.trainingInstances = new Instances(model.getTrainingInstances()); + this.preprocessedInstances = new Instances(model.getPreprocessedInstances()); + this.predictors = new ArrayList(model.getPredictors()); + } + + @Override + public Instances getPreprocessedInstances() { + return this.preprocessedInstances; + } + + @Override + public Instances getTrainingInstances() { + return this.trainingInstances; + } + + + @Override + public String getDatapath() { + return this.datapath; + } + + + @Override + public ArrayList getPredictors() { + return this.predictors; + } + + + @Override + public void loadTrainingData(String path) { + logger.debug("Reading training data from " + path); + this.datapath = path; + try { + trainingInstances = ARFFReader.read(path); + preprocessedInstances = trainingInstances; + logger.debug("Training data is read"); + logger.trace(trainingInstances); + } catch (Exception e) { + logger.warn(e.toString()); + } + } + + @Override + public void loadRawLog(String path) { + logger.debug("Reading logfile from " + path); + try { + BufferedReader in = new BufferedReader(new FileReader(path)); + String line = in.readLine(); + System.out.println(line); + } catch (Exception e) { + logger.warn(e.toString()); + } + } + + + + @Override + public void setPreprocessedInstances(Instances instances){ + this.preprocessedInstances=new Instances(instances); + } + + @Override + public void savePreprocessedInstances(String path) { + ArffSaver saver = new ArffSaver(); + saver.setInstances(this.preprocessedInstances); + try { + saver.setFile(new File(path)); + saver.writeBatch(); + } catch (Exception e) { + logger.error("Cannot save preprocessed instances to file " + path); + } + } + + @Override + public void addPredictor(String shortName) { + for (PredictorFactory.PredictionTechnique pTechnique : PredictorFactory.PredictionTechnique.values()) { + if (pTechnique.getShortName().equals(shortName)) { + predictors.add(PredictorFactory.createPredictor(pTechnique)); + logger.info("Added predictor: " + pTechnique.getName()); + return; + } + } + logger.warn("Added predictor: None"); + logger.warn(shortName + " is not in the list."); + } + + @Override + public void selectTrainingMethod() { + + } + + @Override + public void trainPredictors() throws Exception{ + if (preprocessedInstances == null) { + throw new Exception("No training data"); + } + if (predictors.size() == 0) { + throw new Exception("No predictors selected"); + } + for (PredictorInterface p : predictors) { + try { + logger.info("Training " + p.getName()); + p.train(preprocessedInstances); + logger.debug(p.toString()); + } catch (Exception e) { + logger.error(e.toString()); + } + } + } + + @Override + public void crossValidatePredictors(int numFold) { + long seed = 1; + this.crossValidatePredictors(numFold, seed); + } + + @Override + public void crossValidatePredictors(int numFold, long seed) { + for (PredictorInterface p : predictors) { + try { + logger.debug(numFold + "-fold cross-validating: " + p.getName()); + Random rand = new Random(seed); + p.crossValidate(preprocessedInstances, numFold, rand); + + ///* + ThresholdCurve tc = new ThresholdCurve(); + int classIndex = 1; + Instances result = tc.getCurve(p.getEvaluationPredictions(), classIndex); + + + // Save ROC + + BufferedWriter br = new BufferedWriter(new FileWriter(resultFilename + "_" + p.getName().replace(" ", "_") + "_ROC.arff")); + br.write(result.toString()); + br.close(); + } catch (Exception e) { + logger.error(e.toString()); + } + } + } + + @Override + public void benchmark(int rounds, String filename) throws Exception { + ArrayList> trainingTime = new ArrayList>(this.predictors.size()); + ArrayList> predictionTime = new ArrayList>(this.predictors.size()); + + Runtime runtime = Runtime.getRuntime(); + + for (int i=0; i(rounds)); + predictionTime.add(new ArrayList(rounds)); + } + + // Benchmark - using preprocessed instances + long startTime; + long endTime; + double elapsedTime; + for (int pIndex=0; pIndex getPredictors(); + + public void loadTrainingData(String path); + public void loadRawLog(String path); + public Instances getTrainingInstances(); + public void setPreprocessedInstances(Instances instances); + public Instances getPreprocessedInstances(); + public void savePreprocessedInstances(String path); + public void addPredictor(String shortName); + public void crossValidatePredictors(int numFold); + public void crossValidatePredictors(int numFold, long seed); + public void selectTrainingMethod(); + public void trainPredictors() throws Exception; + public void benchmark(int rounds, String filename) throws Exception; + + public String getPredictorNames(); + public String toString(); + public void saveSettings(String filename) throws Exception; + public void saveResults(String filename) throws Exception; +} + -- cgit 1.2.3-korg