diff --git a/packages/CLPBN/clpbn/bp/BpSolver.cpp b/packages/CLPBN/clpbn/bp/BpSolver.cpp new file mode 100644 index 000000000..d746d2db3 --- /dev/null +++ b/packages/CLPBN/clpbn/bp/BpSolver.cpp @@ -0,0 +1,496 @@ +#include +#include + +#include + +#include + +#include "BpSolver.h" +#include "FactorGraph.h" +#include "Factor.h" +#include "Indexer.h" +#include "Horus.h" + + +BpSolver::BpSolver (const FactorGraph& fg) : Solver (fg) +{ + factorGraph_ = &fg; +} + + + +BpSolver::~BpSolver (void) +{ + for (unsigned i = 0; i < varsI_.size(); i++) { + delete varsI_[i]; + } + for (unsigned i = 0; i < facsI_.size(); i++) { + delete facsI_[i]; + } + for (unsigned i = 0; i < links_.size(); i++) { + delete links_[i]; + } +} + + + +void +BpSolver::runSolver (void) +{ + clock_t start; + if (Constants::COLLECT_STATS) { + start = clock(); + } + runLoopySolver(); + if (Constants::DEBUG >= 2) { + cout << endl; + if (nIters_ < BpOptions::maxIter) { + cout << "Sum-Product converged in " ; + cout << nIters_ << " iterations" << endl; + } else { + cout << "The maximum number of iterations was hit, terminating..." ; + cout << endl; + } + } + unsigned size = factorGraph_->varNodes().size(); + if (Constants::COLLECT_STATS) { + unsigned nIters = 0; + bool loopy = factorGraph_->isTree() == false; + if (loopy) nIters = nIters_; + double time = (double (clock() - start)) / CLOCKS_PER_SEC; + Statistics::updateStatistics (size, loopy, nIters, time); + } +} + + + +Params +BpSolver::getPosterioriOf (VarId vid) +{ + assert (factorGraph_->getVarNode (vid)); + VarNode* var = factorGraph_->getVarNode (vid); + Params probs; + if (var->hasEvidence()) { + probs.resize (var->range(), LogAware::noEvidence()); + probs[var->getEvidence()] = LogAware::withEvidence(); + } else { + probs.resize (var->range(), LogAware::multIdenty()); + const SpLinkSet& links = ninf(var)->getLinks(); + if (Globals::logDomain) { + for (unsigned i = 0; i < links.size(); i++) { + Util::add (probs, links[i]->getMessage()); + } + LogAware::normalize (probs); + Util::fromLog (probs); + } else { + for (unsigned i = 0; i < links.size(); i++) { + Util::multiply (probs, links[i]->getMessage()); + } + LogAware::normalize (probs); + } + } + return probs; +} + + + +Params +BpSolver::getJointDistributionOf (const VarIds& jointVarIds) +{ + int idx = -1; + VarNode* vn = factorGraph_->getVarNode (jointVarIds[0]); + const FactorNodes& factorNodes = vn->neighbors(); + for (unsigned i = 0; i < factorNodes.size(); i++) { + if (factorNodes[i]->factor()->contains (jointVarIds)) { + idx = i; + break; + } + } + if (idx == -1) { + return getJointByConditioning (jointVarIds); + } else { + Factor res (*factorNodes[idx]->factor()); + const SpLinkSet& links = ninf(factorNodes[idx])->getLinks(); + for (unsigned i = 0; i < links.size(); i++) { + Factor msg (links[i]->getVariable()->varId(), + links[i]->getVariable()->range(), + getVar2FactorMsg (links[i])); + res.multiply (msg); + } + res.sumOutAllExcept (jointVarIds); + res.reorderArguments (jointVarIds); + res.normalize(); + Params jointDist = res.params(); + if (Globals::logDomain) { + Util::fromLog (jointDist); + } + return jointDist; + } +} + + + +void +BpSolver::runLoopySolver (void) +{ + initializeSolver(); + nIters_ = 0; + + while (!converged() && nIters_ < BpOptions::maxIter) { + + nIters_ ++; + if (Constants::DEBUG >= 2) { + Util::printHeader (" Iteration " + nIters_); + cout << endl; + } + + switch (BpOptions::schedule) { + case BpOptions::Schedule::SEQ_RANDOM: + random_shuffle (links_.begin(), links_.end()); + // no break + + case BpOptions::Schedule::SEQ_FIXED: + for (unsigned i = 0; i < links_.size(); i++) { + calculateAndUpdateMessage (links_[i]); + } + break; + + case BpOptions::Schedule::PARALLEL: + for (unsigned i = 0; i < links_.size(); i++) { + calculateMessage (links_[i]); + } + for (unsigned i = 0; i < links_.size(); i++) { + updateMessage(links_[i]); + } + break; + + case BpOptions::Schedule::MAX_RESIDUAL: + maxResidualSchedule(); + break; + } + if (Constants::DEBUG >= 2) { + cout << endl; + } + } +} + + + +void +BpSolver::initializeSolver (void) +{ + const VarNodes& varNodes = factorGraph_->varNodes(); + for (unsigned i = 0; i < varsI_.size(); i++) { + delete varsI_[i]; + } + varsI_.reserve (varNodes.size()); + for (unsigned i = 0; i < varNodes.size(); i++) { + varsI_.push_back (new SPNodeInfo()); + } + + const FactorNodes& facNodes = factorGraph_->factorNodes(); + for (unsigned i = 0; i < facsI_.size(); i++) { + delete facsI_[i]; + } + facsI_.reserve (facNodes.size()); + for (unsigned i = 0; i < facNodes.size(); i++) { + facsI_.push_back (new SPNodeInfo()); + } + + for (unsigned i = 0; i < links_.size(); i++) { + delete links_[i]; + } + createLinks(); + + for (unsigned i = 0; i < links_.size(); i++) { + FactorNode* src = links_[i]->getFactor(); + VarNode* dst = links_[i]->getVariable(); + ninf (dst)->addSpLink (links_[i]); + ninf (src)->addSpLink (links_[i]); + } +} + + + +void +BpSolver::createLinks (void) +{ + const FactorNodes& facNodes = factorGraph_->factorNodes(); + for (unsigned i = 0; i < facNodes.size(); i++) { + const VarNodes& neighbors = facNodes[i]->neighbors(); + for (unsigned j = 0; j < neighbors.size(); j++) { + links_.push_back (new SpLink (facNodes[i], neighbors[j])); + } + } +} + + + +bool +BpSolver::converged (void) +{ + if (links_.size() == 0) { + return true; + } + if (nIters_ == 0 || nIters_ == 1) { + return false; + } + bool converged = true; + if (BpOptions::schedule == BpOptions::Schedule::MAX_RESIDUAL) { + double maxResidual = (*(sortedOrder_.begin()))->getResidual(); + if (maxResidual > BpOptions::accuracy) { + converged = false; + } else { + converged = true; + } + } else { + for (unsigned i = 0; i < links_.size(); i++) { + double residual = links_[i]->getResidual(); + if (Constants::DEBUG >= 2) { + cout << links_[i]->toString() + " residual = " << residual << endl; + } + if (residual > BpOptions::accuracy) { + converged = false; + if (Constants::DEBUG == 0) break; + } + } + } + return converged; +} + + + +void +BpSolver::maxResidualSchedule (void) +{ + if (nIters_ == 1) { + for (unsigned i = 0; i < links_.size(); i++) { + calculateMessage (links_[i]); + SortedOrder::iterator it = sortedOrder_.insert (links_[i]); + linkMap_.insert (make_pair (links_[i], it)); + } + return; + } + + for (unsigned c = 0; c < links_.size(); c++) { + if (Constants::DEBUG >= 2) { + cout << "current residuals:" << endl; + for (SortedOrder::iterator it = sortedOrder_.begin(); + it != sortedOrder_.end(); it ++) { + cout << " " << setw (30) << left << (*it)->toString(); + cout << "residual = " << (*it)->getResidual() << endl; + } + } + + SortedOrder::iterator it = sortedOrder_.begin(); + SpLink* link = *it; + if (link->getResidual() < BpOptions::accuracy) { + return; + } + updateMessage (link); + link->clearResidual(); + sortedOrder_.erase (it); + linkMap_.find (link)->second = sortedOrder_.insert (link); + + // update the messages that depend on message source --> destin + const FactorNodes& factorNeighbors = link->getVariable()->neighbors(); + for (unsigned i = 0; i < factorNeighbors.size(); i++) { + if (factorNeighbors[i] != link->getFactor()) { + const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks(); + for (unsigned j = 0; j < links.size(); j++) { + if (links[j]->getVariable() != link->getVariable()) { + calculateMessage (links[j]); + SpLinkMap::iterator iter = linkMap_.find (links[j]); + sortedOrder_.erase (iter->second); + iter->second = sortedOrder_.insert (links[j]); + } + } + } + } + if (Constants::DEBUG >= 2) { + Util::printDashedLine(); + } + } +} + + + +void +BpSolver::calculateFactor2VariableMsg (SpLink* link) const +{ + const FactorNode* src = link->getFactor(); + const VarNode* dst = link->getVariable(); + const SpLinkSet& links = ninf(src)->getLinks(); + // calculate the product of messages that were sent + // to factor `src', except from var `dst' + unsigned msgSize = 1; + for (unsigned i = 0; i < links.size(); i++) { + msgSize *= links[i]->getVariable()->range(); + } + unsigned repetitions = 1; + Params msgProduct (msgSize, LogAware::multIdenty()); + if (Globals::logDomain) { + for (int i = links.size() - 1; i >= 0; i--) { + if (links[i]->getVariable() != dst) { + Util::add (msgProduct, getVar2FactorMsg (links[i]), repetitions); + repetitions *= links[i]->getVariable()->range(); + } else { + unsigned ds = links[i]->getVariable()->range(); + Util::add (msgProduct, Params (ds, 1.0), repetitions); + repetitions *= ds; + } + } + } else { + for (int i = links.size() - 1; i >= 0; i--) { + if (links[i]->getVariable() != dst) { + if (Constants::DEBUG >= 5) { + cout << " message from " << links[i]->getVariable()->label(); + cout << ": " << endl; + } + Util::multiply (msgProduct, getVar2FactorMsg (links[i]), repetitions); + repetitions *= links[i]->getVariable()->range(); + } else { + unsigned ds = links[i]->getVariable()->range(); + Util::multiply (msgProduct, Params (ds, 1.0), repetitions); + repetitions *= ds; + } + } + } + + Factor result (src->factor()->arguments(), + src->factor()->ranges(), + msgProduct); + result.multiply (*(src->factor())); + if (Constants::DEBUG >= 5) { + cout << " message product: " << msgProduct << endl; + cout << " original factor: " << src->params() << endl; + cout << " factor product: " << result.params() << endl; + } + result.sumOutAllExcept (dst->varId()); + if (Constants::DEBUG >= 5) { + cout << " marginalized: " ; + cout << result.params() << endl; + } + const Params& resultParams = result.params(); + Params& message = link->getNextMessage(); + for (unsigned i = 0; i < resultParams.size(); i++) { + message[i] = resultParams[i]; + } + LogAware::normalize (message); + if (Constants::DEBUG >= 5) { + cout << " curr msg: " << link->getMessage() << endl; + cout << " next msg: " << message << endl; + } +} + + + +Params +BpSolver::getVar2FactorMsg (const SpLink* link) const +{ + const VarNode* src = link->getVariable(); + const FactorNode* dst = link->getFactor(); + Params msg; + if (src->hasEvidence()) { + msg.resize (src->range(), LogAware::noEvidence()); + msg[src->getEvidence()] = LogAware::withEvidence(); + if (Constants::DEBUG >= 5) { + cout << msg; + } + } else { + msg.resize (src->range(), LogAware::one()); + } + if (Constants::DEBUG >= 5) { + cout << msg; + } + const SpLinkSet& links = ninf (src)->getLinks(); + if (Globals::logDomain) { + for (unsigned i = 0; i < links.size(); i++) { + if (links[i]->getFactor() != dst) { + Util::add (msg, links[i]->getMessage()); + } + } + } else { + for (unsigned i = 0; i < links.size(); i++) { + if (links[i]->getFactor() != dst) { + Util::multiply (msg, links[i]->getMessage()); + if (Constants::DEBUG >= 5) { + cout << " x " << links[i]->getMessage(); + } + } + } + } + if (Constants::DEBUG >= 5) { + cout << " = " << msg; + } + return msg; +} + + + +Params +BpSolver::getJointByConditioning (const VarIds& jointVarIds) const +{ + VarNodes jointVars; + for (unsigned i = 0; i < jointVarIds.size(); i++) { + assert (factorGraph_->getVarNode (jointVarIds[i])); + jointVars.push_back (factorGraph_->getVarNode (jointVarIds[i])); + } + + FactorGraph* fg = new FactorGraph (*factorGraph_); + BpSolver solver (*fg); + solver.runSolver(); + Params prevBeliefs = solver.getPosterioriOf (jointVarIds[0]); + + VarIds observedVids = {jointVars[0]->varId()}; + + for (unsigned i = 1; i < jointVarIds.size(); i++) { + assert (jointVars[i]->hasEvidence() == false); + Params newBeliefs; + Vars observedVars; + for (unsigned j = 0; j < observedVids.size(); j++) { + observedVars.push_back (fg->getVarNode (observedVids[j])); + } + StatesIndexer idx (observedVars, false); + while (idx.valid()) { + for (unsigned j = 0; j < observedVars.size(); j++) { + observedVars[j]->setEvidence (idx[j]); + } + ++ idx; + BpSolver solver (*fg); + solver.runSolver(); + Params beliefs = solver.getPosterioriOf (jointVarIds[i]); + for (unsigned k = 0; k < beliefs.size(); k++) { + newBeliefs.push_back (beliefs[k]); + } + } + + int count = -1; + for (unsigned j = 0; j < newBeliefs.size(); j++) { + if (j % jointVars[i]->range() == 0) { + count ++; + } + newBeliefs[j] *= prevBeliefs[count]; + } + prevBeliefs = newBeliefs; + observedVids.push_back (jointVars[i]->varId()); + } + return prevBeliefs; +} + + + +void +BpSolver::printLinkInformation (void) const +{ + for (unsigned i = 0; i < links_.size(); i++) { + SpLink* l = links_[i]; + cout << l->toString() << ":" << endl; + cout << " curr msg = " ; + cout << l->getMessage() << endl; + cout << " next msg = " ; + cout << l->getNextMessage() << endl; + cout << " residual = " << l->getResidual() << endl; + } +} + diff --git a/packages/CLPBN/clpbn/bp/BpSolver.h b/packages/CLPBN/clpbn/bp/BpSolver.h new file mode 100644 index 000000000..13fc6243d --- /dev/null +++ b/packages/CLPBN/clpbn/bp/BpSolver.h @@ -0,0 +1,186 @@ +#ifndef HORUS_BpSolver_H +#define HORUS_BpSolver_H + +#include +#include +#include + +#include "Solver.h" +#include "Factor.h" +#include "FactorGraph.h" +#include "Util.h" + +using namespace std; + + +class SpLink +{ + public: + SpLink (FactorNode* fn, VarNode* vn) + { + fac_ = fn; + var_ = vn; + v1_.resize (vn->range(), LogAware::tl (1.0 / vn->range())); + v2_.resize (vn->range(), LogAware::tl (1.0 / vn->range())); + currMsg_ = &v1_; + nextMsg_ = &v2_; + msgSended_ = false; + residual_ = 0.0; + } + + virtual ~SpLink (void) { }; + + FactorNode* getFactor (void) const { return fac_; } + + VarNode* getVariable (void) const { return var_; } + + const Params& getMessage (void) const { return *currMsg_; } + + Params& getNextMessage (void) { return *nextMsg_; } + + bool messageWasSended (void) const { return msgSended_; } + + double getResidual (void) const { return residual_; } + + void clearResidual (void) { residual_ = 0.0; } + + void updateResidual (void) + { + residual_ = LogAware::getMaxNorm (v1_,v2_); + } + + virtual void updateMessage (void) + { + swap (currMsg_, nextMsg_); + msgSended_ = true; + } + + string toString (void) const + { + stringstream ss; + ss << fac_->getLabel(); + ss << " -- " ; + ss << var_->label(); + return ss.str(); + } + + protected: + FactorNode* fac_; + VarNode* var_; + Params v1_; + Params v2_; + Params* currMsg_; + Params* nextMsg_; + bool msgSended_; + double residual_; +}; + +typedef vector SpLinkSet; + + +class SPNodeInfo +{ + public: + void addSpLink (SpLink* link) { links_.push_back (link); } + const SpLinkSet& getLinks (void) { return links_; } + private: + SpLinkSet links_; +}; + + +class BpSolver : public Solver +{ + public: + BpSolver (const FactorGraph&); + + virtual ~BpSolver (void); + + void runSolver (void); + + virtual Params getPosterioriOf (VarId); + + virtual Params getJointDistributionOf (const VarIds&); + + protected: + virtual void initializeSolver (void); + + virtual void createLinks (void); + + virtual void maxResidualSchedule (void); + + virtual void calculateFactor2VariableMsg (SpLink*) const; + + virtual Params getVar2FactorMsg (const SpLink*) const; + + virtual Params getJointByConditioning (const VarIds&) const; + + virtual void printLinkInformation (void) const; + + SPNodeInfo* ninf (const VarNode* var) const + { + return varsI_[var->getIndex()]; + } + + SPNodeInfo* ninf (const FactorNode* fac) const + { + return facsI_[fac->getIndex()]; + } + + void calculateAndUpdateMessage (SpLink* link, bool calcResidual = true) + { + if (Constants::DEBUG >= 3) { + cout << "calculating & updating " << link->toString() << endl; + } + calculateFactor2VariableMsg (link); + if (calcResidual) { + link->updateResidual(); + } + link->updateMessage(); + } + + void calculateMessage (SpLink* link, bool calcResidual = true) + { + if (Constants::DEBUG >= 3) { + cout << "calculating " << link->toString() << endl; + } + calculateFactor2VariableMsg (link); + if (calcResidual) { + link->updateResidual(); + } + } + + void updateMessage (SpLink* link) + { + link->updateMessage(); + if (Constants::DEBUG >= 3) { + cout << "updating " << link->toString() << endl; + } + } + + struct CompareResidual + { + inline bool operator() (const SpLink* link1, const SpLink* link2) + { + return link1->getResidual() > link2->getResidual(); + } + }; + + SpLinkSet links_; + unsigned nIters_; + vector varsI_; + vector facsI_; + const FactorGraph* factorGraph_; + + typedef multiset SortedOrder; + SortedOrder sortedOrder_; + + typedef unordered_map SpLinkMap; + SpLinkMap linkMap_; + + private: + void runLoopySolver (void); + bool converged (void); +}; + +#endif // HORUS_BpSolver_H + diff --git a/packages/CLPBN/clpbn/bp/Var.cpp b/packages/CLPBN/clpbn/bp/Var.cpp new file mode 100644 index 000000000..c359e0747 --- /dev/null +++ b/packages/CLPBN/clpbn/bp/Var.cpp @@ -0,0 +1,102 @@ +#include +#include + +#include "Var.h" + +using namespace std; + + +unordered_map Var::varsInfo_; + + +Var::Var (const Var* v) +{ + varId_ = v->varId(); + range_ = v->range(); + evidence_ = v->getEvidence(); + index_ = std::numeric_limits::max(); +} + + + +Var::Var (VarId varId, unsigned range, int evidence) +{ + assert (range != 0); + assert (evidence < (int) range); + varId_ = varId; + range_ = range; + evidence_ = evidence; + index_ = std::numeric_limits::max(); +} + + + +bool +Var::isValidState (int stateIndex) +{ + return stateIndex >= 0 && stateIndex < (int) range_; +} + + + +bool +Var::isValidState (const string& stateName) +{ + States states = Var::getVarInformation (varId_).states; + return Util::contains (states, stateName); +} + + + +void +Var::setEvidence (int ev) +{ + assert (ev < (int) range_); + evidence_ = ev; +} + + + +void +Var::setEvidence (const string& ev) +{ + States states = Var::getVarInformation (varId_).states; + for (unsigned i = 0; i < states.size(); i++) { + if (states[i] == ev) { + evidence_ = i; + return; + } + } + assert (false); +} + + + +string +Var::label (void) const +{ + if (Var::variablesHaveInformation()) { + return Var::getVarInformation (varId_).label; + } + stringstream ss; + ss << "x" << varId_; + return ss.str(); +} + + + +States +Var::states (void) const +{ + if (Var::variablesHaveInformation()) { + return Var::getVarInformation (varId_).states; + } + States states; + for (unsigned i = 0; i < range_; i++) { + stringstream ss; + ss << i ; + states.push_back (ss.str()); + } + return states; +} + diff --git a/packages/CLPBN/clpbn/bp/Var.h b/packages/CLPBN/clpbn/bp/Var.h new file mode 100644 index 000000000..e3ae62e19 --- /dev/null +++ b/packages/CLPBN/clpbn/bp/Var.h @@ -0,0 +1,108 @@ +#ifndef HORUS_Var_H +#define HORUS_Var_H + +#include + +#include + +#include "Util.h" +#include "Horus.h" + + +using namespace std; + + +struct VarInfo +{ + VarInfo (string l, const States& sts) : label(l), states(sts) { } + string label; + States states; +}; + + + +class Var +{ + public: + Var (const Var*); + + Var (VarId, unsigned, int = Constants::NO_EVIDENCE); + + virtual ~Var (void) { }; + + unsigned varId (void) const { return varId_; } + + unsigned range (void) const { return range_; } + + int getEvidence (void) const { return evidence_; } + + unsigned getIndex (void) const { return index_; } + + void setIndex (unsigned idx) { index_ = idx; } + + operator unsigned () const { return index_; } + + bool hasEvidence (void) const + { + return evidence_ != Constants::NO_EVIDENCE; + } + + bool operator== (const Var& var) const + { + assert (!(varId_ == var.varId() && range_ != var.range())); + return varId_ == var.varId(); + } + + bool operator!= (const Var& var) const + { + assert (!(varId_ == var.varId() && range_ != var.range())); + return varId_ != var.varId(); + } + + bool isValidState (int); + + bool isValidState (const string&); + + void setEvidence (int); + + void setEvidence (const string&); + + string label (void) const; + + States states (void) const; + + static void addVariableInformation ( + VarId vid, string label, const States& states) + { + assert (Util::contains (varsInfo_, vid) == false); + varsInfo_.insert (make_pair (vid, VarInfo (label, states))); + } + + static VarInfo getVarInformation (VarId vid) + { + assert (Util::contains (varsInfo_, vid)); + return varsInfo_.find (vid)->second; + } + + static bool variablesHaveInformation (void) + { + return varsInfo_.size() != 0; + } + + static void clearVariablesInformation (void) + { + varsInfo_.clear(); + } + + private: + VarId varId_; + unsigned range_; + int evidence_; + unsigned index_; + + static unordered_map varsInfo_; + +}; + +#endif // BP_Var_H +