Use camel case for constants and enumerators.
All capitals case should be reserved for macros and besides there is no big need to emphasize constness in general.
This commit is contained in:
parent
afd26ed9b4
commit
ef4ebb4d7f
@ -62,7 +62,7 @@ BpLink::toString (void) const
|
|||||||
|
|
||||||
double BeliefProp::accuracy_ = 0.0001;
|
double BeliefProp::accuracy_ = 0.0001;
|
||||||
unsigned BeliefProp::maxIter_ = 1000;
|
unsigned BeliefProp::maxIter_ = 1000;
|
||||||
MsgSchedule BeliefProp::schedule_ = MsgSchedule::SEQ_FIXED;
|
MsgSchedule BeliefProp::schedule_ = MsgSchedule::seqFixedSch;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -106,10 +106,10 @@ BeliefProp::printSolverFlags (void) const
|
|||||||
ss << "belief propagation [" ;
|
ss << "belief propagation [" ;
|
||||||
ss << "bp_msg_schedule=" ;
|
ss << "bp_msg_schedule=" ;
|
||||||
switch (schedule_) {
|
switch (schedule_) {
|
||||||
case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break;
|
case MsgSchedule::seqFixedSch: ss << "seq_fixed"; break;
|
||||||
case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break;
|
case MsgSchedule::seqRandomSch: ss << "seq_random"; break;
|
||||||
case MsgSchedule::PARALLEL: ss << "parallel"; break;
|
case MsgSchedule::parallelSch: ss << "parallel"; break;
|
||||||
case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break;
|
case MsgSchedule::maxResidualSch: ss << "max_residual"; break;
|
||||||
}
|
}
|
||||||
ss << ",bp_max_iter=" << Util::toString (maxIter_);
|
ss << ",bp_max_iter=" << Util::toString (maxIter_);
|
||||||
ss << ",bp_accuracy=" << Util::toString (accuracy_);
|
ss << ",bp_accuracy=" << Util::toString (accuracy_);
|
||||||
@ -259,15 +259,15 @@ BeliefProp::runSolver (void)
|
|||||||
+ Util::toString (nIters_));
|
+ Util::toString (nIters_));
|
||||||
}
|
}
|
||||||
switch (schedule_) {
|
switch (schedule_) {
|
||||||
case MsgSchedule::SEQ_RANDOM:
|
case MsgSchedule::seqRandomSch:
|
||||||
std::random_shuffle (links_.begin(), links_.end());
|
std::random_shuffle (links_.begin(), links_.end());
|
||||||
// no break
|
// no break
|
||||||
case MsgSchedule::SEQ_FIXED:
|
case MsgSchedule::seqFixedSch:
|
||||||
for (size_t i = 0; i < links_.size(); i++) {
|
for (size_t i = 0; i < links_.size(); i++) {
|
||||||
calculateAndUpdateMessage (links_[i]);
|
calculateAndUpdateMessage (links_[i]);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case MsgSchedule::PARALLEL:
|
case MsgSchedule::parallelSch:
|
||||||
for (size_t i = 0; i < links_.size(); i++) {
|
for (size_t i = 0; i < links_.size(); i++) {
|
||||||
calculateMessage (links_[i]);
|
calculateMessage (links_[i]);
|
||||||
}
|
}
|
||||||
@ -275,7 +275,7 @@ BeliefProp::runSolver (void)
|
|||||||
updateMessage(links_[i]);
|
updateMessage(links_[i]);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case MsgSchedule::MAX_RESIDUAL:
|
case MsgSchedule::maxResidualSch:
|
||||||
maxResidualSchedule();
|
maxResidualSchedule();
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -380,13 +380,13 @@ BeliefProp::calcFactorToVarMsg (BpLink* link)
|
|||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
for (size_t i = links.size(); i-- > 0; ) {
|
for (size_t i = links.size(); i-- > 0; ) {
|
||||||
if (links[i]->varNode() != dst) {
|
if (links[i]->varNode() != dst) {
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (Constants::showBpCalcs) {
|
||||||
std::cout << " message from " << links[i]->varNode()->label();
|
std::cout << " message from " << links[i]->varNode()->label();
|
||||||
std::cout << ": " ;
|
std::cout << ": " ;
|
||||||
}
|
}
|
||||||
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
|
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
|
||||||
reps, std::plus<double>());
|
reps, std::plus<double>());
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (Constants::showBpCalcs) {
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -395,13 +395,13 @@ BeliefProp::calcFactorToVarMsg (BpLink* link)
|
|||||||
} else {
|
} else {
|
||||||
for (size_t i = links.size(); i-- > 0; ) {
|
for (size_t i = links.size(); i-- > 0; ) {
|
||||||
if (links[i]->varNode() != dst) {
|
if (links[i]->varNode() != dst) {
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (Constants::showBpCalcs) {
|
||||||
std::cout << " message from " << links[i]->varNode()->label();
|
std::cout << " message from " << links[i]->varNode()->label();
|
||||||
std::cout << ": " ;
|
std::cout << ": " ;
|
||||||
}
|
}
|
||||||
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
|
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
|
||||||
reps, std::multiplies<double>());
|
reps, std::multiplies<double>());
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (Constants::showBpCalcs) {
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -411,19 +411,19 @@ BeliefProp::calcFactorToVarMsg (BpLink* link)
|
|||||||
Factor result (src->factor().arguments(),
|
Factor result (src->factor().arguments(),
|
||||||
src->factor().ranges(), msgProduct);
|
src->factor().ranges(), msgProduct);
|
||||||
result.multiply (src->factor());
|
result.multiply (src->factor());
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (Constants::showBpCalcs) {
|
||||||
std::cout << " message product: " << msgProduct << std::endl;
|
std::cout << " message product: " << msgProduct << std::endl;
|
||||||
std::cout << " original factor: " << src->factor().params();
|
std::cout << " original factor: " << src->factor().params();
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
std::cout << " factor product: " << result.params() << std::endl;
|
std::cout << " factor product: " << result.params() << std::endl;
|
||||||
}
|
}
|
||||||
result.sumOutAllExcept (dst->varId());
|
result.sumOutAllExcept (dst->varId());
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (Constants::showBpCalcs) {
|
||||||
std::cout << " marginalized: " << result.params() << std::endl;
|
std::cout << " marginalized: " << result.params() << std::endl;
|
||||||
}
|
}
|
||||||
link->nextMessage() = result.params();
|
link->nextMessage() = result.params();
|
||||||
LogAware::normalize (link->nextMessage());
|
LogAware::normalize (link->nextMessage());
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (Constants::showBpCalcs) {
|
||||||
std::cout << " curr msg: " << link->message() << std::endl;
|
std::cout << " curr msg: " << link->message() << std::endl;
|
||||||
std::cout << " next msg: " << link->nextMessage() << std::endl;
|
std::cout << " next msg: " << link->nextMessage() << std::endl;
|
||||||
}
|
}
|
||||||
@ -442,7 +442,7 @@ BeliefProp::getVarToFactorMsg (const BpLink* link) const
|
|||||||
} else {
|
} else {
|
||||||
msg.resize (src->range(), LogAware::one());
|
msg.resize (src->range(), LogAware::one());
|
||||||
}
|
}
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (Constants::showBpCalcs) {
|
||||||
std::cout << msg;
|
std::cout << msg;
|
||||||
}
|
}
|
||||||
BpLinks::const_iterator it;
|
BpLinks::const_iterator it;
|
||||||
@ -452,7 +452,7 @@ BeliefProp::getVarToFactorMsg (const BpLink* link) const
|
|||||||
if (*it != link) {
|
if (*it != link) {
|
||||||
msg += (*it)->message();
|
msg += (*it)->message();
|
||||||
}
|
}
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (Constants::showBpCalcs) {
|
||||||
std::cout << " x " << (*it)->message();
|
std::cout << " x " << (*it)->message();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -461,12 +461,12 @@ BeliefProp::getVarToFactorMsg (const BpLink* link) const
|
|||||||
if (*it != link) {
|
if (*it != link) {
|
||||||
msg *= (*it)->message();
|
msg *= (*it)->message();
|
||||||
}
|
}
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (Constants::showBpCalcs) {
|
||||||
std::cout << " x " << (*it)->message();
|
std::cout << " x " << (*it)->message();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (Constants::showBpCalcs) {
|
||||||
std::cout << " = " << msg;
|
std::cout << " = " << msg;
|
||||||
}
|
}
|
||||||
return msg;
|
return msg;
|
||||||
@ -478,7 +478,7 @@ Params
|
|||||||
BeliefProp::getJointByConditioning (const VarIds& jointVarIds) const
|
BeliefProp::getJointByConditioning (const VarIds& jointVarIds) const
|
||||||
{
|
{
|
||||||
return GroundSolver::getJointByConditioning (
|
return GroundSolver::getJointByConditioning (
|
||||||
GroundSolverType::BP, fg, jointVarIds);
|
GroundSolverType::bpSolver, fg, jointVarIds);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -526,7 +526,7 @@ BeliefProp::converged (void)
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
bool converged = true;
|
bool converged = true;
|
||||||
if (schedule_ == MsgSchedule::MAX_RESIDUAL) {
|
if (schedule_ == MsgSchedule::maxResidualSch) {
|
||||||
double maxResidual = (*(sortedOrder_.begin()))->residual();
|
double maxResidual = (*(sortedOrder_.begin()))->residual();
|
||||||
if (maxResidual > accuracy_) {
|
if (maxResidual > accuracy_) {
|
||||||
converged = false;
|
converged = false;
|
||||||
|
@ -12,10 +12,10 @@
|
|||||||
namespace Horus {
|
namespace Horus {
|
||||||
|
|
||||||
enum MsgSchedule {
|
enum MsgSchedule {
|
||||||
SEQ_FIXED,
|
seqFixedSch,
|
||||||
SEQ_RANDOM,
|
seqRandomSch,
|
||||||
PARALLEL,
|
parallelSch,
|
||||||
MAX_RESIDUAL
|
maxResidualSch
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
@ -45,10 +45,10 @@ CountingBp::printSolverFlags (void) const
|
|||||||
ss << "counting bp [" ;
|
ss << "counting bp [" ;
|
||||||
ss << "bp_msg_schedule=" ;
|
ss << "bp_msg_schedule=" ;
|
||||||
switch (WeightedBp::msgSchedule()) {
|
switch (WeightedBp::msgSchedule()) {
|
||||||
case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break;
|
case MsgSchedule::seqFixedSch: ss << "seq_fixed"; break;
|
||||||
case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break;
|
case MsgSchedule::seqRandomSch: ss << "seq_random"; break;
|
||||||
case MsgSchedule::PARALLEL: ss << "parallel"; break;
|
case MsgSchedule::parallelSch: ss << "parallel"; break;
|
||||||
case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break;
|
case MsgSchedule::maxResidualSch: ss << "max_residual"; break;
|
||||||
}
|
}
|
||||||
ss << ",bp_max_iter=" << WeightedBp::maxIterations();
|
ss << ",bp_max_iter=" << WeightedBp::maxIterations();
|
||||||
ss << ",bp_accuracy=" << WeightedBp::accuracy();
|
ss << ",bp_accuracy=" << WeightedBp::accuracy();
|
||||||
@ -79,7 +79,7 @@ CountingBp::solveQuery (VarIds queryVids)
|
|||||||
}
|
}
|
||||||
if (idx == facNodes.size()) {
|
if (idx == facNodes.size()) {
|
||||||
res = GroundSolver::getJointByConditioning (
|
res = GroundSolver::getJointByConditioning (
|
||||||
GroundSolverType::CBP, fg, queryVids);
|
GroundSolverType::CbpSolver, fg, queryVids);
|
||||||
} else {
|
} else {
|
||||||
VarIds reprArgs;
|
VarIds reprArgs;
|
||||||
for (size_t i = 0; i < queryVids.size(); i++) {
|
for (size_t i = 0; i < queryVids.size(); i++) {
|
||||||
|
@ -6,7 +6,7 @@
|
|||||||
|
|
||||||
namespace Horus {
|
namespace Horus {
|
||||||
|
|
||||||
ElimHeuristic ElimGraph::elimHeuristic_ = MIN_NEIGHBORS;
|
ElimHeuristic ElimGraph::elimHeuristic_ = minNeighborsEh;
|
||||||
|
|
||||||
|
|
||||||
ElimGraph::ElimGraph (const std::vector<Factor*>& factors)
|
ElimGraph::ElimGraph (const std::vector<Factor*>& factors)
|
||||||
@ -137,7 +137,7 @@ ElimGraph::getEliminationOrder (
|
|||||||
const Factors& factors,
|
const Factors& factors,
|
||||||
VarIds excludedVids)
|
VarIds excludedVids)
|
||||||
{
|
{
|
||||||
if (elimHeuristic_ == ElimHeuristic::SEQUENTIAL) {
|
if (elimHeuristic_ == ElimHeuristic::sequentialEh) {
|
||||||
VarIds allVids;
|
VarIds allVids;
|
||||||
Factors::const_iterator first = factors.begin();
|
Factors::const_iterator first = factors.begin();
|
||||||
Factors::const_iterator end = factors.end();
|
Factors::const_iterator end = factors.end();
|
||||||
@ -181,7 +181,7 @@ ElimGraph::getLowestCostNode (void) const
|
|||||||
unsigned minCost = Util::maxUnsigned();
|
unsigned minCost = Util::maxUnsigned();
|
||||||
EGNeighs::const_iterator it;
|
EGNeighs::const_iterator it;
|
||||||
switch (elimHeuristic_) {
|
switch (elimHeuristic_) {
|
||||||
case MIN_NEIGHBORS: {
|
case minNeighborsEh: {
|
||||||
for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) {
|
for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) {
|
||||||
unsigned cost = getNeighborsCost (*it);
|
unsigned cost = getNeighborsCost (*it);
|
||||||
if (cost < minCost) {
|
if (cost < minCost) {
|
||||||
@ -190,7 +190,7 @@ ElimGraph::getLowestCostNode (void) const
|
|||||||
}
|
}
|
||||||
}}
|
}}
|
||||||
break;
|
break;
|
||||||
case MIN_WEIGHT: {
|
case minWeightEh: {
|
||||||
for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) {
|
for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) {
|
||||||
unsigned cost = getWeightCost (*it);
|
unsigned cost = getWeightCost (*it);
|
||||||
if (cost < minCost) {
|
if (cost < minCost) {
|
||||||
@ -199,7 +199,7 @@ ElimGraph::getLowestCostNode (void) const
|
|||||||
}
|
}
|
||||||
}}
|
}}
|
||||||
break;
|
break;
|
||||||
case MIN_FILL: {
|
case minFillEh: {
|
||||||
for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) {
|
for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) {
|
||||||
unsigned cost = getFillCost (*it);
|
unsigned cost = getFillCost (*it);
|
||||||
if (cost < minCost) {
|
if (cost < minCost) {
|
||||||
@ -208,7 +208,7 @@ ElimGraph::getLowestCostNode (void) const
|
|||||||
}
|
}
|
||||||
}}
|
}}
|
||||||
break;
|
break;
|
||||||
case WEIGHTED_MIN_FILL: {
|
case weightedMinFillEh: {
|
||||||
for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) {
|
for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) {
|
||||||
unsigned cost = getWeightedFillCost (*it);
|
unsigned cost = getWeightedFillCost (*it);
|
||||||
if (cost < minCost) {
|
if (cost < minCost) {
|
||||||
|
@ -19,11 +19,11 @@ typedef TinySet<EgNode*> EGNeighs;
|
|||||||
|
|
||||||
|
|
||||||
enum ElimHeuristic {
|
enum ElimHeuristic {
|
||||||
SEQUENTIAL,
|
sequentialEh,
|
||||||
MIN_NEIGHBORS,
|
minNeighborsEh,
|
||||||
MIN_WEIGHT,
|
minWeightEh,
|
||||||
MIN_FILL,
|
minFillEh,
|
||||||
WEIGHTED_MIN_FILL
|
weightedMinFillEh
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
@ -19,7 +19,7 @@ class FacNode;
|
|||||||
class VarNode : public Var {
|
class VarNode : public Var {
|
||||||
public:
|
public:
|
||||||
VarNode (VarId varId, unsigned nrStates,
|
VarNode (VarId varId, unsigned nrStates,
|
||||||
int evidence = Constants::NO_EVIDENCE)
|
int evidence = Constants::unobserved)
|
||||||
: Var (varId, nrStates, evidence) { }
|
: Var (varId, nrStates, evidence) { }
|
||||||
|
|
||||||
VarNode (const Var* v) : Var (v) { }
|
VarNode (const Var* v) : Var (v) { }
|
||||||
|
@ -31,7 +31,7 @@ GroundSolver::printAnswer (const VarIds& vids)
|
|||||||
Util::getStateLines (unobservedVars);
|
Util::getStateLines (unobservedVars);
|
||||||
for (size_t i = 0; i < res.size(); i++) {
|
for (size_t i = 0; i < res.size(); i++) {
|
||||||
std::cout << "P(" << stateLines[i] << ") = " ;
|
std::cout << "P(" << stateLines[i] << ") = " ;
|
||||||
std::cout << std::setprecision (Constants::PRECISION) << res[i];
|
std::cout << std::setprecision (Constants::precision) << res[i];
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
}
|
}
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
@ -66,9 +66,9 @@ GroundSolver::getJointByConditioning (
|
|||||||
|
|
||||||
GroundSolver* solver = 0;
|
GroundSolver* solver = 0;
|
||||||
switch (solverType) {
|
switch (solverType) {
|
||||||
case GroundSolverType::BP: solver = new BeliefProp (fg); break;
|
case GroundSolverType::bpSolver: solver = new BeliefProp (fg); break;
|
||||||
case GroundSolverType::CBP: solver = new CountingBp (fg); break;
|
case GroundSolverType::CbpSolver: solver = new CountingBp (fg); break;
|
||||||
case GroundSolverType::VE: solver = new VarElim (fg); break;
|
case GroundSolverType::veSolver: solver = new VarElim (fg); break;
|
||||||
}
|
}
|
||||||
Params prevBeliefs = solver->solveQuery ({jointVarIds[0]});
|
Params prevBeliefs = solver->solveQuery ({jointVarIds[0]});
|
||||||
VarIds observedVids = {jointVars[0]->varId()};
|
VarIds observedVids = {jointVars[0]->varId()};
|
||||||
@ -89,9 +89,9 @@ GroundSolver::getJointByConditioning (
|
|||||||
}
|
}
|
||||||
delete solver;
|
delete solver;
|
||||||
switch (solverType) {
|
switch (solverType) {
|
||||||
case GroundSolverType::BP: solver = new BeliefProp (fg); break;
|
case GroundSolverType::bpSolver: solver = new BeliefProp (fg); break;
|
||||||
case GroundSolverType::CBP: solver = new CountingBp (fg); break;
|
case GroundSolverType::CbpSolver: solver = new CountingBp (fg); break;
|
||||||
case GroundSolverType::VE: solver = new VarElim (fg); break;
|
case GroundSolverType::veSolver: solver = new VarElim (fg); break;
|
||||||
}
|
}
|
||||||
Params beliefs = solver->solveQuery ({jointVarIds[i]});
|
Params beliefs = solver->solveQuery ({jointVarIds[i]});
|
||||||
for (size_t k = 0; k < beliefs.size(); k++) {
|
for (size_t k = 0; k < beliefs.size(); k++) {
|
||||||
|
@ -35,16 +35,16 @@ typedef unsigned long long ullong;
|
|||||||
|
|
||||||
|
|
||||||
enum LiftedSolverType {
|
enum LiftedSolverType {
|
||||||
LVE, // generalized counting first-order variable elimination (GC-FOVE)
|
lveSolver, // generalized counting first-order variable elimination (GC-FOveSolver)
|
||||||
LBP, // lifted first-order belief propagation
|
lbpSolver, // lifted first-order belief propagation
|
||||||
LKC // lifted first-order knowledge compilation
|
lkcSolver // lifted first-order knowledge compilation
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
enum GroundSolverType {
|
enum GroundSolverType {
|
||||||
VE, // variable elimination
|
veSolver, // variable elimination
|
||||||
BP, // belief propagation
|
bpSolver, // belief propagation
|
||||||
CBP // counting belief propagation
|
CbpSolver // counting belief propagation
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
@ -64,12 +64,12 @@ extern GroundSolverType groundSolver;
|
|||||||
namespace Constants {
|
namespace Constants {
|
||||||
|
|
||||||
// show message calculation for belief propagation
|
// show message calculation for belief propagation
|
||||||
const bool SHOW_BP_CALCS = false;
|
const bool showBpCalcs = false;
|
||||||
|
|
||||||
const int NO_EVIDENCE = -1;
|
const int unobserved = -1;
|
||||||
|
|
||||||
// number of digits to show when printing a parameter
|
// number of digits to show when printing a parameter
|
||||||
const unsigned PRECISION = 6;
|
const unsigned precision = 6;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -18,7 +18,7 @@ Horus::VarIds readQueryAndEvidence (
|
|||||||
|
|
||||||
void runSolver (const Horus::FactorGraph&, const Horus::VarIds&);
|
void runSolver (const Horus::FactorGraph&, const Horus::VarIds&);
|
||||||
|
|
||||||
const std::string USAGE = "usage: ./hcli [solver=hve|bp|cbp] \
|
const std::string usage = "usage: ./hcli [solver=hve|bp|cbp] \
|
||||||
[<OPTION>=<VALUE>]... <FILE> [<VAR>|<VAR>=<EVIDENCE>]... " ;
|
[<OPTION>=<VALUE>]... <FILE> [<VAR>|<VAR>=<EVIDENCE>]... " ;
|
||||||
|
|
||||||
|
|
||||||
@ -27,7 +27,7 @@ main (int argc, const char* argv[])
|
|||||||
{
|
{
|
||||||
if (argc <= 1) {
|
if (argc <= 1) {
|
||||||
std::cerr << "Error: no probabilistic graphical model was given." ;
|
std::cerr << "Error: no probabilistic graphical model was given." ;
|
||||||
std::cerr << std::endl << USAGE << std::endl;
|
std::cerr << std::endl << usage << std::endl;
|
||||||
exit (EXIT_FAILURE);
|
exit (EXIT_FAILURE);
|
||||||
}
|
}
|
||||||
int idx = readHorusFlags (argc, argv);
|
int idx = readHorusFlags (argc, argv);
|
||||||
@ -72,12 +72,12 @@ readHorusFlags (int argc, const char* argv[])
|
|||||||
std::string rightArg = arg.substr (pos + 1);
|
std::string rightArg = arg.substr (pos + 1);
|
||||||
if (leftArg.empty()) {
|
if (leftArg.empty()) {
|
||||||
std::cerr << "Error: missing left argument." << std::endl;
|
std::cerr << "Error: missing left argument." << std::endl;
|
||||||
std::cerr << USAGE << std::endl;
|
std::cerr << usage << std::endl;
|
||||||
exit (EXIT_FAILURE);
|
exit (EXIT_FAILURE);
|
||||||
}
|
}
|
||||||
if (rightArg.empty()) {
|
if (rightArg.empty()) {
|
||||||
std::cerr << "Error: missing right argument." << std::endl;
|
std::cerr << "Error: missing right argument." << std::endl;
|
||||||
std::cerr << USAGE << std::endl;
|
std::cerr << usage << std::endl;
|
||||||
exit (EXIT_FAILURE);
|
exit (EXIT_FAILURE);
|
||||||
}
|
}
|
||||||
Horus::Util::setHorusFlag (leftArg, rightArg);
|
Horus::Util::setHorusFlag (leftArg, rightArg);
|
||||||
@ -136,7 +136,7 @@ readQueryAndEvidence (
|
|||||||
std::string rightArg = arg.substr (pos + 1);
|
std::string rightArg = arg.substr (pos + 1);
|
||||||
if (leftArg.empty()) {
|
if (leftArg.empty()) {
|
||||||
std::cerr << "Error: missing left argument." << std::endl;
|
std::cerr << "Error: missing left argument." << std::endl;
|
||||||
std::cerr << USAGE << std::endl;
|
std::cerr << usage << std::endl;
|
||||||
exit (EXIT_FAILURE);
|
exit (EXIT_FAILURE);
|
||||||
}
|
}
|
||||||
if (Horus::Util::isInteger (leftArg) == false) {
|
if (Horus::Util::isInteger (leftArg) == false) {
|
||||||
@ -153,7 +153,7 @@ readQueryAndEvidence (
|
|||||||
}
|
}
|
||||||
if (rightArg.empty()) {
|
if (rightArg.empty()) {
|
||||||
std::cerr << "Error: missing right argument." << std::endl;
|
std::cerr << "Error: missing right argument." << std::endl;
|
||||||
std::cerr << USAGE << std::endl;
|
std::cerr << usage << std::endl;
|
||||||
exit (EXIT_FAILURE);
|
exit (EXIT_FAILURE);
|
||||||
}
|
}
|
||||||
if (Horus::Util::isInteger (rightArg) == false) {
|
if (Horus::Util::isInteger (rightArg) == false) {
|
||||||
@ -183,13 +183,13 @@ runSolver (
|
|||||||
{
|
{
|
||||||
Horus::GroundSolver* solver = 0;
|
Horus::GroundSolver* solver = 0;
|
||||||
switch (Horus::Globals::groundSolver) {
|
switch (Horus::Globals::groundSolver) {
|
||||||
case Horus::GroundSolverType::VE:
|
case Horus::GroundSolverType::veSolver:
|
||||||
solver = new Horus::VarElim (fg);
|
solver = new Horus::VarElim (fg);
|
||||||
break;
|
break;
|
||||||
case Horus::GroundSolverType::BP:
|
case Horus::GroundSolverType::bpSolver:
|
||||||
solver = new Horus::BeliefProp (fg);
|
solver = new Horus::BeliefProp (fg);
|
||||||
break;
|
break;
|
||||||
case Horus::GroundSolverType::CBP:
|
case Horus::GroundSolverType::CbpSolver:
|
||||||
solver = new Horus::CountingBp (fg);
|
solver = new Horus::CountingBp (fg);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
|
@ -140,9 +140,9 @@ runLiftedSolver (void)
|
|||||||
|
|
||||||
LiftedSolver* solver = 0;
|
LiftedSolver* solver = 0;
|
||||||
switch (Globals::liftedSolver) {
|
switch (Globals::liftedSolver) {
|
||||||
case LiftedSolverType::LVE: solver = new LiftedVe (pfListCopy); break;
|
case LiftedSolverType::lveSolver: solver = new LiftedVe (pfListCopy); break;
|
||||||
case LiftedSolverType::LBP: solver = new LiftedBp (pfListCopy); break;
|
case LiftedSolverType::lbpSolver: solver = new LiftedBp (pfListCopy); break;
|
||||||
case LiftedSolverType::LKC: solver = new LiftedKc (pfListCopy); break;
|
case LiftedSolverType::lkcSolver: solver = new LiftedKc (pfListCopy); break;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Globals::verbosity > 0) {
|
if (Globals::verbosity > 0) {
|
||||||
@ -214,9 +214,9 @@ runGroundSolver (void)
|
|||||||
GroundSolver* solver = 0;
|
GroundSolver* solver = 0;
|
||||||
CountingBp::setFindIdenticalFactorsFlag (false);
|
CountingBp::setFindIdenticalFactorsFlag (false);
|
||||||
switch (Globals::groundSolver) {
|
switch (Globals::groundSolver) {
|
||||||
case GroundSolverType::VE: solver = new VarElim (*mfg); break;
|
case GroundSolverType::veSolver: solver = new VarElim (*mfg); break;
|
||||||
case GroundSolverType::BP: solver = new BeliefProp (*mfg); break;
|
case GroundSolverType::bpSolver: solver = new BeliefProp (*mfg); break;
|
||||||
case GroundSolverType::CBP: solver = new CountingBp (*mfg); break;
|
case GroundSolverType::CbpSolver: solver = new CountingBp (*mfg); break;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Globals::verbosity > 0) {
|
if (Globals::verbosity > 0) {
|
||||||
|
@ -70,10 +70,10 @@ LiftedBp::printSolverFlags (void) const
|
|||||||
ss << "lifted bp [" ;
|
ss << "lifted bp [" ;
|
||||||
ss << "bp_msg_schedule=" ;
|
ss << "bp_msg_schedule=" ;
|
||||||
switch (WeightedBp::msgSchedule()) {
|
switch (WeightedBp::msgSchedule()) {
|
||||||
case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break;
|
case MsgSchedule::seqFixedSch: ss << "seq_fixed"; break;
|
||||||
case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break;
|
case MsgSchedule::seqRandomSch: ss << "seq_random"; break;
|
||||||
case MsgSchedule::PARALLEL: ss << "parallel"; break;
|
case MsgSchedule::parallelSch: ss << "parallel"; break;
|
||||||
case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break;
|
case MsgSchedule::maxResidualSch: ss << "max_residual"; break;
|
||||||
}
|
}
|
||||||
ss << ",bp_max_iter=" << WeightedBp::maxIterations();
|
ss << ",bp_max_iter=" << WeightedBp::maxIterations();
|
||||||
ss << ",bp_accuracy=" << WeightedBp::accuracy();
|
ss << ",bp_accuracy=" << WeightedBp::accuracy();
|
||||||
|
@ -815,7 +815,7 @@ LiftedCircuit::smoothCircuit (CircuitNode* node)
|
|||||||
|
|
||||||
switch (getCircuitNodeType (node)) {
|
switch (getCircuitNodeType (node)) {
|
||||||
|
|
||||||
case CircuitNodeType::OR_NODE: {
|
case CircuitNodeType::orCnt: {
|
||||||
OrNode* casted = dynamic_cast<OrNode*>(node);
|
OrNode* casted = dynamic_cast<OrNode*>(node);
|
||||||
LitLvTypesSet lids1 = smoothCircuit (*casted->leftBranch());
|
LitLvTypesSet lids1 = smoothCircuit (*casted->leftBranch());
|
||||||
LitLvTypesSet lids2 = smoothCircuit (*casted->rightBranch());
|
LitLvTypesSet lids2 = smoothCircuit (*casted->rightBranch());
|
||||||
@ -828,7 +828,7 @@ LiftedCircuit::smoothCircuit (CircuitNode* node)
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
case CircuitNodeType::AND_NODE: {
|
case CircuitNodeType::andCnt: {
|
||||||
AndNode* casted = dynamic_cast<AndNode*>(node);
|
AndNode* casted = dynamic_cast<AndNode*>(node);
|
||||||
LitLvTypesSet lids1 = smoothCircuit (*casted->leftBranch());
|
LitLvTypesSet lids1 = smoothCircuit (*casted->leftBranch());
|
||||||
LitLvTypesSet lids2 = smoothCircuit (*casted->rightBranch());
|
LitLvTypesSet lids2 = smoothCircuit (*casted->rightBranch());
|
||||||
@ -837,7 +837,7 @@ LiftedCircuit::smoothCircuit (CircuitNode* node)
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
case CircuitNodeType::SET_OR_NODE: {
|
case CircuitNodeType::setOrCnt: {
|
||||||
SetOrNode* casted = dynamic_cast<SetOrNode*>(node);
|
SetOrNode* casted = dynamic_cast<SetOrNode*>(node);
|
||||||
propagLits = smoothCircuit (*casted->follow());
|
propagLits = smoothCircuit (*casted->follow());
|
||||||
TinySet<std::pair<LiteralId,unsigned>> litSet;
|
TinySet<std::pair<LiteralId,unsigned>> litSet;
|
||||||
@ -875,13 +875,13 @@ LiftedCircuit::smoothCircuit (CircuitNode* node)
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
case CircuitNodeType::SET_AND_NODE: {
|
case CircuitNodeType::setAndCnt: {
|
||||||
SetAndNode* casted = dynamic_cast<SetAndNode*>(node);
|
SetAndNode* casted = dynamic_cast<SetAndNode*>(node);
|
||||||
propagLits = smoothCircuit (*casted->follow());
|
propagLits = smoothCircuit (*casted->follow());
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
case CircuitNodeType::INC_EXC_NODE: {
|
case CircuitNodeType::incExcCnt: {
|
||||||
IncExcNode* casted = dynamic_cast<IncExcNode*>(node);
|
IncExcNode* casted = dynamic_cast<IncExcNode*>(node);
|
||||||
LitLvTypesSet lids1 = smoothCircuit (*casted->plus1Branch());
|
LitLvTypesSet lids1 = smoothCircuit (*casted->plus1Branch());
|
||||||
LitLvTypesSet lids2 = smoothCircuit (*casted->plus2Branch());
|
LitLvTypesSet lids2 = smoothCircuit (*casted->plus2Branch());
|
||||||
@ -894,7 +894,7 @@ LiftedCircuit::smoothCircuit (CircuitNode* node)
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
case CircuitNodeType::LEAF_NODE: {
|
case CircuitNodeType::leafCnt: {
|
||||||
LeafNode* casted = dynamic_cast<LeafNode*>(node);
|
LeafNode* casted = dynamic_cast<LeafNode*>(node);
|
||||||
propagLits.insert (LitLvTypes (
|
propagLits.insert (LitLvTypes (
|
||||||
casted->clause()->literals()[0].lid(),
|
casted->clause()->literals()[0].lid(),
|
||||||
@ -933,9 +933,9 @@ LiftedCircuit::createSmoothNode (
|
|||||||
Clause* c = lwcnf_->createClause (lid);
|
Clause* c = lwcnf_->createClause (lid);
|
||||||
for (size_t j = 0; j < types.size(); j++) {
|
for (size_t j = 0; j < types.size(); j++) {
|
||||||
LogVar X = c->literals().front().logVars()[j];
|
LogVar X = c->literals().front().logVars()[j];
|
||||||
if (types[j] == LogVarType::POS_LV) {
|
if (types[j] == LogVarType::posLvt) {
|
||||||
c->addPosCountedLogVar (X);
|
c->addPosCountedLogVar (X);
|
||||||
} else if (types[j] == LogVarType::NEG_LV) {
|
} else if (types[j] == LogVarType::negLvt) {
|
||||||
c->addNegCountedLogVar (X);
|
c->addNegCountedLogVar (X);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -960,8 +960,8 @@ LiftedCircuit::getAllPossibleTypes (unsigned nrLogVars) const
|
|||||||
if (nrLogVars == 0) {
|
if (nrLogVars == 0) {
|
||||||
// do nothing
|
// do nothing
|
||||||
} else if (nrLogVars == 1) {
|
} else if (nrLogVars == 1) {
|
||||||
res.push_back ({ LogVarType::POS_LV });
|
res.push_back ({ LogVarType::posLvt });
|
||||||
res.push_back ({ LogVarType::NEG_LV });
|
res.push_back ({ LogVarType::negLvt });
|
||||||
} else {
|
} else {
|
||||||
Ranges ranges (nrLogVars, 2);
|
Ranges ranges (nrLogVars, 2);
|
||||||
Indexer indexer (ranges);
|
Indexer indexer (ranges);
|
||||||
@ -969,9 +969,9 @@ LiftedCircuit::getAllPossibleTypes (unsigned nrLogVars) const
|
|||||||
LogVarTypes types;
|
LogVarTypes types;
|
||||||
for (size_t i = 0; i < nrLogVars; i++) {
|
for (size_t i = 0; i < nrLogVars; i++) {
|
||||||
if (indexer[i] == 0) {
|
if (indexer[i] == 0) {
|
||||||
types.push_back (LogVarType::POS_LV);
|
types.push_back (LogVarType::posLvt);
|
||||||
} else {
|
} else {
|
||||||
types.push_back (LogVarType::NEG_LV);
|
types.push_back (LogVarType::negLvt);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
res.push_back (types);
|
res.push_back (types);
|
||||||
@ -989,13 +989,13 @@ LiftedCircuit::containsTypes (
|
|||||||
const LogVarTypes& typesB) const
|
const LogVarTypes& typesB) const
|
||||||
{
|
{
|
||||||
for (size_t i = 0; i < typesA.size(); i++) {
|
for (size_t i = 0; i < typesA.size(); i++) {
|
||||||
if (typesA[i] == LogVarType::FULL_LV) {
|
if (typesA[i] == LogVarType::fullLvt) {
|
||||||
|
|
||||||
} else if (typesA[i] == LogVarType::POS_LV
|
} else if (typesA[i] == LogVarType::posLvt
|
||||||
&& typesB[i] == LogVarType::POS_LV) {
|
&& typesB[i] == LogVarType::posLvt) {
|
||||||
|
|
||||||
} else if (typesA[i] == LogVarType::NEG_LV
|
} else if (typesA[i] == LogVarType::negLvt
|
||||||
&& typesB[i] == LogVarType::NEG_LV) {
|
&& typesB[i] == LogVarType::negLvt) {
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
return false;
|
return false;
|
||||||
@ -1009,25 +1009,25 @@ LiftedCircuit::containsTypes (
|
|||||||
CircuitNodeType
|
CircuitNodeType
|
||||||
LiftedCircuit::getCircuitNodeType (const CircuitNode* node) const
|
LiftedCircuit::getCircuitNodeType (const CircuitNode* node) const
|
||||||
{
|
{
|
||||||
CircuitNodeType type = CircuitNodeType::OR_NODE;
|
CircuitNodeType type = CircuitNodeType::orCnt;
|
||||||
if (dynamic_cast<const OrNode*>(node)) {
|
if (dynamic_cast<const OrNode*>(node)) {
|
||||||
type = CircuitNodeType::OR_NODE;
|
type = CircuitNodeType::orCnt;
|
||||||
} else if (dynamic_cast<const AndNode*>(node)) {
|
} else if (dynamic_cast<const AndNode*>(node)) {
|
||||||
type = CircuitNodeType::AND_NODE;
|
type = CircuitNodeType::andCnt;
|
||||||
} else if (dynamic_cast<const SetOrNode*>(node)) {
|
} else if (dynamic_cast<const SetOrNode*>(node)) {
|
||||||
type = CircuitNodeType::SET_OR_NODE;
|
type = CircuitNodeType::setOrCnt;
|
||||||
} else if (dynamic_cast<const SetAndNode*>(node)) {
|
} else if (dynamic_cast<const SetAndNode*>(node)) {
|
||||||
type = CircuitNodeType::SET_AND_NODE;
|
type = CircuitNodeType::setAndCnt;
|
||||||
} else if (dynamic_cast<const IncExcNode*>(node)) {
|
} else if (dynamic_cast<const IncExcNode*>(node)) {
|
||||||
type = CircuitNodeType::INC_EXC_NODE;
|
type = CircuitNodeType::incExcCnt;
|
||||||
} else if (dynamic_cast<const LeafNode*>(node)) {
|
} else if (dynamic_cast<const LeafNode*>(node)) {
|
||||||
type = CircuitNodeType::LEAF_NODE;
|
type = CircuitNodeType::leafCnt;
|
||||||
} else if (dynamic_cast<const SmoothNode*>(node)) {
|
} else if (dynamic_cast<const SmoothNode*>(node)) {
|
||||||
type = CircuitNodeType::SMOOTH_NODE;
|
type = CircuitNodeType::smoothCnt;
|
||||||
} else if (dynamic_cast<const TrueNode*>(node)) {
|
} else if (dynamic_cast<const TrueNode*>(node)) {
|
||||||
type = CircuitNodeType::TRUE_NODE;
|
type = CircuitNodeType::trueCnt;
|
||||||
} else if (dynamic_cast<const CompilationFailedNode*>(node)) {
|
} else if (dynamic_cast<const CompilationFailedNode*>(node)) {
|
||||||
type = CircuitNodeType::COMPILATION_FAILED_NODE;
|
type = CircuitNodeType::compilationFailedCnt;
|
||||||
} else {
|
} else {
|
||||||
assert (false);
|
assert (false);
|
||||||
}
|
}
|
||||||
@ -1050,7 +1050,7 @@ LiftedCircuit::exportToGraphViz (CircuitNode* node, std::ofstream& os)
|
|||||||
|
|
||||||
switch (getCircuitNodeType (node)) {
|
switch (getCircuitNodeType (node)) {
|
||||||
|
|
||||||
case OR_NODE: {
|
case orCnt: {
|
||||||
OrNode* casted = dynamic_cast<OrNode*>(node);
|
OrNode* casted = dynamic_cast<OrNode*>(node);
|
||||||
printClauses (casted, os);
|
printClauses (casted, os);
|
||||||
|
|
||||||
@ -1075,7 +1075,7 @@ LiftedCircuit::exportToGraphViz (CircuitNode* node, std::ofstream& os)
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
case AND_NODE: {
|
case andCnt: {
|
||||||
AndNode* casted = dynamic_cast<AndNode*>(node);
|
AndNode* casted = dynamic_cast<AndNode*>(node);
|
||||||
printClauses (casted, os);
|
printClauses (casted, os);
|
||||||
|
|
||||||
@ -1100,7 +1100,7 @@ LiftedCircuit::exportToGraphViz (CircuitNode* node, std::ofstream& os)
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
case SET_OR_NODE: {
|
case setOrCnt: {
|
||||||
SetOrNode* casted = dynamic_cast<SetOrNode*>(node);
|
SetOrNode* casted = dynamic_cast<SetOrNode*>(node);
|
||||||
printClauses (casted, os);
|
printClauses (casted, os);
|
||||||
|
|
||||||
@ -1119,7 +1119,7 @@ LiftedCircuit::exportToGraphViz (CircuitNode* node, std::ofstream& os)
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
case SET_AND_NODE: {
|
case setAndCnt: {
|
||||||
SetAndNode* casted = dynamic_cast<SetAndNode*>(node);
|
SetAndNode* casted = dynamic_cast<SetAndNode*>(node);
|
||||||
printClauses (casted, os);
|
printClauses (casted, os);
|
||||||
|
|
||||||
@ -1138,7 +1138,7 @@ LiftedCircuit::exportToGraphViz (CircuitNode* node, std::ofstream& os)
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
case INC_EXC_NODE: {
|
case incExcCnt: {
|
||||||
IncExcNode* casted = dynamic_cast<IncExcNode*>(node);
|
IncExcNode* casted = dynamic_cast<IncExcNode*>(node);
|
||||||
printClauses (casted, os);
|
printClauses (casted, os);
|
||||||
|
|
||||||
@ -1169,24 +1169,24 @@ LiftedCircuit::exportToGraphViz (CircuitNode* node, std::ofstream& os)
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
case LEAF_NODE: {
|
case leafCnt: {
|
||||||
printClauses (node, os, "style=filled,fillcolor=palegreen,");
|
printClauses (node, os, "style=filled,fillcolor=palegreen,");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
case SMOOTH_NODE: {
|
case smoothCnt: {
|
||||||
printClauses (node, os, "style=filled,fillcolor=lightblue,");
|
printClauses (node, os, "style=filled,fillcolor=lightblue,");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
case TRUE_NODE: {
|
case trueCnt: {
|
||||||
os << escapeNode (node);
|
os << escapeNode (node);
|
||||||
os << " [shape=box,label=\"⊤\"]" ;
|
os << " [shape=box,label=\"⊤\"]" ;
|
||||||
os << std::endl;
|
os << std::endl;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
case COMPILATION_FAILED_NODE: {
|
case compilationFailedCnt: {
|
||||||
printClauses (node, os, "style=filled,fillcolor=salmon,");
|
printClauses (node, os, "style=filled,fillcolor=salmon,");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -1227,9 +1227,9 @@ LiftedCircuit::printClauses (
|
|||||||
Clauses clauses;
|
Clauses clauses;
|
||||||
if (Util::contains (originClausesMap_, node)) {
|
if (Util::contains (originClausesMap_, node)) {
|
||||||
clauses = originClausesMap_[node];
|
clauses = originClausesMap_[node];
|
||||||
} else if (getCircuitNodeType (node) == CircuitNodeType::LEAF_NODE) {
|
} else if (getCircuitNodeType (node) == CircuitNodeType::leafCnt) {
|
||||||
clauses = { (dynamic_cast<LeafNode*>(node))->clause() } ;
|
clauses = { (dynamic_cast<LeafNode*>(node))->clause() } ;
|
||||||
} else if (getCircuitNodeType (node) == CircuitNodeType::SMOOTH_NODE) {
|
} else if (getCircuitNodeType (node) == CircuitNodeType::smoothCnt) {
|
||||||
clauses = (dynamic_cast<SmoothNode*>(node))->clauses();
|
clauses = (dynamic_cast<SmoothNode*>(node))->clauses();
|
||||||
}
|
}
|
||||||
assert (clauses.empty() == false);
|
assert (clauses.empty() == false);
|
||||||
|
@ -14,15 +14,15 @@
|
|||||||
namespace Horus {
|
namespace Horus {
|
||||||
|
|
||||||
enum CircuitNodeType {
|
enum CircuitNodeType {
|
||||||
OR_NODE,
|
orCnt,
|
||||||
AND_NODE,
|
andCnt,
|
||||||
SET_OR_NODE,
|
setOrCnt,
|
||||||
SET_AND_NODE,
|
setAndCnt,
|
||||||
INC_EXC_NODE,
|
incExcCnt,
|
||||||
LEAF_NODE,
|
leafCnt,
|
||||||
SMOOTH_NODE,
|
smoothCnt,
|
||||||
TRUE_NODE,
|
trueCnt,
|
||||||
COMPILATION_FAILED_NODE
|
compilationFailedCnt
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
@ -149,7 +149,7 @@ LiftedOperations::absorveEvidence (
|
|||||||
}
|
}
|
||||||
if (Globals::verbosity > 2 && obsFormulas.empty() == false) {
|
if (Globals::verbosity > 2 && obsFormulas.empty() == false) {
|
||||||
Util::printAsteriskLine();
|
Util::printAsteriskLine();
|
||||||
std::cout << "AFTER EVIDENCE ABSORVED" << std::endl;
|
std::cout << "AFTER EVIDENCE ABSORveSolverD" << std::endl;
|
||||||
for (size_t i = 0; i < obsFormulas.size(); i++) {
|
for (size_t i = 0; i < obsFormulas.size(); i++) {
|
||||||
std::cout << " -> " << obsFormulas[i] << std::endl;
|
std::cout << " -> " << obsFormulas[i] << std::endl;
|
||||||
}
|
}
|
||||||
|
@ -269,11 +269,11 @@ Clause::logVarTypes (size_t litIdx) const
|
|||||||
const LogVars& lvs = literals_[litIdx].logVars();
|
const LogVars& lvs = literals_[litIdx].logVars();
|
||||||
for (size_t i = 0; i < lvs.size(); i++) {
|
for (size_t i = 0; i < lvs.size(); i++) {
|
||||||
if (posCountedLvs_.contains (lvs[i])) {
|
if (posCountedLvs_.contains (lvs[i])) {
|
||||||
types.push_back (LogVarType::POS_LV);
|
types.push_back (LogVarType::posLvt);
|
||||||
} else if (negCountedLvs_.contains (lvs[i])) {
|
} else if (negCountedLvs_.contains (lvs[i])) {
|
||||||
types.push_back (LogVarType::NEG_LV);
|
types.push_back (LogVarType::negLvt);
|
||||||
} else {
|
} else {
|
||||||
types.push_back (LogVarType::FULL_LV);
|
types.push_back (LogVarType::fullLvt);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return types;
|
return types;
|
||||||
@ -384,9 +384,9 @@ operator<< (std::ostream& os, const LitLvTypes& lit)
|
|||||||
os << lit.lid_ << "<" ;
|
os << lit.lid_ << "<" ;
|
||||||
for (size_t i = 0; i < lit.lvTypes_.size(); i++) {
|
for (size_t i = 0; i < lit.lvTypes_.size(); i++) {
|
||||||
switch (lit.lvTypes_[i]) {
|
switch (lit.lvTypes_[i]) {
|
||||||
case LogVarType::FULL_LV: os << "F" ; break;
|
case LogVarType::fullLvt: os << "F" ; break;
|
||||||
case LogVarType::POS_LV: os << "P" ; break;
|
case LogVarType::posLvt: os << "P" ; break;
|
||||||
case LogVarType::NEG_LV: os << "N" ; break;
|
case LogVarType::negLvt: os << "N" ; break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
os << ">" ;
|
os << ">" ;
|
||||||
@ -398,7 +398,7 @@ operator<< (std::ostream& os, const LitLvTypes& lit)
|
|||||||
void
|
void
|
||||||
LitLvTypes::setAllFullLogVars (void)
|
LitLvTypes::setAllFullLogVars (void)
|
||||||
{
|
{
|
||||||
std::fill (lvTypes_.begin(), lvTypes_.end(), LogVarType::FULL_LV);
|
std::fill (lvTypes_.begin(), lvTypes_.end(), LogVarType::fullLvt);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -16,9 +16,9 @@ namespace Horus {
|
|||||||
class ParfactorList;
|
class ParfactorList;
|
||||||
|
|
||||||
enum LogVarType {
|
enum LogVarType {
|
||||||
FULL_LV,
|
fullLvt,
|
||||||
POS_LV,
|
posLvt,
|
||||||
NEG_LV
|
negLvt
|
||||||
};
|
};
|
||||||
|
|
||||||
typedef long LiteralId;
|
typedef long LiteralId;
|
||||||
|
@ -12,9 +12,9 @@ bool logDomain = false;
|
|||||||
|
|
||||||
unsigned verbosity = 0;
|
unsigned verbosity = 0;
|
||||||
|
|
||||||
LiftedSolverType liftedSolver = LiftedSolverType::LVE;
|
LiftedSolverType liftedSolver = LiftedSolverType::lveSolver;
|
||||||
|
|
||||||
GroundSolverType groundSolver = GroundSolverType::VE;
|
GroundSolverType groundSolver = GroundSolverType::veSolver;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -203,15 +203,15 @@ setHorusFlag (std::string option, std::string value)
|
|||||||
{
|
{
|
||||||
bool returnVal = true;
|
bool returnVal = true;
|
||||||
if (option == "lifted_solver") {
|
if (option == "lifted_solver") {
|
||||||
if (value == "lve") Globals::liftedSolver = LiftedSolverType::LVE;
|
if (value == "lve") Globals::liftedSolver = LiftedSolverType::lveSolver;
|
||||||
else if (value == "lbp") Globals::liftedSolver = LiftedSolverType::LBP;
|
else if (value == "lbp") Globals::liftedSolver = LiftedSolverType::lbpSolver;
|
||||||
else if (value == "lkc") Globals::liftedSolver = LiftedSolverType::LKC;
|
else if (value == "lkc") Globals::liftedSolver = LiftedSolverType::lkcSolver;
|
||||||
else returnVal = invalidValue (option, value);
|
else returnVal = invalidValue (option, value);
|
||||||
|
|
||||||
} else if (option == "ground_solver" || option == "solver") {
|
} else if (option == "ground_solver" || option == "solver") {
|
||||||
if (value == "hve") Globals::groundSolver = GroundSolverType::VE;
|
if (value == "hve") Globals::groundSolver = GroundSolverType::veSolver;
|
||||||
else if (value == "bp") Globals::groundSolver = GroundSolverType::BP;
|
else if (value == "bp") Globals::groundSolver = GroundSolverType::bpSolver;
|
||||||
else if (value == "cbp") Globals::groundSolver = GroundSolverType::CBP;
|
else if (value == "cbp") Globals::groundSolver = GroundSolverType::CbpSolver;
|
||||||
else returnVal = invalidValue (option, value);
|
else returnVal = invalidValue (option, value);
|
||||||
|
|
||||||
} else if (option == "verbosity") {
|
} else if (option == "verbosity") {
|
||||||
@ -226,27 +226,27 @@ setHorusFlag (std::string option, std::string value)
|
|||||||
|
|
||||||
} else if (option == "hve_elim_heuristic") {
|
} else if (option == "hve_elim_heuristic") {
|
||||||
if (value == "sequential")
|
if (value == "sequential")
|
||||||
ElimGraph::setElimHeuristic (ElimHeuristic::SEQUENTIAL);
|
ElimGraph::setElimHeuristic (ElimHeuristic::sequentialEh);
|
||||||
else if (value == "min_neighbors")
|
else if (value == "min_neighbors")
|
||||||
ElimGraph::setElimHeuristic (ElimHeuristic::MIN_NEIGHBORS);
|
ElimGraph::setElimHeuristic (ElimHeuristic::minNeighborsEh);
|
||||||
else if (value == "min_weight")
|
else if (value == "min_weight")
|
||||||
ElimGraph::setElimHeuristic (ElimHeuristic::MIN_WEIGHT);
|
ElimGraph::setElimHeuristic (ElimHeuristic::minWeightEh);
|
||||||
else if (value == "min_fill")
|
else if (value == "min_fill")
|
||||||
ElimGraph::setElimHeuristic (ElimHeuristic::MIN_FILL);
|
ElimGraph::setElimHeuristic (ElimHeuristic::minFillEh);
|
||||||
else if (value == "weighted_min_fill")
|
else if (value == "weighted_min_fill")
|
||||||
ElimGraph::setElimHeuristic (ElimHeuristic::WEIGHTED_MIN_FILL);
|
ElimGraph::setElimHeuristic (ElimHeuristic::weightedMinFillEh);
|
||||||
else
|
else
|
||||||
returnVal = invalidValue (option, value);
|
returnVal = invalidValue (option, value);
|
||||||
|
|
||||||
} else if (option == "bp_msg_schedule") {
|
} else if (option == "bp_msg_schedule") {
|
||||||
if (value == "seq_fixed")
|
if (value == "seq_fixed")
|
||||||
BeliefProp::setMsgSchedule (MsgSchedule::SEQ_FIXED);
|
BeliefProp::setMsgSchedule (MsgSchedule::seqFixedSch);
|
||||||
else if (value == "seq_random")
|
else if (value == "seq_random")
|
||||||
BeliefProp::setMsgSchedule (MsgSchedule::SEQ_RANDOM);
|
BeliefProp::setMsgSchedule (MsgSchedule::seqRandomSch);
|
||||||
else if (value == "parallel")
|
else if (value == "parallel")
|
||||||
BeliefProp::setMsgSchedule (MsgSchedule::PARALLEL);
|
BeliefProp::setMsgSchedule (MsgSchedule::parallelSch);
|
||||||
else if (value == "max_residual")
|
else if (value == "max_residual")
|
||||||
BeliefProp::setMsgSchedule (MsgSchedule::MAX_RESIDUAL);
|
BeliefProp::setMsgSchedule (MsgSchedule::maxResidualSch);
|
||||||
else
|
else
|
||||||
returnVal = invalidValue (option, value);
|
returnVal = invalidValue (option, value);
|
||||||
|
|
||||||
|
@ -87,7 +87,7 @@ unsigned nrDigits (int);
|
|||||||
bool isInteger (const std::string&);
|
bool isInteger (const std::string&);
|
||||||
|
|
||||||
std::string parametersToString (
|
std::string parametersToString (
|
||||||
const Params&, unsigned = Constants::PRECISION);
|
const Params&, unsigned = Constants::precision);
|
||||||
|
|
||||||
std::vector<std::string> getStateLines (const Vars&);
|
std::vector<std::string> getStateLines (const Vars&);
|
||||||
|
|
||||||
|
@ -24,7 +24,7 @@ class Var {
|
|||||||
public:
|
public:
|
||||||
Var (const Var*);
|
Var (const Var*);
|
||||||
|
|
||||||
Var (VarId, unsigned, int = Constants::NO_EVIDENCE);
|
Var (VarId, unsigned, int = Constants::unobserved);
|
||||||
|
|
||||||
virtual ~Var (void) { };
|
virtual ~Var (void) { };
|
||||||
|
|
||||||
@ -79,7 +79,7 @@ class Var {
|
|||||||
inline bool
|
inline bool
|
||||||
Var::hasEvidence (void) const
|
Var::hasEvidence (void) const
|
||||||
{
|
{
|
||||||
return evidence_ != Constants::NO_EVIDENCE;
|
return evidence_ != Constants::unobserved;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -43,11 +43,11 @@ VarElim::printSolverFlags (void) const
|
|||||||
ss << "variable elimination [" ;
|
ss << "variable elimination [" ;
|
||||||
ss << "elim_heuristic=" ;
|
ss << "elim_heuristic=" ;
|
||||||
switch (ElimGraph::elimHeuristic()) {
|
switch (ElimGraph::elimHeuristic()) {
|
||||||
case ElimHeuristic::SEQUENTIAL: ss << "sequential"; break;
|
case ElimHeuristic::sequentialEh: ss << "sequential"; break;
|
||||||
case ElimHeuristic::MIN_NEIGHBORS: ss << "min_neighbors"; break;
|
case ElimHeuristic::minNeighborsEh: ss << "min_neighbors"; break;
|
||||||
case ElimHeuristic::MIN_WEIGHT: ss << "min_weight"; break;
|
case ElimHeuristic::minWeightEh: ss << "min_weight"; break;
|
||||||
case ElimHeuristic::MIN_FILL: ss << "min_fill"; break;
|
case ElimHeuristic::minFillEh: ss << "min_fill"; break;
|
||||||
case ElimHeuristic::WEIGHTED_MIN_FILL: ss << "weighted_min_fill"; break;
|
case ElimHeuristic::weightedMinFillEh: ss << "weighted_min_fill"; break;
|
||||||
}
|
}
|
||||||
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
||||||
ss << "]" ;
|
ss << "]" ;
|
||||||
|
@ -177,13 +177,13 @@ WeightedBp::calcFactorToVarMsg (BpLink* _link)
|
|||||||
for (size_t i = links.size(); i-- > 0; ) {
|
for (size_t i = links.size(); i-- > 0; ) {
|
||||||
const WeightedLink* l = static_cast<const WeightedLink*> (links[i]);
|
const WeightedLink* l = static_cast<const WeightedLink*> (links[i]);
|
||||||
if ( ! (l->varNode() == dst && l->index() == link->index())) {
|
if ( ! (l->varNode() == dst && l->index() == link->index())) {
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (Constants::showBpCalcs) {
|
||||||
std::cout << " message from " << links[i]->varNode()->label();
|
std::cout << " message from " << links[i]->varNode()->label();
|
||||||
std::cout << ": " ;
|
std::cout << ": " ;
|
||||||
}
|
}
|
||||||
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
|
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
|
||||||
reps, std::plus<double>());
|
reps, std::plus<double>());
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (Constants::showBpCalcs) {
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -193,13 +193,13 @@ WeightedBp::calcFactorToVarMsg (BpLink* _link)
|
|||||||
for (size_t i = links.size(); i-- > 0; ) {
|
for (size_t i = links.size(); i-- > 0; ) {
|
||||||
const WeightedLink* l = static_cast<const WeightedLink*> (links[i]);
|
const WeightedLink* l = static_cast<const WeightedLink*> (links[i]);
|
||||||
if ( ! (l->varNode() == dst && l->index() == link->index())) {
|
if ( ! (l->varNode() == dst && l->index() == link->index())) {
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (Constants::showBpCalcs) {
|
||||||
std::cout << " message from " << links[i]->varNode()->label();
|
std::cout << " message from " << links[i]->varNode()->label();
|
||||||
std::cout << ": " ;
|
std::cout << ": " ;
|
||||||
}
|
}
|
||||||
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
|
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
|
||||||
reps, std::multiplies<double>());
|
reps, std::multiplies<double>());
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (Constants::showBpCalcs) {
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -214,7 +214,7 @@ WeightedBp::calcFactorToVarMsg (BpLink* _link)
|
|||||||
} else {
|
} else {
|
||||||
result.params() *= src->factor().params();
|
result.params() *= src->factor().params();
|
||||||
}
|
}
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (Constants::showBpCalcs) {
|
||||||
std::cout << " message product: " ;
|
std::cout << " message product: " ;
|
||||||
std::cout << msgProduct << std::endl;
|
std::cout << msgProduct << std::endl;
|
||||||
std::cout << " original factor: " ;
|
std::cout << " original factor: " ;
|
||||||
@ -223,13 +223,13 @@ WeightedBp::calcFactorToVarMsg (BpLink* _link)
|
|||||||
std::cout << result.params() << std::endl;
|
std::cout << result.params() << std::endl;
|
||||||
}
|
}
|
||||||
result.sumOutAllExceptIndex (link->index());
|
result.sumOutAllExceptIndex (link->index());
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (Constants::showBpCalcs) {
|
||||||
std::cout << " marginalized: " ;
|
std::cout << " marginalized: " ;
|
||||||
std::cout << result.params() << std::endl;
|
std::cout << result.params() << std::endl;
|
||||||
}
|
}
|
||||||
link->nextMessage() = result.params();
|
link->nextMessage() = result.params();
|
||||||
LogAware::normalize (link->nextMessage());
|
LogAware::normalize (link->nextMessage());
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (Constants::showBpCalcs) {
|
||||||
std::cout << " curr msg: " ;
|
std::cout << " curr msg: " ;
|
||||||
std::cout << link->message() << std::endl;
|
std::cout << link->message() << std::endl;
|
||||||
std::cout << " next msg: " ;
|
std::cout << " next msg: " ;
|
||||||
@ -249,14 +249,14 @@ WeightedBp::getVarToFactorMsg (const BpLink* _link) const
|
|||||||
if (src->hasEvidence()) {
|
if (src->hasEvidence()) {
|
||||||
msg.resize (src->range(), LogAware::noEvidence());
|
msg.resize (src->range(), LogAware::noEvidence());
|
||||||
double value = link->message()[src->getEvidence()];
|
double value = link->message()[src->getEvidence()];
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (Constants::showBpCalcs) {
|
||||||
msg[src->getEvidence()] = value;
|
msg[src->getEvidence()] = value;
|
||||||
std::cout << msg << "^" << link->weight() << "-1" ;
|
std::cout << msg << "^" << link->weight() << "-1" ;
|
||||||
}
|
}
|
||||||
msg[src->getEvidence()] = LogAware::pow (value, link->weight() - 1);
|
msg[src->getEvidence()] = LogAware::pow (value, link->weight() - 1);
|
||||||
} else {
|
} else {
|
||||||
msg = link->message();
|
msg = link->message();
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (Constants::showBpCalcs) {
|
||||||
std::cout << msg << "^" << link->weight() << "-1" ;
|
std::cout << msg << "^" << link->weight() << "-1" ;
|
||||||
}
|
}
|
||||||
LogAware::pow (msg, link->weight() - 1);
|
LogAware::pow (msg, link->weight() - 1);
|
||||||
@ -274,13 +274,13 @@ WeightedBp::getVarToFactorMsg (const BpLink* _link) const
|
|||||||
WeightedLink* l = static_cast<WeightedLink*> (links[i]);
|
WeightedLink* l = static_cast<WeightedLink*> (links[i]);
|
||||||
if ( ! (l->facNode() == dst && l->index() == link->index())) {
|
if ( ! (l->facNode() == dst && l->index() == link->index())) {
|
||||||
msg *= l->powMessage();
|
msg *= l->powMessage();
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (Constants::showBpCalcs) {
|
||||||
std::cout << " x " << l->nextMessage() << "^" << link->weight();
|
std::cout << " x " << l->nextMessage() << "^" << link->weight();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (Constants::showBpCalcs) {
|
||||||
std::cout << " = " << msg;
|
std::cout << " = " << msg;
|
||||||
}
|
}
|
||||||
return msg;
|
return msg;
|
||||||
|
Reference in New Issue
Block a user