diff options
author | Hai Liu <hai.liu@huawei.com> | 2016-04-20 15:34:00 +0800 |
---|---|---|
committer | Hai Liu <hai.liu@huawei.com> | 2016-04-20 15:35:01 +0800 |
commit | 9f71917d9e283195424fce4427508cc85bb36d22 (patch) | |
tree | d5842ef0d4c46584988acf35b42fae32a091a911 | |
parent | 538c8e4d65b460db537ab81159f83e1092626275 (diff) |
Add a predictor server
JIRA:PREDICTION-44
Change-Id: I97595d8857572c7e524ab459c93687af2169b8ad
-rw-r--r-- | src/experiments/RunPredictionServer.java | 183 |
1 files changed, 183 insertions, 0 deletions
diff --git a/src/experiments/RunPredictionServer.java b/src/experiments/RunPredictionServer.java new file mode 100644 index 0000000..6612a86 --- /dev/null +++ b/src/experiments/RunPredictionServer.java @@ -0,0 +1,183 @@ +package experiments; + +import java.io.*; +import java.text.DateFormat; +import java.text.SimpleDateFormat; +import java.util.*; + +import weka.filters.Filter; +import model.Model; +import org.apache.log4j.Logger; +import weka.core.Instances; +import weka.core.OptionHandler; + +/** + * Created by hailiu on 2016/2/22. + */ +public class RunPredictionServer { + protected Logger logger = Logger.getLogger(RunPredictionServer.class); + + final String configFile = "data/config.txt";//input file path + String[] option = new String[50]; + String filter; + ArrayList filterList; + String DataFile; + protected static DateFormat dateFormat = new SimpleDateFormat("yyyyMMdd_HHmmss"); + protected static Date date = new Date(); + protected static String resultFilename = dateFormat.format(date)+"_results"; + Map map; + + public RunPredictionServer() { + this.filterList = new ArrayList(); + setClassName(); + } + + public static void main(String[] args) + { + RunPredictionServer runServer = new RunPredictionServer(); + //read config file + runServer.loadConfigFile(); + Model tempModel = new Model(); + tempModel.loadTrainingData(runServer.DataFile); + tempModel.addPredictor("ZEROR"); + tempModel.addPredictor("PART"); + tempModel.addPredictor("ONER"); + tempModel.addPredictor("JRIP"); + tempModel.addPredictor("IBK"); + tempModel.addPredictor("NBM"); + tempModel.addPredictor("RF"); + tempModel.addPredictor("LWL"); + tempModel.addPredictor("NBC"); + tempModel.addPredictor("BN"); + tempModel.addPredictor("REPTREE"); + tempModel.addPredictor("DT"); + tempModel.addPredictor("J48"); + tempModel.addPredictor("SMO"); + tempModel.addPredictor("MP"); + tempModel.addPredictor("SL"); + tempModel.addPredictor("LOG"); + tempModel.addPredictor("SGD"); + tempModel.addPredictor("VP"); + tempModel.addPredictor("SVM"); + tempModel.addPredictor("KSTAR"); + + Instances tempInstances=tempModel.getTrainingInstances(); + + Iterator ite=runServer.filterList.iterator(); + + while (ite.hasNext()) + { + ArrayList filterOption=(ArrayList) ite.next(); + Iterator tempIte = filterOption.iterator(); + String tFilter=null; + if (tempIte.hasNext()) + { + tFilter=(String) tempIte.next(); + } + String[] option = new String[filterOption.size()-1]; + int i=0; + while (tempIte.hasNext()) + { + option[i]=(String) tempIte.next(); + i++; + } + tempInstances = runServer.addFilter(tempInstances,tFilter,option); + } + + //tempInstances.setClassIndex(tempInstances.numAttributes()-2); + tempModel.setPreprocessedInstances(tempInstances); + tempModel.savePreprocessedInstances("preprocessed.arff"); + + try { + tempModel.benchmark(2,resultFilename); + + }catch (Exception e) + { + runServer.logger.warn(e.toString()); + } + + tempModel.crossValidatePredictors(10); + try { + tempModel.saveResults(resultFilename); + }catch (Exception e) + { + runServer.logger.warn(e.toString()); + } + } + + //read config file + private void loadConfigFile() { + logger.debug("Reading configure from " + configFile); + try { + BufferedReader in = new BufferedReader(new FileReader(configFile)); + DataFile = in.readLine(); + System.out.println(DataFile); + + String tempStr=null; + while((tempStr=in.readLine())!=null) + { + String[] tempString = tempStr.split(" "); + ArrayList<String> filterOption = new ArrayList<>(); + for(String ts:tempString) + { + filterOption.add(ts); + } + this.filterList.add(filterOption); + } + } catch (Exception e) { + logger.warn(e.toString()); + } + } + + private Instances addFilter(Instances tempInstances,String filter, String[] option) + { + Class<?> tempClass=null; + try { + tempClass = Class.forName(this.map.get(filter).toString()); + }catch (ClassNotFoundException e) + { + logger.warn(e.toString()); + } + + if(tempClass==null) return null; + Object tempFilter=null; + try { + tempFilter = tempClass.newInstance(); + }catch (InstantiationException e) + { + logger.warn(e.toString()); + } + catch (IllegalAccessException e) + { + logger.warn(e.toString()); + } + + try { + ((OptionHandler)tempFilter).setOptions(option); + ((Filter)tempFilter).setInputFormat(tempInstances); + }catch (Exception e) + { + logger.warn(e.toString()); + } + + Instances newInstances=null; + try { + newInstances = Filter.useFilter(tempInstances, (Filter) (tempFilter)); + }catch (Exception e) + { + logger.warn(e.toString()); + } + return newInstances; + } + + private void setClassName() + { + map=new HashMap(); + map.put("addexpression","weka.filters.unsupervised.attribute.AddExpression"); + map.put("remove","weka.filters.unsupervised.attribute.Remove"); + map.put("classassigner","weka.filters.unsupervised.attribute.ClassAssigner"); + map.put("numerictonominal","weka.filters.unsupervised.attribute.NumericToNominal"); + map.put("mathexpression","weka.filters.unsupervised.attribute.MathExpression"); + } +} + |