#ifndef HORUS_BELIEFPROP_H
#define HORUS_BELIEFPROP_H

#include <set>
#include <vector>

#include <sstream>

#include "GroundSolver.h"
#include "FactorGraph.h"


using namespace std;


enum MsgSchedule {
  SEQ_FIXED,
  SEQ_RANDOM,
  PARALLEL,
  MAX_RESIDUAL
};


class BpLink
{
  public:
    BpLink (FacNode* fn, VarNode* vn)
    {
      fac_ = fn;
      var_ = vn;
      v1_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
      v2_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
      currMsg_   = &v1_;
      nextMsg_   = &v2_;
      residual_  = 0.0;
    }

    virtual ~BpLink (void) { };

    FacNode* facNode (void) const { return fac_; }

    VarNode* varNode (void) const { return var_; }

    const Params& message (void) const { return *currMsg_; }

    Params& nextMessage (void) { return *nextMsg_; }

    double residual (void) const { return residual_; }

    void clearResidual (void) { residual_ = 0.0; }

    void updateResidual (void)
    {
      residual_ = LogAware::getMaxNorm (v1_, v2_);
    }

    virtual void updateMessage (void)
    {
      swap (currMsg_, nextMsg_);
    }

    string toString (void) const
    {
      stringstream ss;
      ss << fac_->getLabel();
      ss << " -- " ;
      ss << var_->label();
      return ss.str();
    }

  protected:
    FacNode*  fac_;
    VarNode*  var_;
    Params    v1_;
    Params    v2_;
    Params*   currMsg_;
    Params*   nextMsg_;
    double    residual_;

  private:
    DISALLOW_COPY_AND_ASSIGN (BpLink);
};

typedef vector<BpLink*> BpLinks;


class SPNodeInfo
{
  public:
    SPNodeInfo (void) { }
    void addBpLink (BpLink* link) { links_.push_back (link); }
    const BpLinks& getLinks (void) { return links_; }
  private:
    BpLinks links_;
    DISALLOW_COPY_AND_ASSIGN (SPNodeInfo);
};


class BeliefProp : public GroundSolver
{
  public:
    BeliefProp (const FactorGraph&);

    virtual ~BeliefProp (void);

    Params solveQuery (VarIds);

    virtual void printSolverFlags (void) const;

    virtual Params getPosterioriOf (VarId);

    virtual Params getJointDistributionOf (const VarIds&);

    Params getFactorJoint (FacNode* fn, const VarIds&);

    static double accuracy (void) { return accuracy_; }

    static void setAccuracy (double acc) { accuracy_ = acc; }

    static unsigned maxIterations (void) { return maxIter_; }

    static void setMaxIterations (unsigned mi) { maxIter_ = mi; }

    static MsgSchedule msgSchedule (void) { return schedule_; }

    static void setMsgSchedule (MsgSchedule sch) { schedule_ = sch; }

  protected:
    SPNodeInfo* ninf (const VarNode* var) const
    {
      return varsI_[var->getIndex()];
    }

    SPNodeInfo* ninf (const FacNode* fac) const
    {
      return facsI_[fac->getIndex()];
    }

    void calculateAndUpdateMessage (BpLink* link, bool calcResidual = true)
    {
      if (Globals::verbosity > 2) {
        cout << "calculating & updating " << link->toString() << endl;
      }
      calcFactorToVarMsg (link);
      if (calcResidual) {
        link->updateResidual();
      }
      link->updateMessage();
    }

    void calculateMessage (BpLink* link, bool calcResidual = true)
    {
      if (Globals::verbosity > 2) {
        cout << "calculating " << link->toString() << endl;
      }
      calcFactorToVarMsg (link);
      if (calcResidual) {
        link->updateResidual();
      }
    }

    void updateMessage (BpLink* link)
    {
      link->updateMessage();
      if (Globals::verbosity > 2) {
        cout << "updating " << link->toString() << endl;
      }
    }

    struct CompareResidual
    {
      inline bool operator() (const BpLink* link1, const BpLink* link2)
      {
        return link1->residual() > link2->residual();
      }
    };

    void runSolver (void);

    virtual void createLinks (void);

    virtual void maxResidualSchedule (void);

    virtual void calcFactorToVarMsg (BpLink*);

    virtual Params getVarToFactorMsg (const BpLink*) const;

    virtual Params getJointByConditioning (const VarIds&) const;

    BpLinks              links_;
    unsigned             nIters_;
    vector<SPNodeInfo*>  varsI_;
    vector<SPNodeInfo*>  facsI_;
    bool                 runned_;

    typedef multiset<BpLink*, CompareResidual> SortedOrder;
    SortedOrder sortedOrder_;

    typedef unordered_map<BpLink*, SortedOrder::iterator> BpLinkMap;
    BpLinkMap linkMap_;

    static double       accuracy_;
    static unsigned     maxIter_;
    static MsgSchedule  schedule_;

  private:
    void initializeSolver (void);

    bool converged (void);

    virtual void printLinkInformation (void) const;

    DISALLOW_COPY_AND_ASSIGN (BeliefProp);
};

#endif // HORUS_BELIEFPROP_H