15 #include "drwnNNGraph.h"
18 using namespace Eigen;
26 static double ALPHA_ZERO;
27 static unsigned METRIC_ITERATIONS;
28 static unsigned SEARCH_ITERATIONS;
33 vector<double> _labelWeights;
45 virtual void setTransform(
const MatrixXd& Lt) = 0;
46 virtual MatrixXd getTransform()
const = 0;
47 virtual double learn(
unsigned maxCycles);
49 const drwnNNGraph& getSrcGraph()
const {
return _graph; }
50 const drwnNNGraph& getPosGraph()
const {
return _posGraph; }
51 const drwnNNGraph& getNegGraph()
const {
return _negGraph; }
53 void clearLabelWeights() { _labelWeights.
clear(); }
54 void setLabelWeights(
const vector<double>& w) {
55 DRWN_LOG_VERBOSE(
"...setting label weights to " <<
toString(w));
58 const vector<double>& getLabelWeights()
const {
return _labelWeights; }
61 virtual double computeObjective()
const;
62 virtual double computeLossFunction()
const;
63 virtual MatrixXd computeSubGradient() = 0;
64 virtual void subGradientStep(
const MatrixXd& G,
double alpha) = 0;
66 virtual void startMetricCycle();
67 virtual void endMetricCycle();
69 void updateGraphFeatures();
70 void nearestNeighbourUpdate(
unsigned nCycle,
unsigned maxIterations);
73 MatrixXd initializeTransform()
const;
79 typedef set<pair<drwnNNGraphNodeIndex, drwnNNGraphNodeIndex> > drwnNNGraphLearnViolatedConstraints;
92 void setTransform(
const MatrixXd& Lt);
93 MatrixXd getTransform()
const;
96 MatrixXd computeSubGradient();
97 void subGradientStep(
const MatrixXd& G,
double alpha);
99 void startMetricCycle();
116 void setTransform(
const MatrixXd& Lt);
117 MatrixXd getTransform()
const;
120 MatrixXd computeSubGradient();
121 void subGradientStep(
const MatrixXd& G,
double alpha);
123 void startMetricCycle();
136 virtual double computeLossFunction()
const;
156 void setTransform(
const MatrixXd& Lt);
157 MatrixXd getTransform()
const;
160 MatrixXd computeSubGradient();
161 void subGradientStep(
const MatrixXd& G,
double alpha);
163 void startMetricCycle();
unsigned _dim
feature vector dimensions
Definition: drwnNNGraphLearn.h:38
Learn the distance metric M = LL^T as L^T.
Definition: drwnNNGraphLearn.h:105
Learn the distance metric base class with full set of constraints (i.e., loss function over all targe...
Definition: drwnNNGraphLearn.h:24
Learn the distance metric base class with sparse set of constraints (i.e., loss function over further...
Definition: drwnNNGraphLearn.h:130
std::string toString(const T &v)
Templated function to make conversion from simple data types like int and double to strings easy for ...
Definition: drwnStrUtils.h:134
MatrixXd _X
cached (x_u - x_v) (x_u - x_v)^T
Definition: drwnNNGraphLearn.h:39
Basic datatype for holding three objects of arbitrary type. Similar to the STL pair<> class...
Definition: drwnTriplet.h:33
Class for maintaining a nearest neighbour graph over superpixel images. Search moves are implemented ...
Definition: drwnNNGraph.h:309
Definition: drwnNNGraphLearn.h:145
Definition: drwnNNGraphLearn.h:81
void clear()
clear the entire graph
Definition: drwnNNGraph.h:335