some renamings

This commit is contained in:
Tiago Gomes
2012-05-28 19:41:24 +01:00
parent 64b53e8180
commit 62283f353c
4 changed files with 98 additions and 104 deletions

View File

@@ -63,14 +63,14 @@ CbpSolver::getPosterioriOf (VarId vid)
if (Globals::logDomain) {
for (size_t i = 0; i < links.size(); i++) {
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
probs += l->poweredMessage();
probs += l->powMessage();
}
LogAware::normalize (probs);
Util::exp (probs);
} else {
for (size_t i = 0; i < links.size(); i++) {
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
probs *= l->poweredMessage();
probs *= l->powMessage();
}
LogAware::normalize (probs);
}
@@ -146,7 +146,7 @@ CbpSolver::maxResidualSchedule (void)
for (SortedOrder::iterator it = sortedOrder_.begin();
it != sortedOrder_.end(); ++it) {
cout << " " << setw (30) << left << (*it)->toString();
cout << "residual = " << (*it)->getResidual() << endl;
cout << "residual = " << (*it)->residual() << endl;
}
}
@@ -155,7 +155,7 @@ CbpSolver::maxResidualSchedule (void)
if (Globals::verbosity >= 1) {
cout << "updating " << (*sortedOrder_.begin())->toString() << endl;
}
if (link->getResidual() < BpOptions::accuracy) {
if (link->residual() < BpOptions::accuracy) {
return;
}
link->updateMessage();
@@ -164,11 +164,11 @@ CbpSolver::maxResidualSchedule (void)
linkMap_.find (link)->second = sortedOrder_.insert (link);
// update the messages that depend on message source --> destin
const FacNodes& factorNeighbors = link->getVariable()->neighbors();
const FacNodes& factorNeighbors = link->varNode()->neighbors();
for (size_t i = 0; i < factorNeighbors.size(); i++) {
const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks();
for (size_t j = 0; j < links.size(); j++) {
if (links[j]->getVariable() != link->getVariable()) {
if (links[j]->varNode() != link->varNode()) {
if (Globals::verbosity > 1) {
cout << " calculating " << links[j]->toString() << endl;
}
@@ -181,9 +181,9 @@ CbpSolver::maxResidualSchedule (void)
}
// in counting bp, the message that a variable X sends to
// to a factor F depends on the message that F sent to the X
const SpLinkSet& links = ninf(link->getFactor())->getLinks();
const SpLinkSet& links = ninf(link->facNode())->getLinks();
for (size_t i = 0; i < links.size(); i++) {
if (links[i]->getVariable() != link->getVariable()) {
if (links[i]->varNode() != link->varNode()) {
if (Globals::verbosity > 1) {
cout << " calculating " << links[i]->toString() << endl;
}
@@ -199,35 +199,35 @@ CbpSolver::maxResidualSchedule (void)
void
CbpSolver::calculateFactor2VariableMsg (SpLink* _link)
CbpSolver::calcFactorToVarMsg (SpLink* _link)
{
CbpSolverLink* link = static_cast<CbpSolverLink*> (_link);
FacNode* src = link->getFactor();
const VarNode* dst = link->getVariable();
FacNode* src = link->facNode();
const VarNode* dst = link->varNode();
const SpLinkSet& links = ninf(src)->getLinks();
// calculate the product of messages that were sent
// to factor `src', except from var `dst'
unsigned msgSize = 1;
for (size_t i = 0; i < links.size(); i++) {
msgSize *= links[i]->getVariable()->range();
msgSize *= links[i]->varNode()->range();
}
unsigned repetitions = 1;
Params msgProduct (msgSize, LogAware::multIdenty());
if (Globals::logDomain) {
for (size_t i = links.size(); i-- > 0; ) {
const CbpSolverLink* cl = static_cast<const CbpSolverLink*> (links[i]);
if ( ! (cl->getVariable() == dst && cl->index() == link->index())) {
if ( ! (cl->varNode() == dst && cl->index() == link->index())) {
if (Constants::SHOW_BP_CALCS) {
cout << " message from " << links[i]->getVariable()->label();
cout << " message from " << links[i]->varNode()->label();
cout << ": " ;
}
Util::add (msgProduct, getVar2FactorMsg (links[i]), repetitions);
repetitions *= links[i]->getVariable()->range();
Util::add (msgProduct, getVarToFactorMsg (links[i]), repetitions);
repetitions *= links[i]->varNode()->range();
if (Constants::SHOW_BP_CALCS) {
cout << endl;
}
} else {
unsigned range = links[i]->getVariable()->range();
unsigned range = links[i]->varNode()->range();
Util::add (msgProduct, Params (range, 0.0), repetitions);
repetitions *= range;
}
@@ -235,18 +235,18 @@ CbpSolver::calculateFactor2VariableMsg (SpLink* _link)
} else {
for (size_t i = links.size(); i-- > 0; ) {
const CbpSolverLink* cl = static_cast<const CbpSolverLink*> (links[i]);
if ( ! (cl->getVariable() == dst && cl->index() == link->index())) {
if ( ! (cl->varNode() == dst && cl->index() == link->index())) {
if (Constants::SHOW_BP_CALCS) {
cout << " message from " << links[i]->getVariable()->label();
cout << " message from " << links[i]->varNode()->label();
cout << ": " ;
}
Util::multiply (msgProduct, getVar2FactorMsg (links[i]), repetitions);
repetitions *= links[i]->getVariable()->range();
Util::multiply (msgProduct, getVarToFactorMsg (links[i]), repetitions);
repetitions *= links[i]->varNode()->range();
if (Constants::SHOW_BP_CALCS) {
cout << endl;
}
} else {
unsigned range = links[i]->getVariable()->range();
unsigned range = links[i]->varNode()->range();
Util::multiply (msgProduct, Params (range, 1.0), repetitions);
repetitions *= range;
}
@@ -267,35 +267,35 @@ CbpSolver::calculateFactor2VariableMsg (SpLink* _link)
}
result.sumOutAllExceptIndex (link->index());
if (Constants::SHOW_BP_CALCS) {
cout << " marginalized: " << result.params() << endl;
cout << " marginalized: " << result.params() << endl;
}
link->getNextMessage() = result.params();
LogAware::normalize (link->getNextMessage());
link->nextMessage() = result.params();
LogAware::normalize (link->nextMessage());
if (Constants::SHOW_BP_CALCS) {
cout << " curr msg: " << link->getMessage() << endl;
cout << " next msg: " << link->getNextMessage() << endl;
cout << " curr msg: " << link->message() << endl;
cout << " next msg: " << link->nextMessage() << endl;
}
}
Params
CbpSolver::getVar2FactorMsg (const SpLink* _link) const
CbpSolver::getVarToFactorMsg (const SpLink* _link) const
{
const CbpSolverLink* link = static_cast<const CbpSolverLink*> (_link);
const VarNode* src = link->getVariable();
const FacNode* dst = link->getFactor();
const VarNode* src = link->varNode();
const FacNode* dst = link->facNode();
Params msg;
if (src->hasEvidence()) {
msg.resize (src->range(), LogAware::noEvidence());
double value = link->getMessage()[src->getEvidence()];
double value = link->message()[src->getEvidence()];
if (Constants::SHOW_BP_CALCS) {
msg[src->getEvidence()] = value;
cout << msg << "^" << link->nrEdges() << "-1" ;
}
msg[src->getEvidence()] = LogAware::pow (value, link->nrEdges() - 1);
} else {
msg = link->getMessage();
msg = link->message();
if (Constants::SHOW_BP_CALCS) {
cout << msg << "^" << link->nrEdges() << "-1" ;
}
@@ -305,18 +305,18 @@ CbpSolver::getVar2FactorMsg (const SpLink* _link) const
if (Globals::logDomain) {
for (size_t i = 0; i < links.size(); i++) {
CbpSolverLink* cl = static_cast<CbpSolverLink*> (links[i]);
if ( ! (cl->getFactor() == dst && cl->index() == link->index())) {
if ( ! (cl->facNode() == dst && cl->index() == link->index())) {
CbpSolverLink* cl = static_cast<CbpSolverLink*> (links[i]);
msg += cl->poweredMessage();
msg += cl->powMessage();
}
}
} else {
for (size_t i = 0; i < links.size(); i++) {
CbpSolverLink* cl = static_cast<CbpSolverLink*> (links[i]);
if ( ! (cl->getFactor() == dst && cl->index() == link->index())) {
msg *= cl->poweredMessage();
if ( ! (cl->facNode() == dst && cl->index() == link->index())) {
msg *= cl->powMessage();
if (Constants::SHOW_BP_CALCS) {
cout << " x " << cl->getNextMessage() << "^" << link->nrEdges();
cout << " x " << cl->nextMessage() << "^" << link->nrEdges();
}
}
}
@@ -335,12 +335,12 @@ CbpSolver::printLinkInformation (void) const
for (size_t i = 0; i < links_.size(); i++) {
CbpSolverLink* cl = static_cast<CbpSolverLink*> (links_[i]);
cout << cl->toString() << ":" << endl;
cout << " curr msg = " << cl->getMessage() << endl;
cout << " next msg = " << cl->getNextMessage() << endl;
cout << " curr msg = " << cl->message() << endl;
cout << " next msg = " << cl->nextMessage() << endl;
cout << " index = " << cl->index() << endl;
cout << " nr edges = " << cl->nrEdges() << endl;
cout << " powered = " << cl->poweredMessage() << endl;
cout << " residual = " << cl->getResidual() << endl;
cout << " powered = " << cl->powMessage() << endl;
cout << " residual = " << cl->residual() << endl;
}
}