19 #include "drwnClassifier.h"
25 typedef enum _drwnTreeSplitCriterion {
26 DRWN_DT_SPLIT_ENTROPY, DRWN_DT_SPLIT_MISCLASS, DRWN_DT_SPLIT_GINI
27 } drwnTreeSplitCriterion;
95 virtual const char *
type()
const {
return "drwnDecisionTree"; }
99 virtual void initialize(
unsigned n,
unsigned k = 2);
102 virtual bool save(drwnXMLNode& node)
const;
103 virtual bool load(drwnXMLNode& node);
108 virtual double train(
const vector<vector<double> >& features,
109 const vector<int>& targets);
110 virtual double train(
const vector<vector<double> >& features,
111 const vector<int>& targets,
const vector<double>& weights);
116 vector<double>& outputScores)
const;
124 const Eigen::VectorXd &evaluate(
const Eigen::VectorXd& x)
const;
127 static void computeSortedFeatureIndex(
const vector<vector<double> >& x,
128 const drwnBitArray& sampleIndex,
int featureIndx, vector<int>& featureSortIndex);
129 static void computeSortedFeatureIndex(
const vector<vector<double> >& x,
130 vector<vector<int> >& sortIndex);
131 void learnDecisionTree(
const vector<vector<double> >& x,
const vector<int>& y,
132 const vector<double>& w,
const vector<vector<int> >& sortIndex,
virtual double train(const drwnClassifierDataset &dataset)=0
train the parameters of the classifier from a drwnClassifierDataset object
int _splitIndx
variable index on which to split
Definition: drwnDecisionTree.h:76
virtual void initialize(unsigned n, unsigned k=2)
initialize the classifier object for n features and k classes
Definition: drwnRandomForest.cpp:75
Definition: drwnDecisionTree.cpp:25
virtual void getClassScores(const vector< double > &features, vector< double > &outputScores) const
compute the unnormalized log-probability for a single feature vector
Definition: drwnRandomForest.cpp:179
static bool CACHE_SORTED_INDEXES
pre-cache indexes of sorted features
Definition: drwnDecisionTree.h:73
int _maxDepth
maximum depth of decision tree
Definition: drwnDecisionTree.h:83
int _predictedClass
argmax of _scores
Definition: drwnDecisionTree.h:81
virtual drwnDecisionTree * clone() const
returns a copy of the class usually implemented as virtual Foo* clone() { return new Foo(*this); } ...
Definition: drwnDecisionTree.h:96
static double LEAKAGE
probability that a training sample leaks to both splits
Definition: drwnDecisionTree.h:71
virtual int getClassification(const vector< double > &features) const
return the most likely class label for a single feature vector
Definition: drwnClassifier.cpp:150
virtual double train(const drwnClassifierDataset &dataset)
train the parameters of the classifier from a drwnClassifierDataset object
Definition: drwnRandomForest.cpp:113
Implements a mult-class boosted decision-tree classifier. See Zhu et al., Multi-class AdaBoost...
Definition: drwnBoostedClassifier.h:61
Implements the interface for a generic machine learning classifier.
Definition: drwnClassifier.h:31
static int MIN_SAMPLES
minimum number of samples (after first split)
Definition: drwnDecisionTree.h:70
Eigen::VectorXd _scores
log-marginal for each class at this node
Definition: drwnDecisionTree.h:80
drwnDecisionTree * _rightChild
right child (or NULL)
Definition: drwnDecisionTree.h:79
Implements an efficient packed array of bits.
Definition: drwnBitArray.h:42
Definition: drwnDecisionTree.cpp:571
drwnDecisionTree * _leftChild
left child (or NULL)
Definition: drwnDecisionTree.h:78
static int MAX_DEPTH
default maximum tree depth
Definition: drwnDecisionTree.h:68
virtual bool save(drwnXMLNode &node) const
write object to XML node (see also write)
Definition: drwnRandomForest.cpp:86
static int MAX_FEATURE_THRESHOLDS
maximum number of thresholds to try during learning
Definition: drwnDecisionTree.h:69
virtual const char * type() const
returns object type as a string (e.g., Foo::type() { return "Foo"; })
Definition: drwnDecisionTree.h:95
Implements a cacheable dataset containing feature vectors, labels and optional weights.
Definition: drwnDataset.h:43
virtual void getClassScores(const vector< double > &features, vector< double > &outputScores) const =0
compute the unnormalized log-probability for a single feature vector
static drwnTreeSplitCriterion SPLIT_CRITERION
tree split criteria during learning
Definition: drwnDecisionTree.h:72
Implements a (binary-split) decision tree classifier of arbitrary depth.
Definition: drwnDecisionTree.h:60
virtual bool load(drwnXMLNode &node)
read object from XML node (see also read)
Definition: drwnRandomForest.cpp:99
double _splitValue
split value (go left if less than)
Definition: drwnDecisionTree.h:77
Implements a Random forest ensemble of decision trees classifier. See L. Breiman, "Random Forests"...
Definition: drwnRandomForest.h:49