add support to (real) lifted belief propagation
This commit is contained in:
@@ -82,7 +82,7 @@ BpSolver::getPosterioriOf (VarId vid)
|
||||
probs[var->getEvidence()] = LogAware::withEvidence();
|
||||
} else {
|
||||
probs.resize (var->range(), LogAware::multIdenty());
|
||||
const SpLinkSet& links = ninf(var)->getLinks();
|
||||
const BpLinks& links = ninf(var)->getLinks();
|
||||
if (Globals::logDomain) {
|
||||
for (size_t i = 0; i < links.size(); i++) {
|
||||
probs += links[i]->message();
|
||||
@@ -120,7 +120,7 @@ BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
||||
return getJointByConditioning (jointVarIds);
|
||||
} else {
|
||||
Factor res (facNodes[idx]->factor());
|
||||
const SpLinkSet& links = ninf(facNodes[idx])->getLinks();
|
||||
const BpLinks& links = ninf(facNodes[idx])->getLinks();
|
||||
for (size_t i = 0; i < links.size(); i++) {
|
||||
Factor msg ({links[i]->varNode()->varId()},
|
||||
{links[i]->varNode()->range()},
|
||||
@@ -194,7 +194,7 @@ BpSolver::createLinks (void)
|
||||
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||
const VarNodes& neighbors = facNodes[i]->neighbors();
|
||||
for (size_t j = 0; j < neighbors.size(); j++) {
|
||||
links_.push_back (new SpLink (facNodes[i], neighbors[j]));
|
||||
links_.push_back (new BpLink (facNodes[i], neighbors[j]));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -224,7 +224,7 @@ BpSolver::maxResidualSchedule (void)
|
||||
}
|
||||
|
||||
SortedOrder::iterator it = sortedOrder_.begin();
|
||||
SpLink* link = *it;
|
||||
BpLink* link = *it;
|
||||
if (link->residual() < BpOptions::accuracy) {
|
||||
return;
|
||||
}
|
||||
@@ -237,11 +237,11 @@ BpSolver::maxResidualSchedule (void)
|
||||
const FacNodes& factorNeighbors = link->varNode()->neighbors();
|
||||
for (size_t i = 0; i < factorNeighbors.size(); i++) {
|
||||
if (factorNeighbors[i] != link->facNode()) {
|
||||
const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks();
|
||||
const BpLinks& links = ninf(factorNeighbors[i])->getLinks();
|
||||
for (size_t j = 0; j < links.size(); j++) {
|
||||
if (links[j]->varNode() != link->varNode()) {
|
||||
calculateMessage (links[j]);
|
||||
SpLinkMap::iterator iter = linkMap_.find (links[j]);
|
||||
BpLinkMap::iterator iter = linkMap_.find (links[j]);
|
||||
sortedOrder_.erase (iter->second);
|
||||
iter->second = sortedOrder_.insert (links[j]);
|
||||
}
|
||||
@@ -257,11 +257,11 @@ BpSolver::maxResidualSchedule (void)
|
||||
|
||||
|
||||
void
|
||||
BpSolver::calcFactorToVarMsg (SpLink* link)
|
||||
BpSolver::calcFactorToVarMsg (BpLink* link)
|
||||
{
|
||||
FacNode* src = link->facNode();
|
||||
const VarNode* dst = link->varNode();
|
||||
const SpLinkSet& links = ninf(src)->getLinks();
|
||||
const BpLinks& links = ninf(src)->getLinks();
|
||||
// calculate the product of messages that were sent
|
||||
// to factor `src', except from var `dst'
|
||||
unsigned reps = 1;
|
||||
@@ -321,7 +321,7 @@ BpSolver::calcFactorToVarMsg (SpLink* link)
|
||||
|
||||
|
||||
Params
|
||||
BpSolver::getVarToFactorMsg (const SpLink* link) const
|
||||
BpSolver::getVarToFactorMsg (const BpLink* link) const
|
||||
{
|
||||
const VarNode* src = link->varNode();
|
||||
Params msg;
|
||||
@@ -334,8 +334,8 @@ BpSolver::getVarToFactorMsg (const SpLink* link) const
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << msg;
|
||||
}
|
||||
SpLinkSet::const_iterator it;
|
||||
const SpLinkSet& links = ninf (src)->getLinks();
|
||||
BpLinks::const_iterator it;
|
||||
const BpLinks& links = ninf (src)->getLinks();
|
||||
if (Globals::logDomain) {
|
||||
for (it = links.begin(); it != links.end(); ++it) {
|
||||
msg += (*it)->message();
|
||||
@@ -432,8 +432,8 @@ BpSolver::initializeSolver (void)
|
||||
for (size_t i = 0; i < links_.size(); i++) {
|
||||
FacNode* src = links_[i]->facNode();
|
||||
VarNode* dst = links_[i]->varNode();
|
||||
ninf (dst)->addSpLink (links_[i]);
|
||||
ninf (src)->addSpLink (links_[i]);
|
||||
ninf (dst)->addBpLink (links_[i]);
|
||||
ninf (src)->addBpLink (links_[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -491,7 +491,7 @@ void
|
||||
BpSolver::printLinkInformation (void) const
|
||||
{
|
||||
for (size_t i = 0; i < links_.size(); i++) {
|
||||
SpLink* l = links_[i];
|
||||
BpLink* l = links_[i];
|
||||
cout << l->toString() << ":" << endl;
|
||||
cout << " curr msg = " ;
|
||||
cout << l->message() << endl;
|
||||
|
Reference in New Issue
Block a user