alphabetaswap.hxx

Go to the documentation of this file.
00001 #pragma once
00002 #ifndef OPENGM_ALPHABEATSWAP_HXX
00003 #define OPENGM_ALPHABETASWAP_HXX
00004 
00005 #include <vector>
00006 
00007 #include "opengm/inference/inference.hxx"
00008 #include "opengm/inference/visitors/visitor.hxx"
00009 
00010 namespace opengm {
00011 
00014 template<class GM, class INF>
00015 class AlphaBetaSwap : public Inference<GM, typename INF::AccumulationType> {
00016 public:
00017    typedef GM GraphicalModelType;
00018    typedef INF InferenceType;
00019    typedef typename INF::AccumulationType AccumulationType;
00020    OPENGM_GM_TYPE_TYPEDEFS;
00021    typedef VerboseVisitor<AlphaBetaSwap<GM,INF> >        VerboseVisitorType;
00022    typedef TimingVisitor<AlphaBetaSwap<GM,INF> >         TimingVisitorType;
00023    typedef EmptyVisitor<AlphaBetaSwap<GM,INF> >          EmptyVisitorType;
00024 
00025    struct Parameter {
00026       Parameter() {
00027          maxNumberOfIterations_ = 1000;
00028       }
00029 
00030       typename InferenceType::Parameter parameter_; 
00031       size_t maxNumberOfIterations_; 
00032    };
00033 
00034    AlphaBetaSwap(const GraphicalModelType&, Parameter = Parameter());
00035    std::string name() const;
00036    const GraphicalModelType& graphicalModel() const;
00037    InferenceTermination infer();
00038    template<class VISITOR>
00039    InferenceTermination infer(VISITOR & );
00040    void reset();
00041    void setStartingPoint(typename std::vector<LabelType>::const_iterator);
00042    InferenceTermination arg(std::vector<LabelType>&, const size_t = 1) const;
00043 
00044 private:
00045    const GraphicalModelType& gm_;
00046    Parameter parameter_;
00047    std::vector<LabelType> label_;
00048    size_t alpha_;
00049    size_t beta_;
00050    size_t maxState_;
00051    void increment();
00052    void addUnary(INF&, const size_t var, const ValueType v0, const ValueType v1);
00053    void addPairwise(INF&, const size_t var1, const size_t var2, const ValueType v0, const ValueType v1, const ValueType v2, const ValueType v3);
00054 };
00055 
00056 // reset assumes that the structure of the graphical model has not changed
00057 template<class GM, class INF>
00058 inline void
00059 AlphaBetaSwap<GM, INF>::reset() {
00060    alpha_ = 0;
00061    beta_ = 0;
00062    std::fill(label_.begin(),label_.end(),0);
00063 }
00064 
00065 template<class GM, class INF>
00066 inline void
00067 AlphaBetaSwap<GM, INF>::increment() {
00068    if (++beta_ >= maxState_) {
00069       if (++alpha_ >= maxState_ - 1) {
00070          alpha_ = 0;
00071       }
00072       beta_ = alpha_ + 1;
00073    }
00074    OPENGM_ASSERT(alpha_ < maxState_);
00075    OPENGM_ASSERT(beta_ < maxState_);
00076    OPENGM_ASSERT(alpha_ < beta_);
00077 }
00078 
00079 template<class GM, class INF>
00080 inline std::string
00081 AlphaBetaSwap<GM, INF>::name() const {
00082    return "Alpha-Beta-Swap";
00083 }
00084 
00085 template<class GM, class INF>
00086 inline const typename AlphaBetaSwap<GM, INF>::GraphicalModelType&
00087 AlphaBetaSwap<GM, INF>::graphicalModel() const {
00088    return gm_;
00089 }
00090 
00091 template<class GM, class INF>
00092 inline AlphaBetaSwap<GM, INF>::AlphaBetaSwap
00093 (
00094    const GraphicalModelType& gm,
00095    Parameter para
00096 )
00097 :  gm_(gm)
00098 {
00099    parameter_ = para;
00100    label_.resize(gm_.numberOfVariables(), 0);
00101    alpha_ = 0;
00102    beta_ = 0;
00103    for (size_t j = 0; j < gm_.numberOfFactors(); ++j) {
00104       if (gm_[j].numberOfVariables() > 2) {
00105          throw RuntimeError("This implementation of Alpha-Beta-Swap supports only factors of order <= 2.");
00106       }
00107    }
00108    maxState_ = 0;
00109    for (size_t i = 0; i < gm_.numberOfVariables(); ++i) {
00110       size_t numSt = gm_.numberOfLabels(i);
00111       if (numSt > maxState_)
00112          maxState_ = numSt;
00113    }
00114 }
00115 
00116 template<class GM, class INF>
00117 inline void
00118 AlphaBetaSwap<GM,INF>::setStartingPoint
00119 (
00120    typename std::vector<typename AlphaBetaSwap<GM,INF>::LabelType>::const_iterator begin
00121 ) {
00122    try{
00123       label_.assign(begin, begin+gm_.numberOfVariables());
00124    }
00125    catch(...) {
00126       throw RuntimeError("unsuitable starting point");
00127    }
00128 }
00129 
00130 template<class GM, class INF>
00131 inline void
00132 AlphaBetaSwap<GM, INF>::addUnary
00133 (
00134    INF& inf,
00135    const size_t var1,
00136    const ValueType v0,
00137    const ValueType v1
00138 ) {
00139    const size_t shape[] = {2};
00140    const size_t vars[] = {var1};
00141    opengm::IndependentFactor<ValueType,IndexType,LabelType> fac(vars, vars + 1, shape, shape + 1);
00142    fac(0) = v0;
00143    fac(1) = v1;
00144    inf.addFactor(fac);
00145 }
00146 
00147 template<class GM, class INF>
00148 inline void
00149 AlphaBetaSwap<GM, INF>::addPairwise
00150 (
00151    INF& inf,
00152    const size_t var1,
00153    const size_t var2,
00154    const ValueType v0,
00155    const ValueType v1,
00156    const ValueType v2,
00157    const ValueType v3
00158 ) {
00159    const size_t shape[] = {2, 2};
00160    const size_t vars[] = {var1, var2};
00161    opengm::IndependentFactor<ValueType,IndexType,LabelType> fac(vars, vars + 2, shape, shape + 2);
00162    fac(0, 0) = v0;
00163    fac(0, 1) = v1;
00164    fac(1, 0) = v2;
00165    fac(1, 1) = v3;
00166    OPENGM_ASSERT(v1 + v2 - v0 - v3 >= 0);
00167    inf.addFactor(fac);
00168 }
00169 template<class GM, class INF>
00170 InferenceTermination
00171 AlphaBetaSwap<GM, INF>::infer() {
00172    EmptyVisitorType v;
00173    return infer(v);
00174 }
00175 
00176 template<class GM, class INF>
00177 template<class VISITOR>
00178 InferenceTermination
00179 AlphaBetaSwap<GM, INF>::infer
00180 (
00181    VISITOR & visitor
00182 ) {
00183    visitor.begin(*this,0,0);
00184    //visitor(*this,0,0);
00185    size_t it = 0;
00186    size_t countUnchanged = 0;
00187    size_t numberOfVariables = gm_.numberOfVariables();
00188    std::vector<size_t> variable2Node(numberOfVariables, 0);
00189    ValueType energy = gm_.evaluate(label_);
00190    size_t vecA[1];
00191    size_t vecB[1];
00192    size_t vecAA[2];
00193    size_t vecAB[2];
00194    size_t vecBA[2];
00195    size_t vecBB[2];
00196    size_t vecAX[2];
00197    size_t vecBX[2];
00198    size_t vecXA[2];
00199    size_t vecXB[2];
00200    size_t numberOfLabelPairs = maxState_*(maxState_ - 1)/2;
00201    while (it++ < parameter_.maxNumberOfIterations_ && countUnchanged < numberOfLabelPairs) {
00202       increment();
00203       size_t counter = 0;
00204       std::vector<size_t> numFacDim(4, 0);
00205       for (size_t i = 0; i < numberOfVariables; ++i) {
00206          if (label_[i] == alpha_ || label_[i] == beta_) {
00207             variable2Node[i] = counter++;
00208          }
00209       }
00210       if (counter == 0) {
00211          continue;
00212       }
00213       INF inf(counter, numFacDim);
00214       vecA[0] = alpha_;
00215       vecB[0] = beta_;
00216       vecAA[0] = alpha_;
00217       vecAA[1] = alpha_;
00218       vecBB[0] = beta_;
00219       vecBB[1] = beta_;
00220       vecBA[0] = beta_;
00221       vecBA[1] = alpha_;
00222       vecAB[0] = alpha_;
00223       vecAB[1] = beta_;
00224       vecAX[0] = alpha_;
00225       vecBX[0] = beta_;
00226       vecXA[1] = alpha_;
00227       vecXB[1] = beta_;
00228       for (size_t k = 0; k < gm_.numberOfFactors(); ++k) {
00229          const FactorType& factor = gm_[k];
00230          if (factor.numberOfVariables() == 1) {
00231             size_t var = factor.variableIndex(0);
00232             size_t node = variable2Node[var];
00233             if (label_[var] == alpha_ || label_[var] == beta_) {
00234                OPENGM_ASSERT(alpha_ < gm_.numberOfLabels(var));
00235                OPENGM_ASSERT(beta_ < gm_.numberOfLabels(var));
00236                addUnary(inf, node, factor(vecA), factor(vecB));
00237                //inf.addUnary(node, factor(vecA), factor(vecB));
00238             }
00239          } else if (factor.numberOfVariables() == 2) {
00240             size_t var1 = factor.variableIndex(0);
00241             size_t var2 = factor.variableIndex(1);
00242             size_t node1 = variable2Node[var1];
00243             size_t node2 = variable2Node[var2];
00244 
00245             if ((label_[var1] == alpha_ || label_[var1] == beta_) && (label_[var2] == alpha_ || label_[var2] == beta_)) {
00246                addPairwise(inf, node1, node2, factor(vecAA), factor(vecAB), factor(vecBA), factor(vecBB));
00247                //inf.addPairwise(node1, node2, factor(vecAA), factor(vecAB), factor(vecBA), factor(vecBB));
00248             } else if ((label_[var1] == alpha_ || label_[var1] == beta_) && (label_[var2] != alpha_ && label_[var2] != beta_)) {
00249                vecAX[1] = vecBX[1] = label_[var2];
00250                addUnary(inf, node1, factor(vecAX), factor(vecBX));
00251                //inf.addUnary(node1, factor(vecAX), factor(vecBX));
00252             } else if ((label_[var2] == alpha_ || label_[var2] == beta_) && (label_[var1] != alpha_ && label_[var1] != beta_)) {
00253                vecXA[0] = vecXB[0] = label_[var1];
00254                addUnary(inf, node2, factor(vecXA), factor(vecXB));
00255                //inf.addUnary(node2, factor(vecXA), factor(vecXB));
00256             }
00257          }
00258       }
00259       std::vector<LabelType> state; //(counter);
00260       inf.infer();
00261       inf.arg(state);
00262       OPENGM_ASSERT(state.size() == counter);
00263       for (size_t var = 0; var < numberOfVariables; ++var) {
00264          if (label_[var] == alpha_ || label_[var] == beta_) {
00265             if (state[variable2Node[var]] == 0)
00266                label_[var] = alpha_;
00267             else
00268                label_[var] = beta_;
00269          } else {
00270             //do nothing
00271          }
00272       }
00273       ValueType energy2 = gm_.evaluate(label_);
00274       visitor(*this,energy2,energy);
00275       OPENGM_ASSERT(!AccumulationType::ibop(energy2, energy));
00276       if (AccumulationType::bop(energy2, energy)) {
00277          energy = energy2;
00278       } else {
00279          ++countUnchanged;
00280       }
00281    }
00282    visitor.end(*this,energy,energy);
00283    return NORMAL;
00284 }
00285 
00286 template<class GM, class INF>
00287 inline InferenceTermination
00288 AlphaBetaSwap<GM, INF>::arg(std::vector<LabelType>& arg, const size_t n) const {
00289    if (n > 1) {
00290       return UNKNOWN;
00291    } else {
00292       OPENGM_ASSERT(label_.size() == gm_.numberOfVariables());
00293       arg.resize(label_.size());
00294       for (size_t i = 0; i < label_.size(); ++i)
00295          arg[i] = label_[i];
00296       return NORMAL;
00297    }
00298 }
00299 
00300 } // namespace opengm
00301 
00302 #endif // #ifndef OPENGM_ALPHABEATSWAP_HXX
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
Generated on Mon Jun 17 16:31:01 2013 for OpenGM by  doxygen 1.6.3