Merge branch 'master' of github.com:tacgomes/yap6.3

This commit is contained in:
Tiago Gomes 2012-06-04 14:57:42 +01:00
commit 31fde92a36
30 changed files with 1348 additions and 1111 deletions

View File

@ -39,18 +39,16 @@ warning :-
set_solver(ve) :- set_pfl_flag(solver,ve). set_solver(ve) :- set_pfl_flag(solver,ve).
set_solver(jt) :- set_pfl_flag(solver,jt). set_solver(jt) :- set_pfl_flag(solver,jt).
set_solver(gibbs) :- set_pfl_flag(solver,gibbs). set_solver(gibbs) :- set_pfl_flag(solver,gibbs).
set_solver(fove) :- set_pfl_flag(solver,fove). set_solver(fove) :- set_pfl_flag(solver,fove), set_horus_flag(lifted_solver, fove).
set_solver(hve) :- set_pfl_flag(solver,bp), cpp_set_horus_flag(inf_alg, ve). set_solver(lbp) :- set_pfl_flag(solver,fove), set_horus_flag(lifted_solver, lbp).
set_solver(bp) :- set_pfl_flag(solver,bp), cpp_set_horus_flag(inf_alg, bp). set_solver(hve) :- set_pfl_flag(solver,bp), set_horus_flag(ground_solver, ve).
set_solver(cbp) :- set_pfl_flag(solver,bp), cpp_set_horus_flag(inf_alg, cbp). set_solver(bp) :- set_pfl_flag(solver,bp), set_horus_flag(ground_solver, bp).
set_solver(cbp) :- set_pfl_flag(solver,bp), set_horus_flag(ground_solver, cbp).
set_solver(S) :- throw(error('unknow solver ', S)). set_solver(S) :- throw(error('unknow solver ', S)).
set_horus_flag(K,V) :- cpp_set_horus_flag(K,V). set_horus_flag(K,V) :- cpp_set_horus_flag(K,V).
%:- cpp_set_horus_flag(inf_alg, ve).
%:- cpp_set_horus_flag(inf_alg, bp).
%: -cpp_set_horus_flag(inf_alg, cbp).
:- cpp_set_horus_flag(schedule, seq_fixed). :- cpp_set_horus_flag(schedule, seq_fixed).
%:- cpp_set_horus_flag(schedule, seq_random). %:- cpp_set_horus_flag(schedule, seq_random).

View File

