inference.hxx

Go to the documentation of this file.
00001 #pragma once
00002 #ifndef OPENGM_INFERENCE_HXX
00003 #define OPENGM_INFERENCE_HXX
00004 
00005 #include <vector>
00006 #include <string>
00007 #include <list>
00008 #include <limits>
00009 #include <exception>
00010 
00011 #include "opengm/opengm.hxx"
00012 
00013 #define OPENGM_GM_TYPE_TYPEDEFS                                                      \
00014    typedef typename GraphicalModelType::LabelType LabelType;                         \
00015    typedef typename GraphicalModelType::IndexType IndexType;                         \
00016    typedef typename GraphicalModelType::ValueType ValueType;                         \
00017    typedef typename GraphicalModelType::OperatorType OperatorType;                   \
00018    typedef typename GraphicalModelType::FactorType FactorType;                       \
00019    typedef typename GraphicalModelType::IndependentFactorType IndependentFactorType; \
00020    typedef typename GraphicalModelType::FunctionIdentifier FunctionIdentifier        \
00021 
00022 namespace opengm {
00023 
00024 enum InferenceTermination {
00025    UNKNOWN=0, 
00026    NORMAL=1, 
00027    TIMEOUT=2, 
00028    CONVERGENCE=3, 
00029    INFERENCE_ERROR=4
00030 };
00031 
00033 template <class GM, class ACC>
00034 class Inference
00035 {
00036 public:
00037    typedef GM GraphicalModelType;
00038    typedef ACC AccumulationType;
00039    typedef typename GraphicalModelType::LabelType LabelType;
00040    typedef typename GraphicalModelType::IndexType IndexType;
00041    typedef typename GraphicalModelType::ValueType ValueType;
00042    typedef typename GraphicalModelType::OperatorType OperatorType;
00043    typedef typename GraphicalModelType::FactorType FactorType;
00044    typedef typename GraphicalModelType::IndependentFactorType IndependentFactorType;
00045    typedef typename GraphicalModelType::FunctionIdentifier FunctionIdentifier;
00046 
00047    virtual ~Inference() {}
00048 
00049    virtual std::string name() const = 0;
00050    virtual const GraphicalModelType& graphicalModel() const = 0;
00051    virtual InferenceTermination infer() = 0;
00055 
00056    // member functions with default definition
00057    virtual void setStartingPoint(typename std::vector<LabelType>::const_iterator);
00058    virtual InferenceTermination arg(std::vector<LabelType>&, const size_t = 1) const;
00059    virtual InferenceTermination args(std::vector<std::vector<LabelType> >&) const;
00060    virtual InferenceTermination marginal(const size_t, IndependentFactorType&) const;
00061    virtual InferenceTermination factorMarginal(const size_t, IndependentFactorType&) const;
00062    virtual ValueType bound() const;
00063    virtual ValueType value() const;
00064    InferenceTermination constrainedOptimum(std::vector<IndexType>&,std::vector<LabelType>&, std::vector<LabelType>&) const;
00065    InferenceTermination modeFromMarginal(std::vector<LabelType>&) const;
00066    InferenceTermination modeFromFactorMarginal(std::vector<LabelType>&) const;
00067 };
00068 
00072 template<class GM, class ACC>
00073 inline InferenceTermination
00074 Inference<GM, ACC>::arg(
00075    std::vector<LabelType>& arg,
00076    const size_t argIndex
00077 ) const
00078 {
00079    return UNKNOWN;
00080 }
00081    
00084 template<class GM, class ACC>   
00085 inline void 
00086 Inference<GM, ACC>::setStartingPoint(
00087    typename std::vector<LabelType>::const_iterator begin
00088 ) 
00089 {}
00090    
00091 template<class GM, class ACC>
00092 inline InferenceTermination
00093 Inference<GM, ACC>::args(
00094    std::vector<std::vector<LabelType> >& out
00095 ) const
00096 {
00097    return UNKNOWN;
00098 }
00099 
00103 template<class GM, class ACC>
00104 inline InferenceTermination
00105 Inference<GM, ACC>::marginal(
00106    const size_t variableIndex,
00107    IndependentFactorType& out
00108    ) const
00109 {
00110    return UNKNOWN;
00111 }
00112 
00116 template<class GM, class ACC>
00117 inline InferenceTermination
00118 Inference<GM, ACC>::factorMarginal(
00119    const size_t factorIndex,
00120    IndependentFactorType& out
00121 ) const
00122 {
00123    return UNKNOWN;
00124 }
00125 
00126 template<class GM, class ACC>
00127 InferenceTermination
00128 Inference<GM, ACC>::constrainedOptimum(
00129    std::vector<IndexType>& variableIndices,
00130    std::vector<LabelType>& givenLabels,
00131    std::vector<LabelType>& conf
00132 ) const
00133 {
00134    const GM& gm = graphicalModel();
00135    std::vector<IndexType> waitingVariables;
00136    size_t variableId = 0;
00137    size_t numberOfVariables = gm.numberOfVariables();
00138    size_t numberOfFixedVariables = 0;
00139    conf.assign(gm.numberOfVariables(),std::numeric_limits<LabelType>::max());
00140    OPENGM_ASSERT(variableIndices.size()>=givenLabels.size());
00141    for(size_t i=0; i<givenLabels.size() ;++i) {
00142       OPENGM_ASSERT( variableIndices[i]<gm.numberOfVariables());
00143       OPENGM_ASSERT( givenLabels[i]<gm.numberOfLabels(variableIndices[i]));
00144       conf[variableIndices[i]] = givenLabels[i];
00145       waitingVariables.push_back(variableIndices[i]);
00146       ++numberOfFixedVariables;
00147    }
00148    while(variableId<gm.numberOfVariables() && numberOfFixedVariables<numberOfVariables) {
00149       while(waitingVariables.size()>0 && numberOfFixedVariables<numberOfVariables) {
00150          size_t var = waitingVariables.back();
00151          waitingVariables.pop_back();
00152 
00153          //Search unset neighbourd variable
00154          for(size_t i=0; i<gm.numberOfFactors(var); ++i) { 
00155             size_t var2=var;
00156             size_t factorId = gm.factorOfVariable(var,i);
00157             for(size_t n=0; n<gm[factorId].numberOfVariables();++n) {
00158                if(conf[gm[factorId].variableIndex(n)] == std::numeric_limits<LabelType>::max()) {
00159                   var2=gm[factorId].variableIndex(n);
00160                   break;
00161                }
00162             }
00163             if(var2 != var) { 
00164                //Set this variable
00165                IndependentFactorType t;
00166                //marginal(var2, t);
00167                for(size_t i=0; i<gm.numberOfFactors(var2); ++i) {
00168                   size_t factorId = gm.factorOfVariable(var2,i);
00169                   std::vector<IndexType> knownVariables;
00170                   std::vector<LabelType> knownStates;
00171                   std::vector<IndexType> unknownVariables; 
00172                   IndependentFactorType out;
00173                   InferenceTermination term = factorMarginal(factorId, out);
00174                   if(NORMAL != term) {
00175                      return term;
00176                   }
00177                   for(size_t n=0; n<gm[factorId].numberOfVariables();++n) {
00178                      if(gm[factorId].variableIndex(n)!=var2) {
00179                         if(conf[gm[factorId].variableIndex(n)] < std::numeric_limits<LabelType>::max()) {
00180                            knownVariables.push_back(gm[factorId].variableIndex(n));
00181                            knownStates.push_back(conf[gm[factorId].variableIndex(n)]);
00182                         }else{
00183                            unknownVariables.push_back(gm[factorId].variableIndex(n));
00184                         }
00185                      }
00186                   } 
00187                      
00188                   out.fixVariables(knownVariables.begin(), knownVariables.end(), knownStates.begin()); 
00189                   if(unknownVariables.size()>0)
00190                      out.template accumulate<AccumulationType>(unknownVariables.begin(),unknownVariables.end());
00191                   OperatorType::op(out,t); 
00192                } 
00193                ValueType value;
00194                std::vector<LabelType> state(t.numberOfVariables());
00195                t.template accumulate<AccumulationType>(value,state);
00196                conf[var2] = state[0];
00197                ++numberOfFixedVariables;
00198                waitingVariables.push_back(var2);
00199             }
00200          }
00201       }
00202       if(conf[variableId]==std::numeric_limits<LabelType>::max()) {
00203          //Set variable state
00204          IndependentFactorType out;
00205          InferenceTermination term = marginal(variableId, out);
00206          if(NORMAL != term) {
00207             return term;
00208          } 
00209          ValueType value;
00210          std::vector<LabelType> state(out.numberOfVariables());
00211          out.template accumulate<AccumulationType>(value,state);
00212          conf[variableId] = state[0];
00213          waitingVariables.push_back(variableId);
00214       }
00215       ++variableId;
00216    }
00217    return NORMAL;
00218 }
00219 
00220 template<class GM, class ACC>
00221 InferenceTermination
00222 Inference<GM, ACC>::modeFromMarginal(
00223    std::vector<LabelType>& conf
00224    ) const
00225 {
00226    const GM&         gm = graphicalModel();
00227    //const space_type& space = gm.space();
00228    size_t            numberOfNodes = gm.numberOfVariables();
00229    conf.resize(gm.numberOfVariables());
00230    IndependentFactorType out;
00231    for(size_t node=0; node<numberOfNodes; ++node) {
00232       InferenceTermination term = marginal(node, out);
00233       if(NORMAL != term) {
00234          return term;
00235       }
00236       ValueType value = out(0);
00237       size_t state = 0;
00238       for(size_t i=1; i<gm.numberOfLabels(node); ++i) {
00239          if(ACC::bop(out(i), value)) {
00240             value = out(i);
00241             state = i;
00242          }
00243       }
00244       conf[node] = state;
00245    }
00246    return NORMAL;
00247 }
00248 
00249 template<class GM, class ACC>
00250 InferenceTermination
00251 Inference<GM, ACC>::modeFromFactorMarginal(
00252    std::vector<LabelType>& conf
00253 ) const
00254 {
00255    const GM& gm = graphicalModel();
00256    std::vector<IndexType> knownVariables;
00257    std::vector<LabelType> knownStates;
00258    IndependentFactorType out;
00259    for(size_t node=0; node<gm.numberOfVariables(); ++node) {
00260       InferenceTermination term = marginal(node, out);
00261       if(NORMAL != term) {
00262          return term;
00263       }
00264       ValueType value = out(0);
00265       size_t state = 0;
00266       bool unique = true;
00267       for(size_t i=1; i<gm.numberOfLabels(node); ++i) {
00268 
00269          //ValueType q = out(i)/value;
00270          //if(q<1.001 && q>0.999) {
00271          //   unique=false;
00272          //}
00273          if(fabs(out(i) - value)<0.00001) {
00274             unique=false;
00275          }
00276          else if(ACC::bop(out(i), value)) {
00277             value = out(i);
00278             state = i;
00279             unique=true;
00280          }
00281       }
00282       if(unique) {
00283          knownVariables.push_back(node);
00284          knownStates.push_back(state);
00285       }
00286    }
00287    return constrainedOptimum( knownVariables, knownStates, conf);
00288 }
00289 
00291 template<class GM, class ACC>
00292 typename GM::ValueType
00293 Inference<GM, ACC>::value() const 
00294 {
00295    if(ACC::hasbop()){ 
00296       // Default implementation if ACC defines an ordering  
00297       std::vector<LabelType> s;
00298       const GM& gm = graphicalModel();
00299       if(NORMAL == arg(s)) {
00300          return gm.evaluate(s);
00301       }
00302       else {
00303          return ACC::template neutral<ValueType>();
00304       }
00305    }else{
00306       //TODO: Maybe throw an exception here 
00307       //throw std::runtime_error("There is no default implementation for this type of semi-ring");
00308       return std::numeric_limits<ValueType>::quiet_NaN();
00309    }
00310 }
00311 
00313 template<class GM, class ACC>
00314 typename GM::ValueType
00315 Inference<GM, ACC>::bound() const { 
00316    if(ACC::hasbop()){
00317       // Default implementation if ACC defines an ordering
00318       return ACC::template ineutral<ValueType>();
00319    }else{
00320       //TODO: Maybe throw an exception here 
00321       //throw std::runtime_error("There is no default implementation for this type of semi-ring");
00322       return std::numeric_limits<ValueType>::quiet_NaN();
00323    }
00324 }
00325 
00326 } // namespace opengm
00327 
00328 #endif // #ifndef OPENGM_INFERENCE_HXX
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
Generated on Mon Jun 17 16:31:03 2013 for OpenGM by  doxygen 1.6.3