Use camel case for constants and enumerators.

All capitals case should be reserved for macros and besides there is no big need to emphasize constness in general.
This commit is contained in:
Tiago Gomes 2013-02-13 18:54:15 +00:00
parent afd26ed9b4
commit ef4ebb4d7f
21 changed files with 172 additions and 172 deletions

View File

@ -62,7 +62,7 @@ BpLink::toString (void) const
double BeliefProp::accuracy_ = 0.0001; double BeliefProp::accuracy_ = 0.0001;
unsigned BeliefProp::maxIter_ = 1000; unsigned BeliefProp::maxIter_ = 1000;
MsgSchedule BeliefProp::schedule_ = MsgSchedule::SEQ_FIXED; MsgSchedule BeliefProp::schedule_ = MsgSchedule::seqFixedSch;
@ -106,10 +106,10 @@ BeliefProp::printSolverFlags (void) const
ss << "belief propagation [" ; ss << "belief propagation [" ;
ss << "bp_msg_schedule=" ; ss << "bp_msg_schedule=" ;
switch (schedule_) { switch (schedule_) {
case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break; case MsgSchedule::seqFixedSch: ss << "seq_fixed"; break;
case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break; case MsgSchedule::seqRandomSch: ss << "seq_random"; break;
case MsgSchedule::PARALLEL: ss << "parallel"; break; case MsgSchedule::parallelSch: ss << "parallel"; break;
case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break; case MsgSchedule::maxResidualSch: ss << "max_residual"; break;
} }
ss << ",bp_max_iter=" << Util::toString (maxIter_); ss << ",bp_max_iter=" << Util::toString (maxIter_);
ss << ",bp_accuracy=" << Util::toString (accuracy_); ss << ",bp_accuracy=" << Util::toString (accuracy_);
@ -259,15 +259,15 @@ BeliefProp::runSolver (void)
+ Util::toString (nIters_)); + Util::toString (nIters_));
} }
switch (schedule_) { switch (schedule_) {
case MsgSchedule::SEQ_RANDOM: case MsgSchedule::seqRandomSch:
std::random_shuffle (links_.begin(), links_.end()); std::random_shuffle (links_.begin(), links_.end());
// no break // no break
case MsgSchedule::SEQ_FIXED: case MsgSchedule::seqFixedSch:
for (size_t i = 0; i < links_.size(); i++) { for (size_t i = 0; i < links_.size(); i++) {
calculateAndUpdateMessage (links_[i]); calculateAndUpdateMessage (links_[i]);
} }
break; break;
case MsgSchedule::PARALLEL: case MsgSchedule::parallelSch:
for (size_t i = 0; i < links_.size(); i++) { for (size_t i = 0; i < links_.size(); i++) {
calculateMessage (links_[i]); calculateMessage (links_[i]);
} }
@ -275,7 +275,7 @@ BeliefProp::runSolver (void)
updateMessage(links_[i]); updateMessage(links_[i]);
} }
break; break;
case MsgSchedule::MAX_RESIDUAL: case MsgSchedule::maxResidualSch:
maxResidualSchedule(); maxResidualSchedule();
break; break;
} }
@ -380,13 +380,13 @@ BeliefProp::calcFactorToVarMsg (BpLink* link)
if (Globals::logDomain) { if (Globals::logDomain) {
for (size_t i = links.size(); i-- > 0; ) { for (size_t i = links.size(); i-- > 0; ) {
if (links[i]->varNode() != dst) { if (links[i]->varNode() != dst) {
if (Constants::SHOW_BP_CALCS) { if (Constants::showBpCalcs) {
std::cout << " message from " << links[i]->varNode()->label(); std::cout << " message from " << links[i]->varNode()->label();
std::cout << ": " ; std::cout << ": " ;
} }
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]), Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
reps, std::plus<double>()); reps, std::plus<double>());
if (Constants::SHOW_BP_CALCS) { if (Constants::showBpCalcs) {
std::cout << std::endl; std::cout << std::endl;
} }
} }
@ -395,13 +395,13 @@ BeliefProp::calcFactorToVarMsg (BpLink* link)
} else { } else {
for (size_t i = links.size(); i-- > 0; ) { for (size_t i = links.size(); i-- > 0; ) {
if (links[i]->varNode() != dst) { if (links[i]->varNode() != dst) {
if (Constants::SHOW_BP_CALCS) { if (Constants::showBpCalcs) {
std::cout << " message from " << links[i]->varNode()->label(); std::cout << " message from " << links[i]->varNode()->label();
std::cout << ": " ; std::cout << ": " ;
} }
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]), Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
reps, std::multiplies<double>()); reps, std::multiplies<double>());
if (Constants::SHOW_BP_CALCS) { if (Constants::showBpCalcs) {
std::cout << std::endl; std::cout << std::endl;
} }
} }
@ -411,19 +411,19 @@ BeliefProp::calcFactorToVarMsg (BpLink* link)
Factor result (src->factor().arguments(), Factor result (src->factor().arguments(),
src->factor().ranges(), msgProduct); src->factor().ranges(), msgProduct);
result.multiply (src->factor()); result.multiply (src->factor());
if (Constants::SHOW_BP_CALCS) { if (Constants::showBpCalcs) {
std::cout << " message product: " << msgProduct << std::endl; std::cout << " message product: " << msgProduct << std::endl;
std::cout << " original factor: " << src->factor().params(); std::cout << " original factor: " << src->factor().params();
std::cout << std::endl; std::cout << std::endl;
std::cout << " factor product: " << result.params() << std::endl; std::cout << " factor product: " << result.params() << std::endl;
} }
result.sumOutAllExcept (dst->varId()); result.sumOutAllExcept (dst->varId());
if (Constants::SHOW_BP_CALCS) { if (Constants::showBpCalcs) {
std::cout << " marginalized: " << result.params() << std::endl; std::cout << " marginalized: " << result.params() << std::endl;
} }
link->nextMessage() = result.params(); link->nextMessage() = result.params();
LogAware::normalize (link->nextMessage()); LogAware::normalize (link->nextMessage());
if (Constants::SHOW_BP_CALCS) { if (Constants::showBpCalcs) {
std::cout << " curr msg: " << link->message() << std::endl; std::cout << " curr msg: " << link->message() << std::endl;
std::cout << " next msg: " << link->nextMessage() << std::endl; std::cout << " next msg: " << link->nextMessage() << std::endl;
} }
@ -442,7 +442,7 @@ BeliefProp::getVarToFactorMsg (const BpLink* link) const
} else { } else {
msg.resize (src->range(), LogAware::one()); msg.resize (src->range(), LogAware::one());
} }
if (Constants::SHOW_BP_CALCS) { if (Constants::showBpCalcs) {
std::cout << msg; std::cout << msg;
} }
BpLinks::const_iterator it; BpLinks::const_iterator it;
@ -452,7 +452,7 @@ BeliefProp::getVarToFactorMsg (const BpLink* link) const
if (*it != link) { if (*it != link) {
msg += (*it)->message(); msg += (*it)->message();
} }
if (Constants::SHOW_BP_CALCS) { if (Constants::showBpCalcs) {
std::cout << " x " << (*it)->message(); std::cout << " x " << (*it)->message();
} }
} }
@ -461,12 +461,12 @@ BeliefProp::getVarToFactorMsg (const BpLink* link) const
if (*it != link) { if (*it != link) {
msg *= (*it)->message(); msg *= (*it)->message();
} }
if (Constants::SHOW_BP_CALCS) { if (Constants::showBpCalcs) {
std::cout << " x " << (*it)->message(); std::cout << " x " << (*it)->message();
} }
} }
} }
if (Constants::SHOW_BP_CALCS) { if (Constants::showBpCalcs) {
std::cout << " = " << msg; std::cout << " = " << msg;
} }
return msg; return msg;
@ -478,7 +478,7 @@ Params
BeliefProp::getJointByConditioning (const VarIds& jointVarIds) const BeliefProp::getJointByConditioning (const VarIds& jointVarIds) const
{ {
return GroundSolver::getJointByConditioning ( return GroundSolver::getJointByConditioning (
GroundSolverType::BP, fg, jointVarIds); GroundSolverType::bpSolver, fg, jointVarIds);
} }
@ -526,7 +526,7 @@ BeliefProp::converged (void)
return false; return false;
} }
bool converged = true; bool converged = true;
if (schedule_ == MsgSchedule::MAX_RESIDUAL) { if (schedule_ == MsgSchedule::maxResidualSch) {
double maxResidual = (*(sortedOrder_.begin()))->residual(); double maxResidual = (*(sortedOrder_.begin()))->residual();
if (maxResidual > accuracy_) { if (maxResidual > accuracy_) {
converged = false; converged = false;

View File

@ -12,10 +12,10 @@
namespace Horus { namespace Horus {
enum MsgSchedule { enum MsgSchedule {
SEQ_FIXED, seqFixedSch,
SEQ_RANDOM, seqRandomSch,
PARALLEL, parallelSch,
MAX_RESIDUAL maxResidualSch
}; };

View File

@ -45,10 +45,10 @@ CountingBp::printSolverFlags (void) const
ss << "counting bp [" ; ss << "counting bp [" ;
ss << "bp_msg_schedule=" ; ss << "bp_msg_schedule=" ;
switch (WeightedBp::msgSchedule()) { switch (WeightedBp::msgSchedule()) {
case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break; case MsgSchedule::seqFixedSch: ss << "seq_fixed"; break;
case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break; case MsgSchedule::seqRandomSch: ss << "seq_random"; break;
case MsgSchedule::PARALLEL: ss << "parallel"; break; case MsgSchedule::parallelSch: ss << "parallel"; break;
case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break; case MsgSchedule::maxResidualSch: ss << "max_residual"; break;
} }
ss << ",bp_max_iter=" << WeightedBp::maxIterations(); ss << ",bp_max_iter=" << WeightedBp::maxIterations();
ss << ",bp_accuracy=" << WeightedBp::accuracy(); ss << ",bp_accuracy=" << WeightedBp::accuracy();
@ -79,7 +79,7 @@ CountingBp::solveQuery (VarIds queryVids)
} }
if (idx == facNodes.size()) { if (idx == facNodes.size()) {
res = GroundSolver::getJointByConditioning ( res = GroundSolver::getJointByConditioning (
GroundSolverType::CBP, fg, queryVids); GroundSolverType::CbpSolver, fg, queryVids);
} else { } else {
VarIds reprArgs; VarIds reprArgs;
for (size_t i = 0; i < queryVids.size(); i++) { for (size_t i = 0; i < queryVids.size(); i++) {

View File

@ -6,7 +6,7 @@
namespace Horus { namespace Horus {
ElimHeuristic ElimGraph::elimHeuristic_ = MIN_NEIGHBORS; ElimHeuristic ElimGraph::elimHeuristic_ = minNeighborsEh;
ElimGraph::ElimGraph (const std::vector<Factor*>& factors) ElimGraph::ElimGraph (const std::vector<Factor*>& factors)
@ -137,7 +137,7 @@ ElimGraph::getEliminationOrder (
const Factors& factors, const Factors& factors,
VarIds excludedVids) VarIds excludedVids)
{ {
if (elimHeuristic_ == ElimHeuristic::SEQUENTIAL) { if (elimHeuristic_ == ElimHeuristic::sequentialEh) {
VarIds allVids; VarIds allVids;
Factors::const_iterator first = factors.begin(); Factors::const_iterator first = factors.begin();
Factors::const_iterator end = factors.end(); Factors::const_iterator end = factors.end();
@ -181,7 +181,7 @@ ElimGraph::getLowestCostNode (void) const
unsigned minCost = Util::maxUnsigned(); unsigned minCost = Util::maxUnsigned();
EGNeighs::const_iterator it; EGNeighs::const_iterator it;
switch (elimHeuristic_) { switch (elimHeuristic_) {
case MIN_NEIGHBORS: { case minNeighborsEh: {
for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) { for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) {
unsigned cost = getNeighborsCost (*it); unsigned cost = getNeighborsCost (*it);
if (cost < minCost) { if (cost < minCost) {
@ -190,7 +190,7 @@ ElimGraph::getLowestCostNode (void) const
} }
}} }}
break; break;
case MIN_WEIGHT: { case minWeightEh: {
for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) { for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) {
unsigned cost = getWeightCost (*it); unsigned cost = getWeightCost (*it);
if (cost < minCost) { if (cost < minCost) {
@ -199,7 +199,7 @@ ElimGraph::getLowestCostNode (void) const
} }
}} }}
break; break;
case MIN_FILL: { case minFillEh: {
for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) { for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) {
unsigned cost = getFillCost (*it); unsigned cost = getFillCost (*it);
if (cost < minCost) { if (cost < minCost) {
@ -208,7 +208,7 @@ ElimGraph::getLowestCostNode (void) const
} }
}} }}
break; break;
case WEIGHTED_MIN_FILL: { case weightedMinFillEh: {
for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) { for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) {
unsigned cost = getWeightedFillCost (*it); unsigned cost = getWeightedFillCost (*it);
if (cost < minCost) { if (cost < minCost) {

View File

@ -19,11 +19,11 @@ typedef TinySet<EgNode*> EGNeighs;
enum ElimHeuristic { enum ElimHeuristic {
SEQUENTIAL, sequentialEh,
MIN_NEIGHBORS, minNeighborsEh,
MIN_WEIGHT, minWeightEh,
MIN_FILL, minFillEh,
WEIGHTED_MIN_FILL weightedMinFillEh
}; };

View File

@ -19,7 +19,7 @@ class FacNode;
class VarNode : public Var { class VarNode : public Var {
public: public:
VarNode (VarId varId, unsigned nrStates, VarNode (VarId varId, unsigned nrStates,
int evidence = Constants::NO_EVIDENCE) int evidence = Constants::unobserved)
: Var (varId, nrStates, evidence) { } : Var (varId, nrStates, evidence) { }
VarNode (const Var* v) : Var (v) { } VarNode (const Var* v) : Var (v) { }

View File

@ -31,7 +31,7 @@ GroundSolver::printAnswer (const VarIds& vids)
Util::getStateLines (unobservedVars); Util::getStateLines (unobservedVars);
for (size_t i = 0; i < res.size(); i++) { for (size_t i = 0; i < res.size(); i++) {
std::cout << "P(" << stateLines[i] << ") = " ; std::cout << "P(" << stateLines[i] << ") = " ;
std::cout << std::setprecision (Constants::PRECISION) << res[i]; std::cout << std::setprecision (Constants::precision) << res[i];
std::cout << std::endl; std::cout << std::endl;
} }
std::cout << std::endl; std::cout << std::endl;
@ -66,9 +66,9 @@ GroundSolver::getJointByConditioning (
GroundSolver* solver = 0; GroundSolver* solver = 0;
switch (solverType) { switch (solverType) {
case GroundSolverType::BP: solver = new BeliefProp (fg); break; case GroundSolverType::bpSolver: solver = new BeliefProp (fg); break;
case GroundSolverType::CBP: solver = new CountingBp (fg); break; case GroundSolverType::CbpSolver: solver = new CountingBp (fg); break;
case GroundSolverType::VE: solver = new VarElim (fg); break; case GroundSolverType::veSolver: solver = new VarElim (fg); break;
} }
Params prevBeliefs = solver->solveQuery ({jointVarIds[0]}); Params prevBeliefs = solver->solveQuery ({jointVarIds[0]});
VarIds observedVids = {jointVars[0]->varId()}; VarIds observedVids = {jointVars[0]->varId()};
@ -89,9 +89,9 @@ GroundSolver::getJointByConditioning (
} }
delete solver; delete solver;
switch (solverType) { switch (solverType) {
case GroundSolverType::BP: solver = new BeliefProp (fg); break; case GroundSolverType::bpSolver: solver = new BeliefProp (fg); break;
case GroundSolverType::CBP: solver = new CountingBp (fg); break; case GroundSolverType::CbpSolver: solver = new CountingBp (fg); break;
case GroundSolverType::VE: solver = new VarElim (fg); break; case GroundSolverType::veSolver: solver = new VarElim (fg); break;
} }
Params beliefs = solver->solveQuery ({jointVarIds[i]}); Params beliefs = solver->solveQuery ({jointVarIds[i]});
for (size_t k = 0; k < beliefs.size(); k++) { for (size_t k = 0; k < beliefs.size(); k++) {

View File

@ -35,16 +35,16 @@ typedef unsigned long long ullong;
enum LiftedSolverType { enum LiftedSolverType {
LVE, // generalized counting first-order variable elimination (GC-FOVE) lveSolver, // generalized counting first-order variable elimination (GC-FOveSolver)
LBP, // lifted first-order belief propagation lbpSolver, // lifted first-order belief propagation
LKC // lifted first-order knowledge compilation lkcSolver // lifted first-order knowledge compilation
}; };
enum GroundSolverType { enum GroundSolverType {
VE, // variable elimination veSolver, // variable elimination
BP, // belief propagation bpSolver, // belief propagation
CBP // counting belief propagation CbpSolver // counting belief propagation
}; };
@ -64,12 +64,12 @@ extern GroundSolverType groundSolver;
namespace Constants { namespace Constants {
// show message calculation for belief propagation // show message calculation for belief propagation
const bool SHOW_BP_CALCS = false; const bool showBpCalcs = false;
const int NO_EVIDENCE = -1; const int unobserved = -1;
// number of digits to show when printing a parameter // number of digits to show when printing a parameter
const unsigned PRECISION = 6; const unsigned precision = 6;
} }

View File

@ -18,7 +18,7 @@ Horus::VarIds readQueryAndEvidence (
void runSolver (const Horus::FactorGraph&, const Horus::VarIds&); void runSolver (const Horus::FactorGraph&, const Horus::VarIds&);
const std::string USAGE = "usage: ./hcli [solver=hve|bp|cbp] \ const std::string usage = "usage: ./hcli [solver=hve|bp|cbp] \
[<OPTION>=<VALUE>]... <FILE> [<VAR>|<VAR>=<EVIDENCE>]... " ; [<OPTION>=<VALUE>]... <FILE> [<VAR>|<VAR>=<EVIDENCE>]... " ;
@ -27,7 +27,7 @@ main (int argc, const char* argv[])
{ {
if (argc <= 1) { if (argc <= 1) {
std::cerr << "Error: no probabilistic graphical model was given." ; std::cerr << "Error: no probabilistic graphical model was given." ;
std::cerr << std::endl << USAGE << std::endl; std::cerr << std::endl << usage << std::endl;
exit (EXIT_FAILURE); exit (EXIT_FAILURE);
} }
int idx = readHorusFlags (argc, argv); int idx = readHorusFlags (argc, argv);
@ -72,12 +72,12 @@ readHorusFlags (int argc, const char* argv[])
std::string rightArg = arg.substr (pos + 1); std::string rightArg = arg.substr (pos + 1);
if (leftArg.empty()) { if (leftArg.empty()) {
std::cerr << "Error: missing left argument." << std::endl; std::cerr << "Error: missing left argument." << std::endl;
std::cerr << USAGE << std::endl; std::cerr << usage << std::endl;
exit (EXIT_FAILURE); exit (EXIT_FAILURE);
} }
if (rightArg.empty()) { if (rightArg.empty()) {
std::cerr << "Error: missing right argument." << std::endl; std::cerr << "Error: missing right argument." << std::endl;
std::cerr << USAGE << std::endl; std::cerr << usage << std::endl;
exit (EXIT_FAILURE); exit (EXIT_FAILURE);
} }
Horus::Util::setHorusFlag (leftArg, rightArg); Horus::Util::setHorusFlag (leftArg, rightArg);
@ -136,7 +136,7 @@ readQueryAndEvidence (
std::string rightArg = arg.substr (pos + 1); std::string rightArg = arg.substr (pos + 1);
if (leftArg.empty()) { if (leftArg.empty()) {
std::cerr << "Error: missing left argument." << std::endl; std::cerr << "Error: missing left argument." << std::endl;
std::cerr << USAGE << std::endl; std::cerr << usage << std::endl;
exit (EXIT_FAILURE); exit (EXIT_FAILURE);
} }
if (Horus::Util::isInteger (leftArg) == false) { if (Horus::Util::isInteger (leftArg) == false) {
@ -153,7 +153,7 @@ readQueryAndEvidence (
} }
if (rightArg.empty()) { if (rightArg.empty()) {
std::cerr << "Error: missing right argument." << std::endl; std::cerr << "Error: missing right argument." << std::endl;
std::cerr << USAGE << std::endl; std::cerr << usage << std::endl;
exit (EXIT_FAILURE); exit (EXIT_FAILURE);
} }
if (Horus::Util::isInteger (rightArg) == false) { if (Horus::Util::isInteger (rightArg) == false) {
@ -183,13 +183,13 @@ runSolver (
{ {
Horus::GroundSolver* solver = 0; Horus::GroundSolver* solver = 0;
switch (Horus::Globals::groundSolver) { switch (Horus::Globals::groundSolver) {
case Horus::GroundSolverType::VE: case Horus::GroundSolverType::veSolver:
solver = new Horus::VarElim (fg); solver = new Horus::VarElim (fg);
break; break;
case Horus::GroundSolverType::BP: case Horus::GroundSolverType::bpSolver:
solver = new Horus::BeliefProp (fg); solver = new Horus::BeliefProp (fg);
break; break;
case Horus::GroundSolverType::CBP: case Horus::GroundSolverType::CbpSolver:
solver = new Horus::CountingBp (fg); solver = new Horus::CountingBp (fg);
break; break;
default: default:

View File

@ -140,9 +140,9 @@ runLiftedSolver (void)
LiftedSolver* solver = 0; LiftedSolver* solver = 0;
switch (Globals::liftedSolver) { switch (Globals::liftedSolver) {
case LiftedSolverType::LVE: solver = new LiftedVe (pfListCopy); break; case LiftedSolverType::lveSolver: solver = new LiftedVe (pfListCopy); break;
case LiftedSolverType::LBP: solver = new LiftedBp (pfListCopy); break; case LiftedSolverType::lbpSolver: solver = new LiftedBp (pfListCopy); break;
case LiftedSolverType::LKC: solver = new LiftedKc (pfListCopy); break; case LiftedSolverType::lkcSolver: solver = new LiftedKc (pfListCopy); break;
} }
if (Globals::verbosity > 0) { if (Globals::verbosity > 0) {
@ -214,9 +214,9 @@ runGroundSolver (void)
GroundSolver* solver = 0; GroundSolver* solver = 0;
CountingBp::setFindIdenticalFactorsFlag (false); CountingBp::setFindIdenticalFactorsFlag (false);
switch (Globals::groundSolver) { switch (Globals::groundSolver) {
case GroundSolverType::VE: solver = new VarElim (*mfg); break; case GroundSolverType::veSolver: solver = new VarElim (*mfg); break;
case GroundSolverType::BP: solver = new BeliefProp (*mfg); break; case GroundSolverType::bpSolver: solver = new BeliefProp (*mfg); break;
case GroundSolverType::CBP: solver = new CountingBp (*mfg); break; case GroundSolverType::CbpSolver: solver = new CountingBp (*mfg); break;
} }
if (Globals::verbosity > 0) { if (Globals::verbosity > 0) {

View File

@ -70,10 +70,10 @@ LiftedBp::printSolverFlags (void) const
ss << "lifted bp [" ; ss << "lifted bp [" ;
ss << "bp_msg_schedule=" ; ss << "bp_msg_schedule=" ;
switch (WeightedBp::msgSchedule()) { switch (WeightedBp::msgSchedule()) {
case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break; case MsgSchedule::seqFixedSch: ss << "seq_fixed"; break;
case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break; case MsgSchedule::seqRandomSch: ss << "seq_random"; break;
case MsgSchedule::PARALLEL: ss << "parallel"; break; case MsgSchedule::parallelSch: ss << "parallel"; break;
case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break; case MsgSchedule::maxResidualSch: ss << "max_residual"; break;
} }
ss << ",bp_max_iter=" << WeightedBp::maxIterations(); ss << ",bp_max_iter=" << WeightedBp::maxIterations();
ss << ",bp_accuracy=" << WeightedBp::accuracy(); ss << ",bp_accuracy=" << WeightedBp::accuracy();

View File

@ -815,7 +815,7 @@ LiftedCircuit::smoothCircuit (CircuitNode* node)
switch (getCircuitNodeType (node)) { switch (getCircuitNodeType (node)) {
case CircuitNodeType::OR_NODE: { case CircuitNodeType::orCnt: {
OrNode* casted = dynamic_cast<OrNode*>(node); OrNode* casted = dynamic_cast<OrNode*>(node);
LitLvTypesSet lids1 = smoothCircuit (*casted->leftBranch()); LitLvTypesSet lids1 = smoothCircuit (*casted->leftBranch());
LitLvTypesSet lids2 = smoothCircuit (*casted->rightBranch()); LitLvTypesSet lids2 = smoothCircuit (*casted->rightBranch());
@ -828,7 +828,7 @@ LiftedCircuit::smoothCircuit (CircuitNode* node)
break; break;
} }
case CircuitNodeType::AND_NODE: { case CircuitNodeType::andCnt: {
AndNode* casted = dynamic_cast<AndNode*>(node); AndNode* casted = dynamic_cast<AndNode*>(node);
LitLvTypesSet lids1 = smoothCircuit (*casted->leftBranch()); LitLvTypesSet lids1 = smoothCircuit (*casted->leftBranch());
LitLvTypesSet lids2 = smoothCircuit (*casted->rightBranch()); LitLvTypesSet lids2 = smoothCircuit (*casted->rightBranch());
@ -837,7 +837,7 @@ LiftedCircuit::smoothCircuit (CircuitNode* node)
break; break;
} }
case CircuitNodeType::SET_OR_NODE: { case CircuitNodeType::setOrCnt: {
SetOrNode* casted = dynamic_cast<SetOrNode*>(node); SetOrNode* casted = dynamic_cast<SetOrNode*>(node);
propagLits = smoothCircuit (*casted->follow()); propagLits = smoothCircuit (*casted->follow());
TinySet<std::pair<LiteralId,unsigned>> litSet; TinySet<std::pair<LiteralId,unsigned>> litSet;
@ -875,13 +875,13 @@ LiftedCircuit::smoothCircuit (CircuitNode* node)
break; break;
} }
case CircuitNodeType::SET_AND_NODE: { case CircuitNodeType::setAndCnt: {
SetAndNode* casted = dynamic_cast<SetAndNode*>(node); SetAndNode* casted = dynamic_cast<SetAndNode*>(node);
propagLits = smoothCircuit (*casted->follow()); propagLits = smoothCircuit (*casted->follow());
break; break;
} }
case CircuitNodeType::INC_EXC_NODE: { case CircuitNodeType::incExcCnt: {
IncExcNode* casted = dynamic_cast<IncExcNode*>(node); IncExcNode* casted = dynamic_cast<IncExcNode*>(node);
LitLvTypesSet lids1 = smoothCircuit (*casted->plus1Branch()); LitLvTypesSet lids1 = smoothCircuit (*casted->plus1Branch());
LitLvTypesSet lids2 = smoothCircuit (*casted->plus2Branch()); LitLvTypesSet lids2 = smoothCircuit (*casted->plus2Branch());
@ -894,7 +894,7 @@ LiftedCircuit::smoothCircuit (CircuitNode* node)
break; break;
} }
case CircuitNodeType::LEAF_NODE: { case CircuitNodeType::leafCnt: {
LeafNode* casted = dynamic_cast<LeafNode*>(node); LeafNode* casted = dynamic_cast<LeafNode*>(node);
propagLits.insert (LitLvTypes ( propagLits.insert (LitLvTypes (
casted->clause()->literals()[0].lid(), casted->clause()->literals()[0].lid(),
@ -933,9 +933,9 @@ LiftedCircuit::createSmoothNode (
Clause* c = lwcnf_->createClause (lid); Clause* c = lwcnf_->createClause (lid);
for (size_t j = 0; j < types.size(); j++) { for (size_t j = 0; j < types.size(); j++) {
LogVar X = c->literals().front().logVars()[j]; LogVar X = c->literals().front().logVars()[j];
if (types[j] == LogVarType::POS_LV) { if (types[j] == LogVarType::posLvt) {
c->addPosCountedLogVar (X); c->addPosCountedLogVar (X);
} else if (types[j] == LogVarType::NEG_LV) { } else if (types[j] == LogVarType::negLvt) {
c->addNegCountedLogVar (X); c->addNegCountedLogVar (X);
} }
} }
@ -960,8 +960,8 @@ LiftedCircuit::getAllPossibleTypes (unsigned nrLogVars) const
if (nrLogVars == 0) { if (nrLogVars == 0) {
// do nothing // do nothing
} else if (nrLogVars == 1) { } else if (nrLogVars == 1) {
res.push_back ({ LogVarType::POS_LV }); res.push_back ({ LogVarType::posLvt });
res.push_back ({ LogVarType::NEG_LV }); res.push_back ({ LogVarType::negLvt });
} else { } else {
Ranges ranges (nrLogVars, 2); Ranges ranges (nrLogVars, 2);
Indexer indexer (ranges); Indexer indexer (ranges);
@ -969,9 +969,9 @@ LiftedCircuit::getAllPossibleTypes (unsigned nrLogVars) const
LogVarTypes types; LogVarTypes types;
for (size_t i = 0; i < nrLogVars; i++) { for (size_t i = 0; i < nrLogVars; i++) {
if (indexer[i] == 0) { if (indexer[i] == 0) {
types.push_back (LogVarType::POS_LV); types.push_back (LogVarType::posLvt);
} else { } else {
types.push_back (LogVarType::NEG_LV); types.push_back (LogVarType::negLvt);
} }
} }
res.push_back (types); res.push_back (types);
@ -989,13 +989,13 @@ LiftedCircuit::containsTypes (
const LogVarTypes& typesB) const const LogVarTypes& typesB) const
{ {
for (size_t i = 0; i < typesA.size(); i++) { for (size_t i = 0; i < typesA.size(); i++) {
if (typesA[i] == LogVarType::FULL_LV) { if (typesA[i] == LogVarType::fullLvt) {
} else if (typesA[i] == LogVarType::POS_LV } else if (typesA[i] == LogVarType::posLvt
&& typesB[i] == LogVarType::POS_LV) { && typesB[i] == LogVarType::posLvt) {
} else if (typesA[i] == LogVarType::NEG_LV } else if (typesA[i] == LogVarType::negLvt
&& typesB[i] == LogVarType::NEG_LV) { && typesB[i] == LogVarType::negLvt) {
} else { } else {
return false; return false;
@ -1009,25 +1009,25 @@ LiftedCircuit::containsTypes (
CircuitNodeType CircuitNodeType
LiftedCircuit::getCircuitNodeType (const CircuitNode* node) const LiftedCircuit::getCircuitNodeType (const CircuitNode* node) const
{ {
CircuitNodeType type = CircuitNodeType::OR_NODE; CircuitNodeType type = CircuitNodeType::orCnt;
if (dynamic_cast<const OrNode*>(node)) { if (dynamic_cast<const OrNode*>(node)) {
type = CircuitNodeType::OR_NODE; type = CircuitNodeType::orCnt;
} else if (dynamic_cast<const AndNode*>(node)) { } else if (dynamic_cast<const AndNode*>(node)) {
type = CircuitNodeType::AND_NODE; type = CircuitNodeType::andCnt;
} else if (dynamic_cast<const SetOrNode*>(node)) { } else if (dynamic_cast<const SetOrNode*>(node)) {
type = CircuitNodeType::SET_OR_NODE; type = CircuitNodeType::setOrCnt;
} else if (dynamic_cast<const SetAndNode*>(node)) { } else if (dynamic_cast<const SetAndNode*>(node)) {
type = CircuitNodeType::SET_AND_NODE; type = CircuitNodeType::setAndCnt;
} else if (dynamic_cast<const IncExcNode*>(node)) { } else if (dynamic_cast<const IncExcNode*>(node)) {
type = CircuitNodeType::INC_EXC_NODE; type = CircuitNodeType::incExcCnt;
} else if (dynamic_cast<const LeafNode*>(node)) { } else if (dynamic_cast<const LeafNode*>(node)) {
type = CircuitNodeType::LEAF_NODE; type = CircuitNodeType::leafCnt;
} else if (dynamic_cast<const SmoothNode*>(node)) { } else if (dynamic_cast<const SmoothNode*>(node)) {
type = CircuitNodeType::SMOOTH_NODE; type = CircuitNodeType::smoothCnt;
} else if (dynamic_cast<const TrueNode*>(node)) { } else if (dynamic_cast<const TrueNode*>(node)) {
type = CircuitNodeType::TRUE_NODE; type = CircuitNodeType::trueCnt;
} else if (dynamic_cast<const CompilationFailedNode*>(node)) { } else if (dynamic_cast<const CompilationFailedNode*>(node)) {
type = CircuitNodeType::COMPILATION_FAILED_NODE; type = CircuitNodeType::compilationFailedCnt;
} else { } else {
assert (false); assert (false);
} }
@ -1050,7 +1050,7 @@ LiftedCircuit::exportToGraphViz (CircuitNode* node, std::ofstream& os)
switch (getCircuitNodeType (node)) { switch (getCircuitNodeType (node)) {
case OR_NODE: { case orCnt: {
OrNode* casted = dynamic_cast<OrNode*>(node); OrNode* casted = dynamic_cast<OrNode*>(node);
printClauses (casted, os); printClauses (casted, os);
@ -1075,7 +1075,7 @@ LiftedCircuit::exportToGraphViz (CircuitNode* node, std::ofstream& os)
break; break;
} }
case AND_NODE: { case andCnt: {
AndNode* casted = dynamic_cast<AndNode*>(node); AndNode* casted = dynamic_cast<AndNode*>(node);
printClauses (casted, os); printClauses (casted, os);
@ -1100,7 +1100,7 @@ LiftedCircuit::exportToGraphViz (CircuitNode* node, std::ofstream& os)
break; break;
} }
case SET_OR_NODE: { case setOrCnt: {
SetOrNode* casted = dynamic_cast<SetOrNode*>(node); SetOrNode* casted = dynamic_cast<SetOrNode*>(node);
printClauses (casted, os); printClauses (casted, os);
@ -1119,7 +1119,7 @@ LiftedCircuit::exportToGraphViz (CircuitNode* node, std::ofstream& os)
break; break;
} }
case SET_AND_NODE: { case setAndCnt: {
SetAndNode* casted = dynamic_cast<SetAndNode*>(node); SetAndNode* casted = dynamic_cast<SetAndNode*>(node);
printClauses (casted, os); printClauses (casted, os);
@ -1138,7 +1138,7 @@ LiftedCircuit::exportToGraphViz (CircuitNode* node, std::ofstream& os)
break; break;
} }
case INC_EXC_NODE: { case incExcCnt: {
IncExcNode* casted = dynamic_cast<IncExcNode*>(node); IncExcNode* casted = dynamic_cast<IncExcNode*>(node);
printClauses (casted, os); printClauses (casted, os);
@ -1169,24 +1169,24 @@ LiftedCircuit::exportToGraphViz (CircuitNode* node, std::ofstream& os)
break; break;
} }
case LEAF_NODE: { case leafCnt: {
printClauses (node, os, "style=filled,fillcolor=palegreen,"); printClauses (node, os, "style=filled,fillcolor=palegreen,");
break; break;
} }
case SMOOTH_NODE: { case smoothCnt: {
printClauses (node, os, "style=filled,fillcolor=lightblue,"); printClauses (node, os, "style=filled,fillcolor=lightblue,");
break; break;
} }
case TRUE_NODE: { case trueCnt: {
os << escapeNode (node); os << escapeNode (node);
os << " [shape=box,label=\"\"]" ; os << " [shape=box,label=\"\"]" ;
os << std::endl; os << std::endl;
break; break;
} }
case COMPILATION_FAILED_NODE: { case compilationFailedCnt: {
printClauses (node, os, "style=filled,fillcolor=salmon,"); printClauses (node, os, "style=filled,fillcolor=salmon,");
break; break;
} }
@ -1227,9 +1227,9 @@ LiftedCircuit::printClauses (
Clauses clauses; Clauses clauses;
if (Util::contains (originClausesMap_, node)) { if (Util::contains (originClausesMap_, node)) {
clauses = originClausesMap_[node]; clauses = originClausesMap_[node];
} else if (getCircuitNodeType (node) == CircuitNodeType::LEAF_NODE) { } else if (getCircuitNodeType (node) == CircuitNodeType::leafCnt) {
clauses = { (dynamic_cast<LeafNode*>(node))->clause() } ; clauses = { (dynamic_cast<LeafNode*>(node))->clause() } ;
} else if (getCircuitNodeType (node) == CircuitNodeType::SMOOTH_NODE) { } else if (getCircuitNodeType (node) == CircuitNodeType::smoothCnt) {
clauses = (dynamic_cast<SmoothNode*>(node))->clauses(); clauses = (dynamic_cast<SmoothNode*>(node))->clauses();
} }
assert (clauses.empty() == false); assert (clauses.empty() == false);

View File

@ -14,15 +14,15 @@
namespace Horus { namespace Horus {
enum CircuitNodeType { enum CircuitNodeType {
OR_NODE, orCnt,
AND_NODE, andCnt,
SET_OR_NODE, setOrCnt,
SET_AND_NODE, setAndCnt,
INC_EXC_NODE, incExcCnt,
LEAF_NODE, leafCnt,
SMOOTH_NODE, smoothCnt,
TRUE_NODE, trueCnt,
COMPILATION_FAILED_NODE compilationFailedCnt
}; };

View File

@ -149,7 +149,7 @@ LiftedOperations::absorveEvidence (
} }
if (Globals::verbosity > 2 && obsFormulas.empty() == false) { if (Globals::verbosity > 2 && obsFormulas.empty() == false) {
Util::printAsteriskLine(); Util::printAsteriskLine();
std::cout << "AFTER EVIDENCE ABSORVED" << std::endl; std::cout << "AFTER EVIDENCE ABSORveSolverD" << std::endl;
for (size_t i = 0; i < obsFormulas.size(); i++) { for (size_t i = 0; i < obsFormulas.size(); i++) {
std::cout << " -> " << obsFormulas[i] << std::endl; std::cout << " -> " << obsFormulas[i] << std::endl;
} }

View File

@ -269,11 +269,11 @@ Clause::logVarTypes (size_t litIdx) const
const LogVars& lvs = literals_[litIdx].logVars(); const LogVars& lvs = literals_[litIdx].logVars();
for (size_t i = 0; i < lvs.size(); i++) { for (size_t i = 0; i < lvs.size(); i++) {
if (posCountedLvs_.contains (lvs[i])) { if (posCountedLvs_.contains (lvs[i])) {
types.push_back (LogVarType::POS_LV); types.push_back (LogVarType::posLvt);
} else if (negCountedLvs_.contains (lvs[i])) { } else if (negCountedLvs_.contains (lvs[i])) {
types.push_back (LogVarType::NEG_LV); types.push_back (LogVarType::negLvt);
} else { } else {
types.push_back (LogVarType::FULL_LV); types.push_back (LogVarType::fullLvt);
} }
} }
return types; return types;
@ -384,9 +384,9 @@ operator<< (std::ostream& os, const LitLvTypes& lit)
os << lit.lid_ << "<" ; os << lit.lid_ << "<" ;
for (size_t i = 0; i < lit.lvTypes_.size(); i++) { for (size_t i = 0; i < lit.lvTypes_.size(); i++) {
switch (lit.lvTypes_[i]) { switch (lit.lvTypes_[i]) {
case LogVarType::FULL_LV: os << "F" ; break; case LogVarType::fullLvt: os << "F" ; break;
case LogVarType::POS_LV: os << "P" ; break; case LogVarType::posLvt: os << "P" ; break;
case LogVarType::NEG_LV: os << "N" ; break; case LogVarType::negLvt: os << "N" ; break;
} }
} }
os << ">" ; os << ">" ;
@ -398,7 +398,7 @@ operator<< (std::ostream& os, const LitLvTypes& lit)
void void
LitLvTypes::setAllFullLogVars (void) LitLvTypes::setAllFullLogVars (void)
{ {
std::fill (lvTypes_.begin(), lvTypes_.end(), LogVarType::FULL_LV); std::fill (lvTypes_.begin(), lvTypes_.end(), LogVarType::fullLvt);
} }

View File

@ -16,9 +16,9 @@ namespace Horus {
class ParfactorList; class ParfactorList;
enum LogVarType { enum LogVarType {
FULL_LV, fullLvt,
POS_LV, posLvt,
NEG_LV negLvt
}; };
typedef long LiteralId; typedef long LiteralId;

View File

@ -12,9 +12,9 @@ bool logDomain = false;
unsigned verbosity = 0; unsigned verbosity = 0;
LiftedSolverType liftedSolver = LiftedSolverType::LVE; LiftedSolverType liftedSolver = LiftedSolverType::lveSolver;
GroundSolverType groundSolver = GroundSolverType::VE; GroundSolverType groundSolver = GroundSolverType::veSolver;
} }
@ -203,15 +203,15 @@ setHorusFlag (std::string option, std::string value)
{ {
bool returnVal = true; bool returnVal = true;
if (option == "lifted_solver") { if (option == "lifted_solver") {
if (value == "lve") Globals::liftedSolver = LiftedSolverType::LVE; if (value == "lve") Globals::liftedSolver = LiftedSolverType::lveSolver;
else if (value == "lbp") Globals::liftedSolver = LiftedSolverType::LBP; else if (value == "lbp") Globals::liftedSolver = LiftedSolverType::lbpSolver;
else if (value == "lkc") Globals::liftedSolver = LiftedSolverType::LKC; else if (value == "lkc") Globals::liftedSolver = LiftedSolverType::lkcSolver;
else returnVal = invalidValue (option, value); else returnVal = invalidValue (option, value);
} else if (option == "ground_solver" || option == "solver") { } else if (option == "ground_solver" || option == "solver") {
if (value == "hve") Globals::groundSolver = GroundSolverType::VE; if (value == "hve") Globals::groundSolver = GroundSolverType::veSolver;
else if (value == "bp") Globals::groundSolver = GroundSolverType::BP; else if (value == "bp") Globals::groundSolver = GroundSolverType::bpSolver;
else if (value == "cbp") Globals::groundSolver = GroundSolverType::CBP; else if (value == "cbp") Globals::groundSolver = GroundSolverType::CbpSolver;
else returnVal = invalidValue (option, value); else returnVal = invalidValue (option, value);
} else if (option == "verbosity") { } else if (option == "verbosity") {
@ -226,27 +226,27 @@ setHorusFlag (std::string option, std::string value)
} else if (option == "hve_elim_heuristic") { } else if (option == "hve_elim_heuristic") {
if (value == "sequential") if (value == "sequential")
ElimGraph::setElimHeuristic (ElimHeuristic::SEQUENTIAL); ElimGraph::setElimHeuristic (ElimHeuristic::sequentialEh);
else if (value == "min_neighbors") else if (value == "min_neighbors")
ElimGraph::setElimHeuristic (ElimHeuristic::MIN_NEIGHBORS); ElimGraph::setElimHeuristic (ElimHeuristic::minNeighborsEh);
else if (value == "min_weight") else if (value == "min_weight")
ElimGraph::setElimHeuristic (ElimHeuristic::MIN_WEIGHT); ElimGraph::setElimHeuristic (ElimHeuristic::minWeightEh);
else if (value == "min_fill") else if (value == "min_fill")
ElimGraph::setElimHeuristic (ElimHeuristic::MIN_FILL); ElimGraph::setElimHeuristic (ElimHeuristic::minFillEh);
else if (value == "weighted_min_fill") else if (value == "weighted_min_fill")
ElimGraph::setElimHeuristic (ElimHeuristic::WEIGHTED_MIN_FILL); ElimGraph::setElimHeuristic (ElimHeuristic::weightedMinFillEh);
else else
returnVal = invalidValue (option, value); returnVal = invalidValue (option, value);
} else if (option == "bp_msg_schedule") { } else if (option == "bp_msg_schedule") {
if (value == "seq_fixed") if (value == "seq_fixed")
BeliefProp::setMsgSchedule (MsgSchedule::SEQ_FIXED); BeliefProp::setMsgSchedule (MsgSchedule::seqFixedSch);
else if (value == "seq_random") else if (value == "seq_random")
BeliefProp::setMsgSchedule (MsgSchedule::SEQ_RANDOM); BeliefProp::setMsgSchedule (MsgSchedule::seqRandomSch);
else if (value == "parallel") else if (value == "parallel")
BeliefProp::setMsgSchedule (MsgSchedule::PARALLEL); BeliefProp::setMsgSchedule (MsgSchedule::parallelSch);
else if (value == "max_residual") else if (value == "max_residual")
BeliefProp::setMsgSchedule (MsgSchedule::MAX_RESIDUAL); BeliefProp::setMsgSchedule (MsgSchedule::maxResidualSch);
else else
returnVal = invalidValue (option, value); returnVal = invalidValue (option, value);

View File

@ -87,7 +87,7 @@ unsigned nrDigits (int);
bool isInteger (const std::string&); bool isInteger (const std::string&);
std::string parametersToString ( std::string parametersToString (
const Params&, unsigned = Constants::PRECISION); const Params&, unsigned = Constants::precision);
std::vector<std::string> getStateLines (const Vars&); std::vector<std::string> getStateLines (const Vars&);

View File

@ -24,7 +24,7 @@ class Var {
public: public:
Var (const Var*); Var (const Var*);
Var (VarId, unsigned, int = Constants::NO_EVIDENCE); Var (VarId, unsigned, int = Constants::unobserved);
virtual ~Var (void) { }; virtual ~Var (void) { };
@ -79,7 +79,7 @@ class Var {
inline bool inline bool
Var::hasEvidence (void) const Var::hasEvidence (void) const
{ {
return evidence_ != Constants::NO_EVIDENCE; return evidence_ != Constants::unobserved;
} }

View File

@ -43,11 +43,11 @@ VarElim::printSolverFlags (void) const
ss << "variable elimination [" ; ss << "variable elimination [" ;
ss << "elim_heuristic=" ; ss << "elim_heuristic=" ;
switch (ElimGraph::elimHeuristic()) { switch (ElimGraph::elimHeuristic()) {
case ElimHeuristic::SEQUENTIAL: ss << "sequential"; break; case ElimHeuristic::sequentialEh: ss << "sequential"; break;
case ElimHeuristic::MIN_NEIGHBORS: ss << "min_neighbors"; break; case ElimHeuristic::minNeighborsEh: ss << "min_neighbors"; break;
case ElimHeuristic::MIN_WEIGHT: ss << "min_weight"; break; case ElimHeuristic::minWeightEh: ss << "min_weight"; break;
case ElimHeuristic::MIN_FILL: ss << "min_fill"; break; case ElimHeuristic::minFillEh: ss << "min_fill"; break;
case ElimHeuristic::WEIGHTED_MIN_FILL: ss << "weighted_min_fill"; break; case ElimHeuristic::weightedMinFillEh: ss << "weighted_min_fill"; break;
} }
ss << ",log_domain=" << Util::toString (Globals::logDomain); ss << ",log_domain=" << Util::toString (Globals::logDomain);
ss << "]" ; ss << "]" ;

View File

@ -177,13 +177,13 @@ WeightedBp::calcFactorToVarMsg (BpLink* _link)
for (size_t i = links.size(); i-- > 0; ) { for (size_t i = links.size(); i-- > 0; ) {
const WeightedLink* l = static_cast<const WeightedLink*> (links[i]); const WeightedLink* l = static_cast<const WeightedLink*> (links[i]);
if ( ! (l->varNode() == dst && l->index() == link->index())) { if ( ! (l->varNode() == dst && l->index() == link->index())) {
if (Constants::SHOW_BP_CALCS) { if (Constants::showBpCalcs) {
std::cout << " message from " << links[i]->varNode()->label(); std::cout << " message from " << links[i]->varNode()->label();
std::cout << ": " ; std::cout << ": " ;
} }
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]), Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
reps, std::plus<double>()); reps, std::plus<double>());
if (Constants::SHOW_BP_CALCS) { if (Constants::showBpCalcs) {
std::cout << std::endl; std::cout << std::endl;
} }
} }
@ -193,13 +193,13 @@ WeightedBp::calcFactorToVarMsg (BpLink* _link)
for (size_t i = links.size(); i-- > 0; ) { for (size_t i = links.size(); i-- > 0; ) {
const WeightedLink* l = static_cast<const WeightedLink*> (links[i]); const WeightedLink* l = static_cast<const WeightedLink*> (links[i]);
if ( ! (l->varNode() == dst && l->index() == link->index())) { if ( ! (l->varNode() == dst && l->index() == link->index())) {
if (Constants::SHOW_BP_CALCS) { if (Constants::showBpCalcs) {
std::cout << " message from " << links[i]->varNode()->label(); std::cout << " message from " << links[i]->varNode()->label();
std::cout << ": " ; std::cout << ": " ;
} }
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]), Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
reps, std::multiplies<double>()); reps, std::multiplies<double>());
if (Constants::SHOW_BP_CALCS) { if (Constants::showBpCalcs) {
std::cout << std::endl; std::cout << std::endl;
} }
} }
@ -214,7 +214,7 @@ WeightedBp::calcFactorToVarMsg (BpLink* _link)
} else { } else {
result.params() *= src->factor().params(); result.params() *= src->factor().params();
} }
if (Constants::SHOW_BP_CALCS) { if (Constants::showBpCalcs) {
std::cout << " message product: " ; std::cout << " message product: " ;
std::cout << msgProduct << std::endl; std::cout << msgProduct << std::endl;
std::cout << " original factor: " ; std::cout << " original factor: " ;
@ -223,13 +223,13 @@ WeightedBp::calcFactorToVarMsg (BpLink* _link)
std::cout << result.params() << std::endl; std::cout << result.params() << std::endl;
} }
result.sumOutAllExceptIndex (link->index()); result.sumOutAllExceptIndex (link->index());
if (Constants::SHOW_BP_CALCS) { if (Constants::showBpCalcs) {
std::cout << " marginalized: " ; std::cout << " marginalized: " ;
std::cout << result.params() << std::endl; std::cout << result.params() << std::endl;
} }
link->nextMessage() = result.params(); link->nextMessage() = result.params();
LogAware::normalize (link->nextMessage()); LogAware::normalize (link->nextMessage());
if (Constants::SHOW_BP_CALCS) { if (Constants::showBpCalcs) {
std::cout << " curr msg: " ; std::cout << " curr msg: " ;
std::cout << link->message() << std::endl; std::cout << link->message() << std::endl;
std::cout << " next msg: " ; std::cout << " next msg: " ;
@ -249,14 +249,14 @@ WeightedBp::getVarToFactorMsg (const BpLink* _link) const
if (src->hasEvidence()) { if (src->hasEvidence()) {
msg.resize (src->range(), LogAware::noEvidence()); msg.resize (src->range(), LogAware::noEvidence());
double value = link->message()[src->getEvidence()]; double value = link->message()[src->getEvidence()];
if (Constants::SHOW_BP_CALCS) { if (Constants::showBpCalcs) {
msg[src->getEvidence()] = value; msg[src->getEvidence()] = value;
std::cout << msg << "^" << link->weight() << "-1" ; std::cout << msg << "^" << link->weight() << "-1" ;
} }
msg[src->getEvidence()] = LogAware::pow (value, link->weight() - 1); msg[src->getEvidence()] = LogAware::pow (value, link->weight() - 1);
} else { } else {
msg = link->message(); msg = link->message();
if (Constants::SHOW_BP_CALCS) { if (Constants::showBpCalcs) {
std::cout << msg << "^" << link->weight() << "-1" ; std::cout << msg << "^" << link->weight() << "-1" ;
} }
LogAware::pow (msg, link->weight() - 1); LogAware::pow (msg, link->weight() - 1);
@ -274,13 +274,13 @@ WeightedBp::getVarToFactorMsg (const BpLink* _link) const
WeightedLink* l = static_cast<WeightedLink*> (links[i]); WeightedLink* l = static_cast<WeightedLink*> (links[i]);
if ( ! (l->facNode() == dst && l->index() == link->index())) { if ( ! (l->facNode() == dst && l->index() == link->index())) {
msg *= l->powMessage(); msg *= l->powMessage();
if (Constants::SHOW_BP_CALCS) { if (Constants::showBpCalcs) {
std::cout << " x " << l->nextMessage() << "^" << link->weight(); std::cout << " x " << l->nextMessage() << "^" << link->weight();
} }
} }
} }
} }
if (Constants::SHOW_BP_CALCS) { if (Constants::showBpCalcs) {
std::cout << " = " << msg; std::cout << " = " << msg;
} }
return msg; return msg;