@ -9,12 +9,10 @@
#include "Util.h" #include "Util.h"
FactorGraph* FactorGraph*
BayesBall::getMinimalFactorGraph (const VarIds& queryIds) BayesBall::getMinimalFactorGraph (const VarIds& queryIds)
{ {
assert (fg_.isFromBayesNetwork()); assert (fg_.bayesianFactors());
Scheduling scheduling; Scheduling scheduling;
for (size_t i = 0; i < queryIds.size(); i++) { for (size_t i = 0; i < queryIds.size(); i++) {
assert (dag_.getNode (queryIds[i])); assert (dag_.getNode (queryIds[i]));

View File

@ -14,7 +14,6 @@
BpSolver::BpSolver (const FactorGraph& fg) : Solver (fg) BpSolver::BpSolver (const FactorGraph& fg) : Solver (fg)
{ {
fg_ = &fg;
runned_ = false; runned_ = false;
} }
@ -39,11 +38,9 @@ Params
BpSolver::solveQuery (VarIds queryVids) BpSolver::solveQuery (VarIds queryVids)
{ {
assert (queryVids.empty() == false); assert (queryVids.empty() == false);
if (queryVids.size() == 1) { return queryVids.size() == 1
return getPosterioriOf (queryVids[0]); ? getPosterioriOf (queryVids[0])
} else { : getJointDistributionOf (queryVids);
return getJointDistributionOf (queryVids);
}
} }
@ -61,8 +58,8 @@ BpSolver::printSolverFlags (void) const
case Sch::PARALLEL: ss << "parallel"; break; case Sch::PARALLEL: ss << "parallel"; break;
case Sch::MAX_RESIDUAL: ss << "max_residual"; break; case Sch::MAX_RESIDUAL: ss << "max_residual"; break;
} }
ss << ",max_iter=" << Util::toString (BpOptions::maxIter); ss << ",max_iter=" << Util::toString (BpOptions::maxIter);
ss << ",accuracy=" << Util::toString (BpOptions::accuracy); ss << ",accuracy=" << Util::toString (BpOptions::accuracy);
ss << ",log_domain=" << Util::toString (Globals::logDomain); ss << ",log_domain=" << Util::toString (Globals::logDomain);
ss << "]" ; ss << "]" ;
cout << ss.str() << endl; cout << ss.str() << endl;
@ -76,24 +73,24 @@ BpSolver::getPosterioriOf (VarId vid)
if (runned_ == false) { if (runned_ == false) {
runSolver(); runSolver();
} }
assert (fg_->getVarNode (vid)); assert (fg.getVarNode (vid));
VarNode* var = fg_->getVarNode (vid); VarNode* var = fg.getVarNode (vid);
Params probs; Params probs;
if (var->hasEvidence()) { if (var->hasEvidence()) {
probs.resize (var->range(), LogAware::noEvidence()); probs.resize (var->range(), LogAware::noEvidence());
probs[var->getEvidence()] = LogAware::withEvidence(); probs[var->getEvidence()] = LogAware::withEvidence();
} else { } else {
probs.resize (var->range(), LogAware::multIdenty()); probs.resize (var->range(), LogAware::multIdenty());
const SpLinkSet& links = ninf(var)->getLinks(); const BpLinks& links = ninf(var)->getLinks();
if (Globals::logDomain) { if (Globals::logDomain) {
for (size_t i = 0; i < links.size(); i++) { for (size_t i = 0; i < links.size(); i++) {
probs += links[i]->getMessage(); probs += links[i]->message();
} }
LogAware::normalize (probs); LogAware::normalize (probs);
Util::exp (probs); Util::exp (probs);
} else { } else {
for (size_t i = 0; i < links.size(); i++) { for (size_t i = 0; i < links.size(); i++) {
probs *= links[i]->getMessage(); probs *= links[i]->message();
} }
LogAware::normalize (probs); LogAware::normalize (probs);
} }
@ -109,7 +106,7 @@ BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
if (runned_ == false) { if (runned_ == false) {
runSolver(); runSolver();
} }
VarNode* vn = fg_->getVarNode (jointVarIds[0]); VarNode* vn = fg.getVarNode (jointVarIds[0]);
const FacNodes& facNodes = vn->neighbors(); const FacNodes& facNodes = vn->neighbors();
size_t idx = facNodes.size(); size_t idx = facNodes.size();
for (size_t i = 0; i < facNodes.size(); i++) { for (size_t i = 0; i < facNodes.size(); i++) {
@ -122,11 +119,11 @@ BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
return getJointByConditioning (jointVarIds); return getJointByConditioning (jointVarIds);
} else { } else {
Factor res (facNodes[idx]->factor()); Factor res (facNodes[idx]->factor());
const SpLinkSet& links = ninf(facNodes[idx])->getLinks(); const BpLinks& links = ninf(facNodes[idx])->getLinks();
for (size_t i = 0; i < links.size(); i++) { for (size_t i = 0; i < links.size(); i++) {
Factor msg ({links[i]->getVariable()->varId()}, Factor msg ({links[i]->varNode()->varId()},
{links[i]->getVariable()->range()}, {links[i]->varNode()->range()},
getVar2FactorMsg (links[i])); getVarToFactorMsg (links[i]));
res.multiply (msg); res.multiply (msg);
} }
res.sumOutAllExcept (jointVarIds); res.sumOutAllExcept (jointVarIds);
@ -154,7 +151,7 @@ BpSolver::runSolver (void)
} }
switch (BpOptions::schedule) { switch (BpOptions::schedule) {
case BpOptions::Schedule::SEQ_RANDOM: case BpOptions::Schedule::SEQ_RANDOM:
random_shuffle (links_.begin(), links_.end()); std::random_shuffle (links_.begin(), links_.end());
// no break // no break
case BpOptions::Schedule::SEQ_FIXED: case BpOptions::Schedule::SEQ_FIXED:
for (size_t i = 0; i < links_.size(); i++) { for (size_t i = 0; i < links_.size(); i++) {
@ -192,11 +189,11 @@ BpSolver::runSolver (void)
void void
BpSolver::createLinks (void) BpSolver::createLinks (void)
{ {
const FacNodes& facNodes = fg_->facNodes(); const FacNodes& facNodes = fg.facNodes();
for (size_t i = 0; i < facNodes.size(); i++) { for (size_t i = 0; i < facNodes.size(); i++) {
const VarNodes& neighbors = facNodes[i]->neighbors(); const VarNodes& neighbors = facNodes[i]->neighbors();
for (size_t j = 0; j < neighbors.size(); j++) { for (size_t j = 0; j < neighbors.size(); j++) {
links_.push_back (new SpLink (facNodes[i], neighbors[j])); links_.push_back (new BpLink (facNodes[i], neighbors[j]));
} }
} }
} }
@ -221,13 +218,13 @@ BpSolver::maxResidualSchedule (void)
for (SortedOrder::iterator it = sortedOrder_.begin(); for (SortedOrder::iterator it = sortedOrder_.begin();
it != sortedOrder_.end(); ++it) { it != sortedOrder_.end(); ++it) {
cout << " " << setw (30) << left << (*it)->toString(); cout << " " << setw (30) << left << (*it)->toString();
cout << "residual = " << (*it)->getResidual() << endl; cout << "residual = " << (*it)->residual() << endl;
} }
} }
SortedOrder::iterator it = sortedOrder_.begin(); SortedOrder::iterator it = sortedOrder_.begin();
SpLink* link = *it; BpLink* link = *it;
if (link->getResidual() < BpOptions::accuracy) { if (link->residual() < BpOptions::accuracy) {
return; return;
} }
updateMessage (link); updateMessage (link);
@ -236,14 +233,14 @@ BpSolver::maxResidualSchedule (void)
linkMap_.find (link)->second = sortedOrder_.insert (link); linkMap_.find (link)->second = sortedOrder_.insert (link);
// update the messages that depend on message source --> destin // update the messages that depend on message source --> destin
const FacNodes& factorNeighbors = link->getVariable()->neighbors(); const FacNodes& factorNeighbors = link->varNode()->neighbors();
for (size_t i = 0; i < factorNeighbors.size(); i++) { for (size_t i = 0; i < factorNeighbors.size(); i++) {
if (factorNeighbors[i] != link->getFactor()) { if (factorNeighbors[i] != link->facNode()) {
const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks(); const BpLinks& links = ninf(factorNeighbors[i])->getLinks();
for (size_t j = 0; j < links.size(); j++) { for (size_t j = 0; j < links.size(); j++) {
if (links[j]->getVariable() != link->getVariable()) { if (links[j]->varNode() != link->varNode()) {
calculateMessage (links[j]); calculateMessage (links[j]);
SpLinkMap::iterator iter = linkMap_.find (links[j]); BpLinkMap::iterator iter = linkMap_.find (links[j]);
sortedOrder_.erase (iter->second); sortedOrder_.erase (iter->second);
iter->second = sortedOrder_.insert (links[j]); iter->second = sortedOrder_.insert (links[j]);
} }
@ -259,54 +256,45 @@ BpSolver::maxResidualSchedule (void)
void void
BpSolver::calculateFactor2VariableMsg (SpLink* link) BpSolver::calcFactorToVarMsg (BpLink* link)
{ {
FacNode* src = link->getFactor(); FacNode* src = link->facNode();
const VarNode* dst = link->getVariable(); const VarNode* dst = link->varNode();
const SpLinkSet& links = ninf(src)->getLinks(); const BpLinks& links = ninf(src)->getLinks();
// calculate the product of messages that were sent // calculate the product of messages that were sent
// to factor `src', except from var `dst' // to factor `src', except from var `dst'
unsigned msgSize = 1; unsigned reps = 1;
for (size_t i = 0; i < links.size(); i++) { unsigned msgSize = Util::sizeExpected (src->factor().ranges());
msgSize *= links[i]->getVariable()->range();
}
unsigned repetitions = 1;
Params msgProduct (msgSize, LogAware::multIdenty()); Params msgProduct (msgSize, LogAware::multIdenty());
if (Globals::logDomain) { if (Globals::logDomain) {
for (size_t i = links.size(); i-- > 0; ) { for (size_t i = links.size(); i-- > 0; ) {
if (links[i]->getVariable() != dst) { if (links[i]->varNode() != dst) {
if (Constants::SHOW_BP_CALCS) { if (Constants::SHOW_BP_CALCS) {
cout << " message from " << links[i]->getVariable()->label(); cout << " message from " << links[i]->varNode()->label();
cout << ": " ; cout << ": " ;
} }
Util::add (msgProduct, getVar2FactorMsg (links[i]), repetitions); Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
repetitions *= links[i]->getVariable()->range(); reps, std::plus<double>());
if (Constants::SHOW_BP_CALCS) { if (Constants::SHOW_BP_CALCS) {
cout << endl; cout << endl;
} }
} else {
unsigned range = links[i]->getVariable()->range();
Util::add (msgProduct, Params (range, 0.0), repetitions);
repetitions *= range;
} }
reps *= links[i]->varNode()->range();
} }
} else { } else {
for (size_t i = links.size(); i-- > 0; ) { for (size_t i = links.size(); i-- > 0; ) {
if (links[i]->getVariable() != dst) { if (links[i]->varNode() != dst) {
if (Constants::SHOW_BP_CALCS) { if (Constants::SHOW_BP_CALCS) {
cout << " message from " << links[i]->getVariable()->label(); cout << " message from " << links[i]->varNode()->label();
cout << ": " ; cout << ": " ;
} }
Util::multiply (msgProduct, getVar2FactorMsg (links[i]), repetitions); Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
repetitions *= links[i]->getVariable()->range(); reps, std::multiplies<double>());
if (Constants::SHOW_BP_CALCS) { if (Constants::SHOW_BP_CALCS) {
cout << endl; cout << endl;
} }
} else {
unsigned range = links[i]->getVariable()->range();
Util::multiply (msgProduct, Params (range, 1.0), repetitions);
repetitions *= range;
} }
reps *= links[i]->varNode()->range();
} }
} }
Factor result (src->factor().arguments(), Factor result (src->factor().arguments(),
@ -321,21 +309,20 @@ BpSolver::calculateFactor2VariableMsg (SpLink* link)
if (Constants::SHOW_BP_CALCS) { if (Constants::SHOW_BP_CALCS) {
cout << " marginalized: " << result.params() << endl; cout << " marginalized: " << result.params() << endl;
} }
link->getNextMessage() = result.params(); link->nextMessage() = result.params();
LogAware::normalize (link->getNextMessage()); LogAware::normalize (link->nextMessage());
if (Constants::SHOW_BP_CALCS) { if (Constants::SHOW_BP_CALCS) {
cout << " curr msg: " << link->getMessage() << endl; cout << " curr msg: " << link->message() << endl;
cout << " next msg: " << link->getNextMessage() << endl; cout << " next msg: " << link->nextMessage() << endl;
} }
} }
Params Params
BpSolver::getVar2FactorMsg (const SpLink* link) const BpSolver::getVarToFactorMsg (const BpLink* link) const
{ {
const VarNode* src = link->getVariable(); const VarNode* src = link->varNode();
const FacNode* dst = link->getFactor();
Params msg; Params msg;
if (src->hasEvidence()) { if (src->hasEvidence()) {
msg.resize (src->range(), LogAware::noEvidence()); msg.resize (src->range(), LogAware::noEvidence());
@ -346,25 +333,24 @@ BpSolver::getVar2FactorMsg (const SpLink* link) const
if (Constants::SHOW_BP_CALCS) { if (Constants::SHOW_BP_CALCS) {
cout << msg; cout << msg;
} }
const SpLinkSet& links = ninf (src)->getLinks(); BpLinks::const_iterator it;
const BpLinks& links = ninf (src)->getLinks();
if (Globals::logDomain) { if (Globals::logDomain) {
SpLinkSet::const_iterator it; for (it = links.begin(); it != links.end(); ++it) {
for (it = links.begin(); it != links.end(); ++ it) { msg += (*it)->message();
msg += (*it)->getMessage();
if (Constants::SHOW_BP_CALCS) { if (Constants::SHOW_BP_CALCS) {
cout << " x " << (*it)->getMessage(); cout << " x " << (*it)->message();
} }
} }
msg -= link->getMessage(); msg -= link->message();
} else { } else {
for (size_t i = 0; i < links.size(); i++) { for (it = links.begin(); it != links.end(); ++it) {
if (links[i]->getFactor() != dst) { msg *= (*it)->message();
msg *= links[i]->getMessage(); if (Constants::SHOW_BP_CALCS) {
if (Constants::SHOW_BP_CALCS) { cout << " x " << (*it)->message();
cout << " x " << links[i]->getMessage();
}
} }
} }
msg /= link->message();
} }
if (Constants::SHOW_BP_CALCS) { if (Constants::SHOW_BP_CALCS) {
cout << " = " << msg; cout << " = " << msg;
@ -379,12 +365,12 @@ BpSolver::getJointByConditioning (const VarIds& jointVarIds) const
{ {
VarNodes jointVars; VarNodes jointVars;
for (size_t i = 0; i < jointVarIds.size(); i++) { for (size_t i = 0; i < jointVarIds.size(); i++) {
assert (fg_->getVarNode (jointVarIds[i])); assert (fg.getVarNode (jointVarIds[i]));
jointVars.push_back (fg_->getVarNode (jointVarIds[i])); jointVars.push_back (fg.getVarNode (jointVarIds[i]));
} }
FactorGraph* fg = new FactorGraph (*fg_); FactorGraph* tempFg = new FactorGraph (fg);
BpSolver solver (*fg); BpSolver solver (*tempFg);
solver.runSolver(); solver.runSolver();
Params prevBeliefs = solver.getPosterioriOf (jointVarIds[0]); Params prevBeliefs = solver.getPosterioriOf (jointVarIds[0]);
@ -396,7 +382,7 @@ BpSolver::getJointByConditioning (const VarIds& jointVarIds) const
Vars observedVars; Vars observedVars;
Ranges observedRanges; Ranges observedRanges;
for (size_t j = 0; j < observedVids.size(); j++) { for (size_t j = 0; j < observedVids.size(); j++) {
observedVars.push_back (fg->getVarNode (observedVids[j])); observedVars.push_back (tempFg->getVarNode (observedVids[j]));
observedRanges.push_back (observedVars.back()->range()); observedRanges.push_back (observedVars.back()->range());
} }
Indexer indexer (observedRanges, false); Indexer indexer (observedRanges, false);
@ -404,7 +390,7 @@ BpSolver::getJointByConditioning (const VarIds& jointVarIds) const
for (size_t j = 0; j < observedVars.size(); j++) { for (size_t j = 0; j < observedVars.size(); j++) {
observedVars[j]->setEvidence (indexer[j]); observedVars[j]->setEvidence (indexer[j]);
} }
BpSolver solver (*fg); BpSolver solver (*tempFg);
solver.runSolver(); solver.runSolver();
Params beliefs = solver.getPosterioriOf (jointVarIds[i]); Params beliefs = solver.getPosterioriOf (jointVarIds[i]);
for (size_t k = 0; k < beliefs.size(); k++) { for (size_t k = 0; k < beliefs.size(); k++) {
@ -431,22 +417,22 @@ BpSolver::getJointByConditioning (const VarIds& jointVarIds) const
void void
BpSolver::initializeSolver (void) BpSolver::initializeSolver (void)
{ {
const VarNodes& varNodes = fg_->varNodes(); const VarNodes& varNodes = fg.varNodes();
varsI_.reserve (varNodes.size()); varsI_.reserve (varNodes.size());
for (size_t i = 0; i < varNodes.size(); i++) { for (size_t i = 0; i < varNodes.size(); i++) {
varsI_.push_back (new SPNodeInfo()); varsI_.push_back (new SPNodeInfo());
} }
const FacNodes& facNodes = fg_->facNodes(); const FacNodes& facNodes = fg.facNodes();
facsI_.reserve (facNodes.size()); facsI_.reserve (facNodes.size());
for (size_t i = 0; i < facNodes.size(); i++) { for (size_t i = 0; i < facNodes.size(); i++) {
facsI_.push_back (new SPNodeInfo()); facsI_.push_back (new SPNodeInfo());
} }
createLinks(); createLinks();
for (size_t i = 0; i < links_.size(); i++) { for (size_t i = 0; i < links_.size(); i++) {
FacNode* src = links_[i]->getFactor(); FacNode* src = links_[i]->facNode();
VarNode* dst = links_[i]->getVariable(); VarNode* dst = links_[i]->varNode();
ninf (dst)->addSpLink (links_[i]); ninf (dst)->addBpLink (links_[i]);
ninf (src)->addSpLink (links_[i]); ninf (src)->addBpLink (links_[i]);
} }
} }
@ -472,7 +458,7 @@ BpSolver::converged (void)
} }
bool converged = true; bool converged = true;
if (BpOptions::schedule == BpOptions::Schedule::MAX_RESIDUAL) { if (BpOptions::schedule == BpOptions::Schedule::MAX_RESIDUAL) {
double maxResidual = (*(sortedOrder_.begin()))->getResidual(); double maxResidual = (*(sortedOrder_.begin()))->residual();
if (maxResidual > BpOptions::accuracy) { if (maxResidual > BpOptions::accuracy) {
converged = false; converged = false;
} else { } else {
@ -480,7 +466,7 @@ BpSolver::converged (void)
} }
} else { } else {
for (size_t i = 0; i < links_.size(); i++) { for (size_t i = 0; i < links_.size(); i++) {
double residual = links_[i]->getResidual(); double residual = links_[i]->residual();
if (Globals::verbosity > 1) { if (Globals::verbosity > 1) {
cout << links_[i]->toString() + " residual = " << residual << endl; cout << links_[i]->toString() + " residual = " << residual << endl;
} }
@ -504,13 +490,13 @@ void
BpSolver::printLinkInformation (void) const BpSolver::printLinkInformation (void) const
{ {
for (size_t i = 0; i < links_.size(); i++) { for (size_t i = 0; i < links_.size(); i++) {
SpLink* l = links_[i]; BpLink* l = links_[i];
cout << l->toString() << ":" << endl; cout << l->toString() << ":" << endl;
cout << " curr msg = " ; cout << " curr msg = " ;
cout << l->getMessage() << endl; cout << l->message() << endl;
cout << " next msg = " ; cout << " next msg = " ;
cout << l->getNextMessage() << endl; cout << l->nextMessage() << endl;
cout << " residual = " << l->getResidual() << endl; cout << " residual = " << l->residual() << endl;
} }
} }

View File

@ -13,10 +13,10 @@
using namespace std; using namespace std;
class SpLink class BpLink
{ {
public: public:
SpLink (FacNode* fn, VarNode* vn) BpLink (FacNode* fn, VarNode* vn)
{ {
fac_ = fn; fac_ = fn;
var_ = vn; var_ = vn;
@ -24,23 +24,20 @@ class SpLink
v2_.resize (vn->range(), LogAware::log (1.0 / vn->range())); v2_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
currMsg_ = &v1_; currMsg_ = &v1_;
nextMsg_ = &v2_; nextMsg_ = &v2_;
msgSended_ = false;
residual_ = 0.0; residual_ = 0.0;
} }
virtual ~SpLink (void) { }; virtual ~BpLink (void) { };
FacNode* getFactor (void) const { return fac_; } FacNode* facNode (void) const { return fac_; }
VarNode* getVariable (void) const { return var_; } VarNode* varNode (void) const { return var_; }
const Params& getMessage (void) const { return *currMsg_; } const Params& message (void) const { return *currMsg_; }
Params& getNextMessage (void) { return *nextMsg_; } Params& nextMessage (void) { return *nextMsg_; }
bool messageWasSended (void) const { return msgSended_; } double residual (void) const { return residual_; }
double getResidual (void) const { return residual_; }
void clearResidual (void) { residual_ = 0.0; } void clearResidual (void) { residual_ = 0.0; }
@ -52,7 +49,6 @@ class SpLink
virtual void updateMessage (void) virtual void updateMessage (void)
{ {
swap (currMsg_, nextMsg_); swap (currMsg_, nextMsg_);
msgSended_ = true;
} }
string toString (void) const string toString (void) const
@ -71,20 +67,19 @@ class SpLink
Params v2_; Params v2_;
Params* currMsg_; Params* currMsg_;
Params* nextMsg_; Params* nextMsg_;
bool msgSended_;
double residual_; double residual_;
}; };
typedef vector<SpLink*> SpLinkSet; typedef vector<BpLink*> BpLinks;
class SPNodeInfo class SPNodeInfo
{ {
public: public:
void addSpLink (SpLink* link) { links_.push_back (link); } void addBpLink (BpLink* link) { links_.push_back (link); }
const SpLinkSet& getLinks (void) { return links_; } const BpLinks& getLinks (void) { return links_; }
private: private:
SpLinkSet links_; BpLinks links_;
}; };
@ -110,9 +105,9 @@ class BpSolver : public Solver
virtual void maxResidualSchedule (void); virtual void maxResidualSchedule (void);
virtual void calculateFactor2VariableMsg (SpLink*); virtual void calcFactorToVarMsg (BpLink*);
virtual Params getVar2FactorMsg (const SpLink*) const; virtual Params getVarToFactorMsg (const BpLink*) const;
virtual Params getJointByConditioning (const VarIds&) const; virtual Params getJointByConditioning (const VarIds&) const;
@ -126,30 +121,30 @@ class BpSolver : public Solver
return facsI_[fac->getIndex()]; return facsI_[fac->getIndex()];
} }
void calculateAndUpdateMessage (SpLink* link, bool calcResidual = true) void calculateAndUpdateMessage (BpLink* link, bool calcResidual = true)
{ {
if (Globals::verbosity > 2) { if (Globals::verbosity > 2) {
cout << "calculating & updating " << link->toString() << endl; cout << "calculating & updating " << link->toString() << endl;
} }
calculateFactor2VariableMsg (link); calcFactorToVarMsg (link);
if (calcResidual) { if (calcResidual) {
link->updateResidual(); link->updateResidual();
} }
link->updateMessage(); link->updateMessage();
} }
void calculateMessage (SpLink* link, bool calcResidual = true) void calculateMessage (BpLink* link, bool calcResidual = true)
{ {
if (Globals::verbosity > 2) { if (Globals::verbosity > 2) {
cout << "calculating " << link->toString() << endl; cout << "calculating " << link->toString() << endl;
} }
calculateFactor2VariableMsg (link); calcFactorToVarMsg (link);
if (calcResidual) { if (calcResidual) {
link->updateResidual(); link->updateResidual();
} }
} }
void updateMessage (SpLink* link) void updateMessage (BpLink* link)
{ {
link->updateMessage(); link->updateMessage();
if (Globals::verbosity > 2) { if (Globals::verbosity > 2) {
@ -159,24 +154,23 @@ class BpSolver : public Solver
struct CompareResidual struct CompareResidual
{ {
inline bool operator() (const SpLink* link1, const SpLink* link2) inline bool operator() (const BpLink* link1, const BpLink* link2)
{ {
return link1->getResidual() > link2->getResidual(); return link1->residual() > link2->residual();
} }
}; };
SpLinkSet links_; BpLinks links_;
unsigned nIters_; unsigned nIters_;
vector<SPNodeInfo*> varsI_; vector<SPNodeInfo*> varsI_;
vector<SPNodeInfo*> facsI_; vector<SPNodeInfo*> facsI_;
bool runned_; bool runned_;
const FactorGraph* fg_;
typedef multiset<SpLink*, CompareResidual> SortedOrder; typedef multiset<BpLink*, CompareResidual> SortedOrder;
SortedOrder sortedOrder_; SortedOrder sortedOrder_;
typedef unordered_map<SpLink*, SortedOrder::iterator> SpLinkMap; typedef unordered_map<BpLink*, SortedOrder::iterator> BpLinkMap;
SpLinkMap linkMap_; BpLinkMap linkMap_;
private: private:
void initializeSolver (void); void initializeSolver (void);

View File

@ -1,318 +0,0 @@
#include "CFactorGraph.h"
#include "Factor.h"
bool CFactorGraph::checkForIdenticalFactors = true;
CFactorGraph::CFactorGraph (const FactorGraph& fg)
: freeColor_(0), groundFg_(&fg)
{
findIdenticalFactors();
setInitialColors();
createGroups();
}
CFactorGraph::~CFactorGraph (void)
{
for (size_t i = 0; i < varClusters_.size(); i++) {
delete varClusters_[i];
}
for (size_t i = 0; i < facClusters_.size(); i++) {
delete facClusters_[i];
}
}
void
CFactorGraph::findIdenticalFactors()
{
const FacNodes& facNodes = groundFg_->facNodes();
if (checkForIdenticalFactors == false ||
facNodes.size() == 1) {
return;
}
for (size_t i = 0; i < facNodes.size(); i++) {
facNodes[i]->factor().setDistId (Util::maxUnsigned());
}
unsigned groupCount = 1;
for (size_t i = 0; i < facNodes.size() - 1; i++) {
Factor& f1 = facNodes[i]->factor();
if (f1.distId() != Util::maxUnsigned()) {
continue;
}
f1.setDistId (groupCount);
for (size_t j = i + 1; j < facNodes.size(); j++) {
Factor& f2 = facNodes[j]->factor();
if (f2.distId() != Util::maxUnsigned()) {
continue;
}
if (f1.size() == f2.size() &&
f1.ranges() == f2.ranges() &&
f1.params() == f2.params()) {
f2.setDistId (groupCount);
}
}
groupCount ++;
}
}
void
CFactorGraph::setInitialColors (void)
{
varColors_.resize (groundFg_->nrVarNodes());
facColors_.resize (groundFg_->nrFacNodes());
// create the initial variable colors
VarColorMap colorMap;
const VarNodes& varNodes = groundFg_->varNodes();
for (size_t i = 0; i < varNodes.size(); i++) {
unsigned range = varNodes[i]->range();
VarColorMap::iterator it = colorMap.find (range);
if (it == colorMap.end()) {
it = colorMap.insert (make_pair (
range, Colors (range + 1, -1))).first;
}
unsigned idx = varNodes[i]->hasEvidence()
? varNodes[i]->getEvidence()
: range;
Colors& stateColors = it->second;
if (stateColors[idx] == -1) {
stateColors[idx] = getFreeColor();
}
setColor (varNodes[i], stateColors[idx]);
}
const FacNodes& facNodes = groundFg_->facNodes();
// create the initial factor colors
DistColorMap distColors;
for (size_t i = 0; i < facNodes.size(); i++) {
unsigned distId = facNodes[i]->factor().distId();
DistColorMap::iterator it = distColors.find (distId);
if (it == distColors.end()) {
it = distColors.insert (make_pair (distId, getFreeColor())).first;
}
setColor (facNodes[i], it->second);
}
}
void
CFactorGraph::createGroups (void)
{
VarSignMap varGroups;
FacSignMap facGroups;
unsigned nIters = 0;
bool groupsHaveChanged = true;
const VarNodes& varNodes = groundFg_->varNodes();
const FacNodes& facNodes = groundFg_->facNodes();
while (groupsHaveChanged || nIters == 1) {
nIters ++;
// set a new color to the variables with the same signature
size_t prevVarGroupsSize = varGroups.size();
varGroups.clear();
for (size_t i = 0; i < varNodes.size(); i++) {
const VarSignature& signature = getSignature (varNodes[i]);
VarSignMap::iterator it = varGroups.find (signature);
if (it == varGroups.end()) {
it = varGroups.insert (make_pair (signature, VarNodes())).first;
}
it->second.push_back (varNodes[i]);
}
for (VarSignMap::iterator it = varGroups.begin();
it != varGroups.end(); ++it) {
Color newColor = getFreeColor();
VarNodes& groupMembers = it->second;
for (size_t i = 0; i < groupMembers.size(); i++) {
setColor (groupMembers[i], newColor);
}
}
size_t prevFactorGroupsSize = facGroups.size();
facGroups.clear();
// set a new color to the factors with the same signature
for (size_t i = 0; i < facNodes.size(); i++) {
const FacSignature& signature = getSignature (facNodes[i]);
FacSignMap::iterator it = facGroups.find (signature);
if (it == facGroups.end()) {
it = facGroups.insert (make_pair (signature, FacNodes())).first;
}
it->second.push_back (facNodes[i]);
}
for (FacSignMap::iterator it = facGroups.begin();
it != facGroups.end(); ++it) {
Color newColor = getFreeColor();
FacNodes& groupMembers = it->second;
for (size_t i = 0; i < groupMembers.size(); i++) {
setColor (groupMembers[i], newColor);
}
}
groupsHaveChanged = prevVarGroupsSize != varGroups.size()
|| prevFactorGroupsSize != facGroups.size();
}
// printGroups (varGroups, facGroups);
createClusters (varGroups, facGroups);
}
void
CFactorGraph::createClusters (
const VarSignMap& varGroups,
const FacSignMap& facGroups)
{
varClusters_.reserve (varGroups.size());
for (VarSignMap::const_iterator it = varGroups.begin();
it != varGroups.end(); ++it) {
const VarNodes& groupVars = it->second;
VarCluster* vc = new VarCluster (groupVars);
for (size_t i = 0; i < groupVars.size(); i++) {
vid2VarCluster_.insert (make_pair (groupVars[i]->varId(), vc));
}
varClusters_.push_back (vc);
}
facClusters_.reserve (facGroups.size());
for (FacSignMap::const_iterator it = facGroups.begin();
it != facGroups.end(); ++it) {
FacNode* groupFactor = it->second[0];
const VarNodes& neighs = groupFactor->neighbors();
VarClusters varClusters;
varClusters.reserve (neighs.size());
for (size_t i = 0; i < neighs.size(); i++) {
VarId vid = neighs[i]->varId();
varClusters.push_back (vid2VarCluster_.find (vid)->second);
}
facClusters_.push_back (new FacCluster (it->second, varClusters));
}
}
VarSignature
CFactorGraph::getSignature (const VarNode* varNode)
{
const FacNodes& neighs = varNode->neighbors();
VarSignature sign;
sign.reserve (neighs.size() + 1);
for (size_t i = 0; i < neighs.size(); i++) {
sign.push_back (make_pair (
getColor (neighs[i]),
neighs[i]->factor().indexOf (varNode->varId())));
}
std::sort (sign.begin(), sign.end());
sign.push_back (make_pair (getColor (varNode), 0));
return sign;
}
FacSignature
CFactorGraph::getSignature (const FacNode* facNode)
{
const VarNodes& neighs = facNode->neighbors();
FacSignature sign;
sign.reserve (neighs.size() + 1);
for (size_t i = 0; i < neighs.size(); i++) {
sign.push_back (getColor (neighs[i]));
}
sign.push_back (getColor (facNode));
return sign;
}
FactorGraph*
CFactorGraph::getGroundFactorGraph (void)
{
FactorGraph* fg = new FactorGraph();
for (size_t i = 0; i < varClusters_.size(); i++) {
VarNode* newVar = new VarNode (varClusters_[i]->first());
varClusters_[i]->setRepresentative (newVar);
fg->addVarNode (newVar);
}
for (size_t i = 0; i < facClusters_.size(); i++) {
Vars vars;
const VarClusters& clusters = facClusters_[i]->varClusters();
for (size_t j = 0; j < clusters.size(); j++) {
vars.push_back (clusters[j]->representative());
}
const Factor& groundFac = facClusters_[i]->first()->factor();
FacNode* fn = new FacNode (Factor (
vars, groundFac.params(), groundFac.distId()));
facClusters_[i]->setRepresentative (fn);
fg->addFacNode (fn);
for (size_t j = 0; j < vars.size(); j++) {
fg->addEdge (static_cast<VarNode*> (vars[j]), fn);
}
}
return fg;
}
unsigned
CFactorGraph::getEdgeCount (
const FacCluster* fc,
const VarCluster* vc,
size_t index) const
{
unsigned count = 0;
VarId reprVid = vc->representative()->varId();
VarNode* groundVar = groundFg_->getVarNode (reprVid);
const FacNodes& neighs = groundVar->neighbors();
for (size_t i = 0; i < neighs.size(); i++) {
FacNodes::const_iterator it;
it = std::find (fc->members().begin(), fc->members().end(), neighs[i]);
if (it != fc->members().end() &&
(*it)->factor().indexOf (reprVid) == index) {
count ++;
}
}
return count;
}
void
CFactorGraph::printGroups (
const VarSignMap& varGroups,
const FacSignMap& facGroups) const
{
unsigned count = 1;
cout << "variable groups:" << endl;
for (VarSignMap::const_iterator it = varGroups.begin();
it != varGroups.end(); ++it) {
const VarNodes& groupMembers = it->second;
if (groupMembers.size() > 0) {
cout << count << ": " ;
for (size_t i = 0; i < groupMembers.size(); i++) {
cout << groupMembers[i]->label() << " " ;
}
count ++;
cout << endl;
}
}
count = 1;
cout << endl << "factor groups:" << endl;
for (FacSignMap::const_iterator it = facGroups.begin();
it != facGroups.end(); ++it) {
const FacNodes& groupMembers = it->second;
if (groupMembers.size() > 0) {
cout << ++count << ": " ;
for (size_t i = 0; i < groupMembers.size(); i++) {
cout << groupMembers[i]->getLabel() << " " ;
}
count ++;
cout << endl;
}
}
}

View File

@ -1,176 +0,0 @@
#ifndef HORUS_CFACTORGRAPH_H
#define HORUS_CFACTORGRAPH_H
#include <unordered_map>
#include "FactorGraph.h"
#include "Factor.h"
#include "Util.h"
#include "Horus.h"
class VarCluster;
class FacCluster;
class VarSignatureHash;
class FacSignatureHash;
typedef long Color;
typedef vector<Color> Colors;
typedef vector<std::pair<Color,unsigned>> VarSignature;
typedef vector<Color> FacSignature;
typedef unordered_map<unsigned, Color> DistColorMap;
typedef unordered_map<unsigned, Colors> VarColorMap;
typedef unordered_map<VarSignature, VarNodes, VarSignatureHash> VarSignMap;
typedef unordered_map<FacSignature, FacNodes, FacSignatureHash> FacSignMap;
typedef vector<VarCluster*> VarClusters;
typedef vector<FacCluster*> FacClusters;
typedef unordered_map<VarId, VarCluster*> VarId2VarCluster;
struct VarSignatureHash
{
size_t operator() (const VarSignature &sig) const
{
size_t val = hash<size_t>()(sig.size());
for (size_t i = 0; i < sig.size(); i++) {
val ^= hash<size_t>()(sig[i].first);
val ^= hash<size_t>()(sig[i].second);
}
return val;
}
};
struct FacSignatureHash
{
size_t operator() (const FacSignature &sig) const
{
size_t val = hash<size_t>()(sig.size());
for (size_t i = 0; i < sig.size(); i++) {
val ^= hash<size_t>()(sig[i]);
}
return val;
}
};
class VarCluster
{
public:
VarCluster (const VarNodes& vs) : members_(vs) { }
const VarNode* first (void) const { return members_.front(); }
const VarNodes& members (void) const { return members_; }
VarNode* representative (void) const { return repr_; }
void setRepresentative (VarNode* vn) { repr_ = vn; }
private:
VarNodes members_;
VarNode* repr_;
};
class FacCluster
{
public:
FacCluster (const FacNodes& fcs, const VarClusters& vcs)
: members_(fcs), varClusters_(vcs) { }
const FacNode* first (void) const { return members_.front(); }
const FacNodes& members (void) const { return members_; }
VarClusters& varClusters (void) { return varClusters_; }
FacNode* representative (void) const { return repr_; }
void setRepresentative (FacNode* fn) { repr_ = fn; }
private:
FacNodes members_;
VarClusters varClusters_;
FacNode* repr_;
};
class CFactorGraph
{
public:
CFactorGraph (const FactorGraph&);
~CFactorGraph (void);
const VarClusters& varClusters (void) { return varClusters_; }
const FacClusters& facClusters (void) { return facClusters_; }
VarNode* getEquivalent (VarId vid)
{
VarCluster* vc = vid2VarCluster_.find (vid)->second;
return vc->representative();
}
FactorGraph* getGroundFactorGraph (void);
unsigned getEdgeCount (const FacCluster*,
const VarCluster*, size_t index) const;
static bool checkForIdenticalFactors;
private:
Color getFreeColor (void)
{
++ freeColor_;
return freeColor_ - 1;
}
Color getColor (const VarNode* vn) const
{
return varColors_[vn->getIndex()];
}
Color getColor (const FacNode* fn) const {
return facColors_[fn->getIndex()];
}
void setColor (const VarNode* vn, Color c)
{
varColors_[vn->getIndex()] = c;
}
void setColor (const FacNode* fn, Color c)
{
facColors_[fn->getIndex()] = c;
}
void findIdenticalFactors (void);
void setInitialColors (void);
void createGroups (void);
void createClusters (const VarSignMap&, const FacSignMap&);
VarSignature getSignature (const VarNode*);
FacSignature getSignature (const FacNode*);
void printGroups (const VarSignMap&, const FacSignMap&) const;
Color freeColor_;
Colors varColors_;
Colors facColors_;
VarClusters varClusters_;
FacClusters facClusters_;
VarId2VarCluster vid2VarCluster_;
const FactorGraph* groundFg_;
};
#endif // HORUS_CFACTORGRAPH_H

View File

@ -1,22 +1,32 @@
#include "CbpSolver.h" #include "CbpSolver.h"
#include "WeightedBpSolver.h"
CbpSolver::CbpSolver (const FactorGraph& fg) : BpSolver (fg) bool CbpSolver::checkForIdenticalFactors = true;
CbpSolver::CbpSolver (const FactorGraph& fg)
: Solver (fg), freeColor_(0)
{ {
cfg_ = new CFactorGraph (fg); findIdenticalFactors();
fg_ = cfg_->getGroundFactorGraph(); setInitialColors();
createGroups();
compressedFg_ = getCompressedFactorGraph();
solver_ = new WeightedBpSolver (*compressedFg_, getWeights());
} }
CbpSolver::~CbpSolver (void) CbpSolver::~CbpSolver (void)
{ {
delete cfg_; delete solver_;
delete fg_; delete compressedFg_;
for (size_t i = 0; i < links_.size(); i++) { for (size_t i = 0; i < varClusters_.size(); i++) {
delete links_[i]; delete varClusters_[i];
}
for (size_t i = 0; i < facClusters_.size(); i++) {
delete facClusters_[i];
} }
links_.clear();
} }
@ -38,7 +48,7 @@ CbpSolver::printSolverFlags (void) const
ss << ",accuracy=" << BpOptions::accuracy; ss << ",accuracy=" << BpOptions::accuracy;
ss << ",log_domain=" << Util::toString (Globals::logDomain); ss << ",log_domain=" << Util::toString (Globals::logDomain);
ss << ",chkif=" << ss << ",chkif=" <<
Util::toString (CFactorGraph::checkForIdenticalFactors); Util::toString (CbpSolver::checkForIdenticalFactors);
ss << "]" ; ss << "]" ;
cout << ss.str() << endl; cout << ss.str() << endl;
} }
@ -46,305 +56,345 @@ CbpSolver::printSolverFlags (void) const
Params Params
CbpSolver::getPosterioriOf (VarId vid) CbpSolver::solveQuery (VarIds queryVids)
{ {
if (runned_ == false) { assert (queryVids.empty() == false);
runSolver(); Params res;
} if (queryVids.size() == 1) {
assert (cfg_->getEquivalent (vid)); res = solver_->getPosterioriOf (getRepresentative (queryVids[0]));
VarNode* var = cfg_->getEquivalent (vid);
Params probs;
if (var->hasEvidence()) {
probs.resize (var->range(), LogAware::noEvidence());
probs[var->getEvidence()] = LogAware::withEvidence();
} else { } else {
probs.resize (var->range(), LogAware::multIdenty()); VarNode* vn = fg.getVarNode (queryVids[0]);
const SpLinkSet& links = ninf(var)->getLinks(); const FacNodes& facNodes = vn->neighbors();
if (Globals::logDomain) { size_t idx = facNodes.size();
for (size_t i = 0; i < links.size(); i++) { for (size_t i = 0; i < facNodes.size(); i++) {
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]); if (facNodes[i]->factor().contains (queryVids)) {
probs += l->poweredMessage(); idx = i;
break;
} }
LogAware::normalize (probs); cout << endl;
Util::exp (probs);
} else {
for (size_t i = 0; i < links.size(); i++) {
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
probs *= l->poweredMessage();
}
LogAware::normalize (probs);
} }
if (idx == facNodes.size()) {
cerr << "error: only joint distributions on variables of some " ;
cerr << "clique are supported with the current solver" ;
cerr << endl;
exit (1);
}
VarIds representatives;
for (size_t i = 0; i < queryVids.size(); i++) {
representatives.push_back (getRepresentative (queryVids[i]));
}
res = solver_->getJointDistributionOf (representatives);
} }
return probs; return res;
}
Params
CbpSolver::getJointDistributionOf (const VarIds& jointVids)
{
VarIds eqVarIds;
for (size_t i = 0; i < jointVids.size(); i++) {
VarNode* vn = cfg_->getEquivalent (jointVids[i]);
eqVarIds.push_back (vn->varId());
}
return BpSolver::getJointDistributionOf (eqVarIds);
} }
void void
CbpSolver::createLinks (void) CbpSolver::findIdenticalFactors()
{ {
if (Globals::verbosity > 0) { const FacNodes& facNodes = fg.facNodes();
cout << "compressed factor graph contains " ; if (checkForIdenticalFactors == false ||
cout << fg_->nrVarNodes() << " variables and " ; facNodes.size() == 1) {
cout << fg_->nrFacNodes() << " factors " << endl;
cout << endl;
}
const FacClusters& fcs = cfg_->facClusters();
for (size_t i = 0; i < fcs.size(); i++) {
const VarClusters& vcs = fcs[i]->varClusters();
for (size_t j = 0; j < vcs.size(); j++) {
unsigned count = cfg_->getEdgeCount (fcs[i], vcs[j], j);
if (Globals::verbosity > 1) {
cout << "creating link " ;
cout << fcs[i]->representative()->getLabel();
cout << " -- " ;
cout << vcs[j]->representative()->label();
cout << " idx=" << j << ", count=" << count << endl;
}
links_.push_back (new CbpSolverLink (
fcs[i]->representative(), vcs[j]->representative(), j, count));
}
}
if (Globals::verbosity > 1) {
cout << endl;
}
}
void
CbpSolver::maxResidualSchedule (void)
{
if (nIters_ == 1) {
for (size_t i = 0; i < links_.size(); i++) {
calculateMessage (links_[i]);
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
linkMap_.insert (make_pair (links_[i], it));
if (Globals::verbosity >= 1) {
cout << "calculating " << links_[i]->toString() << endl;
}
}
return; return;
} }
for (size_t i = 0; i < facNodes.size(); i++) {
for (size_t c = 0; c < links_.size(); c++) { facNodes[i]->factor().setDistId (Util::maxUnsigned());
if (Globals::verbosity > 1) { }
cout << endl << "current residuals:" << endl; unsigned groupCount = 1;
for (SortedOrder::iterator it = sortedOrder_.begin(); for (size_t i = 0; i < facNodes.size() - 1; i++) {
it != sortedOrder_.end(); ++it) { Factor& f1 = facNodes[i]->factor();
cout << " " << setw (30) << left << (*it)->toString(); if (f1.distId() != Util::maxUnsigned()) {
cout << "residual = " << (*it)->getResidual() << endl; continue;
} }
} f1.setDistId (groupCount);
for (size_t j = i + 1; j < facNodes.size(); j++) {
SortedOrder::iterator it = sortedOrder_.begin(); Factor& f2 = facNodes[j]->factor();
SpLink* link = *it; if (f2.distId() != Util::maxUnsigned()) {
if (Globals::verbosity >= 1) { continue;
cout << "updating " << (*sortedOrder_.begin())->toString() << endl; }
} if (f1.size() == f2.size() &&
if (link->getResidual() < BpOptions::accuracy) { f1.ranges() == f2.ranges() &&
return; f1.params() == f2.params()) {
} f2.setDistId (groupCount);
link->updateMessage();
link->clearResidual();
sortedOrder_.erase (it);
linkMap_.find (link)->second = sortedOrder_.insert (link);
// update the messages that depend on message source --> destin
const FacNodes& factorNeighbors = link->getVariable()->neighbors();
for (size_t i = 0; i < factorNeighbors.size(); i++) {
const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks();
for (size_t j = 0; j < links.size(); j++) {
if (links[j]->getVariable() != link->getVariable()) {
if (Globals::verbosity > 1) {
cout << " calculating " << links[j]->toString() << endl;
}
calculateMessage (links[j]);
SpLinkMap::iterator iter = linkMap_.find (links[j]);
sortedOrder_.erase (iter->second);
iter->second = sortedOrder_.insert (links[j]);
}
}
}
// in counting bp, the message that a variable X sends to
// to a factor F depends on the message that F sent to the X
const SpLinkSet& links = ninf(link->getFactor())->getLinks();
for (size_t i = 0; i < links.size(); i++) {
if (links[i]->getVariable() != link->getVariable()) {
if (Globals::verbosity > 1) {
cout << " calculating " << links[i]->toString() << endl;
}
calculateMessage (links[i]);
SpLinkMap::iterator iter = linkMap_.find (links[i]);
sortedOrder_.erase (iter->second);
iter->second = sortedOrder_.insert (links[i]);
} }
} }
groupCount ++;
} }
} }
void void
CbpSolver::calculateFactor2VariableMsg (SpLink* _link) CbpSolver::setInitialColors (void)
{ {
CbpSolverLink* link = static_cast<CbpSolverLink*> (_link); varColors_.resize (fg.nrVarNodes());
FacNode* src = link->getFactor(); facColors_.resize (fg.nrFacNodes());
const VarNode* dst = link->getVariable(); // create the initial variable colors
const SpLinkSet& links = ninf(src)->getLinks(); VarColorMap colorMap;
// calculate the product of messages that were sent const VarNodes& varNodes = fg.varNodes();
// to factor `src', except from var `dst' for (size_t i = 0; i < varNodes.size(); i++) {
unsigned msgSize = 1; unsigned range = varNodes[i]->range();
for (size_t i = 0; i < links.size(); i++) { VarColorMap::iterator it = colorMap.find (range);
msgSize *= links[i]->getVariable()->range(); if (it == colorMap.end()) {
} it = colorMap.insert (make_pair (
unsigned repetitions = 1; range, Colors (range + 1, -1))).first;
Params msgProduct (msgSize, LogAware::multIdenty());
if (Globals::logDomain) {
for (size_t i = links.size(); i-- > 0; ) {
const CbpSolverLink* cl = static_cast<const CbpSolverLink*> (links[i]);
if ( ! (cl->getVariable() == dst && cl->index() == link->index())) {
if (Constants::SHOW_BP_CALCS) {
cout << " message from " << links[i]->getVariable()->label();
cout << ": " ;
}
Util::add (msgProduct, getVar2FactorMsg (links[i]), repetitions);
repetitions *= links[i]->getVariable()->range();
if (Constants::SHOW_BP_CALCS) {
cout << endl;
}
} else {
unsigned range = links[i]->getVariable()->range();
Util::add (msgProduct, Params (range, 0.0), repetitions);
repetitions *= range;
}
} }
} else { unsigned idx = varNodes[i]->hasEvidence()
for (size_t i = links.size(); i-- > 0; ) { ? varNodes[i]->getEvidence()
const CbpSolverLink* cl = static_cast<const CbpSolverLink*> (links[i]); : range;
if ( ! (cl->getVariable() == dst && cl->index() == link->index())) { Colors& stateColors = it->second;
if (Constants::SHOW_BP_CALCS) { if (stateColors[idx] == -1) {
cout << " message from " << links[i]->getVariable()->label(); stateColors[idx] = getNewColor();
cout << ": " ;
}
Util::multiply (msgProduct, getVar2FactorMsg (links[i]), repetitions);
repetitions *= links[i]->getVariable()->range();
if (Constants::SHOW_BP_CALCS) {
cout << endl;
}
} else {
unsigned range = links[i]->getVariable()->range();
Util::multiply (msgProduct, Params (range, 1.0), repetitions);
repetitions *= range;
}
} }
setColor (varNodes[i], stateColors[idx]);
} }
Factor result (src->factor().arguments(), const FacNodes& facNodes = fg.facNodes();
src->factor().ranges(), msgProduct); // create the initial factor colors
assert (msgProduct.size() == src->factor().size()); DistColorMap distColors;
if (Globals::logDomain) { for (size_t i = 0; i < facNodes.size(); i++) {
for (size_t i = 0; i < result.size(); i++) { unsigned distId = facNodes[i]->factor().distId();
result[i] += src->factor()[i]; DistColorMap::iterator it = distColors.find (distId);
} if (it == distColors.end()) {
} else { it = distColors.insert (make_pair (distId, getNewColor())).first;
for (size_t i = 0; i < result.size(); i++) {
result[i] *= src->factor()[i];
} }
setColor (facNodes[i], it->second);
} }
if (Constants::SHOW_BP_CALCS) {
cout << " message product: " << msgProduct << endl;
cout << " original factor: " << src->factor().params() << endl;
cout << " factor product: " << result.params() << endl;
}
result.sumOutAllExceptIndex (link->index());
if (Constants::SHOW_BP_CALCS) {
cout << " marginalized: " << result.params() << endl;
}
link->getNextMessage() = result.params();
LogAware::normalize (link->getNextMessage());
if (Constants::SHOW_BP_CALCS) {
cout << " curr msg: " << link->getMessage() << endl;
cout << " next msg: " << link->getNextMessage() << endl;
}
}
Params
CbpSolver::getVar2FactorMsg (const SpLink* _link) const
{
const CbpSolverLink* link = static_cast<const CbpSolverLink*> (_link);
const VarNode* src = link->getVariable();
const FacNode* dst = link->getFactor();
Params msg;
if (src->hasEvidence()) {
msg.resize (src->range(), LogAware::noEvidence());
double value = link->getMessage()[src->getEvidence()];
if (Constants::SHOW_BP_CALCS) {
msg[src->getEvidence()] = value;
cout << msg << "^" << link->nrEdges() << "-1" ;
}
msg[src->getEvidence()] = LogAware::pow (value, link->nrEdges() - 1);
} else {
msg = link->getMessage();
if (Constants::SHOW_BP_CALCS) {
cout << msg << "^" << link->nrEdges() << "-1" ;
}
LogAware::pow (msg, link->nrEdges() - 1);
}
const SpLinkSet& links = ninf(src)->getLinks();
if (Globals::logDomain) {
for (size_t i = 0; i < links.size(); i++) {
CbpSolverLink* cl = static_cast<CbpSolverLink*> (links[i]);
if ( ! (cl->getFactor() == dst && cl->index() == link->index())) {
CbpSolverLink* cl = static_cast<CbpSolverLink*> (links[i]);
msg += cl->poweredMessage();
}
}
} else {
for (size_t i = 0; i < links.size(); i++) {
CbpSolverLink* cl = static_cast<CbpSolverLink*> (links[i]);
if ( ! (cl->getFactor() == dst && cl->index() == link->index())) {
msg *= cl->poweredMessage();
if (Constants::SHOW_BP_CALCS) {
cout << " x " << cl->getNextMessage() << "^" << link->nrEdges();
}
}
}
}
if (Constants::SHOW_BP_CALCS) {
cout << " = " << msg;
}
return msg;
} }
void void
CbpSolver::printLinkInformation (void) const CbpSolver::createGroups (void)
{ {
for (size_t i = 0; i < links_.size(); i++) { VarSignMap varGroups;
CbpSolverLink* cl = static_cast<CbpSolverLink*> (links_[i]); FacSignMap facGroups;
cout << cl->toString() << ":" << endl; unsigned nIters = 0;
cout << " curr msg = " << cl->getMessage() << endl; bool groupsHaveChanged = true;
cout << " next msg = " << cl->getNextMessage() << endl; const VarNodes& varNodes = fg.varNodes();
cout << " index = " << cl->index() << endl; const FacNodes& facNodes = fg.facNodes();
cout << " nr edges = " << cl->nrEdges() << endl;
cout << " powered = " << cl->poweredMessage() << endl; while (groupsHaveChanged || nIters == 1) {
cout << " residual = " << cl->getResidual() << endl; nIters ++;
// set a new color to the variables with the same signature
size_t prevVarGroupsSize = varGroups.size();
varGroups.clear();
for (size_t i = 0; i < varNodes.size(); i++) {
const VarSignature& signature = getSignature (varNodes[i]);
VarSignMap::iterator it = varGroups.find (signature);
if (it == varGroups.end()) {
it = varGroups.insert (make_pair (signature, VarNodes())).first;
}
it->second.push_back (varNodes[i]);
}
for (VarSignMap::iterator it = varGroups.begin();
it != varGroups.end(); ++it) {
Color newColor = getNewColor();
VarNodes& groupMembers = it->second;
for (size_t i = 0; i < groupMembers.size(); i++) {
setColor (groupMembers[i], newColor);
}
}
size_t prevFactorGroupsSize = facGroups.size();
facGroups.clear();
// set a new color to the factors with the same signature
for (size_t i = 0; i < facNodes.size(); i++) {
const FacSignature& signature = getSignature (facNodes[i]);
FacSignMap::iterator it = facGroups.find (signature);
if (it == facGroups.end()) {
it = facGroups.insert (make_pair (signature, FacNodes())).first;
}
it->second.push_back (facNodes[i]);
}
for (FacSignMap::iterator it = facGroups.begin();
it != facGroups.end(); ++it) {
Color newColor = getNewColor();
FacNodes& groupMembers = it->second;
for (size_t i = 0; i < groupMembers.size(); i++) {
setColor (groupMembers[i], newColor);
}
}
groupsHaveChanged = prevVarGroupsSize != varGroups.size()
|| prevFactorGroupsSize != facGroups.size();
}
// printGroups (varGroups, facGroups);
createClusters (varGroups, facGroups);
}
void
CbpSolver::createClusters (
const VarSignMap& varGroups,
const FacSignMap& facGroups)
{
varClusters_.reserve (varGroups.size());
for (VarSignMap::const_iterator it = varGroups.begin();
it != varGroups.end(); ++it) {
const VarNodes& groupVars = it->second;
VarCluster* vc = new VarCluster (groupVars);
for (size_t i = 0; i < groupVars.size(); i++) {
vid2VarCluster_.insert (make_pair (groupVars[i]->varId(), vc));
}
varClusters_.push_back (vc);
}
facClusters_.reserve (facGroups.size());
for (FacSignMap::const_iterator it = facGroups.begin();
it != facGroups.end(); ++it) {
FacNode* groupFactor = it->second[0];
const VarNodes& neighs = groupFactor->neighbors();
VarClusters varClusters;
varClusters.reserve (neighs.size());
for (size_t i = 0; i < neighs.size(); i++) {
VarId vid = neighs[i]->varId();
varClusters.push_back (vid2VarCluster_.find (vid)->second);
}
facClusters_.push_back (new FacCluster (it->second, varClusters));
}
}
VarSignature
CbpSolver::getSignature (const VarNode* varNode)
{
const FacNodes& neighs = varNode->neighbors();
VarSignature sign;
sign.reserve (neighs.size() + 1);
for (size_t i = 0; i < neighs.size(); i++) {
sign.push_back (make_pair (
getColor (neighs[i]),
neighs[i]->factor().indexOf (varNode->varId())));
}
std::sort (sign.begin(), sign.end());
sign.push_back (make_pair (getColor (varNode), 0));
return sign;
}
FacSignature
CbpSolver::getSignature (const FacNode* facNode)
{
const VarNodes& neighs = facNode->neighbors();
FacSignature sign;
sign.reserve (neighs.size() + 1);
for (size_t i = 0; i < neighs.size(); i++) {
sign.push_back (getColor (neighs[i]));
}
sign.push_back (getColor (facNode));
return sign;
}
FactorGraph*
CbpSolver::getCompressedFactorGraph (void)
{
FactorGraph* fg = new FactorGraph();
for (size_t i = 0; i < varClusters_.size(); i++) {
VarNode* newVar = new VarNode (varClusters_[i]->first());
varClusters_[i]->setRepresentative (newVar);
fg->addVarNode (newVar);
}
for (size_t i = 0; i < facClusters_.size(); i++) {
Vars vars;
const VarClusters& clusters = facClusters_[i]->varClusters();
for (size_t j = 0; j < clusters.size(); j++) {
vars.push_back (clusters[j]->representative());
}
const Factor& groundFac = facClusters_[i]->first()->factor();
FacNode* fn = new FacNode (Factor (
vars, groundFac.params(), groundFac.distId()));
facClusters_[i]->setRepresentative (fn);
fg->addFacNode (fn);
for (size_t j = 0; j < vars.size(); j++) {
fg->addEdge (static_cast<VarNode*> (vars[j]), fn);
}
}
return fg;
}
vector<vector<unsigned>>
CbpSolver::getWeights (void) const
{
vector<vector<unsigned>> weights;
weights.reserve (facClusters_.size());
for (size_t i = 0; i < facClusters_.size(); i++) {
const VarClusters& neighs = facClusters_[i]->varClusters();
weights.push_back ({ });
weights.back().reserve (neighs.size());
for (size_t j = 0; j < neighs.size(); j++) {
weights.back().push_back (getWeight (
facClusters_[i], neighs[j], j));
}
}
return weights;
}
unsigned
CbpSolver::getWeight (
const FacCluster* fc,
const VarCluster* vc,
size_t index) const
{
unsigned weight = 0;
VarId reprVid = vc->representative()->varId();
VarNode* groundVar = fg.getVarNode (reprVid);
const FacNodes& neighs = groundVar->neighbors();
for (size_t i = 0; i < neighs.size(); i++) {
FacNodes::const_iterator it;
it = std::find (fc->members().begin(), fc->members().end(), neighs[i]);
if (it != fc->members().end() &&
(*it)->factor().indexOf (reprVid) == index) {
weight ++;
}
}
return weight;
}
void
CbpSolver::printGroups (
const VarSignMap& varGroups,
const FacSignMap& facGroups) const
{
unsigned count = 1;
cout << "variable groups:" << endl;
for (VarSignMap::const_iterator it = varGroups.begin();
it != varGroups.end(); ++it) {
const VarNodes& groupMembers = it->second;
if (groupMembers.size() > 0) {
cout << count << ": " ;
for (size_t i = 0; i < groupMembers.size(); i++) {
cout << groupMembers[i]->label() << " " ;
}
count ++;
cout << endl;
}
}
count = 1;
cout << endl << "factor groups:" << endl;
for (FacSignMap::const_iterator it = facGroups.begin();
it != facGroups.end(); ++it) {
const FacNodes& groupMembers = it->second;
if (groupMembers.size() > 0) {
cout << ++count << ": " ;
for (size_t i = 0; i < groupMembers.size(); i++) {
cout << groupMembers[i]->getLabel() << " " ;
}
count ++;
cout << endl;
}
} }
} }

View File

@ -1,67 +1,183 @@
#ifndef HORUS_CBP_H #ifndef HORUS_CBPSOLVER_H
#define HORUS_CBP_H #define HORUS_CBPSOLVER_H
#include "BpSolver.h" #include <unordered_map>
#include "CFactorGraph.h"
class Factor; #include "Solver.h"
#include "FactorGraph.h"
#include "Util.h"
#include "Horus.h"
class CbpSolverLink : public SpLink class VarCluster;
class FacCluster;
class VarSignHash;
class FacSignHash;
class WeightedBpSolver;
typedef long Color;
typedef vector<Color> Colors;
typedef vector<std::pair<Color,unsigned>> VarSignature;
typedef vector<Color> FacSignature;
typedef unordered_map<unsigned, Color> DistColorMap;
typedef unordered_map<unsigned, Colors> VarColorMap;
typedef unordered_map<VarSignature, VarNodes, VarSignHash> VarSignMap;
typedef unordered_map<FacSignature, FacNodes, FacSignHash> FacSignMap;
typedef vector<VarCluster*> VarClusters;
typedef vector<FacCluster*> FacClusters;
typedef unordered_map<VarId, VarCluster*> VarId2VarCluster;
struct VarSignHash
{ {
public: size_t operator() (const VarSignature &sig) const
CbpSolverLink (FacNode* fn, VarNode* vn, size_t idx, unsigned count) {
: SpLink (fn, vn), index_(idx), nrEdges_(count), size_t val = hash<size_t>()(sig.size());
pwdMsg_(vn->range(), LogAware::one()) { } for (size_t i = 0; i < sig.size(); i++) {
val ^= hash<size_t>()(sig[i].first);
size_t index (void) const { return index_; } val ^= hash<size_t>()(sig[i].second);
unsigned nrEdges (void) const { return nrEdges_; }
const Params& poweredMessage (void) const { return pwdMsg_; }
void updateMessage (void)
{
pwdMsg_ = *nextMsg_;
swap (currMsg_, nextMsg_);
msgSended_ = true;
LogAware::pow (pwdMsg_, nrEdges_);
} }
return val;
private: }
size_t index_;
unsigned nrEdges_;
Params pwdMsg_;
}; };
struct FacSignHash
class CbpSolver : public BpSolver
{ {
public: size_t operator() (const FacSignature &sig) const
{
size_t val = hash<size_t>()(sig.size());
for (size_t i = 0; i < sig.size(); i++) {
val ^= hash<size_t>()(sig[i]);
}
return val;
}
};
class VarCluster
{
public:
VarCluster (const VarNodes& vs) : members_(vs) { }
const VarNode* first (void) const { return members_.front(); }
const VarNodes& members (void) const { return members_; }
VarNode* representative (void) const { return repr_; }
void setRepresentative (VarNode* vn) { repr_ = vn; }
private:
VarNodes members_;
VarNode* repr_;
};
class FacCluster
{
public:
FacCluster (const FacNodes& fcs, const VarClusters& vcs)
: members_(fcs), varClusters_(vcs) { }
const FacNode* first (void) const { return members_.front(); }
const FacNodes& members (void) const { return members_; }
VarClusters& varClusters (void) { return varClusters_; }
FacNode* representative (void) const { return repr_; }
void setRepresentative (FacNode* fn) { repr_ = fn; }
private:
FacNodes members_;
VarClusters varClusters_;
FacNode* repr_;
};
class CbpSolver : public Solver
{
public:
CbpSolver (const FactorGraph& fg); CbpSolver (const FactorGraph& fg);
~CbpSolver (void); ~CbpSolver (void);
void printSolverFlags (void) const; void printSolverFlags (void) const;
Params getPosterioriOf (VarId);
Params getJointDistributionOf (const VarIds&); Params solveQuery (VarIds);
static bool checkForIdenticalFactors;
private:
Color getNewColor (void)
{
++ freeColor_;
return freeColor_ - 1;
}
private: Color getColor (const VarNode* vn) const
{
return varColors_[vn->getIndex()];
}
void createLinks (void); Color getColor (const FacNode* fn) const
{
return facColors_[fn->getIndex()];
}
void maxResidualSchedule (void); void setColor (const VarNode* vn, Color c)
{
varColors_[vn->getIndex()] = c;
}
void calculateFactor2VariableMsg (SpLink*); void setColor (const FacNode* fn, Color c)
{
facColors_[fn->getIndex()] = c;
}
Params getVar2FactorMsg (const SpLink*) const; void findIdenticalFactors (void);
void printLinkInformation (void) const; void setInitialColors (void);
CFactorGraph* cfg_; void createGroups (void);
void createClusters (const VarSignMap&, const FacSignMap&);
VarSignature getSignature (const VarNode*);
FacSignature getSignature (const FacNode*);
void printGroups (const VarSignMap&, const FacSignMap&) const;
VarId getRepresentative (VarId vid)
{
assert (Util::contains (vid2VarCluster_, vid));
VarCluster* vc = vid2VarCluster_.find (vid)->second;
return vc->representative()->varId();
}
FactorGraph* getCompressedFactorGraph (void);
vector<vector<unsigned>> getWeights (void) const;
unsigned getWeight (const FacCluster*,
const VarCluster*, size_t index) const;
Color freeColor_;
Colors varColors_;
Colors facColors_;
VarClusters varClusters_;
FacClusters facClusters_;
VarId2VarCluster vid2VarCluster_;
const FactorGraph* compressedFg_;
WeightedBpSolver* solver_;
}; };
#endif // HORUS_CBP_H #endif // HORUS_CBPSOLVER_H

View File

@ -139,12 +139,21 @@ ElimGraph::exportToGraphViz (
VarIds VarIds
ElimGraph::getEliminationOrder ( ElimGraph::getEliminationOrder (
const vector<Factor*> factors, const Factors& factors,
VarIds excludedVids) VarIds excludedVids)
{ {
if (elimHeuristic == ElimHeuristic::SEQUENTIAL) {
VarIds allVids;
Factors::const_iterator first = factors.begin();
Factors::const_iterator end = factors.end();
for (; first != end; ++first) {
Util::addToVector (allVids, (*first)->arguments());
}
TinySet<VarId> elimOrder (allVids);
elimOrder -= TinySet<VarId> (excludedVids);
return elimOrder.elements();
}
ElimGraph graph (factors); ElimGraph graph (factors);
// graph.print();
// graph.exportToGraphViz ("_egg.dot");
return graph.getEliminatingOrder (excludedVids); return graph.getEliminatingOrder (excludedVids);
} }

View File

@ -12,6 +12,7 @@ using namespace std;
enum ElimHeuristic enum ElimHeuristic
{ {
SEQUENTIAL,
MIN_NEIGHBORS, MIN_NEIGHBORS,
MIN_WEIGHT, MIN_WEIGHT,
MIN_FILL, MIN_FILL,
@ -45,7 +46,7 @@ class EgNode : public Var
class ElimGraph class ElimGraph
{ {
public: public:
ElimGraph (const vector<Factor*>&); // TODO ElimGraph (const Factors&);
~ElimGraph (void); ~ElimGraph (void);
@ -56,7 +57,7 @@ class ElimGraph
void exportToGraphViz (const char*, bool = true, void exportToGraphViz (const char*, bool = true,
const VarIds& = VarIds()) const; const VarIds& = VarIds()) const;
static VarIds getEliminationOrder (const vector<Factor*>, VarIds); static VarIds getEliminationOrder (const Factors&, VarIds);
static ElimHeuristic elimHeuristic; static ElimHeuristic elimHeuristic;

View File

@ -187,7 +187,7 @@ class TFactor
bool contains (const vector<T>& args) const bool contains (const vector<T>& args) const
{ {
for (size_t i = 0; i < args_.size(); i++) { for (size_t i = 0; i < args.size(); i++) {
if (contains (args[i]) == false) { if (contains (args[i]) == false) {
return false; return false;
} }

View File

@ -28,7 +28,7 @@ FactorGraph::FactorGraph (const FactorGraph& fg)
addEdge (varNodes_[neighs[j]->getIndex()], facNode); addEdge (varNodes_[neighs[j]->getIndex()], facNode);
} }
} }
fromBayesNet_ = fg.isFromBayesNetwork(); bayesFactors_ = fg.bayesianFactors();
} }
@ -239,7 +239,7 @@ FactorGraph::isTree (void) const
DAGraph& DAGraph&
FactorGraph::getStructure (void) FactorGraph::getStructure (void)
{ {
assert (fromBayesNet_); assert (bayesFactors_);
if (structure_.empty()) { if (structure_.empty()) {
for (size_t i = 0; i < varNodes_.size(); i++) { for (size_t i = 0; i < varNodes_.size(); i++) {
structure_.addNode (new DAGraphNode (varNodes_[i])); structure_.addNode (new DAGraphNode (varNodes_[i]));

View File

@ -65,7 +65,7 @@ class FacNode
class FactorGraph class FactorGraph
{ {
public: public:
FactorGraph (bool fbn = false) : fromBayesNet_(fbn) { } FactorGraph (void) : bayesFactors_(false) { }
FactorGraph (const FactorGraph&); FactorGraph (const FactorGraph&);
@ -74,8 +74,10 @@ class FactorGraph
const VarNodes& varNodes (void) const { return varNodes_; } const VarNodes& varNodes (void) const { return varNodes_; }
const FacNodes& facNodes (void) const { return facNodes_; } const FacNodes& facNodes (void) const { return facNodes_; }
void setFactorsAsBayesian (void) { bayesFactors_ = true; }
bool isFromBayesNetwork (void) const { return fromBayesNet_ ; } bool bayesianFactors (void) const { return bayesFactors_ ; }
size_t nrVarNodes (void) const { return varNodes_.size(); } size_t nrVarNodes (void) const { return varNodes_.size(); }
@ -128,7 +130,7 @@ class FactorGraph
FacNodes facNodes_; FacNodes facNodes_;
DAGraph structure_; DAGraph structure_;
bool fromBayesNet_; bool bayesFactors_;
typedef unordered_map<unsigned, VarNode*> VarMap; typedef unordered_map<unsigned, VarNode*> VarMap;
VarMap varMap_; VarMap varMap_;

View File

@ -630,16 +630,9 @@ GroundOperator::getAffectedFormulas (void)
Params Params
FoveSolver::getPosterioriOf (const Ground& query) FoveSolver::solveQuery (const Grounds& query)
{
return getJointDistributionOf ({query});
}
Params
FoveSolver::getJointDistributionOf (const Grounds& query)
{ {
assert (query.empty() == false);
runSolver (query); runSolver (query);
(*pfList_.begin())->normalize(); (*pfList_.begin())->normalize();
Params params = (*pfList_.begin())->params(); Params params = (*pfList_.begin())->params();
@ -970,7 +963,8 @@ FoveSolver::absorve (
if (commCt->empty() == false) { if (commCt->empty() == false) {
if (formulas.size() > 1) { if (formulas.size() > 1) {
LogVarSet excl = g->exclusiveLogVars (i); LogVarSet excl = g->exclusiveLogVars (i);
Parfactors countNormPfs = countNormalize (g, excl); Parfactor tempPf (g, commCt);
Parfactors countNormPfs = countNormalize (&tempPf, excl);
for (size_t j = 0; j < countNormPfs.size(); j++) { for (size_t j = 0; j < countNormPfs.size(); j++) {
countNormPfs[j]->absorveEvidence ( countNormPfs[j]->absorveEvidence (
formulas[i], obsFormula.evidence()); formulas[i], obsFormula.evidence());

View File

@ -135,9 +135,7 @@ class FoveSolver
public: public:
FoveSolver (const ParfactorList& pfList) : pfList_(pfList) { } FoveSolver (const ParfactorList& pfList) : pfList_(pfList) { }
Params getPosterioriOf (const Ground&); Params solveQuery (const Grounds&);
Params getJointDistributionOf (const Grounds&);
void printSolverFlags (void) const; void printSolverFlags (void) const;

View File

@ -28,11 +28,18 @@ typedef vector<unsigned> Ranges;
typedef unsigned long long ullong; typedef unsigned long long ullong;
enum InfAlgorithms enum LiftedSolvers
{ {
VE, // variable elimination FOVE, // first order variable elimination
BP, // belief propagation LBP, // lifted belief propagation
CBP // counting belief propagation };
enum GroundSolvers
{
VE, // variable elimination
BP, // belief propagation
CBP // counting belief propagation
}; };
@ -43,7 +50,8 @@ extern bool logDomain;
// level of debug information // level of debug information
extern unsigned verbosity; extern unsigned verbosity;
extern InfAlgorithms infAlgorithm; extern LiftedSolvers liftedSolver;
extern GroundSolvers groundSolver;
}; };

View File

@ -161,14 +161,14 @@ void
runSolver (const FactorGraph& fg, const VarIds& queryIds) runSolver (const FactorGraph& fg, const VarIds& queryIds)
{ {
Solver* solver = 0; Solver* solver = 0;
switch (Globals::infAlgorithm) { switch (Globals::groundSolver) {
case InfAlgorithms::VE: case GroundSolvers::VE:
solver = new VarElimSolver (fg); solver = new VarElimSolver (fg);
break; break;
case InfAlgorithms::BP: case GroundSolvers::BP:
solver = new BpSolver (fg); solver = new BpSolver (fg);
break; break;
case InfAlgorithms::CBP: case GroundSolvers::CBP:
solver = new CbpSolver (fg); solver = new CbpSolver (fg);
break; break;
default: default:
@ -178,7 +178,7 @@ runSolver (const FactorGraph& fg, const VarIds& queryIds)
solver->printSolverFlags(); solver->printSolverFlags();
cout << endl; cout << endl;
} }
if (queryIds.size() == 0) { if (queryIds.empty()) {
solver->printAllPosterioris(); solver->printAllPosterioris();
} else { } else {
solver->printAnswer (queryIds); solver->printAnswer (queryIds);

View File

@ -11,6 +11,7 @@
#include "FactorGraph.h" #include "FactorGraph.h"
#include "FoveSolver.h" #include "FoveSolver.h"
#include "VarElimSolver.h" #include "VarElimSolver.h"
#include "LiftedBpSolver.h"
#include "BpSolver.h" #include "BpSolver.h"
#include "CbpSolver.h" #include "CbpSolver.h"
#include "ElimGraph.h" #include "ElimGraph.h"
@ -218,8 +219,10 @@ int
createGroundNetwork (void) createGroundNetwork (void)
{ {
string factorsType ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG1))); string factorsType ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG1)));
bool fromBayesNet = factorsType == "bayes"; FactorGraph* fg = new FactorGraph();
FactorGraph* fg = new FactorGraph (fromBayesNet); if (factorsType == "bayes") {
fg->setFactorsAsBayesian();
}
YAP_Term factorList = YAP_ARG2; YAP_Term factorList = YAP_ARG2;
while (factorList != YAP_TermNil()) { while (factorList != YAP_TermNil()) {
YAP_Term factor = YAP_HeadOfTerm (factorList); YAP_Term factor = YAP_HeadOfTerm (factorList);
@ -308,15 +311,22 @@ runLiftedSolver (void)
} }
jointList = YAP_TailOfTerm (jointList); jointList = YAP_TailOfTerm (jointList);
} }
FoveSolver solver (pfListCopy); if (Globals::liftedSolver == LiftedSolvers::FOVE) {
if (Globals::verbosity > 0 && taskList == YAP_ARG2) { FoveSolver solver (pfListCopy);
solver.printSolverFlags(); if (Globals::verbosity > 0 && taskList == YAP_ARG2) {
cout << endl; solver.printSolverFlags();
} cout << endl;
if (queryVars.size() == 1) { }
results.push_back (solver.getPosterioriOf (queryVars[0])); results.push_back (solver.solveQuery (queryVars));
} else if (Globals::liftedSolver == LiftedSolvers::LBP) {
LiftedBpSolver solver (pfListCopy);
if (Globals::verbosity > 0 && taskList == YAP_ARG2) {
solver.printSolverFlags();
cout << endl;
}
results.push_back (solver.solveQuery (queryVars));
} else { } else {
results.push_back (solver.getJointDistributionOf (queryVars)); assert (false);
} }
taskList = YAP_TailOfTerm (taskList); taskList = YAP_TailOfTerm (taskList);
} }
@ -352,7 +362,7 @@ runGroundSolver (void)
} }
vector<Params> results; vector<Params> results;
if (Globals::infAlgorithm == InfAlgorithms::VE) { if (Globals::groundSolver == GroundSolvers::VE) {
runVeSolver (fg, tasks, results); runVeSolver (fg, tasks, results);
} else { } else {
runBpSolver (fg, tasks, results); runBpSolver (fg, tasks, results);
@ -384,7 +394,7 @@ void runVeSolver (
results.reserve (tasks.size()); results.reserve (tasks.size());
for (size_t i = 0; i < tasks.size(); i++) { for (size_t i = 0; i < tasks.size(); i++) {
FactorGraph* mfg = fg; FactorGraph* mfg = fg;
if (fg->isFromBayesNetwork()) { if (fg->bayesianFactors()) {
// mfg = BayesBall::getMinimalFactorGraph (*fg, tasks[i]); // mfg = BayesBall::getMinimalFactorGraph (*fg, tasks[i]);
} }
// VarElimSolver solver (*mfg); // VarElimSolver solver (*mfg);
@ -394,7 +404,7 @@ void runVeSolver (
cout << endl; cout << endl;
} }
results.push_back (solver.solveQuery (tasks[i])); results.push_back (solver.solveQuery (tasks[i]));
if (fg->isFromBayesNetwork()) { if (fg->bayesianFactors()) {
// delete mfg; // delete mfg;
} }
} }
@ -413,14 +423,14 @@ void runBpSolver (
} }
Solver* solver = 0; Solver* solver = 0;
FactorGraph* mfg = fg; FactorGraph* mfg = fg;
if (fg->isFromBayesNetwork()) { if (fg->bayesianFactors()) {
//mfg = BayesBall::getMinimalFactorGraph ( //mfg = BayesBall::getMinimalFactorGraph (
// *fg, VarIds (vids.begin(),vids.end())); // *fg, VarIds (vids.begin(),vids.end()));
} }
if (Globals::infAlgorithm == InfAlgorithms::BP) { if (Globals::groundSolver == GroundSolvers::BP) {
solver = new BpSolver (*fg); // FIXME solver = new BpSolver (*fg); // FIXME
} else if (Globals::infAlgorithm == InfAlgorithms::CBP) { } else if (Globals::groundSolver == GroundSolvers::CBP) {
CFactorGraph::checkForIdenticalFactors = false; CbpSolver::checkForIdenticalFactors = false;
solver = new CbpSolver (*fg); // FIXME solver = new CbpSolver (*fg); // FIXME
} else { } else {
cerr << "error: unknow solver" << endl; cerr << "error: unknow solver" << endl;
@ -434,7 +444,7 @@ void runBpSolver (
for (size_t i = 0; i < tasks.size(); i++) { for (size_t i = 0; i < tasks.size(); i++) {
results.push_back (solver->solveQuery (tasks[i])); results.push_back (solver->solveQuery (tasks[i]));
} }
if (fg->isFromBayesNetwork()) { if (fg->bayesianFactors()) {
//delete mfg; //delete mfg;
} }
delete solver; delete solver;

View File

@ -2,7 +2,7 @@
#define HORUS_INDEXER_H #define HORUS_INDEXER_H
#include <algorithm> #include <algorithm>
#include <functional> #include <numeric>
#include <sstream> #include <sstream>
#include <iomanip> #include <iomanip>
@ -14,10 +14,9 @@ class Indexer
{ {
public: public:
Indexer (const Ranges& ranges, bool calcOffsets = true) Indexer (const Ranges& ranges, bool calcOffsets = true)
: index_(0), indices_(ranges.size(), 0), ranges_(ranges) : index_(0), indices_(ranges.size(), 0), ranges_(ranges),
size_(Util::sizeExpected (ranges))
{ {
size_ = std::accumulate (ranges.begin(), ranges.end(), 1,
std::multiplies<unsigned>());
if (calcOffsets) { if (calcOffsets) {
calculateOffsets(); calculateOffsets();
} }

View File

@ -0,0 +1,148 @@
#include "LiftedBpSolver.h"
#include "WeightedBpSolver.h"
#include "FactorGraph.h"
#include "FoveSolver.h"
LiftedBpSolver::LiftedBpSolver (const ParfactorList& pfList)
: pfList_(pfList)
{
refineParfactors();
solver_ = new WeightedBpSolver (*getFactorGraph(), getWeights());
}
Params
LiftedBpSolver::solveQuery (const Grounds& query)
{
assert (query.empty() == false);
Params res;
vector<PrvGroup> groups = getQueryGroups (query);
if (query.size() == 1) {
res = solver_->getPosterioriOf (groups[0]);
} else {
VarIds queryVids;
for (unsigned i = 0; i < groups.size(); i++) {
queryVids.push_back (groups[i]);
}
res = solver_->getJointDistributionOf (queryVids);
}
return res;
}
void
LiftedBpSolver::printSolverFlags (void) const
{
stringstream ss;
ss << "lifted bp [" ;
ss << "schedule=" ;
typedef BpOptions::Schedule Sch;
switch (BpOptions::schedule) {
case Sch::SEQ_FIXED: ss << "seq_fixed"; break;
case Sch::SEQ_RANDOM: ss << "seq_random"; break;
case Sch::PARALLEL: ss << "parallel"; break;
case Sch::MAX_RESIDUAL: ss << "max_residual"; break;
}
ss << ",max_iter=" << BpOptions::maxIter;
ss << ",accuracy=" << BpOptions::accuracy;
ss << ",log_domain=" << Util::toString (Globals::logDomain);
ss << "]" ;
cout << ss.str() << endl;
}
void
LiftedBpSolver::refineParfactors (void)
{
while (iterate() == false);
if (Globals::verbosity > 2) {
Util::printHeader ("AFTER REFINEMENT");
pfList_.print();
}
}
bool
LiftedBpSolver::iterate (void)
{
ParfactorList::iterator it = pfList_.begin();
while (it != pfList_.end()) {
const ProbFormulas& args = (*it)->arguments();
for (size_t i = 0; i < args.size(); i++) {
LogVarSet lvs = (*it)->logVarSet() - args[i].logVars();
if ((*it)->constr()->isCountNormalized (lvs) == false) {
Parfactors pfs = FoveSolver::countNormalize (*it, lvs);
it = pfList_.removeAndDelete (it);
pfList_.add (pfs);
return false;
}
}
++ it;
}
return true;
}
vector<PrvGroup>
LiftedBpSolver::getQueryGroups (const Grounds& query)
{
vector<PrvGroup> queryGroups;
for (unsigned i = 0; i < query.size(); i++) {
ParfactorList::const_iterator it = pfList_.begin();
for (; it != pfList_.end(); ++it) {
if ((*it)->containsGround (query[i])) {
queryGroups.push_back ((*it)->findGroup (query[i]));
break;
}
}
}
assert (queryGroups.size() == query.size());
return queryGroups;
}
FactorGraph*
LiftedBpSolver::getFactorGraph (void)
{
FactorGraph* fg = new FactorGraph();
ParfactorList::const_iterator it = pfList_.begin();
for (; it != pfList_.end(); ++it) {
vector<PrvGroup> groups = (*it)->getAllGroups();
VarIds varIds;
for (size_t i = 0; i < groups.size(); i++) {
varIds.push_back (groups[i]);
}
fg->addFactor (Factor (varIds, (*it)->ranges(), (*it)->params()));
}
return fg;
}
vector<vector<unsigned>>
LiftedBpSolver::getWeights (void) const
{
vector<vector<unsigned>> weights;
weights.reserve (pfList_.size());
ParfactorList::const_iterator it = pfList_.begin();
for (; it != pfList_.end(); ++it) {
const ProbFormulas& args = (*it)->arguments();
weights.push_back ({ });
weights.back().reserve (args.size());
for (size_t i = 0; i < args.size(); i++) {
LogVarSet lvs = (*it)->logVarSet() - args[i].logVars();
weights.back().push_back ((*it)->constr()->getConditionalCount (lvs));
}
}
return weights;
}

View File

@ -0,0 +1,34 @@
#ifndef HORUS_LIFTEDBPSOLVER_H
#define HORUS_LIFTEDBPSOLVER_H
#include "ParfactorList.h"
class FactorGraph;
class WeightedBpSolver;
class LiftedBpSolver
{
public:
LiftedBpSolver (const ParfactorList& pfList);
Params solveQuery (const Grounds&);
void printSolverFlags (void) const;
private:
void refineParfactors (void);
bool iterate (void);
vector<PrvGroup> getQueryGroups (const Grounds&);
FactorGraph* getFactorGraph (void);
vector<vector<unsigned>> getWeights (void) const;
ParfactorList pfList_;
WeightedBpSolver* solver_;
};
#endif // HORUS_LIFTEDBPSOLVER_H

View File

@ -50,7 +50,6 @@ HEADERS = \
$(srcdir)/ElimGraph.h \ $(srcdir)/ElimGraph.h \
$(srcdir)/FactorGraph.h \ $(srcdir)/FactorGraph.h \
$(srcdir)/Factor.h \ $(srcdir)/Factor.h \
$(srcdir)/CFactorGraph.h \
$(srcdir)/ConstraintTree.h \ $(srcdir)/ConstraintTree.h \
$(srcdir)/Solver.h \ $(srcdir)/Solver.h \
$(srcdir)/VarElimSolver.h \ $(srcdir)/VarElimSolver.h \
@ -65,6 +64,8 @@ HEADERS = \
$(srcdir)/ParfactorList.h \ $(srcdir)/ParfactorList.h \
$(srcdir)/LiftedUtils.h \ $(srcdir)/LiftedUtils.h \
$(srcdir)/TinySet.h \ $(srcdir)/TinySet.h \
$(srcdir)/LiftedBpSolver.h \
$(srcdir)/WeightedBpSolver.h \
$(srcdir)/Util.h \ $(srcdir)/Util.h \
$(srcdir)/Horus.h $(srcdir)/Horus.h
@ -74,7 +75,6 @@ CPP_SOURCES = \
$(srcdir)/ElimGraph.cpp \ $(srcdir)/ElimGraph.cpp \
$(srcdir)/FactorGraph.cpp \ $(srcdir)/FactorGraph.cpp \
$(srcdir)/Factor.cpp \ $(srcdir)/Factor.cpp \
$(srcdir)/CFactorGraph.cpp \
$(srcdir)/ConstraintTree.cpp \ $(srcdir)/ConstraintTree.cpp \
$(srcdir)/Var.cpp \ $(srcdir)/Var.cpp \
$(srcdir)/Solver.cpp \ $(srcdir)/Solver.cpp \
@ -88,6 +88,8 @@ CPP_SOURCES = \
$(srcdir)/ParfactorList.cpp \ $(srcdir)/ParfactorList.cpp \
$(srcdir)/LiftedUtils.cpp \ $(srcdir)/LiftedUtils.cpp \
$(srcdir)/Util.cpp \ $(srcdir)/Util.cpp \
$(srcdir)/LiftedBpSolver.cpp \
$(srcdir)/WeightedBpSolver.cpp \
$(srcdir)/HorusYap.cpp \ $(srcdir)/HorusYap.cpp \
$(srcdir)/HorusCli.cpp $(srcdir)/HorusCli.cpp
@ -97,7 +99,6 @@ OBJS = \
ElimGraph.o \ ElimGraph.o \
FactorGraph.o \ FactorGraph.o \
Factor.o \ Factor.o \
CFactorGraph.o \
ConstraintTree.o \ ConstraintTree.o \
Var.o \ Var.o \
Solver.o \ Solver.o \
@ -111,6 +112,8 @@ OBJS = \
ParfactorList.o \ ParfactorList.o \
LiftedUtils.o \ LiftedUtils.o \
Util.o \ Util.o \
LiftedBpSolver.o \
WeightedBpSolver.o \
HorusYap.o HorusYap.o
HCLI_OBJS = \ HCLI_OBJS = \
@ -119,7 +122,6 @@ HCLI_OBJS = \
ElimGraph.o \ ElimGraph.o \
FactorGraph.o \ FactorGraph.o \
Factor.o \ Factor.o \
CFactorGraph.o \
ConstraintTree.o \ ConstraintTree.o \
Var.o \ Var.o \
Solver.o \ Solver.o \
@ -131,6 +133,7 @@ HCLI_OBJS = \
ProbFormula.o \ ProbFormula.o \
Histogram.o \ Histogram.o \
ParfactorList.o \ ParfactorList.o \
WeightedBpSolver.o \
LiftedUtils.o \ LiftedUtils.o \
Util.o \ Util.o \
HorusCli.o HorusCli.o

View File

@ -14,14 +14,16 @@ Solver::printAnswer (const VarIds& vids)
unobservedVids.push_back (vids[i]); unobservedVids.push_back (vids[i]);
} }
} }
Params res = solveQuery (unobservedVids); if (unobservedVids.empty() == false) {
vector<string> stateLines = Util::getStateLines (unobservedVars); Params res = solveQuery (unobservedVids);
for (size_t i = 0; i < res.size(); i++) { vector<string> stateLines = Util::getStateLines (unobservedVars);
cout << "P(" << stateLines[i] << ") = " ; for (size_t i = 0; i < res.size(); i++) {
cout << std::setprecision (Constants::PRECISION) << res[i]; cout << "P(" << stateLines[i] << ") = " ;
cout << std::setprecision (Constants::PRECISION) << res[i];
cout << endl;
}
cout << endl; cout << endl;
} }
cout << endl;
} }
@ -29,14 +31,10 @@ Solver::printAnswer (const VarIds& vids)
void void
Solver::printAllPosterioris (void) Solver::printAllPosterioris (void)
{ {
VarIds vids; VarNodes vars = fg.varNodes();
const VarNodes& vars = fg.varNodes(); std::sort (vars.begin(), vars.end(), sortByVarId());
for (size_t i = 0; i < vars.size(); i++) { for (size_t i = 0; i < vars.size(); i++) {
vids.push_back (vars[i]->varId()); printAnswer ({vars[i]->varId()});
}
std::sort (vids.begin(), vids.end());
for (size_t i = 0; i < vids.size(); i++) {
printAnswer ({vids[i]});
} }
} }

View File

@ -1,8 +1,5 @@
- Check if evidence remains in the compressed factor graph
- Consider using hashs instead of vectors of colors to calculate the groups in
counting bp
- Find a way to decrease the time required to find an - Find a way to decrease the time required to find an
elimination order for variable elimination elimination order for variable elimination
- Add a sequential elimination heuristic - Consider using hashs instead of vectors of colors to calculate the groups in
counting bp

View File

@ -12,7 +12,10 @@ class TinySet
{ {
public: public:
TinySet (const TinySet& s) typedef typename vector<T>::iterator iterator;
typedef typename vector<T>::const_iterator const_iterator;
TinySet (const TinySet& s)
: vec_(s.vec_), cmp_(s.cmp_) { } : vec_(s.vec_), cmp_(s.cmp_) { }
TinySet (const Compare& cmp = Compare()) TinySet (const Compare& cmp = Compare())
@ -25,11 +28,10 @@ class TinySet
: vec_(elements), cmp_(cmp) : vec_(elements), cmp_(cmp)
{ {
std::sort (begin(), end(), cmp_); std::sort (begin(), end(), cmp_);
iterator it = unique_cmp (begin(), end());
vec_.resize (it - begin());
} }
typedef typename vector<T>::iterator iterator;
typedef typename vector<T>::const_iterator const_iterator;
iterator insert (const T& t) iterator insert (const T& t)
{ {
iterator it = std::lower_bound (begin(), end(), t, cmp_); iterator it = std::lower_bound (begin(), end(), t, cmp_);
@ -224,11 +226,25 @@ class TinySet
} }
private: private:
iterator unique_cmp (iterator first, iterator last)
{
if (first == last) {
return last;
}
iterator result = first;
while (++first != last) {
if (cmp_(*result, *first)) {
*(++result) = *first;
}
}
return ++result;
}
bool consistent (void) const bool consistent (void) const
{ {
typename vector<T>::size_type i; typename vector<T>::size_type i;
for (i = 0; i < vec_.size() - 1; i++) { for (i = 0; i < vec_.size() - 1; i++) {
if (cmp_(vec_[i], vec_[i + 1]) == false) { if ( ! cmp_(vec_[i], vec_[i + 1])) {
return false; return false;
} }
} }

View File

@ -13,9 +13,11 @@ bool logDomain = false;
unsigned verbosity = 0; unsigned verbosity = 0;
InfAlgorithms infAlgorithm = InfAlgorithms::VE; LiftedSolvers liftedSolver = LiftedSolvers::FOVE;
};
GroundSolvers groundSolver = GroundSolvers::VE;
};
@ -71,7 +73,6 @@ stringToDouble (string str)
double double
factorial (unsigned num) factorial (unsigned num)
{ {
@ -128,8 +129,8 @@ nrCombinations (unsigned n, unsigned k)
size_t size_t
sizeExpected (const Ranges& ranges) sizeExpected (const Ranges& ranges)
{ {
return std::accumulate ( return std::accumulate (ranges.begin(),
ranges.begin(), ranges.end(), 1, multiplies<unsigned>()); ranges.end(), 1, multiplies<unsigned>());
} }
@ -208,20 +209,32 @@ setHorusFlag (string key, string value)
stringstream ss; stringstream ss;
ss << value; ss << value;
ss >> Globals::verbosity; ss >> Globals::verbosity;
} else if (key == "inf_alg") { } else if (key == "lifted_solver") {
if ( value == "fove") {
Globals::liftedSolver = LiftedSolvers::FOVE;
} else if (value == "lbp") {
Globals::liftedSolver = LiftedSolvers::LBP;
} else {
cerr << "warning: invalid value `" << value << "' " ;
cerr << "for `" << key << "'" << endl;
returnVal = false;
}
} else if (key == "ground_solver") {
if ( value == "ve") { if ( value == "ve") {
Globals::infAlgorithm = InfAlgorithms::VE; Globals::groundSolver = GroundSolvers::VE;
} else if (value == "bp") { } else if (value == "bp") {
Globals::infAlgorithm = InfAlgorithms::BP; Globals::groundSolver = GroundSolvers::BP;
} else if (value == "cbp") { } else if (value == "cbp") {
Globals::infAlgorithm = InfAlgorithms::CBP; Globals::groundSolver = GroundSolvers::CBP;
} else { } else {
cerr << "warning: invalid value `" << value << "' " ; cerr << "warning: invalid value `" << value << "' " ;
cerr << "for `" << key << "'" << endl; cerr << "for `" << key << "'" << endl;
returnVal = false; returnVal = false;
} }
} else if (key == "elim_heuristic") { } else if (key == "elim_heuristic") {
if ( value == "min_neighbors") { if ( value == "sequential") {
ElimGraph::elimHeuristic = ElimHeuristic::SEQUENTIAL;
} else if (value == "min_neighbors") {
ElimGraph::elimHeuristic = ElimHeuristic::MIN_NEIGHBORS; ElimGraph::elimHeuristic = ElimHeuristic::MIN_NEIGHBORS;
} else if (value == "min_weight") { } else if (value == "min_weight") {
ElimGraph::elimHeuristic = ElimHeuristic::MIN_WEIGHT; ElimGraph::elimHeuristic = ElimHeuristic::MIN_WEIGHT;
@ -323,23 +336,15 @@ namespace LogAware {
void void
normalize (Params& v) normalize (Params& v)
{ {
double sum = LogAware::addIdenty();
if (Globals::logDomain) { if (Globals::logDomain) {
for (size_t i = 0; i < v.size(); i++) { double sum = std::accumulate (v.begin(), v.end(),
sum = Util::logSum (sum, v[i]); LogAware::addIdenty(), Util::logSum);
}
assert (sum != -numeric_limits<double>::infinity()); assert (sum != -numeric_limits<double>::infinity());
for (size_t i = 0; i < v.size(); i++) { v -= sum;
v[i] -= sum;
}
} else { } else {
for (size_t i = 0; i < v.size(); i++) { double sum = std::accumulate (v.begin(), v.end(), 0.0);
sum += v[i];
}
assert (sum != 0.0); assert (sum != 0.0);
for (size_t i = 0; i < v.size(); i++) { v /= sum;
v[i] /= sum;
}
} }
} }
@ -351,13 +356,11 @@ getL1Distance (const Params& v1, const Params& v2)
assert (v1.size() == v2.size()); assert (v1.size() == v2.size());
double dist = 0.0; double dist = 0.0;
if (Globals::logDomain) { if (Globals::logDomain) {
for (size_t i = 0; i < v1.size(); i++) { dist = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0,
dist += abs (exp(v1[i]) - exp(v2[i])); std::plus<double>(), FuncObject::abs_diff_exp<double>());
}
} else { } else {
for (size_t i = 0; i < v1.size(); i++) { dist = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0,
dist += abs (v1[i] - v2[i]); std::plus<double>(), FuncObject::abs_diff<double>());
}
} }
return dist; return dist;
} }
@ -370,19 +373,11 @@ getMaxNorm (const Params& v1, const Params& v2)
assert (v1.size() == v2.size()); assert (v1.size() == v2.size());
double max = 0.0; double max = 0.0;
if (Globals::logDomain) { if (Globals::logDomain) {
for (size_t i = 0; i < v1.size(); i++) { max = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0,
double diff = abs (exp(v1[i]) - exp(v2[i])); FuncObject::max<double>(), FuncObject::abs_diff_exp<double>());
if (diff > max) {
max = diff;
}
}
} else { } else {
for (size_t i = 0; i < v1.size(); i++) { max = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0,
double diff = abs (v1[i] - v2[i]); FuncObject::max<double>(), FuncObject::abs_diff<double>());
if (diff > max) {
max = diff;
}
}
} }
return max; return max;
} }
@ -392,7 +387,9 @@ getMaxNorm (const Params& v1, const Params& v2)
double double
pow (double base, unsigned iexp) pow (double base, unsigned iexp)
{ {
return Globals::logDomain ? base * iexp : std::pow (base, iexp); return Globals::logDomain
? base * iexp
: std::pow (base, iexp);
} }
@ -400,8 +397,10 @@ pow (double base, unsigned iexp)
double double
pow (double base, double exp) pow (double base, double exp)
{ {
// assumes that `expoent' is never in log domain // `expoent' should not be in log domain
return Globals::logDomain ? base * exp : std::pow (base, exp); return Globals::logDomain
? base * exp
: std::pow (base, exp);
} }

View File

@ -41,6 +41,9 @@ template <typename K, typename V> bool contains (
template <typename T> size_t indexOf (const vector<T>&, const T&); template <typename T> size_t indexOf (const vector<T>&, const T&);
template <class Operation>
void apply_n_times (Params& v1, const Params& v2, unsigned repetitions, Operation);
template <typename T> void log (vector<T>&); template <typename T> void log (vector<T>&);
template <typename T> void exp (vector<T>&); template <typename T> void exp (vector<T>&);
@ -54,10 +57,6 @@ template <> std::string toString (const bool&);
double logSum (double, double); double logSum (double, double);
void add (Params&, const Params&, unsigned);
void multiply (Params&, const Params&, unsigned);
unsigned maxUnsigned (void); unsigned maxUnsigned (void);
unsigned stringToUnsigned (string); unsigned stringToUnsigned (string);
@ -153,10 +152,29 @@ Util::indexOf (const vector<T>& v, const T& e)
template <class Operation> void
Util::apply_n_times (Params& v1, const Params& v2, unsigned repetitions,
Operation unary_op)
{
Params::iterator first = v1.begin();
Params::const_iterator last = v1.end();
Params::const_iterator first2 = v2.begin();
Params::const_iterator last2 = v2.end();
while (first != last) {
for (first2 = v2.begin(); first2 != last2; ++first2) {
std::transform (first, first + repetitions, first,
std::bind1st (unary_op, *first2));
first += repetitions;
}
}
}
template <typename T> void template <typename T> void
Util::log (vector<T>& v) Util::log (vector<T>& v)
{ {
transform (v.begin(), v.end(), v.begin(), ::log); std::transform (v.begin(), v.end(), v.begin(), ::log);
} }
@ -164,7 +182,7 @@ Util::log (vector<T>& v)
template <typename T> void template <typename T> void
Util::exp (vector<T>& v) Util::exp (vector<T>& v)
{ {
transform (v.begin(), v.end(), v.begin(), ::exp); std::transform (v.begin(), v.end(), v.begin(), ::exp);
} }
@ -224,36 +242,6 @@ Util::logSum (double x, double y)
inline void
Util::add (Params& v1, const Params& v2, unsigned repetitions)
{
for (size_t count = 0; count < v1.size(); ) {
for (size_t i = 0; i < v2.size(); i++) {
for (unsigned r = 0; r < repetitions; r++) {
v1[count] += v2[i];
count ++;
}
}
}
}
inline void
Util::multiply (Params& v1, const Params& v2, unsigned repetitions)
{
for (size_t count = 0; count < v1.size(); ) {
for (size_t i = 0; i < v2.size(); i++) {
for (unsigned r = 0; r < repetitions; r++) {
v1[count] *= v2[i];
count ++;
}
}
}
}
inline unsigned inline unsigned
Util::maxUnsigned (void) Util::maxUnsigned (void)
{ {
@ -273,7 +261,6 @@ inline double noEvidence() { return Globals::logDomain ? NEG_INF : 0.0; }
inline double log (double v) { return Globals::logDomain ? ::log (v) : v; } inline double log (double v) { return Globals::logDomain ? ::log (v) : v; }
inline double exp (double v) { return Globals::logDomain ? ::exp (v) : v; } inline double exp (double v) { return Globals::logDomain ? ::exp (v) : v; }
void normalize (Params&); void normalize (Params&);
double getL1Distance (const Params&, const Params&); double getL1Distance (const Params&, const Params&);
@ -296,7 +283,7 @@ template <typename T>
void operator+=(std::vector<T>& v, double val) void operator+=(std::vector<T>& v, double val)
{ {
std::transform (v.begin(), v.end(), v.begin(), std::transform (v.begin(), v.end(), v.begin(),
std::bind1st (plus<double>(), val)); std::bind2nd (plus<double>(), val));
} }
@ -305,7 +292,7 @@ template <typename T>
void operator-=(std::vector<T>& v, double val) void operator-=(std::vector<T>& v, double val)
{ {
std::transform (v.begin(), v.end(), v.begin(), std::transform (v.begin(), v.end(), v.begin(),
std::bind1st (minus<double>(), val)); std::bind2nd (minus<double>(), val));
} }
@ -314,7 +301,7 @@ template <typename T>
void operator*=(std::vector<T>& v, double val) void operator*=(std::vector<T>& v, double val)
{ {
std::transform (v.begin(), v.end(), v.begin(), std::transform (v.begin(), v.end(), v.begin(),
std::bind1st (multiplies<double>(), val)); std::bind2nd (multiplies<double>(), val));
} }
@ -323,7 +310,7 @@ template <typename T>
void operator/=(std::vector<T>& v, double val) void operator/=(std::vector<T>& v, double val)
{ {
std::transform (v.begin(), v.end(), v.begin(), std::transform (v.begin(), v.end(), v.begin(),
std::bind1st (divides<double>(), val)); std::bind2nd (divides<double>(), val));
} }
@ -395,5 +382,41 @@ std::ostream& operator << (std::ostream& os, const vector<T>& v)
return os; return os;
} }
namespace FuncObject {
template<typename T>
struct max : public std::binary_function<T, T, T>
{
T operator() (const T& x, const T& y) const
{
return x < y ? y : x;
}
};
template <typename T>
struct abs_diff : public std::binary_function<T, T, T>
{
T operator() (const T& x, const T& y) const
{
return std::abs (x - y);
}
};
template <typename T>
struct abs_diff_exp : public std::binary_function<T, T, T>
{
T operator() (const T& x, const T& y) const
{
return std::abs (std::exp (x) - std::exp (y));
}
};
}
#endif // HORUS_UTIL_H #endif // HORUS_UTIL_H

View File

@ -48,6 +48,7 @@ VarElimSolver::printSolverFlags (void) const
ss << "elim_heuristic=" ; ss << "elim_heuristic=" ;
ElimHeuristic eh = ElimGraph::elimHeuristic; ElimHeuristic eh = ElimGraph::elimHeuristic;
switch (eh) { switch (eh) {
case SEQUENTIAL: ss << "sequential"; break;
case MIN_NEIGHBORS: ss << "min_neighbors"; break; case MIN_NEIGHBORS: ss << "min_neighbors"; break;
case MIN_WEIGHT: ss << "min_weight"; break; case MIN_WEIGHT: ss << "min_weight"; break;
case MIN_FILL: ss << "min_fill"; break; case MIN_FILL: ss << "min_fill"; break;

View File

@ -0,0 +1,288 @@
#include "WeightedBpSolver.h"
WeightedBpSolver::~WeightedBpSolver (void)
{
for (size_t i = 0; i < links_.size(); i++) {
delete links_[i];
}
links_.clear();
}
Params
WeightedBpSolver::getPosterioriOf (VarId vid)
{
if (runned_ == false) {
runSolver();
}
VarNode* var = fg.getVarNode (vid);
assert (var != 0);
Params probs;
if (var->hasEvidence()) {
probs.resize (var->range(), LogAware::noEvidence());
probs[var->getEvidence()] = LogAware::withEvidence();
} else {
probs.resize (var->range(), LogAware::multIdenty());
const BpLinks& links = ninf(var)->getLinks();
if (Globals::logDomain) {
for (size_t i = 0; i < links.size(); i++) {
WeightedLink* l = static_cast<WeightedLink*> (links[i]);
probs += l->powMessage();
}
LogAware::normalize (probs);
Util::exp (probs);
} else {
for (size_t i = 0; i < links.size(); i++) {
WeightedLink* l = static_cast<WeightedLink*> (links[i]);
probs *= l->powMessage();
}
LogAware::normalize (probs);
}
}
return probs;
}
void
WeightedBpSolver::createLinks (void)
{
if (Globals::verbosity > 0) {
cout << "compressed factor graph contains " ;
cout << fg.nrVarNodes() << " variables and " ;
cout << fg.nrFacNodes() << " factors " << endl;
cout << endl;
}
const FacNodes& facNodes = fg.facNodes();
for (size_t i = 0; i < facNodes.size(); i++) {
const VarNodes& neighs = facNodes[i]->neighbors();
for (size_t j = 0; j < neighs.size(); j++) {
if (Globals::verbosity > 1) {
cout << "creating link " ;
cout << facNodes[i]->getLabel();
cout << " -- " ;
cout << neighs[j]->label();
cout << " idx=" << j << ", weight=" << weights_[i][j] << endl;
}
links_.push_back (new WeightedLink (
facNodes[i], neighs[j], j, weights_[i][j]));
}
}
if (Globals::verbosity > 1) {
cout << endl;
}
}
void
WeightedBpSolver::maxResidualSchedule (void)
{
if (nIters_ == 1) {
for (size_t i = 0; i < links_.size(); i++) {
calculateMessage (links_[i]);
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
linkMap_.insert (make_pair (links_[i], it));
if (Globals::verbosity >= 1) {
cout << "calculating " << links_[i]->toString() << endl;
}
}
return;
}
for (size_t c = 0; c < links_.size(); c++) {
if (Globals::verbosity > 1) {
cout << endl << "current residuals:" << endl;
for (SortedOrder::iterator it = sortedOrder_.begin();
it != sortedOrder_.end(); ++it) {
cout << " " << setw (30) << left << (*it)->toString();
cout << "residual = " << (*it)->residual() << endl;
}
}
SortedOrder::iterator it = sortedOrder_.begin();
BpLink* link = *it;
if (Globals::verbosity >= 1) {
cout << "updating " << (*sortedOrder_.begin())->toString() << endl;
}
if (link->residual() < BpOptions::accuracy) {
return;
}
link->updateMessage();
link->clearResidual();
sortedOrder_.erase (it);
linkMap_.find (link)->second = sortedOrder_.insert (link);
// update the messages that depend on message source --> destin
const FacNodes& factorNeighbors = link->varNode()->neighbors();
for (size_t i = 0; i < factorNeighbors.size(); i++) {
const BpLinks& links = ninf(factorNeighbors[i])->getLinks();
for (size_t j = 0; j < links.size(); j++) {
if (links[j]->varNode() != link->varNode()) {
if (Globals::verbosity > 1) {
cout << " calculating " << links[j]->toString() << endl;
}
calculateMessage (links[j]);
BpLinkMap::iterator iter = linkMap_.find (links[j]);
sortedOrder_.erase (iter->second);
iter->second = sortedOrder_.insert (links[j]);
}
}
}
// in counting bp, the message that a variable X sends to
// to a factor F depends on the message that F sent to the X
const BpLinks& links = ninf(link->facNode())->getLinks();
for (size_t i = 0; i < links.size(); i++) {
if (links[i]->varNode() != link->varNode()) {
if (Globals::verbosity > 1) {
cout << " calculating " << links[i]->toString() << endl;
}
calculateMessage (links[i]);
BpLinkMap::iterator iter = linkMap_.find (links[i]);
sortedOrder_.erase (iter->second);
iter->second = sortedOrder_.insert (links[i]);
}
}
}
}
void
WeightedBpSolver::calcFactorToVarMsg (BpLink* _link)
{
WeightedLink* link = static_cast<WeightedLink*> (_link);
FacNode* src = link->facNode();
const VarNode* dst = link->varNode();
const BpLinks& links = ninf(src)->getLinks();
// calculate the product of messages that were sent
// to factor `src', except from var `dst'
unsigned reps = 1;
unsigned msgSize = Util::sizeExpected (src->factor().ranges());
Params msgProduct (msgSize, LogAware::multIdenty());
if (Globals::logDomain) {
for (size_t i = links.size(); i-- > 0; ) {
const WeightedLink* l = static_cast<const WeightedLink*> (links[i]);
if ( ! (l->varNode() == dst && l->index() == link->index())) {
if (Constants::SHOW_BP_CALCS) {
cout << " message from " << links[i]->varNode()->label();
cout << ": " ;
}
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
reps, std::plus<double>());
if (Constants::SHOW_BP_CALCS) {
cout << endl;
}
}
reps *= links[i]->varNode()->range();
}
} else {
for (size_t i = links.size(); i-- > 0; ) {
const WeightedLink* l = static_cast<const WeightedLink*> (links[i]);
if ( ! (l->varNode() == dst && l->index() == link->index())) {
if (Constants::SHOW_BP_CALCS) {
cout << " message from " << links[i]->varNode()->label();
cout << ": " ;
}
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
reps, std::multiplies<double>());
if (Constants::SHOW_BP_CALCS) {
cout << endl;
}
}
reps *= links[i]->varNode()->range();
}
}
Factor result (src->factor().arguments(),
src->factor().ranges(), msgProduct);
assert (msgProduct.size() == src->factor().size());
if (Globals::logDomain) {
result.params() += src->factor().params();
} else {
result.params() *= src->factor().params();
}
if (Constants::SHOW_BP_CALCS) {
cout << " message product: " << msgProduct << endl;
cout << " original factor: " << src->factor().params() << endl;
cout << " factor product: " << result.params() << endl;
}
result.sumOutAllExceptIndex (link->index());
if (Constants::SHOW_BP_CALCS) {
cout << " marginalized: " << result.params() << endl;
}
link->nextMessage() = result.params();
LogAware::normalize (link->nextMessage());
if (Constants::SHOW_BP_CALCS) {
cout << " curr msg: " << link->message() << endl;
cout << " next msg: " << link->nextMessage() << endl;
}
}
Params
WeightedBpSolver::getVarToFactorMsg (const BpLink* _link) const
{
const WeightedLink* link = static_cast<const WeightedLink*> (_link);
const VarNode* src = link->varNode();
const FacNode* dst = link->facNode();
Params msg;
if (src->hasEvidence()) {
msg.resize (src->range(), LogAware::noEvidence());
double value = link->message()[src->getEvidence()];
if (Constants::SHOW_BP_CALCS) {
msg[src->getEvidence()] = value;
cout << msg << "^" << link->weight() << "-1" ;
}
msg[src->getEvidence()] = LogAware::pow (value, link->weight() - 1);
} else {
msg = link->message();
if (Constants::SHOW_BP_CALCS) {
cout << msg << "^" << link->weight() << "-1" ;
}
LogAware::pow (msg, link->weight() - 1);
}
const BpLinks& links = ninf(src)->getLinks();
if (Globals::logDomain) {
for (size_t i = 0; i < links.size(); i++) {
WeightedLink* l = static_cast<WeightedLink*> (links[i]);
if ( ! (l->facNode() == dst && l->index() == link->index())) {
msg += l->powMessage();
}
}
} else {
for (size_t i = 0; i < links.size(); i++) {
WeightedLink* l = static_cast<WeightedLink*> (links[i]);
if ( ! (l->facNode() == dst && l->index() == link->index())) {
msg *= l->powMessage();
if (Constants::SHOW_BP_CALCS) {
cout << " x " << l->nextMessage() << "^" << link->weight();
}
}
}
}
if (Constants::SHOW_BP_CALCS) {
cout << " = " << msg;
}
return msg;
}
void
WeightedBpSolver::printLinkInformation (void) const
{
for (size_t i = 0; i < links_.size(); i++) {
WeightedLink* l = static_cast<WeightedLink*> (links_[i]);
cout << l->toString() << ":" << endl;
cout << " curr msg = " << l->message() << endl;
cout << " next msg = " << l->nextMessage() << endl;
cout << " pow msg = " << l->powMessage() << endl;
cout << " index = " << l->index() << endl;
cout << " weight = " << l->weight() << endl;
cout << " residual = " << l->residual() << endl;
}
}

View File

@ -0,0 +1,61 @@
#ifndef HORUS_WEIGHTEDBPSOLVER_H
#define HORUS_WEIGHTEDBPSOLVER_H
#include "BpSolver.h"
class WeightedLink : public BpLink
{
public:
WeightedLink (FacNode* fn, VarNode* vn, size_t idx, unsigned weight)
: BpLink (fn, vn), index_(idx), weight_(weight),
pwdMsg_(vn->range(), LogAware::one()) { }
size_t index (void) const { return index_; }
unsigned weight (void) const { return weight_; }
const Params& powMessage (void) const { return pwdMsg_; }
void updateMessage (void)
{
pwdMsg_ = *nextMsg_;
swap (currMsg_, nextMsg_);
LogAware::pow (pwdMsg_, weight_);
}
private:
size_t index_;
unsigned weight_;
Params pwdMsg_;
};
class WeightedBpSolver : public BpSolver
{
public:
WeightedBpSolver (const FactorGraph& fg,
const vector<vector<unsigned>>& weights)
: BpSolver (fg), weights_(weights) { }
~WeightedBpSolver (void);
Params getPosterioriOf (VarId);
private:
void createLinks (void);
void maxResidualSchedule (void);
void calcFactorToVarMsg (BpLink*);
Params getVarToFactorMsg (const BpLink*) const;
void printLinkInformation (void) const;
vector<vector<unsigned>> weights_;
};
#endif // HORUS_WEIGHTEDBPSOLVER_H