revamp debugging plataform
This commit is contained in:
parent
d86e2c8386
commit
56475cacbc
@ -153,9 +153,8 @@ BpSolver::runSolver (void)
|
|||||||
nIters_ = 0;
|
nIters_ = 0;
|
||||||
while (!converged() && nIters_ < BpOptions::maxIter) {
|
while (!converged() && nIters_ < BpOptions::maxIter) {
|
||||||
nIters_ ++;
|
nIters_ ++;
|
||||||
if (Constants::DEBUG >= 2) {
|
if (Globals::verbosity > 1) {
|
||||||
Util::printHeader (string ("Iteration ") + Util::toString (nIters_));
|
Util::printHeader (string ("Iteration ") + Util::toString (nIters_));
|
||||||
// cout << endl;
|
|
||||||
}
|
}
|
||||||
switch (BpOptions::schedule) {
|
switch (BpOptions::schedule) {
|
||||||
case BpOptions::Schedule::SEQ_RANDOM:
|
case BpOptions::Schedule::SEQ_RANDOM:
|
||||||
@ -178,12 +177,8 @@ BpSolver::runSolver (void)
|
|||||||
maxResidualSchedule();
|
maxResidualSchedule();
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
if (Constants::DEBUG >= 2) {
|
|
||||||
cout << endl;
|
|
||||||
}
|
}
|
||||||
}
|
if (Globals::verbosity > 0) {
|
||||||
if (Constants::DEBUG >= 2) {
|
|
||||||
cout << endl;
|
|
||||||
if (nIters_ < BpOptions::maxIter) {
|
if (nIters_ < BpOptions::maxIter) {
|
||||||
cout << "Sum-Product converged in " ;
|
cout << "Sum-Product converged in " ;
|
||||||
cout << nIters_ << " iterations" << endl;
|
cout << nIters_ << " iterations" << endl;
|
||||||
@ -191,6 +186,7 @@ BpSolver::runSolver (void)
|
|||||||
cout << "The maximum number of iterations was hit, terminating..." ;
|
cout << "The maximum number of iterations was hit, terminating..." ;
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
|
cout << endl;
|
||||||
}
|
}
|
||||||
unsigned size = fg_->varNodes().size();
|
unsigned size = fg_->varNodes().size();
|
||||||
if (Constants::COLLECT_STATS) {
|
if (Constants::COLLECT_STATS) {
|
||||||
@ -232,7 +228,7 @@ BpSolver::maxResidualSchedule (void)
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (unsigned c = 0; c < links_.size(); c++) {
|
for (unsigned c = 0; c < links_.size(); c++) {
|
||||||
if (Constants::DEBUG >= 2) {
|
if (Globals::verbosity > 1) {
|
||||||
cout << "current residuals:" << endl;
|
cout << "current residuals:" << endl;
|
||||||
for (SortedOrder::iterator it = sortedOrder_.begin();
|
for (SortedOrder::iterator it = sortedOrder_.begin();
|
||||||
it != sortedOrder_.end(); it ++) {
|
it != sortedOrder_.end(); it ++) {
|
||||||
@ -266,7 +262,7 @@ BpSolver::maxResidualSchedule (void)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (Constants::DEBUG >= 2) {
|
if (Globals::verbosity > 1) {
|
||||||
Util::printDashedLine();
|
Util::printDashedLine();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -291,13 +287,13 @@ BpSolver::calculateFactor2VariableMsg (SpLink* link)
|
|||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
for (int i = links.size() - 1; i >= 0; i--) {
|
for (int i = links.size() - 1; i >= 0; i--) {
|
||||||
if (links[i]->getVariable() != dst) {
|
if (links[i]->getVariable() != dst) {
|
||||||
if (Constants::DEBUG >= 5) {
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
cout << " message from " << links[i]->getVariable()->label();
|
cout << " message from " << links[i]->getVariable()->label();
|
||||||
cout << ": " ;
|
cout << ": " ;
|
||||||
}
|
}
|
||||||
Util::add (msgProduct, getVar2FactorMsg (links[i]), repetitions);
|
Util::add (msgProduct, getVar2FactorMsg (links[i]), repetitions);
|
||||||
repetitions *= links[i]->getVariable()->range();
|
repetitions *= links[i]->getVariable()->range();
|
||||||
if (Constants::DEBUG >= 5) {
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -309,13 +305,13 @@ BpSolver::calculateFactor2VariableMsg (SpLink* link)
|
|||||||
} else {
|
} else {
|
||||||
for (int i = links.size() - 1; i >= 0; i--) {
|
for (int i = links.size() - 1; i >= 0; i--) {
|
||||||
if (links[i]->getVariable() != dst) {
|
if (links[i]->getVariable() != dst) {
|
||||||
if (Constants::DEBUG >= 5) {
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
cout << " message from " << links[i]->getVariable()->label();
|
cout << " message from " << links[i]->getVariable()->label();
|
||||||
cout << ": " ;
|
cout << ": " ;
|
||||||
}
|
}
|
||||||
Util::multiply (msgProduct, getVar2FactorMsg (links[i]), repetitions);
|
Util::multiply (msgProduct, getVar2FactorMsg (links[i]), repetitions);
|
||||||
repetitions *= links[i]->getVariable()->range();
|
repetitions *= links[i]->getVariable()->range();
|
||||||
if (Constants::DEBUG >= 5) {
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -328,18 +324,18 @@ BpSolver::calculateFactor2VariableMsg (SpLink* link)
|
|||||||
Factor result (src->factor().arguments(),
|
Factor result (src->factor().arguments(),
|
||||||
src->factor().ranges(), msgProduct);
|
src->factor().ranges(), msgProduct);
|
||||||
result.multiply (src->factor());
|
result.multiply (src->factor());
|
||||||
if (Constants::DEBUG >= 5) {
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
cout << " message product: " << msgProduct << endl;
|
cout << " message product: " << msgProduct << endl;
|
||||||
cout << " original factor: " << src->factor().params() << endl;
|
cout << " original factor: " << src->factor().params() << endl;
|
||||||
cout << " factor product: " << result.params() << endl;
|
cout << " factor product: " << result.params() << endl;
|
||||||
}
|
}
|
||||||
result.sumOutAllExcept (dst->varId());
|
result.sumOutAllExcept (dst->varId());
|
||||||
if (Constants::DEBUG >= 5) {
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
cout << " marginalized: " << result.params() << endl;
|
cout << " marginalized: " << result.params() << endl;
|
||||||
}
|
}
|
||||||
link->getNextMessage() = result.params();
|
link->getNextMessage() = result.params();
|
||||||
LogAware::normalize (link->getNextMessage());
|
LogAware::normalize (link->getNextMessage());
|
||||||
if (Constants::DEBUG >= 5) {
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
cout << " curr msg: " << link->getMessage() << endl;
|
cout << " curr msg: " << link->getMessage() << endl;
|
||||||
cout << " next msg: " << link->getNextMessage() << endl;
|
cout << " next msg: " << link->getNextMessage() << endl;
|
||||||
}
|
}
|
||||||
@ -359,7 +355,7 @@ BpSolver::getVar2FactorMsg (const SpLink* link) const
|
|||||||
} else {
|
} else {
|
||||||
msg.resize (src->range(), LogAware::one());
|
msg.resize (src->range(), LogAware::one());
|
||||||
}
|
}
|
||||||
if (Constants::DEBUG >= 5) {
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
cout << msg;
|
cout << msg;
|
||||||
}
|
}
|
||||||
const SpLinkSet& links = ninf (src)->getLinks();
|
const SpLinkSet& links = ninf (src)->getLinks();
|
||||||
@ -367,7 +363,7 @@ BpSolver::getVar2FactorMsg (const SpLink* link) const
|
|||||||
for (unsigned i = 0; i < links.size(); i++) {
|
for (unsigned i = 0; i < links.size(); i++) {
|
||||||
if (links[i]->getFactor() != dst) {
|
if (links[i]->getFactor() != dst) {
|
||||||
Util::add (msg, links[i]->getMessage());
|
Util::add (msg, links[i]->getMessage());
|
||||||
if (Constants::DEBUG >= 5) {
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
cout << " x " << links[i]->getMessage();
|
cout << " x " << links[i]->getMessage();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -376,13 +372,13 @@ BpSolver::getVar2FactorMsg (const SpLink* link) const
|
|||||||
for (unsigned i = 0; i < links.size(); i++) {
|
for (unsigned i = 0; i < links.size(); i++) {
|
||||||
if (links[i]->getFactor() != dst) {
|
if (links[i]->getFactor() != dst) {
|
||||||
Util::multiply (msg, links[i]->getMessage());
|
Util::multiply (msg, links[i]->getMessage());
|
||||||
if (Constants::DEBUG >= 5) {
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
cout << " x " << links[i]->getMessage();
|
cout << " x " << links[i]->getMessage();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (Constants::DEBUG >= 5) {
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
cout << " = " << msg;
|
cout << " = " << msg;
|
||||||
}
|
}
|
||||||
return msg;
|
return msg;
|
||||||
@ -472,7 +468,16 @@ BpSolver::converged (void)
|
|||||||
if (links_.size() == 0) {
|
if (links_.size() == 0) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
if (nIters_ <= 1) {
|
if (nIters_ == 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (Globals::verbosity > 2) {
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
if (nIters_ == 1) {
|
||||||
|
if (Globals::verbosity > 1) {
|
||||||
|
cout << "no residuals" << endl << endl;
|
||||||
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
bool converged = true;
|
bool converged = true;
|
||||||
@ -486,7 +491,7 @@ BpSolver::converged (void)
|
|||||||
} else {
|
} else {
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
for (unsigned i = 0; i < links_.size(); i++) {
|
||||||
double residual = links_[i]->getResidual();
|
double residual = links_[i]->getResidual();
|
||||||
if (Constants::DEBUG >= 2) {
|
if (Globals::verbosity > 1) {
|
||||||
cout << links_[i]->toString() + " residual = " << residual << endl;
|
cout << links_[i]->toString() + " residual = " << residual << endl;
|
||||||
}
|
}
|
||||||
if (residual > BpOptions::accuracy) {
|
if (residual > BpOptions::accuracy) {
|
||||||
@ -494,7 +499,7 @@ BpSolver::converged (void)
|
|||||||
if (Constants::DEBUG == 0) break;
|
if (Constants::DEBUG == 0) break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (Constants::DEBUG >= 2) {
|
if (Globals::verbosity > 1) {
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -128,7 +128,7 @@ class BpSolver : public Solver
|
|||||||
|
|
||||||
void calculateAndUpdateMessage (SpLink* link, bool calcResidual = true)
|
void calculateAndUpdateMessage (SpLink* link, bool calcResidual = true)
|
||||||
{
|
{
|
||||||
if (Constants::DEBUG >= 3) {
|
if (Globals::verbosity > 2) {
|
||||||
cout << "calculating & updating " << link->toString() << endl;
|
cout << "calculating & updating " << link->toString() << endl;
|
||||||
}
|
}
|
||||||
calculateFactor2VariableMsg (link);
|
calculateFactor2VariableMsg (link);
|
||||||
@ -140,7 +140,7 @@ class BpSolver : public Solver
|
|||||||
|
|
||||||
void calculateMessage (SpLink* link, bool calcResidual = true)
|
void calculateMessage (SpLink* link, bool calcResidual = true)
|
||||||
{
|
{
|
||||||
if (Constants::DEBUG >= 3) {
|
if (Globals::verbosity > 2) {
|
||||||
cout << "calculating " << link->toString() << endl;
|
cout << "calculating " << link->toString() << endl;
|
||||||
}
|
}
|
||||||
calculateFactor2VariableMsg (link);
|
calculateFactor2VariableMsg (link);
|
||||||
@ -152,7 +152,7 @@ class BpSolver : public Solver
|
|||||||
void updateMessage (SpLink* link)
|
void updateMessage (SpLink* link)
|
||||||
{
|
{
|
||||||
link->updateMessage();
|
link->updateMessage();
|
||||||
if (Constants::DEBUG >= 3) {
|
if (Globals::verbosity > 2) {
|
||||||
cout << "updating " << link->toString() << endl;
|
cout << "updating " << link->toString() << endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -24,14 +24,6 @@ CbpSolver::CbpSolver (const FactorGraph& fg) : BpSolver (fg)
|
|||||||
Statistics::updateCompressingStatistics (nrGroundVars,
|
Statistics::updateCompressingStatistics (nrGroundVars,
|
||||||
nrGroundFacs, nrClusterVars, nrClusterFacs, nrNeighborless);
|
nrGroundFacs, nrClusterVars, nrClusterFacs, nrNeighborless);
|
||||||
}
|
}
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
cout << "uncompressed factor graph:" << endl;
|
|
||||||
cout << " " << fg.nrVarNodes() << " variables " << endl;
|
|
||||||
cout << " " << fg.nrFacNodes() << " factors " << endl;
|
|
||||||
cout << "compressed factor graph:" << endl;
|
|
||||||
cout << " " << fg_->nrVarNodes() << " variables " << endl;
|
|
||||||
cout << " " << fg_->nrFacNodes() << " factors " << endl;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -124,14 +116,24 @@ CbpSolver::getJointDistributionOf (const VarIds& jointVids)
|
|||||||
void
|
void
|
||||||
CbpSolver::createLinks (void)
|
CbpSolver::createLinks (void)
|
||||||
{
|
{
|
||||||
|
if (Globals::verbosity > 0) {
|
||||||
|
cout << "original factor graph has " ;
|
||||||
|
cout << fg.nrVarNodes() << " variables and " ;
|
||||||
|
cout << fg.nrFacNodes() << " factors " << endl;
|
||||||
|
cout << "compressed factor graph has " ;
|
||||||
|
cout << fg_->nrVarNodes() << " variables and " ;
|
||||||
|
cout << fg_->nrFacNodes() << " factors " << endl;
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
const FacClusters& fcs = cfg_->facClusters();
|
const FacClusters& fcs = cfg_->facClusters();
|
||||||
for (unsigned i = 0; i < fcs.size(); i++) {
|
for (unsigned i = 0; i < fcs.size(); i++) {
|
||||||
const VarClusters& vcs = fcs[i]->varClusters();
|
const VarClusters& vcs = fcs[i]->varClusters();
|
||||||
for (unsigned j = 0; j < vcs.size(); j++) {
|
for (unsigned j = 0; j < vcs.size(); j++) {
|
||||||
unsigned count = cfg_->getEdgeCount (fcs[i], vcs[j], j);
|
unsigned count = cfg_->getEdgeCount (fcs[i], vcs[j], j);
|
||||||
if (Constants::DEBUG >= 5) {
|
if (Globals::verbosity > 1) {
|
||||||
cout << "creating edge " ;
|
cout << "creating link " ;
|
||||||
cout << fcs[i]->representative()->getLabel() << " -> " ;
|
cout << fcs[i]->representative()->getLabel();
|
||||||
|
cout << " -- " ;
|
||||||
cout << vcs[j]->representative()->label();
|
cout << vcs[j]->representative()->label();
|
||||||
cout << " idx=" << j << ", count=" << count << endl;
|
cout << " idx=" << j << ", count=" << count << endl;
|
||||||
}
|
}
|
||||||
@ -139,6 +141,9 @@ CbpSolver::createLinks (void)
|
|||||||
fcs[i]->representative(), vcs[j]->representative(), j, count));
|
fcs[i]->representative(), vcs[j]->representative(), j, count));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (Globals::verbosity > 1) {
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -151,7 +156,7 @@ CbpSolver::maxResidualSchedule (void)
|
|||||||
calculateMessage (links_[i]);
|
calculateMessage (links_[i]);
|
||||||
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
|
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
|
||||||
linkMap_.insert (make_pair (links_[i], it));
|
linkMap_.insert (make_pair (links_[i], it));
|
||||||
if (Constants::DEBUG >= 2 && Constants::DEBUG < 5) {
|
if (Globals::verbosity >= 1) {
|
||||||
cout << "calculating " << links_[i]->toString() << endl;
|
cout << "calculating " << links_[i]->toString() << endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -159,7 +164,7 @@ CbpSolver::maxResidualSchedule (void)
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (unsigned c = 0; c < links_.size(); c++) {
|
for (unsigned c = 0; c < links_.size(); c++) {
|
||||||
if (Constants::DEBUG >= 2) {
|
if (Globals::verbosity > 1) {
|
||||||
cout << endl << "current residuals:" << endl;
|
cout << endl << "current residuals:" << endl;
|
||||||
for (SortedOrder::iterator it = sortedOrder_.begin();
|
for (SortedOrder::iterator it = sortedOrder_.begin();
|
||||||
it != sortedOrder_.end(); it ++) {
|
it != sortedOrder_.end(); it ++) {
|
||||||
@ -170,7 +175,7 @@ CbpSolver::maxResidualSchedule (void)
|
|||||||
|
|
||||||
SortedOrder::iterator it = sortedOrder_.begin();
|
SortedOrder::iterator it = sortedOrder_.begin();
|
||||||
SpLink* link = *it;
|
SpLink* link = *it;
|
||||||
if (Constants::DEBUG >= 2) {
|
if (Globals::verbosity >= 1) {
|
||||||
cout << "updating " << (*sortedOrder_.begin())->toString() << endl;
|
cout << "updating " << (*sortedOrder_.begin())->toString() << endl;
|
||||||
}
|
}
|
||||||
if (link->getResidual() < BpOptions::accuracy) {
|
if (link->getResidual() < BpOptions::accuracy) {
|
||||||
@ -187,7 +192,7 @@ CbpSolver::maxResidualSchedule (void)
|
|||||||
const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks();
|
const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks();
|
||||||
for (unsigned j = 0; j < links.size(); j++) {
|
for (unsigned j = 0; j < links.size(); j++) {
|
||||||
if (links[j]->getVariable() != link->getVariable()) {
|
if (links[j]->getVariable() != link->getVariable()) {
|
||||||
if (Constants::DEBUG >= 2 && Constants::DEBUG < 5) {
|
if (Globals::verbosity > 1) {
|
||||||
cout << " calculating " << links[j]->toString() << endl;
|
cout << " calculating " << links[j]->toString() << endl;
|
||||||
}
|
}
|
||||||
calculateMessage (links[j]);
|
calculateMessage (links[j]);
|
||||||
@ -202,7 +207,7 @@ CbpSolver::maxResidualSchedule (void)
|
|||||||
const SpLinkSet& links = ninf(link->getFactor())->getLinks();
|
const SpLinkSet& links = ninf(link->getFactor())->getLinks();
|
||||||
for (unsigned i = 0; i < links.size(); i++) {
|
for (unsigned i = 0; i < links.size(); i++) {
|
||||||
if (links[i]->getVariable() != link->getVariable()) {
|
if (links[i]->getVariable() != link->getVariable()) {
|
||||||
if (Constants::DEBUG >= 2 && Constants::DEBUG < 5) {
|
if (Globals::verbosity > 1) {
|
||||||
cout << " calculating " << links[i]->toString() << endl;
|
cout << " calculating " << links[i]->toString() << endl;
|
||||||
}
|
}
|
||||||
calculateMessage (links[i]);
|
calculateMessage (links[i]);
|
||||||
@ -235,13 +240,13 @@ CbpSolver::calculateFactor2VariableMsg (SpLink* _link)
|
|||||||
for (int i = links.size() - 1; i >= 0; i--) {
|
for (int i = links.size() - 1; i >= 0; i--) {
|
||||||
const CbpSolverLink* cl = static_cast<const CbpSolverLink*> (links[i]);
|
const CbpSolverLink* cl = static_cast<const CbpSolverLink*> (links[i]);
|
||||||
if ( ! (cl->getVariable() == dst && cl->index() == link->index())) {
|
if ( ! (cl->getVariable() == dst && cl->index() == link->index())) {
|
||||||
if (Constants::DEBUG >= 5) {
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
cout << " message from " << links[i]->getVariable()->label();
|
cout << " message from " << links[i]->getVariable()->label();
|
||||||
cout << ": " ;
|
cout << ": " ;
|
||||||
}
|
}
|
||||||
Util::add (msgProduct, getVar2FactorMsg (links[i]), repetitions);
|
Util::add (msgProduct, getVar2FactorMsg (links[i]), repetitions);
|
||||||
repetitions *= links[i]->getVariable()->range();
|
repetitions *= links[i]->getVariable()->range();
|
||||||
if (Constants::DEBUG >= 5) {
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -254,13 +259,13 @@ CbpSolver::calculateFactor2VariableMsg (SpLink* _link)
|
|||||||
for (int i = links.size() - 1; i >= 0; i--) {
|
for (int i = links.size() - 1; i >= 0; i--) {
|
||||||
const CbpSolverLink* cl = static_cast<const CbpSolverLink*> (links[i]);
|
const CbpSolverLink* cl = static_cast<const CbpSolverLink*> (links[i]);
|
||||||
if ( ! (cl->getVariable() == dst && cl->index() == link->index())) {
|
if ( ! (cl->getVariable() == dst && cl->index() == link->index())) {
|
||||||
if (Constants::DEBUG >= 5) {
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
cout << " message from " << links[i]->getVariable()->label();
|
cout << " message from " << links[i]->getVariable()->label();
|
||||||
cout << ": " ;
|
cout << ": " ;
|
||||||
}
|
}
|
||||||
Util::multiply (msgProduct, getVar2FactorMsg (links[i]), repetitions);
|
Util::multiply (msgProduct, getVar2FactorMsg (links[i]), repetitions);
|
||||||
repetitions *= links[i]->getVariable()->range();
|
repetitions *= links[i]->getVariable()->range();
|
||||||
if (Constants::DEBUG >= 5) {
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -282,18 +287,18 @@ CbpSolver::calculateFactor2VariableMsg (SpLink* _link)
|
|||||||
result[i] *= src->factor()[i];
|
result[i] *= src->factor()[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (Constants::DEBUG >= 5) {
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
cout << " message product: " << msgProduct << endl;
|
cout << " message product: " << msgProduct << endl;
|
||||||
cout << " original factor: " << src->factor().params() << endl;
|
cout << " original factor: " << src->factor().params() << endl;
|
||||||
cout << " factor product: " << result.params() << endl;
|
cout << " factor product: " << result.params() << endl;
|
||||||
}
|
}
|
||||||
result.sumOutAllExceptIndex (link->index());
|
result.sumOutAllExceptIndex (link->index());
|
||||||
if (Constants::DEBUG >= 5) {
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
cout << " marginalized: " << result.params() << endl;
|
cout << " marginalized: " << result.params() << endl;
|
||||||
}
|
}
|
||||||
link->getNextMessage() = result.params();
|
link->getNextMessage() = result.params();
|
||||||
LogAware::normalize (link->getNextMessage());
|
LogAware::normalize (link->getNextMessage());
|
||||||
if (Constants::DEBUG >= 5) {
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
cout << " curr msg: " << link->getMessage() << endl;
|
cout << " curr msg: " << link->getMessage() << endl;
|
||||||
cout << " next msg: " << link->getNextMessage() << endl;
|
cout << " next msg: " << link->getNextMessage() << endl;
|
||||||
}
|
}
|
||||||
@ -311,14 +316,14 @@ CbpSolver::getVar2FactorMsg (const SpLink* _link) const
|
|||||||
if (src->hasEvidence()) {
|
if (src->hasEvidence()) {
|
||||||
msg.resize (src->range(), LogAware::noEvidence());
|
msg.resize (src->range(), LogAware::noEvidence());
|
||||||
double value = link->getMessage()[src->getEvidence()];
|
double value = link->getMessage()[src->getEvidence()];
|
||||||
if (Constants::DEBUG >= 5) {
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
msg[src->getEvidence()] = value;
|
msg[src->getEvidence()] = value;
|
||||||
cout << msg << "^" << link->nrEdges() << "-1" ;
|
cout << msg << "^" << link->nrEdges() << "-1" ;
|
||||||
}
|
}
|
||||||
msg[src->getEvidence()] = LogAware::pow (value, link->nrEdges() - 1);
|
msg[src->getEvidence()] = LogAware::pow (value, link->nrEdges() - 1);
|
||||||
} else {
|
} else {
|
||||||
msg = link->getMessage();
|
msg = link->getMessage();
|
||||||
if (Constants::DEBUG >= 5) {
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
cout << msg << "^" << link->nrEdges() << "-1" ;
|
cout << msg << "^" << link->nrEdges() << "-1" ;
|
||||||
}
|
}
|
||||||
LogAware::pow (msg, link->nrEdges() - 1);
|
LogAware::pow (msg, link->nrEdges() - 1);
|
||||||
@ -337,13 +342,13 @@ CbpSolver::getVar2FactorMsg (const SpLink* _link) const
|
|||||||
CbpSolverLink* cl = static_cast<CbpSolverLink*> (links[i]);
|
CbpSolverLink* cl = static_cast<CbpSolverLink*> (links[i]);
|
||||||
if ( ! (cl->getFactor() == dst && cl->index() == link->index())) {
|
if ( ! (cl->getFactor() == dst && cl->index() == link->index())) {
|
||||||
Util::multiply (msg, cl->poweredMessage());
|
Util::multiply (msg, cl->poweredMessage());
|
||||||
if (Constants::DEBUG >= 5) {
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
cout << " x " << cl->getNextMessage() << "^" << link->nrEdges();
|
cout << " x " << cl->getNextMessage() << "^" << link->nrEdges();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (Constants::DEBUG >= 5) {
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
cout << " = " << msg;
|
cout << " = " << msg;
|
||||||
}
|
}
|
||||||
return msg;
|
return msg;
|
||||||
|
@ -151,13 +151,11 @@ Factor::sumOutIndex (unsigned idx)
|
|||||||
void
|
void
|
||||||
Factor::sumOutAllExceptIndex (unsigned idx)
|
Factor::sumOutAllExceptIndex (unsigned idx)
|
||||||
{
|
{
|
||||||
int i = (int)idx;
|
while (args_.size() > idx + 1) {
|
||||||
while (args_.size() > i + 1) {
|
|
||||||
sumOutLastVariable();
|
sumOutLastVariable();
|
||||||
}
|
}
|
||||||
while (i > 0) {
|
for (unsigned i = 0; i < idx; i++) {
|
||||||
sumOutFirstVariable();
|
sumOutFirstVariable();
|
||||||
i -- ;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -542,6 +542,18 @@ FoveSolver::getJointDistributionOf (const Grounds& query)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
FoveSolver::printSolverFlags (void) const
|
||||||
|
{
|
||||||
|
stringstream ss;
|
||||||
|
ss << "fove [" ;
|
||||||
|
ss << "log_domain=" << Util::toString (Globals::logDomain);
|
||||||
|
ss << "]" ;
|
||||||
|
cout << ss.str() << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
FoveSolver::absorveEvidence (
|
FoveSolver::absorveEvidence (
|
||||||
ParfactorList& pfList,
|
ParfactorList& pfList,
|
||||||
@ -568,7 +580,7 @@ FoveSolver::absorveEvidence (
|
|||||||
}
|
}
|
||||||
pfList.add (newPfs);
|
pfList.add (newPfs);
|
||||||
}
|
}
|
||||||
if (Constants::DEBUG >= 2 && obsFormulas.empty() == false) {
|
if (Globals::verbosity > 2 && obsFormulas.empty() == false) {
|
||||||
Util::printAsteriskLine();
|
Util::printAsteriskLine();
|
||||||
cout << "AFTER EVIDENCE ABSORVED" << endl;
|
cout << "AFTER EVIDENCE ABSORVED" << endl;
|
||||||
for (unsigned i = 0; i < obsFormulas.size(); i++) {
|
for (unsigned i = 0; i < obsFormulas.size(); i++) {
|
||||||
@ -603,20 +615,26 @@ FoveSolver::countNormalize (
|
|||||||
void
|
void
|
||||||
FoveSolver::runSolver (const Grounds& query)
|
FoveSolver::runSolver (const Grounds& query)
|
||||||
{
|
{
|
||||||
|
largestCost_ = std::log (0);
|
||||||
shatterAgainstQuery (query);
|
shatterAgainstQuery (query);
|
||||||
runWeakBayesBall (query);
|
runWeakBayesBall (query);
|
||||||
while (true) {
|
while (true) {
|
||||||
if (Constants::DEBUG >= 2) {
|
if (Globals::verbosity > 2) {
|
||||||
Util::printDashedLine();
|
Util::printDashedLine();
|
||||||
pfList_.print();
|
pfList_.print();
|
||||||
|
if (Globals::verbosity > 3) {
|
||||||
LiftedOperator::printValidOps (pfList_, query);
|
LiftedOperator::printValidOps (pfList_, query);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
LiftedOperator* op = getBestOperation (query);
|
LiftedOperator* op = getBestOperation (query);
|
||||||
if (op == 0) {
|
if (op == 0) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
if (Constants::DEBUG >= 2) {
|
if (Globals::verbosity > 1) {
|
||||||
cout << "best operation: " << op->toString() << endl;
|
cout << "best operation: " << op->toString();
|
||||||
|
if (Globals::verbosity > 2) {
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
op->apply();
|
op->apply();
|
||||||
delete op;
|
delete op;
|
||||||
@ -630,6 +648,10 @@ FoveSolver::runSolver (const Grounds& query)
|
|||||||
++ pfIter;
|
++ pfIter;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (Globals::verbosity > 0) {
|
||||||
|
cout << "largest cost = " << std::exp (largestCost_) << endl;
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
(*pfList_.begin())->reorderAccordingGrounds (query);
|
(*pfList_.begin())->reorderAccordingGrounds (query);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -649,6 +671,9 @@ FoveSolver::getBestOperation (const Grounds& query)
|
|||||||
bestCost = cost;
|
bestCost = cost;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (bestCost > largestCost_) {
|
||||||
|
largestCost_ = bestCost;
|
||||||
|
}
|
||||||
for (unsigned i = 0; i < validOps.size(); i++) {
|
for (unsigned i = 0; i < validOps.size(); i++) {
|
||||||
if (validOps[i] != bestOp) {
|
if (validOps[i] != bestOp) {
|
||||||
delete validOps[i];
|
delete validOps[i];
|
||||||
@ -699,18 +724,21 @@ FoveSolver::runWeakBayesBall (const Grounds& query)
|
|||||||
}
|
}
|
||||||
|
|
||||||
ParfactorList::iterator it = pfList_.begin();
|
ParfactorList::iterator it = pfList_.begin();
|
||||||
|
bool foundNotRequired = false;
|
||||||
while (it != pfList_.end()) {
|
while (it != pfList_.end()) {
|
||||||
if (Util::contains (requiredPfs, *it) == false) {
|
if (Util::contains (requiredPfs, *it) == false) {
|
||||||
|
if (Globals::verbosity > 2) {
|
||||||
|
if (foundNotRequired == false) {
|
||||||
|
Util::printHeader ("PARFACTORS TO DISCARD");
|
||||||
|
foundNotRequired = true;
|
||||||
|
}
|
||||||
|
(*it)->print();
|
||||||
|
}
|
||||||
it = pfList_.removeAndDelete (it);
|
it = pfList_.removeAndDelete (it);
|
||||||
} else {
|
} else {
|
||||||
++ it;
|
++ it;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Constants::DEBUG >= 2) {
|
|
||||||
Util::printHeader ("REQUIRED PARFACTORS");
|
|
||||||
pfList_.print();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -750,8 +778,7 @@ FoveSolver::shatterAgainstQuery (const Grounds& query)
|
|||||||
}
|
}
|
||||||
pfList_.add (newPfs);
|
pfList_.add (newPfs);
|
||||||
}
|
}
|
||||||
if (Constants::DEBUG >= 2) {
|
if (Globals::verbosity > 2) {
|
||||||
cout << endl;
|
|
||||||
Util::printAsteriskLine();
|
Util::printAsteriskLine();
|
||||||
cout << "SHATTERED AGAINST THE QUERY" << endl;
|
cout << "SHATTERED AGAINST THE QUERY" << endl;
|
||||||
for (unsigned i = 0; i < query.size(); i++) {
|
for (unsigned i = 0; i < query.size(); i++) {
|
||||||
|
@ -116,6 +116,8 @@ class FoveSolver
|
|||||||
|
|
||||||
Params getJointDistributionOf (const Grounds&);
|
Params getJointDistributionOf (const Grounds&);
|
||||||
|
|
||||||
|
void printSolverFlags (void) const;
|
||||||
|
|
||||||
static void absorveEvidence (
|
static void absorveEvidence (
|
||||||
ParfactorList& pfList, ObservedFormulas& obsFormulas);
|
ParfactorList& pfList, ObservedFormulas& obsFormulas);
|
||||||
|
|
||||||
@ -133,6 +135,8 @@ class FoveSolver
|
|||||||
static Parfactors absorve (ObservedFormula&, Parfactor*);
|
static Parfactors absorve (ObservedFormula&, Parfactor*);
|
||||||
|
|
||||||
ParfactorList pfList_;
|
ParfactorList pfList_;
|
||||||
|
|
||||||
|
double largestCost_;
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // HORUS_FOVESOLVER_H
|
#endif // HORUS_FOVESOLVER_H
|
||||||
|
@ -39,6 +39,8 @@ namespace Globals {
|
|||||||
|
|
||||||
extern bool logDomain;
|
extern bool logDomain;
|
||||||
|
|
||||||
|
extern unsigned verbosity;
|
||||||
|
|
||||||
extern InfAlgorithms infAlgorithm;
|
extern InfAlgorithms infAlgorithm;
|
||||||
|
|
||||||
};
|
};
|
||||||
@ -47,12 +49,14 @@ extern InfAlgorithms infAlgorithm;
|
|||||||
namespace Constants {
|
namespace Constants {
|
||||||
|
|
||||||
// level of debug information
|
// level of debug information
|
||||||
const unsigned DEBUG = 0;
|
const unsigned DEBUG = 3;
|
||||||
|
|
||||||
|
const bool SHOW_BP_CALCS = false;
|
||||||
|
|
||||||
const int NO_EVIDENCE = -1;
|
const int NO_EVIDENCE = -1;
|
||||||
|
|
||||||
// number of digits to show when printing a parameter
|
// number of digits to show when printing a parameter
|
||||||
const unsigned PRECISION = 5;
|
const unsigned PRECISION = 6;
|
||||||
|
|
||||||
const bool COLLECT_STATS = false;
|
const bool COLLECT_STATS = false;
|
||||||
|
|
||||||
|
@ -174,8 +174,10 @@ runSolver (const FactorGraph& fg, const VarIds& queryIds)
|
|||||||
default:
|
default:
|
||||||
assert (false);
|
assert (false);
|
||||||
}
|
}
|
||||||
|
if (Globals::verbosity > 0) {
|
||||||
solver->printSolverFlags();
|
solver->printSolverFlags();
|
||||||
cout << endl;
|
cout << endl;
|
||||||
|
}
|
||||||
if (queryIds.size() == 0) {
|
if (queryIds.size() == 0) {
|
||||||
solver->printAllPosterioris();
|
solver->printAllPosterioris();
|
||||||
} else {
|
} else {
|
||||||
|
@ -64,7 +64,7 @@ int createLiftedNetwork (void)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// LiftedUtils::printSymbolDictionary();
|
// LiftedUtils::printSymbolDictionary();
|
||||||
if (Constants::DEBUG > 2) {
|
if (Globals::verbosity > 2) {
|
||||||
Util::printHeader ("INITIAL PARFACTORS");
|
Util::printHeader ("INITIAL PARFACTORS");
|
||||||
for (unsigned i = 0; i < parfactors.size(); i++) {
|
for (unsigned i = 0; i < parfactors.size(); i++) {
|
||||||
parfactors[i]->print();
|
parfactors[i]->print();
|
||||||
@ -73,7 +73,7 @@ int createLiftedNetwork (void)
|
|||||||
|
|
||||||
ParfactorList* pfList = new ParfactorList (parfactors);
|
ParfactorList* pfList = new ParfactorList (parfactors);
|
||||||
|
|
||||||
if (Constants::DEBUG >= 2) {
|
if (Globals::verbosity > 2) {
|
||||||
Util::printHeader ("SHATTERED PARFACTORS");
|
Util::printHeader ("SHATTERED PARFACTORS");
|
||||||
pfList->print();
|
pfList->print();
|
||||||
}
|
}
|
||||||
@ -300,6 +300,10 @@ runLiftedSolver (void)
|
|||||||
jointList = YAP_TailOfTerm (jointList);
|
jointList = YAP_TailOfTerm (jointList);
|
||||||
}
|
}
|
||||||
FoveSolver solver (pfListCopy);
|
FoveSolver solver (pfListCopy);
|
||||||
|
if (Globals::verbosity > 0 && taskList == YAP_ARG2) {
|
||||||
|
solver.printSolverFlags();
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
if (queryVars.size() == 1) {
|
if (queryVars.size() == 1) {
|
||||||
results.push_back (solver.getPosterioriOf (queryVars[0]));
|
results.push_back (solver.getPosterioriOf (queryVars[0]));
|
||||||
} else {
|
} else {
|
||||||
@ -376,7 +380,9 @@ void runVeSolver (
|
|||||||
}
|
}
|
||||||
// VarElimSolver solver (*mfg);
|
// VarElimSolver solver (*mfg);
|
||||||
VarElimSolver solver (*fg); //FIXME
|
VarElimSolver solver (*fg); //FIXME
|
||||||
// solver.printSolverFlags();
|
if (Globals::verbosity > 0 && i == 0) {
|
||||||
|
solver.printSolverFlags();
|
||||||
|
}
|
||||||
results.push_back (solver.solveQuery (tasks[i]));
|
results.push_back (solver.solveQuery (tasks[i]));
|
||||||
if (fg->isFromBayesNetwork()) {
|
if (fg->isFromBayesNetwork()) {
|
||||||
delete mfg;
|
delete mfg;
|
||||||
@ -410,7 +416,9 @@ void runBpSolver (
|
|||||||
cerr << "error: unknow solver" << endl;
|
cerr << "error: unknow solver" << endl;
|
||||||
abort();
|
abort();
|
||||||
}
|
}
|
||||||
// solver->printSolverFlags();
|
if (Globals::verbosity > 0) {
|
||||||
|
solver->printSolverFlags();
|
||||||
|
}
|
||||||
results.reserve (tasks.size());
|
results.reserve (tasks.size());
|
||||||
for (unsigned i = 0; i < tasks.size(); i++) {
|
for (unsigned i = 0; i < tasks.size(); i++) {
|
||||||
results.push_back (solver->solveQuery (tasks[i]));
|
results.push_back (solver->solveQuery (tasks[i]));
|
||||||
@ -508,7 +516,11 @@ setHorusFlag (void)
|
|||||||
{
|
{
|
||||||
string key ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG1)));
|
string key ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG1)));
|
||||||
string value;
|
string value;
|
||||||
if (key == "accuracy") {
|
if (key == "verbosity") {
|
||||||
|
stringstream ss;
|
||||||
|
ss << (int) YAP_IntOfTerm (YAP_ARG2);
|
||||||
|
ss >> value;
|
||||||
|
} else if (key == "accuracy") {
|
||||||
stringstream ss;
|
stringstream ss;
|
||||||
ss << (float) YAP_FloatOfTerm (YAP_ARG2);
|
ss << (float) YAP_FloatOfTerm (YAP_ARG2);
|
||||||
ss >> value;
|
ss >> value;
|
||||||
|
@ -11,10 +11,9 @@
|
|||||||
namespace Globals {
|
namespace Globals {
|
||||||
bool logDomain = false;
|
bool logDomain = false;
|
||||||
|
|
||||||
//InfAlgs infAlgorithm = InfAlgorithms::VE;
|
unsigned verbosity = 0;
|
||||||
//InfAlgs infAlgorithm = InfAlgorithms::BN_BP;
|
|
||||||
//InfAlgs infAlgorithm = InfAlgorithms::FG_BP;
|
InfAlgorithms infAlgorithm = InfAlgorithms::VE;
|
||||||
InfAlgorithms infAlgorithm = InfAlgorithms::CBP;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
@ -227,7 +226,11 @@ bool
|
|||||||
setHorusFlag (string key, string value)
|
setHorusFlag (string key, string value)
|
||||||
{
|
{
|
||||||
bool returnVal = true;
|
bool returnVal = true;
|
||||||
if (key == "inf_alg") {
|
if (key == "verbosity") {
|
||||||
|
stringstream ss;
|
||||||
|
ss << value;
|
||||||
|
ss >> Globals::verbosity;
|
||||||
|
} else if (key == "inf_alg") {
|
||||||
if ( value == "ve") {
|
if ( value == "ve") {
|
||||||
Globals::infAlgorithm = InfAlgorithms::VE;
|
Globals::infAlgorithm = InfAlgorithms::VE;
|
||||||
} else if (value == "bp") {
|
} else if (value == "bp") {
|
||||||
|
@ -16,6 +16,14 @@ VarElimSolver::~VarElimSolver (void)
|
|||||||
Params
|
Params
|
||||||
VarElimSolver::solveQuery (VarIds queryVids)
|
VarElimSolver::solveQuery (VarIds queryVids)
|
||||||
{
|
{
|
||||||
|
if (Globals::verbosity > 1) {
|
||||||
|
cout << "Solving query on " ;
|
||||||
|
for (unsigned i = 0; i < queryVids.size(); i++) {
|
||||||
|
if (i != 0) cout << ", " ;
|
||||||
|
cout << fg.getVarNode (queryVids[i])->label();
|
||||||
|
}
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
factorList_.clear();
|
factorList_.clear();
|
||||||
varFactors_.clear();
|
varFactors_.clear();
|
||||||
elimOrder_.clear();
|
elimOrder_.clear();
|
||||||
@ -77,9 +85,19 @@ VarElimSolver::createFactorList (void)
|
|||||||
void
|
void
|
||||||
VarElimSolver::absorveEvidence (void)
|
VarElimSolver::absorveEvidence (void)
|
||||||
{
|
{
|
||||||
|
if (Globals::verbosity > 2) {
|
||||||
|
Util::printDashedLine();
|
||||||
|
cout << "(initial factor list)" << endl;
|
||||||
|
printActiveFactors();
|
||||||
|
}
|
||||||
const VarNodes& varNodes = fg.varNodes();
|
const VarNodes& varNodes = fg.varNodes();
|
||||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||||
if (varNodes[i]->hasEvidence()) {
|
if (varNodes[i]->hasEvidence()) {
|
||||||
|
if (Globals::verbosity > 1) {
|
||||||
|
cout << "-> aborving evidence on ";
|
||||||
|
cout << varNodes[i]->label() << " = " ;
|
||||||
|
cout << varNodes[i]->getEvidence() << endl;
|
||||||
|
}
|
||||||
const vector<unsigned>& idxs =
|
const vector<unsigned>& idxs =
|
||||||
varFactors_.find (varNodes[i]->varId())->second;
|
varFactors_.find (varNodes[i]->varId())->second;
|
||||||
for (unsigned j = 0; j < idxs.size(); j++) {
|
for (unsigned j = 0; j < idxs.size(); j++) {
|
||||||
@ -108,12 +126,16 @@ VarElimSolver::findEliminationOrder (const VarIds& vids)
|
|||||||
void
|
void
|
||||||
VarElimSolver::processFactorList (const VarIds& vids)
|
VarElimSolver::processFactorList (const VarIds& vids)
|
||||||
{
|
{
|
||||||
|
totalFactorSize_ = 0;
|
||||||
|
largestFactorSize_ = 0;
|
||||||
for (unsigned i = 0; i < elimOrder_.size(); i++) {
|
for (unsigned i = 0; i < elimOrder_.size(); i++) {
|
||||||
if (Constants::DEBUG >= 3) {
|
if (Globals::verbosity >= 2) {
|
||||||
|
if (Globals::verbosity >= 3) {
|
||||||
|
Util::printDashedLine();
|
||||||
printActiveFactors();
|
printActiveFactors();
|
||||||
|
}
|
||||||
cout << "-> summing out " ;
|
cout << "-> summing out " ;
|
||||||
VarNode* vn = fg.getVarNode (elimOrder_[i]);
|
cout << fg.getVarNode (elimOrder_[i])->label() << endl;
|
||||||
cout << vn->label() << endl;
|
|
||||||
}
|
}
|
||||||
eliminate (elimOrder_[i]);
|
eliminate (elimOrder_[i]);
|
||||||
}
|
}
|
||||||
@ -137,6 +159,11 @@ VarElimSolver::processFactorList (const VarIds& vids)
|
|||||||
finalFactor->reorderArguments (unobservedVids);
|
finalFactor->reorderArguments (unobservedVids);
|
||||||
finalFactor->normalize();
|
finalFactor->normalize();
|
||||||
factorList_.push_back (finalFactor);
|
factorList_.push_back (finalFactor);
|
||||||
|
if (Globals::verbosity > 0) {
|
||||||
|
cout << "total factor size: " << totalFactorSize_ << endl;
|
||||||
|
cout << "largest factor size: " << largestFactorSize_ << endl;
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -158,6 +185,10 @@ VarElimSolver::eliminate (VarId elimVar)
|
|||||||
factorList_[idx] = 0;
|
factorList_[idx] = 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
totalFactorSize_ += result->size();
|
||||||
|
if (result->size() > largestFactorSize_) {
|
||||||
|
largestFactorSize_ = result->size();
|
||||||
|
}
|
||||||
if (result != 0 && result->nrArguments() != 1) {
|
if (result != 0 && result->nrArguments() != 1) {
|
||||||
result->sumOut (elimVar);
|
result->sumOut (elimVar);
|
||||||
factorList_.push_back (result);
|
factorList_.push_back (result);
|
||||||
@ -175,14 +206,11 @@ VarElimSolver::eliminate (VarId elimVar)
|
|||||||
void
|
void
|
||||||
VarElimSolver::printActiveFactors (void)
|
VarElimSolver::printActiveFactors (void)
|
||||||
{
|
{
|
||||||
cout << endl;
|
|
||||||
Util::printDashedLine();
|
|
||||||
for (unsigned i = 0; i < factorList_.size(); i++) {
|
for (unsigned i = 0; i < factorList_.size(); i++) {
|
||||||
if (factorList_[i] != 0) {
|
if (factorList_[i] != 0) {
|
||||||
cout << factorList_[i]->getLabel() << " " ;
|
cout << factorList_[i]->getLabel() << " " ;
|
||||||
cout << factorList_[i]->params() << endl;
|
cout << factorList_[i]->params() << endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cout << endl;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -35,8 +35,10 @@ class VarElimSolver : public Solver
|
|||||||
|
|
||||||
void printActiveFactors (void);
|
void printActiveFactors (void);
|
||||||
|
|
||||||
vector<Factor*> factorList_;
|
Factors factorList_;
|
||||||
VarIds elimOrder_;
|
VarIds elimOrder_;
|
||||||
|
unsigned largestFactorSize_;
|
||||||
|
unsigned totalFactorSize_;
|
||||||
unordered_map<VarId, vector<unsigned>> varFactors_;
|
unordered_map<VarId, vector<unsigned>> varFactors_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user