revamp debugging plataform
This commit is contained in:
parent
d86e2c8386
commit
56475cacbc
@ -153,9 +153,8 @@ BpSolver::runSolver (void)
|
||||
nIters_ = 0;
|
||||
while (!converged() && nIters_ < BpOptions::maxIter) {
|
||||
nIters_ ++;
|
||||
if (Constants::DEBUG >= 2) {
|
||||
if (Globals::verbosity > 1) {
|
||||
Util::printHeader (string ("Iteration ") + Util::toString (nIters_));
|
||||
// cout << endl;
|
||||
}
|
||||
switch (BpOptions::schedule) {
|
||||
case BpOptions::Schedule::SEQ_RANDOM:
|
||||
@ -178,12 +177,8 @@ BpSolver::runSolver (void)
|
||||
maxResidualSchedule();
|
||||
break;
|
||||
}
|
||||
if (Constants::DEBUG >= 2) {
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
if (Constants::DEBUG >= 2) {
|
||||
cout << endl;
|
||||
if (Globals::verbosity > 0) {
|
||||
if (nIters_ < BpOptions::maxIter) {
|
||||
cout << "Sum-Product converged in " ;
|
||||
cout << nIters_ << " iterations" << endl;
|
||||
@ -191,6 +186,7 @@ BpSolver::runSolver (void)
|
||||
cout << "The maximum number of iterations was hit, terminating..." ;
|
||||
cout << endl;
|
||||
}
|
||||
cout << endl;
|
||||
}
|
||||
unsigned size = fg_->varNodes().size();
|
||||
if (Constants::COLLECT_STATS) {
|
||||
@ -232,7 +228,7 @@ BpSolver::maxResidualSchedule (void)
|
||||
}
|
||||
|
||||
for (unsigned c = 0; c < links_.size(); c++) {
|
||||
if (Constants::DEBUG >= 2) {
|
||||
if (Globals::verbosity > 1) {
|
||||
cout << "current residuals:" << endl;
|
||||
for (SortedOrder::iterator it = sortedOrder_.begin();
|
||||
it != sortedOrder_.end(); it ++) {
|
||||
@ -266,7 +262,7 @@ BpSolver::maxResidualSchedule (void)
|
||||
}
|
||||
}
|
||||
}
|
||||
if (Constants::DEBUG >= 2) {
|
||||
if (Globals::verbosity > 1) {
|
||||
Util::printDashedLine();
|
||||
}
|
||||
}
|
||||
@ -291,13 +287,13 @@ BpSolver::calculateFactor2VariableMsg (SpLink* link)
|
||||
if (Globals::logDomain) {
|
||||
for (int i = links.size() - 1; i >= 0; i--) {
|
||||
if (links[i]->getVariable() != dst) {
|
||||
if (Constants::DEBUG >= 5) {
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << " message from " << links[i]->getVariable()->label();
|
||||
cout << ": " ;
|
||||
}
|
||||
Util::add (msgProduct, getVar2FactorMsg (links[i]), repetitions);
|
||||
repetitions *= links[i]->getVariable()->range();
|
||||
if (Constants::DEBUG >= 5) {
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << endl;
|
||||
}
|
||||
} else {
|
||||
@ -309,13 +305,13 @@ BpSolver::calculateFactor2VariableMsg (SpLink* link)
|
||||
} else {
|
||||
for (int i = links.size() - 1; i >= 0; i--) {
|
||||
if (links[i]->getVariable() != dst) {
|
||||
if (Constants::DEBUG >= 5) {
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << " message from " << links[i]->getVariable()->label();
|
||||
cout << ": " ;
|
||||
}
|
||||
Util::multiply (msgProduct, getVar2FactorMsg (links[i]), repetitions);
|
||||
repetitions *= links[i]->getVariable()->range();
|
||||
if (Constants::DEBUG >= 5) {
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << endl;
|
||||
}
|
||||
} else {
|
||||
@ -328,18 +324,18 @@ BpSolver::calculateFactor2VariableMsg (SpLink* link)
|
||||
Factor result (src->factor().arguments(),
|
||||
src->factor().ranges(), msgProduct);
|
||||
result.multiply (src->factor());
|
||||
if (Constants::DEBUG >= 5) {
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << " message product: " << msgProduct << endl;
|
||||
cout << " original factor: " << src->factor().params() << endl;
|
||||
cout << " factor product: " << result.params() << endl;
|
||||
}
|
||||
result.sumOutAllExcept (dst->varId());
|
||||
if (Constants::DEBUG >= 5) {
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << " marginalized: " << result.params() << endl;
|
||||
}
|
||||
link->getNextMessage() = result.params();
|
||||
LogAware::normalize (link->getNextMessage());
|
||||
if (Constants::DEBUG >= 5) {
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << " curr msg: " << link->getMessage() << endl;
|
||||
cout << " next msg: " << link->getNextMessage() << endl;
|
||||
}
|
||||
@ -359,7 +355,7 @@ BpSolver::getVar2FactorMsg (const SpLink* link) const
|
||||
} else {
|
||||
msg.resize (src->range(), LogAware::one());
|
||||
}
|
||||
if (Constants::DEBUG >= 5) {
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << msg;
|
||||
}
|
||||
const SpLinkSet& links = ninf (src)->getLinks();
|
||||
@ -367,7 +363,7 @@ BpSolver::getVar2FactorMsg (const SpLink* link) const
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
if (links[i]->getFactor() != dst) {
|
||||
Util::add (msg, links[i]->getMessage());
|
||||
if (Constants::DEBUG >= 5) {
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << " x " << links[i]->getMessage();
|
||||
}
|
||||
}
|
||||
@ -376,13 +372,13 @@ BpSolver::getVar2FactorMsg (const SpLink* link) const
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
if (links[i]->getFactor() != dst) {
|
||||
Util::multiply (msg, links[i]->getMessage());
|
||||
if (Constants::DEBUG >= 5) {
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << " x " << links[i]->getMessage();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (Constants::DEBUG >= 5) {
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << " = " << msg;
|
||||
}
|
||||
return msg;
|
||||
@ -472,7 +468,16 @@ BpSolver::converged (void)
|
||||
if (links_.size() == 0) {
|
||||
return true;
|
||||
}
|
||||
if (nIters_ <= 1) {
|
||||
if (nIters_ == 0) {
|
||||
return false;
|
||||
}
|
||||
if (Globals::verbosity > 2) {
|
||||
cout << endl;
|
||||
}
|
||||
if (nIters_ == 1) {
|
||||
if (Globals::verbosity > 1) {
|
||||
cout << "no residuals" << endl << endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
bool converged = true;
|
||||
@ -486,7 +491,7 @@ BpSolver::converged (void)
|
||||
} else {
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
double residual = links_[i]->getResidual();
|
||||
if (Constants::DEBUG >= 2) {
|
||||
if (Globals::verbosity > 1) {
|
||||
cout << links_[i]->toString() + " residual = " << residual << endl;
|
||||
}
|
||||
if (residual > BpOptions::accuracy) {
|
||||
@ -494,7 +499,7 @@ BpSolver::converged (void)
|
||||
if (Constants::DEBUG == 0) break;
|
||||
}
|
||||
}
|
||||
if (Constants::DEBUG >= 2) {
|
||||
if (Globals::verbosity > 1) {
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
|
@ -128,7 +128,7 @@ class BpSolver : public Solver
|
||||
|
||||
void calculateAndUpdateMessage (SpLink* link, bool calcResidual = true)
|
||||
{
|
||||
if (Constants::DEBUG >= 3) {
|
||||
if (Globals::verbosity > 2) {
|
||||
cout << "calculating & updating " << link->toString() << endl;
|
||||
}
|
||||
calculateFactor2VariableMsg (link);
|
||||
@ -140,7 +140,7 @@ class BpSolver : public Solver
|
||||
|
||||
void calculateMessage (SpLink* link, bool calcResidual = true)
|
||||
{
|
||||
if (Constants::DEBUG >= 3) {
|
||||
if (Globals::verbosity > 2) {
|
||||
cout << "calculating " << link->toString() << endl;
|
||||
}
|
||||
calculateFactor2VariableMsg (link);
|
||||
@ -152,7 +152,7 @@ class BpSolver : public Solver
|
||||
void updateMessage (SpLink* link)
|
||||
{
|
||||
link->updateMessage();
|
||||
if (Constants::DEBUG >= 3) {
|
||||
if (Globals::verbosity > 2) {
|
||||
cout << "updating " << link->toString() << endl;
|
||||
}
|
||||
}
|
||||
|
@ -24,14 +24,6 @@ CbpSolver::CbpSolver (const FactorGraph& fg) : BpSolver (fg)
|
||||
Statistics::updateCompressingStatistics (nrGroundVars,
|
||||
nrGroundFacs, nrClusterVars, nrClusterFacs, nrNeighborless);
|
||||
}
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << "uncompressed factor graph:" << endl;
|
||||
cout << " " << fg.nrVarNodes() << " variables " << endl;
|
||||
cout << " " << fg.nrFacNodes() << " factors " << endl;
|
||||
cout << "compressed factor graph:" << endl;
|
||||
cout << " " << fg_->nrVarNodes() << " variables " << endl;
|
||||
cout << " " << fg_->nrFacNodes() << " factors " << endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -123,15 +115,25 @@ CbpSolver::getJointDistributionOf (const VarIds& jointVids)
|
||||
|
||||
void
|
||||
CbpSolver::createLinks (void)
|
||||
{
|
||||
{
|
||||
if (Globals::verbosity > 0) {
|
||||
cout << "original factor graph has " ;
|
||||
cout << fg.nrVarNodes() << " variables and " ;
|
||||
cout << fg.nrFacNodes() << " factors " << endl;
|
||||
cout << "compressed factor graph has " ;
|
||||
cout << fg_->nrVarNodes() << " variables and " ;
|
||||
cout << fg_->nrFacNodes() << " factors " << endl;
|
||||
cout << endl;
|
||||
}
|
||||
const FacClusters& fcs = cfg_->facClusters();
|
||||
for (unsigned i = 0; i < fcs.size(); i++) {
|
||||
const VarClusters& vcs = fcs[i]->varClusters();
|
||||
for (unsigned j = 0; j < vcs.size(); j++) {
|
||||
unsigned count = cfg_->getEdgeCount (fcs[i], vcs[j], j);
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << "creating edge " ;
|
||||
cout << fcs[i]->representative()->getLabel() << " -> " ;
|
||||
if (Globals::verbosity > 1) {
|
||||
cout << "creating link " ;
|
||||
cout << fcs[i]->representative()->getLabel();
|
||||
cout << " -- " ;
|
||||
cout << vcs[j]->representative()->label();
|
||||
cout << " idx=" << j << ", count=" << count << endl;
|
||||
}
|
||||
@ -139,6 +141,9 @@ CbpSolver::createLinks (void)
|
||||
fcs[i]->representative(), vcs[j]->representative(), j, count));
|
||||
}
|
||||
}
|
||||
if (Globals::verbosity > 1) {
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -151,7 +156,7 @@ CbpSolver::maxResidualSchedule (void)
|
||||
calculateMessage (links_[i]);
|
||||
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
|
||||
linkMap_.insert (make_pair (links_[i], it));
|
||||
if (Constants::DEBUG >= 2 && Constants::DEBUG < 5) {
|
||||
if (Globals::verbosity >= 1) {
|
||||
cout << "calculating " << links_[i]->toString() << endl;
|
||||
}
|
||||
}
|
||||
@ -159,7 +164,7 @@ CbpSolver::maxResidualSchedule (void)
|
||||
}
|
||||
|
||||
for (unsigned c = 0; c < links_.size(); c++) {
|
||||
if (Constants::DEBUG >= 2) {
|
||||
if (Globals::verbosity > 1) {
|
||||
cout << endl << "current residuals:" << endl;
|
||||
for (SortedOrder::iterator it = sortedOrder_.begin();
|
||||
it != sortedOrder_.end(); it ++) {
|
||||
@ -170,7 +175,7 @@ CbpSolver::maxResidualSchedule (void)
|
||||
|
||||
SortedOrder::iterator it = sortedOrder_.begin();
|
||||
SpLink* link = *it;
|
||||
if (Constants::DEBUG >= 2) {
|
||||
if (Globals::verbosity >= 1) {
|
||||
cout << "updating " << (*sortedOrder_.begin())->toString() << endl;
|
||||
}
|
||||
if (link->getResidual() < BpOptions::accuracy) {
|
||||
@ -187,7 +192,7 @@ CbpSolver::maxResidualSchedule (void)
|
||||
const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks();
|
||||
for (unsigned j = 0; j < links.size(); j++) {
|
||||
if (links[j]->getVariable() != link->getVariable()) {
|
||||
if (Constants::DEBUG >= 2 && Constants::DEBUG < 5) {
|
||||
if (Globals::verbosity > 1) {
|
||||
cout << " calculating " << links[j]->toString() << endl;
|
||||
}
|
||||
calculateMessage (links[j]);
|
||||
@ -202,7 +207,7 @@ CbpSolver::maxResidualSchedule (void)
|
||||
const SpLinkSet& links = ninf(link->getFactor())->getLinks();
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
if (links[i]->getVariable() != link->getVariable()) {
|
||||
if (Constants::DEBUG >= 2 && Constants::DEBUG < 5) {
|
||||
if (Globals::verbosity > 1) {
|
||||
cout << " calculating " << links[i]->toString() << endl;
|
||||
}
|
||||
calculateMessage (links[i]);
|
||||
@ -235,13 +240,13 @@ CbpSolver::calculateFactor2VariableMsg (SpLink* _link)
|
||||
for (int i = links.size() - 1; i >= 0; i--) {
|
||||
const CbpSolverLink* cl = static_cast<const CbpSolverLink*> (links[i]);
|
||||
if ( ! (cl->getVariable() == dst && cl->index() == link->index())) {
|
||||
if (Constants::DEBUG >= 5) {
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << " message from " << links[i]->getVariable()->label();
|
||||
cout << ": " ;
|
||||
}
|
||||
Util::add (msgProduct, getVar2FactorMsg (links[i]), repetitions);
|
||||
repetitions *= links[i]->getVariable()->range();
|
||||
if (Constants::DEBUG >= 5) {
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << endl;
|
||||
}
|
||||
} else {
|
||||
@ -254,13 +259,13 @@ CbpSolver::calculateFactor2VariableMsg (SpLink* _link)
|
||||
for (int i = links.size() - 1; i >= 0; i--) {
|
||||
const CbpSolverLink* cl = static_cast<const CbpSolverLink*> (links[i]);
|
||||
if ( ! (cl->getVariable() == dst && cl->index() == link->index())) {
|
||||
if (Constants::DEBUG >= 5) {
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << " message from " << links[i]->getVariable()->label();
|
||||
cout << ": " ;
|
||||
}
|
||||
Util::multiply (msgProduct, getVar2FactorMsg (links[i]), repetitions);
|
||||
repetitions *= links[i]->getVariable()->range();
|
||||
if (Constants::DEBUG >= 5) {
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << endl;
|
||||
}
|
||||
} else {
|
||||
@ -282,18 +287,18 @@ CbpSolver::calculateFactor2VariableMsg (SpLink* _link)
|
||||
result[i] *= src->factor()[i];
|
||||
}
|
||||
}
|
||||
if (Constants::DEBUG >= 5) {
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << " message product: " << msgProduct << endl;
|
||||
cout << " original factor: " << src->factor().params() << endl;
|
||||
cout << " factor product: " << result.params() << endl;
|
||||
}
|
||||
result.sumOutAllExceptIndex (link->index());
|
||||
if (Constants::DEBUG >= 5) {
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << " marginalized: " << result.params() << endl;
|
||||
}
|
||||
link->getNextMessage() = result.params();
|
||||
LogAware::normalize (link->getNextMessage());
|
||||
if (Constants::DEBUG >= 5) {
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << " curr msg: " << link->getMessage() << endl;
|
||||
cout << " next msg: " << link->getNextMessage() << endl;
|
||||
}
|
||||
@ -311,14 +316,14 @@ CbpSolver::getVar2FactorMsg (const SpLink* _link) const
|
||||
if (src->hasEvidence()) {
|
||||
msg.resize (src->range(), LogAware::noEvidence());
|
||||
double value = link->getMessage()[src->getEvidence()];
|
||||
if (Constants::DEBUG >= 5) {
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
msg[src->getEvidence()] = value;
|
||||
cout << msg << "^" << link->nrEdges() << "-1" ;
|
||||
}
|
||||
msg[src->getEvidence()] = LogAware::pow (value, link->nrEdges() - 1);
|
||||
} else {
|
||||
msg = link->getMessage();
|
||||
if (Constants::DEBUG >= 5) {
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << msg << "^" << link->nrEdges() << "-1" ;
|
||||
}
|
||||
LogAware::pow (msg, link->nrEdges() - 1);
|
||||
@ -337,13 +342,13 @@ CbpSolver::getVar2FactorMsg (const SpLink* _link) const
|
||||
CbpSolverLink* cl = static_cast<CbpSolverLink*> (links[i]);
|
||||
if ( ! (cl->getFactor() == dst && cl->index() == link->index())) {
|
||||
Util::multiply (msg, cl->poweredMessage());
|
||||
if (Constants::DEBUG >= 5) {
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << " x " << cl->getNextMessage() << "^" << link->nrEdges();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (Constants::DEBUG >= 5) {
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << " = " << msg;
|
||||
}
|
||||
return msg;
|
||||
|
@ -151,14 +151,12 @@ Factor::sumOutIndex (unsigned idx)
|
||||
void
|
||||
Factor::sumOutAllExceptIndex (unsigned idx)
|
||||
{
|
||||
int i = (int)idx;
|
||||
while (args_.size() > i + 1) {
|
||||
while (args_.size() > idx + 1) {
|
||||
sumOutLastVariable();
|
||||
}
|
||||
while (i > 0) {
|
||||
for (unsigned i = 0; i < idx; i++) {
|
||||
sumOutFirstVariable();
|
||||
i -- ;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
@ -542,6 +542,18 @@ FoveSolver::getJointDistributionOf (const Grounds& query)
|
||||
|
||||
|
||||
|
||||
void
|
||||
FoveSolver::printSolverFlags (void) const
|
||||
{
|
||||
stringstream ss;
|
||||
ss << "fove [" ;
|
||||
ss << "log_domain=" << Util::toString (Globals::logDomain);
|
||||
ss << "]" ;
|
||||
cout << ss.str() << endl;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FoveSolver::absorveEvidence (
|
||||
ParfactorList& pfList,
|
||||
@ -568,7 +580,7 @@ FoveSolver::absorveEvidence (
|
||||
}
|
||||
pfList.add (newPfs);
|
||||
}
|
||||
if (Constants::DEBUG >= 2 && obsFormulas.empty() == false) {
|
||||
if (Globals::verbosity > 2 && obsFormulas.empty() == false) {
|
||||
Util::printAsteriskLine();
|
||||
cout << "AFTER EVIDENCE ABSORVED" << endl;
|
||||
for (unsigned i = 0; i < obsFormulas.size(); i++) {
|
||||
@ -603,20 +615,26 @@ FoveSolver::countNormalize (
|
||||
void
|
||||
FoveSolver::runSolver (const Grounds& query)
|
||||
{
|
||||
largestCost_ = std::log (0);
|
||||
shatterAgainstQuery (query);
|
||||
runWeakBayesBall (query);
|
||||
while (true) {
|
||||
if (Constants::DEBUG >= 2) {
|
||||
if (Globals::verbosity > 2) {
|
||||
Util::printDashedLine();
|
||||
pfList_.print();
|
||||
LiftedOperator::printValidOps (pfList_, query);
|
||||
if (Globals::verbosity > 3) {
|
||||
LiftedOperator::printValidOps (pfList_, query);
|
||||
}
|
||||
}
|
||||
LiftedOperator* op = getBestOperation (query);
|
||||
if (op == 0) {
|
||||
break;
|
||||
}
|
||||
if (Constants::DEBUG >= 2) {
|
||||
cout << "best operation: " << op->toString() << endl;
|
||||
if (Globals::verbosity > 1) {
|
||||
cout << "best operation: " << op->toString();
|
||||
if (Globals::verbosity > 2) {
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
op->apply();
|
||||
delete op;
|
||||
@ -630,6 +648,10 @@ FoveSolver::runSolver (const Grounds& query)
|
||||
++ pfIter;
|
||||
}
|
||||
}
|
||||
if (Globals::verbosity > 0) {
|
||||
cout << "largest cost = " << std::exp (largestCost_) << endl;
|
||||
cout << endl;
|
||||
}
|
||||
(*pfList_.begin())->reorderAccordingGrounds (query);
|
||||
}
|
||||
|
||||
@ -649,6 +671,9 @@ FoveSolver::getBestOperation (const Grounds& query)
|
||||
bestCost = cost;
|
||||
}
|
||||
}
|
||||
if (bestCost > largestCost_) {
|
||||
largestCost_ = bestCost;
|
||||
}
|
||||
for (unsigned i = 0; i < validOps.size(); i++) {
|
||||
if (validOps[i] != bestOp) {
|
||||
delete validOps[i];
|
||||
@ -699,18 +724,21 @@ FoveSolver::runWeakBayesBall (const Grounds& query)
|
||||
}
|
||||
|
||||
ParfactorList::iterator it = pfList_.begin();
|
||||
bool foundNotRequired = false;
|
||||
while (it != pfList_.end()) {
|
||||
if (Util::contains (requiredPfs, *it) == false) {
|
||||
if (Globals::verbosity > 2) {
|
||||
if (foundNotRequired == false) {
|
||||
Util::printHeader ("PARFACTORS TO DISCARD");
|
||||
foundNotRequired = true;
|
||||
}
|
||||
(*it)->print();
|
||||
}
|
||||
it = pfList_.removeAndDelete (it);
|
||||
} else {
|
||||
++ it;
|
||||
}
|
||||
}
|
||||
|
||||
if (Constants::DEBUG >= 2) {
|
||||
Util::printHeader ("REQUIRED PARFACTORS");
|
||||
pfList_.print();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -750,8 +778,7 @@ FoveSolver::shatterAgainstQuery (const Grounds& query)
|
||||
}
|
||||
pfList_.add (newPfs);
|
||||
}
|
||||
if (Constants::DEBUG >= 2) {
|
||||
cout << endl;
|
||||
if (Globals::verbosity > 2) {
|
||||
Util::printAsteriskLine();
|
||||
cout << "SHATTERED AGAINST THE QUERY" << endl;
|
||||
for (unsigned i = 0; i < query.size(); i++) {
|
||||
|
@ -116,6 +116,8 @@ class FoveSolver
|
||||
|
||||
Params getJointDistributionOf (const Grounds&);
|
||||
|
||||
void printSolverFlags (void) const;
|
||||
|
||||
static void absorveEvidence (
|
||||
ParfactorList& pfList, ObservedFormulas& obsFormulas);
|
||||
|
||||
@ -133,6 +135,8 @@ class FoveSolver
|
||||
static Parfactors absorve (ObservedFormula&, Parfactor*);
|
||||
|
||||
ParfactorList pfList_;
|
||||
|
||||
double largestCost_;
|
||||
};
|
||||
|
||||
#endif // HORUS_FOVESOLVER_H
|
||||
|
@ -39,6 +39,8 @@ namespace Globals {
|
||||
|
||||
extern bool logDomain;
|
||||
|
||||
extern unsigned verbosity;
|
||||
|
||||
extern InfAlgorithms infAlgorithm;
|
||||
|
||||
};
|
||||
@ -47,12 +49,14 @@ extern InfAlgorithms infAlgorithm;
|
||||
namespace Constants {
|
||||
|
||||
// level of debug information
|
||||
const unsigned DEBUG = 0;
|
||||
const unsigned DEBUG = 3;
|
||||
|
||||
const bool SHOW_BP_CALCS = false;
|
||||
|
||||
const int NO_EVIDENCE = -1;
|
||||
|
||||
// number of digits to show when printing a parameter
|
||||
const unsigned PRECISION = 5;
|
||||
const unsigned PRECISION = 6;
|
||||
|
||||
const bool COLLECT_STATS = false;
|
||||
|
||||
|
@ -174,8 +174,10 @@ runSolver (const FactorGraph& fg, const VarIds& queryIds)
|
||||
default:
|
||||
assert (false);
|
||||
}
|
||||
solver->printSolverFlags();
|
||||
cout << endl;
|
||||
if (Globals::verbosity > 0) {
|
||||
solver->printSolverFlags();
|
||||
cout << endl;
|
||||
}
|
||||
if (queryIds.size() == 0) {
|
||||
solver->printAllPosterioris();
|
||||
} else {
|
||||
|
@ -64,7 +64,7 @@ int createLiftedNetwork (void)
|
||||
}
|
||||
|
||||
// LiftedUtils::printSymbolDictionary();
|
||||
if (Constants::DEBUG > 2) {
|
||||
if (Globals::verbosity > 2) {
|
||||
Util::printHeader ("INITIAL PARFACTORS");
|
||||
for (unsigned i = 0; i < parfactors.size(); i++) {
|
||||
parfactors[i]->print();
|
||||
@ -73,7 +73,7 @@ int createLiftedNetwork (void)
|
||||
|
||||
ParfactorList* pfList = new ParfactorList (parfactors);
|
||||
|
||||
if (Constants::DEBUG >= 2) {
|
||||
if (Globals::verbosity > 2) {
|
||||
Util::printHeader ("SHATTERED PARFACTORS");
|
||||
pfList->print();
|
||||
}
|
||||
@ -300,6 +300,10 @@ runLiftedSolver (void)
|
||||
jointList = YAP_TailOfTerm (jointList);
|
||||
}
|
||||
FoveSolver solver (pfListCopy);
|
||||
if (Globals::verbosity > 0 && taskList == YAP_ARG2) {
|
||||
solver.printSolverFlags();
|
||||
cout << endl;
|
||||
}
|
||||
if (queryVars.size() == 1) {
|
||||
results.push_back (solver.getPosterioriOf (queryVars[0]));
|
||||
} else {
|
||||
@ -376,7 +380,9 @@ void runVeSolver (
|
||||
}
|
||||
// VarElimSolver solver (*mfg);
|
||||
VarElimSolver solver (*fg); //FIXME
|
||||
// solver.printSolverFlags();
|
||||
if (Globals::verbosity > 0 && i == 0) {
|
||||
solver.printSolverFlags();
|
||||
}
|
||||
results.push_back (solver.solveQuery (tasks[i]));
|
||||
if (fg->isFromBayesNetwork()) {
|
||||
delete mfg;
|
||||
@ -410,7 +416,9 @@ void runBpSolver (
|
||||
cerr << "error: unknow solver" << endl;
|
||||
abort();
|
||||
}
|
||||
// solver->printSolverFlags();
|
||||
if (Globals::verbosity > 0) {
|
||||
solver->printSolverFlags();
|
||||
}
|
||||
results.reserve (tasks.size());
|
||||
for (unsigned i = 0; i < tasks.size(); i++) {
|
||||
results.push_back (solver->solveQuery (tasks[i]));
|
||||
@ -508,7 +516,11 @@ setHorusFlag (void)
|
||||
{
|
||||
string key ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG1)));
|
||||
string value;
|
||||
if (key == "accuracy") {
|
||||
if (key == "verbosity") {
|
||||
stringstream ss;
|
||||
ss << (int) YAP_IntOfTerm (YAP_ARG2);
|
||||
ss >> value;
|
||||
} else if (key == "accuracy") {
|
||||
stringstream ss;
|
||||
ss << (float) YAP_FloatOfTerm (YAP_ARG2);
|
||||
ss >> value;
|
||||
|
@ -11,10 +11,9 @@
|
||||
namespace Globals {
|
||||
bool logDomain = false;
|
||||
|
||||
//InfAlgs infAlgorithm = InfAlgorithms::VE;
|
||||
//InfAlgs infAlgorithm = InfAlgorithms::BN_BP;
|
||||
//InfAlgs infAlgorithm = InfAlgorithms::FG_BP;
|
||||
InfAlgorithms infAlgorithm = InfAlgorithms::CBP;
|
||||
unsigned verbosity = 0;
|
||||
|
||||
InfAlgorithms infAlgorithm = InfAlgorithms::VE;
|
||||
};
|
||||
|
||||
|
||||
@ -227,7 +226,11 @@ bool
|
||||
setHorusFlag (string key, string value)
|
||||
{
|
||||
bool returnVal = true;
|
||||
if (key == "inf_alg") {
|
||||
if (key == "verbosity") {
|
||||
stringstream ss;
|
||||
ss << value;
|
||||
ss >> Globals::verbosity;
|
||||
} else if (key == "inf_alg") {
|
||||
if ( value == "ve") {
|
||||
Globals::infAlgorithm = InfAlgorithms::VE;
|
||||
} else if (value == "bp") {
|
||||
|
@ -16,6 +16,14 @@ VarElimSolver::~VarElimSolver (void)
|
||||
Params
|
||||
VarElimSolver::solveQuery (VarIds queryVids)
|
||||
{
|
||||
if (Globals::verbosity > 1) {
|
||||
cout << "Solving query on " ;
|
||||
for (unsigned i = 0; i < queryVids.size(); i++) {
|
||||
if (i != 0) cout << ", " ;
|
||||
cout << fg.getVarNode (queryVids[i])->label();
|
||||
}
|
||||
cout << endl;
|
||||
}
|
||||
factorList_.clear();
|
||||
varFactors_.clear();
|
||||
elimOrder_.clear();
|
||||
@ -77,9 +85,19 @@ VarElimSolver::createFactorList (void)
|
||||
void
|
||||
VarElimSolver::absorveEvidence (void)
|
||||
{
|
||||
if (Globals::verbosity > 2) {
|
||||
Util::printDashedLine();
|
||||
cout << "(initial factor list)" << endl;
|
||||
printActiveFactors();
|
||||
}
|
||||
const VarNodes& varNodes = fg.varNodes();
|
||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||
if (varNodes[i]->hasEvidence()) {
|
||||
if (Globals::verbosity > 1) {
|
||||
cout << "-> aborving evidence on ";
|
||||
cout << varNodes[i]->label() << " = " ;
|
||||
cout << varNodes[i]->getEvidence() << endl;
|
||||
}
|
||||
const vector<unsigned>& idxs =
|
||||
varFactors_.find (varNodes[i]->varId())->second;
|
||||
for (unsigned j = 0; j < idxs.size(); j++) {
|
||||
@ -108,12 +126,16 @@ VarElimSolver::findEliminationOrder (const VarIds& vids)
|
||||
void
|
||||
VarElimSolver::processFactorList (const VarIds& vids)
|
||||
{
|
||||
totalFactorSize_ = 0;
|
||||
largestFactorSize_ = 0;
|
||||
for (unsigned i = 0; i < elimOrder_.size(); i++) {
|
||||
if (Constants::DEBUG >= 3) {
|
||||
printActiveFactors();
|
||||
if (Globals::verbosity >= 2) {
|
||||
if (Globals::verbosity >= 3) {
|
||||
Util::printDashedLine();
|
||||
printActiveFactors();
|
||||
}
|
||||
cout << "-> summing out " ;
|
||||
VarNode* vn = fg.getVarNode (elimOrder_[i]);
|
||||
cout << vn->label() << endl;
|
||||
cout << fg.getVarNode (elimOrder_[i])->label() << endl;
|
||||
}
|
||||
eliminate (elimOrder_[i]);
|
||||
}
|
||||
@ -137,6 +159,11 @@ VarElimSolver::processFactorList (const VarIds& vids)
|
||||
finalFactor->reorderArguments (unobservedVids);
|
||||
finalFactor->normalize();
|
||||
factorList_.push_back (finalFactor);
|
||||
if (Globals::verbosity > 0) {
|
||||
cout << "total factor size: " << totalFactorSize_ << endl;
|
||||
cout << "largest factor size: " << largestFactorSize_ << endl;
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -158,6 +185,10 @@ VarElimSolver::eliminate (VarId elimVar)
|
||||
factorList_[idx] = 0;
|
||||
}
|
||||
}
|
||||
totalFactorSize_ += result->size();
|
||||
if (result->size() > largestFactorSize_) {
|
||||
largestFactorSize_ = result->size();
|
||||
}
|
||||
if (result != 0 && result->nrArguments() != 1) {
|
||||
result->sumOut (elimVar);
|
||||
factorList_.push_back (result);
|
||||
@ -175,14 +206,11 @@ VarElimSolver::eliminate (VarId elimVar)
|
||||
void
|
||||
VarElimSolver::printActiveFactors (void)
|
||||
{
|
||||
cout << endl;
|
||||
Util::printDashedLine();
|
||||
for (unsigned i = 0; i < factorList_.size(); i++) {
|
||||
if (factorList_[i] != 0) {
|
||||
cout << factorList_[i]->getLabel() << " " ;
|
||||
cout << factorList_[i]->params() << endl;
|
||||
}
|
||||
}
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
|
@ -35,8 +35,10 @@ class VarElimSolver : public Solver
|
||||
|
||||
void printActiveFactors (void);
|
||||
|
||||
vector<Factor*> factorList_;
|
||||
VarIds elimOrder_;
|
||||
Factors factorList_;
|
||||
VarIds elimOrder_;
|
||||
unsigned largestFactorSize_;
|
||||
unsigned totalFactorSize_;
|
||||
unordered_map<VarId, vector<unsigned>> varFactors_;
|
||||
};
|
||||
|
||||
|
Reference in New Issue
Block a user