Darwin  1.10(beta)
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Groups Pages
drwnNNGraphLearn.h
1 /*****************************************************************************
2 ** DARWIN: A FRAMEWORK FOR MACHINE LEARNING RESEARCH AND DEVELOPMENT
3 ** Distributed under the terms of the BSD license (see the LICENSE file)
4 ** Copyright (c) 2007-2015, Stephen Gould
5 ** All rights reserved.
6 **
7 ******************************************************************************
8 ** FILENAME: drwnNNGraphLearn.h
9 ** AUTHOR(S): Stephen Gould <stephen.gould@anu.edu.au>
10 **
11 *****************************************************************************/
12 
13 #pragma once
14 
15 #include "drwnNNGraph.h"
16 
17 using namespace std;
18 using namespace Eigen;
19 
20 // drwnNNGraphLearner --------------------------------------------------------
23 
25  public:
26  static double ALPHA_ZERO;
27  static unsigned METRIC_ITERATIONS;
28  static unsigned SEARCH_ITERATIONS;
29 
30  protected:
31  const drwnNNGraph& _graph;
32  double _lambda;
33  vector<double> _labelWeights;
34 
35  drwnNNGraph _posGraph;
36  drwnNNGraph _negGraph;
37 
38  unsigned _dim;
39  MatrixXd _X;
40 
41  public:
42  drwnNNGraphLearner(const drwnNNGraph& graph, double lambda);
43  virtual ~drwnNNGraphLearner();
44 
45  virtual void setTransform(const MatrixXd& Lt) = 0;
46  virtual MatrixXd getTransform() const = 0;
47  virtual double learn(unsigned maxCycles);
48 
49  const drwnNNGraph& getSrcGraph() const { return _graph; }
50  const drwnNNGraph& getPosGraph() const { return _posGraph; }
51  const drwnNNGraph& getNegGraph() const { return _negGraph; }
52 
53  void clearLabelWeights() { _labelWeights.clear(); }
54  void setLabelWeights(const vector<double>& w) {
55  DRWN_LOG_VERBOSE("...setting label weights to " << toString(w));
56  _labelWeights = w;
57  };
58  const vector<double>& getLabelWeights() const { return _labelWeights; }
59 
60  protected:
61  virtual double computeObjective() const;
62  virtual double computeLossFunction() const;
63  virtual MatrixXd computeSubGradient() = 0;
64  virtual void subGradientStep(const MatrixXd& G, double alpha) = 0;
65 
66  virtual void startMetricCycle();
67  virtual void endMetricCycle();
68 
69  void updateGraphFeatures();
70  void nearestNeighbourUpdate(unsigned nCycle, unsigned maxIterations);
71 
73  MatrixXd initializeTransform() const;
74 };
75 
76 // drwnNNGraphMLearner -------------------------------------------------------
78 
79 typedef set<pair<drwnNNGraphNodeIndex, drwnNNGraphNodeIndex> > drwnNNGraphLearnViolatedConstraints;
80 
82  protected:
83  MatrixXd _M;
84 
85  MatrixXd _G;
87 
88  public:
89  drwnNNGraphMLearner(const drwnNNGraph& graph, double lambda);
91 
92  void setTransform(const MatrixXd& Lt);
93  MatrixXd getTransform() const;
94 
95  protected:
96  MatrixXd computeSubGradient();
97  void subGradientStep(const MatrixXd& G, double alpha);
98 
99  void startMetricCycle();
100 };
101 
102 // drwnNNGraphLLearner -------------------------------------------------------
104 
106  protected:
107  MatrixXd _Lt;
108 
109  MatrixXd _G;
111 
112  public:
113  drwnNNGraphLLearner(const drwnNNGraph& graph, double lambda);
115 
116  void setTransform(const MatrixXd& Lt);
117  MatrixXd getTransform() const;
118 
119  protected:
120  MatrixXd computeSubGradient();
121  void subGradientStep(const MatrixXd& G, double alpha);
122 
123  void startMetricCycle();
124 };
125 
126 // drwnNNGraphSparseLearner --------------------------------------------------
129 
131  public:
132  drwnNNGraphSparseLearner(const drwnNNGraph& graph, double lambda);
134 
135  protected:
136  virtual double computeLossFunction() const;
137 };
138 
139 // drwnNNGraphLFastLearner ---------------------------------------------------
141 
144 
146  protected:
147  MatrixXd _Lt;
148 
149  MatrixXd _G;
151 
152  public:
153  drwnNNGraphLSparseLearner(const drwnNNGraph& graph, double lambda);
155 
156  void setTransform(const MatrixXd& Lt);
157  MatrixXd getTransform() const;
158 
159  protected:
160  MatrixXd computeSubGradient();
161  void subGradientStep(const MatrixXd& G, double alpha);
162 
163  void startMetricCycle();
164 };
unsigned _dim
feature vector dimensions
Definition: drwnNNGraphLearn.h:38
Learn the distance metric M = LL^T as L^T.
Definition: drwnNNGraphLearn.h:105
Learn the distance metric base class with full set of constraints (i.e., loss function over all targe...
Definition: drwnNNGraphLearn.h:24
Learn the distance metric base class with sparse set of constraints (i.e., loss function over further...
Definition: drwnNNGraphLearn.h:130
std::string toString(const T &v)
Templated function to make conversion from simple data types like int and double to strings easy for ...
Definition: drwnStrUtils.h:134
MatrixXd _X
cached (x_u - x_v) (x_u - x_v)^T
Definition: drwnNNGraphLearn.h:39
Basic datatype for holding three objects of arbitrary type. Similar to the STL pair<> class...
Definition: drwnTriplet.h:33
Class for maintaining a nearest neighbour graph over superpixel images. Search moves are implemented ...
Definition: drwnNNGraph.h:309
Definition: drwnNNGraphLearn.h:145
Definition: drwnNNGraphLearn.h:81
void clear()
clear the entire graph
Definition: drwnNNGraph.h:335