Darwin  1.10(beta)
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Groups Pages
drwnInference.h
1 /*****************************************************************************
2 ** DARWIN: A FRAMEWORK FOR MACHINE LEARNING RESEARCH AND DEVELOPMENT
3 ** Distributed under the terms of the BSD license (see the LICENSE file)
4 ** Copyright (c) 2007-2015, Stephen Gould
5 ** All rights reserved.
6 **
7 ******************************************************************************
8 ** FILENAME: drwnInference.h
9 ** AUTHOR(S): Stephen Gould <stephen.gould@anu.edu.au>
10 **
11 *****************************************************************************/
12 
13 #pragma once
14 
15 #include <vector>
16 #include <set>
17 #include <map>
18 
19 #include "drwnBase.h"
20 #include "drwnVarUniverse.h"
21 #include "drwnVarAssignment.h"
22 #include "drwnFactorGraph.h"
23 #include "drwnTableFactorOps.h"
24 
25 // drwnInference class -----------------------------------------------------
36 
38  protected:
39  const drwnFactorGraph& _graph;
40 
41  public:
42  drwnInference(const drwnFactorGraph& graph);
43  drwnInference(const drwnInference& inf);
44  virtual ~drwnInference();
45 
47  virtual void clear() { /* do nothing */ };
50  virtual bool inference() = 0;
53  virtual void marginal(drwnTableFactor& belief) const = 0;
55  virtual drwnFactorGraph varMarginals() const;
56 
58  inline drwnTableFactor operator[](int varIndx) const {
59  drwnTableFactor factor(_graph.getUniverse());
60  factor.addVariable(varIndx);
61  marginal(factor);
62  return factor;
63  }
65  inline drwnTableFactor operator[](const char* varName) const {
66  return (*this)[_graph.getUniverse()->findVariable(varName)];
67  }
68 };
69 
70 // drwnMessagePassingInference class ----------------------------------------
73 
75 {
76  public:
77  static unsigned MAX_ITERATIONS;
78 
79  protected:
80  // forward and backward messages during each iteration
81  vector<drwnTableFactor *> _forwardMessages;
82  vector<drwnTableFactor *> _backwardMessages;
83  vector<drwnTableFactor *> _oldForwardMessages;
84  vector<drwnTableFactor *> _oldBackwardMessages;
85 
86  // computation tree: intermediate factors, and (atomic) factor operations
87  vector<drwnTableFactor *> _intermediateFactors;
88  vector<drwnFactorOperation *> _computations;
89 
90  // shared storage for intermediate factors
91  vector<drwnTableFactorStorage *> _sharedStorage;
92 
93  public:
95  //drwnMessagePassingInference(const drwnMessagePassingInference& inf);
96  virtual ~drwnMessagePassingInference();
97 
98  void clear();
99  bool inference();
100 
101  protected:
102  virtual void initializeMessages();
103  virtual void buildComputationGraph() = 0;
104 };
105 
106 // drwnSumProductInference class --------------------------------------------
108 
110  public:
113 
114  void marginal(drwnTableFactor& belief) const;
115 
116  protected:
117  void buildComputationGraph();
118 };
119 
120 // drwnAsyncSumProductInference class ---------------------------------------
122 
124  public:
127 
128  protected:
129  void buildComputationGraph();
130 };
const drwnVarUniversePtr & getUniverse() const
return the universe of variables for this factor graph
Definition: drwnFactorGraph.h:62
drwnInference(const drwnFactorGraph &graph)
reference to initial clique potentials
Definition: drwnInference.cpp:24
bool inference()
run inference (or resume for iterative algorithms) and return true if converged
Definition: drwnInference.cpp:99
drwnTableFactor operator[](int varIndx) const
return the marginal distribution over variable varIndx
Definition: drwnInference.h:58
virtual void marginal(drwnTableFactor &belief) const =0
return the belief over the variables in the given factor, which must be one of the cliques in the ori...
virtual bool inference()=0
run inference (or resume for iterative algorithms) and return true if converged
Implements asynchronous sum-product inference.
Definition: drwnInference.h:123
void marginal(drwnTableFactor &belief) const
return the belief over the variables in the given factor, which must be one of the cliques in the ori...
Definition: drwnInference.cpp:209
Container and utility functions for factor graphs.
Definition: drwnFactorGraph.h:40
drwnTableFactor operator[](const char *varName) const
return the marginal distribution over variable varName
Definition: drwnInference.h:65
Factor which stores the value of each assignment explicitly in table form.
Definition: drwnTableFactor.h:144
Implements generic message-passing algorithms on factor graphs. See derived classes for specific algo...
Definition: drwnInference.h:74
virtual drwnFactorGraph varMarginals() const
returns marginals for each variable in the factor graph's universe
Definition: drwnInference.cpp:40
void clear()
clear internally cached data (e.g., computation graph)
Definition: drwnInference.cpp:68
int findVariable(const char *name) const
returns the index of variable with name name
Definition: drwnVarUniverse.cpp:95
Data structures and utilities for encoding assignments to variables.
Interface for various (marginal) inference algorithms.
Definition: drwnInference.h:37
void addVariable(int var)
add variable by id
Definition: drwnTableFactor.cpp:236
virtual void clear()
clear internally cached data (e.g., computation graph)
Definition: drwnInference.h:47
Implements sum-product inference.
Definition: drwnInference.h:109
static unsigned MAX_ITERATIONS
maximum number of iterations
Definition: drwnInference.h:77