revamp debugging plataform

This commit is contained in:
Tiago Gomes 2012-04-29 20:07:09 +01:00
parent d86e2c8386
commit 56475cacbc
12 changed files with 184 additions and 94 deletions

View File

@ -153,9 +153,8 @@ BpSolver::runSolver (void)
nIters_ = 0;
while (!converged() && nIters_ < BpOptions::maxIter) {
nIters_ ++;
if (Constants::DEBUG >= 2) {
if (Globals::verbosity > 1) {
Util::printHeader (string ("Iteration ") + Util::toString (nIters_));
// cout << endl;
}
switch (BpOptions::schedule) {
case BpOptions::Schedule::SEQ_RANDOM:
@ -178,12 +177,8 @@ BpSolver::runSolver (void)
maxResidualSchedule();
break;
}
if (Constants::DEBUG >= 2) {
cout << endl;
}
}
if (Constants::DEBUG >= 2) {
cout << endl;
if (Globals::verbosity > 0) {
if (nIters_ < BpOptions::maxIter) {
cout << "Sum-Product converged in " ;
cout << nIters_ << " iterations" << endl;
@ -191,6 +186,7 @@ BpSolver::runSolver (void)
cout << "The maximum number of iterations was hit, terminating..." ;
cout << endl;
}
cout << endl;
}
unsigned size = fg_->varNodes().size();
if (Constants::COLLECT_STATS) {
@ -232,7 +228,7 @@ BpSolver::maxResidualSchedule (void)
}
for (unsigned c = 0; c < links_.size(); c++) {
if (Constants::DEBUG >= 2) {
if (Globals::verbosity > 1) {
cout << "current residuals:" << endl;
for (SortedOrder::iterator it = sortedOrder_.begin();
it != sortedOrder_.end(); it ++) {
@ -266,7 +262,7 @@ BpSolver::maxResidualSchedule (void)
}
}
}
if (Constants::DEBUG >= 2) {
if (Globals::verbosity > 1) {
Util::printDashedLine();
}
}
@ -291,13 +287,13 @@ BpSolver::calculateFactor2VariableMsg (SpLink* link)
if (Globals::logDomain) {
for (int i = links.size() - 1; i >= 0; i--) {
if (links[i]->getVariable() != dst) {
if (Constants::DEBUG >= 5) {
if (Constants::SHOW_BP_CALCS) {
cout << " message from " << links[i]->getVariable()->label();
cout << ": " ;
}
Util::add (msgProduct, getVar2FactorMsg (links[i]), repetitions);
repetitions *= links[i]->getVariable()->range();
if (Constants::DEBUG >= 5) {
if (Constants::SHOW_BP_CALCS) {
cout << endl;
}
} else {
@ -309,13 +305,13 @@ BpSolver::calculateFactor2VariableMsg (SpLink* link)
} else {
for (int i = links.size() - 1; i >= 0; i--) {
if (links[i]->getVariable() != dst) {
if (Constants::DEBUG >= 5) {
if (Constants::SHOW_BP_CALCS) {
cout << " message from " << links[i]->getVariable()->label();
cout << ": " ;
}
Util::multiply (msgProduct, getVar2FactorMsg (links[i]), repetitions);
repetitions *= links[i]->getVariable()->range();
if (Constants::DEBUG >= 5) {
if (Constants::SHOW_BP_CALCS) {
cout << endl;
}
} else {
@ -328,18 +324,18 @@ BpSolver::calculateFactor2VariableMsg (SpLink* link)
Factor result (src->factor().arguments(),
src->factor().ranges(), msgProduct);
result.multiply (src->factor());
if (Constants::DEBUG >= 5) {
if (Constants::SHOW_BP_CALCS) {
cout << " message product: " << msgProduct << endl;
cout << " original factor: " << src->factor().params() << endl;
cout << " factor product: " << result.params() << endl;
}
result.sumOutAllExcept (dst->varId());
if (Constants::DEBUG >= 5) {
if (Constants::SHOW_BP_CALCS) {
cout << " marginalized: " << result.params() << endl;
}
link->getNextMessage() = result.params();
LogAware::normalize (link->getNextMessage());
if (Constants::DEBUG >= 5) {
if (Constants::SHOW_BP_CALCS) {
cout << " curr msg: " << link->getMessage() << endl;
cout << " next msg: " << link->getNextMessage() << endl;
}
@ -359,7 +355,7 @@ BpSolver::getVar2FactorMsg (const SpLink* link) const
} else {
msg.resize (src->range(), LogAware::one());
}
if (Constants::DEBUG >= 5) {
if (Constants::SHOW_BP_CALCS) {
cout << msg;
}
const SpLinkSet& links = ninf (src)->getLinks();
@ -367,7 +363,7 @@ BpSolver::getVar2FactorMsg (const SpLink* link) const
for (unsigned i = 0; i < links.size(); i++) {
if (links[i]->getFactor() != dst) {
Util::add (msg, links[i]->getMessage());
if (Constants::DEBUG >= 5) {
if (Constants::SHOW_BP_CALCS) {
cout << " x " << links[i]->getMessage();
}
}
@ -376,13 +372,13 @@ BpSolver::getVar2FactorMsg (const SpLink* link) const
for (unsigned i = 0; i < links.size(); i++) {
if (links[i]->getFactor() != dst) {
Util::multiply (msg, links[i]->getMessage());
if (Constants::DEBUG >= 5) {
if (Constants::SHOW_BP_CALCS) {
cout << " x " << links[i]->getMessage();
}
}
}
}
if (Constants::DEBUG >= 5) {
if (Constants::SHOW_BP_CALCS) {
cout << " = " << msg;
}
return msg;
@ -472,7 +468,16 @@ BpSolver::converged (void)
if (links_.size() == 0) {
return true;
}
if (nIters_ <= 1) {
if (nIters_ == 0) {
return false;
}
if (Globals::verbosity > 2) {
cout << endl;
}
if (nIters_ == 1) {
if (Globals::verbosity > 1) {
cout << "no residuals" << endl << endl;
}
return false;
}
bool converged = true;
@ -486,7 +491,7 @@ BpSolver::converged (void)
} else {
for (unsigned i = 0; i < links_.size(); i++) {
double residual = links_[i]->getResidual();
if (Constants::DEBUG >= 2) {
if (Globals::verbosity > 1) {
cout << links_[i]->toString() + " residual = " << residual << endl;
}
if (residual > BpOptions::accuracy) {
@ -494,7 +499,7 @@ BpSolver::converged (void)
if (Constants::DEBUG == 0) break;
}
}
if (Constants::DEBUG >= 2) {
if (Globals::verbosity > 1) {
cout << endl;
}
}

View File

@ -128,7 +128,7 @@ class BpSolver : public Solver
void calculateAndUpdateMessage (SpLink* link, bool calcResidual = true)
{
if (Constants::DEBUG >= 3) {
if (Globals::verbosity > 2) {
cout << "calculating & updating " << link->toString() << endl;
}
calculateFactor2VariableMsg (link);
@ -140,7 +140,7 @@ class BpSolver : public Solver
void calculateMessage (SpLink* link, bool calcResidual = true)
{
if (Constants::DEBUG >= 3) {
if (Globals::verbosity > 2) {
cout << "calculating " << link->toString() << endl;
}
calculateFactor2VariableMsg (link);
@ -152,7 +152,7 @@ class BpSolver : public Solver
void updateMessage (SpLink* link)
{
link->updateMessage();
if (Constants::DEBUG >= 3) {
if (Globals::verbosity > 2) {
cout << "updating " << link->toString() << endl;
}
}

View File

@ -24,14 +24,6 @@ CbpSolver::CbpSolver (const FactorGraph& fg) : BpSolver (fg)
Statistics::updateCompressingStatistics (nrGroundVars,
nrGroundFacs, nrClusterVars, nrClusterFacs, nrNeighborless);
}
if (Constants::DEBUG >= 5) {
cout << "uncompressed factor graph:" << endl;
cout << " " << fg.nrVarNodes() << " variables " << endl;
cout << " " << fg.nrFacNodes() << " factors " << endl;
cout << "compressed factor graph:" << endl;
cout << " " << fg_->nrVarNodes() << " variables " << endl;
cout << " " << fg_->nrFacNodes() << " factors " << endl;
}
}
@ -123,15 +115,25 @@ CbpSolver::getJointDistributionOf (const VarIds& jointVids)
void
CbpSolver::createLinks (void)
{
{
if (Globals::verbosity > 0) {
cout << "original factor graph has " ;
cout << fg.nrVarNodes() << " variables and " ;
cout << fg.nrFacNodes() << " factors " << endl;
cout << "compressed factor graph has " ;
cout << fg_->nrVarNodes() << " variables and " ;
cout << fg_->nrFacNodes() << " factors " << endl;
cout << endl;
}
const FacClusters& fcs = cfg_->facClusters();
for (unsigned i = 0; i < fcs.size(); i++) {
const VarClusters& vcs = fcs[i]->varClusters();
for (unsigned j = 0; j < vcs.size(); j++) {
unsigned count = cfg_->getEdgeCount (fcs[i], vcs[j], j);
if (Constants::DEBUG >= 5) {
cout << "creating edge " ;
cout << fcs[i]->representative()->getLabel() << " -> " ;
if (Globals::verbosity > 1) {
cout << "creating link " ;
cout << fcs[i]->representative()->getLabel();
cout << " -- " ;
cout << vcs[j]->representative()->label();
cout << " idx=" << j << ", count=" << count << endl;
}
@ -139,6 +141,9 @@ CbpSolver::createLinks (void)
fcs[i]->representative(), vcs[j]->representative(), j, count));
}
}
if (Globals::verbosity > 1) {
cout << endl;
}
}
@ -151,7 +156,7 @@ CbpSolver::maxResidualSchedule (void)
calculateMessage (links_[i]);
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
linkMap_.insert (make_pair (links_[i], it));
if (Constants::DEBUG >= 2 && Constants::DEBUG < 5) {
if (Globals::verbosity >= 1) {
cout << "calculating " << links_[i]->toString() << endl;
}
}
@ -159,7 +164,7 @@ CbpSolver::maxResidualSchedule (void)
}
for (unsigned c = 0; c < links_.size(); c++) {
if (Constants::DEBUG >= 2) {
if (Globals::verbosity > 1) {
cout << endl << "current residuals:" << endl;
for (SortedOrder::iterator it = sortedOrder_.begin();
it != sortedOrder_.end(); it ++) {
@ -170,7 +175,7 @@ CbpSolver::maxResidualSchedule (void)
SortedOrder::iterator it = sortedOrder_.begin();
SpLink* link = *it;
if (Constants::DEBUG >= 2) {
if (Globals::verbosity >= 1) {
cout << "updating " << (*sortedOrder_.begin())->toString() << endl;
}
if (link->getResidual() < BpOptions::accuracy) {
@ -187,7 +192,7 @@ CbpSolver::maxResidualSchedule (void)
const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks();
for (unsigned j = 0; j < links.size(); j++) {
if (links[j]->getVariable() != link->getVariable()) {
if (Constants::DEBUG >= 2 && Constants::DEBUG < 5) {
if (Globals::verbosity > 1) {
cout << " calculating " << links[j]->toString() << endl;
}
calculateMessage (links[j]);
@ -202,7 +207,7 @@ CbpSolver::maxResidualSchedule (void)
const SpLinkSet& links = ninf(link->getFactor())->getLinks();
for (unsigned i = 0; i < links.size(); i++) {
if (links[i]->getVariable() != link->getVariable()) {
if (Constants::DEBUG >= 2 && Constants::DEBUG < 5) {
if (Globals::verbosity > 1) {
cout << " calculating " << links[i]->toString() << endl;
}
calculateMessage (links[i]);
@ -235,13 +240,13 @@ CbpSolver::calculateFactor2VariableMsg (SpLink* _link)
for (int i = links.size() - 1; i >= 0; i--) {
const CbpSolverLink* cl = static_cast<const CbpSolverLink*> (links[i]);
if ( ! (cl->getVariable() == dst && cl->index() == link->index())) {
if (Constants::DEBUG >= 5) {
if (Constants::SHOW_BP_CALCS) {
cout << " message from " << links[i]->getVariable()->label();
cout << ": " ;
}
Util::add (msgProduct, getVar2FactorMsg (links[i]), repetitions);
repetitions *= links[i]->getVariable()->range();
if (Constants::DEBUG >= 5) {
if (Constants::SHOW_BP_CALCS) {
cout << endl;
}
} else {
@ -254,13 +259,13 @@ CbpSolver::calculateFactor2VariableMsg (SpLink* _link)
for (int i = links.size() - 1; i >= 0; i--) {
const CbpSolverLink* cl = static_cast<const CbpSolverLink*> (links[i]);
if ( ! (cl->getVariable() == dst && cl->index() == link->index())) {
if (Constants::DEBUG >= 5) {
if (Constants::SHOW_BP_CALCS) {
cout << " message from " << links[i]->getVariable()->label();
cout << ": " ;
}
Util::multiply (msgProduct, getVar2FactorMsg (links[i]), repetitions);
repetitions *= links[i]->getVariable()->range();
if (Constants::DEBUG >= 5) {
if (Constants::SHOW_BP_CALCS) {
cout << endl;
}
} else {
@ -282,18 +287,18 @@ CbpSolver::calculateFactor2VariableMsg (SpLink* _link)
result[i] *= src->factor()[i];
}
}
if (Constants::DEBUG >= 5) {
if (Constants::SHOW_BP_CALCS) {
cout << " message product: " << msgProduct << endl;
cout << " original factor: " << src->factor().params() << endl;
cout << " factor product: " << result.params() << endl;
}
result.sumOutAllExceptIndex (link->index());
if (Constants::DEBUG >= 5) {
if (Constants::SHOW_BP_CALCS) {
cout << " marginalized: " << result.params() << endl;
}
link->getNextMessage() = result.params();
LogAware::normalize (link->getNextMessage());
if (Constants::DEBUG >= 5) {
if (Constants::SHOW_BP_CALCS) {
cout << " curr msg: " << link->getMessage() << endl;
cout << " next msg: " << link->getNextMessage() << endl;
}
@ -311,14 +316,14 @@ CbpSolver::getVar2FactorMsg (const SpLink* _link) const
if (src->hasEvidence()) {
msg.resize (src->range(), LogAware::noEvidence());
double value = link->getMessage()[src->getEvidence()];
if (Constants::DEBUG >= 5) {
if (Constants::SHOW_BP_CALCS) {
msg[src->getEvidence()] = value;
cout << msg << "^" << link->nrEdges() << "-1" ;
}
msg[src->getEvidence()] = LogAware::pow (value, link->nrEdges() - 1);
} else {
msg = link->getMessage();
if (Constants::DEBUG >= 5) {
if (Constants::SHOW_BP_CALCS) {
cout << msg << "^" << link->nrEdges() << "-1" ;
}
LogAware::pow (msg, link->nrEdges() - 1);
@ -337,13 +342,13 @@ CbpSolver::getVar2FactorMsg (const SpLink* _link) const
CbpSolverLink* cl = static_cast<CbpSolverLink*> (links[i]);
if ( ! (cl->getFactor() == dst && cl->index() == link->index())) {
Util::multiply (msg, cl->poweredMessage());
if (Constants::DEBUG >= 5) {
if (Constants::SHOW_BP_CALCS) {
cout << " x " << cl->getNextMessage() << "^" << link->nrEdges();
}
}
}
}
if (Constants::DEBUG >= 5) {
if (Constants::SHOW_BP_CALCS) {
cout << " = " << msg;
}
return msg;

View File

@ -151,14 +151,12 @@ Factor::sumOutIndex (unsigned idx)
void
Factor::sumOutAllExceptIndex (unsigned idx)
{
int i = (int)idx;
while (args_.size() > i + 1) {
while (args_.size() > idx + 1) {
sumOutLastVariable();
}
while (i > 0) {
for (unsigned i = 0; i < idx; i++) {
sumOutFirstVariable();
i -- ;
}
}
}

View File

@ -542,6 +542,18 @@ FoveSolver::getJointDistributionOf (const Grounds& query)
void
FoveSolver::printSolverFlags (void) const
{
stringstream ss;
ss << "fove [" ;
ss << "log_domain=" << Util::toString (Globals::logDomain);
ss << "]" ;
cout << ss.str() << endl;
}
void
FoveSolver::absorveEvidence (
ParfactorList& pfList,
@ -568,7 +580,7 @@ FoveSolver::absorveEvidence (
}
pfList.add (newPfs);
}
if (Constants::DEBUG >= 2 && obsFormulas.empty() == false) {
if (Globals::verbosity > 2 && obsFormulas.empty() == false) {
Util::printAsteriskLine();
cout << "AFTER EVIDENCE ABSORVED" << endl;
for (unsigned i = 0; i < obsFormulas.size(); i++) {
@ -603,20 +615,26 @@ FoveSolver::countNormalize (
void
FoveSolver::runSolver (const Grounds& query)
{
largestCost_ = std::log (0);
shatterAgainstQuery (query);
runWeakBayesBall (query);
while (true) {
if (Constants::DEBUG >= 2) {
if (Globals::verbosity > 2) {
Util::printDashedLine();
pfList_.print();
LiftedOperator::printValidOps (pfList_, query);
if (Globals::verbosity > 3) {
LiftedOperator::printValidOps (pfList_, query);
}
}
LiftedOperator* op = getBestOperation (query);
if (op == 0) {
break;
}
if (Constants::DEBUG >= 2) {
cout << "best operation: " << op->toString() << endl;
if (Globals::verbosity > 1) {
cout << "best operation: " << op->toString();
if (Globals::verbosity > 2) {
cout << endl;
}
}
op->apply();
delete op;
@ -630,6 +648,10 @@ FoveSolver::runSolver (const Grounds& query)
++ pfIter;
}
}
if (Globals::verbosity > 0) {
cout << "largest cost = " << std::exp (largestCost_) << endl;
cout << endl;
}
(*pfList_.begin())->reorderAccordingGrounds (query);
}
@ -649,6 +671,9 @@ FoveSolver::getBestOperation (const Grounds& query)
bestCost = cost;
}
}
if (bestCost > largestCost_) {
largestCost_ = bestCost;
}
for (unsigned i = 0; i < validOps.size(); i++) {
if (validOps[i] != bestOp) {
delete validOps[i];
@ -699,18 +724,21 @@ FoveSolver::runWeakBayesBall (const Grounds& query)
}
ParfactorList::iterator it = pfList_.begin();
bool foundNotRequired = false;
while (it != pfList_.end()) {
if (Util::contains (requiredPfs, *it) == false) {
if (Globals::verbosity > 2) {
if (foundNotRequired == false) {
Util::printHeader ("PARFACTORS TO DISCARD");
foundNotRequired = true;
}
(*it)->print();
}
it = pfList_.removeAndDelete (it);
} else {
++ it;
}
}
if (Constants::DEBUG >= 2) {
Util::printHeader ("REQUIRED PARFACTORS");
pfList_.print();
}
}
@ -750,8 +778,7 @@ FoveSolver::shatterAgainstQuery (const Grounds& query)
}
pfList_.add (newPfs);
}
if (Constants::DEBUG >= 2) {
cout << endl;
if (Globals::verbosity > 2) {
Util::printAsteriskLine();
cout << "SHATTERED AGAINST THE QUERY" << endl;
for (unsigned i = 0; i < query.size(); i++) {

View File

@ -116,6 +116,8 @@ class FoveSolver
Params getJointDistributionOf (const Grounds&);
void printSolverFlags (void) const;
static void absorveEvidence (
ParfactorList& pfList, ObservedFormulas& obsFormulas);
@ -133,6 +135,8 @@ class FoveSolver
static Parfactors absorve (ObservedFormula&, Parfactor*);
ParfactorList pfList_;
double largestCost_;
};
#endif // HORUS_FOVESOLVER_H

View File

@ -39,6 +39,8 @@ namespace Globals {
extern bool logDomain;
extern unsigned verbosity;
extern InfAlgorithms infAlgorithm;
};
@ -47,12 +49,14 @@ extern InfAlgorithms infAlgorithm;
namespace Constants {
// level of debug information
const unsigned DEBUG = 0;
const unsigned DEBUG = 3;
const bool SHOW_BP_CALCS = false;
const int NO_EVIDENCE = -1;
// number of digits to show when printing a parameter
const unsigned PRECISION = 5;
const unsigned PRECISION = 6;
const bool COLLECT_STATS = false;

View File

@ -174,8 +174,10 @@ runSolver (const FactorGraph& fg, const VarIds& queryIds)
default:
assert (false);
}
solver->printSolverFlags();
cout << endl;
if (Globals::verbosity > 0) {
solver->printSolverFlags();
cout << endl;
}
if (queryIds.size() == 0) {
solver->printAllPosterioris();
} else {

View File

@ -64,7 +64,7 @@ int createLiftedNetwork (void)
}
// LiftedUtils::printSymbolDictionary();
if (Constants::DEBUG > 2) {
if (Globals::verbosity > 2) {
Util::printHeader ("INITIAL PARFACTORS");
for (unsigned i = 0; i < parfactors.size(); i++) {
parfactors[i]->print();
@ -73,7 +73,7 @@ int createLiftedNetwork (void)
ParfactorList* pfList = new ParfactorList (parfactors);
if (Constants::DEBUG >= 2) {
if (Globals::verbosity > 2) {
Util::printHeader ("SHATTERED PARFACTORS");
pfList->print();
}
@ -300,6 +300,10 @@ runLiftedSolver (void)
jointList = YAP_TailOfTerm (jointList);
}
FoveSolver solver (pfListCopy);
if (Globals::verbosity > 0 && taskList == YAP_ARG2) {
solver.printSolverFlags();
cout << endl;
}
if (queryVars.size() == 1) {
results.push_back (solver.getPosterioriOf (queryVars[0]));
} else {
@ -376,7 +380,9 @@ void runVeSolver (
}
// VarElimSolver solver (*mfg);
VarElimSolver solver (*fg); //FIXME
// solver.printSolverFlags();
if (Globals::verbosity > 0 && i == 0) {
solver.printSolverFlags();
}
results.push_back (solver.solveQuery (tasks[i]));
if (fg->isFromBayesNetwork()) {
delete mfg;
@ -410,7 +416,9 @@ void runBpSolver (
cerr << "error: unknow solver" << endl;
abort();
}
// solver->printSolverFlags();
if (Globals::verbosity > 0) {
solver->printSolverFlags();
}
results.reserve (tasks.size());
for (unsigned i = 0; i < tasks.size(); i++) {
results.push_back (solver->solveQuery (tasks[i]));
@ -508,7 +516,11 @@ setHorusFlag (void)
{
string key ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG1)));
string value;
if (key == "accuracy") {
if (key == "verbosity") {
stringstream ss;
ss << (int) YAP_IntOfTerm (YAP_ARG2);
ss >> value;
} else if (key == "accuracy") {
stringstream ss;
ss << (float) YAP_FloatOfTerm (YAP_ARG2);
ss >> value;

View File

@ -11,10 +11,9 @@
namespace Globals {
bool logDomain = false;
//InfAlgs infAlgorithm = InfAlgorithms::VE;
//InfAlgs infAlgorithm = InfAlgorithms::BN_BP;
//InfAlgs infAlgorithm = InfAlgorithms::FG_BP;
InfAlgorithms infAlgorithm = InfAlgorithms::CBP;
unsigned verbosity = 0;
InfAlgorithms infAlgorithm = InfAlgorithms::VE;
};
@ -227,7 +226,11 @@ bool
setHorusFlag (string key, string value)
{
bool returnVal = true;
if (key == "inf_alg") {
if (key == "verbosity") {
stringstream ss;
ss << value;
ss >> Globals::verbosity;
} else if (key == "inf_alg") {
if ( value == "ve") {
Globals::infAlgorithm = InfAlgorithms::VE;
} else if (value == "bp") {

View File

@ -16,6 +16,14 @@ VarElimSolver::~VarElimSolver (void)
Params
VarElimSolver::solveQuery (VarIds queryVids)
{
if (Globals::verbosity > 1) {
cout << "Solving query on " ;
for (unsigned i = 0; i < queryVids.size(); i++) {
if (i != 0) cout << ", " ;
cout << fg.getVarNode (queryVids[i])->label();
}
cout << endl;
}
factorList_.clear();
varFactors_.clear();
elimOrder_.clear();
@ -77,9 +85,19 @@ VarElimSolver::createFactorList (void)
void
VarElimSolver::absorveEvidence (void)
{
if (Globals::verbosity > 2) {
Util::printDashedLine();
cout << "(initial factor list)" << endl;
printActiveFactors();
}
const VarNodes& varNodes = fg.varNodes();
for (unsigned i = 0; i < varNodes.size(); i++) {
if (varNodes[i]->hasEvidence()) {
if (Globals::verbosity > 1) {
cout << "-> aborving evidence on ";
cout << varNodes[i]->label() << " = " ;
cout << varNodes[i]->getEvidence() << endl;
}
const vector<unsigned>& idxs =
varFactors_.find (varNodes[i]->varId())->second;
for (unsigned j = 0; j < idxs.size(); j++) {
@ -108,12 +126,16 @@ VarElimSolver::findEliminationOrder (const VarIds& vids)
void
VarElimSolver::processFactorList (const VarIds& vids)
{
totalFactorSize_ = 0;
largestFactorSize_ = 0;
for (unsigned i = 0; i < elimOrder_.size(); i++) {
if (Constants::DEBUG >= 3) {
printActiveFactors();
if (Globals::verbosity >= 2) {
if (Globals::verbosity >= 3) {
Util::printDashedLine();
printActiveFactors();
}
cout << "-> summing out " ;
VarNode* vn = fg.getVarNode (elimOrder_[i]);
cout << vn->label() << endl;
cout << fg.getVarNode (elimOrder_[i])->label() << endl;
}
eliminate (elimOrder_[i]);
}
@ -137,6 +159,11 @@ VarElimSolver::processFactorList (const VarIds& vids)
finalFactor->reorderArguments (unobservedVids);
finalFactor->normalize();
factorList_.push_back (finalFactor);
if (Globals::verbosity > 0) {
cout << "total factor size: " << totalFactorSize_ << endl;
cout << "largest factor size: " << largestFactorSize_ << endl;
cout << endl;
}
}
@ -158,6 +185,10 @@ VarElimSolver::eliminate (VarId elimVar)
factorList_[idx] = 0;
}
}
totalFactorSize_ += result->size();
if (result->size() > largestFactorSize_) {
largestFactorSize_ = result->size();
}
if (result != 0 && result->nrArguments() != 1) {
result->sumOut (elimVar);
factorList_.push_back (result);
@ -175,14 +206,11 @@ VarElimSolver::eliminate (VarId elimVar)
void
VarElimSolver::printActiveFactors (void)
{
cout << endl;
Util::printDashedLine();
for (unsigned i = 0; i < factorList_.size(); i++) {
if (factorList_[i] != 0) {
cout << factorList_[i]->getLabel() << " " ;
cout << factorList_[i]->params() << endl;
}
}
cout << endl;
}

View File

@ -35,8 +35,10 @@ class VarElimSolver : public Solver
void printActiveFactors (void);
vector<Factor*> factorList_;
VarIds elimOrder_;
Factors factorList_;
VarIds elimOrder_;
unsigned largestFactorSize_;
unsigned totalFactorSize_;
unordered_map<VarId, vector<unsigned>> varFactors_;
};