dynamicprogramming.hxx

Go to the documentation of this file.
00001 #pragma once
00002 #ifndef OPENGM_DYNAMICPROGRAMMING_HXX
00003 #define OPENGM_DYNAMICPROGRAMMING_HXX
00004 
00005 #include <typeinfo>
00006 #include <limits>
00007 #include "opengm/inference/inference.hxx"
00008 #include "opengm/inference/visitors/visitor.hxx"
00009 
00010 namespace opengm {
00011 
00015   template<class GM, class ACC>
00016   class DynamicProgramming : public Inference<GM, ACC> {
00017   public:
00018     typedef ACC AccumulationType;
00019     typedef ACC AccumulatorType;
00020     typedef GM GraphicalModelType;
00021     OPENGM_GM_TYPE_TYPEDEFS;
00022     typedef unsigned char MyStateType;
00023     typedef double        MyValueType;
00024     typedef VerboseVisitor<DynamicProgramming<GM,ACC> >        VerboseVisitorType;
00025     typedef TimingVisitor<DynamicProgramming<GM,ACC> >         TimingVisitorType;
00026     typedef EmptyVisitor<DynamicProgramming<GM,ACC> >          EmptyVisitorType;
00027     struct Parameter {
00028       std::vector<IndexType> roots_;
00029     };
00030 
00031     DynamicProgramming(const GraphicalModelType&, const Parameter& = Parameter());
00032     ~DynamicProgramming();
00033 
00034     std::string name() const;
00035     const GraphicalModelType& graphicalModel() const;
00036     InferenceTermination infer();
00037     template< class VISITOR>
00038     InferenceTermination infer(VISITOR &);
00039     InferenceTermination arg(std::vector<LabelType>&, const size_t = 1) const;
00040     
00041     
00042     void getNodeInfo(const IndexType Inode, std::vector<ValueType>& values, std::vector<std::vector<LabelType> >& substates, std::vector<IndexType>& nodes) const;
00043     
00044 
00045   private:
00046     const GraphicalModelType& gm_;
00047     Parameter para_;
00048     MyValueType* valueBuffer_;
00049     MyStateType* stateBuffer_;
00050     std::vector<MyValueType*> valueBuffers_;
00051     std::vector<MyStateType*> stateBuffers_;
00052     std::vector<size_t> nodeOrder_; 
00053     std::vector<size_t> orderedNodes_;
00054   };
00055 
00056   template<class GM, class ACC>
00057   inline std::string
00058   DynamicProgramming<GM, ACC>::name() const {
00059     return "DynamicProgramming";
00060   }
00061 
00062   template<class GM, class ACC>
00063   inline const typename DynamicProgramming<GM, ACC>::GraphicalModelType&
00064   DynamicProgramming<GM, ACC>::graphicalModel() const {
00065     return gm_;
00066   }
00067 
00068   template<class GM, class ACC>
00069   DynamicProgramming<GM, ACC>::~DynamicProgramming()
00070   {
00071     free(valueBuffer_);
00072     free(stateBuffer_);
00073   }
00074   
00075   template<class GM, class ACC>
00076   inline DynamicProgramming<GM, ACC>::DynamicProgramming
00077   (
00078   const GraphicalModelType& gm, 
00079   const Parameter& para
00080   ) 
00081   :  gm_(gm)
00082   {
00083     para_ = para;
00084     
00085     // Set nodeOrder 
00086     std::vector<size_t> numChildren(gm_.numberOfVariables(),0);
00087     std::vector<size_t> nodeList;
00088     size_t orderCount = 0;
00089     size_t varCount   = 0;
00090     nodeOrder_.resize(gm_.numberOfVariables(),std::numeric_limits<std::size_t>::max());
00091     size_t rootCounter=0;
00092     while(varCount < gm_.numberOfVariables() && orderCount < gm_.numberOfVariables()){
00093       if(rootCounter<para_.roots_.size()){
00094         nodeOrder_[para_.roots_[rootCounter]] = orderCount++;
00095         nodeList.push_back(para_.roots_[rootCounter]);
00096         ++rootCounter;
00097       }
00098       else if(nodeOrder_[varCount]==std::numeric_limits<std::size_t>::max()){
00099         nodeOrder_[varCount] = orderCount++;
00100         nodeList.push_back(varCount);
00101       }
00102       ++varCount;
00103       while(nodeList.size()>0){
00104         size_t node = nodeList.back();
00105         nodeList.pop_back();
00106         for(typename GM::ConstFactorIterator it=gm_.factorsOfVariableBegin(node); it !=gm_.factorsOfVariableEnd(node); ++it){
00107           const typename GM::FactorType& factor = gm_[(*it)];
00108           if( factor.numberOfVariables() == 2 ){
00109             if( factor.variableIndex(1) == node && nodeOrder_[factor.variableIndex(0)]==std::numeric_limits<std::size_t>::max() ){
00110               nodeOrder_[factor.variableIndex(0)] = orderCount++;
00111               nodeList.push_back(factor.variableIndex(0));
00112               ++numChildren[node];
00113             }
00114             if( factor.variableIndex(0) == node && nodeOrder_[factor.variableIndex(1)]==std::numeric_limits<std::size_t>::max() ){
00115               nodeOrder_[factor.variableIndex(1)] = orderCount++;
00116               nodeList.push_back(factor.variableIndex(1));
00117               ++numChildren[node];                       
00118             }
00119           }
00120         }
00121       }
00122     }
00123 
00124     // Allocate memmory
00125     size_t memSizeValue = 0;
00126     size_t memSizeState = 0;
00127     for(size_t i=0; i<gm_.numberOfVariables();++i){
00128       memSizeValue += gm_.numberOfLabels(i);
00129       memSizeState += gm.numberOfLabels(i) * numChildren[i];
00130     }
00131     valueBuffer_ = (MyValueType*) malloc(memSizeValue*sizeof(MyValueType));
00132     stateBuffer_ = (MyStateType*) malloc(memSizeState*sizeof(MyStateType));
00133     valueBuffers_.resize(gm_.numberOfVariables());
00134     stateBuffers_.resize(gm_.numberOfVariables()); 
00135     
00136     MyValueType* valuePointer =  valueBuffer_;
00137     MyStateType* statePointer =  stateBuffer_;
00138     for(size_t i=0; i<gm_.numberOfVariables();++i){
00139       valueBuffers_[i] = valuePointer;
00140       valuePointer += gm.numberOfLabels(i);
00141       stateBuffers_[i] = statePointer;
00142       statePointer +=  gm.numberOfLabels(i) * numChildren[i];
00143     }
00144     
00145     orderedNodes_.resize(gm_.numberOfVariables(),std::numeric_limits<std::size_t>::max());
00146     for(size_t i=0; i<gm_.numberOfVariables(); ++i)
00147       orderedNodes_[nodeOrder_[i]] = i;
00148     
00149   }
00150   
00151   template<class GM, class ACC>
00152   inline InferenceTermination 
00153   DynamicProgramming<GM, ACC>::infer(){
00154     EmptyVisitorType v;
00155     return infer(v);
00156   }
00157   
00158   template<class GM, class ACC>
00159   template<class VISITOR>
00160   inline InferenceTermination 
00161   DynamicProgramming<GM, ACC>::infer
00162   (
00163   VISITOR & visitor
00164   ){
00165     for(size_t i=1; i<=gm_.numberOfVariables();++i){
00166       const size_t node = orderedNodes_[gm_.numberOfVariables()-i];
00167       // set buffer neutral
00168       for(size_t n=0; n<gm_.numberOfLabels(node); ++n){
00169         OperatorType::neutral(valueBuffers_[node][n]);
00170       }
00171       // accumulate messages
00172       size_t childrenCounter = 0;
00173       for(typename GM::ConstFactorIterator it=gm_.factorsOfVariableBegin(node); it !=gm_.factorsOfVariableEnd(node); ++it){
00174         const typename GM::FactorType& factor = gm_[(*it)];
00175 
00176         // unary
00177         if(factor.numberOfVariables()==1 ){
00178           for(size_t n=0; n<gm_.numberOfLabels(node); ++n){
00179             const ValueType fac = factor(&n);
00180             OperatorType::op(fac, valueBuffers_[node][n]); 
00181           } 
00182         }
00183         
00184         //pairwise
00185         if( factor.numberOfVariables()==2 ){
00186           size_t vec[] = {0,0};
00187           if(factor.variableIndex(0) == node && nodeOrder_[factor.variableIndex(1)]>nodeOrder_[node] ){
00188             const size_t node2 = factor.variableIndex(1);
00189             MyStateType s;
00190             MyValueType v,v2;
00191             for(vec[0]=0; vec[0]<gm_.numberOfLabels(node); ++vec[0]){
00192               ACC::neutral(v);
00193               for(vec[1]=0; vec[1]<gm_.numberOfLabels(node2); ++vec[1]){ 
00194                 const ValueType fac = factor(vec);
00195                 OperatorType::op(fac,valueBuffers_[node2][vec[1]],v2) ;
00196                 if(ACC::bop(v2,v)){
00197                   v=v2;
00198                   s=vec[1];
00199                 }
00200               }
00201               stateBuffers_[node][childrenCounter*gm_.numberOfLabels(node)+vec[0]] = s;
00202               OperatorType::op(v,valueBuffers_[node][vec[0]]);
00203             }  
00204             ++childrenCounter;
00205             
00206           }
00207           if(factor.variableIndex(1) == node && nodeOrder_[factor.variableIndex(0)]>nodeOrder_[node]){ 
00208             const size_t node2 = factor.variableIndex(0);
00209             MyStateType s;
00210             MyValueType v,v2;
00211             for(vec[1]=0; vec[1]<gm_.numberOfLabels(node); ++vec[1]){
00212               ACC::neutral(v);
00213               for(vec[0]=0; vec[0]<gm_.numberOfLabels(node2); ++vec[0]){
00214                 const ValueType fac = factor(vec);
00215                 OperatorType::op(fac,valueBuffers_[node2][vec[0]],v2); 
00216                 if(ACC::bop(v2,v)){
00217                   v=v2;
00218                   s=vec[0];
00219                 }
00220               }  
00221               stateBuffers_[node][childrenCounter*gm_.numberOfLabels(node)+vec[1]] = s;
00222               OperatorType::op(v,valueBuffers_[node][vec[1]]); 
00223             }
00224             ++childrenCounter;                      
00225           }
00226         }
00227       } 
00228     }
00229     return NORMAL;
00230   }
00231 
00232   template<class GM, class ACC>
00233   inline InferenceTermination DynamicProgramming<GM, ACC>::arg
00234   (
00235   std::vector<LabelType>& arg, 
00236   const size_t n
00237   ) const {
00238     if(n > 1) {
00239       return UNKNOWN;
00240     } 
00241     else {
00242       std::vector<size_t> nodeList;
00243       arg.assign(gm_.numberOfVariables(), std::numeric_limits<LabelType>::max() );
00244       size_t var = 0;
00245       while(var < gm_.numberOfVariables()){
00246         if(arg[var]==std::numeric_limits<LabelType>::max()){
00247           MyValueType v; ACC::neutral(v);             
00248           for(size_t i=0; i<gm_.numberOfLabels(var); ++i){
00249             if(ACC::bop(valueBuffers_[var][i], v)){
00250               v = valueBuffers_[var][i];
00251               arg[var]=i;      
00252             }
00253           }
00254           nodeList.push_back(var);
00255         }
00256         ++var;
00257         while(nodeList.size()>0){
00258           size_t node = nodeList.back();
00259           size_t childrenCounter = 0;
00260           nodeList.pop_back();
00261           for(typename GM::ConstFactorIterator it=gm_.factorsOfVariableBegin(node); it !=gm_.factorsOfVariableEnd(node); ++it){
00262             const typename GM::FactorType& factor = gm_[(*it)];
00263             if(factor.numberOfVariables()==2 ){
00264               if(factor.variableIndex(1)==node && nodeOrder_[factor.variableIndex(0)] > nodeOrder_[node] ){
00265                 arg[factor.variableIndex(0)] = stateBuffers_[node][childrenCounter*gm_.numberOfLabels(node)+arg[node]];
00266                 nodeList.push_back(factor.variableIndex(0));
00267                 ++childrenCounter;             
00268               }
00269               if(factor.variableIndex(0)==node && nodeOrder_[factor.variableIndex(1)] > nodeOrder_[node] ){
00270                 arg[factor.variableIndex(1)] = stateBuffers_[node][childrenCounter*gm_.numberOfLabels(node)+arg[node]];
00271                 nodeList.push_back(factor.variableIndex(1)); 
00272                 ++childrenCounter;                                 
00273               }
00274             }
00275           }
00276         }
00277       }
00278       return NORMAL;
00279     }
00280   }
00281 
00282   template<class GM, class ACC>
00283   inline void DynamicProgramming<GM, ACC>::getNodeInfo(const IndexType Inode, std::vector<ValueType>& values, std::vector<std::vector<LabelType> >& substates, std::vector<IndexType>& nodes) const{
00284     values.clear();
00285     substates.clear();
00286     nodes.clear();
00287     values.resize(gm_.numberOfLabels(Inode));
00288     substates.resize(gm_.numberOfLabels(Inode));
00289     std::vector<LabelType> arg;
00290     bool firstround = true;
00291     std::vector<size_t> nodeList;
00292     for(IndexType i=0;i<gm_.numberOfLabels(Inode); ++i){
00293       arg.assign(gm_.numberOfVariables(), std::numeric_limits<LabelType>::max() );
00294       arg[Inode]=i;
00295       values[i]=valueBuffers_[Inode][i];
00296       nodeList.push_back(Inode);
00297       if(i!=0){
00298         firstround=false;
00299       }
00300       
00301       while(nodeList.size()>0){
00302         size_t node = nodeList.back();
00303         size_t childrenCounter = 0;
00304         nodeList.pop_back();
00305         for(typename GM::ConstFactorIterator it=gm_.factorsOfVariableBegin(node); it !=gm_.factorsOfVariableEnd(node); ++it){
00306           const typename GM::FactorType& factor = gm_[(*it)];
00307           if(factor.numberOfVariables()==2 ){
00308             if(factor.variableIndex(1)==node && nodeOrder_[factor.variableIndex(0)] > nodeOrder_[node] ){
00309               arg[factor.variableIndex(0)] = stateBuffers_[node][childrenCounter*gm_.numberOfLabels(node)+arg[node]];
00310               substates[i].push_back(stateBuffers_[node][childrenCounter*gm_.numberOfLabels(node)+arg[node]]);
00311               if(firstround==true){              
00312                 nodes.push_back(factor.variableIndex(0));
00313               }
00314               nodeList.push_back(factor.variableIndex(0));
00315               ++childrenCounter;             
00316             }
00317             if(factor.variableIndex(0)==node && nodeOrder_[factor.variableIndex(1)] > nodeOrder_[node] ){
00318               arg[factor.variableIndex(1)] = stateBuffers_[node][childrenCounter*gm_.numberOfLabels(node)+arg[node]];
00319               substates[i].push_back(stateBuffers_[node][childrenCounter*gm_.numberOfLabels(node)+arg[node]]);
00320               if(firstround==true){
00321                 nodes.push_back(factor.variableIndex(1));
00322               }
00323               nodeList.push_back(factor.variableIndex(1)); 
00324               ++childrenCounter;                                 
00325             }
00326           }
00327         }
00328       }
00329     }
00330   }
00331   
00332   
00333 } // namespace opengm
00334 
00335 #endif // #ifndef OPENGM_DYNAMICPROGRAMMING_HXX
00336 
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
Generated on Mon Jun 17 16:31:02 2013 for OpenGM by  doxygen 1.6.3