Merge branch 'master' of github.com:tacgomes/yap6.3
This commit is contained in:
commit
31fde92a36
@ -39,18 +39,16 @@ warning :-
|
|||||||
set_solver(ve) :- set_pfl_flag(solver,ve).
|
set_solver(ve) :- set_pfl_flag(solver,ve).
|
||||||
set_solver(jt) :- set_pfl_flag(solver,jt).
|
set_solver(jt) :- set_pfl_flag(solver,jt).
|
||||||
set_solver(gibbs) :- set_pfl_flag(solver,gibbs).
|
set_solver(gibbs) :- set_pfl_flag(solver,gibbs).
|
||||||
set_solver(fove) :- set_pfl_flag(solver,fove).
|
set_solver(fove) :- set_pfl_flag(solver,fove), set_horus_flag(lifted_solver, fove).
|
||||||
set_solver(hve) :- set_pfl_flag(solver,bp), cpp_set_horus_flag(inf_alg, ve).
|
set_solver(lbp) :- set_pfl_flag(solver,fove), set_horus_flag(lifted_solver, lbp).
|
||||||
set_solver(bp) :- set_pfl_flag(solver,bp), cpp_set_horus_flag(inf_alg, bp).
|
set_solver(hve) :- set_pfl_flag(solver,bp), set_horus_flag(ground_solver, ve).
|
||||||
set_solver(cbp) :- set_pfl_flag(solver,bp), cpp_set_horus_flag(inf_alg, cbp).
|
set_solver(bp) :- set_pfl_flag(solver,bp), set_horus_flag(ground_solver, bp).
|
||||||
|
set_solver(cbp) :- set_pfl_flag(solver,bp), set_horus_flag(ground_solver, cbp).
|
||||||
set_solver(S) :- throw(error('unknow solver ', S)).
|
set_solver(S) :- throw(error('unknow solver ', S)).
|
||||||
|
|
||||||
|
|
||||||
set_horus_flag(K,V) :- cpp_set_horus_flag(K,V).
|
set_horus_flag(K,V) :- cpp_set_horus_flag(K,V).
|
||||||
|
|
||||||
%:- cpp_set_horus_flag(inf_alg, ve).
|
|
||||||
%:- cpp_set_horus_flag(inf_alg, bp).
|
|
||||||
%: -cpp_set_horus_flag(inf_alg, cbp).
|
|
||||||
|
|
||||||
:- cpp_set_horus_flag(schedule, seq_fixed).
|
:- cpp_set_horus_flag(schedule, seq_fixed).
|
||||||
%:- cpp_set_horus_flag(schedule, seq_random).
|
%:- cpp_set_horus_flag(schedule, seq_random).
|
||||||
|
@ -9,12 +9,10 @@
|
|||||||
#include "Util.h"
|
#include "Util.h"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
FactorGraph*
|
FactorGraph*
|
||||||
BayesBall::getMinimalFactorGraph (const VarIds& queryIds)
|
BayesBall::getMinimalFactorGraph (const VarIds& queryIds)
|
||||||
{
|
{
|
||||||
assert (fg_.isFromBayesNetwork());
|
assert (fg_.bayesianFactors());
|
||||||
Scheduling scheduling;
|
Scheduling scheduling;
|
||||||
for (size_t i = 0; i < queryIds.size(); i++) {
|
for (size_t i = 0; i < queryIds.size(); i++) {
|
||||||
assert (dag_.getNode (queryIds[i]));
|
assert (dag_.getNode (queryIds[i]));
|
||||||
|
@ -14,7 +14,6 @@
|
|||||||
|
|
||||||
BpSolver::BpSolver (const FactorGraph& fg) : Solver (fg)
|
BpSolver::BpSolver (const FactorGraph& fg) : Solver (fg)
|
||||||
{
|
{
|
||||||
fg_ = &fg;
|
|
||||||
runned_ = false;
|
runned_ = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -39,11 +38,9 @@ Params
|
|||||||
BpSolver::solveQuery (VarIds queryVids)
|
BpSolver::solveQuery (VarIds queryVids)
|
||||||
{
|
{
|
||||||
assert (queryVids.empty() == false);
|
assert (queryVids.empty() == false);
|
||||||
if (queryVids.size() == 1) {
|
return queryVids.size() == 1
|
||||||
return getPosterioriOf (queryVids[0]);
|
? getPosterioriOf (queryVids[0])
|
||||||
} else {
|
: getJointDistributionOf (queryVids);
|
||||||
return getJointDistributionOf (queryVids);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -61,8 +58,8 @@ BpSolver::printSolverFlags (void) const
|
|||||||
case Sch::PARALLEL: ss << "parallel"; break;
|
case Sch::PARALLEL: ss << "parallel"; break;
|
||||||
case Sch::MAX_RESIDUAL: ss << "max_residual"; break;
|
case Sch::MAX_RESIDUAL: ss << "max_residual"; break;
|
||||||
}
|
}
|
||||||
ss << ",max_iter=" << Util::toString (BpOptions::maxIter);
|
ss << ",max_iter=" << Util::toString (BpOptions::maxIter);
|
||||||
ss << ",accuracy=" << Util::toString (BpOptions::accuracy);
|
ss << ",accuracy=" << Util::toString (BpOptions::accuracy);
|
||||||
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
||||||
ss << "]" ;
|
ss << "]" ;
|
||||||
cout << ss.str() << endl;
|
cout << ss.str() << endl;
|
||||||
@ -76,24 +73,24 @@ BpSolver::getPosterioriOf (VarId vid)
|
|||||||
if (runned_ == false) {
|
if (runned_ == false) {
|
||||||
runSolver();
|
runSolver();
|
||||||
}
|
}
|
||||||
assert (fg_->getVarNode (vid));
|
assert (fg.getVarNode (vid));
|
||||||
VarNode* var = fg_->getVarNode (vid);
|
VarNode* var = fg.getVarNode (vid);
|
||||||
Params probs;
|
Params probs;
|
||||||
if (var->hasEvidence()) {
|
if (var->hasEvidence()) {
|
||||||
probs.resize (var->range(), LogAware::noEvidence());
|
probs.resize (var->range(), LogAware::noEvidence());
|
||||||
probs[var->getEvidence()] = LogAware::withEvidence();
|
probs[var->getEvidence()] = LogAware::withEvidence();
|
||||||
} else {
|
} else {
|
||||||
probs.resize (var->range(), LogAware::multIdenty());
|
probs.resize (var->range(), LogAware::multIdenty());
|
||||||
const SpLinkSet& links = ninf(var)->getLinks();
|
const BpLinks& links = ninf(var)->getLinks();
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
for (size_t i = 0; i < links.size(); i++) {
|
for (size_t i = 0; i < links.size(); i++) {
|
||||||
probs += links[i]->getMessage();
|
probs += links[i]->message();
|
||||||
}
|
}
|
||||||
LogAware::normalize (probs);
|
LogAware::normalize (probs);
|
||||||
Util::exp (probs);
|
Util::exp (probs);
|
||||||
} else {
|
} else {
|
||||||
for (size_t i = 0; i < links.size(); i++) {
|
for (size_t i = 0; i < links.size(); i++) {
|
||||||
probs *= links[i]->getMessage();
|
probs *= links[i]->message();
|
||||||
}
|
}
|
||||||
LogAware::normalize (probs);
|
LogAware::normalize (probs);
|
||||||
}
|
}
|
||||||
@ -109,7 +106,7 @@ BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
|||||||
if (runned_ == false) {
|
if (runned_ == false) {
|
||||||
runSolver();
|
runSolver();
|
||||||
}
|
}
|
||||||
VarNode* vn = fg_->getVarNode (jointVarIds[0]);
|
VarNode* vn = fg.getVarNode (jointVarIds[0]);
|
||||||
const FacNodes& facNodes = vn->neighbors();
|
const FacNodes& facNodes = vn->neighbors();
|
||||||
size_t idx = facNodes.size();
|
size_t idx = facNodes.size();
|
||||||
for (size_t i = 0; i < facNodes.size(); i++) {
|
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||||
@ -122,11 +119,11 @@ BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
|||||||
return getJointByConditioning (jointVarIds);
|
return getJointByConditioning (jointVarIds);
|
||||||
} else {
|
} else {
|
||||||
Factor res (facNodes[idx]->factor());
|
Factor res (facNodes[idx]->factor());
|
||||||
const SpLinkSet& links = ninf(facNodes[idx])->getLinks();
|
const BpLinks& links = ninf(facNodes[idx])->getLinks();
|
||||||
for (size_t i = 0; i < links.size(); i++) {
|
for (size_t i = 0; i < links.size(); i++) {
|
||||||
Factor msg ({links[i]->getVariable()->varId()},
|
Factor msg ({links[i]->varNode()->varId()},
|
||||||
{links[i]->getVariable()->range()},
|
{links[i]->varNode()->range()},
|
||||||
getVar2FactorMsg (links[i]));
|
getVarToFactorMsg (links[i]));
|
||||||
res.multiply (msg);
|
res.multiply (msg);
|
||||||
}
|
}
|
||||||
res.sumOutAllExcept (jointVarIds);
|
res.sumOutAllExcept (jointVarIds);
|
||||||
@ -154,7 +151,7 @@ BpSolver::runSolver (void)
|
|||||||
}
|
}
|
||||||
switch (BpOptions::schedule) {
|
switch (BpOptions::schedule) {
|
||||||
case BpOptions::Schedule::SEQ_RANDOM:
|
case BpOptions::Schedule::SEQ_RANDOM:
|
||||||
random_shuffle (links_.begin(), links_.end());
|
std::random_shuffle (links_.begin(), links_.end());
|
||||||
// no break
|
// no break
|
||||||
case BpOptions::Schedule::SEQ_FIXED:
|
case BpOptions::Schedule::SEQ_FIXED:
|
||||||
for (size_t i = 0; i < links_.size(); i++) {
|
for (size_t i = 0; i < links_.size(); i++) {
|
||||||
@ -192,11 +189,11 @@ BpSolver::runSolver (void)
|
|||||||
void
|
void
|
||||||
BpSolver::createLinks (void)
|
BpSolver::createLinks (void)
|
||||||
{
|
{
|
||||||
const FacNodes& facNodes = fg_->facNodes();
|
const FacNodes& facNodes = fg.facNodes();
|
||||||
for (size_t i = 0; i < facNodes.size(); i++) {
|
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||||
const VarNodes& neighbors = facNodes[i]->neighbors();
|
const VarNodes& neighbors = facNodes[i]->neighbors();
|
||||||
for (size_t j = 0; j < neighbors.size(); j++) {
|
for (size_t j = 0; j < neighbors.size(); j++) {
|
||||||
links_.push_back (new SpLink (facNodes[i], neighbors[j]));
|
links_.push_back (new BpLink (facNodes[i], neighbors[j]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -221,13 +218,13 @@ BpSolver::maxResidualSchedule (void)
|
|||||||
for (SortedOrder::iterator it = sortedOrder_.begin();
|
for (SortedOrder::iterator it = sortedOrder_.begin();
|
||||||
it != sortedOrder_.end(); ++it) {
|
it != sortedOrder_.end(); ++it) {
|
||||||
cout << " " << setw (30) << left << (*it)->toString();
|
cout << " " << setw (30) << left << (*it)->toString();
|
||||||
cout << "residual = " << (*it)->getResidual() << endl;
|
cout << "residual = " << (*it)->residual() << endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
SortedOrder::iterator it = sortedOrder_.begin();
|
SortedOrder::iterator it = sortedOrder_.begin();
|
||||||
SpLink* link = *it;
|
BpLink* link = *it;
|
||||||
if (link->getResidual() < BpOptions::accuracy) {
|
if (link->residual() < BpOptions::accuracy) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
updateMessage (link);
|
updateMessage (link);
|
||||||
@ -236,14 +233,14 @@ BpSolver::maxResidualSchedule (void)
|
|||||||
linkMap_.find (link)->second = sortedOrder_.insert (link);
|
linkMap_.find (link)->second = sortedOrder_.insert (link);
|
||||||
|
|
||||||
// update the messages that depend on message source --> destin
|
// update the messages that depend on message source --> destin
|
||||||
const FacNodes& factorNeighbors = link->getVariable()->neighbors();
|
const FacNodes& factorNeighbors = link->varNode()->neighbors();
|
||||||
for (size_t i = 0; i < factorNeighbors.size(); i++) {
|
for (size_t i = 0; i < factorNeighbors.size(); i++) {
|
||||||
if (factorNeighbors[i] != link->getFactor()) {
|
if (factorNeighbors[i] != link->facNode()) {
|
||||||
const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks();
|
const BpLinks& links = ninf(factorNeighbors[i])->getLinks();
|
||||||
for (size_t j = 0; j < links.size(); j++) {
|
for (size_t j = 0; j < links.size(); j++) {
|
||||||
if (links[j]->getVariable() != link->getVariable()) {
|
if (links[j]->varNode() != link->varNode()) {
|
||||||
calculateMessage (links[j]);
|
calculateMessage (links[j]);
|
||||||
SpLinkMap::iterator iter = linkMap_.find (links[j]);
|
BpLinkMap::iterator iter = linkMap_.find (links[j]);
|
||||||
sortedOrder_.erase (iter->second);
|
sortedOrder_.erase (iter->second);
|
||||||
iter->second = sortedOrder_.insert (links[j]);
|
iter->second = sortedOrder_.insert (links[j]);
|
||||||
}
|
}
|
||||||
@ -259,54 +256,45 @@ BpSolver::maxResidualSchedule (void)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
BpSolver::calculateFactor2VariableMsg (SpLink* link)
|
BpSolver::calcFactorToVarMsg (BpLink* link)
|
||||||
{
|
{
|
||||||
FacNode* src = link->getFactor();
|
FacNode* src = link->facNode();
|
||||||
const VarNode* dst = link->getVariable();
|
const VarNode* dst = link->varNode();
|
||||||
const SpLinkSet& links = ninf(src)->getLinks();
|
const BpLinks& links = ninf(src)->getLinks();
|
||||||
// calculate the product of messages that were sent
|
// calculate the product of messages that were sent
|
||||||
// to factor `src', except from var `dst'
|
// to factor `src', except from var `dst'
|
||||||
unsigned msgSize = 1;
|
unsigned reps = 1;
|
||||||
for (size_t i = 0; i < links.size(); i++) {
|
unsigned msgSize = Util::sizeExpected (src->factor().ranges());
|
||||||
msgSize *= links[i]->getVariable()->range();
|
|
||||||
}
|
|
||||||
unsigned repetitions = 1;
|
|
||||||
Params msgProduct (msgSize, LogAware::multIdenty());
|
Params msgProduct (msgSize, LogAware::multIdenty());
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
for (size_t i = links.size(); i-- > 0; ) {
|
for (size_t i = links.size(); i-- > 0; ) {
|
||||||
if (links[i]->getVariable() != dst) {
|
if (links[i]->varNode() != dst) {
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
cout << " message from " << links[i]->getVariable()->label();
|
cout << " message from " << links[i]->varNode()->label();
|
||||||
cout << ": " ;
|
cout << ": " ;
|
||||||
}
|
}
|
||||||
Util::add (msgProduct, getVar2FactorMsg (links[i]), repetitions);
|
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
|
||||||
repetitions *= links[i]->getVariable()->range();
|
reps, std::plus<double>());
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
unsigned range = links[i]->getVariable()->range();
|
|
||||||
Util::add (msgProduct, Params (range, 0.0), repetitions);
|
|
||||||
repetitions *= range;
|
|
||||||
}
|
}
|
||||||
|
reps *= links[i]->varNode()->range();
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (size_t i = links.size(); i-- > 0; ) {
|
for (size_t i = links.size(); i-- > 0; ) {
|
||||||
if (links[i]->getVariable() != dst) {
|
if (links[i]->varNode() != dst) {
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
cout << " message from " << links[i]->getVariable()->label();
|
cout << " message from " << links[i]->varNode()->label();
|
||||||
cout << ": " ;
|
cout << ": " ;
|
||||||
}
|
}
|
||||||
Util::multiply (msgProduct, getVar2FactorMsg (links[i]), repetitions);
|
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
|
||||||
repetitions *= links[i]->getVariable()->range();
|
reps, std::multiplies<double>());
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
unsigned range = links[i]->getVariable()->range();
|
|
||||||
Util::multiply (msgProduct, Params (range, 1.0), repetitions);
|
|
||||||
repetitions *= range;
|
|
||||||
}
|
}
|
||||||
|
reps *= links[i]->varNode()->range();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Factor result (src->factor().arguments(),
|
Factor result (src->factor().arguments(),
|
||||||
@ -321,21 +309,20 @@ BpSolver::calculateFactor2VariableMsg (SpLink* link)
|
|||||||
if (Constants::SHOW_BP_CALCS) {
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
cout << " marginalized: " << result.params() << endl;
|
cout << " marginalized: " << result.params() << endl;
|
||||||
}
|
}
|
||||||
link->getNextMessage() = result.params();
|
link->nextMessage() = result.params();
|
||||||
LogAware::normalize (link->getNextMessage());
|
LogAware::normalize (link->nextMessage());
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
cout << " curr msg: " << link->getMessage() << endl;
|
cout << " curr msg: " << link->message() << endl;
|
||||||
cout << " next msg: " << link->getNextMessage() << endl;
|
cout << " next msg: " << link->nextMessage() << endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
Params
|
||||||
BpSolver::getVar2FactorMsg (const SpLink* link) const
|
BpSolver::getVarToFactorMsg (const BpLink* link) const
|
||||||
{
|
{
|
||||||
const VarNode* src = link->getVariable();
|
const VarNode* src = link->varNode();
|
||||||
const FacNode* dst = link->getFactor();
|
|
||||||
Params msg;
|
Params msg;
|
||||||
if (src->hasEvidence()) {
|
if (src->hasEvidence()) {
|
||||||
msg.resize (src->range(), LogAware::noEvidence());
|
msg.resize (src->range(), LogAware::noEvidence());
|
||||||
@ -346,25 +333,24 @@ BpSolver::getVar2FactorMsg (const SpLink* link) const
|
|||||||
if (Constants::SHOW_BP_CALCS) {
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
cout << msg;
|
cout << msg;
|
||||||
}
|
}
|
||||||
const SpLinkSet& links = ninf (src)->getLinks();
|
BpLinks::const_iterator it;
|
||||||
|
const BpLinks& links = ninf (src)->getLinks();
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
SpLinkSet::const_iterator it;
|
for (it = links.begin(); it != links.end(); ++it) {
|
||||||
for (it = links.begin(); it != links.end(); ++ it) {
|
msg += (*it)->message();
|
||||||
msg += (*it)->getMessage();
|
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
cout << " x " << (*it)->getMessage();
|
cout << " x " << (*it)->message();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
msg -= link->getMessage();
|
msg -= link->message();
|
||||||
} else {
|
} else {
|
||||||
for (size_t i = 0; i < links.size(); i++) {
|
for (it = links.begin(); it != links.end(); ++it) {
|
||||||
if (links[i]->getFactor() != dst) {
|
msg *= (*it)->message();
|
||||||
msg *= links[i]->getMessage();
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
cout << " x " << (*it)->message();
|
||||||
cout << " x " << links[i]->getMessage();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
msg /= link->message();
|
||||||
}
|
}
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
cout << " = " << msg;
|
cout << " = " << msg;
|
||||||
@ -379,12 +365,12 @@ BpSolver::getJointByConditioning (const VarIds& jointVarIds) const
|
|||||||
{
|
{
|
||||||
VarNodes jointVars;
|
VarNodes jointVars;
|
||||||
for (size_t i = 0; i < jointVarIds.size(); i++) {
|
for (size_t i = 0; i < jointVarIds.size(); i++) {
|
||||||
assert (fg_->getVarNode (jointVarIds[i]));
|
assert (fg.getVarNode (jointVarIds[i]));
|
||||||
jointVars.push_back (fg_->getVarNode (jointVarIds[i]));
|
jointVars.push_back (fg.getVarNode (jointVarIds[i]));
|
||||||
}
|
}
|
||||||
|
|
||||||
FactorGraph* fg = new FactorGraph (*fg_);
|
FactorGraph* tempFg = new FactorGraph (fg);
|
||||||
BpSolver solver (*fg);
|
BpSolver solver (*tempFg);
|
||||||
solver.runSolver();
|
solver.runSolver();
|
||||||
Params prevBeliefs = solver.getPosterioriOf (jointVarIds[0]);
|
Params prevBeliefs = solver.getPosterioriOf (jointVarIds[0]);
|
||||||
|
|
||||||
@ -396,7 +382,7 @@ BpSolver::getJointByConditioning (const VarIds& jointVarIds) const
|
|||||||
Vars observedVars;
|
Vars observedVars;
|
||||||
Ranges observedRanges;
|
Ranges observedRanges;
|
||||||
for (size_t j = 0; j < observedVids.size(); j++) {
|
for (size_t j = 0; j < observedVids.size(); j++) {
|
||||||
observedVars.push_back (fg->getVarNode (observedVids[j]));
|
observedVars.push_back (tempFg->getVarNode (observedVids[j]));
|
||||||
observedRanges.push_back (observedVars.back()->range());
|
observedRanges.push_back (observedVars.back()->range());
|
||||||
}
|
}
|
||||||
Indexer indexer (observedRanges, false);
|
Indexer indexer (observedRanges, false);
|
||||||
@ -404,7 +390,7 @@ BpSolver::getJointByConditioning (const VarIds& jointVarIds) const
|
|||||||
for (size_t j = 0; j < observedVars.size(); j++) {
|
for (size_t j = 0; j < observedVars.size(); j++) {
|
||||||
observedVars[j]->setEvidence (indexer[j]);
|
observedVars[j]->setEvidence (indexer[j]);
|
||||||
}
|
}
|
||||||
BpSolver solver (*fg);
|
BpSolver solver (*tempFg);
|
||||||
solver.runSolver();
|
solver.runSolver();
|
||||||
Params beliefs = solver.getPosterioriOf (jointVarIds[i]);
|
Params beliefs = solver.getPosterioriOf (jointVarIds[i]);
|
||||||
for (size_t k = 0; k < beliefs.size(); k++) {
|
for (size_t k = 0; k < beliefs.size(); k++) {
|
||||||
@ -431,22 +417,22 @@ BpSolver::getJointByConditioning (const VarIds& jointVarIds) const
|
|||||||
void
|
void
|
||||||
BpSolver::initializeSolver (void)
|
BpSolver::initializeSolver (void)
|
||||||
{
|
{
|
||||||
const VarNodes& varNodes = fg_->varNodes();
|
const VarNodes& varNodes = fg.varNodes();
|
||||||
varsI_.reserve (varNodes.size());
|
varsI_.reserve (varNodes.size());
|
||||||
for (size_t i = 0; i < varNodes.size(); i++) {
|
for (size_t i = 0; i < varNodes.size(); i++) {
|
||||||
varsI_.push_back (new SPNodeInfo());
|
varsI_.push_back (new SPNodeInfo());
|
||||||
}
|
}
|
||||||
const FacNodes& facNodes = fg_->facNodes();
|
const FacNodes& facNodes = fg.facNodes();
|
||||||
facsI_.reserve (facNodes.size());
|
facsI_.reserve (facNodes.size());
|
||||||
for (size_t i = 0; i < facNodes.size(); i++) {
|
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||||
facsI_.push_back (new SPNodeInfo());
|
facsI_.push_back (new SPNodeInfo());
|
||||||
}
|
}
|
||||||
createLinks();
|
createLinks();
|
||||||
for (size_t i = 0; i < links_.size(); i++) {
|
for (size_t i = 0; i < links_.size(); i++) {
|
||||||
FacNode* src = links_[i]->getFactor();
|
FacNode* src = links_[i]->facNode();
|
||||||
VarNode* dst = links_[i]->getVariable();
|
VarNode* dst = links_[i]->varNode();
|
||||||
ninf (dst)->addSpLink (links_[i]);
|
ninf (dst)->addBpLink (links_[i]);
|
||||||
ninf (src)->addSpLink (links_[i]);
|
ninf (src)->addBpLink (links_[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -472,7 +458,7 @@ BpSolver::converged (void)
|
|||||||
}
|
}
|
||||||
bool converged = true;
|
bool converged = true;
|
||||||
if (BpOptions::schedule == BpOptions::Schedule::MAX_RESIDUAL) {
|
if (BpOptions::schedule == BpOptions::Schedule::MAX_RESIDUAL) {
|
||||||
double maxResidual = (*(sortedOrder_.begin()))->getResidual();
|
double maxResidual = (*(sortedOrder_.begin()))->residual();
|
||||||
if (maxResidual > BpOptions::accuracy) {
|
if (maxResidual > BpOptions::accuracy) {
|
||||||
converged = false;
|
converged = false;
|
||||||
} else {
|
} else {
|
||||||
@ -480,7 +466,7 @@ BpSolver::converged (void)
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (size_t i = 0; i < links_.size(); i++) {
|
for (size_t i = 0; i < links_.size(); i++) {
|
||||||
double residual = links_[i]->getResidual();
|
double residual = links_[i]->residual();
|
||||||
if (Globals::verbosity > 1) {
|
if (Globals::verbosity > 1) {
|
||||||
cout << links_[i]->toString() + " residual = " << residual << endl;
|
cout << links_[i]->toString() + " residual = " << residual << endl;
|
||||||
}
|
}
|
||||||
@ -504,13 +490,13 @@ void
|
|||||||
BpSolver::printLinkInformation (void) const
|
BpSolver::printLinkInformation (void) const
|
||||||
{
|
{
|
||||||
for (size_t i = 0; i < links_.size(); i++) {
|
for (size_t i = 0; i < links_.size(); i++) {
|
||||||
SpLink* l = links_[i];
|
BpLink* l = links_[i];
|
||||||
cout << l->toString() << ":" << endl;
|
cout << l->toString() << ":" << endl;
|
||||||
cout << " curr msg = " ;
|
cout << " curr msg = " ;
|
||||||
cout << l->getMessage() << endl;
|
cout << l->message() << endl;
|
||||||
cout << " next msg = " ;
|
cout << " next msg = " ;
|
||||||
cout << l->getNextMessage() << endl;
|
cout << l->nextMessage() << endl;
|
||||||
cout << " residual = " << l->getResidual() << endl;
|
cout << " residual = " << l->residual() << endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -13,10 +13,10 @@
|
|||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
|
|
||||||
class SpLink
|
class BpLink
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
SpLink (FacNode* fn, VarNode* vn)
|
BpLink (FacNode* fn, VarNode* vn)
|
||||||
{
|
{
|
||||||
fac_ = fn;
|
fac_ = fn;
|
||||||
var_ = vn;
|
var_ = vn;
|
||||||
@ -24,23 +24,20 @@ class SpLink
|
|||||||
v2_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
|
v2_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
|
||||||
currMsg_ = &v1_;
|
currMsg_ = &v1_;
|
||||||
nextMsg_ = &v2_;
|
nextMsg_ = &v2_;
|
||||||
msgSended_ = false;
|
|
||||||
residual_ = 0.0;
|
residual_ = 0.0;
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual ~SpLink (void) { };
|
virtual ~BpLink (void) { };
|
||||||
|
|
||||||
FacNode* getFactor (void) const { return fac_; }
|
FacNode* facNode (void) const { return fac_; }
|
||||||
|
|
||||||
VarNode* getVariable (void) const { return var_; }
|
VarNode* varNode (void) const { return var_; }
|
||||||
|
|
||||||
const Params& getMessage (void) const { return *currMsg_; }
|
const Params& message (void) const { return *currMsg_; }
|
||||||
|
|
||||||
Params& getNextMessage (void) { return *nextMsg_; }
|
Params& nextMessage (void) { return *nextMsg_; }
|
||||||
|
|
||||||
bool messageWasSended (void) const { return msgSended_; }
|
double residual (void) const { return residual_; }
|
||||||
|
|
||||||
double getResidual (void) const { return residual_; }
|
|
||||||
|
|
||||||
void clearResidual (void) { residual_ = 0.0; }
|
void clearResidual (void) { residual_ = 0.0; }
|
||||||
|
|
||||||
@ -52,7 +49,6 @@ class SpLink
|
|||||||
virtual void updateMessage (void)
|
virtual void updateMessage (void)
|
||||||
{
|
{
|
||||||
swap (currMsg_, nextMsg_);
|
swap (currMsg_, nextMsg_);
|
||||||
msgSended_ = true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
string toString (void) const
|
string toString (void) const
|
||||||
@ -71,20 +67,19 @@ class SpLink
|
|||||||
Params v2_;
|
Params v2_;
|
||||||
Params* currMsg_;
|
Params* currMsg_;
|
||||||
Params* nextMsg_;
|
Params* nextMsg_;
|
||||||
bool msgSended_;
|
|
||||||
double residual_;
|
double residual_;
|
||||||
};
|
};
|
||||||
|
|
||||||
typedef vector<SpLink*> SpLinkSet;
|
typedef vector<BpLink*> BpLinks;
|
||||||
|
|
||||||
|
|
||||||
class SPNodeInfo
|
class SPNodeInfo
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
void addSpLink (SpLink* link) { links_.push_back (link); }
|
void addBpLink (BpLink* link) { links_.push_back (link); }
|
||||||
const SpLinkSet& getLinks (void) { return links_; }
|
const BpLinks& getLinks (void) { return links_; }
|
||||||
private:
|
private:
|
||||||
SpLinkSet links_;
|
BpLinks links_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
@ -110,9 +105,9 @@ class BpSolver : public Solver
|
|||||||
|
|
||||||
virtual void maxResidualSchedule (void);
|
virtual void maxResidualSchedule (void);
|
||||||
|
|
||||||
virtual void calculateFactor2VariableMsg (SpLink*);
|
virtual void calcFactorToVarMsg (BpLink*);
|
||||||
|
|
||||||
virtual Params getVar2FactorMsg (const SpLink*) const;
|
virtual Params getVarToFactorMsg (const BpLink*) const;
|
||||||
|
|
||||||
virtual Params getJointByConditioning (const VarIds&) const;
|
virtual Params getJointByConditioning (const VarIds&) const;
|
||||||
|
|
||||||
@ -126,30 +121,30 @@ class BpSolver : public Solver
|
|||||||
return facsI_[fac->getIndex()];
|
return facsI_[fac->getIndex()];
|
||||||
}
|
}
|
||||||
|
|
||||||
void calculateAndUpdateMessage (SpLink* link, bool calcResidual = true)
|
void calculateAndUpdateMessage (BpLink* link, bool calcResidual = true)
|
||||||
{
|
{
|
||||||
if (Globals::verbosity > 2) {
|
if (Globals::verbosity > 2) {
|
||||||
cout << "calculating & updating " << link->toString() << endl;
|
cout << "calculating & updating " << link->toString() << endl;
|
||||||
}
|
}
|
||||||
calculateFactor2VariableMsg (link);
|
calcFactorToVarMsg (link);
|
||||||
if (calcResidual) {
|
if (calcResidual) {
|
||||||
link->updateResidual();
|
link->updateResidual();
|
||||||
}
|
}
|
||||||
link->updateMessage();
|
link->updateMessage();
|
||||||
}
|
}
|
||||||
|
|
||||||
void calculateMessage (SpLink* link, bool calcResidual = true)
|
void calculateMessage (BpLink* link, bool calcResidual = true)
|
||||||
{
|
{
|
||||||
if (Globals::verbosity > 2) {
|
if (Globals::verbosity > 2) {
|
||||||
cout << "calculating " << link->toString() << endl;
|
cout << "calculating " << link->toString() << endl;
|
||||||
}
|
}
|
||||||
calculateFactor2VariableMsg (link);
|
calcFactorToVarMsg (link);
|
||||||
if (calcResidual) {
|
if (calcResidual) {
|
||||||
link->updateResidual();
|
link->updateResidual();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void updateMessage (SpLink* link)
|
void updateMessage (BpLink* link)
|
||||||
{
|
{
|
||||||
link->updateMessage();
|
link->updateMessage();
|
||||||
if (Globals::verbosity > 2) {
|
if (Globals::verbosity > 2) {
|
||||||
@ -159,24 +154,23 @@ class BpSolver : public Solver
|
|||||||
|
|
||||||
struct CompareResidual
|
struct CompareResidual
|
||||||
{
|
{
|
||||||
inline bool operator() (const SpLink* link1, const SpLink* link2)
|
inline bool operator() (const BpLink* link1, const BpLink* link2)
|
||||||
{
|
{
|
||||||
return link1->getResidual() > link2->getResidual();
|
return link1->residual() > link2->residual();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
SpLinkSet links_;
|
BpLinks links_;
|
||||||
unsigned nIters_;
|
unsigned nIters_;
|
||||||
vector<SPNodeInfo*> varsI_;
|
vector<SPNodeInfo*> varsI_;
|
||||||
vector<SPNodeInfo*> facsI_;
|
vector<SPNodeInfo*> facsI_;
|
||||||
bool runned_;
|
bool runned_;
|
||||||
const FactorGraph* fg_;
|
|
||||||
|
|
||||||
typedef multiset<SpLink*, CompareResidual> SortedOrder;
|
typedef multiset<BpLink*, CompareResidual> SortedOrder;
|
||||||
SortedOrder sortedOrder_;
|
SortedOrder sortedOrder_;
|
||||||
|
|
||||||
typedef unordered_map<SpLink*, SortedOrder::iterator> SpLinkMap;
|
typedef unordered_map<BpLink*, SortedOrder::iterator> BpLinkMap;
|
||||||
SpLinkMap linkMap_;
|
BpLinkMap linkMap_;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void initializeSolver (void);
|
void initializeSolver (void);
|
||||||
|
@ -1,318 +0,0 @@
|
|||||||
#include "CFactorGraph.h"
|
|
||||||
#include "Factor.h"
|
|
||||||
|
|
||||||
|
|
||||||
bool CFactorGraph::checkForIdenticalFactors = true;
|
|
||||||
|
|
||||||
CFactorGraph::CFactorGraph (const FactorGraph& fg)
|
|
||||||
: freeColor_(0), groundFg_(&fg)
|
|
||||||
{
|
|
||||||
findIdenticalFactors();
|
|
||||||
setInitialColors();
|
|
||||||
createGroups();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
CFactorGraph::~CFactorGraph (void)
|
|
||||||
{
|
|
||||||
for (size_t i = 0; i < varClusters_.size(); i++) {
|
|
||||||
delete varClusters_[i];
|
|
||||||
}
|
|
||||||
for (size_t i = 0; i < facClusters_.size(); i++) {
|
|
||||||
delete facClusters_[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
CFactorGraph::findIdenticalFactors()
|
|
||||||
{
|
|
||||||
const FacNodes& facNodes = groundFg_->facNodes();
|
|
||||||
if (checkForIdenticalFactors == false ||
|
|
||||||
facNodes.size() == 1) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
for (size_t i = 0; i < facNodes.size(); i++) {
|
|
||||||
facNodes[i]->factor().setDistId (Util::maxUnsigned());
|
|
||||||
}
|
|
||||||
unsigned groupCount = 1;
|
|
||||||
for (size_t i = 0; i < facNodes.size() - 1; i++) {
|
|
||||||
Factor& f1 = facNodes[i]->factor();
|
|
||||||
if (f1.distId() != Util::maxUnsigned()) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
f1.setDistId (groupCount);
|
|
||||||
for (size_t j = i + 1; j < facNodes.size(); j++) {
|
|
||||||
Factor& f2 = facNodes[j]->factor();
|
|
||||||
if (f2.distId() != Util::maxUnsigned()) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (f1.size() == f2.size() &&
|
|
||||||
f1.ranges() == f2.ranges() &&
|
|
||||||
f1.params() == f2.params()) {
|
|
||||||
f2.setDistId (groupCount);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
groupCount ++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
CFactorGraph::setInitialColors (void)
|
|
||||||
{
|
|
||||||
varColors_.resize (groundFg_->nrVarNodes());
|
|
||||||
facColors_.resize (groundFg_->nrFacNodes());
|
|
||||||
|
|
||||||
// create the initial variable colors
|
|
||||||
VarColorMap colorMap;
|
|
||||||
const VarNodes& varNodes = groundFg_->varNodes();
|
|
||||||
for (size_t i = 0; i < varNodes.size(); i++) {
|
|
||||||
unsigned range = varNodes[i]->range();
|
|
||||||
VarColorMap::iterator it = colorMap.find (range);
|
|
||||||
if (it == colorMap.end()) {
|
|
||||||
it = colorMap.insert (make_pair (
|
|
||||||
range, Colors (range + 1, -1))).first;
|
|
||||||
}
|
|
||||||
unsigned idx = varNodes[i]->hasEvidence()
|
|
||||||
? varNodes[i]->getEvidence()
|
|
||||||
: range;
|
|
||||||
Colors& stateColors = it->second;
|
|
||||||
if (stateColors[idx] == -1) {
|
|
||||||
stateColors[idx] = getFreeColor();
|
|
||||||
}
|
|
||||||
setColor (varNodes[i], stateColors[idx]);
|
|
||||||
}
|
|
||||||
const FacNodes& facNodes = groundFg_->facNodes();
|
|
||||||
// create the initial factor colors
|
|
||||||
DistColorMap distColors;
|
|
||||||
for (size_t i = 0; i < facNodes.size(); i++) {
|
|
||||||
unsigned distId = facNodes[i]->factor().distId();
|
|
||||||
DistColorMap::iterator it = distColors.find (distId);
|
|
||||||
if (it == distColors.end()) {
|
|
||||||
it = distColors.insert (make_pair (distId, getFreeColor())).first;
|
|
||||||
}
|
|
||||||
setColor (facNodes[i], it->second);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
CFactorGraph::createGroups (void)
|
|
||||||
{
|
|
||||||
VarSignMap varGroups;
|
|
||||||
FacSignMap facGroups;
|
|
||||||
unsigned nIters = 0;
|
|
||||||
bool groupsHaveChanged = true;
|
|
||||||
const VarNodes& varNodes = groundFg_->varNodes();
|
|
||||||
const FacNodes& facNodes = groundFg_->facNodes();
|
|
||||||
|
|
||||||
while (groupsHaveChanged || nIters == 1) {
|
|
||||||
nIters ++;
|
|
||||||
|
|
||||||
// set a new color to the variables with the same signature
|
|
||||||
size_t prevVarGroupsSize = varGroups.size();
|
|
||||||
varGroups.clear();
|
|
||||||
for (size_t i = 0; i < varNodes.size(); i++) {
|
|
||||||
const VarSignature& signature = getSignature (varNodes[i]);
|
|
||||||
VarSignMap::iterator it = varGroups.find (signature);
|
|
||||||
if (it == varGroups.end()) {
|
|
||||||
it = varGroups.insert (make_pair (signature, VarNodes())).first;
|
|
||||||
}
|
|
||||||
it->second.push_back (varNodes[i]);
|
|
||||||
}
|
|
||||||
for (VarSignMap::iterator it = varGroups.begin();
|
|
||||||
it != varGroups.end(); ++it) {
|
|
||||||
Color newColor = getFreeColor();
|
|
||||||
VarNodes& groupMembers = it->second;
|
|
||||||
for (size_t i = 0; i < groupMembers.size(); i++) {
|
|
||||||
setColor (groupMembers[i], newColor);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t prevFactorGroupsSize = facGroups.size();
|
|
||||||
facGroups.clear();
|
|
||||||
// set a new color to the factors with the same signature
|
|
||||||
for (size_t i = 0; i < facNodes.size(); i++) {
|
|
||||||
const FacSignature& signature = getSignature (facNodes[i]);
|
|
||||||
FacSignMap::iterator it = facGroups.find (signature);
|
|
||||||
if (it == facGroups.end()) {
|
|
||||||
it = facGroups.insert (make_pair (signature, FacNodes())).first;
|
|
||||||
}
|
|
||||||
it->second.push_back (facNodes[i]);
|
|
||||||
}
|
|
||||||
for (FacSignMap::iterator it = facGroups.begin();
|
|
||||||
it != facGroups.end(); ++it) {
|
|
||||||
Color newColor = getFreeColor();
|
|
||||||
FacNodes& groupMembers = it->second;
|
|
||||||
for (size_t i = 0; i < groupMembers.size(); i++) {
|
|
||||||
setColor (groupMembers[i], newColor);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
groupsHaveChanged = prevVarGroupsSize != varGroups.size()
|
|
||||||
|| prevFactorGroupsSize != facGroups.size();
|
|
||||||
}
|
|
||||||
// printGroups (varGroups, facGroups);
|
|
||||||
createClusters (varGroups, facGroups);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
CFactorGraph::createClusters (
|
|
||||||
const VarSignMap& varGroups,
|
|
||||||
const FacSignMap& facGroups)
|
|
||||||
{
|
|
||||||
varClusters_.reserve (varGroups.size());
|
|
||||||
for (VarSignMap::const_iterator it = varGroups.begin();
|
|
||||||
it != varGroups.end(); ++it) {
|
|
||||||
const VarNodes& groupVars = it->second;
|
|
||||||
VarCluster* vc = new VarCluster (groupVars);
|
|
||||||
for (size_t i = 0; i < groupVars.size(); i++) {
|
|
||||||
vid2VarCluster_.insert (make_pair (groupVars[i]->varId(), vc));
|
|
||||||
}
|
|
||||||
varClusters_.push_back (vc);
|
|
||||||
}
|
|
||||||
|
|
||||||
facClusters_.reserve (facGroups.size());
|
|
||||||
for (FacSignMap::const_iterator it = facGroups.begin();
|
|
||||||
it != facGroups.end(); ++it) {
|
|
||||||
FacNode* groupFactor = it->second[0];
|
|
||||||
const VarNodes& neighs = groupFactor->neighbors();
|
|
||||||
VarClusters varClusters;
|
|
||||||
varClusters.reserve (neighs.size());
|
|
||||||
for (size_t i = 0; i < neighs.size(); i++) {
|
|
||||||
VarId vid = neighs[i]->varId();
|
|
||||||
varClusters.push_back (vid2VarCluster_.find (vid)->second);
|
|
||||||
}
|
|
||||||
facClusters_.push_back (new FacCluster (it->second, varClusters));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
VarSignature
|
|
||||||
CFactorGraph::getSignature (const VarNode* varNode)
|
|
||||||
{
|
|
||||||
const FacNodes& neighs = varNode->neighbors();
|
|
||||||
VarSignature sign;
|
|
||||||
sign.reserve (neighs.size() + 1);
|
|
||||||
for (size_t i = 0; i < neighs.size(); i++) {
|
|
||||||
sign.push_back (make_pair (
|
|
||||||
getColor (neighs[i]),
|
|
||||||
neighs[i]->factor().indexOf (varNode->varId())));
|
|
||||||
}
|
|
||||||
std::sort (sign.begin(), sign.end());
|
|
||||||
sign.push_back (make_pair (getColor (varNode), 0));
|
|
||||||
return sign;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
FacSignature
|
|
||||||
CFactorGraph::getSignature (const FacNode* facNode)
|
|
||||||
{
|
|
||||||
const VarNodes& neighs = facNode->neighbors();
|
|
||||||
FacSignature sign;
|
|
||||||
sign.reserve (neighs.size() + 1);
|
|
||||||
for (size_t i = 0; i < neighs.size(); i++) {
|
|
||||||
sign.push_back (getColor (neighs[i]));
|
|
||||||
}
|
|
||||||
sign.push_back (getColor (facNode));
|
|
||||||
return sign;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
FactorGraph*
|
|
||||||
CFactorGraph::getGroundFactorGraph (void)
|
|
||||||
{
|
|
||||||
FactorGraph* fg = new FactorGraph();
|
|
||||||
for (size_t i = 0; i < varClusters_.size(); i++) {
|
|
||||||
VarNode* newVar = new VarNode (varClusters_[i]->first());
|
|
||||||
varClusters_[i]->setRepresentative (newVar);
|
|
||||||
fg->addVarNode (newVar);
|
|
||||||
}
|
|
||||||
for (size_t i = 0; i < facClusters_.size(); i++) {
|
|
||||||
Vars vars;
|
|
||||||
const VarClusters& clusters = facClusters_[i]->varClusters();
|
|
||||||
for (size_t j = 0; j < clusters.size(); j++) {
|
|
||||||
vars.push_back (clusters[j]->representative());
|
|
||||||
}
|
|
||||||
const Factor& groundFac = facClusters_[i]->first()->factor();
|
|
||||||
FacNode* fn = new FacNode (Factor (
|
|
||||||
vars, groundFac.params(), groundFac.distId()));
|
|
||||||
facClusters_[i]->setRepresentative (fn);
|
|
||||||
fg->addFacNode (fn);
|
|
||||||
for (size_t j = 0; j < vars.size(); j++) {
|
|
||||||
fg->addEdge (static_cast<VarNode*> (vars[j]), fn);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return fg;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
unsigned
|
|
||||||
CFactorGraph::getEdgeCount (
|
|
||||||
const FacCluster* fc,
|
|
||||||
const VarCluster* vc,
|
|
||||||
size_t index) const
|
|
||||||
{
|
|
||||||
unsigned count = 0;
|
|
||||||
VarId reprVid = vc->representative()->varId();
|
|
||||||
VarNode* groundVar = groundFg_->getVarNode (reprVid);
|
|
||||||
const FacNodes& neighs = groundVar->neighbors();
|
|
||||||
for (size_t i = 0; i < neighs.size(); i++) {
|
|
||||||
FacNodes::const_iterator it;
|
|
||||||
it = std::find (fc->members().begin(), fc->members().end(), neighs[i]);
|
|
||||||
if (it != fc->members().end() &&
|
|
||||||
(*it)->factor().indexOf (reprVid) == index) {
|
|
||||||
count ++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return count;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
CFactorGraph::printGroups (
|
|
||||||
const VarSignMap& varGroups,
|
|
||||||
const FacSignMap& facGroups) const
|
|
||||||
{
|
|
||||||
unsigned count = 1;
|
|
||||||
cout << "variable groups:" << endl;
|
|
||||||
for (VarSignMap::const_iterator it = varGroups.begin();
|
|
||||||
it != varGroups.end(); ++it) {
|
|
||||||
const VarNodes& groupMembers = it->second;
|
|
||||||
if (groupMembers.size() > 0) {
|
|
||||||
cout << count << ": " ;
|
|
||||||
for (size_t i = 0; i < groupMembers.size(); i++) {
|
|
||||||
cout << groupMembers[i]->label() << " " ;
|
|
||||||
}
|
|
||||||
count ++;
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
count = 1;
|
|
||||||
cout << endl << "factor groups:" << endl;
|
|
||||||
for (FacSignMap::const_iterator it = facGroups.begin();
|
|
||||||
it != facGroups.end(); ++it) {
|
|
||||||
const FacNodes& groupMembers = it->second;
|
|
||||||
if (groupMembers.size() > 0) {
|
|
||||||
cout << ++count << ": " ;
|
|
||||||
for (size_t i = 0; i < groupMembers.size(); i++) {
|
|
||||||
cout << groupMembers[i]->getLabel() << " " ;
|
|
||||||
}
|
|
||||||
count ++;
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,176 +0,0 @@
|
|||||||
#ifndef HORUS_CFACTORGRAPH_H
|
|
||||||
#define HORUS_CFACTORGRAPH_H
|
|
||||||
|
|
||||||
#include <unordered_map>
|
|
||||||
|
|
||||||
#include "FactorGraph.h"
|
|
||||||
#include "Factor.h"
|
|
||||||
#include "Util.h"
|
|
||||||
#include "Horus.h"
|
|
||||||
|
|
||||||
class VarCluster;
|
|
||||||
class FacCluster;
|
|
||||||
class VarSignatureHash;
|
|
||||||
class FacSignatureHash;
|
|
||||||
|
|
||||||
typedef long Color;
|
|
||||||
typedef vector<Color> Colors;
|
|
||||||
typedef vector<std::pair<Color,unsigned>> VarSignature;
|
|
||||||
typedef vector<Color> FacSignature;
|
|
||||||
|
|
||||||
typedef unordered_map<unsigned, Color> DistColorMap;
|
|
||||||
typedef unordered_map<unsigned, Colors> VarColorMap;
|
|
||||||
|
|
||||||
typedef unordered_map<VarSignature, VarNodes, VarSignatureHash> VarSignMap;
|
|
||||||
typedef unordered_map<FacSignature, FacNodes, FacSignatureHash> FacSignMap;
|
|
||||||
|
|
||||||
typedef vector<VarCluster*> VarClusters;
|
|
||||||
typedef vector<FacCluster*> FacClusters;
|
|
||||||
|
|
||||||
typedef unordered_map<VarId, VarCluster*> VarId2VarCluster;
|
|
||||||
|
|
||||||
|
|
||||||
struct VarSignatureHash
|
|
||||||
{
|
|
||||||
size_t operator() (const VarSignature &sig) const
|
|
||||||
{
|
|
||||||
size_t val = hash<size_t>()(sig.size());
|
|
||||||
for (size_t i = 0; i < sig.size(); i++) {
|
|
||||||
val ^= hash<size_t>()(sig[i].first);
|
|
||||||
val ^= hash<size_t>()(sig[i].second);
|
|
||||||
}
|
|
||||||
return val;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
struct FacSignatureHash
|
|
||||||
{
|
|
||||||
size_t operator() (const FacSignature &sig) const
|
|
||||||
{
|
|
||||||
size_t val = hash<size_t>()(sig.size());
|
|
||||||
for (size_t i = 0; i < sig.size(); i++) {
|
|
||||||
val ^= hash<size_t>()(sig[i]);
|
|
||||||
}
|
|
||||||
return val;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class VarCluster
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
VarCluster (const VarNodes& vs) : members_(vs) { }
|
|
||||||
|
|
||||||
const VarNode* first (void) const { return members_.front(); }
|
|
||||||
|
|
||||||
const VarNodes& members (void) const { return members_; }
|
|
||||||
|
|
||||||
VarNode* representative (void) const { return repr_; }
|
|
||||||
|
|
||||||
void setRepresentative (VarNode* vn) { repr_ = vn; }
|
|
||||||
|
|
||||||
private:
|
|
||||||
VarNodes members_;
|
|
||||||
VarNode* repr_;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
class FacCluster
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
FacCluster (const FacNodes& fcs, const VarClusters& vcs)
|
|
||||||
: members_(fcs), varClusters_(vcs) { }
|
|
||||||
|
|
||||||
const FacNode* first (void) const { return members_.front(); }
|
|
||||||
|
|
||||||
const FacNodes& members (void) const { return members_; }
|
|
||||||
|
|
||||||
VarClusters& varClusters (void) { return varClusters_; }
|
|
||||||
|
|
||||||
FacNode* representative (void) const { return repr_; }
|
|
||||||
|
|
||||||
void setRepresentative (FacNode* fn) { repr_ = fn; }
|
|
||||||
|
|
||||||
private:
|
|
||||||
FacNodes members_;
|
|
||||||
VarClusters varClusters_;
|
|
||||||
FacNode* repr_;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
class CFactorGraph
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
CFactorGraph (const FactorGraph&);
|
|
||||||
|
|
||||||
~CFactorGraph (void);
|
|
||||||
|
|
||||||
const VarClusters& varClusters (void) { return varClusters_; }
|
|
||||||
|
|
||||||
const FacClusters& facClusters (void) { return facClusters_; }
|
|
||||||
|
|
||||||
VarNode* getEquivalent (VarId vid)
|
|
||||||
{
|
|
||||||
VarCluster* vc = vid2VarCluster_.find (vid)->second;
|
|
||||||
return vc->representative();
|
|
||||||
}
|
|
||||||
|
|
||||||
FactorGraph* getGroundFactorGraph (void);
|
|
||||||
|
|
||||||
unsigned getEdgeCount (const FacCluster*,
|
|
||||||
const VarCluster*, size_t index) const;
|
|
||||||
|
|
||||||
static bool checkForIdenticalFactors;
|
|
||||||
|
|
||||||
private:
|
|
||||||
Color getFreeColor (void)
|
|
||||||
{
|
|
||||||
++ freeColor_;
|
|
||||||
return freeColor_ - 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
Color getColor (const VarNode* vn) const
|
|
||||||
{
|
|
||||||
return varColors_[vn->getIndex()];
|
|
||||||
}
|
|
||||||
Color getColor (const FacNode* fn) const {
|
|
||||||
return facColors_[fn->getIndex()];
|
|
||||||
}
|
|
||||||
|
|
||||||
void setColor (const VarNode* vn, Color c)
|
|
||||||
{
|
|
||||||
varColors_[vn->getIndex()] = c;
|
|
||||||
}
|
|
||||||
|
|
||||||
void setColor (const FacNode* fn, Color c)
|
|
||||||
{
|
|
||||||
facColors_[fn->getIndex()] = c;
|
|
||||||
}
|
|
||||||
|
|
||||||
void findIdenticalFactors (void);
|
|
||||||
|
|
||||||
void setInitialColors (void);
|
|
||||||
|
|
||||||
void createGroups (void);
|
|
||||||
|
|
||||||
void createClusters (const VarSignMap&, const FacSignMap&);
|
|
||||||
|
|
||||||
VarSignature getSignature (const VarNode*);
|
|
||||||
|
|
||||||
FacSignature getSignature (const FacNode*);
|
|
||||||
|
|
||||||
void printGroups (const VarSignMap&, const FacSignMap&) const;
|
|
||||||
|
|
||||||
Color freeColor_;
|
|
||||||
Colors varColors_;
|
|
||||||
Colors facColors_;
|
|
||||||
VarClusters varClusters_;
|
|
||||||
FacClusters facClusters_;
|
|
||||||
VarId2VarCluster vid2VarCluster_;
|
|
||||||
const FactorGraph* groundFg_;
|
|
||||||
};
|
|
||||||
|
|
||||||
#endif // HORUS_CFACTORGRAPH_H
|
|
||||||
|
|
@ -1,22 +1,32 @@
|
|||||||
#include "CbpSolver.h"
|
#include "CbpSolver.h"
|
||||||
|
#include "WeightedBpSolver.h"
|
||||||
|
|
||||||
|
|
||||||
CbpSolver::CbpSolver (const FactorGraph& fg) : BpSolver (fg)
|
bool CbpSolver::checkForIdenticalFactors = true;
|
||||||
|
|
||||||
|
|
||||||
|
CbpSolver::CbpSolver (const FactorGraph& fg)
|
||||||
|
: Solver (fg), freeColor_(0)
|
||||||
{
|
{
|
||||||
cfg_ = new CFactorGraph (fg);
|
findIdenticalFactors();
|
||||||
fg_ = cfg_->getGroundFactorGraph();
|
setInitialColors();
|
||||||
|
createGroups();
|
||||||
|
compressedFg_ = getCompressedFactorGraph();
|
||||||
|
solver_ = new WeightedBpSolver (*compressedFg_, getWeights());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
CbpSolver::~CbpSolver (void)
|
CbpSolver::~CbpSolver (void)
|
||||||
{
|
{
|
||||||
delete cfg_;
|
delete solver_;
|
||||||
delete fg_;
|
delete compressedFg_;
|
||||||
for (size_t i = 0; i < links_.size(); i++) {
|
for (size_t i = 0; i < varClusters_.size(); i++) {
|
||||||
delete links_[i];
|
delete varClusters_[i];
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < facClusters_.size(); i++) {
|
||||||
|
delete facClusters_[i];
|
||||||
}
|
}
|
||||||
links_.clear();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -38,7 +48,7 @@ CbpSolver::printSolverFlags (void) const
|
|||||||
ss << ",accuracy=" << BpOptions::accuracy;
|
ss << ",accuracy=" << BpOptions::accuracy;
|
||||||
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
||||||
ss << ",chkif=" <<
|
ss << ",chkif=" <<
|
||||||
Util::toString (CFactorGraph::checkForIdenticalFactors);
|
Util::toString (CbpSolver::checkForIdenticalFactors);
|
||||||
ss << "]" ;
|
ss << "]" ;
|
||||||
cout << ss.str() << endl;
|
cout << ss.str() << endl;
|
||||||
}
|
}
|
||||||
@ -46,305 +56,345 @@ CbpSolver::printSolverFlags (void) const
|
|||||||
|
|
||||||
|
|
||||||
Params
|
Params
|
||||||
CbpSolver::getPosterioriOf (VarId vid)
|
CbpSolver::solveQuery (VarIds queryVids)
|
||||||
{
|
{
|
||||||
if (runned_ == false) {
|
assert (queryVids.empty() == false);
|
||||||
runSolver();
|
Params res;
|
||||||
}
|
if (queryVids.size() == 1) {
|
||||||
assert (cfg_->getEquivalent (vid));
|
res = solver_->getPosterioriOf (getRepresentative (queryVids[0]));
|
||||||
VarNode* var = cfg_->getEquivalent (vid);
|
|
||||||
Params probs;
|
|
||||||
if (var->hasEvidence()) {
|
|
||||||
probs.resize (var->range(), LogAware::noEvidence());
|
|
||||||
probs[var->getEvidence()] = LogAware::withEvidence();
|
|
||||||
} else {
|
} else {
|
||||||
probs.resize (var->range(), LogAware::multIdenty());
|
VarNode* vn = fg.getVarNode (queryVids[0]);
|
||||||
const SpLinkSet& links = ninf(var)->getLinks();
|
const FacNodes& facNodes = vn->neighbors();
|
||||||
if (Globals::logDomain) {
|
size_t idx = facNodes.size();
|
||||||
for (size_t i = 0; i < links.size(); i++) {
|
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
if (facNodes[i]->factor().contains (queryVids)) {
|
||||||
probs += l->poweredMessage();
|
idx = i;
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
LogAware::normalize (probs);
|
cout << endl;
|
||||||
Util::exp (probs);
|
|
||||||
} else {
|
|
||||||
for (size_t i = 0; i < links.size(); i++) {
|
|
||||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
|
||||||
probs *= l->poweredMessage();
|
|
||||||
}
|
|
||||||
LogAware::normalize (probs);
|
|
||||||
}
|
}
|
||||||
|
if (idx == facNodes.size()) {
|
||||||
|
cerr << "error: only joint distributions on variables of some " ;
|
||||||
|
cerr << "clique are supported with the current solver" ;
|
||||||
|
cerr << endl;
|
||||||
|
exit (1);
|
||||||
|
}
|
||||||
|
VarIds representatives;
|
||||||
|
for (size_t i = 0; i < queryVids.size(); i++) {
|
||||||
|
representatives.push_back (getRepresentative (queryVids[i]));
|
||||||
|
}
|
||||||
|
res = solver_->getJointDistributionOf (representatives);
|
||||||
}
|
}
|
||||||
return probs;
|
return res;
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
|
||||||
CbpSolver::getJointDistributionOf (const VarIds& jointVids)
|
|
||||||
{
|
|
||||||
VarIds eqVarIds;
|
|
||||||
for (size_t i = 0; i < jointVids.size(); i++) {
|
|
||||||
VarNode* vn = cfg_->getEquivalent (jointVids[i]);
|
|
||||||
eqVarIds.push_back (vn->varId());
|
|
||||||
}
|
|
||||||
return BpSolver::getJointDistributionOf (eqVarIds);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
CbpSolver::createLinks (void)
|
CbpSolver::findIdenticalFactors()
|
||||||
{
|
{
|
||||||
if (Globals::verbosity > 0) {
|
const FacNodes& facNodes = fg.facNodes();
|
||||||
cout << "compressed factor graph contains " ;
|
if (checkForIdenticalFactors == false ||
|
||||||
cout << fg_->nrVarNodes() << " variables and " ;
|
facNodes.size() == 1) {
|
||||||
cout << fg_->nrFacNodes() << " factors " << endl;
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
const FacClusters& fcs = cfg_->facClusters();
|
|
||||||
for (size_t i = 0; i < fcs.size(); i++) {
|
|
||||||
const VarClusters& vcs = fcs[i]->varClusters();
|
|
||||||
for (size_t j = 0; j < vcs.size(); j++) {
|
|
||||||
unsigned count = cfg_->getEdgeCount (fcs[i], vcs[j], j);
|
|
||||||
if (Globals::verbosity > 1) {
|
|
||||||
cout << "creating link " ;
|
|
||||||
cout << fcs[i]->representative()->getLabel();
|
|
||||||
cout << " -- " ;
|
|
||||||
cout << vcs[j]->representative()->label();
|
|
||||||
cout << " idx=" << j << ", count=" << count << endl;
|
|
||||||
}
|
|
||||||
links_.push_back (new CbpSolverLink (
|
|
||||||
fcs[i]->representative(), vcs[j]->representative(), j, count));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (Globals::verbosity > 1) {
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
CbpSolver::maxResidualSchedule (void)
|
|
||||||
{
|
|
||||||
if (nIters_ == 1) {
|
|
||||||
for (size_t i = 0; i < links_.size(); i++) {
|
|
||||||
calculateMessage (links_[i]);
|
|
||||||
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
|
|
||||||
linkMap_.insert (make_pair (links_[i], it));
|
|
||||||
if (Globals::verbosity >= 1) {
|
|
||||||
cout << "calculating " << links_[i]->toString() << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||||
for (size_t c = 0; c < links_.size(); c++) {
|
facNodes[i]->factor().setDistId (Util::maxUnsigned());
|
||||||
if (Globals::verbosity > 1) {
|
}
|
||||||
cout << endl << "current residuals:" << endl;
|
unsigned groupCount = 1;
|
||||||
for (SortedOrder::iterator it = sortedOrder_.begin();
|
for (size_t i = 0; i < facNodes.size() - 1; i++) {
|
||||||
it != sortedOrder_.end(); ++it) {
|
Factor& f1 = facNodes[i]->factor();
|
||||||
cout << " " << setw (30) << left << (*it)->toString();
|
if (f1.distId() != Util::maxUnsigned()) {
|
||||||
cout << "residual = " << (*it)->getResidual() << endl;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
f1.setDistId (groupCount);
|
||||||
|
for (size_t j = i + 1; j < facNodes.size(); j++) {
|
||||||
SortedOrder::iterator it = sortedOrder_.begin();
|
Factor& f2 = facNodes[j]->factor();
|
||||||
SpLink* link = *it;
|
if (f2.distId() != Util::maxUnsigned()) {
|
||||||
if (Globals::verbosity >= 1) {
|
continue;
|
||||||
cout << "updating " << (*sortedOrder_.begin())->toString() << endl;
|
}
|
||||||
}
|
if (f1.size() == f2.size() &&
|
||||||
if (link->getResidual() < BpOptions::accuracy) {
|
f1.ranges() == f2.ranges() &&
|
||||||
return;
|
f1.params() == f2.params()) {
|
||||||
}
|
f2.setDistId (groupCount);
|
||||||
link->updateMessage();
|
|
||||||
link->clearResidual();
|
|
||||||
sortedOrder_.erase (it);
|
|
||||||
linkMap_.find (link)->second = sortedOrder_.insert (link);
|
|
||||||
|
|
||||||
// update the messages that depend on message source --> destin
|
|
||||||
const FacNodes& factorNeighbors = link->getVariable()->neighbors();
|
|
||||||
for (size_t i = 0; i < factorNeighbors.size(); i++) {
|
|
||||||
const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks();
|
|
||||||
for (size_t j = 0; j < links.size(); j++) {
|
|
||||||
if (links[j]->getVariable() != link->getVariable()) {
|
|
||||||
if (Globals::verbosity > 1) {
|
|
||||||
cout << " calculating " << links[j]->toString() << endl;
|
|
||||||
}
|
|
||||||
calculateMessage (links[j]);
|
|
||||||
SpLinkMap::iterator iter = linkMap_.find (links[j]);
|
|
||||||
sortedOrder_.erase (iter->second);
|
|
||||||
iter->second = sortedOrder_.insert (links[j]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// in counting bp, the message that a variable X sends to
|
|
||||||
// to a factor F depends on the message that F sent to the X
|
|
||||||
const SpLinkSet& links = ninf(link->getFactor())->getLinks();
|
|
||||||
for (size_t i = 0; i < links.size(); i++) {
|
|
||||||
if (links[i]->getVariable() != link->getVariable()) {
|
|
||||||
if (Globals::verbosity > 1) {
|
|
||||||
cout << " calculating " << links[i]->toString() << endl;
|
|
||||||
}
|
|
||||||
calculateMessage (links[i]);
|
|
||||||
SpLinkMap::iterator iter = linkMap_.find (links[i]);
|
|
||||||
sortedOrder_.erase (iter->second);
|
|
||||||
iter->second = sortedOrder_.insert (links[i]);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
groupCount ++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
CbpSolver::calculateFactor2VariableMsg (SpLink* _link)
|
CbpSolver::setInitialColors (void)
|
||||||
{
|
{
|
||||||
CbpSolverLink* link = static_cast<CbpSolverLink*> (_link);
|
varColors_.resize (fg.nrVarNodes());
|
||||||
FacNode* src = link->getFactor();
|
facColors_.resize (fg.nrFacNodes());
|
||||||
const VarNode* dst = link->getVariable();
|
// create the initial variable colors
|
||||||
const SpLinkSet& links = ninf(src)->getLinks();
|
VarColorMap colorMap;
|
||||||
// calculate the product of messages that were sent
|
const VarNodes& varNodes = fg.varNodes();
|
||||||
// to factor `src', except from var `dst'
|
for (size_t i = 0; i < varNodes.size(); i++) {
|
||||||
unsigned msgSize = 1;
|
unsigned range = varNodes[i]->range();
|
||||||
for (size_t i = 0; i < links.size(); i++) {
|
VarColorMap::iterator it = colorMap.find (range);
|
||||||
msgSize *= links[i]->getVariable()->range();
|
if (it == colorMap.end()) {
|
||||||
}
|
it = colorMap.insert (make_pair (
|
||||||
unsigned repetitions = 1;
|
range, Colors (range + 1, -1))).first;
|
||||||
Params msgProduct (msgSize, LogAware::multIdenty());
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
for (size_t i = links.size(); i-- > 0; ) {
|
|
||||||
const CbpSolverLink* cl = static_cast<const CbpSolverLink*> (links[i]);
|
|
||||||
if ( ! (cl->getVariable() == dst && cl->index() == link->index())) {
|
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
|
||||||
cout << " message from " << links[i]->getVariable()->label();
|
|
||||||
cout << ": " ;
|
|
||||||
}
|
|
||||||
Util::add (msgProduct, getVar2FactorMsg (links[i]), repetitions);
|
|
||||||
repetitions *= links[i]->getVariable()->range();
|
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
unsigned range = links[i]->getVariable()->range();
|
|
||||||
Util::add (msgProduct, Params (range, 0.0), repetitions);
|
|
||||||
repetitions *= range;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else {
|
unsigned idx = varNodes[i]->hasEvidence()
|
||||||
for (size_t i = links.size(); i-- > 0; ) {
|
? varNodes[i]->getEvidence()
|
||||||
const CbpSolverLink* cl = static_cast<const CbpSolverLink*> (links[i]);
|
: range;
|
||||||
if ( ! (cl->getVariable() == dst && cl->index() == link->index())) {
|
Colors& stateColors = it->second;
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (stateColors[idx] == -1) {
|
||||||
cout << " message from " << links[i]->getVariable()->label();
|
stateColors[idx] = getNewColor();
|
||||||
cout << ": " ;
|
|
||||||
}
|
|
||||||
Util::multiply (msgProduct, getVar2FactorMsg (links[i]), repetitions);
|
|
||||||
repetitions *= links[i]->getVariable()->range();
|
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
unsigned range = links[i]->getVariable()->range();
|
|
||||||
Util::multiply (msgProduct, Params (range, 1.0), repetitions);
|
|
||||||
repetitions *= range;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
setColor (varNodes[i], stateColors[idx]);
|
||||||
}
|
}
|
||||||
Factor result (src->factor().arguments(),
|
const FacNodes& facNodes = fg.facNodes();
|
||||||
src->factor().ranges(), msgProduct);
|
// create the initial factor colors
|
||||||
assert (msgProduct.size() == src->factor().size());
|
DistColorMap distColors;
|
||||||
if (Globals::logDomain) {
|
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||||
for (size_t i = 0; i < result.size(); i++) {
|
unsigned distId = facNodes[i]->factor().distId();
|
||||||
result[i] += src->factor()[i];
|
DistColorMap::iterator it = distColors.find (distId);
|
||||||
}
|
if (it == distColors.end()) {
|
||||||
} else {
|
it = distColors.insert (make_pair (distId, getNewColor())).first;
|
||||||
for (size_t i = 0; i < result.size(); i++) {
|
|
||||||
result[i] *= src->factor()[i];
|
|
||||||
}
|
}
|
||||||
|
setColor (facNodes[i], it->second);
|
||||||
}
|
}
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
|
||||||
cout << " message product: " << msgProduct << endl;
|
|
||||||
cout << " original factor: " << src->factor().params() << endl;
|
|
||||||
cout << " factor product: " << result.params() << endl;
|
|
||||||
}
|
|
||||||
result.sumOutAllExceptIndex (link->index());
|
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
|
||||||
cout << " marginalized: " << result.params() << endl;
|
|
||||||
}
|
|
||||||
link->getNextMessage() = result.params();
|
|
||||||
LogAware::normalize (link->getNextMessage());
|
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
|
||||||
cout << " curr msg: " << link->getMessage() << endl;
|
|
||||||
cout << " next msg: " << link->getNextMessage() << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
|
||||||
CbpSolver::getVar2FactorMsg (const SpLink* _link) const
|
|
||||||
{
|
|
||||||
const CbpSolverLink* link = static_cast<const CbpSolverLink*> (_link);
|
|
||||||
const VarNode* src = link->getVariable();
|
|
||||||
const FacNode* dst = link->getFactor();
|
|
||||||
Params msg;
|
|
||||||
if (src->hasEvidence()) {
|
|
||||||
msg.resize (src->range(), LogAware::noEvidence());
|
|
||||||
double value = link->getMessage()[src->getEvidence()];
|
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
|
||||||
msg[src->getEvidence()] = value;
|
|
||||||
cout << msg << "^" << link->nrEdges() << "-1" ;
|
|
||||||
}
|
|
||||||
msg[src->getEvidence()] = LogAware::pow (value, link->nrEdges() - 1);
|
|
||||||
} else {
|
|
||||||
msg = link->getMessage();
|
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
|
||||||
cout << msg << "^" << link->nrEdges() << "-1" ;
|
|
||||||
}
|
|
||||||
LogAware::pow (msg, link->nrEdges() - 1);
|
|
||||||
}
|
|
||||||
const SpLinkSet& links = ninf(src)->getLinks();
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
for (size_t i = 0; i < links.size(); i++) {
|
|
||||||
CbpSolverLink* cl = static_cast<CbpSolverLink*> (links[i]);
|
|
||||||
if ( ! (cl->getFactor() == dst && cl->index() == link->index())) {
|
|
||||||
CbpSolverLink* cl = static_cast<CbpSolverLink*> (links[i]);
|
|
||||||
msg += cl->poweredMessage();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (size_t i = 0; i < links.size(); i++) {
|
|
||||||
CbpSolverLink* cl = static_cast<CbpSolverLink*> (links[i]);
|
|
||||||
if ( ! (cl->getFactor() == dst && cl->index() == link->index())) {
|
|
||||||
msg *= cl->poweredMessage();
|
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
|
||||||
cout << " x " << cl->getNextMessage() << "^" << link->nrEdges();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
|
||||||
cout << " = " << msg;
|
|
||||||
}
|
|
||||||
return msg;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
CbpSolver::printLinkInformation (void) const
|
CbpSolver::createGroups (void)
|
||||||
{
|
{
|
||||||
for (size_t i = 0; i < links_.size(); i++) {
|
VarSignMap varGroups;
|
||||||
CbpSolverLink* cl = static_cast<CbpSolverLink*> (links_[i]);
|
FacSignMap facGroups;
|
||||||
cout << cl->toString() << ":" << endl;
|
unsigned nIters = 0;
|
||||||
cout << " curr msg = " << cl->getMessage() << endl;
|
bool groupsHaveChanged = true;
|
||||||
cout << " next msg = " << cl->getNextMessage() << endl;
|
const VarNodes& varNodes = fg.varNodes();
|
||||||
cout << " index = " << cl->index() << endl;
|
const FacNodes& facNodes = fg.facNodes();
|
||||||
cout << " nr edges = " << cl->nrEdges() << endl;
|
|
||||||
cout << " powered = " << cl->poweredMessage() << endl;
|
while (groupsHaveChanged || nIters == 1) {
|
||||||
cout << " residual = " << cl->getResidual() << endl;
|
nIters ++;
|
||||||
|
|
||||||
|
// set a new color to the variables with the same signature
|
||||||
|
size_t prevVarGroupsSize = varGroups.size();
|
||||||
|
varGroups.clear();
|
||||||
|
for (size_t i = 0; i < varNodes.size(); i++) {
|
||||||
|
const VarSignature& signature = getSignature (varNodes[i]);
|
||||||
|
VarSignMap::iterator it = varGroups.find (signature);
|
||||||
|
if (it == varGroups.end()) {
|
||||||
|
it = varGroups.insert (make_pair (signature, VarNodes())).first;
|
||||||
|
}
|
||||||
|
it->second.push_back (varNodes[i]);
|
||||||
|
}
|
||||||
|
for (VarSignMap::iterator it = varGroups.begin();
|
||||||
|
it != varGroups.end(); ++it) {
|
||||||
|
Color newColor = getNewColor();
|
||||||
|
VarNodes& groupMembers = it->second;
|
||||||
|
for (size_t i = 0; i < groupMembers.size(); i++) {
|
||||||
|
setColor (groupMembers[i], newColor);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t prevFactorGroupsSize = facGroups.size();
|
||||||
|
facGroups.clear();
|
||||||
|
// set a new color to the factors with the same signature
|
||||||
|
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||||
|
const FacSignature& signature = getSignature (facNodes[i]);
|
||||||
|
FacSignMap::iterator it = facGroups.find (signature);
|
||||||
|
if (it == facGroups.end()) {
|
||||||
|
it = facGroups.insert (make_pair (signature, FacNodes())).first;
|
||||||
|
}
|
||||||
|
it->second.push_back (facNodes[i]);
|
||||||
|
}
|
||||||
|
for (FacSignMap::iterator it = facGroups.begin();
|
||||||
|
it != facGroups.end(); ++it) {
|
||||||
|
Color newColor = getNewColor();
|
||||||
|
FacNodes& groupMembers = it->second;
|
||||||
|
for (size_t i = 0; i < groupMembers.size(); i++) {
|
||||||
|
setColor (groupMembers[i], newColor);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
groupsHaveChanged = prevVarGroupsSize != varGroups.size()
|
||||||
|
|| prevFactorGroupsSize != facGroups.size();
|
||||||
|
}
|
||||||
|
// printGroups (varGroups, facGroups);
|
||||||
|
createClusters (varGroups, facGroups);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
CbpSolver::createClusters (
|
||||||
|
const VarSignMap& varGroups,
|
||||||
|
const FacSignMap& facGroups)
|
||||||
|
{
|
||||||
|
varClusters_.reserve (varGroups.size());
|
||||||
|
for (VarSignMap::const_iterator it = varGroups.begin();
|
||||||
|
it != varGroups.end(); ++it) {
|
||||||
|
const VarNodes& groupVars = it->second;
|
||||||
|
VarCluster* vc = new VarCluster (groupVars);
|
||||||
|
for (size_t i = 0; i < groupVars.size(); i++) {
|
||||||
|
vid2VarCluster_.insert (make_pair (groupVars[i]->varId(), vc));
|
||||||
|
}
|
||||||
|
varClusters_.push_back (vc);
|
||||||
|
}
|
||||||
|
|
||||||
|
facClusters_.reserve (facGroups.size());
|
||||||
|
for (FacSignMap::const_iterator it = facGroups.begin();
|
||||||
|
it != facGroups.end(); ++it) {
|
||||||
|
FacNode* groupFactor = it->second[0];
|
||||||
|
const VarNodes& neighs = groupFactor->neighbors();
|
||||||
|
VarClusters varClusters;
|
||||||
|
varClusters.reserve (neighs.size());
|
||||||
|
for (size_t i = 0; i < neighs.size(); i++) {
|
||||||
|
VarId vid = neighs[i]->varId();
|
||||||
|
varClusters.push_back (vid2VarCluster_.find (vid)->second);
|
||||||
|
}
|
||||||
|
facClusters_.push_back (new FacCluster (it->second, varClusters));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
VarSignature
|
||||||
|
CbpSolver::getSignature (const VarNode* varNode)
|
||||||
|
{
|
||||||
|
const FacNodes& neighs = varNode->neighbors();
|
||||||
|
VarSignature sign;
|
||||||
|
sign.reserve (neighs.size() + 1);
|
||||||
|
for (size_t i = 0; i < neighs.size(); i++) {
|
||||||
|
sign.push_back (make_pair (
|
||||||
|
getColor (neighs[i]),
|
||||||
|
neighs[i]->factor().indexOf (varNode->varId())));
|
||||||
|
}
|
||||||
|
std::sort (sign.begin(), sign.end());
|
||||||
|
sign.push_back (make_pair (getColor (varNode), 0));
|
||||||
|
return sign;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
FacSignature
|
||||||
|
CbpSolver::getSignature (const FacNode* facNode)
|
||||||
|
{
|
||||||
|
const VarNodes& neighs = facNode->neighbors();
|
||||||
|
FacSignature sign;
|
||||||
|
sign.reserve (neighs.size() + 1);
|
||||||
|
for (size_t i = 0; i < neighs.size(); i++) {
|
||||||
|
sign.push_back (getColor (neighs[i]));
|
||||||
|
}
|
||||||
|
sign.push_back (getColor (facNode));
|
||||||
|
return sign;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
FactorGraph*
|
||||||
|
CbpSolver::getCompressedFactorGraph (void)
|
||||||
|
{
|
||||||
|
FactorGraph* fg = new FactorGraph();
|
||||||
|
for (size_t i = 0; i < varClusters_.size(); i++) {
|
||||||
|
VarNode* newVar = new VarNode (varClusters_[i]->first());
|
||||||
|
varClusters_[i]->setRepresentative (newVar);
|
||||||
|
fg->addVarNode (newVar);
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < facClusters_.size(); i++) {
|
||||||
|
Vars vars;
|
||||||
|
const VarClusters& clusters = facClusters_[i]->varClusters();
|
||||||
|
for (size_t j = 0; j < clusters.size(); j++) {
|
||||||
|
vars.push_back (clusters[j]->representative());
|
||||||
|
}
|
||||||
|
const Factor& groundFac = facClusters_[i]->first()->factor();
|
||||||
|
FacNode* fn = new FacNode (Factor (
|
||||||
|
vars, groundFac.params(), groundFac.distId()));
|
||||||
|
facClusters_[i]->setRepresentative (fn);
|
||||||
|
fg->addFacNode (fn);
|
||||||
|
for (size_t j = 0; j < vars.size(); j++) {
|
||||||
|
fg->addEdge (static_cast<VarNode*> (vars[j]), fn);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fg;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
vector<vector<unsigned>>
|
||||||
|
CbpSolver::getWeights (void) const
|
||||||
|
{
|
||||||
|
vector<vector<unsigned>> weights;
|
||||||
|
weights.reserve (facClusters_.size());
|
||||||
|
for (size_t i = 0; i < facClusters_.size(); i++) {
|
||||||
|
const VarClusters& neighs = facClusters_[i]->varClusters();
|
||||||
|
weights.push_back ({ });
|
||||||
|
weights.back().reserve (neighs.size());
|
||||||
|
for (size_t j = 0; j < neighs.size(); j++) {
|
||||||
|
weights.back().push_back (getWeight (
|
||||||
|
facClusters_[i], neighs[j], j));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return weights;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
unsigned
|
||||||
|
CbpSolver::getWeight (
|
||||||
|
const FacCluster* fc,
|
||||||
|
const VarCluster* vc,
|
||||||
|
size_t index) const
|
||||||
|
{
|
||||||
|
unsigned weight = 0;
|
||||||
|
VarId reprVid = vc->representative()->varId();
|
||||||
|
VarNode* groundVar = fg.getVarNode (reprVid);
|
||||||
|
const FacNodes& neighs = groundVar->neighbors();
|
||||||
|
for (size_t i = 0; i < neighs.size(); i++) {
|
||||||
|
FacNodes::const_iterator it;
|
||||||
|
it = std::find (fc->members().begin(), fc->members().end(), neighs[i]);
|
||||||
|
if (it != fc->members().end() &&
|
||||||
|
(*it)->factor().indexOf (reprVid) == index) {
|
||||||
|
weight ++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return weight;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
CbpSolver::printGroups (
|
||||||
|
const VarSignMap& varGroups,
|
||||||
|
const FacSignMap& facGroups) const
|
||||||
|
{
|
||||||
|
unsigned count = 1;
|
||||||
|
cout << "variable groups:" << endl;
|
||||||
|
for (VarSignMap::const_iterator it = varGroups.begin();
|
||||||
|
it != varGroups.end(); ++it) {
|
||||||
|
const VarNodes& groupMembers = it->second;
|
||||||
|
if (groupMembers.size() > 0) {
|
||||||
|
cout << count << ": " ;
|
||||||
|
for (size_t i = 0; i < groupMembers.size(); i++) {
|
||||||
|
cout << groupMembers[i]->label() << " " ;
|
||||||
|
}
|
||||||
|
count ++;
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
count = 1;
|
||||||
|
cout << endl << "factor groups:" << endl;
|
||||||
|
for (FacSignMap::const_iterator it = facGroups.begin();
|
||||||
|
it != facGroups.end(); ++it) {
|
||||||
|
const FacNodes& groupMembers = it->second;
|
||||||
|
if (groupMembers.size() > 0) {
|
||||||
|
cout << ++count << ": " ;
|
||||||
|
for (size_t i = 0; i < groupMembers.size(); i++) {
|
||||||
|
cout << groupMembers[i]->getLabel() << " " ;
|
||||||
|
}
|
||||||
|
count ++;
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,67 +1,183 @@
|
|||||||
#ifndef HORUS_CBP_H
|
#ifndef HORUS_CBPSOLVER_H
|
||||||
#define HORUS_CBP_H
|
#define HORUS_CBPSOLVER_H
|
||||||
|
|
||||||
#include "BpSolver.h"
|
#include <unordered_map>
|
||||||
#include "CFactorGraph.h"
|
|
||||||
|
|
||||||
class Factor;
|
#include "Solver.h"
|
||||||
|
#include "FactorGraph.h"
|
||||||
|
#include "Util.h"
|
||||||
|
#include "Horus.h"
|
||||||
|
|
||||||
class CbpSolverLink : public SpLink
|
class VarCluster;
|
||||||
|
class FacCluster;
|
||||||
|
class VarSignHash;
|
||||||
|
class FacSignHash;
|
||||||
|
class WeightedBpSolver;
|
||||||
|
|
||||||
|
typedef long Color;
|
||||||
|
typedef vector<Color> Colors;
|
||||||
|
typedef vector<std::pair<Color,unsigned>> VarSignature;
|
||||||
|
typedef vector<Color> FacSignature;
|
||||||
|
|
||||||
|
typedef unordered_map<unsigned, Color> DistColorMap;
|
||||||
|
typedef unordered_map<unsigned, Colors> VarColorMap;
|
||||||
|
|
||||||
|
typedef unordered_map<VarSignature, VarNodes, VarSignHash> VarSignMap;
|
||||||
|
typedef unordered_map<FacSignature, FacNodes, FacSignHash> FacSignMap;
|
||||||
|
|
||||||
|
typedef vector<VarCluster*> VarClusters;
|
||||||
|
typedef vector<FacCluster*> FacClusters;
|
||||||
|
|
||||||
|
typedef unordered_map<VarId, VarCluster*> VarId2VarCluster;
|
||||||
|
|
||||||
|
|
||||||
|
struct VarSignHash
|
||||||
{
|
{
|
||||||
public:
|
size_t operator() (const VarSignature &sig) const
|
||||||
CbpSolverLink (FacNode* fn, VarNode* vn, size_t idx, unsigned count)
|
{
|
||||||
: SpLink (fn, vn), index_(idx), nrEdges_(count),
|
size_t val = hash<size_t>()(sig.size());
|
||||||
pwdMsg_(vn->range(), LogAware::one()) { }
|
for (size_t i = 0; i < sig.size(); i++) {
|
||||||
|
val ^= hash<size_t>()(sig[i].first);
|
||||||
size_t index (void) const { return index_; }
|
val ^= hash<size_t>()(sig[i].second);
|
||||||
|
|
||||||
unsigned nrEdges (void) const { return nrEdges_; }
|
|
||||||
|
|
||||||
const Params& poweredMessage (void) const { return pwdMsg_; }
|
|
||||||
|
|
||||||
void updateMessage (void)
|
|
||||||
{
|
|
||||||
pwdMsg_ = *nextMsg_;
|
|
||||||
swap (currMsg_, nextMsg_);
|
|
||||||
msgSended_ = true;
|
|
||||||
LogAware::pow (pwdMsg_, nrEdges_);
|
|
||||||
}
|
}
|
||||||
|
return val;
|
||||||
private:
|
}
|
||||||
size_t index_;
|
|
||||||
unsigned nrEdges_;
|
|
||||||
Params pwdMsg_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
struct FacSignHash
|
||||||
class CbpSolver : public BpSolver
|
|
||||||
{
|
{
|
||||||
public:
|
size_t operator() (const FacSignature &sig) const
|
||||||
|
{
|
||||||
|
size_t val = hash<size_t>()(sig.size());
|
||||||
|
for (size_t i = 0; i < sig.size(); i++) {
|
||||||
|
val ^= hash<size_t>()(sig[i]);
|
||||||
|
}
|
||||||
|
return val;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
class VarCluster
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
VarCluster (const VarNodes& vs) : members_(vs) { }
|
||||||
|
|
||||||
|
const VarNode* first (void) const { return members_.front(); }
|
||||||
|
|
||||||
|
const VarNodes& members (void) const { return members_; }
|
||||||
|
|
||||||
|
VarNode* representative (void) const { return repr_; }
|
||||||
|
|
||||||
|
void setRepresentative (VarNode* vn) { repr_ = vn; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
VarNodes members_;
|
||||||
|
VarNode* repr_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
class FacCluster
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
FacCluster (const FacNodes& fcs, const VarClusters& vcs)
|
||||||
|
: members_(fcs), varClusters_(vcs) { }
|
||||||
|
|
||||||
|
const FacNode* first (void) const { return members_.front(); }
|
||||||
|
|
||||||
|
const FacNodes& members (void) const { return members_; }
|
||||||
|
|
||||||
|
VarClusters& varClusters (void) { return varClusters_; }
|
||||||
|
|
||||||
|
FacNode* representative (void) const { return repr_; }
|
||||||
|
|
||||||
|
void setRepresentative (FacNode* fn) { repr_ = fn; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
FacNodes members_;
|
||||||
|
VarClusters varClusters_;
|
||||||
|
FacNode* repr_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
class CbpSolver : public Solver
|
||||||
|
{
|
||||||
|
public:
|
||||||
CbpSolver (const FactorGraph& fg);
|
CbpSolver (const FactorGraph& fg);
|
||||||
|
|
||||||
~CbpSolver (void);
|
~CbpSolver (void);
|
||||||
|
|
||||||
void printSolverFlags (void) const;
|
void printSolverFlags (void) const;
|
||||||
|
|
||||||
Params getPosterioriOf (VarId);
|
|
||||||
|
|
||||||
Params getJointDistributionOf (const VarIds&);
|
Params solveQuery (VarIds);
|
||||||
|
|
||||||
|
static bool checkForIdenticalFactors;
|
||||||
|
|
||||||
|
private:
|
||||||
|
Color getNewColor (void)
|
||||||
|
{
|
||||||
|
++ freeColor_;
|
||||||
|
return freeColor_ - 1;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
Color getColor (const VarNode* vn) const
|
||||||
|
{
|
||||||
|
return varColors_[vn->getIndex()];
|
||||||
|
}
|
||||||
|
|
||||||
void createLinks (void);
|
Color getColor (const FacNode* fn) const
|
||||||
|
{
|
||||||
|
return facColors_[fn->getIndex()];
|
||||||
|
}
|
||||||
|
|
||||||
void maxResidualSchedule (void);
|
void setColor (const VarNode* vn, Color c)
|
||||||
|
{
|
||||||
|
varColors_[vn->getIndex()] = c;
|
||||||
|
}
|
||||||
|
|
||||||
void calculateFactor2VariableMsg (SpLink*);
|
void setColor (const FacNode* fn, Color c)
|
||||||
|
{
|
||||||
|
facColors_[fn->getIndex()] = c;
|
||||||
|
}
|
||||||
|
|
||||||
Params getVar2FactorMsg (const SpLink*) const;
|
void findIdenticalFactors (void);
|
||||||
|
|
||||||
void printLinkInformation (void) const;
|
void setInitialColors (void);
|
||||||
|
|
||||||
CFactorGraph* cfg_;
|
void createGroups (void);
|
||||||
|
|
||||||
|
void createClusters (const VarSignMap&, const FacSignMap&);
|
||||||
|
|
||||||
|
VarSignature getSignature (const VarNode*);
|
||||||
|
|
||||||
|
FacSignature getSignature (const FacNode*);
|
||||||
|
|
||||||
|
void printGroups (const VarSignMap&, const FacSignMap&) const;
|
||||||
|
|
||||||
|
VarId getRepresentative (VarId vid)
|
||||||
|
{
|
||||||
|
assert (Util::contains (vid2VarCluster_, vid));
|
||||||
|
VarCluster* vc = vid2VarCluster_.find (vid)->second;
|
||||||
|
return vc->representative()->varId();
|
||||||
|
}
|
||||||
|
|
||||||
|
FactorGraph* getCompressedFactorGraph (void);
|
||||||
|
|
||||||
|
vector<vector<unsigned>> getWeights (void) const;
|
||||||
|
|
||||||
|
unsigned getWeight (const FacCluster*,
|
||||||
|
const VarCluster*, size_t index) const;
|
||||||
|
|
||||||
|
|
||||||
|
Color freeColor_;
|
||||||
|
Colors varColors_;
|
||||||
|
Colors facColors_;
|
||||||
|
VarClusters varClusters_;
|
||||||
|
FacClusters facClusters_;
|
||||||
|
VarId2VarCluster vid2VarCluster_;
|
||||||
|
const FactorGraph* compressedFg_;
|
||||||
|
WeightedBpSolver* solver_;
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // HORUS_CBP_H
|
#endif // HORUS_CBPSOLVER_H
|
||||||
|
|
||||||
|
@ -139,12 +139,21 @@ ElimGraph::exportToGraphViz (
|
|||||||
|
|
||||||
VarIds
|
VarIds
|
||||||
ElimGraph::getEliminationOrder (
|
ElimGraph::getEliminationOrder (
|
||||||
const vector<Factor*> factors,
|
const Factors& factors,
|
||||||
VarIds excludedVids)
|
VarIds excludedVids)
|
||||||
{
|
{
|
||||||
|
if (elimHeuristic == ElimHeuristic::SEQUENTIAL) {
|
||||||
|
VarIds allVids;
|
||||||
|
Factors::const_iterator first = factors.begin();
|
||||||
|
Factors::const_iterator end = factors.end();
|
||||||
|
for (; first != end; ++first) {
|
||||||
|
Util::addToVector (allVids, (*first)->arguments());
|
||||||
|
}
|
||||||
|
TinySet<VarId> elimOrder (allVids);
|
||||||
|
elimOrder -= TinySet<VarId> (excludedVids);
|
||||||
|
return elimOrder.elements();
|
||||||
|
}
|
||||||
ElimGraph graph (factors);
|
ElimGraph graph (factors);
|
||||||
// graph.print();
|
|
||||||
// graph.exportToGraphViz ("_egg.dot");
|
|
||||||
return graph.getEliminatingOrder (excludedVids);
|
return graph.getEliminatingOrder (excludedVids);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -12,6 +12,7 @@ using namespace std;
|
|||||||
|
|
||||||
enum ElimHeuristic
|
enum ElimHeuristic
|
||||||
{
|
{
|
||||||
|
SEQUENTIAL,
|
||||||
MIN_NEIGHBORS,
|
MIN_NEIGHBORS,
|
||||||
MIN_WEIGHT,
|
MIN_WEIGHT,
|
||||||
MIN_FILL,
|
MIN_FILL,
|
||||||
@ -45,7 +46,7 @@ class EgNode : public Var
|
|||||||
class ElimGraph
|
class ElimGraph
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
ElimGraph (const vector<Factor*>&); // TODO
|
ElimGraph (const Factors&);
|
||||||
|
|
||||||
~ElimGraph (void);
|
~ElimGraph (void);
|
||||||
|
|
||||||
@ -56,7 +57,7 @@ class ElimGraph
|
|||||||
void exportToGraphViz (const char*, bool = true,
|
void exportToGraphViz (const char*, bool = true,
|
||||||
const VarIds& = VarIds()) const;
|
const VarIds& = VarIds()) const;
|
||||||
|
|
||||||
static VarIds getEliminationOrder (const vector<Factor*>, VarIds);
|
static VarIds getEliminationOrder (const Factors&, VarIds);
|
||||||
|
|
||||||
static ElimHeuristic elimHeuristic;
|
static ElimHeuristic elimHeuristic;
|
||||||
|
|
||||||
|
@ -187,7 +187,7 @@ class TFactor
|
|||||||
|
|
||||||
bool contains (const vector<T>& args) const
|
bool contains (const vector<T>& args) const
|
||||||
{
|
{
|
||||||
for (size_t i = 0; i < args_.size(); i++) {
|
for (size_t i = 0; i < args.size(); i++) {
|
||||||
if (contains (args[i]) == false) {
|
if (contains (args[i]) == false) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -28,7 +28,7 @@ FactorGraph::FactorGraph (const FactorGraph& fg)
|
|||||||
addEdge (varNodes_[neighs[j]->getIndex()], facNode);
|
addEdge (varNodes_[neighs[j]->getIndex()], facNode);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
fromBayesNet_ = fg.isFromBayesNetwork();
|
bayesFactors_ = fg.bayesianFactors();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -239,7 +239,7 @@ FactorGraph::isTree (void) const
|
|||||||
DAGraph&
|
DAGraph&
|
||||||
FactorGraph::getStructure (void)
|
FactorGraph::getStructure (void)
|
||||||
{
|
{
|
||||||
assert (fromBayesNet_);
|
assert (bayesFactors_);
|
||||||
if (structure_.empty()) {
|
if (structure_.empty()) {
|
||||||
for (size_t i = 0; i < varNodes_.size(); i++) {
|
for (size_t i = 0; i < varNodes_.size(); i++) {
|
||||||
structure_.addNode (new DAGraphNode (varNodes_[i]));
|
structure_.addNode (new DAGraphNode (varNodes_[i]));
|
||||||
|
@ -65,7 +65,7 @@ class FacNode
|
|||||||
class FactorGraph
|
class FactorGraph
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
FactorGraph (bool fbn = false) : fromBayesNet_(fbn) { }
|
FactorGraph (void) : bayesFactors_(false) { }
|
||||||
|
|
||||||
FactorGraph (const FactorGraph&);
|
FactorGraph (const FactorGraph&);
|
||||||
|
|
||||||
@ -74,8 +74,10 @@ class FactorGraph
|
|||||||
const VarNodes& varNodes (void) const { return varNodes_; }
|
const VarNodes& varNodes (void) const { return varNodes_; }
|
||||||
|
|
||||||
const FacNodes& facNodes (void) const { return facNodes_; }
|
const FacNodes& facNodes (void) const { return facNodes_; }
|
||||||
|
|
||||||
|
void setFactorsAsBayesian (void) { bayesFactors_ = true; }
|
||||||
|
|
||||||
bool isFromBayesNetwork (void) const { return fromBayesNet_ ; }
|
bool bayesianFactors (void) const { return bayesFactors_ ; }
|
||||||
|
|
||||||
size_t nrVarNodes (void) const { return varNodes_.size(); }
|
size_t nrVarNodes (void) const { return varNodes_.size(); }
|
||||||
|
|
||||||
@ -128,7 +130,7 @@ class FactorGraph
|
|||||||
FacNodes facNodes_;
|
FacNodes facNodes_;
|
||||||
|
|
||||||
DAGraph structure_;
|
DAGraph structure_;
|
||||||
bool fromBayesNet_;
|
bool bayesFactors_;
|
||||||
|
|
||||||
typedef unordered_map<unsigned, VarNode*> VarMap;
|
typedef unordered_map<unsigned, VarNode*> VarMap;
|
||||||
VarMap varMap_;
|
VarMap varMap_;
|
||||||
|
@ -630,16 +630,9 @@ GroundOperator::getAffectedFormulas (void)
|
|||||||
|
|
||||||
|
|
||||||
Params
|
Params
|
||||||
FoveSolver::getPosterioriOf (const Ground& query)
|
FoveSolver::solveQuery (const Grounds& query)
|
||||||
{
|
|
||||||
return getJointDistributionOf ({query});
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
|
||||||
FoveSolver::getJointDistributionOf (const Grounds& query)
|
|
||||||
{
|
{
|
||||||
|
assert (query.empty() == false);
|
||||||
runSolver (query);
|
runSolver (query);
|
||||||
(*pfList_.begin())->normalize();
|
(*pfList_.begin())->normalize();
|
||||||
Params params = (*pfList_.begin())->params();
|
Params params = (*pfList_.begin())->params();
|
||||||
@ -970,7 +963,8 @@ FoveSolver::absorve (
|
|||||||
if (commCt->empty() == false) {
|
if (commCt->empty() == false) {
|
||||||
if (formulas.size() > 1) {
|
if (formulas.size() > 1) {
|
||||||
LogVarSet excl = g->exclusiveLogVars (i);
|
LogVarSet excl = g->exclusiveLogVars (i);
|
||||||
Parfactors countNormPfs = countNormalize (g, excl);
|
Parfactor tempPf (g, commCt);
|
||||||
|
Parfactors countNormPfs = countNormalize (&tempPf, excl);
|
||||||
for (size_t j = 0; j < countNormPfs.size(); j++) {
|
for (size_t j = 0; j < countNormPfs.size(); j++) {
|
||||||
countNormPfs[j]->absorveEvidence (
|
countNormPfs[j]->absorveEvidence (
|
||||||
formulas[i], obsFormula.evidence());
|
formulas[i], obsFormula.evidence());
|
||||||
|
@ -135,9 +135,7 @@ class FoveSolver
|
|||||||
public:
|
public:
|
||||||
FoveSolver (const ParfactorList& pfList) : pfList_(pfList) { }
|
FoveSolver (const ParfactorList& pfList) : pfList_(pfList) { }
|
||||||
|
|
||||||
Params getPosterioriOf (const Ground&);
|
Params solveQuery (const Grounds&);
|
||||||
|
|
||||||
Params getJointDistributionOf (const Grounds&);
|
|
||||||
|
|
||||||
void printSolverFlags (void) const;
|
void printSolverFlags (void) const;
|
||||||
|
|
||||||
|
@ -28,11 +28,18 @@ typedef vector<unsigned> Ranges;
|
|||||||
typedef unsigned long long ullong;
|
typedef unsigned long long ullong;
|
||||||
|
|
||||||
|
|
||||||
enum InfAlgorithms
|
enum LiftedSolvers
|
||||||
{
|
{
|
||||||
VE, // variable elimination
|
FOVE, // first order variable elimination
|
||||||
BP, // belief propagation
|
LBP, // lifted belief propagation
|
||||||
CBP // counting belief propagation
|
};
|
||||||
|
|
||||||
|
|
||||||
|
enum GroundSolvers
|
||||||
|
{
|
||||||
|
VE, // variable elimination
|
||||||
|
BP, // belief propagation
|
||||||
|
CBP // counting belief propagation
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
@ -43,7 +50,8 @@ extern bool logDomain;
|
|||||||
// level of debug information
|
// level of debug information
|
||||||
extern unsigned verbosity;
|
extern unsigned verbosity;
|
||||||
|
|
||||||
extern InfAlgorithms infAlgorithm;
|
extern LiftedSolvers liftedSolver;
|
||||||
|
extern GroundSolvers groundSolver;
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -161,14 +161,14 @@ void
|
|||||||
runSolver (const FactorGraph& fg, const VarIds& queryIds)
|
runSolver (const FactorGraph& fg, const VarIds& queryIds)
|
||||||
{
|
{
|
||||||
Solver* solver = 0;
|
Solver* solver = 0;
|
||||||
switch (Globals::infAlgorithm) {
|
switch (Globals::groundSolver) {
|
||||||
case InfAlgorithms::VE:
|
case GroundSolvers::VE:
|
||||||
solver = new VarElimSolver (fg);
|
solver = new VarElimSolver (fg);
|
||||||
break;
|
break;
|
||||||
case InfAlgorithms::BP:
|
case GroundSolvers::BP:
|
||||||
solver = new BpSolver (fg);
|
solver = new BpSolver (fg);
|
||||||
break;
|
break;
|
||||||
case InfAlgorithms::CBP:
|
case GroundSolvers::CBP:
|
||||||
solver = new CbpSolver (fg);
|
solver = new CbpSolver (fg);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
@ -178,7 +178,7 @@ runSolver (const FactorGraph& fg, const VarIds& queryIds)
|
|||||||
solver->printSolverFlags();
|
solver->printSolverFlags();
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
if (queryIds.size() == 0) {
|
if (queryIds.empty()) {
|
||||||
solver->printAllPosterioris();
|
solver->printAllPosterioris();
|
||||||
} else {
|
} else {
|
||||||
solver->printAnswer (queryIds);
|
solver->printAnswer (queryIds);
|
||||||
|
@ -11,6 +11,7 @@
|
|||||||
#include "FactorGraph.h"
|
#include "FactorGraph.h"
|
||||||
#include "FoveSolver.h"
|
#include "FoveSolver.h"
|
||||||
#include "VarElimSolver.h"
|
#include "VarElimSolver.h"
|
||||||
|
#include "LiftedBpSolver.h"
|
||||||
#include "BpSolver.h"
|
#include "BpSolver.h"
|
||||||
#include "CbpSolver.h"
|
#include "CbpSolver.h"
|
||||||
#include "ElimGraph.h"
|
#include "ElimGraph.h"
|
||||||
@ -218,8 +219,10 @@ int
|
|||||||
createGroundNetwork (void)
|
createGroundNetwork (void)
|
||||||
{
|
{
|
||||||
string factorsType ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG1)));
|
string factorsType ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG1)));
|
||||||
bool fromBayesNet = factorsType == "bayes";
|
FactorGraph* fg = new FactorGraph();
|
||||||
FactorGraph* fg = new FactorGraph (fromBayesNet);
|
if (factorsType == "bayes") {
|
||||||
|
fg->setFactorsAsBayesian();
|
||||||
|
}
|
||||||
YAP_Term factorList = YAP_ARG2;
|
YAP_Term factorList = YAP_ARG2;
|
||||||
while (factorList != YAP_TermNil()) {
|
while (factorList != YAP_TermNil()) {
|
||||||
YAP_Term factor = YAP_HeadOfTerm (factorList);
|
YAP_Term factor = YAP_HeadOfTerm (factorList);
|
||||||
@ -308,15 +311,22 @@ runLiftedSolver (void)
|
|||||||
}
|
}
|
||||||
jointList = YAP_TailOfTerm (jointList);
|
jointList = YAP_TailOfTerm (jointList);
|
||||||
}
|
}
|
||||||
FoveSolver solver (pfListCopy);
|
if (Globals::liftedSolver == LiftedSolvers::FOVE) {
|
||||||
if (Globals::verbosity > 0 && taskList == YAP_ARG2) {
|
FoveSolver solver (pfListCopy);
|
||||||
solver.printSolverFlags();
|
if (Globals::verbosity > 0 && taskList == YAP_ARG2) {
|
||||||
cout << endl;
|
solver.printSolverFlags();
|
||||||
}
|
cout << endl;
|
||||||
if (queryVars.size() == 1) {
|
}
|
||||||
results.push_back (solver.getPosterioriOf (queryVars[0]));
|
results.push_back (solver.solveQuery (queryVars));
|
||||||
|
} else if (Globals::liftedSolver == LiftedSolvers::LBP) {
|
||||||
|
LiftedBpSolver solver (pfListCopy);
|
||||||
|
if (Globals::verbosity > 0 && taskList == YAP_ARG2) {
|
||||||
|
solver.printSolverFlags();
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
results.push_back (solver.solveQuery (queryVars));
|
||||||
} else {
|
} else {
|
||||||
results.push_back (solver.getJointDistributionOf (queryVars));
|
assert (false);
|
||||||
}
|
}
|
||||||
taskList = YAP_TailOfTerm (taskList);
|
taskList = YAP_TailOfTerm (taskList);
|
||||||
}
|
}
|
||||||
@ -352,7 +362,7 @@ runGroundSolver (void)
|
|||||||
}
|
}
|
||||||
|
|
||||||
vector<Params> results;
|
vector<Params> results;
|
||||||
if (Globals::infAlgorithm == InfAlgorithms::VE) {
|
if (Globals::groundSolver == GroundSolvers::VE) {
|
||||||
runVeSolver (fg, tasks, results);
|
runVeSolver (fg, tasks, results);
|
||||||
} else {
|
} else {
|
||||||
runBpSolver (fg, tasks, results);
|
runBpSolver (fg, tasks, results);
|
||||||
@ -384,7 +394,7 @@ void runVeSolver (
|
|||||||
results.reserve (tasks.size());
|
results.reserve (tasks.size());
|
||||||
for (size_t i = 0; i < tasks.size(); i++) {
|
for (size_t i = 0; i < tasks.size(); i++) {
|
||||||
FactorGraph* mfg = fg;
|
FactorGraph* mfg = fg;
|
||||||
if (fg->isFromBayesNetwork()) {
|
if (fg->bayesianFactors()) {
|
||||||
// mfg = BayesBall::getMinimalFactorGraph (*fg, tasks[i]);
|
// mfg = BayesBall::getMinimalFactorGraph (*fg, tasks[i]);
|
||||||
}
|
}
|
||||||
// VarElimSolver solver (*mfg);
|
// VarElimSolver solver (*mfg);
|
||||||
@ -394,7 +404,7 @@ void runVeSolver (
|
|||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
results.push_back (solver.solveQuery (tasks[i]));
|
results.push_back (solver.solveQuery (tasks[i]));
|
||||||
if (fg->isFromBayesNetwork()) {
|
if (fg->bayesianFactors()) {
|
||||||
// delete mfg;
|
// delete mfg;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -413,14 +423,14 @@ void runBpSolver (
|
|||||||
}
|
}
|
||||||
Solver* solver = 0;
|
Solver* solver = 0;
|
||||||
FactorGraph* mfg = fg;
|
FactorGraph* mfg = fg;
|
||||||
if (fg->isFromBayesNetwork()) {
|
if (fg->bayesianFactors()) {
|
||||||
//mfg = BayesBall::getMinimalFactorGraph (
|
//mfg = BayesBall::getMinimalFactorGraph (
|
||||||
// *fg, VarIds (vids.begin(),vids.end()));
|
// *fg, VarIds (vids.begin(),vids.end()));
|
||||||
}
|
}
|
||||||
if (Globals::infAlgorithm == InfAlgorithms::BP) {
|
if (Globals::groundSolver == GroundSolvers::BP) {
|
||||||
solver = new BpSolver (*fg); // FIXME
|
solver = new BpSolver (*fg); // FIXME
|
||||||
} else if (Globals::infAlgorithm == InfAlgorithms::CBP) {
|
} else if (Globals::groundSolver == GroundSolvers::CBP) {
|
||||||
CFactorGraph::checkForIdenticalFactors = false;
|
CbpSolver::checkForIdenticalFactors = false;
|
||||||
solver = new CbpSolver (*fg); // FIXME
|
solver = new CbpSolver (*fg); // FIXME
|
||||||
} else {
|
} else {
|
||||||
cerr << "error: unknow solver" << endl;
|
cerr << "error: unknow solver" << endl;
|
||||||
@ -434,7 +444,7 @@ void runBpSolver (
|
|||||||
for (size_t i = 0; i < tasks.size(); i++) {
|
for (size_t i = 0; i < tasks.size(); i++) {
|
||||||
results.push_back (solver->solveQuery (tasks[i]));
|
results.push_back (solver->solveQuery (tasks[i]));
|
||||||
}
|
}
|
||||||
if (fg->isFromBayesNetwork()) {
|
if (fg->bayesianFactors()) {
|
||||||
//delete mfg;
|
//delete mfg;
|
||||||
}
|
}
|
||||||
delete solver;
|
delete solver;
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
#define HORUS_INDEXER_H
|
#define HORUS_INDEXER_H
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <functional>
|
#include <numeric>
|
||||||
|
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <iomanip>
|
#include <iomanip>
|
||||||
@ -14,10 +14,9 @@ class Indexer
|
|||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
Indexer (const Ranges& ranges, bool calcOffsets = true)
|
Indexer (const Ranges& ranges, bool calcOffsets = true)
|
||||||
: index_(0), indices_(ranges.size(), 0), ranges_(ranges)
|
: index_(0), indices_(ranges.size(), 0), ranges_(ranges),
|
||||||
|
size_(Util::sizeExpected (ranges))
|
||||||
{
|
{
|
||||||
size_ = std::accumulate (ranges.begin(), ranges.end(), 1,
|
|
||||||
std::multiplies<unsigned>());
|
|
||||||
if (calcOffsets) {
|
if (calcOffsets) {
|
||||||
calculateOffsets();
|
calculateOffsets();
|
||||||
}
|
}
|
||||||
|
148
packages/CLPBN/horus/LiftedBpSolver.cpp
Normal file
148
packages/CLPBN/horus/LiftedBpSolver.cpp
Normal file
@ -0,0 +1,148 @@
|
|||||||
|
#include "LiftedBpSolver.h"
|
||||||
|
#include "WeightedBpSolver.h"
|
||||||
|
#include "FactorGraph.h"
|
||||||
|
#include "FoveSolver.h"
|
||||||
|
|
||||||
|
|
||||||
|
LiftedBpSolver::LiftedBpSolver (const ParfactorList& pfList)
|
||||||
|
: pfList_(pfList)
|
||||||
|
{
|
||||||
|
refineParfactors();
|
||||||
|
solver_ = new WeightedBpSolver (*getFactorGraph(), getWeights());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Params
|
||||||
|
LiftedBpSolver::solveQuery (const Grounds& query)
|
||||||
|
{
|
||||||
|
assert (query.empty() == false);
|
||||||
|
Params res;
|
||||||
|
vector<PrvGroup> groups = getQueryGroups (query);
|
||||||
|
if (query.size() == 1) {
|
||||||
|
res = solver_->getPosterioriOf (groups[0]);
|
||||||
|
} else {
|
||||||
|
VarIds queryVids;
|
||||||
|
for (unsigned i = 0; i < groups.size(); i++) {
|
||||||
|
queryVids.push_back (groups[i]);
|
||||||
|
}
|
||||||
|
res = solver_->getJointDistributionOf (queryVids);
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
LiftedBpSolver::printSolverFlags (void) const
|
||||||
|
{
|
||||||
|
stringstream ss;
|
||||||
|
ss << "lifted bp [" ;
|
||||||
|
ss << "schedule=" ;
|
||||||
|
typedef BpOptions::Schedule Sch;
|
||||||
|
switch (BpOptions::schedule) {
|
||||||
|
case Sch::SEQ_FIXED: ss << "seq_fixed"; break;
|
||||||
|
case Sch::SEQ_RANDOM: ss << "seq_random"; break;
|
||||||
|
case Sch::PARALLEL: ss << "parallel"; break;
|
||||||
|
case Sch::MAX_RESIDUAL: ss << "max_residual"; break;
|
||||||
|
}
|
||||||
|
ss << ",max_iter=" << BpOptions::maxIter;
|
||||||
|
ss << ",accuracy=" << BpOptions::accuracy;
|
||||||
|
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
||||||
|
ss << "]" ;
|
||||||
|
cout << ss.str() << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
LiftedBpSolver::refineParfactors (void)
|
||||||
|
{
|
||||||
|
while (iterate() == false);
|
||||||
|
|
||||||
|
if (Globals::verbosity > 2) {
|
||||||
|
Util::printHeader ("AFTER REFINEMENT");
|
||||||
|
pfList_.print();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
LiftedBpSolver::iterate (void)
|
||||||
|
{
|
||||||
|
ParfactorList::iterator it = pfList_.begin();
|
||||||
|
while (it != pfList_.end()) {
|
||||||
|
const ProbFormulas& args = (*it)->arguments();
|
||||||
|
for (size_t i = 0; i < args.size(); i++) {
|
||||||
|
LogVarSet lvs = (*it)->logVarSet() - args[i].logVars();
|
||||||
|
if ((*it)->constr()->isCountNormalized (lvs) == false) {
|
||||||
|
Parfactors pfs = FoveSolver::countNormalize (*it, lvs);
|
||||||
|
it = pfList_.removeAndDelete (it);
|
||||||
|
pfList_.add (pfs);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
++ it;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
vector<PrvGroup>
|
||||||
|
LiftedBpSolver::getQueryGroups (const Grounds& query)
|
||||||
|
{
|
||||||
|
vector<PrvGroup> queryGroups;
|
||||||
|
for (unsigned i = 0; i < query.size(); i++) {
|
||||||
|
ParfactorList::const_iterator it = pfList_.begin();
|
||||||
|
for (; it != pfList_.end(); ++it) {
|
||||||
|
if ((*it)->containsGround (query[i])) {
|
||||||
|
queryGroups.push_back ((*it)->findGroup (query[i]));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert (queryGroups.size() == query.size());
|
||||||
|
return queryGroups;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
FactorGraph*
|
||||||
|
LiftedBpSolver::getFactorGraph (void)
|
||||||
|
{
|
||||||
|
FactorGraph* fg = new FactorGraph();
|
||||||
|
ParfactorList::const_iterator it = pfList_.begin();
|
||||||
|
for (; it != pfList_.end(); ++it) {
|
||||||
|
vector<PrvGroup> groups = (*it)->getAllGroups();
|
||||||
|
VarIds varIds;
|
||||||
|
for (size_t i = 0; i < groups.size(); i++) {
|
||||||
|
varIds.push_back (groups[i]);
|
||||||
|
}
|
||||||
|
fg->addFactor (Factor (varIds, (*it)->ranges(), (*it)->params()));
|
||||||
|
}
|
||||||
|
return fg;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
vector<vector<unsigned>>
|
||||||
|
LiftedBpSolver::getWeights (void) const
|
||||||
|
{
|
||||||
|
vector<vector<unsigned>> weights;
|
||||||
|
weights.reserve (pfList_.size());
|
||||||
|
ParfactorList::const_iterator it = pfList_.begin();
|
||||||
|
for (; it != pfList_.end(); ++it) {
|
||||||
|
const ProbFormulas& args = (*it)->arguments();
|
||||||
|
weights.push_back ({ });
|
||||||
|
weights.back().reserve (args.size());
|
||||||
|
for (size_t i = 0; i < args.size(); i++) {
|
||||||
|
LogVarSet lvs = (*it)->logVarSet() - args[i].logVars();
|
||||||
|
weights.back().push_back ((*it)->constr()->getConditionalCount (lvs));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return weights;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
34
packages/CLPBN/horus/LiftedBpSolver.h
Normal file
34
packages/CLPBN/horus/LiftedBpSolver.h
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
#ifndef HORUS_LIFTEDBPSOLVER_H
|
||||||
|
#define HORUS_LIFTEDBPSOLVER_H
|
||||||
|
|
||||||
|
#include "ParfactorList.h"
|
||||||
|
|
||||||
|
class FactorGraph;
|
||||||
|
class WeightedBpSolver;
|
||||||
|
|
||||||
|
class LiftedBpSolver
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
LiftedBpSolver (const ParfactorList& pfList);
|
||||||
|
|
||||||
|
Params solveQuery (const Grounds&);
|
||||||
|
|
||||||
|
void printSolverFlags (void) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
void refineParfactors (void);
|
||||||
|
|
||||||
|
bool iterate (void);
|
||||||
|
|
||||||
|
vector<PrvGroup> getQueryGroups (const Grounds&);
|
||||||
|
|
||||||
|
FactorGraph* getFactorGraph (void);
|
||||||
|
|
||||||
|
vector<vector<unsigned>> getWeights (void) const;
|
||||||
|
|
||||||
|
ParfactorList pfList_;
|
||||||
|
WeightedBpSolver* solver_;
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // HORUS_LIFTEDBPSOLVER_H
|
@ -50,7 +50,6 @@ HEADERS = \
|
|||||||
$(srcdir)/ElimGraph.h \
|
$(srcdir)/ElimGraph.h \
|
||||||
$(srcdir)/FactorGraph.h \
|
$(srcdir)/FactorGraph.h \
|
||||||
$(srcdir)/Factor.h \
|
$(srcdir)/Factor.h \
|
||||||
$(srcdir)/CFactorGraph.h \
|
|
||||||
$(srcdir)/ConstraintTree.h \
|
$(srcdir)/ConstraintTree.h \
|
||||||
$(srcdir)/Solver.h \
|
$(srcdir)/Solver.h \
|
||||||
$(srcdir)/VarElimSolver.h \
|
$(srcdir)/VarElimSolver.h \
|
||||||
@ -65,6 +64,8 @@ HEADERS = \
|
|||||||
$(srcdir)/ParfactorList.h \
|
$(srcdir)/ParfactorList.h \
|
||||||
$(srcdir)/LiftedUtils.h \
|
$(srcdir)/LiftedUtils.h \
|
||||||
$(srcdir)/TinySet.h \
|
$(srcdir)/TinySet.h \
|
||||||
|
$(srcdir)/LiftedBpSolver.h \
|
||||||
|
$(srcdir)/WeightedBpSolver.h \
|
||||||
$(srcdir)/Util.h \
|
$(srcdir)/Util.h \
|
||||||
$(srcdir)/Horus.h
|
$(srcdir)/Horus.h
|
||||||
|
|
||||||
@ -74,7 +75,6 @@ CPP_SOURCES = \
|
|||||||
$(srcdir)/ElimGraph.cpp \
|
$(srcdir)/ElimGraph.cpp \
|
||||||
$(srcdir)/FactorGraph.cpp \
|
$(srcdir)/FactorGraph.cpp \
|
||||||
$(srcdir)/Factor.cpp \
|
$(srcdir)/Factor.cpp \
|
||||||
$(srcdir)/CFactorGraph.cpp \
|
|
||||||
$(srcdir)/ConstraintTree.cpp \
|
$(srcdir)/ConstraintTree.cpp \
|
||||||
$(srcdir)/Var.cpp \
|
$(srcdir)/Var.cpp \
|
||||||
$(srcdir)/Solver.cpp \
|
$(srcdir)/Solver.cpp \
|
||||||
@ -88,6 +88,8 @@ CPP_SOURCES = \
|
|||||||
$(srcdir)/ParfactorList.cpp \
|
$(srcdir)/ParfactorList.cpp \
|
||||||
$(srcdir)/LiftedUtils.cpp \
|
$(srcdir)/LiftedUtils.cpp \
|
||||||
$(srcdir)/Util.cpp \
|
$(srcdir)/Util.cpp \
|
||||||
|
$(srcdir)/LiftedBpSolver.cpp \
|
||||||
|
$(srcdir)/WeightedBpSolver.cpp \
|
||||||
$(srcdir)/HorusYap.cpp \
|
$(srcdir)/HorusYap.cpp \
|
||||||
$(srcdir)/HorusCli.cpp
|
$(srcdir)/HorusCli.cpp
|
||||||
|
|
||||||
@ -97,7 +99,6 @@ OBJS = \
|
|||||||
ElimGraph.o \
|
ElimGraph.o \
|
||||||
FactorGraph.o \
|
FactorGraph.o \
|
||||||
Factor.o \
|
Factor.o \
|
||||||
CFactorGraph.o \
|
|
||||||
ConstraintTree.o \
|
ConstraintTree.o \
|
||||||
Var.o \
|
Var.o \
|
||||||
Solver.o \
|
Solver.o \
|
||||||
@ -111,6 +112,8 @@ OBJS = \
|
|||||||
ParfactorList.o \
|
ParfactorList.o \
|
||||||
LiftedUtils.o \
|
LiftedUtils.o \
|
||||||
Util.o \
|
Util.o \
|
||||||
|
LiftedBpSolver.o \
|
||||||
|
WeightedBpSolver.o \
|
||||||
HorusYap.o
|
HorusYap.o
|
||||||
|
|
||||||
HCLI_OBJS = \
|
HCLI_OBJS = \
|
||||||
@ -119,7 +122,6 @@ HCLI_OBJS = \
|
|||||||
ElimGraph.o \
|
ElimGraph.o \
|
||||||
FactorGraph.o \
|
FactorGraph.o \
|
||||||
Factor.o \
|
Factor.o \
|
||||||
CFactorGraph.o \
|
|
||||||
ConstraintTree.o \
|
ConstraintTree.o \
|
||||||
Var.o \
|
Var.o \
|
||||||
Solver.o \
|
Solver.o \
|
||||||
@ -131,6 +133,7 @@ HCLI_OBJS = \
|
|||||||
ProbFormula.o \
|
ProbFormula.o \
|
||||||
Histogram.o \
|
Histogram.o \
|
||||||
ParfactorList.o \
|
ParfactorList.o \
|
||||||
|
WeightedBpSolver.o \
|
||||||
LiftedUtils.o \
|
LiftedUtils.o \
|
||||||
Util.o \
|
Util.o \
|
||||||
HorusCli.o
|
HorusCli.o
|
||||||
|
@ -14,14 +14,16 @@ Solver::printAnswer (const VarIds& vids)
|
|||||||
unobservedVids.push_back (vids[i]);
|
unobservedVids.push_back (vids[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Params res = solveQuery (unobservedVids);
|
if (unobservedVids.empty() == false) {
|
||||||
vector<string> stateLines = Util::getStateLines (unobservedVars);
|
Params res = solveQuery (unobservedVids);
|
||||||
for (size_t i = 0; i < res.size(); i++) {
|
vector<string> stateLines = Util::getStateLines (unobservedVars);
|
||||||
cout << "P(" << stateLines[i] << ") = " ;
|
for (size_t i = 0; i < res.size(); i++) {
|
||||||
cout << std::setprecision (Constants::PRECISION) << res[i];
|
cout << "P(" << stateLines[i] << ") = " ;
|
||||||
|
cout << std::setprecision (Constants::PRECISION) << res[i];
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
cout << endl;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -29,14 +31,10 @@ Solver::printAnswer (const VarIds& vids)
|
|||||||
void
|
void
|
||||||
Solver::printAllPosterioris (void)
|
Solver::printAllPosterioris (void)
|
||||||
{
|
{
|
||||||
VarIds vids;
|
VarNodes vars = fg.varNodes();
|
||||||
const VarNodes& vars = fg.varNodes();
|
std::sort (vars.begin(), vars.end(), sortByVarId());
|
||||||
for (size_t i = 0; i < vars.size(); i++) {
|
for (size_t i = 0; i < vars.size(); i++) {
|
||||||
vids.push_back (vars[i]->varId());
|
printAnswer ({vars[i]->varId()});
|
||||||
}
|
|
||||||
std::sort (vids.begin(), vids.end());
|
|
||||||
for (size_t i = 0; i < vids.size(); i++) {
|
|
||||||
printAnswer ({vids[i]});
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,8 +1,5 @@
|
|||||||
- Check if evidence remains in the compressed factor graph
|
|
||||||
- Consider using hashs instead of vectors of colors to calculate the groups in
|
|
||||||
counting bp
|
|
||||||
- Find a way to decrease the time required to find an
|
- Find a way to decrease the time required to find an
|
||||||
elimination order for variable elimination
|
elimination order for variable elimination
|
||||||
- Add a sequential elimination heuristic
|
- Consider using hashs instead of vectors of colors to calculate the groups in
|
||||||
|
counting bp
|
||||||
|
|
||||||
|
@ -12,7 +12,10 @@ class TinySet
|
|||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
|
|
||||||
TinySet (const TinySet& s)
|
typedef typename vector<T>::iterator iterator;
|
||||||
|
typedef typename vector<T>::const_iterator const_iterator;
|
||||||
|
|
||||||
|
TinySet (const TinySet& s)
|
||||||
: vec_(s.vec_), cmp_(s.cmp_) { }
|
: vec_(s.vec_), cmp_(s.cmp_) { }
|
||||||
|
|
||||||
TinySet (const Compare& cmp = Compare())
|
TinySet (const Compare& cmp = Compare())
|
||||||
@ -25,11 +28,10 @@ class TinySet
|
|||||||
: vec_(elements), cmp_(cmp)
|
: vec_(elements), cmp_(cmp)
|
||||||
{
|
{
|
||||||
std::sort (begin(), end(), cmp_);
|
std::sort (begin(), end(), cmp_);
|
||||||
|
iterator it = unique_cmp (begin(), end());
|
||||||
|
vec_.resize (it - begin());
|
||||||
}
|
}
|
||||||
|
|
||||||
typedef typename vector<T>::iterator iterator;
|
|
||||||
typedef typename vector<T>::const_iterator const_iterator;
|
|
||||||
|
|
||||||
iterator insert (const T& t)
|
iterator insert (const T& t)
|
||||||
{
|
{
|
||||||
iterator it = std::lower_bound (begin(), end(), t, cmp_);
|
iterator it = std::lower_bound (begin(), end(), t, cmp_);
|
||||||
@ -224,11 +226,25 @@ class TinySet
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
iterator unique_cmp (iterator first, iterator last)
|
||||||
|
{
|
||||||
|
if (first == last) {
|
||||||
|
return last;
|
||||||
|
}
|
||||||
|
iterator result = first;
|
||||||
|
while (++first != last) {
|
||||||
|
if (cmp_(*result, *first)) {
|
||||||
|
*(++result) = *first;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ++result;
|
||||||
|
}
|
||||||
|
|
||||||
bool consistent (void) const
|
bool consistent (void) const
|
||||||
{
|
{
|
||||||
typename vector<T>::size_type i;
|
typename vector<T>::size_type i;
|
||||||
for (i = 0; i < vec_.size() - 1; i++) {
|
for (i = 0; i < vec_.size() - 1; i++) {
|
||||||
if (cmp_(vec_[i], vec_[i + 1]) == false) {
|
if ( ! cmp_(vec_[i], vec_[i + 1])) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -13,9 +13,11 @@ bool logDomain = false;
|
|||||||
|
|
||||||
unsigned verbosity = 0;
|
unsigned verbosity = 0;
|
||||||
|
|
||||||
InfAlgorithms infAlgorithm = InfAlgorithms::VE;
|
LiftedSolvers liftedSolver = LiftedSolvers::FOVE;
|
||||||
};
|
|
||||||
|
|
||||||
|
GroundSolvers groundSolver = GroundSolvers::VE;
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -71,7 +73,6 @@ stringToDouble (string str)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
double
|
double
|
||||||
factorial (unsigned num)
|
factorial (unsigned num)
|
||||||
{
|
{
|
||||||
@ -128,8 +129,8 @@ nrCombinations (unsigned n, unsigned k)
|
|||||||
size_t
|
size_t
|
||||||
sizeExpected (const Ranges& ranges)
|
sizeExpected (const Ranges& ranges)
|
||||||
{
|
{
|
||||||
return std::accumulate (
|
return std::accumulate (ranges.begin(),
|
||||||
ranges.begin(), ranges.end(), 1, multiplies<unsigned>());
|
ranges.end(), 1, multiplies<unsigned>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -208,20 +209,32 @@ setHorusFlag (string key, string value)
|
|||||||
stringstream ss;
|
stringstream ss;
|
||||||
ss << value;
|
ss << value;
|
||||||
ss >> Globals::verbosity;
|
ss >> Globals::verbosity;
|
||||||
} else if (key == "inf_alg") {
|
} else if (key == "lifted_solver") {
|
||||||
|
if ( value == "fove") {
|
||||||
|
Globals::liftedSolver = LiftedSolvers::FOVE;
|
||||||
|
} else if (value == "lbp") {
|
||||||
|
Globals::liftedSolver = LiftedSolvers::LBP;
|
||||||
|
} else {
|
||||||
|
cerr << "warning: invalid value `" << value << "' " ;
|
||||||
|
cerr << "for `" << key << "'" << endl;
|
||||||
|
returnVal = false;
|
||||||
|
}
|
||||||
|
} else if (key == "ground_solver") {
|
||||||
if ( value == "ve") {
|
if ( value == "ve") {
|
||||||
Globals::infAlgorithm = InfAlgorithms::VE;
|
Globals::groundSolver = GroundSolvers::VE;
|
||||||
} else if (value == "bp") {
|
} else if (value == "bp") {
|
||||||
Globals::infAlgorithm = InfAlgorithms::BP;
|
Globals::groundSolver = GroundSolvers::BP;
|
||||||
} else if (value == "cbp") {
|
} else if (value == "cbp") {
|
||||||
Globals::infAlgorithm = InfAlgorithms::CBP;
|
Globals::groundSolver = GroundSolvers::CBP;
|
||||||
} else {
|
} else {
|
||||||
cerr << "warning: invalid value `" << value << "' " ;
|
cerr << "warning: invalid value `" << value << "' " ;
|
||||||
cerr << "for `" << key << "'" << endl;
|
cerr << "for `" << key << "'" << endl;
|
||||||
returnVal = false;
|
returnVal = false;
|
||||||
}
|
}
|
||||||
} else if (key == "elim_heuristic") {
|
} else if (key == "elim_heuristic") {
|
||||||
if ( value == "min_neighbors") {
|
if ( value == "sequential") {
|
||||||
|
ElimGraph::elimHeuristic = ElimHeuristic::SEQUENTIAL;
|
||||||
|
} else if (value == "min_neighbors") {
|
||||||
ElimGraph::elimHeuristic = ElimHeuristic::MIN_NEIGHBORS;
|
ElimGraph::elimHeuristic = ElimHeuristic::MIN_NEIGHBORS;
|
||||||
} else if (value == "min_weight") {
|
} else if (value == "min_weight") {
|
||||||
ElimGraph::elimHeuristic = ElimHeuristic::MIN_WEIGHT;
|
ElimGraph::elimHeuristic = ElimHeuristic::MIN_WEIGHT;
|
||||||
@ -323,23 +336,15 @@ namespace LogAware {
|
|||||||
void
|
void
|
||||||
normalize (Params& v)
|
normalize (Params& v)
|
||||||
{
|
{
|
||||||
double sum = LogAware::addIdenty();
|
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
for (size_t i = 0; i < v.size(); i++) {
|
double sum = std::accumulate (v.begin(), v.end(),
|
||||||
sum = Util::logSum (sum, v[i]);
|
LogAware::addIdenty(), Util::logSum);
|
||||||
}
|
|
||||||
assert (sum != -numeric_limits<double>::infinity());
|
assert (sum != -numeric_limits<double>::infinity());
|
||||||
for (size_t i = 0; i < v.size(); i++) {
|
v -= sum;
|
||||||
v[i] -= sum;
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
for (size_t i = 0; i < v.size(); i++) {
|
double sum = std::accumulate (v.begin(), v.end(), 0.0);
|
||||||
sum += v[i];
|
|
||||||
}
|
|
||||||
assert (sum != 0.0);
|
assert (sum != 0.0);
|
||||||
for (size_t i = 0; i < v.size(); i++) {
|
v /= sum;
|
||||||
v[i] /= sum;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -351,13 +356,11 @@ getL1Distance (const Params& v1, const Params& v2)
|
|||||||
assert (v1.size() == v2.size());
|
assert (v1.size() == v2.size());
|
||||||
double dist = 0.0;
|
double dist = 0.0;
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
for (size_t i = 0; i < v1.size(); i++) {
|
dist = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0,
|
||||||
dist += abs (exp(v1[i]) - exp(v2[i]));
|
std::plus<double>(), FuncObject::abs_diff_exp<double>());
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
for (size_t i = 0; i < v1.size(); i++) {
|
dist = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0,
|
||||||
dist += abs (v1[i] - v2[i]);
|
std::plus<double>(), FuncObject::abs_diff<double>());
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return dist;
|
return dist;
|
||||||
}
|
}
|
||||||
@ -370,19 +373,11 @@ getMaxNorm (const Params& v1, const Params& v2)
|
|||||||
assert (v1.size() == v2.size());
|
assert (v1.size() == v2.size());
|
||||||
double max = 0.0;
|
double max = 0.0;
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
for (size_t i = 0; i < v1.size(); i++) {
|
max = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0,
|
||||||
double diff = abs (exp(v1[i]) - exp(v2[i]));
|
FuncObject::max<double>(), FuncObject::abs_diff_exp<double>());
|
||||||
if (diff > max) {
|
|
||||||
max = diff;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
for (size_t i = 0; i < v1.size(); i++) {
|
max = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0,
|
||||||
double diff = abs (v1[i] - v2[i]);
|
FuncObject::max<double>(), FuncObject::abs_diff<double>());
|
||||||
if (diff > max) {
|
|
||||||
max = diff;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return max;
|
return max;
|
||||||
}
|
}
|
||||||
@ -392,7 +387,9 @@ getMaxNorm (const Params& v1, const Params& v2)
|
|||||||
double
|
double
|
||||||
pow (double base, unsigned iexp)
|
pow (double base, unsigned iexp)
|
||||||
{
|
{
|
||||||
return Globals::logDomain ? base * iexp : std::pow (base, iexp);
|
return Globals::logDomain
|
||||||
|
? base * iexp
|
||||||
|
: std::pow (base, iexp);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -400,8 +397,10 @@ pow (double base, unsigned iexp)
|
|||||||
double
|
double
|
||||||
pow (double base, double exp)
|
pow (double base, double exp)
|
||||||
{
|
{
|
||||||
// assumes that `expoent' is never in log domain
|
// `expoent' should not be in log domain
|
||||||
return Globals::logDomain ? base * exp : std::pow (base, exp);
|
return Globals::logDomain
|
||||||
|
? base * exp
|
||||||
|
: std::pow (base, exp);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -41,6 +41,9 @@ template <typename K, typename V> bool contains (
|
|||||||
|
|
||||||
template <typename T> size_t indexOf (const vector<T>&, const T&);
|
template <typename T> size_t indexOf (const vector<T>&, const T&);
|
||||||
|
|
||||||
|
template <class Operation>
|
||||||
|
void apply_n_times (Params& v1, const Params& v2, unsigned repetitions, Operation);
|
||||||
|
|
||||||
template <typename T> void log (vector<T>&);
|
template <typename T> void log (vector<T>&);
|
||||||
|
|
||||||
template <typename T> void exp (vector<T>&);
|
template <typename T> void exp (vector<T>&);
|
||||||
@ -54,10 +57,6 @@ template <> std::string toString (const bool&);
|
|||||||
|
|
||||||
double logSum (double, double);
|
double logSum (double, double);
|
||||||
|
|
||||||
void add (Params&, const Params&, unsigned);
|
|
||||||
|
|
||||||
void multiply (Params&, const Params&, unsigned);
|
|
||||||
|
|
||||||
unsigned maxUnsigned (void);
|
unsigned maxUnsigned (void);
|
||||||
|
|
||||||
unsigned stringToUnsigned (string);
|
unsigned stringToUnsigned (string);
|
||||||
@ -153,10 +152,29 @@ Util::indexOf (const vector<T>& v, const T& e)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <class Operation> void
|
||||||
|
Util::apply_n_times (Params& v1, const Params& v2, unsigned repetitions,
|
||||||
|
Operation unary_op)
|
||||||
|
{
|
||||||
|
Params::iterator first = v1.begin();
|
||||||
|
Params::const_iterator last = v1.end();
|
||||||
|
Params::const_iterator first2 = v2.begin();
|
||||||
|
Params::const_iterator last2 = v2.end();
|
||||||
|
while (first != last) {
|
||||||
|
for (first2 = v2.begin(); first2 != last2; ++first2) {
|
||||||
|
std::transform (first, first + repetitions, first,
|
||||||
|
std::bind1st (unary_op, *first2));
|
||||||
|
first += repetitions;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
template <typename T> void
|
template <typename T> void
|
||||||
Util::log (vector<T>& v)
|
Util::log (vector<T>& v)
|
||||||
{
|
{
|
||||||
transform (v.begin(), v.end(), v.begin(), ::log);
|
std::transform (v.begin(), v.end(), v.begin(), ::log);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -164,7 +182,7 @@ Util::log (vector<T>& v)
|
|||||||
template <typename T> void
|
template <typename T> void
|
||||||
Util::exp (vector<T>& v)
|
Util::exp (vector<T>& v)
|
||||||
{
|
{
|
||||||
transform (v.begin(), v.end(), v.begin(), ::exp);
|
std::transform (v.begin(), v.end(), v.begin(), ::exp);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -224,36 +242,6 @@ Util::logSum (double x, double y)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline void
|
|
||||||
Util::add (Params& v1, const Params& v2, unsigned repetitions)
|
|
||||||
{
|
|
||||||
for (size_t count = 0; count < v1.size(); ) {
|
|
||||||
for (size_t i = 0; i < v2.size(); i++) {
|
|
||||||
for (unsigned r = 0; r < repetitions; r++) {
|
|
||||||
v1[count] += v2[i];
|
|
||||||
count ++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline void
|
|
||||||
Util::multiply (Params& v1, const Params& v2, unsigned repetitions)
|
|
||||||
{
|
|
||||||
for (size_t count = 0; count < v1.size(); ) {
|
|
||||||
for (size_t i = 0; i < v2.size(); i++) {
|
|
||||||
for (unsigned r = 0; r < repetitions; r++) {
|
|
||||||
v1[count] *= v2[i];
|
|
||||||
count ++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline unsigned
|
inline unsigned
|
||||||
Util::maxUnsigned (void)
|
Util::maxUnsigned (void)
|
||||||
{
|
{
|
||||||
@ -273,7 +261,6 @@ inline double noEvidence() { return Globals::logDomain ? NEG_INF : 0.0; }
|
|||||||
inline double log (double v) { return Globals::logDomain ? ::log (v) : v; }
|
inline double log (double v) { return Globals::logDomain ? ::log (v) : v; }
|
||||||
inline double exp (double v) { return Globals::logDomain ? ::exp (v) : v; }
|
inline double exp (double v) { return Globals::logDomain ? ::exp (v) : v; }
|
||||||
|
|
||||||
|
|
||||||
void normalize (Params&);
|
void normalize (Params&);
|
||||||
|
|
||||||
double getL1Distance (const Params&, const Params&);
|
double getL1Distance (const Params&, const Params&);
|
||||||
@ -296,7 +283,7 @@ template <typename T>
|
|||||||
void operator+=(std::vector<T>& v, double val)
|
void operator+=(std::vector<T>& v, double val)
|
||||||
{
|
{
|
||||||
std::transform (v.begin(), v.end(), v.begin(),
|
std::transform (v.begin(), v.end(), v.begin(),
|
||||||
std::bind1st (plus<double>(), val));
|
std::bind2nd (plus<double>(), val));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -305,7 +292,7 @@ template <typename T>
|
|||||||
void operator-=(std::vector<T>& v, double val)
|
void operator-=(std::vector<T>& v, double val)
|
||||||
{
|
{
|
||||||
std::transform (v.begin(), v.end(), v.begin(),
|
std::transform (v.begin(), v.end(), v.begin(),
|
||||||
std::bind1st (minus<double>(), val));
|
std::bind2nd (minus<double>(), val));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -314,7 +301,7 @@ template <typename T>
|
|||||||
void operator*=(std::vector<T>& v, double val)
|
void operator*=(std::vector<T>& v, double val)
|
||||||
{
|
{
|
||||||
std::transform (v.begin(), v.end(), v.begin(),
|
std::transform (v.begin(), v.end(), v.begin(),
|
||||||
std::bind1st (multiplies<double>(), val));
|
std::bind2nd (multiplies<double>(), val));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -323,7 +310,7 @@ template <typename T>
|
|||||||
void operator/=(std::vector<T>& v, double val)
|
void operator/=(std::vector<T>& v, double val)
|
||||||
{
|
{
|
||||||
std::transform (v.begin(), v.end(), v.begin(),
|
std::transform (v.begin(), v.end(), v.begin(),
|
||||||
std::bind1st (divides<double>(), val));
|
std::bind2nd (divides<double>(), val));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -395,5 +382,41 @@ std::ostream& operator << (std::ostream& os, const vector<T>& v)
|
|||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
namespace FuncObject {
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
struct max : public std::binary_function<T, T, T>
|
||||||
|
{
|
||||||
|
T operator() (const T& x, const T& y) const
|
||||||
|
{
|
||||||
|
return x < y ? y : x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct abs_diff : public std::binary_function<T, T, T>
|
||||||
|
{
|
||||||
|
T operator() (const T& x, const T& y) const
|
||||||
|
{
|
||||||
|
return std::abs (x - y);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct abs_diff_exp : public std::binary_function<T, T, T>
|
||||||
|
{
|
||||||
|
T operator() (const T& x, const T& y) const
|
||||||
|
{
|
||||||
|
return std::abs (std::exp (x) - std::exp (y));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
#endif // HORUS_UTIL_H
|
#endif // HORUS_UTIL_H
|
||||||
|
|
||||||
|
@ -48,6 +48,7 @@ VarElimSolver::printSolverFlags (void) const
|
|||||||
ss << "elim_heuristic=" ;
|
ss << "elim_heuristic=" ;
|
||||||
ElimHeuristic eh = ElimGraph::elimHeuristic;
|
ElimHeuristic eh = ElimGraph::elimHeuristic;
|
||||||
switch (eh) {
|
switch (eh) {
|
||||||
|
case SEQUENTIAL: ss << "sequential"; break;
|
||||||
case MIN_NEIGHBORS: ss << "min_neighbors"; break;
|
case MIN_NEIGHBORS: ss << "min_neighbors"; break;
|
||||||
case MIN_WEIGHT: ss << "min_weight"; break;
|
case MIN_WEIGHT: ss << "min_weight"; break;
|
||||||
case MIN_FILL: ss << "min_fill"; break;
|
case MIN_FILL: ss << "min_fill"; break;
|
||||||
|
288
packages/CLPBN/horus/WeightedBpSolver.cpp
Normal file
288
packages/CLPBN/horus/WeightedBpSolver.cpp
Normal file
@ -0,0 +1,288 @@
|
|||||||
|
#include "WeightedBpSolver.h"
|
||||||
|
|
||||||
|
|
||||||
|
WeightedBpSolver::~WeightedBpSolver (void)
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < links_.size(); i++) {
|
||||||
|
delete links_[i];
|
||||||
|
}
|
||||||
|
links_.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Params
|
||||||
|
WeightedBpSolver::getPosterioriOf (VarId vid)
|
||||||
|
{
|
||||||
|
if (runned_ == false) {
|
||||||
|
runSolver();
|
||||||
|
}
|
||||||
|
VarNode* var = fg.getVarNode (vid);
|
||||||
|
assert (var != 0);
|
||||||
|
Params probs;
|
||||||
|
if (var->hasEvidence()) {
|
||||||
|
probs.resize (var->range(), LogAware::noEvidence());
|
||||||
|
probs[var->getEvidence()] = LogAware::withEvidence();
|
||||||
|
} else {
|
||||||
|
probs.resize (var->range(), LogAware::multIdenty());
|
||||||
|
const BpLinks& links = ninf(var)->getLinks();
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
for (size_t i = 0; i < links.size(); i++) {
|
||||||
|
WeightedLink* l = static_cast<WeightedLink*> (links[i]);
|
||||||
|
probs += l->powMessage();
|
||||||
|
}
|
||||||
|
LogAware::normalize (probs);
|
||||||
|
Util::exp (probs);
|
||||||
|
} else {
|
||||||
|
for (size_t i = 0; i < links.size(); i++) {
|
||||||
|
WeightedLink* l = static_cast<WeightedLink*> (links[i]);
|
||||||
|
probs *= l->powMessage();
|
||||||
|
}
|
||||||
|
LogAware::normalize (probs);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return probs;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
WeightedBpSolver::createLinks (void)
|
||||||
|
{
|
||||||
|
if (Globals::verbosity > 0) {
|
||||||
|
cout << "compressed factor graph contains " ;
|
||||||
|
cout << fg.nrVarNodes() << " variables and " ;
|
||||||
|
cout << fg.nrFacNodes() << " factors " << endl;
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
const FacNodes& facNodes = fg.facNodes();
|
||||||
|
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||||
|
const VarNodes& neighs = facNodes[i]->neighbors();
|
||||||
|
for (size_t j = 0; j < neighs.size(); j++) {
|
||||||
|
if (Globals::verbosity > 1) {
|
||||||
|
cout << "creating link " ;
|
||||||
|
cout << facNodes[i]->getLabel();
|
||||||
|
cout << " -- " ;
|
||||||
|
cout << neighs[j]->label();
|
||||||
|
cout << " idx=" << j << ", weight=" << weights_[i][j] << endl;
|
||||||
|
}
|
||||||
|
links_.push_back (new WeightedLink (
|
||||||
|
facNodes[i], neighs[j], j, weights_[i][j]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (Globals::verbosity > 1) {
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
WeightedBpSolver::maxResidualSchedule (void)
|
||||||
|
{
|
||||||
|
if (nIters_ == 1) {
|
||||||
|
for (size_t i = 0; i < links_.size(); i++) {
|
||||||
|
calculateMessage (links_[i]);
|
||||||
|
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
|
||||||
|
linkMap_.insert (make_pair (links_[i], it));
|
||||||
|
if (Globals::verbosity >= 1) {
|
||||||
|
cout << "calculating " << links_[i]->toString() << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t c = 0; c < links_.size(); c++) {
|
||||||
|
if (Globals::verbosity > 1) {
|
||||||
|
cout << endl << "current residuals:" << endl;
|
||||||
|
for (SortedOrder::iterator it = sortedOrder_.begin();
|
||||||
|
it != sortedOrder_.end(); ++it) {
|
||||||
|
cout << " " << setw (30) << left << (*it)->toString();
|
||||||
|
cout << "residual = " << (*it)->residual() << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
SortedOrder::iterator it = sortedOrder_.begin();
|
||||||
|
BpLink* link = *it;
|
||||||
|
if (Globals::verbosity >= 1) {
|
||||||
|
cout << "updating " << (*sortedOrder_.begin())->toString() << endl;
|
||||||
|
}
|
||||||
|
if (link->residual() < BpOptions::accuracy) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
link->updateMessage();
|
||||||
|
link->clearResidual();
|
||||||
|
sortedOrder_.erase (it);
|
||||||
|
linkMap_.find (link)->second = sortedOrder_.insert (link);
|
||||||
|
|
||||||
|
// update the messages that depend on message source --> destin
|
||||||
|
const FacNodes& factorNeighbors = link->varNode()->neighbors();
|
||||||
|
for (size_t i = 0; i < factorNeighbors.size(); i++) {
|
||||||
|
const BpLinks& links = ninf(factorNeighbors[i])->getLinks();
|
||||||
|
for (size_t j = 0; j < links.size(); j++) {
|
||||||
|
if (links[j]->varNode() != link->varNode()) {
|
||||||
|
if (Globals::verbosity > 1) {
|
||||||
|
cout << " calculating " << links[j]->toString() << endl;
|
||||||
|
}
|
||||||
|
calculateMessage (links[j]);
|
||||||
|
BpLinkMap::iterator iter = linkMap_.find (links[j]);
|
||||||
|
sortedOrder_.erase (iter->second);
|
||||||
|
iter->second = sortedOrder_.insert (links[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// in counting bp, the message that a variable X sends to
|
||||||
|
// to a factor F depends on the message that F sent to the X
|
||||||
|
const BpLinks& links = ninf(link->facNode())->getLinks();
|
||||||
|
for (size_t i = 0; i < links.size(); i++) {
|
||||||
|
if (links[i]->varNode() != link->varNode()) {
|
||||||
|
if (Globals::verbosity > 1) {
|
||||||
|
cout << " calculating " << links[i]->toString() << endl;
|
||||||
|
}
|
||||||
|
calculateMessage (links[i]);
|
||||||
|
BpLinkMap::iterator iter = linkMap_.find (links[i]);
|
||||||
|
sortedOrder_.erase (iter->second);
|
||||||
|
iter->second = sortedOrder_.insert (links[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
WeightedBpSolver::calcFactorToVarMsg (BpLink* _link)
|
||||||
|
{
|
||||||
|
WeightedLink* link = static_cast<WeightedLink*> (_link);
|
||||||
|
FacNode* src = link->facNode();
|
||||||
|
const VarNode* dst = link->varNode();
|
||||||
|
const BpLinks& links = ninf(src)->getLinks();
|
||||||
|
// calculate the product of messages that were sent
|
||||||
|
// to factor `src', except from var `dst'
|
||||||
|
unsigned reps = 1;
|
||||||
|
unsigned msgSize = Util::sizeExpected (src->factor().ranges());
|
||||||
|
Params msgProduct (msgSize, LogAware::multIdenty());
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
for (size_t i = links.size(); i-- > 0; ) {
|
||||||
|
const WeightedLink* l = static_cast<const WeightedLink*> (links[i]);
|
||||||
|
if ( ! (l->varNode() == dst && l->index() == link->index())) {
|
||||||
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
|
cout << " message from " << links[i]->varNode()->label();
|
||||||
|
cout << ": " ;
|
||||||
|
}
|
||||||
|
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
|
||||||
|
reps, std::plus<double>());
|
||||||
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
reps *= links[i]->varNode()->range();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (size_t i = links.size(); i-- > 0; ) {
|
||||||
|
const WeightedLink* l = static_cast<const WeightedLink*> (links[i]);
|
||||||
|
if ( ! (l->varNode() == dst && l->index() == link->index())) {
|
||||||
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
|
cout << " message from " << links[i]->varNode()->label();
|
||||||
|
cout << ": " ;
|
||||||
|
}
|
||||||
|
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
|
||||||
|
reps, std::multiplies<double>());
|
||||||
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
reps *= links[i]->varNode()->range();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Factor result (src->factor().arguments(),
|
||||||
|
src->factor().ranges(), msgProduct);
|
||||||
|
assert (msgProduct.size() == src->factor().size());
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
result.params() += src->factor().params();
|
||||||
|
} else {
|
||||||
|
result.params() *= src->factor().params();
|
||||||
|
}
|
||||||
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
|
cout << " message product: " << msgProduct << endl;
|
||||||
|
cout << " original factor: " << src->factor().params() << endl;
|
||||||
|
cout << " factor product: " << result.params() << endl;
|
||||||
|
}
|
||||||
|
result.sumOutAllExceptIndex (link->index());
|
||||||
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
|
cout << " marginalized: " << result.params() << endl;
|
||||||
|
}
|
||||||
|
link->nextMessage() = result.params();
|
||||||
|
LogAware::normalize (link->nextMessage());
|
||||||
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
|
cout << " curr msg: " << link->message() << endl;
|
||||||
|
cout << " next msg: " << link->nextMessage() << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Params
|
||||||
|
WeightedBpSolver::getVarToFactorMsg (const BpLink* _link) const
|
||||||
|
{
|
||||||
|
const WeightedLink* link = static_cast<const WeightedLink*> (_link);
|
||||||
|
const VarNode* src = link->varNode();
|
||||||
|
const FacNode* dst = link->facNode();
|
||||||
|
Params msg;
|
||||||
|
if (src->hasEvidence()) {
|
||||||
|
msg.resize (src->range(), LogAware::noEvidence());
|
||||||
|
double value = link->message()[src->getEvidence()];
|
||||||
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
|
msg[src->getEvidence()] = value;
|
||||||
|
cout << msg << "^" << link->weight() << "-1" ;
|
||||||
|
}
|
||||||
|
msg[src->getEvidence()] = LogAware::pow (value, link->weight() - 1);
|
||||||
|
} else {
|
||||||
|
msg = link->message();
|
||||||
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
|
cout << msg << "^" << link->weight() << "-1" ;
|
||||||
|
}
|
||||||
|
LogAware::pow (msg, link->weight() - 1);
|
||||||
|
}
|
||||||
|
const BpLinks& links = ninf(src)->getLinks();
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
for (size_t i = 0; i < links.size(); i++) {
|
||||||
|
WeightedLink* l = static_cast<WeightedLink*> (links[i]);
|
||||||
|
if ( ! (l->facNode() == dst && l->index() == link->index())) {
|
||||||
|
msg += l->powMessage();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (size_t i = 0; i < links.size(); i++) {
|
||||||
|
WeightedLink* l = static_cast<WeightedLink*> (links[i]);
|
||||||
|
if ( ! (l->facNode() == dst && l->index() == link->index())) {
|
||||||
|
msg *= l->powMessage();
|
||||||
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
|
cout << " x " << l->nextMessage() << "^" << link->weight();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
|
cout << " = " << msg;
|
||||||
|
}
|
||||||
|
return msg;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
WeightedBpSolver::printLinkInformation (void) const
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < links_.size(); i++) {
|
||||||
|
WeightedLink* l = static_cast<WeightedLink*> (links_[i]);
|
||||||
|
cout << l->toString() << ":" << endl;
|
||||||
|
cout << " curr msg = " << l->message() << endl;
|
||||||
|
cout << " next msg = " << l->nextMessage() << endl;
|
||||||
|
cout << " pow msg = " << l->powMessage() << endl;
|
||||||
|
cout << " index = " << l->index() << endl;
|
||||||
|
cout << " weight = " << l->weight() << endl;
|
||||||
|
cout << " residual = " << l->residual() << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
61
packages/CLPBN/horus/WeightedBpSolver.h
Normal file
61
packages/CLPBN/horus/WeightedBpSolver.h
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
#ifndef HORUS_WEIGHTEDBPSOLVER_H
|
||||||
|
#define HORUS_WEIGHTEDBPSOLVER_H
|
||||||
|
|
||||||
|
#include "BpSolver.h"
|
||||||
|
|
||||||
|
class WeightedLink : public BpLink
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
WeightedLink (FacNode* fn, VarNode* vn, size_t idx, unsigned weight)
|
||||||
|
: BpLink (fn, vn), index_(idx), weight_(weight),
|
||||||
|
pwdMsg_(vn->range(), LogAware::one()) { }
|
||||||
|
|
||||||
|
size_t index (void) const { return index_; }
|
||||||
|
|
||||||
|
unsigned weight (void) const { return weight_; }
|
||||||
|
|
||||||
|
const Params& powMessage (void) const { return pwdMsg_; }
|
||||||
|
|
||||||
|
void updateMessage (void)
|
||||||
|
{
|
||||||
|
pwdMsg_ = *nextMsg_;
|
||||||
|
swap (currMsg_, nextMsg_);
|
||||||
|
LogAware::pow (pwdMsg_, weight_);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
size_t index_;
|
||||||
|
unsigned weight_;
|
||||||
|
Params pwdMsg_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class WeightedBpSolver : public BpSolver
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
WeightedBpSolver (const FactorGraph& fg,
|
||||||
|
const vector<vector<unsigned>>& weights)
|
||||||
|
: BpSolver (fg), weights_(weights) { }
|
||||||
|
|
||||||
|
~WeightedBpSolver (void);
|
||||||
|
|
||||||
|
Params getPosterioriOf (VarId);
|
||||||
|
|
||||||
|
private:
|
||||||
|
|
||||||
|
void createLinks (void);
|
||||||
|
|
||||||
|
void maxResidualSchedule (void);
|
||||||
|
|
||||||
|
void calcFactorToVarMsg (BpLink*);
|
||||||
|
|
||||||
|
Params getVarToFactorMsg (const BpLink*) const;
|
||||||
|
|
||||||
|
void printLinkInformation (void) const;
|
||||||
|
|
||||||
|
vector<vector<unsigned>> weights_;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // HORUS_WEIGHTEDBPSOLVER_H
|
||||||
|
|
Reference in New Issue
Block a user