diff --git a/packages/CLPBN/clpbn/bp/BpSolver.cpp b/packages/CLPBN/clpbn/bp/BpSolver.cpp index 3271371e2..da1e7dd41 100644 --- a/packages/CLPBN/clpbn/bp/BpSolver.cpp +++ b/packages/CLPBN/clpbn/bp/BpSolver.cpp @@ -153,9 +153,8 @@ BpSolver::runSolver (void) nIters_ = 0; while (!converged() && nIters_ < BpOptions::maxIter) { nIters_ ++; - if (Constants::DEBUG >= 2) { + if (Globals::verbosity > 1) { Util::printHeader (string ("Iteration ") + Util::toString (nIters_)); - // cout << endl; } switch (BpOptions::schedule) { case BpOptions::Schedule::SEQ_RANDOM: @@ -178,12 +177,8 @@ BpSolver::runSolver (void) maxResidualSchedule(); break; } - if (Constants::DEBUG >= 2) { - cout << endl; - } } - if (Constants::DEBUG >= 2) { - cout << endl; + if (Globals::verbosity > 0) { if (nIters_ < BpOptions::maxIter) { cout << "Sum-Product converged in " ; cout << nIters_ << " iterations" << endl; @@ -191,6 +186,7 @@ BpSolver::runSolver (void) cout << "The maximum number of iterations was hit, terminating..." ; cout << endl; } + cout << endl; } unsigned size = fg_->varNodes().size(); if (Constants::COLLECT_STATS) { @@ -232,7 +228,7 @@ BpSolver::maxResidualSchedule (void) } for (unsigned c = 0; c < links_.size(); c++) { - if (Constants::DEBUG >= 2) { + if (Globals::verbosity > 1) { cout << "current residuals:" << endl; for (SortedOrder::iterator it = sortedOrder_.begin(); it != sortedOrder_.end(); it ++) { @@ -266,7 +262,7 @@ BpSolver::maxResidualSchedule (void) } } } - if (Constants::DEBUG >= 2) { + if (Globals::verbosity > 1) { Util::printDashedLine(); } } @@ -291,13 +287,13 @@ BpSolver::calculateFactor2VariableMsg (SpLink* link) if (Globals::logDomain) { for (int i = links.size() - 1; i >= 0; i--) { if (links[i]->getVariable() != dst) { - if (Constants::DEBUG >= 5) { + if (Constants::SHOW_BP_CALCS) { cout << " message from " << links[i]->getVariable()->label(); cout << ": " ; } Util::add (msgProduct, getVar2FactorMsg (links[i]), repetitions); repetitions *= links[i]->getVariable()->range(); - if (Constants::DEBUG >= 5) { + if (Constants::SHOW_BP_CALCS) { cout << endl; } } else { @@ -309,13 +305,13 @@ BpSolver::calculateFactor2VariableMsg (SpLink* link) } else { for (int i = links.size() - 1; i >= 0; i--) { if (links[i]->getVariable() != dst) { - if (Constants::DEBUG >= 5) { + if (Constants::SHOW_BP_CALCS) { cout << " message from " << links[i]->getVariable()->label(); cout << ": " ; } Util::multiply (msgProduct, getVar2FactorMsg (links[i]), repetitions); repetitions *= links[i]->getVariable()->range(); - if (Constants::DEBUG >= 5) { + if (Constants::SHOW_BP_CALCS) { cout << endl; } } else { @@ -328,18 +324,18 @@ BpSolver::calculateFactor2VariableMsg (SpLink* link) Factor result (src->factor().arguments(), src->factor().ranges(), msgProduct); result.multiply (src->factor()); - if (Constants::DEBUG >= 5) { + if (Constants::SHOW_BP_CALCS) { cout << " message product: " << msgProduct << endl; cout << " original factor: " << src->factor().params() << endl; cout << " factor product: " << result.params() << endl; } result.sumOutAllExcept (dst->varId()); - if (Constants::DEBUG >= 5) { + if (Constants::SHOW_BP_CALCS) { cout << " marginalized: " << result.params() << endl; } link->getNextMessage() = result.params(); LogAware::normalize (link->getNextMessage()); - if (Constants::DEBUG >= 5) { + if (Constants::SHOW_BP_CALCS) { cout << " curr msg: " << link->getMessage() << endl; cout << " next msg: " << link->getNextMessage() << endl; } @@ -359,7 +355,7 @@ BpSolver::getVar2FactorMsg (const SpLink* link) const } else { msg.resize (src->range(), LogAware::one()); } - if (Constants::DEBUG >= 5) { + if (Constants::SHOW_BP_CALCS) { cout << msg; } const SpLinkSet& links = ninf (src)->getLinks(); @@ -367,7 +363,7 @@ BpSolver::getVar2FactorMsg (const SpLink* link) const for (unsigned i = 0; i < links.size(); i++) { if (links[i]->getFactor() != dst) { Util::add (msg, links[i]->getMessage()); - if (Constants::DEBUG >= 5) { + if (Constants::SHOW_BP_CALCS) { cout << " x " << links[i]->getMessage(); } } @@ -376,13 +372,13 @@ BpSolver::getVar2FactorMsg (const SpLink* link) const for (unsigned i = 0; i < links.size(); i++) { if (links[i]->getFactor() != dst) { Util::multiply (msg, links[i]->getMessage()); - if (Constants::DEBUG >= 5) { + if (Constants::SHOW_BP_CALCS) { cout << " x " << links[i]->getMessage(); } } } } - if (Constants::DEBUG >= 5) { + if (Constants::SHOW_BP_CALCS) { cout << " = " << msg; } return msg; @@ -472,7 +468,16 @@ BpSolver::converged (void) if (links_.size() == 0) { return true; } - if (nIters_ <= 1) { + if (nIters_ == 0) { + return false; + } + if (Globals::verbosity > 2) { + cout << endl; + } + if (nIters_ == 1) { + if (Globals::verbosity > 1) { + cout << "no residuals" << endl << endl; + } return false; } bool converged = true; @@ -486,7 +491,7 @@ BpSolver::converged (void) } else { for (unsigned i = 0; i < links_.size(); i++) { double residual = links_[i]->getResidual(); - if (Constants::DEBUG >= 2) { + if (Globals::verbosity > 1) { cout << links_[i]->toString() + " residual = " << residual << endl; } if (residual > BpOptions::accuracy) { @@ -494,7 +499,7 @@ BpSolver::converged (void) if (Constants::DEBUG == 0) break; } } - if (Constants::DEBUG >= 2) { + if (Globals::verbosity > 1) { cout << endl; } } diff --git a/packages/CLPBN/clpbn/bp/BpSolver.h b/packages/CLPBN/clpbn/bp/BpSolver.h index 688b6bb15..f52370da7 100644 --- a/packages/CLPBN/clpbn/bp/BpSolver.h +++ b/packages/CLPBN/clpbn/bp/BpSolver.h @@ -128,7 +128,7 @@ class BpSolver : public Solver void calculateAndUpdateMessage (SpLink* link, bool calcResidual = true) { - if (Constants::DEBUG >= 3) { + if (Globals::verbosity > 2) { cout << "calculating & updating " << link->toString() << endl; } calculateFactor2VariableMsg (link); @@ -140,7 +140,7 @@ class BpSolver : public Solver void calculateMessage (SpLink* link, bool calcResidual = true) { - if (Constants::DEBUG >= 3) { + if (Globals::verbosity > 2) { cout << "calculating " << link->toString() << endl; } calculateFactor2VariableMsg (link); @@ -152,7 +152,7 @@ class BpSolver : public Solver void updateMessage (SpLink* link) { link->updateMessage(); - if (Constants::DEBUG >= 3) { + if (Globals::verbosity > 2) { cout << "updating " << link->toString() << endl; } } diff --git a/packages/CLPBN/clpbn/bp/CbpSolver.cpp b/packages/CLPBN/clpbn/bp/CbpSolver.cpp index bf8bfbe11..0989dc465 100644 --- a/packages/CLPBN/clpbn/bp/CbpSolver.cpp +++ b/packages/CLPBN/clpbn/bp/CbpSolver.cpp @@ -24,14 +24,6 @@ CbpSolver::CbpSolver (const FactorGraph& fg) : BpSolver (fg) Statistics::updateCompressingStatistics (nrGroundVars, nrGroundFacs, nrClusterVars, nrClusterFacs, nrNeighborless); } - if (Constants::DEBUG >= 5) { - cout << "uncompressed factor graph:" << endl; - cout << " " << fg.nrVarNodes() << " variables " << endl; - cout << " " << fg.nrFacNodes() << " factors " << endl; - cout << "compressed factor graph:" << endl; - cout << " " << fg_->nrVarNodes() << " variables " << endl; - cout << " " << fg_->nrFacNodes() << " factors " << endl; - } } @@ -123,15 +115,25 @@ CbpSolver::getJointDistributionOf (const VarIds& jointVids) void CbpSolver::createLinks (void) -{ +{ + if (Globals::verbosity > 0) { + cout << "original factor graph has " ; + cout << fg.nrVarNodes() << " variables and " ; + cout << fg.nrFacNodes() << " factors " << endl; + cout << "compressed factor graph has " ; + cout << fg_->nrVarNodes() << " variables and " ; + cout << fg_->nrFacNodes() << " factors " << endl; + cout << endl; + } const FacClusters& fcs = cfg_->facClusters(); for (unsigned i = 0; i < fcs.size(); i++) { const VarClusters& vcs = fcs[i]->varClusters(); for (unsigned j = 0; j < vcs.size(); j++) { unsigned count = cfg_->getEdgeCount (fcs[i], vcs[j], j); - if (Constants::DEBUG >= 5) { - cout << "creating edge " ; - cout << fcs[i]->representative()->getLabel() << " -> " ; + if (Globals::verbosity > 1) { + cout << "creating link " ; + cout << fcs[i]->representative()->getLabel(); + cout << " -- " ; cout << vcs[j]->representative()->label(); cout << " idx=" << j << ", count=" << count << endl; } @@ -139,6 +141,9 @@ CbpSolver::createLinks (void) fcs[i]->representative(), vcs[j]->representative(), j, count)); } } + if (Globals::verbosity > 1) { + cout << endl; + } } @@ -151,7 +156,7 @@ CbpSolver::maxResidualSchedule (void) calculateMessage (links_[i]); SortedOrder::iterator it = sortedOrder_.insert (links_[i]); linkMap_.insert (make_pair (links_[i], it)); - if (Constants::DEBUG >= 2 && Constants::DEBUG < 5) { + if (Globals::verbosity >= 1) { cout << "calculating " << links_[i]->toString() << endl; } } @@ -159,7 +164,7 @@ CbpSolver::maxResidualSchedule (void) } for (unsigned c = 0; c < links_.size(); c++) { - if (Constants::DEBUG >= 2) { + if (Globals::verbosity > 1) { cout << endl << "current residuals:" << endl; for (SortedOrder::iterator it = sortedOrder_.begin(); it != sortedOrder_.end(); it ++) { @@ -170,7 +175,7 @@ CbpSolver::maxResidualSchedule (void) SortedOrder::iterator it = sortedOrder_.begin(); SpLink* link = *it; - if (Constants::DEBUG >= 2) { + if (Globals::verbosity >= 1) { cout << "updating " << (*sortedOrder_.begin())->toString() << endl; } if (link->getResidual() < BpOptions::accuracy) { @@ -187,7 +192,7 @@ CbpSolver::maxResidualSchedule (void) const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks(); for (unsigned j = 0; j < links.size(); j++) { if (links[j]->getVariable() != link->getVariable()) { - if (Constants::DEBUG >= 2 && Constants::DEBUG < 5) { + if (Globals::verbosity > 1) { cout << " calculating " << links[j]->toString() << endl; } calculateMessage (links[j]); @@ -202,7 +207,7 @@ CbpSolver::maxResidualSchedule (void) const SpLinkSet& links = ninf(link->getFactor())->getLinks(); for (unsigned i = 0; i < links.size(); i++) { if (links[i]->getVariable() != link->getVariable()) { - if (Constants::DEBUG >= 2 && Constants::DEBUG < 5) { + if (Globals::verbosity > 1) { cout << " calculating " << links[i]->toString() << endl; } calculateMessage (links[i]); @@ -235,13 +240,13 @@ CbpSolver::calculateFactor2VariableMsg (SpLink* _link) for (int i = links.size() - 1; i >= 0; i--) { const CbpSolverLink* cl = static_cast (links[i]); if ( ! (cl->getVariable() == dst && cl->index() == link->index())) { - if (Constants::DEBUG >= 5) { + if (Constants::SHOW_BP_CALCS) { cout << " message from " << links[i]->getVariable()->label(); cout << ": " ; } Util::add (msgProduct, getVar2FactorMsg (links[i]), repetitions); repetitions *= links[i]->getVariable()->range(); - if (Constants::DEBUG >= 5) { + if (Constants::SHOW_BP_CALCS) { cout << endl; } } else { @@ -254,13 +259,13 @@ CbpSolver::calculateFactor2VariableMsg (SpLink* _link) for (int i = links.size() - 1; i >= 0; i--) { const CbpSolverLink* cl = static_cast (links[i]); if ( ! (cl->getVariable() == dst && cl->index() == link->index())) { - if (Constants::DEBUG >= 5) { + if (Constants::SHOW_BP_CALCS) { cout << " message from " << links[i]->getVariable()->label(); cout << ": " ; } Util::multiply (msgProduct, getVar2FactorMsg (links[i]), repetitions); repetitions *= links[i]->getVariable()->range(); - if (Constants::DEBUG >= 5) { + if (Constants::SHOW_BP_CALCS) { cout << endl; } } else { @@ -282,18 +287,18 @@ CbpSolver::calculateFactor2VariableMsg (SpLink* _link) result[i] *= src->factor()[i]; } } - if (Constants::DEBUG >= 5) { + if (Constants::SHOW_BP_CALCS) { cout << " message product: " << msgProduct << endl; cout << " original factor: " << src->factor().params() << endl; cout << " factor product: " << result.params() << endl; } result.sumOutAllExceptIndex (link->index()); - if (Constants::DEBUG >= 5) { + if (Constants::SHOW_BP_CALCS) { cout << " marginalized: " << result.params() << endl; } link->getNextMessage() = result.params(); LogAware::normalize (link->getNextMessage()); - if (Constants::DEBUG >= 5) { + if (Constants::SHOW_BP_CALCS) { cout << " curr msg: " << link->getMessage() << endl; cout << " next msg: " << link->getNextMessage() << endl; } @@ -311,14 +316,14 @@ CbpSolver::getVar2FactorMsg (const SpLink* _link) const if (src->hasEvidence()) { msg.resize (src->range(), LogAware::noEvidence()); double value = link->getMessage()[src->getEvidence()]; - if (Constants::DEBUG >= 5) { + if (Constants::SHOW_BP_CALCS) { msg[src->getEvidence()] = value; cout << msg << "^" << link->nrEdges() << "-1" ; } msg[src->getEvidence()] = LogAware::pow (value, link->nrEdges() - 1); } else { msg = link->getMessage(); - if (Constants::DEBUG >= 5) { + if (Constants::SHOW_BP_CALCS) { cout << msg << "^" << link->nrEdges() << "-1" ; } LogAware::pow (msg, link->nrEdges() - 1); @@ -337,13 +342,13 @@ CbpSolver::getVar2FactorMsg (const SpLink* _link) const CbpSolverLink* cl = static_cast (links[i]); if ( ! (cl->getFactor() == dst && cl->index() == link->index())) { Util::multiply (msg, cl->poweredMessage()); - if (Constants::DEBUG >= 5) { + if (Constants::SHOW_BP_CALCS) { cout << " x " << cl->getNextMessage() << "^" << link->nrEdges(); } } } } - if (Constants::DEBUG >= 5) { + if (Constants::SHOW_BP_CALCS) { cout << " = " << msg; } return msg; diff --git a/packages/CLPBN/clpbn/bp/Factor.cpp b/packages/CLPBN/clpbn/bp/Factor.cpp index 7f8b7e611..011303496 100644 --- a/packages/CLPBN/clpbn/bp/Factor.cpp +++ b/packages/CLPBN/clpbn/bp/Factor.cpp @@ -151,14 +151,12 @@ Factor::sumOutIndex (unsigned idx) void Factor::sumOutAllExceptIndex (unsigned idx) { - int i = (int)idx; - while (args_.size() > i + 1) { + while (args_.size() > idx + 1) { sumOutLastVariable(); } - while (i > 0) { + for (unsigned i = 0; i < idx; i++) { sumOutFirstVariable(); - i -- ; - } + } } diff --git a/packages/CLPBN/clpbn/bp/FoveSolver.cpp b/packages/CLPBN/clpbn/bp/FoveSolver.cpp index 4205e62c5..a82774217 100644 --- a/packages/CLPBN/clpbn/bp/FoveSolver.cpp +++ b/packages/CLPBN/clpbn/bp/FoveSolver.cpp @@ -542,6 +542,18 @@ FoveSolver::getJointDistributionOf (const Grounds& query) +void +FoveSolver::printSolverFlags (void) const +{ + stringstream ss; + ss << "fove [" ; + ss << "log_domain=" << Util::toString (Globals::logDomain); + ss << "]" ; + cout << ss.str() << endl; +} + + + void FoveSolver::absorveEvidence ( ParfactorList& pfList, @@ -568,7 +580,7 @@ FoveSolver::absorveEvidence ( } pfList.add (newPfs); } - if (Constants::DEBUG >= 2 && obsFormulas.empty() == false) { + if (Globals::verbosity > 2 && obsFormulas.empty() == false) { Util::printAsteriskLine(); cout << "AFTER EVIDENCE ABSORVED" << endl; for (unsigned i = 0; i < obsFormulas.size(); i++) { @@ -603,20 +615,26 @@ FoveSolver::countNormalize ( void FoveSolver::runSolver (const Grounds& query) { + largestCost_ = std::log (0); shatterAgainstQuery (query); runWeakBayesBall (query); while (true) { - if (Constants::DEBUG >= 2) { + if (Globals::verbosity > 2) { Util::printDashedLine(); pfList_.print(); - LiftedOperator::printValidOps (pfList_, query); + if (Globals::verbosity > 3) { + LiftedOperator::printValidOps (pfList_, query); + } } LiftedOperator* op = getBestOperation (query); if (op == 0) { break; } - if (Constants::DEBUG >= 2) { - cout << "best operation: " << op->toString() << endl; + if (Globals::verbosity > 1) { + cout << "best operation: " << op->toString(); + if (Globals::verbosity > 2) { + cout << endl; + } } op->apply(); delete op; @@ -630,6 +648,10 @@ FoveSolver::runSolver (const Grounds& query) ++ pfIter; } } + if (Globals::verbosity > 0) { + cout << "largest cost = " << std::exp (largestCost_) << endl; + cout << endl; + } (*pfList_.begin())->reorderAccordingGrounds (query); } @@ -649,6 +671,9 @@ FoveSolver::getBestOperation (const Grounds& query) bestCost = cost; } } + if (bestCost > largestCost_) { + largestCost_ = bestCost; + } for (unsigned i = 0; i < validOps.size(); i++) { if (validOps[i] != bestOp) { delete validOps[i]; @@ -699,18 +724,21 @@ FoveSolver::runWeakBayesBall (const Grounds& query) } ParfactorList::iterator it = pfList_.begin(); + bool foundNotRequired = false; while (it != pfList_.end()) { if (Util::contains (requiredPfs, *it) == false) { + if (Globals::verbosity > 2) { + if (foundNotRequired == false) { + Util::printHeader ("PARFACTORS TO DISCARD"); + foundNotRequired = true; + } + (*it)->print(); + } it = pfList_.removeAndDelete (it); } else { ++ it; } } - - if (Constants::DEBUG >= 2) { - Util::printHeader ("REQUIRED PARFACTORS"); - pfList_.print(); - } } @@ -750,8 +778,7 @@ FoveSolver::shatterAgainstQuery (const Grounds& query) } pfList_.add (newPfs); } - if (Constants::DEBUG >= 2) { - cout << endl; + if (Globals::verbosity > 2) { Util::printAsteriskLine(); cout << "SHATTERED AGAINST THE QUERY" << endl; for (unsigned i = 0; i < query.size(); i++) { diff --git a/packages/CLPBN/clpbn/bp/FoveSolver.h b/packages/CLPBN/clpbn/bp/FoveSolver.h index 5b39aac8b..ef52fae7c 100644 --- a/packages/CLPBN/clpbn/bp/FoveSolver.h +++ b/packages/CLPBN/clpbn/bp/FoveSolver.h @@ -116,6 +116,8 @@ class FoveSolver Params getJointDistributionOf (const Grounds&); + void printSolverFlags (void) const; + static void absorveEvidence ( ParfactorList& pfList, ObservedFormulas& obsFormulas); @@ -133,6 +135,8 @@ class FoveSolver static Parfactors absorve (ObservedFormula&, Parfactor*); ParfactorList pfList_; + + double largestCost_; }; #endif // HORUS_FOVESOLVER_H diff --git a/packages/CLPBN/clpbn/bp/Horus.h b/packages/CLPBN/clpbn/bp/Horus.h index 302887899..b4da50b16 100644 --- a/packages/CLPBN/clpbn/bp/Horus.h +++ b/packages/CLPBN/clpbn/bp/Horus.h @@ -39,6 +39,8 @@ namespace Globals { extern bool logDomain; +extern unsigned verbosity; + extern InfAlgorithms infAlgorithm; }; @@ -47,12 +49,14 @@ extern InfAlgorithms infAlgorithm; namespace Constants { // level of debug information -const unsigned DEBUG = 0; +const unsigned DEBUG = 3; + +const bool SHOW_BP_CALCS = false; const int NO_EVIDENCE = -1; // number of digits to show when printing a parameter -const unsigned PRECISION = 5; +const unsigned PRECISION = 6; const bool COLLECT_STATS = false; diff --git a/packages/CLPBN/clpbn/bp/HorusCli.cpp b/packages/CLPBN/clpbn/bp/HorusCli.cpp index 282e69a2f..00bbeefbd 100644 --- a/packages/CLPBN/clpbn/bp/HorusCli.cpp +++ b/packages/CLPBN/clpbn/bp/HorusCli.cpp @@ -174,8 +174,10 @@ runSolver (const FactorGraph& fg, const VarIds& queryIds) default: assert (false); } - solver->printSolverFlags(); - cout << endl; + if (Globals::verbosity > 0) { + solver->printSolverFlags(); + cout << endl; + } if (queryIds.size() == 0) { solver->printAllPosterioris(); } else { diff --git a/packages/CLPBN/clpbn/bp/HorusYap.cpp b/packages/CLPBN/clpbn/bp/HorusYap.cpp index 28bbe780d..763dfaab9 100644 --- a/packages/CLPBN/clpbn/bp/HorusYap.cpp +++ b/packages/CLPBN/clpbn/bp/HorusYap.cpp @@ -64,7 +64,7 @@ int createLiftedNetwork (void) } // LiftedUtils::printSymbolDictionary(); - if (Constants::DEBUG > 2) { + if (Globals::verbosity > 2) { Util::printHeader ("INITIAL PARFACTORS"); for (unsigned i = 0; i < parfactors.size(); i++) { parfactors[i]->print(); @@ -73,7 +73,7 @@ int createLiftedNetwork (void) ParfactorList* pfList = new ParfactorList (parfactors); - if (Constants::DEBUG >= 2) { + if (Globals::verbosity > 2) { Util::printHeader ("SHATTERED PARFACTORS"); pfList->print(); } @@ -300,6 +300,10 @@ runLiftedSolver (void) jointList = YAP_TailOfTerm (jointList); } FoveSolver solver (pfListCopy); + if (Globals::verbosity > 0 && taskList == YAP_ARG2) { + solver.printSolverFlags(); + cout << endl; + } if (queryVars.size() == 1) { results.push_back (solver.getPosterioriOf (queryVars[0])); } else { @@ -376,7 +380,9 @@ void runVeSolver ( } // VarElimSolver solver (*mfg); VarElimSolver solver (*fg); //FIXME - // solver.printSolverFlags(); + if (Globals::verbosity > 0 && i == 0) { + solver.printSolverFlags(); + } results.push_back (solver.solveQuery (tasks[i])); if (fg->isFromBayesNetwork()) { delete mfg; @@ -410,7 +416,9 @@ void runBpSolver ( cerr << "error: unknow solver" << endl; abort(); } - // solver->printSolverFlags(); + if (Globals::verbosity > 0) { + solver->printSolverFlags(); + } results.reserve (tasks.size()); for (unsigned i = 0; i < tasks.size(); i++) { results.push_back (solver->solveQuery (tasks[i])); @@ -508,7 +516,11 @@ setHorusFlag (void) { string key ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG1))); string value; - if (key == "accuracy") { + if (key == "verbosity") { + stringstream ss; + ss << (int) YAP_IntOfTerm (YAP_ARG2); + ss >> value; + } else if (key == "accuracy") { stringstream ss; ss << (float) YAP_FloatOfTerm (YAP_ARG2); ss >> value; diff --git a/packages/CLPBN/clpbn/bp/Util.cpp b/packages/CLPBN/clpbn/bp/Util.cpp index 97e23de7f..f131512e2 100644 --- a/packages/CLPBN/clpbn/bp/Util.cpp +++ b/packages/CLPBN/clpbn/bp/Util.cpp @@ -11,10 +11,9 @@ namespace Globals { bool logDomain = false; -//InfAlgs infAlgorithm = InfAlgorithms::VE; -//InfAlgs infAlgorithm = InfAlgorithms::BN_BP; -//InfAlgs infAlgorithm = InfAlgorithms::FG_BP; -InfAlgorithms infAlgorithm = InfAlgorithms::CBP; +unsigned verbosity = 0; + +InfAlgorithms infAlgorithm = InfAlgorithms::VE; }; @@ -227,7 +226,11 @@ bool setHorusFlag (string key, string value) { bool returnVal = true; - if (key == "inf_alg") { + if (key == "verbosity") { + stringstream ss; + ss << value; + ss >> Globals::verbosity; + } else if (key == "inf_alg") { if ( value == "ve") { Globals::infAlgorithm = InfAlgorithms::VE; } else if (value == "bp") { diff --git a/packages/CLPBN/clpbn/bp/VarElimSolver.cpp b/packages/CLPBN/clpbn/bp/VarElimSolver.cpp index 55399609b..70634e91e 100644 --- a/packages/CLPBN/clpbn/bp/VarElimSolver.cpp +++ b/packages/CLPBN/clpbn/bp/VarElimSolver.cpp @@ -16,6 +16,14 @@ VarElimSolver::~VarElimSolver (void) Params VarElimSolver::solveQuery (VarIds queryVids) { + if (Globals::verbosity > 1) { + cout << "Solving query on " ; + for (unsigned i = 0; i < queryVids.size(); i++) { + if (i != 0) cout << ", " ; + cout << fg.getVarNode (queryVids[i])->label(); + } + cout << endl; + } factorList_.clear(); varFactors_.clear(); elimOrder_.clear(); @@ -77,9 +85,19 @@ VarElimSolver::createFactorList (void) void VarElimSolver::absorveEvidence (void) { + if (Globals::verbosity > 2) { + Util::printDashedLine(); + cout << "(initial factor list)" << endl; + printActiveFactors(); + } const VarNodes& varNodes = fg.varNodes(); for (unsigned i = 0; i < varNodes.size(); i++) { if (varNodes[i]->hasEvidence()) { + if (Globals::verbosity > 1) { + cout << "-> aborving evidence on "; + cout << varNodes[i]->label() << " = " ; + cout << varNodes[i]->getEvidence() << endl; + } const vector& idxs = varFactors_.find (varNodes[i]->varId())->second; for (unsigned j = 0; j < idxs.size(); j++) { @@ -108,12 +126,16 @@ VarElimSolver::findEliminationOrder (const VarIds& vids) void VarElimSolver::processFactorList (const VarIds& vids) { + totalFactorSize_ = 0; + largestFactorSize_ = 0; for (unsigned i = 0; i < elimOrder_.size(); i++) { - if (Constants::DEBUG >= 3) { - printActiveFactors(); + if (Globals::verbosity >= 2) { + if (Globals::verbosity >= 3) { + Util::printDashedLine(); + printActiveFactors(); + } cout << "-> summing out " ; - VarNode* vn = fg.getVarNode (elimOrder_[i]); - cout << vn->label() << endl; + cout << fg.getVarNode (elimOrder_[i])->label() << endl; } eliminate (elimOrder_[i]); } @@ -137,6 +159,11 @@ VarElimSolver::processFactorList (const VarIds& vids) finalFactor->reorderArguments (unobservedVids); finalFactor->normalize(); factorList_.push_back (finalFactor); + if (Globals::verbosity > 0) { + cout << "total factor size: " << totalFactorSize_ << endl; + cout << "largest factor size: " << largestFactorSize_ << endl; + cout << endl; + } } @@ -158,6 +185,10 @@ VarElimSolver::eliminate (VarId elimVar) factorList_[idx] = 0; } } + totalFactorSize_ += result->size(); + if (result->size() > largestFactorSize_) { + largestFactorSize_ = result->size(); + } if (result != 0 && result->nrArguments() != 1) { result->sumOut (elimVar); factorList_.push_back (result); @@ -175,14 +206,11 @@ VarElimSolver::eliminate (VarId elimVar) void VarElimSolver::printActiveFactors (void) { - cout << endl; - Util::printDashedLine(); for (unsigned i = 0; i < factorList_.size(); i++) { if (factorList_[i] != 0) { cout << factorList_[i]->getLabel() << " " ; cout << factorList_[i]->params() << endl; } } - cout << endl; } diff --git a/packages/CLPBN/clpbn/bp/VarElimSolver.h b/packages/CLPBN/clpbn/bp/VarElimSolver.h index ccdaefdb6..d5ed441a9 100644 --- a/packages/CLPBN/clpbn/bp/VarElimSolver.h +++ b/packages/CLPBN/clpbn/bp/VarElimSolver.h @@ -35,8 +35,10 @@ class VarElimSolver : public Solver void printActiveFactors (void); - vector factorList_; - VarIds elimOrder_; + Factors factorList_; + VarIds elimOrder_; + unsigned largestFactorSize_; + unsigned totalFactorSize_; unordered_map> varFactors_; };