package model; import java.io.*; import java.util.*; import java.text.DateFormat; import java.text.SimpleDateFormat; import org.apache.log4j.*; import org.apache.commons.math3.stat.descriptive.moment.*; import org.apache.commons.math3.distribution.NormalDistribution; import weka.core.*; import weka.classifiers.evaluation.*; import weka.core.converters.*; import input.ARFFReader; import predictor.*; public class Model implements ModelInterface { protected Logger logger = Logger.getLogger(Model.class); protected String datapath; protected Instances trainingInstances; protected Instances preprocessedInstances; protected ArrayList predictors; protected static DateFormat dateFormat = new SimpleDateFormat("yyyyMMdd_HHmmss"); protected static Date date = new Date(); protected static String resultFilename = dateFormat.format(date); public Model(){ this.datapath = ""; this.trainingInstances = null; this.preprocessedInstances = null; this.predictors = new ArrayList(); } public Model (ModelInterface model) { this.datapath = model.getDatapath(); this.trainingInstances = new Instances(model.getTrainingInstances()); this.preprocessedInstances = new Instances(model.getPreprocessedInstances()); this.predictors = new ArrayList(model.getPredictors()); } @Override public Instances getPreprocessedInstances() { return this.preprocessedInstances; } @Override public Instances getTrainingInstances() { return this.trainingInstances; } @Override public String getDatapath() { return this.datapath; } @Override public ArrayList getPredictors() { return this.predictors; } @Override public void loadTrainingData(String path) { logger.debug("Reading training data from " + path); this.datapath = path; try { trainingInstances = ARFFReader.read(path); preprocessedInstances = trainingInstances; logger.debug("Training data is read"); logger.trace(trainingInstances); } catch (Exception e) { logger.warn(e.toString()); } } @Override public void loadRawLog(String path) { logger.debug("Reading logfile from " + path); try { BufferedReader in = new BufferedReader(new FileReader(path)); String line = in.readLine(); System.out.println(line); } catch (Exception e) { logger.warn(e.toString()); } } @Override public void setPreprocessedInstances(Instances instances){ this.preprocessedInstances=new Instances(instances); } @Override public void savePreprocessedInstances(String path) { ArffSaver saver = new ArffSaver(); saver.setInstances(this.preprocessedInstances); try { saver.setFile(new File(path)); saver.writeBatch(); } catch (Exception e) { logger.error("Cannot save preprocessed instances to file " + path); } } @Override public void addPredictor(String shortName) { for (PredictorFactory.PredictionTechnique pTechnique : PredictorFactory.PredictionTechnique.values()) { if (pTechnique.getShortName().equals(shortName)) { predictors.add(PredictorFactory.createPredictor(pTechnique)); logger.info("Added predictor: " + pTechnique.getName()); return; } } logger.warn("Added predictor: None"); logger.warn(shortName + " is not in the list."); } @Override public void selectTrainingMethod() { } @Override public void trainPredictors() throws Exception{ if (preprocessedInstances == null) { throw new Exception("No training data"); } if (predictors.size() == 0) { throw new Exception("No predictors selected"); } for (PredictorInterface p : predictors) { try { logger.info("Training " + p.getName()); p.train(preprocessedInstances); logger.debug(p.toString()); } catch (Exception e) { logger.error(e.toString()); } } } @Override public void crossValidatePredictors(int numFold) { long seed = 1; this.crossValidatePredictors(numFold, seed); } @Override public void crossValidatePredictors(int numFold, long seed) { for (PredictorInterface p : predictors) { try { logger.debug(numFold + "-fold cross-validating: " + p.getName()); Random rand = new Random(seed); p.crossValidate(preprocessedInstances, numFold, rand); ///* ThresholdCurve tc = new ThresholdCurve(); int classIndex = 1; Instances result = tc.getCurve(p.getEvaluationPredictions(), classIndex); // Save ROC BufferedWriter br = new BufferedWriter(new FileWriter(resultFilename + "_" + p.getName().replace(" ", "_") + "_ROC.arff")); br.write(result.toString()); br.close(); } catch (Exception e) { logger.error(e.toString()); } } } @Override public void benchmark(int rounds, String filename) throws Exception { ArrayList> trainingTime = new ArrayList>(this.predictors.size()); ArrayList> predictionTime = new ArrayList>(this.predictors.size()); Runtime runtime = Runtime.getRuntime(); for (int i=0; i(rounds)); predictionTime.add(new ArrayList(rounds)); } // Benchmark - using preprocessed instances long startTime; long endTime; double elapsedTime; for (int pIndex=0; pIndex