Darwin  1.10(beta)
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Groups Pages
drwnRandomForest.h
1 /******************************************************************************
2 ** DARWIN: A FRAMEWORK FOR MACHINE LEARNING RESEARCH AND DEVELOPMENT
3 ** Distributed under the terms of the BSD license (see the LICENSE file)
4 ** Copyright (c) 2007-2015, Stephen Gould
5 ** All rights reserved.
6 **
7 ******************************************************************************
8 ** FILENAME: drwnRandomForest.h
9 ** AUTHOR(S): Stephen Gould <stephen.gould@anu.edu.au>
10 **
11 *****************************************************************************/
12 
13 #pragma once
14 
15 #include <cstdlib>
16 #include <vector>
17 
18 #include "drwnBase.h"
19 #include "drwnClassifier.h"
20 #include "drwnDecisionTree.h"
21 
22 using namespace std;
23 
24 // drwnRandomForest ----------------------------------------------------------
48 
50  public:
51  // default training parameters
52  static int NUM_TREES;
53  static int MAX_DEPTH;
54  static int MAX_FEATURES;
55 
56  protected:
57  // actual training parameters
58  int _numTrees;
59  int _maxDepth;
61 
63  vector<drwnDecisionTree *> _forest;
65  vector<double> _alphas;
66 
67  public:
71  drwnRandomForest(unsigned n, unsigned k = 2);
75 
76  // access functions
77  virtual const char *type() const { return "drwnRandomForest"; }
78  virtual drwnRandomForest *clone() const { return new drwnRandomForest(*this); }
79 
80  // initialization
81  virtual void initialize(unsigned n, unsigned k = 2);
82 
83  // i/o
84  virtual bool save(drwnXMLNode& node) const;
85  virtual bool load(drwnXMLNode& node);
86 
87  // training
89  virtual double train(const drwnClassifierDataset& dataset);
90 
91  // evaluation (log-probability)
93  virtual void getClassScores(const vector<double>& features,
94  vector<double>& outputScores) const;
95 };
96 
virtual double train(const drwnClassifierDataset &dataset)=0
train the parameters of the classifier from a drwnClassifierDataset object
int _maxDepth
maximum depth of each decision tree
Definition: drwnRandomForest.h:59
static int MAX_DEPTH
default depth of each tree (used during construction)
Definition: drwnRandomForest.h:53
virtual drwnRandomForest * clone() const
returns a copy of the class usually implemented as virtual Foo* clone() { return new Foo(*this); } ...
Definition: drwnRandomForest.h:78
vector< double > _alphas
weight for each weak tree
Definition: drwnRandomForest.h:65
Implements the interface for a generic machine learning classifier.
Definition: drwnClassifier.h:31
static int MAX_FEATURES
maximum number of features to use at each iteration (used during construction)
Definition: drwnRandomForest.h:54
int _maxFeatures
maximum number of features to use at each iteration
Definition: drwnRandomForest.h:60
int _numTrees
number of trees to learn
Definition: drwnRandomForest.h:58
Implements a cacheable dataset containing feature vectors, labels and optional weights.
Definition: drwnDataset.h:43
virtual void getClassScores(const vector< double > &features, vector< double > &outputScores) const =0
compute the unnormalized log-probability for a single feature vector
vector< drwnDecisionTree * > _forest
forest
Definition: drwnRandomForest.h:63
static int NUM_TREES
default number of trees in the forest (used during construction)
Definition: drwnRandomForest.h:52
Implements a Random forest ensemble of decision trees classifier. See L. Breiman, "Random Forests"...
Definition: drwnRandomForest.h:49
virtual const char * type() const
returns object type as a string (e.g., Foo::type() { return "Foo"; })
Definition: drwnRandomForest.h:77