Use camel case for constants and enumerators.
All capitals case should be reserved for macros and besides there is no big need to emphasize constness in general.
This commit is contained in:
parent
afd26ed9b4
commit
ef4ebb4d7f
@ -62,7 +62,7 @@ BpLink::toString (void) const
|
||||
|
||||
double BeliefProp::accuracy_ = 0.0001;
|
||||
unsigned BeliefProp::maxIter_ = 1000;
|
||||
MsgSchedule BeliefProp::schedule_ = MsgSchedule::SEQ_FIXED;
|
||||
MsgSchedule BeliefProp::schedule_ = MsgSchedule::seqFixedSch;
|
||||
|
||||
|
||||
|
||||
@ -106,10 +106,10 @@ BeliefProp::printSolverFlags (void) const
|
||||
ss << "belief propagation [" ;
|
||||
ss << "bp_msg_schedule=" ;
|
||||
switch (schedule_) {
|
||||
case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break;
|
||||
case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break;
|
||||
case MsgSchedule::PARALLEL: ss << "parallel"; break;
|
||||
case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break;
|
||||
case MsgSchedule::seqFixedSch: ss << "seq_fixed"; break;
|
||||
case MsgSchedule::seqRandomSch: ss << "seq_random"; break;
|
||||
case MsgSchedule::parallelSch: ss << "parallel"; break;
|
||||
case MsgSchedule::maxResidualSch: ss << "max_residual"; break;
|
||||
}
|
||||
ss << ",bp_max_iter=" << Util::toString (maxIter_);
|
||||
ss << ",bp_accuracy=" << Util::toString (accuracy_);
|
||||
@ -259,15 +259,15 @@ BeliefProp::runSolver (void)
|
||||
+ Util::toString (nIters_));
|
||||
}
|
||||
switch (schedule_) {
|
||||
case MsgSchedule::SEQ_RANDOM:
|
||||
case MsgSchedule::seqRandomSch:
|
||||
std::random_shuffle (links_.begin(), links_.end());
|
||||
// no break
|
||||
case MsgSchedule::SEQ_FIXED:
|
||||
case MsgSchedule::seqFixedSch:
|
||||
for (size_t i = 0; i < links_.size(); i++) {
|
||||
calculateAndUpdateMessage (links_[i]);
|
||||
}
|
||||
break;
|
||||
case MsgSchedule::PARALLEL:
|
||||
case MsgSchedule::parallelSch:
|
||||
for (size_t i = 0; i < links_.size(); i++) {
|
||||
calculateMessage (links_[i]);
|
||||
}
|
||||
@ -275,7 +275,7 @@ BeliefProp::runSolver (void)
|
||||
updateMessage(links_[i]);
|
||||
}
|
||||
break;
|
||||
case MsgSchedule::MAX_RESIDUAL:
|
||||
case MsgSchedule::maxResidualSch:
|
||||
maxResidualSchedule();
|
||||
break;
|
||||
}
|
||||
@ -380,13 +380,13 @@ BeliefProp::calcFactorToVarMsg (BpLink* link)
|
||||
if (Globals::logDomain) {
|
||||
for (size_t i = links.size(); i-- > 0; ) {
|
||||
if (links[i]->varNode() != dst) {
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
if (Constants::showBpCalcs) {
|
||||
std::cout << " message from " << links[i]->varNode()->label();
|
||||
std::cout << ": " ;
|
||||
}
|
||||
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
|
||||
reps, std::plus<double>());
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
if (Constants::showBpCalcs) {
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
@ -395,13 +395,13 @@ BeliefProp::calcFactorToVarMsg (BpLink* link)
|
||||
} else {
|
||||
for (size_t i = links.size(); i-- > 0; ) {
|
||||
if (links[i]->varNode() != dst) {
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
if (Constants::showBpCalcs) {
|
||||
std::cout << " message from " << links[i]->varNode()->label();
|
||||
std::cout << ": " ;
|
||||
}
|
||||
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
|
||||
reps, std::multiplies<double>());
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
if (Constants::showBpCalcs) {
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
@ -411,19 +411,19 @@ BeliefProp::calcFactorToVarMsg (BpLink* link)
|
||||
Factor result (src->factor().arguments(),
|
||||
src->factor().ranges(), msgProduct);
|
||||
result.multiply (src->factor());
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
if (Constants::showBpCalcs) {
|
||||
std::cout << " message product: " << msgProduct << std::endl;
|
||||
std::cout << " original factor: " << src->factor().params();
|
||||
std::cout << std::endl;
|
||||
std::cout << " factor product: " << result.params() << std::endl;
|
||||
}
|
||||
result.sumOutAllExcept (dst->varId());
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
if (Constants::showBpCalcs) {
|
||||
std::cout << " marginalized: " << result.params() << std::endl;
|
||||
}
|
||||
link->nextMessage() = result.params();
|
||||
LogAware::normalize (link->nextMessage());
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
if (Constants::showBpCalcs) {
|
||||
std::cout << " curr msg: " << link->message() << std::endl;
|
||||
std::cout << " next msg: " << link->nextMessage() << std::endl;
|
||||
}
|
||||
@ -442,7 +442,7 @@ BeliefProp::getVarToFactorMsg (const BpLink* link) const
|
||||
} else {
|
||||
msg.resize (src->range(), LogAware::one());
|
||||
}
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
if (Constants::showBpCalcs) {
|
||||
std::cout << msg;
|
||||
}
|
||||
BpLinks::const_iterator it;
|
||||
@ -452,7 +452,7 @@ BeliefProp::getVarToFactorMsg (const BpLink* link) const
|
||||
if (*it != link) {
|
||||
msg += (*it)->message();
|
||||
}
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
if (Constants::showBpCalcs) {
|
||||
std::cout << " x " << (*it)->message();
|
||||
}
|
||||
}
|
||||
@ -461,12 +461,12 @@ BeliefProp::getVarToFactorMsg (const BpLink* link) const
|
||||
if (*it != link) {
|
||||
msg *= (*it)->message();
|
||||
}
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
if (Constants::showBpCalcs) {
|
||||
std::cout << " x " << (*it)->message();
|
||||
}
|
||||
}
|
||||
}
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
if (Constants::showBpCalcs) {
|
||||
std::cout << " = " << msg;
|
||||
}
|
||||
return msg;
|
||||
@ -478,7 +478,7 @@ Params
|
||||
BeliefProp::getJointByConditioning (const VarIds& jointVarIds) const
|
||||
{
|
||||
return GroundSolver::getJointByConditioning (
|
||||
GroundSolverType::BP, fg, jointVarIds);
|
||||
GroundSolverType::bpSolver, fg, jointVarIds);
|
||||
}
|
||||
|
||||
|
||||
@ -526,7 +526,7 @@ BeliefProp::converged (void)
|
||||
return false;
|
||||
}
|
||||
bool converged = true;
|
||||
if (schedule_ == MsgSchedule::MAX_RESIDUAL) {
|
||||
if (schedule_ == MsgSchedule::maxResidualSch) {
|
||||
double maxResidual = (*(sortedOrder_.begin()))->residual();
|
||||
if (maxResidual > accuracy_) {
|
||||
converged = false;
|
||||
|
@ -12,10 +12,10 @@
|
||||
namespace Horus {
|
||||
|
||||
enum MsgSchedule {
|
||||
SEQ_FIXED,
|
||||
SEQ_RANDOM,
|
||||
PARALLEL,
|
||||
MAX_RESIDUAL
|
||||
seqFixedSch,
|
||||
seqRandomSch,
|
||||
parallelSch,
|
||||
maxResidualSch
|
||||
};
|
||||
|
||||
|
||||
|
@ -45,10 +45,10 @@ CountingBp::printSolverFlags (void) const
|
||||
ss << "counting bp [" ;
|
||||
ss << "bp_msg_schedule=" ;
|
||||
switch (WeightedBp::msgSchedule()) {
|
||||
case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break;
|
||||
case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break;
|
||||
case MsgSchedule::PARALLEL: ss << "parallel"; break;
|
||||
case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break;
|
||||
case MsgSchedule::seqFixedSch: ss << "seq_fixed"; break;
|
||||
case MsgSchedule::seqRandomSch: ss << "seq_random"; break;
|
||||
case MsgSchedule::parallelSch: ss << "parallel"; break;
|
||||
case MsgSchedule::maxResidualSch: ss << "max_residual"; break;
|
||||
}
|
||||
ss << ",bp_max_iter=" << WeightedBp::maxIterations();
|
||||
ss << ",bp_accuracy=" << WeightedBp::accuracy();
|
||||
@ -79,7 +79,7 @@ CountingBp::solveQuery (VarIds queryVids)
|
||||
}
|
||||
if (idx == facNodes.size()) {
|
||||
res = GroundSolver::getJointByConditioning (
|
||||
GroundSolverType::CBP, fg, queryVids);
|
||||
GroundSolverType::CbpSolver, fg, queryVids);
|
||||
} else {
|
||||
VarIds reprArgs;
|
||||
for (size_t i = 0; i < queryVids.size(); i++) {
|
||||
|
@ -6,7 +6,7 @@
|
||||
|
||||
namespace Horus {
|
||||
|
||||
ElimHeuristic ElimGraph::elimHeuristic_ = MIN_NEIGHBORS;
|
||||
ElimHeuristic ElimGraph::elimHeuristic_ = minNeighborsEh;
|
||||
|
||||
|
||||
ElimGraph::ElimGraph (const std::vector<Factor*>& factors)
|
||||
@ -137,7 +137,7 @@ ElimGraph::getEliminationOrder (
|
||||
const Factors& factors,
|
||||
VarIds excludedVids)
|
||||
{
|
||||
if (elimHeuristic_ == ElimHeuristic::SEQUENTIAL) {
|
||||
if (elimHeuristic_ == ElimHeuristic::sequentialEh) {
|
||||
VarIds allVids;
|
||||
Factors::const_iterator first = factors.begin();
|
||||
Factors::const_iterator end = factors.end();
|
||||
@ -181,7 +181,7 @@ ElimGraph::getLowestCostNode (void) const
|
||||
unsigned minCost = Util::maxUnsigned();
|
||||
EGNeighs::const_iterator it;
|
||||
switch (elimHeuristic_) {
|
||||
case MIN_NEIGHBORS: {
|
||||
case minNeighborsEh: {
|
||||
for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) {
|
||||
unsigned cost = getNeighborsCost (*it);
|
||||
if (cost < minCost) {
|
||||
@ -190,7 +190,7 @@ ElimGraph::getLowestCostNode (void) const
|
||||
}
|
||||
}}
|
||||
break;
|
||||
case MIN_WEIGHT: {
|
||||
case minWeightEh: {
|
||||
for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) {
|
||||
unsigned cost = getWeightCost (*it);
|
||||
if (cost < minCost) {
|
||||
@ -199,7 +199,7 @@ ElimGraph::getLowestCostNode (void) const
|
||||
}
|
||||
}}
|
||||
break;
|
||||
case MIN_FILL: {
|
||||
case minFillEh: {
|
||||
for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) {
|
||||
unsigned cost = getFillCost (*it);
|
||||
if (cost < minCost) {
|
||||
@ -208,7 +208,7 @@ ElimGraph::getLowestCostNode (void) const
|
||||
}
|
||||
}}
|
||||
break;
|
||||
case WEIGHTED_MIN_FILL: {
|
||||
case weightedMinFillEh: {
|
||||
for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) {
|
||||
unsigned cost = getWeightedFillCost (*it);
|
||||
if (cost < minCost) {
|
||||
|
@ -19,11 +19,11 @@ typedef TinySet<EgNode*> EGNeighs;
|
||||
|
||||
|
||||
enum ElimHeuristic {
|
||||
SEQUENTIAL,
|
||||
MIN_NEIGHBORS,
|
||||
MIN_WEIGHT,
|
||||
MIN_FILL,
|
||||
WEIGHTED_MIN_FILL
|
||||
sequentialEh,
|
||||
minNeighborsEh,
|
||||
minWeightEh,
|
||||
minFillEh,
|
||||
weightedMinFillEh
|
||||
};
|
||||
|
||||
|
||||
|
@ -19,7 +19,7 @@ class FacNode;
|
||||
class VarNode : public Var {
|
||||
public:
|
||||
VarNode (VarId varId, unsigned nrStates,
|
||||
int evidence = Constants::NO_EVIDENCE)
|
||||
int evidence = Constants::unobserved)
|
||||
: Var (varId, nrStates, evidence) { }
|
||||
|
||||
VarNode (const Var* v) : Var (v) { }
|
||||
|
@ -31,7 +31,7 @@ GroundSolver::printAnswer (const VarIds& vids)
|
||||
Util::getStateLines (unobservedVars);
|
||||
for (size_t i = 0; i < res.size(); i++) {
|
||||
std::cout << "P(" << stateLines[i] << ") = " ;
|
||||
std::cout << std::setprecision (Constants::PRECISION) << res[i];
|
||||
std::cout << std::setprecision (Constants::precision) << res[i];
|
||||
std::cout << std::endl;
|
||||
}
|
||||
std::cout << std::endl;
|
||||
@ -66,9 +66,9 @@ GroundSolver::getJointByConditioning (
|
||||
|
||||
GroundSolver* solver = 0;
|
||||
switch (solverType) {
|
||||
case GroundSolverType::BP: solver = new BeliefProp (fg); break;
|
||||
case GroundSolverType::CBP: solver = new CountingBp (fg); break;
|
||||
case GroundSolverType::VE: solver = new VarElim (fg); break;
|
||||
case GroundSolverType::bpSolver: solver = new BeliefProp (fg); break;
|
||||
case GroundSolverType::CbpSolver: solver = new CountingBp (fg); break;
|
||||
case GroundSolverType::veSolver: solver = new VarElim (fg); break;
|
||||
}
|
||||
Params prevBeliefs = solver->solveQuery ({jointVarIds[0]});
|
||||
VarIds observedVids = {jointVars[0]->varId()};
|
||||
@ -89,9 +89,9 @@ GroundSolver::getJointByConditioning (
|
||||
}
|
||||
delete solver;
|
||||
switch (solverType) {
|
||||
case GroundSolverType::BP: solver = new BeliefProp (fg); break;
|
||||
case GroundSolverType::CBP: solver = new CountingBp (fg); break;
|
||||
case GroundSolverType::VE: solver = new VarElim (fg); break;
|
||||
case GroundSolverType::bpSolver: solver = new BeliefProp (fg); break;
|
||||
case GroundSolverType::CbpSolver: solver = new CountingBp (fg); break;
|
||||
case GroundSolverType::veSolver: solver = new VarElim (fg); break;
|
||||
}
|
||||
Params beliefs = solver->solveQuery ({jointVarIds[i]});
|
||||
for (size_t k = 0; k < beliefs.size(); k++) {
|
||||
|
@ -35,16 +35,16 @@ typedef unsigned long long ullong;
|
||||
|
||||
|
||||
enum LiftedSolverType {
|
||||
LVE, // generalized counting first-order variable elimination (GC-FOVE)
|
||||
LBP, // lifted first-order belief propagation
|
||||
LKC // lifted first-order knowledge compilation
|
||||
lveSolver, // generalized counting first-order variable elimination (GC-FOveSolver)
|
||||
lbpSolver, // lifted first-order belief propagation
|
||||
lkcSolver // lifted first-order knowledge compilation
|
||||
};
|
||||
|
||||
|
||||
enum GroundSolverType {
|
||||
VE, // variable elimination
|
||||
BP, // belief propagation
|
||||
CBP // counting belief propagation
|
||||
veSolver, // variable elimination
|
||||
bpSolver, // belief propagation
|
||||
CbpSolver // counting belief propagation
|
||||
};
|
||||
|
||||
|
||||
@ -64,12 +64,12 @@ extern GroundSolverType groundSolver;
|
||||
namespace Constants {
|
||||
|
||||
// show message calculation for belief propagation
|
||||
const bool SHOW_BP_CALCS = false;
|
||||
const bool showBpCalcs = false;
|
||||
|
||||
const int NO_EVIDENCE = -1;
|
||||
const int unobserved = -1;
|
||||
|
||||
// number of digits to show when printing a parameter
|
||||
const unsigned PRECISION = 6;
|
||||
const unsigned precision = 6;
|
||||
|
||||
}
|
||||
|
||||
|
@ -18,7 +18,7 @@ Horus::VarIds readQueryAndEvidence (
|
||||
|
||||
void runSolver (const Horus::FactorGraph&, const Horus::VarIds&);
|
||||
|
||||
const std::string USAGE = "usage: ./hcli [solver=hve|bp|cbp] \
|
||||
const std::string usage = "usage: ./hcli [solver=hve|bp|cbp] \
|
||||
[<OPTION>=<VALUE>]... <FILE> [<VAR>|<VAR>=<EVIDENCE>]... " ;
|
||||
|
||||
|
||||
@ -27,7 +27,7 @@ main (int argc, const char* argv[])
|
||||
{
|
||||
if (argc <= 1) {
|
||||
std::cerr << "Error: no probabilistic graphical model was given." ;
|
||||
std::cerr << std::endl << USAGE << std::endl;
|
||||
std::cerr << std::endl << usage << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
int idx = readHorusFlags (argc, argv);
|
||||
@ -72,12 +72,12 @@ readHorusFlags (int argc, const char* argv[])
|
||||
std::string rightArg = arg.substr (pos + 1);
|
||||
if (leftArg.empty()) {
|
||||
std::cerr << "Error: missing left argument." << std::endl;
|
||||
std::cerr << USAGE << std::endl;
|
||||
std::cerr << usage << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
if (rightArg.empty()) {
|
||||
std::cerr << "Error: missing right argument." << std::endl;
|
||||
std::cerr << USAGE << std::endl;
|
||||
std::cerr << usage << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
Horus::Util::setHorusFlag (leftArg, rightArg);
|
||||
@ -136,7 +136,7 @@ readQueryAndEvidence (
|
||||
std::string rightArg = arg.substr (pos + 1);
|
||||
if (leftArg.empty()) {
|
||||
std::cerr << "Error: missing left argument." << std::endl;
|
||||
std::cerr << USAGE << std::endl;
|
||||
std::cerr << usage << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
if (Horus::Util::isInteger (leftArg) == false) {
|
||||
@ -153,7 +153,7 @@ readQueryAndEvidence (
|
||||
}
|
||||
if (rightArg.empty()) {
|
||||
std::cerr << "Error: missing right argument." << std::endl;
|
||||
std::cerr << USAGE << std::endl;
|
||||
std::cerr << usage << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
if (Horus::Util::isInteger (rightArg) == false) {
|
||||
@ -183,13 +183,13 @@ runSolver (
|
||||
{
|
||||
Horus::GroundSolver* solver = 0;
|
||||
switch (Horus::Globals::groundSolver) {
|
||||
case Horus::GroundSolverType::VE:
|
||||
case Horus::GroundSolverType::veSolver:
|
||||
solver = new Horus::VarElim (fg);
|
||||
break;
|
||||
case Horus::GroundSolverType::BP:
|
||||
case Horus::GroundSolverType::bpSolver:
|
||||
solver = new Horus::BeliefProp (fg);
|
||||
break;
|
||||
case Horus::GroundSolverType::CBP:
|
||||
case Horus::GroundSolverType::CbpSolver:
|
||||
solver = new Horus::CountingBp (fg);
|
||||
break;
|
||||
default:
|
||||
|
@ -140,9 +140,9 @@ runLiftedSolver (void)
|
||||
|
||||
LiftedSolver* solver = 0;
|
||||
switch (Globals::liftedSolver) {
|
||||
case LiftedSolverType::LVE: solver = new LiftedVe (pfListCopy); break;
|
||||
case LiftedSolverType::LBP: solver = new LiftedBp (pfListCopy); break;
|
||||
case LiftedSolverType::LKC: solver = new LiftedKc (pfListCopy); break;
|
||||
case LiftedSolverType::lveSolver: solver = new LiftedVe (pfListCopy); break;
|
||||
case LiftedSolverType::lbpSolver: solver = new LiftedBp (pfListCopy); break;
|
||||
case LiftedSolverType::lkcSolver: solver = new LiftedKc (pfListCopy); break;
|
||||
}
|
||||
|
||||
if (Globals::verbosity > 0) {
|
||||
@ -214,9 +214,9 @@ runGroundSolver (void)
|
||||
GroundSolver* solver = 0;
|
||||
CountingBp::setFindIdenticalFactorsFlag (false);
|
||||
switch (Globals::groundSolver) {
|
||||
case GroundSolverType::VE: solver = new VarElim (*mfg); break;
|
||||
case GroundSolverType::BP: solver = new BeliefProp (*mfg); break;
|
||||
case GroundSolverType::CBP: solver = new CountingBp (*mfg); break;
|
||||
case GroundSolverType::veSolver: solver = new VarElim (*mfg); break;
|
||||
case GroundSolverType::bpSolver: solver = new BeliefProp (*mfg); break;
|
||||
case GroundSolverType::CbpSolver: solver = new CountingBp (*mfg); break;
|
||||
}
|
||||
|
||||
if (Globals::verbosity > 0) {
|
||||
|
@ -70,10 +70,10 @@ LiftedBp::printSolverFlags (void) const
|
||||
ss << "lifted bp [" ;
|
||||
ss << "bp_msg_schedule=" ;
|
||||
switch (WeightedBp::msgSchedule()) {
|
||||
case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break;
|
||||
case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break;
|
||||
case MsgSchedule::PARALLEL: ss << "parallel"; break;
|
||||
case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break;
|
||||
case MsgSchedule::seqFixedSch: ss << "seq_fixed"; break;
|
||||
case MsgSchedule::seqRandomSch: ss << "seq_random"; break;
|
||||
case MsgSchedule::parallelSch: ss << "parallel"; break;
|
||||
case MsgSchedule::maxResidualSch: ss << "max_residual"; break;
|
||||
}
|
||||
ss << ",bp_max_iter=" << WeightedBp::maxIterations();
|
||||
ss << ",bp_accuracy=" << WeightedBp::accuracy();
|
||||
|
@ -815,7 +815,7 @@ LiftedCircuit::smoothCircuit (CircuitNode* node)
|
||||
|
||||
switch (getCircuitNodeType (node)) {
|
||||
|
||||
case CircuitNodeType::OR_NODE: {
|
||||
case CircuitNodeType::orCnt: {
|
||||
OrNode* casted = dynamic_cast<OrNode*>(node);
|
||||
LitLvTypesSet lids1 = smoothCircuit (*casted->leftBranch());
|
||||
LitLvTypesSet lids2 = smoothCircuit (*casted->rightBranch());
|
||||
@ -828,7 +828,7 @@ LiftedCircuit::smoothCircuit (CircuitNode* node)
|
||||
break;
|
||||
}
|
||||
|
||||
case CircuitNodeType::AND_NODE: {
|
||||
case CircuitNodeType::andCnt: {
|
||||
AndNode* casted = dynamic_cast<AndNode*>(node);
|
||||
LitLvTypesSet lids1 = smoothCircuit (*casted->leftBranch());
|
||||
LitLvTypesSet lids2 = smoothCircuit (*casted->rightBranch());
|
||||
@ -837,7 +837,7 @@ LiftedCircuit::smoothCircuit (CircuitNode* node)
|
||||
break;
|
||||
}
|
||||
|
||||
case CircuitNodeType::SET_OR_NODE: {
|
||||
case CircuitNodeType::setOrCnt: {
|
||||
SetOrNode* casted = dynamic_cast<SetOrNode*>(node);
|
||||
propagLits = smoothCircuit (*casted->follow());
|
||||
TinySet<std::pair<LiteralId,unsigned>> litSet;
|
||||
@ -875,13 +875,13 @@ LiftedCircuit::smoothCircuit (CircuitNode* node)
|
||||
break;
|
||||
}
|
||||
|
||||
case CircuitNodeType::SET_AND_NODE: {
|
||||
case CircuitNodeType::setAndCnt: {
|
||||
SetAndNode* casted = dynamic_cast<SetAndNode*>(node);
|
||||
propagLits = smoothCircuit (*casted->follow());
|
||||
break;
|
||||
}
|
||||
|
||||
case CircuitNodeType::INC_EXC_NODE: {
|
||||
case CircuitNodeType::incExcCnt: {
|
||||
IncExcNode* casted = dynamic_cast<IncExcNode*>(node);
|
||||
LitLvTypesSet lids1 = smoothCircuit (*casted->plus1Branch());
|
||||
LitLvTypesSet lids2 = smoothCircuit (*casted->plus2Branch());
|
||||
@ -894,7 +894,7 @@ LiftedCircuit::smoothCircuit (CircuitNode* node)
|
||||
break;
|
||||
}
|
||||
|
||||
case CircuitNodeType::LEAF_NODE: {
|
||||
case CircuitNodeType::leafCnt: {
|
||||
LeafNode* casted = dynamic_cast<LeafNode*>(node);
|
||||
propagLits.insert (LitLvTypes (
|
||||
casted->clause()->literals()[0].lid(),
|
||||
@ -933,9 +933,9 @@ LiftedCircuit::createSmoothNode (
|
||||
Clause* c = lwcnf_->createClause (lid);
|
||||
for (size_t j = 0; j < types.size(); j++) {
|
||||
LogVar X = c->literals().front().logVars()[j];
|
||||
if (types[j] == LogVarType::POS_LV) {
|
||||
if (types[j] == LogVarType::posLvt) {
|
||||
c->addPosCountedLogVar (X);
|
||||
} else if (types[j] == LogVarType::NEG_LV) {
|
||||
} else if (types[j] == LogVarType::negLvt) {
|
||||
c->addNegCountedLogVar (X);
|
||||
}
|
||||
}
|
||||
@ -960,8 +960,8 @@ LiftedCircuit::getAllPossibleTypes (unsigned nrLogVars) const
|
||||
if (nrLogVars == 0) {
|
||||
// do nothing
|
||||
} else if (nrLogVars == 1) {
|
||||
res.push_back ({ LogVarType::POS_LV });
|
||||
res.push_back ({ LogVarType::NEG_LV });
|
||||
res.push_back ({ LogVarType::posLvt });
|
||||
res.push_back ({ LogVarType::negLvt });
|
||||
} else {
|
||||
Ranges ranges (nrLogVars, 2);
|
||||
Indexer indexer (ranges);
|
||||
@ -969,9 +969,9 @@ LiftedCircuit::getAllPossibleTypes (unsigned nrLogVars) const
|
||||
LogVarTypes types;
|
||||
for (size_t i = 0; i < nrLogVars; i++) {
|
||||
if (indexer[i] == 0) {
|
||||
types.push_back (LogVarType::POS_LV);
|
||||
types.push_back (LogVarType::posLvt);
|
||||
} else {
|
||||
types.push_back (LogVarType::NEG_LV);
|
||||
types.push_back (LogVarType::negLvt);
|
||||
}
|
||||
}
|
||||
res.push_back (types);
|
||||
@ -989,13 +989,13 @@ LiftedCircuit::containsTypes (
|
||||
const LogVarTypes& typesB) const
|
||||
{
|
||||
for (size_t i = 0; i < typesA.size(); i++) {
|
||||
if (typesA[i] == LogVarType::FULL_LV) {
|
||||
if (typesA[i] == LogVarType::fullLvt) {
|
||||
|
||||
} else if (typesA[i] == LogVarType::POS_LV
|
||||
&& typesB[i] == LogVarType::POS_LV) {
|
||||
} else if (typesA[i] == LogVarType::posLvt
|
||||
&& typesB[i] == LogVarType::posLvt) {
|
||||
|
||||
} else if (typesA[i] == LogVarType::NEG_LV
|
||||
&& typesB[i] == LogVarType::NEG_LV) {
|
||||
} else if (typesA[i] == LogVarType::negLvt
|
||||
&& typesB[i] == LogVarType::negLvt) {
|
||||
|
||||
} else {
|
||||
return false;
|
||||
@ -1009,25 +1009,25 @@ LiftedCircuit::containsTypes (
|
||||
CircuitNodeType
|
||||
LiftedCircuit::getCircuitNodeType (const CircuitNode* node) const
|
||||
{
|
||||
CircuitNodeType type = CircuitNodeType::OR_NODE;
|
||||
CircuitNodeType type = CircuitNodeType::orCnt;
|
||||
if (dynamic_cast<const OrNode*>(node)) {
|
||||
type = CircuitNodeType::OR_NODE;
|
||||
type = CircuitNodeType::orCnt;
|
||||
} else if (dynamic_cast<const AndNode*>(node)) {
|
||||
type = CircuitNodeType::AND_NODE;
|
||||
type = CircuitNodeType::andCnt;
|
||||
} else if (dynamic_cast<const SetOrNode*>(node)) {
|
||||
type = CircuitNodeType::SET_OR_NODE;
|
||||
type = CircuitNodeType::setOrCnt;
|
||||
} else if (dynamic_cast<const SetAndNode*>(node)) {
|
||||
type = CircuitNodeType::SET_AND_NODE;
|
||||
type = CircuitNodeType::setAndCnt;
|
||||
} else if (dynamic_cast<const IncExcNode*>(node)) {
|
||||
type = CircuitNodeType::INC_EXC_NODE;
|
||||
type = CircuitNodeType::incExcCnt;
|
||||
} else if (dynamic_cast<const LeafNode*>(node)) {
|
||||
type = CircuitNodeType::LEAF_NODE;
|
||||
type = CircuitNodeType::leafCnt;
|
||||
} else if (dynamic_cast<const SmoothNode*>(node)) {
|
||||
type = CircuitNodeType::SMOOTH_NODE;
|
||||
type = CircuitNodeType::smoothCnt;
|
||||
} else if (dynamic_cast<const TrueNode*>(node)) {
|
||||
type = CircuitNodeType::TRUE_NODE;
|
||||
type = CircuitNodeType::trueCnt;
|
||||
} else if (dynamic_cast<const CompilationFailedNode*>(node)) {
|
||||
type = CircuitNodeType::COMPILATION_FAILED_NODE;
|
||||
type = CircuitNodeType::compilationFailedCnt;
|
||||
} else {
|
||||
assert (false);
|
||||
}
|
||||
@ -1050,7 +1050,7 @@ LiftedCircuit::exportToGraphViz (CircuitNode* node, std::ofstream& os)
|
||||
|
||||
switch (getCircuitNodeType (node)) {
|
||||
|
||||
case OR_NODE: {
|
||||
case orCnt: {
|
||||
OrNode* casted = dynamic_cast<OrNode*>(node);
|
||||
printClauses (casted, os);
|
||||
|
||||
@ -1075,7 +1075,7 @@ LiftedCircuit::exportToGraphViz (CircuitNode* node, std::ofstream& os)
|
||||
break;
|
||||
}
|
||||
|
||||
case AND_NODE: {
|
||||
case andCnt: {
|
||||
AndNode* casted = dynamic_cast<AndNode*>(node);
|
||||
printClauses (casted, os);
|
||||
|
||||
@ -1100,7 +1100,7 @@ LiftedCircuit::exportToGraphViz (CircuitNode* node, std::ofstream& os)
|
||||
break;
|
||||
}
|
||||
|
||||
case SET_OR_NODE: {
|
||||
case setOrCnt: {
|
||||
SetOrNode* casted = dynamic_cast<SetOrNode*>(node);
|
||||
printClauses (casted, os);
|
||||
|
||||
@ -1119,7 +1119,7 @@ LiftedCircuit::exportToGraphViz (CircuitNode* node, std::ofstream& os)
|
||||
break;
|
||||
}
|
||||
|
||||
case SET_AND_NODE: {
|
||||
case setAndCnt: {
|
||||
SetAndNode* casted = dynamic_cast<SetAndNode*>(node);
|
||||
printClauses (casted, os);
|
||||
|
||||
@ -1138,7 +1138,7 @@ LiftedCircuit::exportToGraphViz (CircuitNode* node, std::ofstream& os)
|
||||
break;
|
||||
}
|
||||
|
||||
case INC_EXC_NODE: {
|
||||
case incExcCnt: {
|
||||
IncExcNode* casted = dynamic_cast<IncExcNode*>(node);
|
||||
printClauses (casted, os);
|
||||
|
||||
@ -1169,24 +1169,24 @@ LiftedCircuit::exportToGraphViz (CircuitNode* node, std::ofstream& os)
|
||||
break;
|
||||
}
|
||||
|
||||
case LEAF_NODE: {
|
||||
case leafCnt: {
|
||||
printClauses (node, os, "style=filled,fillcolor=palegreen,");
|
||||
break;
|
||||
}
|
||||
|
||||
case SMOOTH_NODE: {
|
||||
case smoothCnt: {
|
||||
printClauses (node, os, "style=filled,fillcolor=lightblue,");
|
||||
break;
|
||||
}
|
||||
|
||||
case TRUE_NODE: {
|
||||
case trueCnt: {
|
||||
os << escapeNode (node);
|
||||
os << " [shape=box,label=\"⊤\"]" ;
|
||||
os << std::endl;
|
||||
break;
|
||||
}
|
||||
|
||||
case COMPILATION_FAILED_NODE: {
|
||||
case compilationFailedCnt: {
|
||||
printClauses (node, os, "style=filled,fillcolor=salmon,");
|
||||
break;
|
||||
}
|
||||
@ -1227,9 +1227,9 @@ LiftedCircuit::printClauses (
|
||||
Clauses clauses;
|
||||
if (Util::contains (originClausesMap_, node)) {
|
||||
clauses = originClausesMap_[node];
|
||||
} else if (getCircuitNodeType (node) == CircuitNodeType::LEAF_NODE) {
|
||||
} else if (getCircuitNodeType (node) == CircuitNodeType::leafCnt) {
|
||||
clauses = { (dynamic_cast<LeafNode*>(node))->clause() } ;
|
||||
} else if (getCircuitNodeType (node) == CircuitNodeType::SMOOTH_NODE) {
|
||||
} else if (getCircuitNodeType (node) == CircuitNodeType::smoothCnt) {
|
||||
clauses = (dynamic_cast<SmoothNode*>(node))->clauses();
|
||||
}
|
||||
assert (clauses.empty() == false);
|
||||
|
@ -14,15 +14,15 @@
|
||||
namespace Horus {
|
||||
|
||||
enum CircuitNodeType {
|
||||
OR_NODE,
|
||||
AND_NODE,
|
||||
SET_OR_NODE,
|
||||
SET_AND_NODE,
|
||||
INC_EXC_NODE,
|
||||
LEAF_NODE,
|
||||
SMOOTH_NODE,
|
||||
TRUE_NODE,
|
||||
COMPILATION_FAILED_NODE
|
||||
orCnt,
|
||||
andCnt,
|
||||
setOrCnt,
|
||||
setAndCnt,
|
||||
incExcCnt,
|
||||
leafCnt,
|
||||
smoothCnt,
|
||||
trueCnt,
|
||||
compilationFailedCnt
|
||||
};
|
||||
|
||||
|
||||
|
@ -149,7 +149,7 @@ LiftedOperations::absorveEvidence (
|
||||
}
|
||||
if (Globals::verbosity > 2 && obsFormulas.empty() == false) {
|
||||
Util::printAsteriskLine();
|
||||
std::cout << "AFTER EVIDENCE ABSORVED" << std::endl;
|
||||
std::cout << "AFTER EVIDENCE ABSORveSolverD" << std::endl;
|
||||
for (size_t i = 0; i < obsFormulas.size(); i++) {
|
||||
std::cout << " -> " << obsFormulas[i] << std::endl;
|
||||
}
|
||||
|
@ -269,11 +269,11 @@ Clause::logVarTypes (size_t litIdx) const
|
||||
const LogVars& lvs = literals_[litIdx].logVars();
|
||||
for (size_t i = 0; i < lvs.size(); i++) {
|
||||
if (posCountedLvs_.contains (lvs[i])) {
|
||||
types.push_back (LogVarType::POS_LV);
|
||||
types.push_back (LogVarType::posLvt);
|
||||
} else if (negCountedLvs_.contains (lvs[i])) {
|
||||
types.push_back (LogVarType::NEG_LV);
|
||||
types.push_back (LogVarType::negLvt);
|
||||
} else {
|
||||
types.push_back (LogVarType::FULL_LV);
|
||||
types.push_back (LogVarType::fullLvt);
|
||||
}
|
||||
}
|
||||
return types;
|
||||
@ -384,9 +384,9 @@ operator<< (std::ostream& os, const LitLvTypes& lit)
|
||||
os << lit.lid_ << "<" ;
|
||||
for (size_t i = 0; i < lit.lvTypes_.size(); i++) {
|
||||
switch (lit.lvTypes_[i]) {
|
||||
case LogVarType::FULL_LV: os << "F" ; break;
|
||||
case LogVarType::POS_LV: os << "P" ; break;
|
||||
case LogVarType::NEG_LV: os << "N" ; break;
|
||||
case LogVarType::fullLvt: os << "F" ; break;
|
||||
case LogVarType::posLvt: os << "P" ; break;
|
||||
case LogVarType::negLvt: os << "N" ; break;
|
||||
}
|
||||
}
|
||||
os << ">" ;
|
||||
@ -398,7 +398,7 @@ operator<< (std::ostream& os, const LitLvTypes& lit)
|
||||
void
|
||||
LitLvTypes::setAllFullLogVars (void)
|
||||
{
|
||||
std::fill (lvTypes_.begin(), lvTypes_.end(), LogVarType::FULL_LV);
|
||||
std::fill (lvTypes_.begin(), lvTypes_.end(), LogVarType::fullLvt);
|
||||
}
|
||||
|
||||
|
||||
|
@ -16,9 +16,9 @@ namespace Horus {
|
||||
class ParfactorList;
|
||||
|
||||
enum LogVarType {
|
||||
FULL_LV,
|
||||
POS_LV,
|
||||
NEG_LV
|
||||
fullLvt,
|
||||
posLvt,
|
||||
negLvt
|
||||
};
|
||||
|
||||
typedef long LiteralId;
|
||||
|
@ -12,9 +12,9 @@ bool logDomain = false;
|
||||
|
||||
unsigned verbosity = 0;
|
||||
|
||||
LiftedSolverType liftedSolver = LiftedSolverType::LVE;
|
||||
LiftedSolverType liftedSolver = LiftedSolverType::lveSolver;
|
||||
|
||||
GroundSolverType groundSolver = GroundSolverType::VE;
|
||||
GroundSolverType groundSolver = GroundSolverType::veSolver;
|
||||
|
||||
}
|
||||
|
||||
@ -203,15 +203,15 @@ setHorusFlag (std::string option, std::string value)
|
||||
{
|
||||
bool returnVal = true;
|
||||
if (option == "lifted_solver") {
|
||||
if (value == "lve") Globals::liftedSolver = LiftedSolverType::LVE;
|
||||
else if (value == "lbp") Globals::liftedSolver = LiftedSolverType::LBP;
|
||||
else if (value == "lkc") Globals::liftedSolver = LiftedSolverType::LKC;
|
||||
if (value == "lve") Globals::liftedSolver = LiftedSolverType::lveSolver;
|
||||
else if (value == "lbp") Globals::liftedSolver = LiftedSolverType::lbpSolver;
|
||||
else if (value == "lkc") Globals::liftedSolver = LiftedSolverType::lkcSolver;
|
||||
else returnVal = invalidValue (option, value);
|
||||
|
||||
} else if (option == "ground_solver" || option == "solver") {
|
||||
if (value == "hve") Globals::groundSolver = GroundSolverType::VE;
|
||||
else if (value == "bp") Globals::groundSolver = GroundSolverType::BP;
|
||||
else if (value == "cbp") Globals::groundSolver = GroundSolverType::CBP;
|
||||
if (value == "hve") Globals::groundSolver = GroundSolverType::veSolver;
|
||||
else if (value == "bp") Globals::groundSolver = GroundSolverType::bpSolver;
|
||||
else if (value == "cbp") Globals::groundSolver = GroundSolverType::CbpSolver;
|
||||
else returnVal = invalidValue (option, value);
|
||||
|
||||
} else if (option == "verbosity") {
|
||||
@ -226,27 +226,27 @@ setHorusFlag (std::string option, std::string value)
|
||||
|
||||
} else if (option == "hve_elim_heuristic") {
|
||||
if (value == "sequential")
|
||||
ElimGraph::setElimHeuristic (ElimHeuristic::SEQUENTIAL);
|
||||
ElimGraph::setElimHeuristic (ElimHeuristic::sequentialEh);
|
||||
else if (value == "min_neighbors")
|
||||
ElimGraph::setElimHeuristic (ElimHeuristic::MIN_NEIGHBORS);
|
||||
ElimGraph::setElimHeuristic (ElimHeuristic::minNeighborsEh);
|
||||
else if (value == "min_weight")
|
||||
ElimGraph::setElimHeuristic (ElimHeuristic::MIN_WEIGHT);
|
||||
ElimGraph::setElimHeuristic (ElimHeuristic::minWeightEh);
|
||||
else if (value == "min_fill")
|
||||
ElimGraph::setElimHeuristic (ElimHeuristic::MIN_FILL);
|
||||
ElimGraph::setElimHeuristic (ElimHeuristic::minFillEh);
|
||||
else if (value == "weighted_min_fill")
|
||||
ElimGraph::setElimHeuristic (ElimHeuristic::WEIGHTED_MIN_FILL);
|
||||
ElimGraph::setElimHeuristic (ElimHeuristic::weightedMinFillEh);
|
||||
else
|
||||
returnVal = invalidValue (option, value);
|
||||
|
||||
} else if (option == "bp_msg_schedule") {
|
||||
if (value == "seq_fixed")
|
||||
BeliefProp::setMsgSchedule (MsgSchedule::SEQ_FIXED);
|
||||
BeliefProp::setMsgSchedule (MsgSchedule::seqFixedSch);
|
||||
else if (value == "seq_random")
|
||||
BeliefProp::setMsgSchedule (MsgSchedule::SEQ_RANDOM);
|
||||
BeliefProp::setMsgSchedule (MsgSchedule::seqRandomSch);
|
||||
else if (value == "parallel")
|
||||
BeliefProp::setMsgSchedule (MsgSchedule::PARALLEL);
|
||||
BeliefProp::setMsgSchedule (MsgSchedule::parallelSch);
|
||||
else if (value == "max_residual")
|
||||
BeliefProp::setMsgSchedule (MsgSchedule::MAX_RESIDUAL);
|
||||
BeliefProp::setMsgSchedule (MsgSchedule::maxResidualSch);
|
||||
else
|
||||
returnVal = invalidValue (option, value);
|
||||
|
||||
|
@ -87,7 +87,7 @@ unsigned nrDigits (int);
|
||||
bool isInteger (const std::string&);
|
||||
|
||||
std::string parametersToString (
|
||||
const Params&, unsigned = Constants::PRECISION);
|
||||
const Params&, unsigned = Constants::precision);
|
||||
|
||||
std::vector<std::string> getStateLines (const Vars&);
|
||||
|
||||
|
@ -24,7 +24,7 @@ class Var {
|
||||
public:
|
||||
Var (const Var*);
|
||||
|
||||
Var (VarId, unsigned, int = Constants::NO_EVIDENCE);
|
||||
Var (VarId, unsigned, int = Constants::unobserved);
|
||||
|
||||
virtual ~Var (void) { };
|
||||
|
||||
@ -79,7 +79,7 @@ class Var {
|
||||
inline bool
|
||||
Var::hasEvidence (void) const
|
||||
{
|
||||
return evidence_ != Constants::NO_EVIDENCE;
|
||||
return evidence_ != Constants::unobserved;
|
||||
}
|
||||
|
||||
|
||||
|
@ -43,11 +43,11 @@ VarElim::printSolverFlags (void) const
|
||||
ss << "variable elimination [" ;
|
||||
ss << "elim_heuristic=" ;
|
||||
switch (ElimGraph::elimHeuristic()) {
|
||||
case ElimHeuristic::SEQUENTIAL: ss << "sequential"; break;
|
||||
case ElimHeuristic::MIN_NEIGHBORS: ss << "min_neighbors"; break;
|
||||
case ElimHeuristic::MIN_WEIGHT: ss << "min_weight"; break;
|
||||
case ElimHeuristic::MIN_FILL: ss << "min_fill"; break;
|
||||
case ElimHeuristic::WEIGHTED_MIN_FILL: ss << "weighted_min_fill"; break;
|
||||
case ElimHeuristic::sequentialEh: ss << "sequential"; break;
|
||||
case ElimHeuristic::minNeighborsEh: ss << "min_neighbors"; break;
|
||||
case ElimHeuristic::minWeightEh: ss << "min_weight"; break;
|
||||
case ElimHeuristic::minFillEh: ss << "min_fill"; break;
|
||||
case ElimHeuristic::weightedMinFillEh: ss << "weighted_min_fill"; break;
|
||||
}
|
||||
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
||||
ss << "]" ;
|
||||
|
@ -177,13 +177,13 @@ WeightedBp::calcFactorToVarMsg (BpLink* _link)
|
||||
for (size_t i = links.size(); i-- > 0; ) {
|
||||
const WeightedLink* l = static_cast<const WeightedLink*> (links[i]);
|
||||
if ( ! (l->varNode() == dst && l->index() == link->index())) {
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
if (Constants::showBpCalcs) {
|
||||
std::cout << " message from " << links[i]->varNode()->label();
|
||||
std::cout << ": " ;
|
||||
}
|
||||
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
|
||||
reps, std::plus<double>());
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
if (Constants::showBpCalcs) {
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
@ -193,13 +193,13 @@ WeightedBp::calcFactorToVarMsg (BpLink* _link)
|
||||
for (size_t i = links.size(); i-- > 0; ) {
|
||||
const WeightedLink* l = static_cast<const WeightedLink*> (links[i]);
|
||||
if ( ! (l->varNode() == dst && l->index() == link->index())) {
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
if (Constants::showBpCalcs) {
|
||||
std::cout << " message from " << links[i]->varNode()->label();
|
||||
std::cout << ": " ;
|
||||
}
|
||||
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
|
||||
reps, std::multiplies<double>());
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
if (Constants::showBpCalcs) {
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
@ -214,7 +214,7 @@ WeightedBp::calcFactorToVarMsg (BpLink* _link)
|
||||
} else {
|
||||
result.params() *= src->factor().params();
|
||||
}
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
if (Constants::showBpCalcs) {
|
||||
std::cout << " message product: " ;
|
||||
std::cout << msgProduct << std::endl;
|
||||
std::cout << " original factor: " ;
|
||||
@ -223,13 +223,13 @@ WeightedBp::calcFactorToVarMsg (BpLink* _link)
|
||||
std::cout << result.params() << std::endl;
|
||||
}
|
||||
result.sumOutAllExceptIndex (link->index());
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
if (Constants::showBpCalcs) {
|
||||
std::cout << " marginalized: " ;
|
||||
std::cout << result.params() << std::endl;
|
||||
}
|
||||
link->nextMessage() = result.params();
|
||||
LogAware::normalize (link->nextMessage());
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
if (Constants::showBpCalcs) {
|
||||
std::cout << " curr msg: " ;
|
||||
std::cout << link->message() << std::endl;
|
||||
std::cout << " next msg: " ;
|
||||
@ -249,14 +249,14 @@ WeightedBp::getVarToFactorMsg (const BpLink* _link) const
|
||||
if (src->hasEvidence()) {
|
||||
msg.resize (src->range(), LogAware::noEvidence());
|
||||
double value = link->message()[src->getEvidence()];
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
if (Constants::showBpCalcs) {
|
||||
msg[src->getEvidence()] = value;
|
||||
std::cout << msg << "^" << link->weight() << "-1" ;
|
||||
}
|
||||
msg[src->getEvidence()] = LogAware::pow (value, link->weight() - 1);
|
||||
} else {
|
||||
msg = link->message();
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
if (Constants::showBpCalcs) {
|
||||
std::cout << msg << "^" << link->weight() << "-1" ;
|
||||
}
|
||||
LogAware::pow (msg, link->weight() - 1);
|
||||
@ -274,13 +274,13 @@ WeightedBp::getVarToFactorMsg (const BpLink* _link) const
|
||||
WeightedLink* l = static_cast<WeightedLink*> (links[i]);
|
||||
if ( ! (l->facNode() == dst && l->index() == link->index())) {
|
||||
msg *= l->powMessage();
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
if (Constants::showBpCalcs) {
|
||||
std::cout << " x " << l->nextMessage() << "^" << link->weight();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
if (Constants::showBpCalcs) {
|
||||
std::cout << " = " << msg;
|
||||
}
|
||||
return msg;
|
||||
|
Reference in New Issue
Block a user