trws_base.hxx

Go to the documentation of this file.
00001 #ifndef TRWS_BASE_HXX_
00002 #define TRWS_BASE_HXX_
00003 #include <iostream>
00004 #include <time.h>
00005 #include <opengm/inference/trws/trws_decomposition.hxx>
00006 #include <opengm/inference/trws/trws_subproblemsolver.hxx>
00007 #include <functional>
00008 #include <opengm/functions/view_fix_variables_function.hxx>
00009 #include <opengm/inference/inference.hxx>
00010 #include <opengm/inference/visitors/visitor.hxx>
00011 
00012 namespace trws_base{
00013 
00014 template<class GM>
00015 class DecompositionStorage
00016 {
00017 public:
00018    typedef GM GraphicalModelType;
00019    typedef SequenceStorage<GM> SubModel;
00020    typedef typename GM::ValueType ValueType;
00021    typedef typename GM::IndexType IndexType;
00022    typedef typename GM::LabelType LabelType;
00023    typedef typename MonotoneChainsDecomposition<GM>::SubVariable SubVariable;
00024    typedef typename MonotoneChainsDecomposition<GM>::SubVariableListType SubVariableListType;
00025    typedef typename SubModel::UnaryFactor UnaryFactor;
00026    typedef enum {GRIDSTRUCTURE, GENERALSTRUCTURE} StructureType;
00027    typedef VariableToFactorMapping<GM> VariableToFactorMap;
00028 
00029    DecompositionStorage(const GM& gm,StructureType structureType=GENERALSTRUCTURE);
00030    ~DecompositionStorage();
00031 
00032    const GM& masterModel()const{return _gm;}
00033    LabelType numberOfLabels(IndexType varId)const{return _gm.numberOfLabels(varId);}
00034    IndexType numberOfModels()const{return (IndexType)_subModels.size();}
00035    IndexType numberOfSharedVariables()const{return (IndexType)_variableDecomposition.size();}
00036    SubModel& subModel(IndexType modelId){return *_subModels[modelId];}
00037    const SubModel& subModel(IndexType modelId)const{return *_subModels[modelId];}
00038    IndexType size(IndexType subModelId)const{return (IndexType)_subModels[subModelId]->size();}
00039 
00040    const SubVariableListType& getSubVariableList(IndexType varId)const{return _variableDecomposition[varId];}
00041    StructureType getStructureType()const{return _structureType;}
00042 #ifdef TRWS_DEBUG_OUTPUT
00043    void PrintTestData(std::ostream& fout)const;
00044    void PrintVariableDecompositionConsistency(std::ostream& fout)const;
00045 #endif
00046 
00047 private:
00048    void _InitSubModels();
00049    const GM& _gm;
00050    StructureType  _structureType;
00051    std::vector<SubModel*> _subModels;
00052    std::vector<SubVariableListType> _variableDecomposition;
00053    VariableToFactorMap _var2FactorMap;
00054 };
00055 
00056 template<class VISITOR, class INFERENCE_TYPE>
00057 class VisitorWrapper
00058 {
00059 public:
00060    typedef VISITOR VisitorType;
00061    typedef INFERENCE_TYPE InferenceType;
00062    typedef typename InferenceType::ValueType ValueType;
00063 
00064    VisitorWrapper(VISITOR* pvisitor,INFERENCE_TYPE* pinference)
00065    :_pvisitor(pvisitor),
00066     _pinference(pinference){};
00067    void begin(ValueType value,ValueType bound){_pvisitor->begin(*_pinference,value,bound);}
00068    void end(ValueType value,ValueType bound){_pvisitor->end(*_pinference,value,bound);}
00069    void operator() (ValueType value,ValueType bound){(*_pvisitor)(*_pinference,value,bound);}
00070 private:
00071    VISITOR* _pvisitor;
00072    INFERENCE_TYPE* _pinference;
00073 };
00074 
00075 template<class ValueType>
00076 struct TRWSPrototype_Parameters
00077 {
00078    size_t maxNumberOfIterations_;
00079    ValueType precision_;
00080    bool absolutePrecision_;//true for absolute precision, false for relative w.r.t. dual value
00081    ValueType minRelativeDualImprovement_;
00082    bool fastComputations_;
00083 
00084    TRWSPrototype_Parameters(size_t maxIternum,
00085                           ValueType precision=1.0,
00086                           bool absolutePrecision=true,
00087                           ValueType minRelativeDualImprovement=-1.0,
00088                           bool fastComputations=true):
00089       maxNumberOfIterations_(maxIternum),
00090       precision_(precision),
00091       absolutePrecision_(absolutePrecision),
00092       minRelativeDualImprovement_(minRelativeDualImprovement),
00093       fastComputations_(fastComputations)
00094       {};
00095 };
00096 
00097 template<class GM>
00098 class PreviousFactorTable
00099 {
00100 public:
00101    typedef typename GM::IndexType IndexType;
00102    typedef SequenceStorage<GM> Storage;
00103    typedef typename Storage::MoveDirection MoveDirection;
00104    struct FactorVarID
00105    {
00106       FactorVarID(){};
00107       FactorVarID(IndexType fID,IndexType vID,IndexType lID):
00108          factorId(fID),varId(vID),localId(lID){};
00109 
00110 #ifdef TRWS_DEBUG_OUTPUT
00111       void print(std::ostream& out)const{out <<"("<<factorId<<","<<varId<<","<<localId<<"),";}
00112 #endif
00113 
00114       IndexType factorId;
00115       IndexType varId;
00116       IndexType localId;//local index of varId
00117    };
00118    typedef std::vector<FactorVarID> FactorList;
00119    typedef typename FactorList::const_iterator const_iterator;
00120 
00121    PreviousFactorTable(const GM& gm);
00122    const_iterator begin(IndexType varId,MoveDirection md)const{return (md==Storage::Direct ? _forwardFactors[varId].begin() : _backwardFactors[varId].begin());}
00123    const_iterator end(IndexType varId,MoveDirection md)const{return (md==Storage::Direct ? _forwardFactors[varId].end() : _backwardFactors[varId].end());}
00124 #ifdef TRWS_DEBUG_OUTPUT
00125    void PrintTestData(std::ostream& fout);
00126 #endif
00127 private:
00128    std::vector<FactorList> _forwardFactors;
00129    std::vector<FactorList> _backwardFactors;
00130 };
00131 
00132 template<class GM>
00133 PreviousFactorTable<GM>::PreviousFactorTable(const GM& gm):
00134 _forwardFactors(gm.numberOfVariables()),
00135 _backwardFactors(gm.numberOfVariables())
00136 {
00137  std::vector<IndexType> varIDs(2);
00138  for (IndexType factorId=0;factorId<gm.numberOfFactors();++factorId)
00139  {
00140    switch (gm[factorId].numberOfVariables())
00141    {
00142     case 1: break;
00143     case 2:
00144        gm[factorId].variableIndices(varIDs.begin());
00145        if (varIDs[0] < varIDs[1])
00146        {
00147           _forwardFactors[varIDs[1]].push_back(FactorVarID(factorId,varIDs[0],0));
00148           _backwardFactors[varIDs[0]].push_back(FactorVarID(factorId,varIDs[1],1));
00149        }
00150        else
00151        {
00152           _forwardFactors[varIDs[0]].push_back(FactorVarID(factorId,varIDs[1],1));
00153           _backwardFactors[varIDs[1]].push_back(FactorVarID(factorId,varIDs[0],0));
00154        }
00155        break;
00156     default: throw std::runtime_error("PreviousFactor::PreviousFactor: only the factors of order <=2 are supported!");
00157    }
00158  }
00159 }
00160 
00161 #ifdef TRWS_DEBUG_OUTPUT
00162 template<class GM>
00163 void PreviousFactorTable<GM>::PrintTestData(std::ostream& fout)
00164 {
00165    fout << "Forward factors:"<<std::endl;
00166    for (size_t varId=0;varId<_forwardFactors.size();++varId)
00167    {
00168      fout << "varId="<<varId<<", ";
00169      for (size_t i=0;i<_forwardFactors[varId].size();++i)
00170        _forwardFactors[varId][i].print(fout);
00171      fout <<std::endl;
00172    }
00173 
00174    fout << "Backward factors:"<<std::endl;
00175    for (size_t varId=0;varId<_backwardFactors.size();++varId)
00176    {
00177      fout << "varId="<<varId<<", ";
00178      for (size_t i=0;i<_backwardFactors[varId].size();++i)
00179        _backwardFactors[varId][i].print(fout);
00180      fout <<std::endl;
00181    }
00182 }
00183 #endif
00184 
00185 template <class SubSolver>
00186 class TRWSPrototype
00187 {
00188 public:
00189    typedef typename SubSolver::GMType GM;//TODO: remove me
00190    typedef GM GraphicalModelType;
00191    typedef typename SubSolver::ACCType ACC;//TODO: remove me
00192    typedef ACC AccumulationType;
00193    typedef SubSolver SubSolverType;
00194    typedef FunctionParameters<GM> FactorProperties;
00195    typedef opengm::EmptyVisitor< TRWSPrototype<SubSolverType> >  EmptyVisitorParent;
00196    typedef VisitorWrapper<EmptyVisitorParent,TRWSPrototype<SubSolver>  > EmptyVisitorType;
00197 
00198    typedef typename SubSolver::const_iterators_pair const_marginals_iterators_pair;
00199    typedef typename GM::ValueType ValueType;
00200    typedef typename GM::IndexType IndexType;
00201    typedef typename GM::LabelType LabelType;
00202    typedef opengm::InferenceTermination InferenceTermination;
00203    typedef typename std::vector<ValueType> OutputContainerType;
00204    typedef typename OutputContainerType::iterator OutputIteratorType;//TODO: make a template parameter
00205 
00206    typedef TRWSPrototype_Parameters<ValueType> Parameters;
00207 
00208    typedef SequenceStorage<GM> SubModel;
00209    typedef DecompositionStorage<GM> Storage;
00210    typedef typename Storage::UnaryFactor UnaryFactor;
00211 
00212    TRWSPrototype(Storage& storage,const Parameters& params
00213 #ifdef TRWS_DEBUG_OUTPUT
00214          ,std::ostream& fout=std::cout
00215 #endif
00216          );
00217    virtual ~TRWSPrototype();
00218 
00219    virtual ValueType GetBestIntegerBound()const{return _bestIntegerBound;};
00220    virtual ValueType value()const{return _bestIntegerBound;}
00221    virtual ValueType bound()const{return _dualBound;}
00222    virtual const std::vector<LabelType>& arg()const{return _bestIntegerLabeling;}
00223 
00224 #ifdef TRWS_DEBUG_OUTPUT
00225    virtual void PrintTestData(std::ostream& fout)const;
00226 #endif
00227 
00228    bool CheckDualityGap(ValueType primalBound,ValueType dualBound);
00229    virtual std::pair<ValueType,ValueType> GetMarginals(IndexType variable, OutputIteratorType begin){return std::make_pair((ValueType)0,(ValueType)0);};
00230    void GetMarginalsMove();
00231    void BackwardMove();//optimization move, also estimates a primal bound
00232 
00233    ValueType getBound(size_t i)const{return _subSolvers[i]->GetObjectiveValue();}
00234    virtual InferenceTermination infer(){EmptyVisitorParent vis; EmptyVisitorType visitor(&vis,this);  return infer(visitor);};
00235    template<class VISITOR> InferenceTermination infer(VISITOR&);
00236    void ForwardMove();
00237    ValueType lastDualUpdate()const{return _lastDualUpdate;}
00238 
00239    template<class VISITOR> InferenceTermination infer_visitor_updates(VISITOR&);
00240    InferenceTermination core_infer(){EmptyVisitorParent vis; EmptyVisitorType visitor(&vis,this);  return _core_infer(visitor);};
00241    const FactorProperties& getFactorProperties()const{return _factorProperties;}
00242 protected:
00243    void _EstimateIntegerLabeling();
00244    template <class VISITOR> InferenceTermination _core_infer(VISITOR&);
00245    virtual ValueType _GetPrimalBound(){_EvaluateIntegerBounds(); return GetBestIntegerBound();}
00246    virtual void _postprocessMarginals(typename std::vector<ValueType>::iterator begin,typename std::vector<ValueType>::iterator end)=0;
00247    virtual void _normalizeMarginals(typename std::vector<ValueType>::iterator begin,typename std::vector<ValueType>::iterator end,SubSolver* subSolver)=0;
00248    void _EvaluateIntegerBounds();
00249 
00250    /*
00251     * Integer labeling computation functions
00252     */
00253    virtual void _SumUpForwardMarginals(std::vector<ValueType>* pout,const_marginals_iterators_pair itpair)=0;
00254    void _EstimateIntegerLabel(IndexType varId,const std::vector<ValueType>& sumMarginal)
00255    {_integerLabeling[varId]=std::max_element(sumMarginal.begin(),sumMarginal.end(),ACC::template ibop<ValueType>)-sumMarginal.begin();}
00256 
00257    void _InitSubSolvers();
00258    void _ForwardMove();
00259    void _FinalizeMove();
00260    ValueType _GetObjectiveValue();
00261    IndexType _order(IndexType i);
00262    IndexType _core_order(IndexType i,IndexType totalSize);
00263    bool _CheckConvergence(ValueType relativeThreshold);
00264    virtual bool _CheckStoppingCondition(InferenceTermination* pterminationCode);
00265    virtual void _EstimateTRWSBound(){};
00266 
00267    virtual void _InitMove()=0;
00268 
00269    Storage&    _storage;
00270    FactorProperties _factorProperties;
00271    PreviousFactorTable<GM> _ftable;
00272    Parameters _parameters;
00273 
00274 #ifdef TRWS_DEBUG_OUTPUT
00275    std::ostream& _fout;
00276 #endif
00277 
00278    ValueType _dualBound;
00279    ValueType _oldDualBound;
00280    ValueType _lastDualUpdate;
00281 
00282    typename SubModel::MoveDirection _moveDirection;
00283    std::vector<SubSolver*> _subSolvers;
00284 
00285    std::vector<std::vector<ValueType> > _marginals;
00286 
00287    ValueType _integerBound;
00288    ValueType _bestIntegerBound;
00289 
00290    std::vector<LabelType> _integerLabeling;
00291    std::vector<LabelType> _bestIntegerLabeling;
00292 
00293    /* Computation optimization */
00294    std::vector<ValueType> _sumMarginal;
00295    mutable typename FactorProperties::ParameterStorageType _factorParameters;
00296 
00297 private:
00298    TRWSPrototype(TRWSPrototype&);
00299    TRWSPrototype& operator =(TRWSPrototype&);
00300 };
00301 
00302 template<class ValueType>
00303 struct SumProdTRWS_Parameters : public TRWSPrototype_Parameters<ValueType>
00304 {
00305    typedef TRWSPrototype_Parameters<ValueType> parent;
00306    ValueType smoothingValue_;
00307    SumProdTRWS_Parameters(size_t maxIternum,
00308             ValueType smValue,
00309             ValueType precision=1.0,
00310             bool absolutePrecision=true,
00311             ValueType minRelativeDualImprovement=2*std::numeric_limits<ValueType>::epsilon(),
00312             bool fastComputations=true)
00313    :parent(maxIternum,precision,absolutePrecision,minRelativeDualImprovement,fastComputations),
00314     smoothingValue_(smValue){};
00315 };
00316 
00317 template<class GM,class ACC>
00318 class SumProdTRWS : public TRWSPrototype<SumProdSolver<GM,ACC,typename std::vector<typename GM::ValueType>::const_iterator> >
00319 {
00320 public:
00321    typedef TRWSPrototype<SumProdSolver<GM,ACC,typename std::vector<typename GM::ValueType>::const_iterator> > parent;
00322    typedef ACC AccumulationType;
00323    typedef GM GraphicalModelType;
00324    typedef typename parent::SubSolverType SubSolver;
00325    typedef typename parent::const_marginals_iterators_pair const_marginals_iterators_pair;
00326    typedef typename parent::ValueType ValueType;
00327    typedef typename parent::IndexType IndexType;
00328    typedef typename parent::LabelType LabelType;
00329    typedef typename parent::InferenceTermination InferenceTermination;
00330    typedef SequenceStorage<GM> SubModel;
00331    typedef DecompositionStorage<GM> Storage;
00332    typedef typename parent::OutputContainerType OutputContainerType;
00333    typedef typename OutputContainerType::iterator OutputIteratorType;
00334 
00335    typedef SumProdTRWS_Parameters<ValueType> Parameters;
00336 
00337    SumProdTRWS(Storage& storage,const Parameters& params
00338 #ifdef TRWS_DEBUG_OUTPUT
00339          ,std::ostream& fout=std::cout
00340 #endif
00341          ):
00342       parent(storage,params
00343 #ifdef TRWS_DEBUG_OUTPUT
00344             ,fout
00345 #endif
00346       ),
00347       _smoothingValue(params.smoothingValue_)
00348       {};
00349    ~SumProdTRWS(){};
00350 
00351 #ifdef TRWS_DEBUG_OUTPUT
00352    void PrintTestData(std::ostream& fout)const;
00353 #endif
00354 
00355    void SetSmoothing(ValueType smoothingValue){_smoothingValue=smoothingValue;_InitMove();}
00356    ValueType GetSmoothing()const{return _smoothingValue;}
00357    /*
00358     * returns "averaged" over subsolvers marginals
00359     * and pair of (ell_2 norm,ell_infty norm)
00360     */
00361    std::pair<ValueType,ValueType> GetMarginals(IndexType variable, OutputIteratorType begin);
00362    ValueType GetMarginalsAndDerivativeMove();
00363    ValueType getDerivative(size_t i)const{return parent::_subSolvers[i]->getDerivative();}
00364 
00365 protected:
00366    void _SumUpForwardMarginals(std::vector<ValueType>* pout,const_marginals_iterators_pair itpair);
00367    void _postprocessMarginals(typename std::vector<ValueType>::iterator begin,typename std::vector<ValueType>::iterator end);
00368    void _normalizeMarginals(typename std::vector<ValueType>::iterator begin,typename std::vector<ValueType>::iterator end,SubSolver* subSolver);
00369    void _InitMove();
00370    //bool _CheckConvergence();
00371    //bool _CheckStoppingCondition(InferenceTermination* pterminationCode);
00372    ValueType _smoothingValue;
00373 };
00374 
00375 template<class ValueType>
00376 struct MaxSumTRWS_Parameters : public TRWSPrototype_Parameters<ValueType>
00377 {
00378    typedef TRWSPrototype_Parameters<ValueType> parent;
00379    MaxSumTRWS_Parameters(size_t maxIternum,
00380             ValueType precision=1.0,
00381             bool absolutePrecision=true,
00382             ValueType minRelativeDualImprovement=-1.0,
00383             bool fastComputations=true,
00384             bool canonicalNormalization=false):
00385       parent(maxIternum,precision,absolutePrecision,minRelativeDualImprovement,fastComputations),
00386       canonicalNormalization_(canonicalNormalization){};
00387 
00388    bool canonicalNormalization_;
00389 };
00390 
00391 template<class GM,class ACC>
00392 class MaxSumTRWS : public TRWSPrototype<MaxSumSolver<GM,ACC,typename std::vector<typename GM::ValueType>::const_iterator> >
00393 {
00394 public:
00395    typedef TRWSPrototype<MaxSumSolver<GM,ACC,typename std::vector<typename GM::ValueType>::const_iterator> > parent;
00396    //typedef typename parent::Parameters Parameters;
00397    typedef typename parent::SubSolverType SubSolver;
00398    typedef typename parent::const_marginals_iterators_pair const_marginals_iterators_pair;
00399    typedef typename parent::ValueType ValueType;
00400    typedef typename parent::IndexType IndexType;
00401    typedef typename parent::LabelType LabelType;
00402    typedef typename parent::InferenceTermination InferenceTermination;
00403    typedef typename parent::EmptyVisitorType EmptyVisitorType;
00404    typedef typename parent::UnaryFactor UnaryFactor;
00405    typedef ACC AccumulationType;
00406    typedef GM GraphicalModelType;
00407    typedef typename parent::OutputContainerType OutputContainerType;
00408 
00409    typedef SequenceStorage<GM> SubModel;
00410    typedef DecompositionStorage<GM> Storage;
00411 
00412    typedef MaxSumTRWS_Parameters<ValueType> Parameters;
00413 
00414    MaxSumTRWS(Storage& storage,const Parameters& params
00415 #ifdef TRWS_DEBUG_OUTPUT
00416          ,std::ostream& fout=std::cout
00417 #endif
00418    ):
00419       parent(storage,params
00420 #ifdef TRWS_DEBUG_OUTPUT
00421             ,fout
00422 #endif
00423             ),
00424       _canonicalNormalization(params.canonicalNormalization_),
00425       _pseudoBoundValue(0.0),
00426       _localConsistencyCounter(0)
00427    {}
00428    ~MaxSumTRWS(){};
00429 
00430    void getTreeAgreement(std::vector<bool>& out,std::vector<LabelType>* plabeling=0);
00431    bool CheckTreeAgreement(InferenceTermination* pterminationCode);
00432 protected:
00433    void _SumUpForwardMarginals(std::vector<ValueType>* pout,const_marginals_iterators_pair itpair);
00434    void _postprocessMarginals(typename std::vector<ValueType>::iterator begin,typename std::vector<ValueType>::iterator end);
00435    void _normalizeMarginals(typename std::vector<ValueType>::iterator begin,typename std::vector<ValueType>::iterator end,SubSolver* subSolver);
00436    void _InitMove();
00437    void _EstimateTRWSBound();
00438    bool _CheckStoppingCondition(InferenceTermination* pterminationCode);
00439 
00440    bool _canonicalNormalization;
00441    ValueType _pseudoBoundValue;
00442    size_t _localConsistencyCounter;
00443    /*
00444     * computaton optimization
00445     */
00446    std::vector<bool> _treeAgree;
00447    std::vector<bool> _mask;
00448    std::vector<bool> _nodeMask;
00449 };
00450 //============ TRWSPrototype IMPLEMENTATION ======================================
00451 
00452 template <class SubSolver>
00453 TRWSPrototype<SubSolver>::TRWSPrototype(Storage& storage,const Parameters& params
00454 #ifdef TRWS_DEBUG_OUTPUT
00455       ,std::ostream& fout
00456 #endif
00457 ):
00458 _storage(storage),
00459 _factorProperties(storage.masterModel()),
00460 _ftable(storage.masterModel()),
00461 _parameters(params),
00462 #ifdef TRWS_DEBUG_OUTPUT
00463 _fout(fout),
00464 #endif
00465 _dualBound(ACC::template ineutral<ValueType>()),
00466 _oldDualBound(ACC::template ineutral<ValueType>()),
00467 _lastDualUpdate(0),
00468 _moveDirection(SubModel::Direct),
00469 _subSolvers(),
00470 _marginals(),
00471 _integerBound(ACC::template neutral<ValueType>()),
00472 _bestIntegerBound(ACC::template neutral<ValueType>()),
00473 _integerLabeling(storage.masterModel().numberOfVariables(),0),
00474 _bestIntegerLabeling(storage.masterModel().numberOfVariables(),0),
00475 _sumMarginal()
00476 {
00477 #ifdef TRWS_DEBUG_OUTPUT
00478    _fout.precision(16);
00479 #endif
00480    _InitSubSolvers();
00481    _marginals.resize(_storage.numberOfModels());
00482 #ifdef TRWS_DEBUG_OUTPUT
00483    _factorProperties.PrintStatusData(fout);
00484 #endif
00485 }
00486 
00487 template <class SubSolver>
00488 TRWSPrototype<SubSolver>::~TRWSPrototype()
00489 {
00490    for_each(_subSolvers.begin(),_subSolvers.end(),DeallocatePointer<SubSolver>);
00491 };
00492 
00493 template <class SubSolver>
00494 void TRWSPrototype<SubSolver>::_InitSubSolvers()
00495 {
00496    _subSolvers.resize(_storage.numberOfModels());
00497    for (size_t modelId=0;modelId<_subSolvers.size();++modelId)
00498       _subSolvers[modelId]= new SubSolver(_storage.subModel(modelId),_factorProperties,_parameters.fastComputations_);
00499 }
00500 
00501 template <class SubSolver>
00502 bool TRWSPrototype<SubSolver>::CheckDualityGap(ValueType primalBound,ValueType dualBound)
00503 {
00504    //TODO: check that primal bound > dualBound if (bop(primalBound,dualBound)
00505 
00506    if (_parameters.absolutePrecision_)
00507    {
00508       if (fabs(primalBound-dualBound) <= _parameters.precision_)
00509       {
00510          return true;
00511       }
00512    }
00513    else
00514    {
00515       if (fabs((primalBound-dualBound)/dualBound)<= _parameters.precision_ )
00516          return true;
00517    }
00518    return false;
00519 }
00520 
00521 template <class SubSolver>
00522 bool TRWSPrototype<SubSolver>::_CheckConvergence(ValueType relativeThreshold)
00523 {
00524    if (relativeThreshold >=0.0)
00525    {
00526    ValueType mul; ACC::iop(-1.0,1.0,mul);
00527    if (ACC::bop(_dualBound, (_oldDualBound + _dualBound*mul*relativeThreshold)))
00528       return true;
00529    }
00530    return false;
00531 }
00532 
00533 template <class SubSolver>
00534 bool TRWSPrototype<SubSolver>::_CheckStoppingCondition(InferenceTermination* pterminationCode)
00535 {
00536    _lastDualUpdate=fabs(_dualBound-_oldDualBound);
00537 
00538    if (CheckDualityGap(_bestIntegerBound,_dualBound))
00539    {
00540 #ifdef TRWS_DEBUG_OUTPUT
00541       _fout << "TRWSPrototype::_CheckStoppingCondition(): duality gap <= specified precision!" <<std::endl;
00542 #endif
00543       *pterminationCode=opengm::CONVERGENCE;
00544       return true;
00545    }
00546 
00547    if (_CheckConvergence(_parameters.minRelativeDualImprovement_))
00548    {
00549 #ifdef TRWS_DEBUG_OUTPUT
00550       _fout << "TRWSPrototype::_CheckStoppingCondition(): Dual update is smaller than the specified threshold. Stopping"<<std::endl;
00551 #endif
00552       *pterminationCode=opengm::NORMAL;
00553       return true;
00554    }
00555 
00556    _oldDualBound=_dualBound;
00557 
00558    return false;
00559 }
00560 
00561 template <class SubSolver>
00562 template <class VISITOR>
00563 typename TRWSPrototype<SubSolver>::InferenceTermination TRWSPrototype<SubSolver>::_core_infer(VISITOR& visitor)
00564 {
00565    for (size_t iterationCounter=0;iterationCounter<_parameters.maxNumberOfIterations_;++iterationCounter)
00566    {
00567 #ifdef TRWS_DEBUG_OUTPUT
00568       _fout <<"Iteration Nr."<<iterationCounter<<"-------------------------------------"<<std::endl;
00569 #endif
00570 
00571       BackwardMove();
00572 
00573 #ifdef TRWS_DEBUG_OUTPUT
00574       _fout << "dualBound=" << _dualBound <<", primalBound="<<_GetPrimalBound() <<std::endl;
00575 #endif
00576       _EstimateTRWSBound();
00577       visitor(value(),bound());
00578       InferenceTermination returncode;
00579       if (_CheckStoppingCondition(&returncode))
00580           return returncode;
00581    }
00582    return opengm::TIMEOUT;
00583 }
00584 
00585 template <class SubSolver>
00586 typename TRWSPrototype<SubSolver>::ValueType TRWSPrototype<SubSolver>::_GetObjectiveValue()
00587 {
00588    ValueType   dualBound=0;
00589    for (size_t i=0;i<_subSolvers.size();++i)
00590       dualBound+=_subSolvers[i]->GetObjectiveValue();
00591 
00592    return dualBound;
00593 }
00594 
00595 template <class SubSolver>
00596 void TRWSPrototype<SubSolver>::_ForwardMove()
00597 {
00598    std::for_each(_subSolvers.begin(), _subSolvers.end(), std::mem_fun(&SubSolver::Move));
00599    _moveDirection=SubModel::ReverseDirection(_moveDirection);
00600    _dualBound=_GetObjectiveValue();
00601 }
00602 
00603 template <class SubSolver>
00604 void TRWSPrototype<SubSolver>::GetMarginalsMove()
00605 {
00606    std::for_each(_subSolvers.begin(), _subSolvers.end(), std::mem_fun(&SubSolver::MoveBack));
00607    _moveDirection=SubModel::ReverseDirection(_moveDirection);
00608 }
00609 
00610 template <class SubSolver>
00611 typename TRWSPrototype<SubSolver>::IndexType TRWSPrototype<SubSolver>::_core_order(IndexType i,IndexType totalSize)
00612 {
00613    return (_moveDirection==SubModel::Direct ? i : totalSize-i-1);
00614 }
00615 
00616 template <class SubSolver>
00617 typename TRWSPrototype<SubSolver>::IndexType TRWSPrototype<SubSolver>::_order(IndexType i)
00618 {
00619    return _core_order(i,_storage.numberOfSharedVariables());
00620 }
00621 
00622 template <class SubSolver>
00623 void TRWSPrototype<SubSolver>::_FinalizeMove()
00624 {
00625    std::for_each(_subSolvers.begin(), _subSolvers.end(), std::mem_fun(&SubSolver::FinalizeMove));
00626    _moveDirection=SubModel::ReverseDirection(_moveDirection);
00627    _EstimateIntegerLabeling();
00628 }
00629 
00630 #ifdef TRWS_DEBUG_OUTPUT
00631 template <class SubSolver>
00632 void TRWSPrototype<SubSolver>::PrintTestData(std::ostream& fout)const
00633 {
00634    fout << "_dualBound:" << _dualBound <<std::endl;
00635    fout << "_oldDualBound:" << _oldDualBound <<std::endl;
00636    fout << "_lastDualUpdate=" << _lastDualUpdate << std::endl;
00637    fout << "_moveDirection:" << _moveDirection <<std::endl;
00638    fout << "_integerBound=" << _integerBound << std::endl;
00639    fout << "_bestIntegerBound=" << _bestIntegerBound << std::endl;
00640    fout << "_integerLabeling=" << _integerLabeling;
00641    fout << "_bestIntegerLabeling=" << _bestIntegerLabeling;
00642 }
00643 #endif
00644 
00645 //template <class SubSolver>
00646 //template <class VISITOR>
00647 //typename TRWSPrototype<SubSolver>::InferenceTermination TRWSPrototype<SubSolver>::infer(VISITOR& visitor)
00648 //{
00649 // _InitMove();
00650 // _ForwardMove();
00651 // _oldDualBound=_dualBound;
00652 // visitor.begin(value(),bound());
00653 //#ifdef TRWS_DEBUG_OUTPUT
00654 // _fout << "ForwardMove: dualBound=" << _dualBound <<std::endl;
00655 //#endif
00656 // InferenceTermination returncode;
00657 // returncode=_core_infer(visitor);
00658 // visitor.end(value(), bound());
00659 // return returncode;
00660 //}
00661 
00662 //template <class SubSolver>
00663 //template <class VISITOR>
00664 //typename TRWSPrototype<SubSolver>::InferenceTermination TRWSPrototype<SubSolver>::infer(VISITOR& visitor)
00665 //{
00666 // visitor.begin(value(),bound());
00667 // _InitMove();
00668 // _ForwardMove();
00669 // visitor(value(),bound());
00670 // _oldDualBound=_dualBound;
00671 //#ifdef TRWS_DEBUG_OUTPUT
00672 // _fout << "ForwardMove: dualBound=" << _dualBound <<std::endl;
00673 //#endif
00674 // InferenceTermination returncode;
00675 // returncode=_core_infer(visitor);
00676 // visitor.end(value(), bound());
00677 // return returncode;
00678 //}
00679 template <class SubSolver>
00680 template <class VISITOR>
00681 typename TRWSPrototype<SubSolver>::InferenceTermination TRWSPrototype<SubSolver>::infer(VISITOR& visitor)
00682 {
00683    visitor.begin(value(),bound());
00684    InferenceTermination returncode=infer_visitor_updates(visitor);
00685    visitor.end(value(), bound());
00686    return returncode;
00687 }
00688 
00689 template <class SubSolver>
00690 template <class VISITOR>
00691 typename TRWSPrototype<SubSolver>::InferenceTermination TRWSPrototype<SubSolver>::infer_visitor_updates(VISITOR& visitor)
00692 {
00693    _InitMove();
00694    _ForwardMove();
00695    visitor(value(),bound());
00696    _oldDualBound=_dualBound;
00697 #ifdef TRWS_DEBUG_OUTPUT
00698    _fout << "ForwardMove: dualBound=" << _dualBound <<std::endl;
00699 #endif
00700    InferenceTermination returncode;
00701    returncode=_core_infer(visitor);
00702    return returncode;
00703 }
00704 
00705 template <class SubSolver>
00706 void TRWSPrototype<SubSolver>::ForwardMove()
00707 {
00708    _InitMove();
00709    _ForwardMove();
00710    _dualBound=_GetObjectiveValue();
00711 }
00712 
00713 
00714 template <class SubSolver>
00715 void TRWSPrototype<SubSolver>::BackwardMove()
00716 {
00717    std::vector<ValueType> averageMarginal;
00718 
00719    for (IndexType i=0;i<_storage.numberOfSharedVariables();++i)
00720    {
00721       IndexType varId=_order(i);
00722       const typename Storage::SubVariableListType& varList=_storage.getSubVariableList(varId);
00723       averageMarginal.assign(_storage.numberOfLabels(varId),0.0);
00724 
00725       //<!computing average marginals
00726       for(typename Storage::SubVariableListType::const_iterator modelIt=varList.begin();modelIt!=varList.end();++modelIt)
00727       {
00728          SubSolver& subSolver=*_subSolvers[modelIt->subModelId_];
00729          std::vector<ValueType>& marginals=_marginals[modelIt->subModelId_];
00730          marginals.resize(_storage.numberOfLabels(varId));
00731 
00732          IndexType startNodeIndex=_core_order(0,_storage.size(modelIt->subModelId_));
00733 
00734          if (modelIt->subVariableId_!=startNodeIndex)
00735             subSolver.PushBack();
00736 
00737          typename SubSolver::const_iterators_pair marginalsit=subSolver.GetMarginals();
00738 
00739          std::copy(marginalsit.first,marginalsit.second,marginals.begin());
00740          _normalizeMarginals(marginals.begin(),marginals.end(),&subSolver);
00741          std::transform(marginals.begin(),marginals.end(),averageMarginal.begin(),averageMarginal.begin(),std::plus<ValueType>());
00742       }
00743       transform_inplace(averageMarginal.begin(),averageMarginal.end(),std::bind1st(std::multiplies<ValueType>(),-1.0/varList.size()));
00744 
00745 
00746       //<!reweighting submodels
00747 
00748       for(typename Storage::SubVariableListType::const_iterator modelIt=varList.begin();modelIt!=varList.end();++modelIt)
00749       {
00750          SubSolver& subSolver=*_subSolvers[modelIt->subModelId_];
00751          std::vector<ValueType>& marginals=_marginals[modelIt->subModelId_];
00752 
00753          std::transform(marginals.begin(),marginals.end(),averageMarginal.begin(),marginals.begin(),std::plus<ValueType>());
00754 
00755          _postprocessMarginals(marginals.begin(),marginals.end());
00756 
00757          subSolver.IncreaseUnaryWeights(marginals.begin(),marginals.end());
00758 
00759          IndexType startNodeIndex=_core_order(0,_storage.size(modelIt->subModelId_));
00760 
00761          if (modelIt->subVariableId_!=startNodeIndex)
00762             subSolver.UpdateMarginals();
00763              else subSolver.InitReverseMove();
00764       }
00765    }
00766 
00767    _FinalizeMove();
00768    _EvaluateIntegerBounds();
00769    _dualBound=_GetObjectiveValue();
00770 }
00771 
00772 template <class SubSolver>
00773 void TRWSPrototype<SubSolver>::_EstimateIntegerLabeling()
00774 {
00775    for (IndexType i=0;i<_storage.numberOfSharedVariables();++i)
00776    {
00777       IndexType varId=_order(i);
00778 
00779       const typename Storage::SubVariableListType& varList=_storage.getSubVariableList(varId);
00780       _sumMarginal.assign(_storage.masterModel().numberOfLabels(varId),0.0);
00781       for(typename Storage::SubVariableListType::const_iterator modelIt=varList.begin();modelIt!=varList.end();++modelIt)
00782       {
00783        const_marginals_iterators_pair itpair=_subSolvers[modelIt->subModelId_]->GetMarginals(modelIt->subVariableId_);
00784        _SumUpForwardMarginals(&_sumMarginal,itpair);
00785       }
00786 
00787        typename PreviousFactorTable<GM>::const_iterator begin=_ftable.begin(varId,_moveDirection);
00788        typename PreviousFactorTable<GM>::const_iterator end=_ftable.end(varId,_moveDirection);
00789       for (;begin!=end;++begin)
00790       {
00791        if ((_factorProperties.getFunctionType(begin->factorId)==FunctionParameters<GM>::POTTS) && _parameters.fastComputations_)
00792        {
00793           _sumMarginal[_integerLabeling[begin->varId]]-=_factorProperties.getFunctionParameters(begin->factorId)[0];//instead of adding everywhere the same we just subtract the difference
00794        }else
00795        {
00796        const typename GM::FactorType& pwfactor=_storage.masterModel()[begin->factorId];
00797        IndexType localVarIndx = begin->localId;
00798        LabelType fixedLabel=_integerLabeling[begin->varId];
00799 
00800          opengm::ViewFixVariablesFunction<GM> pencil(pwfactor,
00801                std::vector<opengm::PositionAndLabel<IndexType,LabelType> >(1,
00802                      opengm::PositionAndLabel<IndexType,LabelType>(localVarIndx,
00803                            fixedLabel)));
00804 
00805          for (LabelType j=0;j<_sumMarginal.size();++j)
00806             _sumMarginal[j]+=pencil(&j);
00807        }
00808       }
00809       _EstimateIntegerLabel(varId,_sumMarginal);
00810    }
00811 }
00812 
00813 template <class SubSolver>
00814 void TRWSPrototype<SubSolver>::_EvaluateIntegerBounds()
00815 {
00816    _integerBound=_storage.masterModel().evaluate(_integerLabeling.begin());
00817 
00818    if (ACC::bop(_integerBound,_bestIntegerBound))
00819    {
00820       _bestIntegerLabeling=_integerLabeling;
00821       _bestIntegerBound=_integerBound;
00822    }
00823 
00824 }
00825 
00826 //================================= DecompositionStorage IMPLEMENTATION =================================================
00827 template<class GM>
00828 DecompositionStorage<GM>::DecompositionStorage(const GM& gm,StructureType structureType):
00829 _gm(gm),
00830 _structureType(structureType),
00831 _subModels(),
00832 _variableDecomposition(),
00833 _var2FactorMap(gm)
00834 {
00835    _InitSubModels();
00836 }
00837 
00838 template<class GM>
00839 DecompositionStorage<GM>::~DecompositionStorage()
00840 {
00841    for_each(_subModels.begin(),_subModels.end(),DeallocatePointer<SubModel>);
00842 }
00843 
00844 template<class GM>
00845 void DecompositionStorage<GM>::_InitSubModels()
00846 {
00847    std::auto_ptr<Decomposition<GM> > pdecomposition;
00848 
00849    if (_structureType==GRIDSTRUCTURE)
00850       pdecomposition=std::auto_ptr<Decomposition<GM> >(new GridDecomposition<GM>(_gm));
00851    else
00852       pdecomposition=std::auto_ptr<Decomposition<GM> >(new MonotoneChainsDecomposition<GM>(_gm));
00853 
00854    try{
00855       pdecomposition->ComputeVariableDecomposition(&_variableDecomposition);
00856       size_t numberOfModels=pdecomposition->getNumberOfSubModels();
00857       _subModels.resize(numberOfModels);
00858       for (size_t modelId=0;modelId<numberOfModels;++modelId)
00859       {
00860          const typename SubModel::IndexList& varList=pdecomposition->getVariableList(modelId);
00861          typename SubModel::IndexList numOfSubModelsPerVar(varList.size());
00862          // Initialize numOfSubModelsPerVar
00863          for (size_t varIndx=0;varIndx<varList.size();++varIndx)
00864             numOfSubModelsPerVar[varIndx]=_variableDecomposition[varList[varIndx]].size();
00865 
00866          _subModels[modelId]= new SubModel(_gm,_var2FactorMap,varList,pdecomposition->getFactorList(modelId),numOfSubModelsPerVar);
00867       };
00868    }catch(std::runtime_error& err)
00869    {
00870       throw err;
00871    }
00872 };
00873 
00874 #ifdef TRWS_DEBUG_OUTPUT
00875 template<class GM>
00876 void DecompositionStorage<GM>::PrintTestData(std::ostream& fout)const
00877 {
00878    fout << "_variableDecomposition: "<<std::endl;
00879    for (size_t variableId=0;variableId<_variableDecomposition.size();++variableId)
00880    {
00881       std::for_each(_variableDecomposition[variableId].begin(),_variableDecomposition[variableId].end(),printSubVariable<typename MonotoneChainsDecomposition<GM>::SubVariable>(fout));
00882       fout << std::endl;
00883    }
00884 }
00885 
00886 template<class GM>
00887 void DecompositionStorage<GM>::PrintVariableDecompositionConsistency(std::ostream& fout)const
00888 {
00889    fout << "Variable decomposition consistency:" <<std::endl;
00890    for (size_t varId=0;varId<_gm.numberOfVariables();++varId)
00891    {
00892       fout << varId<<": ";
00893       const SubVariableListType& varList=_variableDecomposition[varId];
00894       typename SubVariableListType::const_iterator modelIt=varList.begin();
00895       std::vector<ValueType> sum(_gm.numberOfLabels(varId),0.0);
00896       while (modelIt!=varList.end())
00897       {
00898          const SubModel& subModel=*_subModels[modelIt->subModelId_];
00899          std::transform(subModel.unaryFactors(modelIt->subVariableId_).begin(),subModel.unaryFactors(modelIt->subVariableId_).end(),
00900                   sum.begin(),sum.begin(),std::plus<ValueType>());
00901          ++modelIt;
00902       }
00903       std::vector<ValueType> originalFactor(_gm.numberOfLabels(varId),0.0);
00904       _gm[varId].copyValues(originalFactor.begin());
00905 
00906       std::transform(sum.begin(),sum.end(),originalFactor.begin(),sum.begin(),std::minus<ValueType>());
00907       fout << std::accumulate(sum.begin(),sum.end(),(ValueType)0.0)<<std::endl;
00908    }
00909 
00910 }
00911 #endif
00912 //================================= MaxSumTRWS IMPLEMENTATION =================================================
00913 
00914 template<class GM,class ACC>
00915 void MaxSumTRWS<GM,ACC>::_InitMove()
00916 {
00917    parent::_moveDirection=SubModel::Direct;
00918    std::for_each(parent::_subSolvers.begin(), parent::_subSolvers.end(), std::mem_fun_t<void,SubSolver>(&SubSolver::InitMove));
00919 }
00920 
00921 template<class GM,class ACC>
00922 void MaxSumTRWS<GM,ACC>::_postprocessMarginals(typename std::vector<ValueType>::iterator begin,typename std::vector<ValueType>::iterator end)
00923 {
00924    transform_inplace(begin,end,std::bind1st(std::multiplies<ValueType>(),-1.0));
00925 }
00926 
00927 template<class GM,class ACC>
00928 void MaxSumTRWS<GM,ACC>::_SumUpForwardMarginals(std::vector<ValueType>* pout,const_marginals_iterators_pair itpair)
00929 {
00930    std::transform(itpair.first,itpair.second,pout->begin(),pout->begin(),std::plus<ValueType>());
00931 }
00932 
00933 template<class GM,class ACC>
00934 void MaxSumTRWS<GM,ACC>::_EstimateTRWSBound()
00935 {
00936    if (_canonicalNormalization) return;
00937    std::vector<ValueType> bounds(parent::_storage.numberOfModels());
00938    for (size_t i=0;i<bounds.size();++i)
00939       bounds[i]=parent::_subSolvers[i]->GetObjectiveValue();
00940 
00941    ValueType min=*std::min_element(bounds.begin(),bounds.end());
00942    ValueType max=*std::max_element(bounds.begin(),bounds.end());
00943    ValueType eps; ACC::iop(max-min,min-max,eps);
00944    ACC::iop(min,max,_pseudoBoundValue);
00945 #ifdef TRWS_DEBUG_OUTPUT
00946    parent::_fout <<"min="<<min<<", max="<<max<<", eps="<<eps<<", pseudo bound="<<bounds.size()*_pseudoBoundValue<<std::endl;
00947 #endif
00948 }
00949 
00950 
00951 template<class GM,class ACC>
00952 void MaxSumTRWS<GM,ACC>::_normalizeMarginals(typename std::vector<ValueType>::iterator begin,typename std::vector<ValueType>::iterator end,SubSolver* subSolver)
00953 {
00954    if (!_canonicalNormalization) return;
00955    ValueType maxVal=*std::max_element(begin,end,ACC::template bop<ValueType>);
00956    transform_inplace(begin,end,std::bind2nd(std::plus<ValueType>(),-maxVal));
00957 }
00958 
00959 template<class GM,class ACC>
00960 void MaxSumTRWS<GM,ACC>::getTreeAgreement(std::vector<bool>& out,std::vector<LabelType>* plabeling)
00961 {
00962    if (plabeling!=0)
00963       plabeling->resize(parent::_storage.masterModel().numberOfVariables());
00964 
00965    out.assign(parent::_storage.masterModel().numberOfVariables(),true);
00966    for (size_t varId=0;varId<parent::_storage.masterModel().numberOfVariables();++varId)
00967    {
00968       const typename Storage::SubVariableListType& varList=parent::_storage.getSubVariableList(varId);
00969       size_t label;
00970       for(typename Storage::SubVariableListType::const_iterator modelIt=varList.begin()
00971                                           ;modelIt!=varList.end();++modelIt)
00972       {
00973          size_t check_label=parent::_subSolvers[modelIt->subModelId_]->arg()[modelIt->subVariableId_];
00974 
00975          if (plabeling!=0) (*plabeling)[varId]=check_label;
00976 
00977          if (modelIt==varList.begin())
00978          {
00979             label=check_label;
00980          }else if (check_label!=label)
00981           {
00982             out[varId]=false;
00983             break;
00984           }
00985       }
00986 
00987    }
00988 }
00989 
00990 
00991 
00992 template<class GM,class ACC>
00993 bool MaxSumTRWS<GM,ACC>::CheckTreeAgreement(InferenceTermination* pterminationCode)
00994 {
00995      getTreeAgreement(_treeAgree);
00996      size_t agree_count=count(_treeAgree.begin(),_treeAgree.end(),true);
00997 #ifdef TRWS_DEBUG_OUTPUT
00998      parent::_fout << "tree-agreement: " << agree_count <<" out of "<<_treeAgree.size() <<", ="<<100*(double)agree_count/_treeAgree.size()<<"%"<<std::endl;
00999 #endif
01000 
01001      if (_treeAgree.size()==agree_count)
01002      {
01003 #ifdef TRWS_DEBUG_OUTPUT
01004         parent::_fout <<"Problem solved."<<std::endl;
01005 #endif
01006         *pterminationCode=opengm::CONVERGENCE;
01007         return true;
01008      }else
01009         return false;
01010 }
01011 
01012 
01013 template<class GM,class ACC>
01014 bool MaxSumTRWS<GM,ACC>::_CheckStoppingCondition(InferenceTermination* pterminationCode)
01015 {
01016   if (CheckTreeAgreement(pterminationCode)) return true;
01017 
01018   return parent::_CheckStoppingCondition(pterminationCode);
01019 }
01020 
01021 //================================= SumProdTRWS IMPLEMENTATION =================================================
01022 #ifdef TRWS_DEBUG_OUTPUT
01023 template<class GM,class ACC>
01024 void SumProdTRWS<GM,ACC>::PrintTestData(std::ostream& fout)const
01025 {
01026    fout << "_smoothingValue:"<<_smoothingValue <<std::endl;
01027    parent::PrintTestData(fout);
01028 }
01029 #endif
01030 
01031 template<class GM,class ACC>
01032 void SumProdTRWS<GM,ACC>::_InitMove()//(ValueType smoothingValue)
01033 {
01034    parent::_moveDirection=SubModel::Direct;
01035    std::for_each(parent::_subSolvers.begin(), parent::_subSolvers.end(), std::bind2nd(std::mem_fun(&SubSolver::InitMove),_smoothingValue));
01036 }
01037 
01038 template<class GM,class ACC>
01039 void SumProdTRWS<GM,ACC>::_normalizeMarginals(typename std::vector<ValueType>::iterator begin,
01040                                    typename std::vector<ValueType>::iterator end,SubSolver* subSolver)
01041 {
01042    ValueType logPartition=subSolver->ComputeObjectiveValue();
01043    //normalizing marginals - subtracting log-partition function value/smoothing
01044    transform_inplace(begin,end,std::bind2nd(std::plus<ValueType>(),-logPartition/_smoothingValue));
01045 }
01046 
01047 template<class GM,class ACC>
01048 void SumProdTRWS<GM,ACC>::_postprocessMarginals(typename std::vector<ValueType>::iterator begin,typename std::vector<ValueType>::iterator end)
01049 {
01050    transform_inplace(begin,end,std::bind1st(std::multiplies<ValueType>(),-_smoothingValue));
01051 }
01052 
01053 template<class GM,class ACC>
01054 void SumProdTRWS<GM,ACC>::_SumUpForwardMarginals(std::vector<ValueType>* pout,const_marginals_iterators_pair itpair)
01055 {
01056    std::transform(pout->begin(),pout->end(),itpair.first,pout->begin(),plus2ndMul<ValueType>(_smoothingValue));
01057 }
01058 
01059 template<class GM,class ACC>
01060 std::pair<typename SumProdTRWS<GM,ACC>::ValueType,typename SumProdTRWS<GM,ACC>::ValueType>
01061 SumProdTRWS<GM,ACC>::GetMarginals(IndexType varId, OutputIteratorType begin)
01062 {
01063   std::fill_n(begin,parent::_storage.numberOfLabels(varId),0.0);
01064   const typename Storage::SubVariableListType& varList=parent::_storage.getSubVariableList(varId);
01065 
01066   OPENGM_ASSERT(varList.size()>0);
01067 
01068   for(typename Storage::SubVariableListType::const_iterator modelIt=varList.begin();modelIt!=varList.end();++modelIt)
01069   {
01070      typename SubSolver::const_iterators_pair marginalsit=parent::_subSolvers[modelIt->subModelId_]->GetMarginals(modelIt->subVariableId_);
01071       std::vector<ValueType>& normMarginals=parent::_marginals[modelIt->subModelId_];
01072       normMarginals.resize(parent::_storage.numberOfLabels(varId));
01073      //normalize
01074      std::copy(marginalsit.first,marginalsit.second,normMarginals.begin());
01075      _normalizeMarginals(normMarginals.begin(),normMarginals.end(),parent::_subSolvers[modelIt->subModelId_]);
01076      ValueType mul; ACC::op(1.0,-1.0,mul);
01077      transform_inplace(normMarginals.begin(),normMarginals.end(),mulAndExp<ValueType>(mul));
01078      std::transform(normMarginals.begin(),normMarginals.end(),begin,begin,std::plus<ValueType>());
01079   }
01080   transform_inplace(begin,begin+parent::_storage.numberOfLabels(varId),std::bind1st(std::multiplies<ValueType>(),1.0/varList.size()));
01081 
01082   ValueType ell2Norm=0, ellInftyNorm=0;
01083   for (typename Storage::SubVariableListType::const_iterator modelIt=varList.begin();modelIt!=varList.end();++modelIt)
01084   {
01085      std::vector<ValueType>& normMarginals=parent::_marginals[modelIt->subModelId_];
01086      OutputIteratorType begin0=begin;
01087      for (typename std::vector<ValueType>::const_iterator bm=normMarginals.begin(); bm!=normMarginals.end();++bm)
01088      {
01089         ValueType diff=(*bm-*begin0); ++begin0;
01090         ell2Norm+=diff*diff;
01091         ellInftyNorm=std::max((ValueType)fabs(diff),ellInftyNorm);
01092      }
01093   }
01094 
01095   return std::make_pair(sqrt(ell2Norm),ellInftyNorm);
01096 }
01097 
01098 template<class GM,class ACC>
01099 typename SumProdTRWS<GM,ACC>::ValueType
01100 SumProdTRWS<GM,ACC>::GetMarginalsAndDerivativeMove()
01101 {
01102    ValueType derivativeValue=0.0;
01103    //std::for_each(parent::_subSolvers.begin(), parent::_subSolvers.end(), std::(&SubSolver::MoveBackGetDerivative()));
01104    for (size_t i=0;i<parent::_subSolvers.size();++i)
01105       derivativeValue+=parent::_subSolvers[i]->MoveBackGetDerivative();
01106 
01107    parent::_moveDirection=SubModel::ReverseDirection(parent::_moveDirection);
01108    return derivativeValue;
01109 }
01110 
01111 };//DD
01112 
01113 #endif /* ADSAL_H_ */
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
Generated on Mon Jun 17 16:31:06 2013 for OpenGM by  doxygen 1.6.3