/** Copyright 2016 Huawei Technologies Co. Ltd. Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file to You under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. **/ package model; import java.io.*; import java.util.*; import java.text.DateFormat; import java.text.SimpleDateFormat; import org.apache.log4j.*; import org.apache.commons.math3.stat.descriptive.moment.*; import org.apache.commons.math3.distribution.NormalDistribution; import weka.core.*; import weka.classifiers.evaluation.*; import weka.core.converters.*; import input.ARFFReader; import predictor.*; public class Model implements ModelInterface { protected Logger logger = Logger.getLogger(Model.class); protected String datapath; protected Instances trainingInstances; protected Instances preprocessedInstances; protected ArrayList predictors; protected static DateFormat dateFormat = new SimpleDateFormat("yyyyMMdd_HHmmss"); protected static Date date = new Date(); protected static String resultFilename = dateFormat.format(date); public Model(){ this.datapath = ""; this.trainingInstances = null; this.preprocessedInstances = null; this.predictors = new ArrayList(); } public Model (ModelInterface model) { this.datapath = model.getDatapath(); this.trainingInstances = new Instances(model.getTrainingInstances()); this.preprocessedInstances = new Instances(model.getPreprocessedInstances()); this.predictors = new ArrayList(model.getPredictors()); } @Override public Instances getPreprocessedInstances() { return this.preprocessedInstances; } @Override public Instances getTrainingInstances() { return this.trainingInstances; } @Override public String getDatapath() { return this.datapath; } @Override public ArrayList getPredictors() { return this.predictors; } @Override public void loadTrainingData(String path) { logger.debug("Reading training data from " + path); this.datapath = path; try { trainingInstances = ARFFReader.read(path); preprocessedInstances = trainingInstances; logger.debug("Training data is read"); logger.trace(trainingInstances); } catch (Exception e) { logger.warn(e.toString()); } } @Override public void loadRawLog(String path) { logger.debug("Reading logfile from " + path); try { BufferedReader in = new BufferedReader(new FileReader(path)); String line = in.readLine(); System.out.println(line); } catch (Exception e) { logger.warn(e.toString()); } } @Override public void setPreprocessedInstances(Instances instances){ this.preprocessedInstances=new Instances(instances); } @Override public void savePreprocessedInstances(String path) { ArffSaver saver = new ArffSaver(); saver.setInstances(this.preprocessedInstances); try { saver.setFile(new File(path)); saver.writeBatch(); } catch (Exception e) { logger.error("Cannot save preprocessed instances to file " + path); } } @Override public void addPredictor(String shortName) { for (PredictorFactory.PredictionTechnique pTechnique : PredictorFactory.PredictionTechnique.values()) { if (pTechnique.getShortName().equals(shortName)) { predictors.add(PredictorFactory.createPredictor(pTechnique)); logger.info("Added predictor: " + pTechnique.getName()); return; } } logger.warn("Added predictor: None"); logger.warn(shortName + " is not in the list."); } @Override public void selectTrainingMethod() { } @Override public void trainPredictors() throws Exception{ if (preprocessedInstances == null) { throw new Exception("No training data"); } if (predictors.size() == 0) { throw new Exception("No predictors selected"); } for (PredictorInterface p : predictors) { try { logger.info("Training " + p.getName()); p.train(preprocessedInstances); logger.debug(p.toString()); } catch (Exception e) { logger.error(e.toString()); } } } @Override public void crossValidatePredictors(int numFold) { long seed = 1; this.crossValidatePredictors(numFold, seed); } @Override public void crossValidatePredictors(int numFold, long seed) { for (PredictorInterface p : predictors) { try { logger.debug(numFold + "-fold cross-validating: " + p.getName()); Random rand = new Random(seed); p.crossValidate(preprocessedInstances, numFold, rand); ///* ThresholdCurve tc = new ThresholdCurve(); int classIndex = 1; Instances result = tc.getCurve(p.getEvaluationPredictions(), classIndex); // Save ROC BufferedWriter br = new BufferedWriter(new FileWriter(resultFilename + "_" + p.getName().replace(" ", "_") + "_ROC.arff")); br.write(result.toString()); br.close(); } catch (Exception e) { logger.error(e.toString()); } } } @Override public void benchmark(int rounds, String filename) throws Exception { ArrayList> trainingTime = new ArrayList>(this.predictors.size()); ArrayList> predictionTime = new ArrayList>(this.predictors.size()); Runtime runtime = Runtime.getRuntime(); for (int i=0; i(rounds)); predictionTime.add(new ArrayList(rounds)); } // Benchmark - using preprocessed instances long startTime; long endTime; double elapsedTime; for (int pIndex=0; pIndex