revamp debugging plataform

This commit is contained in:
Tiago Gomes 2012-04-29 20:07:09 +01:00
parent d86e2c8386
commit 56475cacbc
12 changed files with 184 additions and 94 deletions

View File

@ -153,9 +153,8 @@ BpSolver::runSolver (void)
nIters_ = 0; nIters_ = 0;
while (!converged() && nIters_ < BpOptions::maxIter) { while (!converged() && nIters_ < BpOptions::maxIter) {
nIters_ ++; nIters_ ++;
if (Constants::DEBUG >= 2) { if (Globals::verbosity > 1) {
Util::printHeader (string ("Iteration ") + Util::toString (nIters_)); Util::printHeader (string ("Iteration ") + Util::toString (nIters_));
// cout << endl;
} }
switch (BpOptions::schedule) { switch (BpOptions::schedule) {
case BpOptions::Schedule::SEQ_RANDOM: case BpOptions::Schedule::SEQ_RANDOM:
@ -178,12 +177,8 @@ BpSolver::runSolver (void)
maxResidualSchedule(); maxResidualSchedule();
break; break;
} }
if (Constants::DEBUG >= 2) {
cout << endl;
}
} }
if (Constants::DEBUG >= 2) { if (Globals::verbosity > 0) {
cout << endl;
if (nIters_ < BpOptions::maxIter) { if (nIters_ < BpOptions::maxIter) {
cout << "Sum-Product converged in " ; cout << "Sum-Product converged in " ;
cout << nIters_ << " iterations" << endl; cout << nIters_ << " iterations" << endl;
@ -191,6 +186,7 @@ BpSolver::runSolver (void)
cout << "The maximum number of iterations was hit, terminating..." ; cout << "The maximum number of iterations was hit, terminating..." ;
cout << endl; cout << endl;
} }
cout << endl;
} }
unsigned size = fg_->varNodes().size(); unsigned size = fg_->varNodes().size();
if (Constants::COLLECT_STATS) { if (Constants::COLLECT_STATS) {
@ -232,7 +228,7 @@ BpSolver::maxResidualSchedule (void)
} }
for (unsigned c = 0; c < links_.size(); c++) { for (unsigned c = 0; c < links_.size(); c++) {
if (Constants::DEBUG >= 2) { if (Globals::verbosity > 1) {
cout << "current residuals:" << endl; cout << "current residuals:" << endl;
for (SortedOrder::iterator it = sortedOrder_.begin(); for (SortedOrder::iterator it = sortedOrder_.begin();
it != sortedOrder_.end(); it ++) { it != sortedOrder_.end(); it ++) {
@ -266,7 +262,7 @@ BpSolver::maxResidualSchedule (void)
} }
} }
} }
if (Constants::DEBUG >= 2) { if (Globals::verbosity > 1) {
Util::printDashedLine(); Util::printDashedLine();
} }
} }
@ -291,13 +287,13 @@ BpSolver::calculateFactor2VariableMsg (SpLink* link)
if (Globals::logDomain) { if (Globals::logDomain) {
for (int i = links.size() - 1; i >= 0; i--) { for (int i = links.size() - 1; i >= 0; i--) {
if (links[i]->getVariable() != dst) { if (links[i]->getVariable() != dst) {
if (Constants::DEBUG >= 5) { if (Constants::SHOW_BP_CALCS) {
cout << " message from " << links[i]->getVariable()->label(); cout << " message from " << links[i]->getVariable()->label();
cout << ": " ; cout << ": " ;
} }
Util::add (msgProduct, getVar2FactorMsg (links[i]), repetitions); Util::add (msgProduct, getVar2FactorMsg (links[i]), repetitions);
repetitions *= links[i]->getVariable()->range(); repetitions *= links[i]->getVariable()->range();
if (Constants::DEBUG >= 5) { if (Constants::SHOW_BP_CALCS) {
cout << endl; cout << endl;
} }
} else { } else {
@ -309,13 +305,13 @@ BpSolver::calculateFactor2VariableMsg (SpLink* link)
} else { } else {
for (int i = links.size() - 1; i >= 0; i--) { for (int i = links.size() - 1; i >= 0; i--) {
if (links[i]->getVariable() != dst) { if (links[i]->getVariable() != dst) {
if (Constants::DEBUG >= 5) { if (Constants::SHOW_BP_CALCS) {
cout << " message from " << links[i]->getVariable()->label(); cout << " message from " << links[i]->getVariable()->label();
cout << ": " ; cout << ": " ;
} }
Util::multiply (msgProduct, getVar2FactorMsg (links[i]), repetitions); Util::multiply (msgProduct, getVar2FactorMsg (links[i]), repetitions);
repetitions *= links[i]->getVariable()->range(); repetitions *= links[i]->getVariable()->range();
if (Constants::DEBUG >= 5) { if (Constants::SHOW_BP_CALCS) {
cout << endl; cout << endl;
} }
} else { } else {
@ -328,18 +324,18 @@ BpSolver::calculateFactor2VariableMsg (SpLink* link)
Factor result (src->factor().arguments(), Factor result (src->factor().arguments(),
src->factor().ranges(), msgProduct); src->factor().ranges(), msgProduct);
result.multiply (src->factor()); result.multiply (src->factor());
if (Constants::DEBUG >= 5) { if (Constants::SHOW_BP_CALCS) {
cout << " message product: " << msgProduct << endl; cout << " message product: " << msgProduct << endl;
cout << " original factor: " << src->factor().params() << endl; cout << " original factor: " << src->factor().params() << endl;
cout << " factor product: " << result.params() << endl; cout << " factor product: " << result.params() << endl;
} }
result.sumOutAllExcept (dst->varId()); result.sumOutAllExcept (dst->varId());
if (Constants::DEBUG >= 5) { if (Constants::SHOW_BP_CALCS) {
cout << " marginalized: " << result.params() << endl; cout << " marginalized: " << result.params() << endl;
} }
link->getNextMessage() = result.params(); link->getNextMessage() = result.params();
LogAware::normalize (link->getNextMessage()); LogAware::normalize (link->getNextMessage());
if (Constants::DEBUG >= 5) { if (Constants::SHOW_BP_CALCS) {
cout << " curr msg: " << link->getMessage() << endl; cout << " curr msg: " << link->getMessage() << endl;
cout << " next msg: " << link->getNextMessage() << endl; cout << " next msg: " << link->getNextMessage() << endl;
} }
@ -359,7 +355,7 @@ BpSolver::getVar2FactorMsg (const SpLink* link) const
} else { } else {
msg.resize (src->range(), LogAware::one()); msg.resize (src->range(), LogAware::one());
} }
if (Constants::DEBUG >= 5) { if (Constants::SHOW_BP_CALCS) {
cout << msg; cout << msg;
} }
const SpLinkSet& links = ninf (src)->getLinks(); const SpLinkSet& links = ninf (src)->getLinks();
@ -367,7 +363,7 @@ BpSolver::getVar2FactorMsg (const SpLink* link) const
for (unsigned i = 0; i < links.size(); i++) { for (unsigned i = 0; i < links.size(); i++) {
if (links[i]->getFactor() != dst) { if (links[i]->getFactor() != dst) {
Util::add (msg, links[i]->getMessage()); Util::add (msg, links[i]->getMessage());
if (Constants::DEBUG >= 5) { if (Constants::SHOW_BP_CALCS) {
cout << " x " << links[i]->getMessage(); cout << " x " << links[i]->getMessage();
} }
} }
@ -376,13 +372,13 @@ BpSolver::getVar2FactorMsg (const SpLink* link) const
for (unsigned i = 0; i < links.size(); i++) { for (unsigned i = 0; i < links.size(); i++) {
if (links[i]->getFactor() != dst) { if (links[i]->getFactor() != dst) {
Util::multiply (msg, links[i]->getMessage()); Util::multiply (msg, links[i]->getMessage());
if (Constants::DEBUG >= 5) { if (Constants::SHOW_BP_CALCS) {
cout << " x " << links[i]->getMessage(); cout << " x " << links[i]->getMessage();
} }
} }
} }
} }
if (Constants::DEBUG >= 5) { if (Constants::SHOW_BP_CALCS) {
cout << " = " << msg; cout << " = " << msg;
} }
return msg; return msg;
@ -472,7 +468,16 @@ BpSolver::converged (void)
if (links_.size() == 0) { if (links_.size() == 0) {
return true; return true;
} }
if (nIters_ <= 1) { if (nIters_ == 0) {
return false;
}
if (Globals::verbosity > 2) {
cout << endl;
}
if (nIters_ == 1) {
if (Globals::verbosity > 1) {
cout << "no residuals" << endl << endl;
}
return false; return false;
} }
bool converged = true; bool converged = true;
@ -486,7 +491,7 @@ BpSolver::converged (void)
} else { } else {
for (unsigned i = 0; i < links_.size(); i++) { for (unsigned i = 0; i < links_.size(); i++) {
double residual = links_[i]->getResidual(); double residual = links_[i]->getResidual();
if (Constants::DEBUG >= 2) { if (Globals::verbosity > 1) {
cout << links_[i]->toString() + " residual = " << residual << endl; cout << links_[i]->toString() + " residual = " << residual << endl;
} }
if (residual > BpOptions::accuracy) { if (residual > BpOptions::accuracy) {
@ -494,7 +499,7 @@ BpSolver::converged (void)
if (Constants::DEBUG == 0) break; if (Constants::DEBUG == 0) break;
} }
} }
if (Constants::DEBUG >= 2) { if (Globals::verbosity > 1) {
cout << endl; cout << endl;
} }
} }

View File

@ -128,7 +128,7 @@ class BpSolver : public Solver
void calculateAndUpdateMessage (SpLink* link, bool calcResidual = true) void calculateAndUpdateMessage (SpLink* link, bool calcResidual = true)
{ {
if (Constants::DEBUG >= 3) { if (Globals::verbosity > 2) {
cout << "calculating & updating " << link->toString() << endl; cout << "calculating & updating " << link->toString() << endl;
} }
calculateFactor2VariableMsg (link); calculateFactor2VariableMsg (link);
@ -140,7 +140,7 @@ class BpSolver : public Solver
void calculateMessage (SpLink* link, bool calcResidual = true) void calculateMessage (SpLink* link, bool calcResidual = true)
{ {
if (Constants::DEBUG >= 3) { if (Globals::verbosity > 2) {
cout << "calculating " << link->toString() << endl; cout << "calculating " << link->toString() << endl;
} }
calculateFactor2VariableMsg (link); calculateFactor2VariableMsg (link);
@ -152,7 +152,7 @@ class BpSolver : public Solver
void updateMessage (SpLink* link) void updateMessage (SpLink* link)
{ {
link->updateMessage(); link->updateMessage();
if (Constants::DEBUG >= 3) { if (Globals::verbosity > 2) {
cout << "updating " << link->toString() << endl; cout << "updating " << link->toString() << endl;
} }
} }

View File

@ -24,14 +24,6 @@ CbpSolver::CbpSolver (const FactorGraph& fg) : BpSolver (fg)
Statistics::updateCompressingStatistics (nrGroundVars, Statistics::updateCompressingStatistics (nrGroundVars,
nrGroundFacs, nrClusterVars, nrClusterFacs, nrNeighborless); nrGroundFacs, nrClusterVars, nrClusterFacs, nrNeighborless);
} }
if (Constants::DEBUG >= 5) {
cout << "uncompressed factor graph:" << endl;
cout << " " << fg.nrVarNodes() << " variables " << endl;
cout << " " << fg.nrFacNodes() << " factors " << endl;
cout << "compressed factor graph:" << endl;
cout << " " << fg_->nrVarNodes() << " variables " << endl;
cout << " " << fg_->nrFacNodes() << " factors " << endl;
}
} }
@ -124,14 +116,24 @@ CbpSolver::getJointDistributionOf (const VarIds& jointVids)
void void
CbpSolver::createLinks (void) CbpSolver::createLinks (void)
{ {
if (Globals::verbosity > 0) {
cout << "original factor graph has " ;
cout << fg.nrVarNodes() << " variables and " ;
cout << fg.nrFacNodes() << " factors " << endl;
cout << "compressed factor graph has " ;
cout << fg_->nrVarNodes() << " variables and " ;
cout << fg_->nrFacNodes() << " factors " << endl;
cout << endl;
}
const FacClusters& fcs = cfg_->facClusters(); const FacClusters& fcs = cfg_->facClusters();
for (unsigned i = 0; i < fcs.size(); i++) { for (unsigned i = 0; i < fcs.size(); i++) {
const VarClusters& vcs = fcs[i]->varClusters(); const VarClusters& vcs = fcs[i]->varClusters();
for (unsigned j = 0; j < vcs.size(); j++) { for (unsigned j = 0; j < vcs.size(); j++) {
unsigned count = cfg_->getEdgeCount (fcs[i], vcs[j], j); unsigned count = cfg_->getEdgeCount (fcs[i], vcs[j], j);
if (Constants::DEBUG >= 5) { if (Globals::verbosity > 1) {
cout << "creating edge " ; cout << "creating link " ;
cout << fcs[i]->representative()->getLabel() << " -> " ; cout << fcs[i]->representative()->getLabel();
cout << " -- " ;
cout << vcs[j]->representative()->label(); cout << vcs[j]->representative()->label();
cout << " idx=" << j << ", count=" << count << endl; cout << " idx=" << j << ", count=" << count << endl;
} }
@ -139,6 +141,9 @@ CbpSolver::createLinks (void)
fcs[i]->representative(), vcs[j]->representative(), j, count)); fcs[i]->representative(), vcs[j]->representative(), j, count));
} }
} }
if (Globals::verbosity > 1) {
cout << endl;
}
} }
@ -151,7 +156,7 @@ CbpSolver::maxResidualSchedule (void)
calculateMessage (links_[i]); calculateMessage (links_[i]);
SortedOrder::iterator it = sortedOrder_.insert (links_[i]); SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
linkMap_.insert (make_pair (links_[i], it)); linkMap_.insert (make_pair (links_[i], it));
if (Constants::DEBUG >= 2 && Constants::DEBUG < 5) { if (Globals::verbosity >= 1) {
cout << "calculating " << links_[i]->toString() << endl; cout << "calculating " << links_[i]->toString() << endl;
} }
} }
@ -159,7 +164,7 @@ CbpSolver::maxResidualSchedule (void)
} }
for (unsigned c = 0; c < links_.size(); c++) { for (unsigned c = 0; c < links_.size(); c++) {
if (Constants::DEBUG >= 2) { if (Globals::verbosity > 1) {
cout << endl << "current residuals:" << endl; cout << endl << "current residuals:" << endl;
for (SortedOrder::iterator it = sortedOrder_.begin(); for (SortedOrder::iterator it = sortedOrder_.begin();
it != sortedOrder_.end(); it ++) { it != sortedOrder_.end(); it ++) {
@ -170,7 +175,7 @@ CbpSolver::maxResidualSchedule (void)
SortedOrder::iterator it = sortedOrder_.begin(); SortedOrder::iterator it = sortedOrder_.begin();
SpLink* link = *it; SpLink* link = *it;
if (Constants::DEBUG >= 2) { if (Globals::verbosity >= 1) {
cout << "updating " << (*sortedOrder_.begin())->toString() << endl; cout << "updating " << (*sortedOrder_.begin())->toString() << endl;
} }
if (link->getResidual() < BpOptions::accuracy) { if (link->getResidual() < BpOptions::accuracy) {
@ -187,7 +192,7 @@ CbpSolver::maxResidualSchedule (void)
const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks(); const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks();
for (unsigned j = 0; j < links.size(); j++) { for (unsigned j = 0; j < links.size(); j++) {
if (links[j]->getVariable() != link->getVariable()) { if (links[j]->getVariable() != link->getVariable()) {
if (Constants::DEBUG >= 2 && Constants::DEBUG < 5) { if (Globals::verbosity > 1) {
cout << " calculating " << links[j]->toString() << endl; cout << " calculating " << links[j]->toString() << endl;
} }
calculateMessage (links[j]); calculateMessage (links[j]);
@ -202,7 +207,7 @@ CbpSolver::maxResidualSchedule (void)
const SpLinkSet& links = ninf(link->getFactor())->getLinks(); const SpLinkSet& links = ninf(link->getFactor())->getLinks();
for (unsigned i = 0; i < links.size(); i++) { for (unsigned i = 0; i < links.size(); i++) {
if (links[i]->getVariable() != link->getVariable()) { if (links[i]->getVariable() != link->getVariable()) {
if (Constants::DEBUG >= 2 && Constants::DEBUG < 5) { if (Globals::verbosity > 1) {
cout << " calculating " << links[i]->toString() << endl; cout << " calculating " << links[i]->toString() << endl;
} }
calculateMessage (links[i]); calculateMessage (links[i]);
@ -235,13 +240,13 @@ CbpSolver::calculateFactor2VariableMsg (SpLink* _link)
for (int i = links.size() - 1; i >= 0; i--) { for (int i = links.size() - 1; i >= 0; i--) {
const CbpSolverLink* cl = static_cast<const CbpSolverLink*> (links[i]); const CbpSolverLink* cl = static_cast<const CbpSolverLink*> (links[i]);
if ( ! (cl->getVariable() == dst && cl->index() == link->index())) { if ( ! (cl->getVariable() == dst && cl->index() == link->index())) {
if (Constants::DEBUG >= 5) { if (Constants::SHOW_BP_CALCS) {
cout << " message from " << links[i]->getVariable()->label(); cout << " message from " << links[i]->getVariable()->label();
cout << ": " ; cout << ": " ;
} }
Util::add (msgProduct, getVar2FactorMsg (links[i]), repetitions); Util::add (msgProduct, getVar2FactorMsg (links[i]), repetitions);
repetitions *= links[i]->getVariable()->range(); repetitions *= links[i]->getVariable()->range();
if (Constants::DEBUG >= 5) { if (Constants::SHOW_BP_CALCS) {
cout << endl; cout << endl;
} }
} else { } else {
@ -254,13 +259,13 @@ CbpSolver::calculateFactor2VariableMsg (SpLink* _link)
for (int i = links.size() - 1; i >= 0; i--) { for (int i = links.size() - 1; i >= 0; i--) {
const CbpSolverLink* cl = static_cast<const CbpSolverLink*> (links[i]); const CbpSolverLink* cl = static_cast<const CbpSolverLink*> (links[i]);
if ( ! (cl->getVariable() == dst && cl->index() == link->index())) { if ( ! (cl->getVariable() == dst && cl->index() == link->index())) {
if (Constants::DEBUG >= 5) { if (Constants::SHOW_BP_CALCS) {
cout << " message from " << links[i]->getVariable()->label(); cout << " message from " << links[i]->getVariable()->label();
cout << ": " ; cout << ": " ;
} }
Util::multiply (msgProduct, getVar2FactorMsg (links[i]), repetitions); Util::multiply (msgProduct, getVar2FactorMsg (links[i]), repetitions);
repetitions *= links[i]->getVariable()->range(); repetitions *= links[i]->getVariable()->range();
if (Constants::DEBUG >= 5) { if (Constants::SHOW_BP_CALCS) {
cout << endl; cout << endl;
} }
} else { } else {
@ -282,18 +287,18 @@ CbpSolver::calculateFactor2VariableMsg (SpLink* _link)
result[i] *= src->factor()[i]; result[i] *= src->factor()[i];
} }
} }
if (Constants::DEBUG >= 5) { if (Constants::SHOW_BP_CALCS) {
cout << " message product: " << msgProduct << endl; cout << " message product: " << msgProduct << endl;
cout << " original factor: " << src->factor().params() << endl; cout << " original factor: " << src->factor().params() << endl;
cout << " factor product: " << result.params() << endl; cout << " factor product: " << result.params() << endl;
} }
result.sumOutAllExceptIndex (link->index()); result.sumOutAllExceptIndex (link->index());
if (Constants::DEBUG >= 5) { if (Constants::SHOW_BP_CALCS) {
cout << " marginalized: " << result.params() << endl; cout << " marginalized: " << result.params() << endl;
} }
link->getNextMessage() = result.params(); link->getNextMessage() = result.params();
LogAware::normalize (link->getNextMessage()); LogAware::normalize (link->getNextMessage());
if (Constants::DEBUG >= 5) { if (Constants::SHOW_BP_CALCS) {
cout << " curr msg: " << link->getMessage() << endl; cout << " curr msg: " << link->getMessage() << endl;
cout << " next msg: " << link->getNextMessage() << endl; cout << " next msg: " << link->getNextMessage() << endl;
} }
@ -311,14 +316,14 @@ CbpSolver::getVar2FactorMsg (const SpLink* _link) const
if (src->hasEvidence()) { if (src->hasEvidence()) {
msg.resize (src->range(), LogAware::noEvidence()); msg.resize (src->range(), LogAware::noEvidence());
double value = link->getMessage()[src->getEvidence()]; double value = link->getMessage()[src->getEvidence()];
if (Constants::DEBUG >= 5) { if (Constants::SHOW_BP_CALCS) {
msg[src->getEvidence()] = value; msg[src->getEvidence()] = value;
cout << msg << "^" << link->nrEdges() << "-1" ; cout << msg << "^" << link->nrEdges() << "-1" ;
} }
msg[src->getEvidence()] = LogAware::pow (value, link->nrEdges() - 1); msg[src->getEvidence()] = LogAware::pow (value, link->nrEdges() - 1);
} else { } else {
msg = link->getMessage(); msg = link->getMessage();
if (Constants::DEBUG >= 5) { if (Constants::SHOW_BP_CALCS) {
cout << msg << "^" << link->nrEdges() << "-1" ; cout << msg << "^" << link->nrEdges() << "-1" ;
} }
LogAware::pow (msg, link->nrEdges() - 1); LogAware::pow (msg, link->nrEdges() - 1);
@ -337,13 +342,13 @@ CbpSolver::getVar2FactorMsg (const SpLink* _link) const
CbpSolverLink* cl = static_cast<CbpSolverLink*> (links[i]); CbpSolverLink* cl = static_cast<CbpSolverLink*> (links[i]);
if ( ! (cl->getFactor() == dst && cl->index() == link->index())) { if ( ! (cl->getFactor() == dst && cl->index() == link->index())) {
Util::multiply (msg, cl->poweredMessage()); Util::multiply (msg, cl->poweredMessage());
if (Constants::DEBUG >= 5) { if (Constants::SHOW_BP_CALCS) {
cout << " x " << cl->getNextMessage() << "^" << link->nrEdges(); cout << " x " << cl->getNextMessage() << "^" << link->nrEdges();
} }
} }
} }
} }
if (Constants::DEBUG >= 5) { if (Constants::SHOW_BP_CALCS) {
cout << " = " << msg; cout << " = " << msg;
} }
return msg; return msg;

View File

@ -151,13 +151,11 @@ Factor::sumOutIndex (unsigned idx)
void void
Factor::sumOutAllExceptIndex (unsigned idx) Factor::sumOutAllExceptIndex (unsigned idx)
{ {
int i = (int)idx; while (args_.size() > idx + 1) {
while (args_.size() > i + 1) {
sumOutLastVariable(); sumOutLastVariable();
} }
while (i > 0) { for (unsigned i = 0; i < idx; i++) {
sumOutFirstVariable(); sumOutFirstVariable();
i -- ;
} }
} }

View File

@ -542,6 +542,18 @@ FoveSolver::getJointDistributionOf (const Grounds& query)
void
FoveSolver::printSolverFlags (void) const
{
stringstream ss;
ss << "fove [" ;
ss << "log_domain=" << Util::toString (Globals::logDomain);
ss << "]" ;
cout << ss.str() << endl;
}
void void
FoveSolver::absorveEvidence ( FoveSolver::absorveEvidence (
ParfactorList& pfList, ParfactorList& pfList,
@ -568,7 +580,7 @@ FoveSolver::absorveEvidence (
} }
pfList.add (newPfs); pfList.add (newPfs);
} }
if (Constants::DEBUG >= 2 && obsFormulas.empty() == false) { if (Globals::verbosity > 2 && obsFormulas.empty() == false) {
Util::printAsteriskLine(); Util::printAsteriskLine();
cout << "AFTER EVIDENCE ABSORVED" << endl; cout << "AFTER EVIDENCE ABSORVED" << endl;
for (unsigned i = 0; i < obsFormulas.size(); i++) { for (unsigned i = 0; i < obsFormulas.size(); i++) {
@ -603,20 +615,26 @@ FoveSolver::countNormalize (
void void
FoveSolver::runSolver (const Grounds& query) FoveSolver::runSolver (const Grounds& query)
{ {
largestCost_ = std::log (0);
shatterAgainstQuery (query); shatterAgainstQuery (query);
runWeakBayesBall (query); runWeakBayesBall (query);
while (true) { while (true) {
if (Constants::DEBUG >= 2) { if (Globals::verbosity > 2) {
Util::printDashedLine(); Util::printDashedLine();
pfList_.print(); pfList_.print();
LiftedOperator::printValidOps (pfList_, query); if (Globals::verbosity > 3) {
LiftedOperator::printValidOps (pfList_, query);
}
} }
LiftedOperator* op = getBestOperation (query); LiftedOperator* op = getBestOperation (query);
if (op == 0) { if (op == 0) {
break; break;
} }
if (Constants::DEBUG >= 2) { if (Globals::verbosity > 1) {
cout << "best operation: " << op->toString() << endl; cout << "best operation: " << op->toString();
if (Globals::verbosity > 2) {
cout << endl;
}
} }
op->apply(); op->apply();
delete op; delete op;
@ -630,6 +648,10 @@ FoveSolver::runSolver (const Grounds& query)
++ pfIter; ++ pfIter;
} }
} }
if (Globals::verbosity > 0) {
cout << "largest cost = " << std::exp (largestCost_) << endl;
cout << endl;
}
(*pfList_.begin())->reorderAccordingGrounds (query); (*pfList_.begin())->reorderAccordingGrounds (query);
} }
@ -649,6 +671,9 @@ FoveSolver::getBestOperation (const Grounds& query)
bestCost = cost; bestCost = cost;
} }
} }
if (bestCost > largestCost_) {
largestCost_ = bestCost;
}
for (unsigned i = 0; i < validOps.size(); i++) { for (unsigned i = 0; i < validOps.size(); i++) {
if (validOps[i] != bestOp) { if (validOps[i] != bestOp) {
delete validOps[i]; delete validOps[i];
@ -699,18 +724,21 @@ FoveSolver::runWeakBayesBall (const Grounds& query)
} }
ParfactorList::iterator it = pfList_.begin(); ParfactorList::iterator it = pfList_.begin();
bool foundNotRequired = false;
while (it != pfList_.end()) { while (it != pfList_.end()) {
if (Util::contains (requiredPfs, *it) == false) { if (Util::contains (requiredPfs, *it) == false) {
if (Globals::verbosity > 2) {
if (foundNotRequired == false) {
Util::printHeader ("PARFACTORS TO DISCARD");
foundNotRequired = true;
}
(*it)->print();
}
it = pfList_.removeAndDelete (it); it = pfList_.removeAndDelete (it);
} else { } else {
++ it; ++ it;
} }
} }
if (Constants::DEBUG >= 2) {
Util::printHeader ("REQUIRED PARFACTORS");
pfList_.print();
}
} }
@ -750,8 +778,7 @@ FoveSolver::shatterAgainstQuery (const Grounds& query)
} }
pfList_.add (newPfs); pfList_.add (newPfs);
} }
if (Constants::DEBUG >= 2) { if (Globals::verbosity > 2) {
cout << endl;
Util::printAsteriskLine(); Util::printAsteriskLine();
cout << "SHATTERED AGAINST THE QUERY" << endl; cout << "SHATTERED AGAINST THE QUERY" << endl;
for (unsigned i = 0; i < query.size(); i++) { for (unsigned i = 0; i < query.size(); i++) {

View File

@ -116,6 +116,8 @@ class FoveSolver
Params getJointDistributionOf (const Grounds&); Params getJointDistributionOf (const Grounds&);
void printSolverFlags (void) const;
static void absorveEvidence ( static void absorveEvidence (
ParfactorList& pfList, ObservedFormulas& obsFormulas); ParfactorList& pfList, ObservedFormulas& obsFormulas);
@ -133,6 +135,8 @@ class FoveSolver
static Parfactors absorve (ObservedFormula&, Parfactor*); static Parfactors absorve (ObservedFormula&, Parfactor*);
ParfactorList pfList_; ParfactorList pfList_;
double largestCost_;
}; };
#endif // HORUS_FOVESOLVER_H #endif // HORUS_FOVESOLVER_H

View File

@ -39,6 +39,8 @@ namespace Globals {
extern bool logDomain; extern bool logDomain;
extern unsigned verbosity;
extern InfAlgorithms infAlgorithm; extern InfAlgorithms infAlgorithm;
}; };
@ -47,12 +49,14 @@ extern InfAlgorithms infAlgorithm;
namespace Constants { namespace Constants {
// level of debug information // level of debug information
const unsigned DEBUG = 0; const unsigned DEBUG = 3;
const bool SHOW_BP_CALCS = false;
const int NO_EVIDENCE = -1; const int NO_EVIDENCE = -1;
// number of digits to show when printing a parameter // number of digits to show when printing a parameter
const unsigned PRECISION = 5; const unsigned PRECISION = 6;
const bool COLLECT_STATS = false; const bool COLLECT_STATS = false;

View File

@ -174,8 +174,10 @@ runSolver (const FactorGraph& fg, const VarIds& queryIds)
default: default:
assert (false); assert (false);
} }
solver->printSolverFlags(); if (Globals::verbosity > 0) {
cout << endl; solver->printSolverFlags();
cout << endl;
}
if (queryIds.size() == 0) { if (queryIds.size() == 0) {
solver->printAllPosterioris(); solver->printAllPosterioris();
} else { } else {

View File

@ -64,7 +64,7 @@ int createLiftedNetwork (void)
} }
// LiftedUtils::printSymbolDictionary(); // LiftedUtils::printSymbolDictionary();
if (Constants::DEBUG > 2) { if (Globals::verbosity > 2) {
Util::printHeader ("INITIAL PARFACTORS"); Util::printHeader ("INITIAL PARFACTORS");
for (unsigned i = 0; i < parfactors.size(); i++) { for (unsigned i = 0; i < parfactors.size(); i++) {
parfactors[i]->print(); parfactors[i]->print();
@ -73,7 +73,7 @@ int createLiftedNetwork (void)
ParfactorList* pfList = new ParfactorList (parfactors); ParfactorList* pfList = new ParfactorList (parfactors);
if (Constants::DEBUG >= 2) { if (Globals::verbosity > 2) {
Util::printHeader ("SHATTERED PARFACTORS"); Util::printHeader ("SHATTERED PARFACTORS");
pfList->print(); pfList->print();
} }
@ -300,6 +300,10 @@ runLiftedSolver (void)
jointList = YAP_TailOfTerm (jointList); jointList = YAP_TailOfTerm (jointList);
} }
FoveSolver solver (pfListCopy); FoveSolver solver (pfListCopy);
if (Globals::verbosity > 0 && taskList == YAP_ARG2) {
solver.printSolverFlags();
cout << endl;
}
if (queryVars.size() == 1) { if (queryVars.size() == 1) {
results.push_back (solver.getPosterioriOf (queryVars[0])); results.push_back (solver.getPosterioriOf (queryVars[0]));
} else { } else {
@ -376,7 +380,9 @@ void runVeSolver (
} }
// VarElimSolver solver (*mfg); // VarElimSolver solver (*mfg);
VarElimSolver solver (*fg); //FIXME VarElimSolver solver (*fg); //FIXME
// solver.printSolverFlags(); if (Globals::verbosity > 0 && i == 0) {
solver.printSolverFlags();
}
results.push_back (solver.solveQuery (tasks[i])); results.push_back (solver.solveQuery (tasks[i]));
if (fg->isFromBayesNetwork()) { if (fg->isFromBayesNetwork()) {
delete mfg; delete mfg;
@ -410,7 +416,9 @@ void runBpSolver (
cerr << "error: unknow solver" << endl; cerr << "error: unknow solver" << endl;
abort(); abort();
} }
// solver->printSolverFlags(); if (Globals::verbosity > 0) {
solver->printSolverFlags();
}
results.reserve (tasks.size()); results.reserve (tasks.size());
for (unsigned i = 0; i < tasks.size(); i++) { for (unsigned i = 0; i < tasks.size(); i++) {
results.push_back (solver->solveQuery (tasks[i])); results.push_back (solver->solveQuery (tasks[i]));
@ -508,7 +516,11 @@ setHorusFlag (void)
{ {
string key ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG1))); string key ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG1)));
string value; string value;
if (key == "accuracy") { if (key == "verbosity") {
stringstream ss;
ss << (int) YAP_IntOfTerm (YAP_ARG2);
ss >> value;
} else if (key == "accuracy") {
stringstream ss; stringstream ss;
ss << (float) YAP_FloatOfTerm (YAP_ARG2); ss << (float) YAP_FloatOfTerm (YAP_ARG2);
ss >> value; ss >> value;

View File

@ -11,10 +11,9 @@
namespace Globals { namespace Globals {
bool logDomain = false; bool logDomain = false;
//InfAlgs infAlgorithm = InfAlgorithms::VE; unsigned verbosity = 0;
//InfAlgs infAlgorithm = InfAlgorithms::BN_BP;
//InfAlgs infAlgorithm = InfAlgorithms::FG_BP; InfAlgorithms infAlgorithm = InfAlgorithms::VE;
InfAlgorithms infAlgorithm = InfAlgorithms::CBP;
}; };
@ -227,7 +226,11 @@ bool
setHorusFlag (string key, string value) setHorusFlag (string key, string value)
{ {
bool returnVal = true; bool returnVal = true;
if (key == "inf_alg") { if (key == "verbosity") {
stringstream ss;
ss << value;
ss >> Globals::verbosity;
} else if (key == "inf_alg") {
if ( value == "ve") { if ( value == "ve") {
Globals::infAlgorithm = InfAlgorithms::VE; Globals::infAlgorithm = InfAlgorithms::VE;
} else if (value == "bp") { } else if (value == "bp") {

View File

@ -16,6 +16,14 @@ VarElimSolver::~VarElimSolver (void)
Params Params
VarElimSolver::solveQuery (VarIds queryVids) VarElimSolver::solveQuery (VarIds queryVids)
{ {
if (Globals::verbosity > 1) {
cout << "Solving query on " ;
for (unsigned i = 0; i < queryVids.size(); i++) {
if (i != 0) cout << ", " ;
cout << fg.getVarNode (queryVids[i])->label();
}
cout << endl;
}
factorList_.clear(); factorList_.clear();
varFactors_.clear(); varFactors_.clear();
elimOrder_.clear(); elimOrder_.clear();
@ -77,9 +85,19 @@ VarElimSolver::createFactorList (void)
void void
VarElimSolver::absorveEvidence (void) VarElimSolver::absorveEvidence (void)
{ {
if (Globals::verbosity > 2) {
Util::printDashedLine();
cout << "(initial factor list)" << endl;
printActiveFactors();
}
const VarNodes& varNodes = fg.varNodes(); const VarNodes& varNodes = fg.varNodes();
for (unsigned i = 0; i < varNodes.size(); i++) { for (unsigned i = 0; i < varNodes.size(); i++) {
if (varNodes[i]->hasEvidence()) { if (varNodes[i]->hasEvidence()) {
if (Globals::verbosity > 1) {
cout << "-> aborving evidence on ";
cout << varNodes[i]->label() << " = " ;
cout << varNodes[i]->getEvidence() << endl;
}
const vector<unsigned>& idxs = const vector<unsigned>& idxs =
varFactors_.find (varNodes[i]->varId())->second; varFactors_.find (varNodes[i]->varId())->second;
for (unsigned j = 0; j < idxs.size(); j++) { for (unsigned j = 0; j < idxs.size(); j++) {
@ -108,12 +126,16 @@ VarElimSolver::findEliminationOrder (const VarIds& vids)
void void
VarElimSolver::processFactorList (const VarIds& vids) VarElimSolver::processFactorList (const VarIds& vids)
{ {
totalFactorSize_ = 0;
largestFactorSize_ = 0;
for (unsigned i = 0; i < elimOrder_.size(); i++) { for (unsigned i = 0; i < elimOrder_.size(); i++) {
if (Constants::DEBUG >= 3) { if (Globals::verbosity >= 2) {
printActiveFactors(); if (Globals::verbosity >= 3) {
Util::printDashedLine();
printActiveFactors();
}
cout << "-> summing out " ; cout << "-> summing out " ;
VarNode* vn = fg.getVarNode (elimOrder_[i]); cout << fg.getVarNode (elimOrder_[i])->label() << endl;
cout << vn->label() << endl;
} }
eliminate (elimOrder_[i]); eliminate (elimOrder_[i]);
} }
@ -137,6 +159,11 @@ VarElimSolver::processFactorList (const VarIds& vids)
finalFactor->reorderArguments (unobservedVids); finalFactor->reorderArguments (unobservedVids);
finalFactor->normalize(); finalFactor->normalize();
factorList_.push_back (finalFactor); factorList_.push_back (finalFactor);
if (Globals::verbosity > 0) {
cout << "total factor size: " << totalFactorSize_ << endl;
cout << "largest factor size: " << largestFactorSize_ << endl;
cout << endl;
}
} }
@ -158,6 +185,10 @@ VarElimSolver::eliminate (VarId elimVar)
factorList_[idx] = 0; factorList_[idx] = 0;
} }
} }
totalFactorSize_ += result->size();
if (result->size() > largestFactorSize_) {
largestFactorSize_ = result->size();
}
if (result != 0 && result->nrArguments() != 1) { if (result != 0 && result->nrArguments() != 1) {
result->sumOut (elimVar); result->sumOut (elimVar);
factorList_.push_back (result); factorList_.push_back (result);
@ -175,14 +206,11 @@ VarElimSolver::eliminate (VarId elimVar)
void void
VarElimSolver::printActiveFactors (void) VarElimSolver::printActiveFactors (void)
{ {
cout << endl;
Util::printDashedLine();
for (unsigned i = 0; i < factorList_.size(); i++) { for (unsigned i = 0; i < factorList_.size(); i++) {
if (factorList_[i] != 0) { if (factorList_[i] != 0) {
cout << factorList_[i]->getLabel() << " " ; cout << factorList_[i]->getLabel() << " " ;
cout << factorList_[i]->params() << endl; cout << factorList_[i]->params() << endl;
} }
} }
cout << endl;
} }

View File

@ -35,8 +35,10 @@ class VarElimSolver : public Solver
void printActiveFactors (void); void printActiveFactors (void);
vector<Factor*> factorList_; Factors factorList_;
VarIds elimOrder_; VarIds elimOrder_;
unsigned largestFactorSize_;
unsigned totalFactorSize_;
unordered_map<VarId, vector<unsigned>> varFactors_; unordered_map<VarId, vector<unsigned>> varFactors_;
}; };