Darwin  1.10(beta)
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Groups Pages
drwnBoostedClassifier.h
1 /******************************************************************************
2 ** DARWIN: A FRAMEWORK FOR MACHINE LEARNING RESEARCH AND DEVELOPMENT
3 ** Distributed under the terms of the BSD license (see the LICENSE file)
4 ** Copyright (c) 2007-2015, Stephen Gould
5 ** All rights reserved.
6 **
7 ******************************************************************************
8 ** FILENAME: drwnBoostedClassifier.h
9 ** AUTHOR(S): Stephen Gould <stephen.gould@anu.edu.au>
10 **
11 *****************************************************************************/
12 
13 #pragma once
14 
15 #include <cstdlib>
16 #include <vector>
17 
18 #include "drwnBase.h"
19 #include "drwnClassifier.h"
20 #include "drwnDecisionTree.h"
21 
22 using namespace std;
23 
24 // drwnBoostingMethod --------------------------------------------------------
25 
26 typedef enum _drwnBoostingMethod {
27  DRWN_BOOST_DISCRETE, DRWN_BOOST_GENTLE, DRWN_BOOST_REAL
28 } drwnBoostingMethod;
29 
30 // drwnBoostedClassifier -----------------------------------------------------
60 
62  public:
63  // default training parameters
65  static drwnBoostingMethod METHOD;
66  static int NUM_ROUNDS;
67  static int MAX_DEPTH;
68  static double SHRINKAGE;
69 
70  protected:
71  // actual training parameters
72  drwnBoostingMethod _method;
73  int _numRounds;
74  int _maxDepth;
75  double _shrinkage;
76 
78  vector<drwnDecisionTree *> _weakLearners;
80  vector<double> _alphas;
81 
82  public:
86  drwnBoostedClassifier(unsigned n, unsigned k = 2);
90 
91  // access functions
92  virtual const char *type() const { return "drwnBoostedClassifier"; }
93  virtual drwnBoostedClassifier *clone() const { return new drwnBoostedClassifier(*this); }
94 
95  // initialization
96  virtual void initialize(unsigned n, unsigned k = 2);
97 
98  // i/o
99  virtual bool save(drwnXMLNode& node) const;
100  virtual bool load(drwnXMLNode& node);
101 
102  // training
103  using drwnClassifier::train;
104  virtual double train(const drwnClassifierDataset& dataset);
105 
110  void pruneRounds(unsigned numRounds);
111 
112  // evaluation (log-probability)
114  virtual void getClassScores(const vector<double>& features,
115  vector<double>& outputScores) const;
116 
117  protected:
118 
119 };
virtual double train(const drwnClassifierDataset &dataset)=0
train the parameters of the classifier from a drwnClassifierDataset object
int _numRounds
number of rounds of boosting
Definition: drwnBoostedClassifier.h:73
vector< drwnDecisionTree * > _weakLearners
weak learners
Definition: drwnBoostedClassifier.h:78
static double SHRINKAGE
boosting shrinkage
Definition: drwnBoostedClassifier.h:68
vector< double > _alphas
weight for each weak learner
Definition: drwnBoostedClassifier.h:80
virtual drwnBoostedClassifier * clone() const
returns a copy of the class usually implemented as virtual Foo* clone() { return new Foo(*this); } ...
Definition: drwnBoostedClassifier.h:93
static int NUM_ROUNDS
maximum number of boosting rounds
Definition: drwnBoostedClassifier.h:66
static drwnBoostingMethod METHOD
controls the re-weighting of data samples at the end of each training iteration
Definition: drwnBoostedClassifier.h:65
Implements a mult-class boosted decision-tree classifier. See Zhu et al., Multi-class AdaBoost...
Definition: drwnBoostedClassifier.h:61
double _shrinkage
boosting shrinkage
Definition: drwnBoostedClassifier.h:75
Implements the interface for a generic machine learning classifier.
Definition: drwnClassifier.h:31
drwnBoostingMethod _method
boosting method
Definition: drwnBoostedClassifier.h:72
virtual const char * type() const
returns object type as a string (e.g., Foo::type() { return "Foo"; })
Definition: drwnBoostedClassifier.h:92
Implements a cacheable dataset containing feature vectors, labels and optional weights.
Definition: drwnDataset.h:43
virtual void getClassScores(const vector< double > &features, vector< double > &outputScores) const =0
compute the unnormalized log-probability for a single feature vector
int _maxDepth
maximum depth of each decision tree
Definition: drwnBoostedClassifier.h:74
static int MAX_DEPTH
maximum depth of each decision tree
Definition: drwnBoostedClassifier.h:67