Darwin  1.10(beta)
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Groups Pages
drwnDecisionTree.h
1 /******************************************************************************
2 ** DARWIN: A FRAMEWORK FOR MACHINE LEARNING RESEARCH AND DEVELOPMENT
3 ** Distributed under the terms of the BSD license (see the LICENSE file)
4 ** Copyright (c) 2007-2015, Stephen Gould
5 ** All rights reserved.
6 **
7 ******************************************************************************
8 ** FILENAME: drwnDecisionTree.h
9 ** AUTHOR(S): Stephen Gould <stephen.gould@anu.edu.au>
10 **
11 *****************************************************************************/
12 
13 #pragma once
14 
15 #include <cstdlib>
16 #include <vector>
17 
18 #include "drwnBase.h"
19 #include "drwnClassifier.h"
20 
21 using namespace std;
22 
23 // drwnTreeSplitCriterion ----------------------------------------------------
24 
25 typedef enum _drwnTreeSplitCriterion {
26  DRWN_DT_SPLIT_ENTROPY, DRWN_DT_SPLIT_MISCLASS, DRWN_DT_SPLIT_GINI
27 } drwnTreeSplitCriterion;
28 
29 // drwnDecisionTree ---------------------------------------------------------
59 
61  public:
62  friend class drwnDecisionTreeThread;
63  friend class drwnDecisionTreeConfig;
64  friend class drwnBoostedClassifier;
65  friend class drwnRandomForest;
66 
67  public:
68  static int MAX_DEPTH;
70  static int MIN_SAMPLES;
71  static double LEAKAGE;
72  static drwnTreeSplitCriterion SPLIT_CRITERION;
73  static bool CACHE_SORTED_INDEXES;
74 
75  protected:
76  int _splitIndx;
77  double _splitValue;
80  Eigen::VectorXd _scores;
82 
83  int _maxDepth;
84 
85  public:
89  drwnDecisionTree(unsigned n, unsigned k = 2);
93 
94  // access functions
95  virtual const char *type() const { return "drwnDecisionTree"; }
96  virtual drwnDecisionTree *clone() const { return new drwnDecisionTree(*this); }
97 
98  // initialization
99  virtual void initialize(unsigned n, unsigned k = 2);
100 
101  // i/o
102  virtual bool save(drwnXMLNode& node) const;
103  virtual bool load(drwnXMLNode& node);
104 
105  // training
106  using drwnClassifier::train;
107  virtual double train(const drwnClassifierDataset& dataset);
108  virtual double train(const vector<vector<double> >& features,
109  const vector<int>& targets);
110  virtual double train(const vector<vector<double> >& features,
111  const vector<int>& targets, const vector<double>& weights);
112 
113  // evaluation (log-probability)
115  virtual void getClassScores(const vector<double>& features,
116  vector<double>& outputScores) const;
117 
118  // evaluation (classification)
120  virtual int getClassification(const vector<double>& features) const;
121 
122  protected:
123  // evaluation
124  const Eigen::VectorXd &evaluate(const Eigen::VectorXd& x) const;
125 
126  // training
127  static void computeSortedFeatureIndex(const vector<vector<double> >& x,
128  const drwnBitArray& sampleIndex, int featureIndx, vector<int>& featureSortIndex);
129  static void computeSortedFeatureIndex(const vector<vector<double> >& x,
130  vector<vector<int> >& sortIndex);
131  void learnDecisionTree(const vector<vector<double> >& x, const vector<int>& y,
132  const vector<double>& w, const vector<vector<int> >& sortIndex,
133  const drwnBitArray& sampleIndex);
134 };
135 
virtual double train(const drwnClassifierDataset &dataset)=0
train the parameters of the classifier from a drwnClassifierDataset object
int _splitIndx
variable index on which to split
Definition: drwnDecisionTree.h:76
virtual void initialize(unsigned n, unsigned k=2)
initialize the classifier object for n features and k classes
Definition: drwnRandomForest.cpp:75
Definition: drwnDecisionTree.cpp:25
virtual void getClassScores(const vector< double > &features, vector< double > &outputScores) const
compute the unnormalized log-probability for a single feature vector
Definition: drwnRandomForest.cpp:179
static bool CACHE_SORTED_INDEXES
pre-cache indexes of sorted features
Definition: drwnDecisionTree.h:73
int _maxDepth
maximum depth of decision tree
Definition: drwnDecisionTree.h:83
int _predictedClass
argmax of _scores
Definition: drwnDecisionTree.h:81
virtual drwnDecisionTree * clone() const
returns a copy of the class usually implemented as virtual Foo* clone() { return new Foo(*this); } ...
Definition: drwnDecisionTree.h:96
static double LEAKAGE
probability that a training sample leaks to both splits
Definition: drwnDecisionTree.h:71
virtual int getClassification(const vector< double > &features) const
return the most likely class label for a single feature vector
Definition: drwnClassifier.cpp:150
virtual double train(const drwnClassifierDataset &dataset)
train the parameters of the classifier from a drwnClassifierDataset object
Definition: drwnRandomForest.cpp:113
Implements a mult-class boosted decision-tree classifier. See Zhu et al., Multi-class AdaBoost...
Definition: drwnBoostedClassifier.h:61
Implements the interface for a generic machine learning classifier.
Definition: drwnClassifier.h:31
static int MIN_SAMPLES
minimum number of samples (after first split)
Definition: drwnDecisionTree.h:70
Eigen::VectorXd _scores
log-marginal for each class at this node
Definition: drwnDecisionTree.h:80
drwnDecisionTree * _rightChild
right child (or NULL)
Definition: drwnDecisionTree.h:79
Implements an efficient packed array of bits.
Definition: drwnBitArray.h:42
Definition: drwnDecisionTree.cpp:571
drwnDecisionTree * _leftChild
left child (or NULL)
Definition: drwnDecisionTree.h:78
static int MAX_DEPTH
default maximum tree depth
Definition: drwnDecisionTree.h:68
virtual bool save(drwnXMLNode &node) const
write object to XML node (see also write)
Definition: drwnRandomForest.cpp:86
static int MAX_FEATURE_THRESHOLDS
maximum number of thresholds to try during learning
Definition: drwnDecisionTree.h:69
virtual const char * type() const
returns object type as a string (e.g., Foo::type() { return "Foo"; })
Definition: drwnDecisionTree.h:95
Implements a cacheable dataset containing feature vectors, labels and optional weights.
Definition: drwnDataset.h:43
virtual void getClassScores(const vector< double > &features, vector< double > &outputScores) const =0
compute the unnormalized log-probability for a single feature vector
static drwnTreeSplitCriterion SPLIT_CRITERION
tree split criteria during learning
Definition: drwnDecisionTree.h:72
Implements a (binary-split) decision tree classifier of arbitrary depth.
Definition: drwnDecisionTree.h:60
virtual bool load(drwnXMLNode &node)
read object from XML node (see also read)
Definition: drwnRandomForest.cpp:99
double _splitValue
split value (go left if less than)
Definition: drwnDecisionTree.h:77
Implements a Random forest ensemble of decision trees classifier. See L. Breiman, "Random Forests"...
Definition: drwnRandomForest.h:49