Revert "Improve namespace names"

This reverts commit 973df43fe0.

On a second thought, namespaces are close to classes in the sense that both encapsulate data, so they should both use Pascal case notation.
This commit is contained in:
Tiago Gomes 2013-02-08 21:01:53 +00:00
parent 264ef7a067
commit 6a1a209ee3
33 changed files with 381 additions and 382 deletions

View File

@ -79,8 +79,8 @@ BayesBall::constructGraph (FactorGraph* fg) const
} else if (n->hasEvidence() && n->isVisited()) { } else if (n->hasEvidence() && n->isVisited()) {
VarIds varIds = { facNodes[i]->factor().argument (0) }; VarIds varIds = { facNodes[i]->factor().argument (0) };
Ranges ranges = { facNodes[i]->factor().range (0) }; Ranges ranges = { facNodes[i]->factor().range (0) };
Params params (ranges[0], log_aware::noEvidence()); Params params (ranges[0], LogAware::noEvidence());
params[n->getEvidence()] = log_aware::withEvidence(); params[n->getEvidence()] = LogAware::withEvidence();
fg->addFactor (Factor (varIds, ranges, params)); fg->addFactor (Factor (varIds, ranges, params));
} }
} }

View File

@ -13,7 +13,7 @@ namespace horus {
void void
BayesBallGraph::addNode (BBNode* n) BayesBallGraph::addNode (BBNode* n)
{ {
assert (util::contains (varMap_, n->varId()) == false); assert (Util::contains (varMap_, n->varId()) == false);
nodes_.push_back (n); nodes_.push_back (n);
varMap_[n->varId()] = n; varMap_[n->varId()] = n;
} }

View File

@ -15,8 +15,8 @@ BpLink::BpLink (FacNode* fn, VarNode* vn)
{ {
fac_ = fn; fac_ = fn;
var_ = vn; var_ = vn;
v1_.resize (vn->range(), log_aware::log (1.0 / vn->range())); v1_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
v2_.resize (vn->range(), log_aware::log (1.0 / vn->range())); v2_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
currMsg_ = &v1_; currMsg_ = &v1_;
nextMsg_ = &v2_; nextMsg_ = &v2_;
residual_ = 0.0; residual_ = 0.0;
@ -35,7 +35,7 @@ BpLink::clearResidual (void)
void void
BpLink::updateResidual (void) BpLink::updateResidual (void)
{ {
residual_ = log_aware::getMaxNorm (v1_, v2_); residual_ = LogAware::getMaxNorm (v1_, v2_);
} }
@ -111,9 +111,9 @@ BeliefProp::printSolverFlags (void) const
case MsgSchedule::PARALLEL: ss << "parallel"; break; case MsgSchedule::PARALLEL: ss << "parallel"; break;
case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break; case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break;
} }
ss << ",bp_max_iter=" << util::toString (maxIter_); ss << ",bp_max_iter=" << Util::toString (maxIter_);
ss << ",bp_accuracy=" << util::toString (accuracy_); ss << ",bp_accuracy=" << Util::toString (accuracy_);
ss << ",log_domain=" << util::toString (globals::logDomain); ss << ",log_domain=" << Util::toString (Globals::logDomain);
ss << "]" ; ss << "]" ;
std::cout << ss.str() << std::endl; std::cout << ss.str() << std::endl;
} }
@ -130,22 +130,22 @@ BeliefProp::getPosterioriOf (VarId vid)
VarNode* var = fg.getVarNode (vid); VarNode* var = fg.getVarNode (vid);
Params probs; Params probs;
if (var->hasEvidence()) { if (var->hasEvidence()) {
probs.resize (var->range(), log_aware::noEvidence()); probs.resize (var->range(), LogAware::noEvidence());
probs[var->getEvidence()] = log_aware::withEvidence(); probs[var->getEvidence()] = LogAware::withEvidence();
} else { } else {
probs.resize (var->range(), log_aware::multIdenty()); probs.resize (var->range(), LogAware::multIdenty());
const BpLinks& links = ninf(var)->getLinks(); const BpLinks& links = ninf(var)->getLinks();
if (globals::logDomain) { if (Globals::logDomain) {
for (size_t i = 0; i < links.size(); i++) { for (size_t i = 0; i < links.size(); i++) {
probs += links[i]->message(); probs += links[i]->message();
} }
log_aware::normalize (probs); LogAware::normalize (probs);
util::exp (probs); Util::exp (probs);
} else { } else {
for (size_t i = 0; i < links.size(); i++) { for (size_t i = 0; i < links.size(); i++) {
probs *= links[i]->message(); probs *= links[i]->message();
} }
log_aware::normalize (probs); LogAware::normalize (probs);
} }
} }
return probs; return probs;
@ -196,8 +196,8 @@ BeliefProp::getFactorJoint (
res.reorderArguments (jointVarIds); res.reorderArguments (jointVarIds);
res.normalize(); res.normalize();
Params jointDist = res.params(); Params jointDist = res.params();
if (globals::logDomain) { if (Globals::logDomain) {
util::exp (jointDist); Util::exp (jointDist);
} }
return jointDist; return jointDist;
} }
@ -207,7 +207,7 @@ BeliefProp::getFactorJoint (
void void
BeliefProp::calculateAndUpdateMessage (BpLink* link, bool calcResidual) BeliefProp::calculateAndUpdateMessage (BpLink* link, bool calcResidual)
{ {
if (globals::verbosity > 2) { if (Globals::verbosity > 2) {
std::cout << "calculating & updating " << link->toString(); std::cout << "calculating & updating " << link->toString();
std::cout << std::endl; std::cout << std::endl;
} }
@ -223,7 +223,7 @@ BeliefProp::calculateAndUpdateMessage (BpLink* link, bool calcResidual)
void void
BeliefProp::calculateMessage (BpLink* link, bool calcResidual) BeliefProp::calculateMessage (BpLink* link, bool calcResidual)
{ {
if (globals::verbosity > 2) { if (Globals::verbosity > 2) {
std::cout << "calculating " << link->toString(); std::cout << "calculating " << link->toString();
std::cout << std::endl; std::cout << std::endl;
} }
@ -239,7 +239,7 @@ void
BeliefProp::updateMessage (BpLink* link) BeliefProp::updateMessage (BpLink* link)
{ {
link->updateMessage(); link->updateMessage();
if (globals::verbosity > 2) { if (Globals::verbosity > 2) {
std::cout << "updating " << link->toString(); std::cout << "updating " << link->toString();
std::cout << std::endl; std::cout << std::endl;
} }
@ -254,9 +254,9 @@ BeliefProp::runSolver (void)
nIters_ = 0; nIters_ = 0;
while (!converged() && nIters_ < maxIter_) { while (!converged() && nIters_ < maxIter_) {
nIters_ ++; nIters_ ++;
if (globals::verbosity > 1) { if (Globals::verbosity > 1) {
util::printHeader (std::string ("Iteration ") Util::printHeader (std::string ("Iteration ")
+ util::toString (nIters_)); + Util::toString (nIters_));
} }
switch (schedule_) { switch (schedule_) {
case MsgSchedule::SEQ_RANDOM: case MsgSchedule::SEQ_RANDOM:
@ -280,7 +280,7 @@ BeliefProp::runSolver (void)
break; break;
} }
} }
if (globals::verbosity > 0) { if (Globals::verbosity > 0) {
if (nIters_ < maxIter_) { if (nIters_ < maxIter_) {
std::cout << "Belief propagation converged in " ; std::cout << "Belief propagation converged in " ;
std::cout << nIters_ << " iterations" << std::endl; std::cout << nIters_ << " iterations" << std::endl;
@ -322,7 +322,7 @@ BeliefProp::maxResidualSchedule (void)
} }
for (size_t c = 0; c < links_.size(); c++) { for (size_t c = 0; c < links_.size(); c++) {
if (globals::verbosity > 1) { if (Globals::verbosity > 1) {
std::cout << "current residuals:" << std::endl; std::cout << "current residuals:" << std::endl;
for (SortedOrder::iterator it = sortedOrder_.begin(); for (SortedOrder::iterator it = sortedOrder_.begin();
it != sortedOrder_.end(); ++it) { it != sortedOrder_.end(); ++it) {
@ -358,8 +358,8 @@ BeliefProp::maxResidualSchedule (void)
} }
} }
} }
if (globals::verbosity > 1) { if (Globals::verbosity > 1) {
util::printDashedLine(); Util::printDashedLine();
} }
} }
} }
@ -375,18 +375,18 @@ BeliefProp::calcFactorToVarMsg (BpLink* link)
// calculate the product of messages that were sent // calculate the product of messages that were sent
// to factor `src', except from var `dst' // to factor `src', except from var `dst'
unsigned reps = 1; unsigned reps = 1;
unsigned msgSize = util::sizeExpected (src->factor().ranges()); unsigned msgSize = Util::sizeExpected (src->factor().ranges());
Params msgProduct (msgSize, log_aware::multIdenty()); Params msgProduct (msgSize, LogAware::multIdenty());
if (globals::logDomain) { if (Globals::logDomain) {
for (size_t i = links.size(); i-- > 0; ) { for (size_t i = links.size(); i-- > 0; ) {
if (links[i]->varNode() != dst) { if (links[i]->varNode() != dst) {
if (constants::SHOW_BP_CALCS) { if (Constants::SHOW_BP_CALCS) {
std::cout << " message from " << links[i]->varNode()->label(); std::cout << " message from " << links[i]->varNode()->label();
std::cout << ": " ; std::cout << ": " ;
} }
util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]), Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
reps, std::plus<double>()); reps, std::plus<double>());
if (constants::SHOW_BP_CALCS) { if (Constants::SHOW_BP_CALCS) {
std::cout << std::endl; std::cout << std::endl;
} }
} }
@ -395,13 +395,13 @@ BeliefProp::calcFactorToVarMsg (BpLink* link)
} else { } else {
for (size_t i = links.size(); i-- > 0; ) { for (size_t i = links.size(); i-- > 0; ) {
if (links[i]->varNode() != dst) { if (links[i]->varNode() != dst) {
if (constants::SHOW_BP_CALCS) { if (Constants::SHOW_BP_CALCS) {
std::cout << " message from " << links[i]->varNode()->label(); std::cout << " message from " << links[i]->varNode()->label();
std::cout << ": " ; std::cout << ": " ;
} }
util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]), Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
reps, std::multiplies<double>()); reps, std::multiplies<double>());
if (constants::SHOW_BP_CALCS) { if (Constants::SHOW_BP_CALCS) {
std::cout << std::endl; std::cout << std::endl;
} }
} }
@ -411,19 +411,19 @@ BeliefProp::calcFactorToVarMsg (BpLink* link)
Factor result (src->factor().arguments(), Factor result (src->factor().arguments(),
src->factor().ranges(), msgProduct); src->factor().ranges(), msgProduct);
result.multiply (src->factor()); result.multiply (src->factor());
if (constants::SHOW_BP_CALCS) { if (Constants::SHOW_BP_CALCS) {
std::cout << " message product: " << msgProduct << std::endl; std::cout << " message product: " << msgProduct << std::endl;
std::cout << " original factor: " << src->factor().params(); std::cout << " original factor: " << src->factor().params();
std::cout << std::endl; std::cout << std::endl;
std::cout << " factor product: " << result.params() << std::endl; std::cout << " factor product: " << result.params() << std::endl;
} }
result.sumOutAllExcept (dst->varId()); result.sumOutAllExcept (dst->varId());
if (constants::SHOW_BP_CALCS) { if (Constants::SHOW_BP_CALCS) {
std::cout << " marginalized: " << result.params() << std::endl; std::cout << " marginalized: " << result.params() << std::endl;
} }
link->nextMessage() = result.params(); link->nextMessage() = result.params();
log_aware::normalize (link->nextMessage()); LogAware::normalize (link->nextMessage());
if (constants::SHOW_BP_CALCS) { if (Constants::SHOW_BP_CALCS) {
std::cout << " curr msg: " << link->message() << std::endl; std::cout << " curr msg: " << link->message() << std::endl;
std::cout << " next msg: " << link->nextMessage() << std::endl; std::cout << " next msg: " << link->nextMessage() << std::endl;
} }
@ -437,22 +437,22 @@ BeliefProp::getVarToFactorMsg (const BpLink* link) const
const VarNode* src = link->varNode(); const VarNode* src = link->varNode();
Params msg; Params msg;
if (src->hasEvidence()) { if (src->hasEvidence()) {
msg.resize (src->range(), log_aware::noEvidence()); msg.resize (src->range(), LogAware::noEvidence());
msg[src->getEvidence()] = log_aware::withEvidence(); msg[src->getEvidence()] = LogAware::withEvidence();
} else { } else {
msg.resize (src->range(), log_aware::one()); msg.resize (src->range(), LogAware::one());
} }
if (constants::SHOW_BP_CALCS) { if (Constants::SHOW_BP_CALCS) {
std::cout << msg; std::cout << msg;
} }
BpLinks::const_iterator it; BpLinks::const_iterator it;
const BpLinks& links = ninf (src)->getLinks(); const BpLinks& links = ninf (src)->getLinks();
if (globals::logDomain) { if (Globals::logDomain) {
for (it = links.begin(); it != links.end(); ++it) { for (it = links.begin(); it != links.end(); ++it) {
if (*it != link) { if (*it != link) {
msg += (*it)->message(); msg += (*it)->message();
} }
if (constants::SHOW_BP_CALCS) { if (Constants::SHOW_BP_CALCS) {
std::cout << " x " << (*it)->message(); std::cout << " x " << (*it)->message();
} }
} }
@ -461,12 +461,12 @@ BeliefProp::getVarToFactorMsg (const BpLink* link) const
if (*it != link) { if (*it != link) {
msg *= (*it)->message(); msg *= (*it)->message();
} }
if (constants::SHOW_BP_CALCS) { if (Constants::SHOW_BP_CALCS) {
std::cout << " x " << (*it)->message(); std::cout << " x " << (*it)->message();
} }
} }
} }
if (constants::SHOW_BP_CALCS) { if (Constants::SHOW_BP_CALCS) {
std::cout << " = " << msg; std::cout << " = " << msg;
} }
return msg; return msg;
@ -516,11 +516,11 @@ BeliefProp::converged (void)
if (nIters_ == 0) { if (nIters_ == 0) {
return false; return false;
} }
if (globals::verbosity > 2) { if (Globals::verbosity > 2) {
std::cout << std::endl; std::cout << std::endl;
} }
if (nIters_ == 1) { if (nIters_ == 1) {
if (globals::verbosity > 1) { if (Globals::verbosity > 1) {
std::cout << "no residuals" << std::endl << std::endl; std::cout << "no residuals" << std::endl << std::endl;
} }
return false; return false;
@ -536,18 +536,18 @@ BeliefProp::converged (void)
} else { } else {
for (size_t i = 0; i < links_.size(); i++) { for (size_t i = 0; i < links_.size(); i++) {
double residual = links_[i]->residual(); double residual = links_[i]->residual();
if (globals::verbosity > 1) { if (Globals::verbosity > 1) {
std::cout << links_[i]->toString() + " residual = " << residual; std::cout << links_[i]->toString() + " residual = " << residual;
std::cout << std::endl; std::cout << std::endl;
} }
if (residual > accuracy_) { if (residual > accuracy_) {
converged = false; converged = false;
if (globals::verbosity < 2) { if (Globals::verbosity < 2) {
break; break;
} }
} }
} }
if (globals::verbosity > 1) { if (Globals::verbosity > 1) {
std::cout << std::endl; std::cout << std::endl;
} }
} }

View File

@ -278,7 +278,7 @@ void
ConstraintTree::moveToTop (const LogVars& lvs) ConstraintTree::moveToTop (const LogVars& lvs)
{ {
for (size_t i = 0; i < lvs.size(); i++) { for (size_t i = 0; i < lvs.size(); i++) {
size_t pos = util::indexOf (logVars_, lvs[i]); size_t pos = Util::indexOf (logVars_, lvs[i]);
assert (pos != logVars_.size()); assert (pos != logVars_.size());
for (size_t j = pos; j-- > i; ) { for (size_t j = pos; j-- > i; ) {
swapLogVar (logVars_[j]); swapLogVar (logVars_[j]);
@ -292,7 +292,7 @@ void
ConstraintTree::moveToBottom (const LogVars& lvs) ConstraintTree::moveToBottom (const LogVars& lvs)
{ {
for (size_t i = lvs.size(); i-- > 0; ) { for (size_t i = lvs.size(); i-- > 0; ) {
size_t pos = util::indexOf (logVars_, lvs[i]); size_t pos = Util::indexOf (logVars_, lvs[i]);
assert (pos != logVars_.size()); assert (pos != logVars_.size());
size_t stop = logVars_.size() - (lvs.size() - i - 1); size_t stop = logVars_.size() - (lvs.size() - i - 1);
for (size_t j = pos; j < stop - 1; j++) { for (size_t j = pos; j < stop - 1; j++) {
@ -329,7 +329,7 @@ ConstraintTree::join (ConstraintTree* ct, bool oneTwoOne)
if (intersect.empty()) { if (intersect.empty()) {
// cartesian product // cartesian product
appendOnBottom (root_, ct->root()->childs()); appendOnBottom (root_, ct->root()->childs());
util::addToVector (logVars_, ct->logVars_); Util::addToVector (logVars_, ct->logVars_);
logVarSet_ |= ct->logVarSet_; logVarSet_ |= ct->logVarSet_;
} else { } else {
moveToTop (intersect.elements()); moveToTop (intersect.elements());
@ -350,7 +350,7 @@ ConstraintTree::join (ConstraintTree* ct, bool oneTwoOne)
LogVars newLvs (ct->logVars().begin() + intersect.size(), LogVars newLvs (ct->logVars().begin() + intersect.size(),
ct->logVars().end()); ct->logVars().end());
util::addToVector (logVars_, newLvs); Util::addToVector (logVars_, newLvs);
logVarSet_ |= LogVarSet (newLvs); logVarSet_ |= LogVarSet (newLvs);
} }
} }
@ -360,7 +360,7 @@ ConstraintTree::join (ConstraintTree* ct, bool oneTwoOne)
unsigned unsigned
ConstraintTree::getLevel (LogVar X) const ConstraintTree::getLevel (LogVar X) const
{ {
unsigned level = util::indexOf (logVars_, X); unsigned level = Util::indexOf (logVars_, X);
level += 1; // root is in level 0, first logVar is in level 1 level += 1; // root is in level 0, first logVar is in level 1
return level; return level;
} }
@ -496,7 +496,7 @@ ConstraintTree::tupleSet (const LogVars& originalLvs)
{ {
LogVars uniqueLvs; LogVars uniqueLvs;
for (size_t i = 0; i < originalLvs.size(); i++) { for (size_t i = 0; i < originalLvs.size(); i++) {
if (util::contains (uniqueLvs, originalLvs[i]) == false) { if (Util::contains (uniqueLvs, originalLvs[i]) == false) {
uniqueLvs.push_back (originalLvs[i]); uniqueLvs.push_back (originalLvs[i]);
} }
} }
@ -510,7 +510,7 @@ ConstraintTree::tupleSet (const LogVars& originalLvs)
std::vector<size_t> indexes; std::vector<size_t> indexes;
indexes.reserve (originalLvs.size()); indexes.reserve (originalLvs.size());
for (size_t i = 0; i < originalLvs.size(); i++) { for (size_t i = 0; i < originalLvs.size(); i++) {
indexes.push_back (util::indexOf (uniqueLvs, originalLvs[i])); indexes.push_back (Util::indexOf (uniqueLvs, originalLvs[i]));
} }
Tuples tuples2; Tuples tuples2;
tuples2.reserve (tuples.size()); tuples2.reserve (tuples.size());
@ -1030,7 +1030,7 @@ ConstraintTree::appendOnBottom (CTNode* n, const CTChilds& childs)
void void
ConstraintTree::swapLogVar (LogVar X) ConstraintTree::swapLogVar (LogVar X)
{ {
size_t pos = util::indexOf (logVars_, X); size_t pos = Util::indexOf (logVars_, X);
assert (pos != logVars_.size()); assert (pos != logVars_.size());
const CTNodes& nodes = getNodesAtLevel (pos); const CTNodes& nodes = getNodesAtLevel (pos);
for (CTNodes::const_iterator nodeIt = nodes.begin(); for (CTNodes::const_iterator nodeIt = nodes.begin();

View File

@ -52,8 +52,8 @@ CountingBp::printSolverFlags (void) const
} }
ss << ",bp_max_iter=" << WeightedBp::maxIterations(); ss << ",bp_max_iter=" << WeightedBp::maxIterations();
ss << ",bp_accuracy=" << WeightedBp::accuracy(); ss << ",bp_accuracy=" << WeightedBp::accuracy();
ss << ",log_domain=" << util::toString (globals::logDomain); ss << ",log_domain=" << Util::toString (Globals::logDomain);
ss << ",fif=" << util::toString (CountingBp::fif_); ss << ",fif=" << Util::toString (CountingBp::fif_);
ss << "]" ; ss << "]" ;
std::cout << ss.str() << std::endl; std::cout << ss.str() << std::endl;
} }
@ -103,18 +103,18 @@ CountingBp::findIdenticalFactors()
return; return;
} }
for (size_t i = 0; i < facNodes.size(); i++) { for (size_t i = 0; i < facNodes.size(); i++) {
facNodes[i]->factor().setDistId (util::maxUnsigned()); facNodes[i]->factor().setDistId (Util::maxUnsigned());
} }
unsigned groupCount = 1; unsigned groupCount = 1;
for (size_t i = 0; i < facNodes.size() - 1; i++) { for (size_t i = 0; i < facNodes.size() - 1; i++) {
Factor& f1 = facNodes[i]->factor(); Factor& f1 = facNodes[i]->factor();
if (f1.distId() != util::maxUnsigned()) { if (f1.distId() != Util::maxUnsigned()) {
continue; continue;
} }
f1.setDistId (groupCount); f1.setDistId (groupCount);
for (size_t j = i + 1; j < facNodes.size(); j++) { for (size_t j = i + 1; j < facNodes.size(); j++) {
Factor& f2 = facNodes[j]->factor(); Factor& f2 = facNodes[j]->factor();
if (f2.distId() != util::maxUnsigned()) { if (f2.distId() != Util::maxUnsigned()) {
continue; continue;
} }
if (f1.size() == f2.size() && if (f1.size() == f2.size() &&
@ -303,7 +303,7 @@ CountingBp::getSignature (const FacNode* facNode)
VarId VarId
CountingBp::getRepresentative (VarId vid) CountingBp::getRepresentative (VarId vid)
{ {
assert (util::contains (varClusterMap_, vid)); assert (Util::contains (varClusterMap_, vid));
VarCluster* vc = varClusterMap_.find (vid)->second; VarCluster* vc = varClusterMap_.find (vid)->second;
return vc->representative()->varId(); return vc->representative()->varId();
} }
@ -314,7 +314,7 @@ FacNode*
CountingBp::getRepresentative (FacNode* fn) CountingBp::getRepresentative (FacNode* fn)
{ {
for (size_t i = 0; i < facClusters_.size(); i++) { for (size_t i = 0; i < facClusters_.size(); i++) {
if (util::contains (facClusters_[i]->members(), fn)) { if (Util::contains (facClusters_[i]->members(), fn)) {
return facClusters_[i]->representative(); return facClusters_[i]->representative();
} }
} }

View File

@ -55,7 +55,7 @@ ElimGraph::getEliminatingOrder (const VarIds& excludedVids)
VarIds elimOrder; VarIds elimOrder;
unmarked_.reserve (nodes_.size()); unmarked_.reserve (nodes_.size());
for (size_t i = 0; i < nodes_.size(); i++) { for (size_t i = 0; i < nodes_.size(); i++) {
if (util::contains (excludedVids, nodes_[i]->varId()) == false) { if (Util::contains (excludedVids, nodes_[i]->varId()) == false) {
unmarked_.insert (nodes_[i]); unmarked_.insert (nodes_[i]);
} }
} }
@ -142,7 +142,7 @@ ElimGraph::getEliminationOrder (
Factors::const_iterator first = factors.begin(); Factors::const_iterator first = factors.begin();
Factors::const_iterator end = factors.end(); Factors::const_iterator end = factors.end();
for (; first != end; ++first) { for (; first != end; ++first) {
util::addToVector (allVids, (*first)->arguments()); Util::addToVector (allVids, (*first)->arguments());
} }
TinySet<VarId> elimOrder (allVids); TinySet<VarId> elimOrder (allVids);
elimOrder -= TinySet<VarId> (excludedVids); elimOrder -= TinySet<VarId> (excludedVids);
@ -178,7 +178,7 @@ EgNode*
ElimGraph::getLowestCostNode (void) const ElimGraph::getLowestCostNode (void) const
{ {
EgNode* bestNode = 0; EgNode* bestNode = 0;
unsigned minCost = util::maxUnsigned(); unsigned minCost = Util::maxUnsigned();
EGNeighs::const_iterator it; EGNeighs::const_iterator it;
switch (elimHeuristic_) { switch (elimHeuristic_) {
case MIN_NEIGHBORS: { case MIN_NEIGHBORS: {

View File

@ -27,7 +27,7 @@ Factor::Factor (
ranges_ = ranges; ranges_ = ranges;
params_ = params; params_ = params;
distId_ = distId; distId_ = distId;
assert (params_.size() == util::sizeExpected (ranges_)); assert (params_.size() == Util::sizeExpected (ranges_));
} }
@ -43,7 +43,7 @@ Factor::Factor (
} }
params_ = params; params_ = params;
distId_ = distId; distId_ = distId;
assert (params_.size() == util::sizeExpected (ranges_)); assert (params_.size() == Util::sizeExpected (ranges_));
} }
@ -131,7 +131,7 @@ Factor::print (void) const
for (size_t i = 0; i < args_.size(); i++) { for (size_t i = 0; i < args_.size(); i++) {
vars.push_back (new Var (args_[i], ranges_[i])); vars.push_back (new Var (args_[i], ranges_[i]));
} }
std::vector<std::string> jointStrings = util::getStateLines (vars); std::vector<std::string> jointStrings = Util::getStateLines (vars);
for (size_t i = 0; i < params_.size(); i++) { for (size_t i = 0; i < params_.size(); i++) {
// cout << "[" << distId_ << "] " ; // cout << "[" << distId_ << "] " ;
std::cout << "f(" << jointStrings[i] << ")" ; std::cout << "f(" << jointStrings[i] << ")" ;
@ -149,11 +149,11 @@ void
Factor::sumOutFirstVariable (void) Factor::sumOutFirstVariable (void)
{ {
size_t sep = params_.size() / 2; size_t sep = params_.size() / 2;
if (globals::logDomain) { if (Globals::logDomain) {
std::transform ( std::transform (
params_.begin(), params_.begin() + sep, params_.begin(), params_.begin() + sep,
params_.begin() + sep, params_.begin(), params_.begin() + sep, params_.begin(),
util::logSum); Util::logSum);
} else { } else {
std::transform ( std::transform (
@ -174,10 +174,10 @@ Factor::sumOutLastVariable (void)
Params::iterator first1 = params_.begin(); Params::iterator first1 = params_.begin();
Params::iterator first2 = params_.begin(); Params::iterator first2 = params_.begin();
Params::iterator last = params_.end(); Params::iterator last = params_.end();
if (globals::logDomain) { if (Globals::logDomain) {
while (first2 != last) { while (first2 != last) {
// the arguments can be swaped, but that is ok // the arguments can be swaped, but that is ok
*first1++ = util::logSum (*first2++, *first2++); *first1++ = Util::logSum (*first2++, *first2++);
} }
} else { } else {
while (first2 != last) { while (first2 != last) {
@ -206,13 +206,13 @@ Factor::sumOutArgs (const std::vector<bool>& mask)
ranges_.push_back (ranges_[i]); ranges_.push_back (ranges_[i]);
} }
} }
Params newps (new_size, log_aware::addIdenty()); Params newps (new_size, LogAware::addIdenty());
Params::const_iterator first = params_.begin(); Params::const_iterator first = params_.begin();
Params::const_iterator last = params_.end(); Params::const_iterator last = params_.end();
MapIndexer indexer (oldRanges, mask); MapIndexer indexer (oldRanges, mask);
if (globals::logDomain) { if (Globals::logDomain) {
while (first != last) { while (first != last) {
newps[indexer] = util::logSum (newps[indexer], *first++); newps[indexer] = Util::logSum (newps[indexer], *first++);
++ indexer; ++ indexer;
} }
} else { } else {

View File

@ -34,7 +34,7 @@ class TFactor
void setDistId (unsigned id) { distId_ = id; } void setDistId (unsigned id) { distId_ = id; }
void normalize (void) { log_aware::normalize (params_); } void normalize (void) { LogAware::normalize (params_); }
void randomize (void); void randomize (void);
@ -91,7 +91,7 @@ template <typename T> inline void
TFactor<T>::setParams (const Params& newParams) TFactor<T>::setParams (const Params& newParams)
{ {
params_ = newParams; params_ = newParams;
assert (params_.size() == util::sizeExpected (ranges_)); assert (params_.size() == Util::sizeExpected (ranges_));
} }
@ -99,7 +99,7 @@ TFactor<T>::setParams (const Params& newParams)
template <typename T> inline size_t template <typename T> inline size_t
TFactor<T>::indexOf (const T& t) const TFactor<T>::indexOf (const T& t) const
{ {
return util::indexOf (args_, t); return Util::indexOf (args_, t);
} }
@ -136,7 +136,7 @@ TFactor<T>::multiply (TFactor<T>& g)
{ {
if (args_ == g.arguments()) { if (args_ == g.arguments()) {
// optimization // optimization
globals::logDomain Globals::logDomain
? params_ += g.params() ? params_ += g.params()
: params_ *= g.params(); : params_ *= g.params();
return; return;
@ -163,7 +163,7 @@ TFactor<T>::multiply (TFactor<T>& g)
extend (range_prod); extend (range_prod);
Params::iterator it = params_.begin(); Params::iterator it = params_.begin();
MapIndexer indexer (args_, ranges_, g_args, g_ranges); MapIndexer indexer (args_, ranges_, g_args, g_ranges);
if (globals::logDomain) { if (Globals::logDomain) {
for (; indexer.valid(); ++it, ++indexer) { for (; indexer.valid(); ++it, ++indexer) {
*it += g_params[indexer]; *it += g_params[indexer];
} }
@ -183,13 +183,13 @@ TFactor<T>::sumOutIndex (size_t idx)
assert (idx < args_.size()); assert (idx < args_.size());
assert (args_.size() > 1); assert (args_.size() > 1);
size_t new_size = params_.size() / ranges_[idx]; size_t new_size = params_.size() / ranges_[idx];
Params newps (new_size, log_aware::addIdenty()); Params newps (new_size, LogAware::addIdenty());
Params::const_iterator first = params_.begin(); Params::const_iterator first = params_.begin();
Params::const_iterator last = params_.end(); Params::const_iterator last = params_.end();
MapIndexer indexer (ranges_, idx); MapIndexer indexer (ranges_, idx);
if (globals::logDomain) { if (Globals::logDomain) {
for (; first != last; ++indexer) { for (; first != last; ++indexer) {
newps[indexer] = util::logSum (newps[indexer], *first++); newps[indexer] = Util::logSum (newps[indexer], *first++);
} }
} else { } else {
for (; first != last; ++indexer) { for (; first != last; ++indexer) {
@ -255,7 +255,7 @@ TFactor<T>::reorderArguments (const std::vector<T>& new_args)
template <typename T> inline bool template <typename T> inline bool
TFactor<T>::contains (const T& arg) const TFactor<T>::contains (const T& arg) const
{ {
return util::contains (args_, arg); return Util::contains (args_, arg);
} }
@ -310,7 +310,7 @@ TFactor<T>::cartesianProduct (
Params::const_iterator first1 = backup.begin(); Params::const_iterator first1 = backup.begin();
Params::const_iterator last1 = backup.end(); Params::const_iterator last1 = backup.end();
Params::const_iterator tmp; Params::const_iterator tmp;
if (globals::logDomain) { if (Globals::logDomain) {
for (; first1 != last1; ++first1) { for (; first1 != last1; ++first1) {
for (tmp = first2; tmp != last2; ++tmp) { for (tmp = first2; tmp != last2; ++tmp) {
params_.push_back ((*first1) + (*tmp)); params_.push_back ((*first1) + (*tmp));
@ -335,10 +335,10 @@ class Factor : public TFactor<VarId>
Factor (const Factor&); Factor (const Factor&);
Factor (const VarIds&, const Ranges&, const Params&, Factor (const VarIds&, const Ranges&, const Params&,
unsigned = util::maxUnsigned()); unsigned = Util::maxUnsigned());
Factor (const Vars&, const Params&, Factor (const Vars&, const Params&,
unsigned = util::maxUnsigned()); unsigned = Util::maxUnsigned());
void sumOut (VarId); void sumOut (VarId);

View File

@ -106,9 +106,9 @@ FactorGraph::readFromUaiFormat (const char* fileName)
for (unsigned i = 0; i < nrFactors; i++) { for (unsigned i = 0; i < nrFactors; i++) {
ignoreLines (is); ignoreLines (is);
is >> nrParams; is >> nrParams;
if (nrParams != util::sizeExpected (allRanges[i])) { if (nrParams != Util::sizeExpected (allRanges[i])) {
std::cerr << "Error: invalid number of parameters for factor nº " << i ; std::cerr << "Error: invalid number of parameters for factor nº " << i ;
std::cerr << ", " << util::sizeExpected (allRanges[i]); std::cerr << ", " << Util::sizeExpected (allRanges[i]);
std::cerr << " expected, " << nrParams << " given." << std::endl; std::cerr << " expected, " << nrParams << " given." << std::endl;
exit (EXIT_FAILURE); exit (EXIT_FAILURE);
} }
@ -116,8 +116,8 @@ FactorGraph::readFromUaiFormat (const char* fileName)
for (unsigned j = 0; j < nrParams; j++) { for (unsigned j = 0; j < nrParams; j++) {
is >> params[j]; is >> params[j];
} }
if (globals::logDomain) { if (Globals::logDomain) {
util::log (params); Util::log (params);
} }
Factor f (allVarIds[i], allRanges[i], params); Factor f (allVarIds[i], allRanges[i], params);
if (bayesFactors_ && allVarIds[i].size() > 1) { if (bayesFactors_ && allVarIds[i].size() > 1) {
@ -171,7 +171,7 @@ FactorGraph::readFromLibDaiFormat (const char* fileName)
ignoreLines (is); ignoreLines (is);
unsigned nNonzeros; unsigned nNonzeros;
is >> nNonzeros; is >> nNonzeros;
Params params (util::sizeExpected (ranges), 0); Params params (Util::sizeExpected (ranges), 0);
for (unsigned j = 0; j < nNonzeros; j++) { for (unsigned j = 0; j < nNonzeros; j++) {
ignoreLines (is); ignoreLines (is);
unsigned index; unsigned index;
@ -181,8 +181,8 @@ FactorGraph::readFromLibDaiFormat (const char* fileName)
is >> val; is >> val;
params[index] = val; params[index] = val;
} }
if (globals::logDomain) { if (Globals::logDomain) {
util::log (params); Util::log (params);
} }
std::reverse (vids.begin(), vids.end()); std::reverse (vids.begin(), vids.end());
Factor f (vids, ranges, params); Factor f (vids, ranges, params);
@ -306,13 +306,13 @@ FactorGraph::exportToLibDai (const char* fileName) const
for (size_t i = 0; i < facNodes_.size(); i++) { for (size_t i = 0; i < facNodes_.size(); i++) {
Factor f (facNodes_[i]->factor()); Factor f (facNodes_[i]->factor());
out << f.nrArguments() << std::endl; out << f.nrArguments() << std::endl;
out << util::elementsToString (f.arguments()) << std::endl; out << Util::elementsToString (f.arguments()) << std::endl;
out << util::elementsToString (f.ranges()) << std::endl; out << Util::elementsToString (f.ranges()) << std::endl;
VarIds args = f.arguments(); VarIds args = f.arguments();
std::reverse (args.begin(), args.end()); std::reverse (args.begin(), args.end());
f.reorderArguments (args); f.reorderArguments (args);
if (globals::logDomain) { if (Globals::logDomain) {
util::exp (f.params()); Util::exp (f.params());
} }
out << f.size() << std::endl; out << f.size() << std::endl;
for (size_t j = 0; j < f.size(); j++) { for (size_t j = 0; j < f.size(); j++) {
@ -347,7 +347,7 @@ FactorGraph::exportToUai (const char* fileName) const
if (bayesFactors_) { if (bayesFactors_) {
std::swap (args.front(), args.back()); std::swap (args.front(), args.back());
} }
out << args.size() << " " << util::elementsToString (args); out << args.size() << " " << Util::elementsToString (args);
out << std::endl; out << std::endl;
} }
out << std::endl; out << std::endl;
@ -359,11 +359,11 @@ FactorGraph::exportToUai (const char* fileName) const
f.reorderArguments (args); f.reorderArguments (args);
} }
Params params = f.params(); Params params = f.params();
if (globals::logDomain) { if (Globals::logDomain) {
util::exp (params); Util::exp (params);
} }
out << params.size() << std::endl << " " ; out << params.size() << std::endl << " " ;
out << util::elementsToString (params); out << Util::elementsToString (params);
out << std::endl << std::endl; out << std::endl << std::endl;
} }
out.close(); out.close();

View File

@ -20,7 +20,7 @@ class VarNode : public Var
{ {
public: public:
VarNode (VarId varId, unsigned nrStates, VarNode (VarId varId, unsigned nrStates,
int evidence = constants::NO_EVIDENCE) int evidence = Constants::NO_EVIDENCE)
: Var (varId, nrStates, evidence) { } : Var (varId, nrStates, evidence) { }
VarNode (const Var* v) : Var (v) { } VarNode (const Var* v) : Var (v) { }

View File

@ -28,10 +28,10 @@ GroundSolver::printAnswer (const VarIds& vids)
if (unobservedVids.empty() == false) { if (unobservedVids.empty() == false) {
Params res = solveQuery (unobservedVids); Params res = solveQuery (unobservedVids);
std::vector<std::string> stateLines = std::vector<std::string> stateLines =
util::getStateLines (unobservedVars); Util::getStateLines (unobservedVars);
for (size_t i = 0; i < res.size(); i++) { for (size_t i = 0; i < res.size(); i++) {
std::cout << "P(" << stateLines[i] << ") = " ; std::cout << "P(" << stateLines[i] << ") = " ;
std::cout << std::setprecision (constants::PRECISION) << res[i]; std::cout << std::setprecision (Constants::PRECISION) << res[i];
std::cout << std::endl; std::cout << std::endl;
} }
std::cout << std::endl; std::cout << std::endl;

View File

@ -80,7 +80,7 @@ HistogramSet::getHistograms (unsigned N, unsigned R)
unsigned unsigned
HistogramSet::nrHistograms (unsigned N, unsigned R) HistogramSet::nrHistograms (unsigned N, unsigned R)
{ {
return util::nrCombinations (N + R - 1, R - 1); return Util::nrCombinations (N + R - 1, R - 1);
} }
@ -102,17 +102,17 @@ std::vector<double>
HistogramSet::getNumAssigns (unsigned N, unsigned R) HistogramSet::getNumAssigns (unsigned N, unsigned R)
{ {
HistogramSet hs (N, R); HistogramSet hs (N, R);
double N_fac = util::logFactorial (N); double N_fac = Util::logFactorial (N);
unsigned H = hs.nrHistograms(); unsigned H = hs.nrHistograms();
std::vector<double> numAssigns; std::vector<double> numAssigns;
numAssigns.reserve (H); numAssigns.reserve (H);
for (unsigned h = 0; h < H; h++) { for (unsigned h = 0; h < H; h++) {
double prod = 0.0; double prod = 0.0;
for (unsigned r = 0; r < R; r++) { for (unsigned r = 0; r < R; r++) {
prod += util::logFactorial (hs[r]); prod += Util::logFactorial (hs[r]);
} }
double res = N_fac - prod; double res = N_fac - prod;
numAssigns.push_back (globals::logDomain ? res : std::exp (res)); numAssigns.push_back (Globals::logDomain ? res : std::exp (res));
hs.nextHistogram(); hs.nextHistogram();
} }
return numAssigns; return numAssigns;

View File

@ -50,7 +50,7 @@ enum GroundSolverType
}; };
namespace globals { namespace Globals {
extern bool logDomain; extern bool logDomain;
@ -63,7 +63,7 @@ extern GroundSolverType groundSolver;
} }
namespace constants { namespace Constants {
// show message calculation for belief propagation // show message calculation for belief propagation
const bool SHOW_BP_CALCS = false; const bool SHOW_BP_CALCS = false;
@ -73,7 +73,7 @@ const int NO_EVIDENCE = -1;
// number of digits to show when printing a parameter // number of digits to show when printing a parameter
const unsigned PRECISION = 6; const unsigned PRECISION = 6;
} // namespace constants }
} // namespace horus } // namespace horus

View File

@ -47,7 +47,7 @@ main (int argc, const char* argv[])
if (horus::FactorGraph::printFactorGraph()) { if (horus::FactorGraph::printFactorGraph()) {
fg.print(); fg.print();
} }
if (horus::globals::verbosity > 0) { if (horus::Globals::verbosity > 0) {
std::cout << "factor graph contains " ; std::cout << "factor graph contains " ;
std::cout << fg.nrVarNodes() << " variables and " ; std::cout << fg.nrVarNodes() << " variables and " ;
std::cout << fg.nrFacNodes() << " factors " << std::endl; std::cout << fg.nrFacNodes() << " factors " << std::endl;
@ -80,7 +80,7 @@ readHorusFlags (int argc, const char* argv[])
std::cerr << USAGE << std::endl; std::cerr << USAGE << std::endl;
exit (EXIT_FAILURE); exit (EXIT_FAILURE);
} }
horus::util::setHorusFlag (leftArg, rightArg); horus::Util::setHorusFlag (leftArg, rightArg);
} }
return i + 1; return i + 1;
} }
@ -116,13 +116,13 @@ readQueryAndEvidence (
for (int i = start; i < argc; i++) { for (int i = start; i < argc; i++) {
const std::string& arg = argv[i]; const std::string& arg = argv[i];
if (arg.find ('=') == std::string::npos) { if (arg.find ('=') == std::string::npos) {
if (horus::util::isInteger (arg) == false) { if (horus::Util::isInteger (arg) == false) {
std::cerr << "Error: `" << arg << "' " ; std::cerr << "Error: `" << arg << "' " ;
std::cerr << "is not a variable id." ; std::cerr << "is not a variable id." ;
std::cerr << std::endl; std::cerr << std::endl;
exit (EXIT_FAILURE); exit (EXIT_FAILURE);
} }
horus::VarId vid = horus::util::stringToUnsigned (arg); horus::VarId vid = horus::Util::stringToUnsigned (arg);
horus::VarNode* queryVar = fg.getVarNode (vid); horus::VarNode* queryVar = fg.getVarNode (vid);
if (queryVar == false) { if (queryVar == false) {
std::cerr << "Error: unknow variable with id " ; std::cerr << "Error: unknow variable with id " ;
@ -139,12 +139,12 @@ readQueryAndEvidence (
std::cerr << USAGE << std::endl; std::cerr << USAGE << std::endl;
exit (EXIT_FAILURE); exit (EXIT_FAILURE);
} }
if (horus::util::isInteger (leftArg) == false) { if (horus::Util::isInteger (leftArg) == false) {
std::cerr << "Error: `" << leftArg << "' " ; std::cerr << "Error: `" << leftArg << "' " ;
std::cerr << "is not a variable id." << std::endl; std::cerr << "is not a variable id." << std::endl;
exit (EXIT_FAILURE); exit (EXIT_FAILURE);
} }
horus::VarId vid = horus::util::stringToUnsigned (leftArg); horus::VarId vid = horus::Util::stringToUnsigned (leftArg);
horus::VarNode* observedVar = fg.getVarNode (vid); horus::VarNode* observedVar = fg.getVarNode (vid);
if (observedVar == false) { if (observedVar == false) {
std::cerr << "Error: unknow variable with id " ; std::cerr << "Error: unknow variable with id " ;
@ -156,12 +156,12 @@ readQueryAndEvidence (
std::cerr << USAGE << std::endl; std::cerr << USAGE << std::endl;
exit (EXIT_FAILURE); exit (EXIT_FAILURE);
} }
if (horus::util::isInteger (rightArg) == false) { if (horus::Util::isInteger (rightArg) == false) {
std::cerr << "Error: `" << rightArg << "' " ; std::cerr << "Error: `" << rightArg << "' " ;
std::cerr << "is not a state index." << std::endl; std::cerr << "is not a state index." << std::endl;
exit (EXIT_FAILURE); exit (EXIT_FAILURE);
} }
unsigned stateIdx = horus::util::stringToUnsigned (rightArg); unsigned stateIdx = horus::Util::stringToUnsigned (rightArg);
if (observedVar->isValidState (stateIdx) == false) { if (observedVar->isValidState (stateIdx) == false) {
std::cerr << "Error: `" << stateIdx << "' " ; std::cerr << "Error: `" << stateIdx << "' " ;
std::cerr << "is not a valid state index for variable with id " ; std::cerr << "is not a valid state index for variable with id " ;
@ -182,7 +182,7 @@ runSolver (
const horus::VarIds& queryIds) const horus::VarIds& queryIds)
{ {
horus::GroundSolver* solver = 0; horus::GroundSolver* solver = 0;
switch (horus::globals::groundSolver) { switch (horus::Globals::groundSolver) {
case horus::GroundSolverType::VE: case horus::GroundSolverType::VE:
solver = new horus::VarElim (fg); solver = new horus::VarElim (fg);
break; break;
@ -195,7 +195,7 @@ runSolver (
default: default:
assert (false); assert (false);
} }
if (horus::globals::verbosity > 0) { if (horus::Globals::verbosity > 0) {
solver->printSolverFlags(); solver->printSolverFlags();
std::cout << std::endl; std::cout << std::endl;
} }

View File

@ -48,8 +48,8 @@ createLiftedNetwork (void)
} }
// LiftedUtils::printSymbolDictionary(); // LiftedUtils::printSymbolDictionary();
if (globals::verbosity > 2) { if (Globals::verbosity > 2) {
util::printHeader ("INITIAL PARFACTORS"); Util::printHeader ("INITIAL PARFACTORS");
for (size_t i = 0; i < parfactors.size(); i++) { for (size_t i = 0; i < parfactors.size(); i++) {
parfactors[i]->print(); parfactors[i]->print();
std::cout << std::endl; std::cout << std::endl;
@ -58,8 +58,8 @@ createLiftedNetwork (void)
ParfactorList* pfList = new ParfactorList (parfactors); ParfactorList* pfList = new ParfactorList (parfactors);
if (globals::verbosity > 2) { if (Globals::verbosity > 2) {
util::printHeader ("SHATTERED PARFACTORS"); Util::printHeader ("SHATTERED PARFACTORS");
pfList->print(); pfList->print();
} }
@ -120,7 +120,7 @@ createGroundNetwork (void)
if (FactorGraph::printFactorGraph()) { if (FactorGraph::printFactorGraph()) {
fg->print(); fg->print();
} }
if (globals::verbosity > 0) { if (Globals::verbosity > 0) {
std::cout << "factor graph contains " ; std::cout << "factor graph contains " ;
std::cout << fg->nrVarNodes() << " variables and " ; std::cout << fg->nrVarNodes() << " variables and " ;
std::cout << fg->nrFacNodes() << " factors " << std::endl; std::cout << fg->nrFacNodes() << " factors " << std::endl;
@ -139,13 +139,13 @@ runLiftedSolver (void)
LiftedOperations::absorveEvidence (pfListCopy, *network->second); LiftedOperations::absorveEvidence (pfListCopy, *network->second);
LiftedSolver* solver = 0; LiftedSolver* solver = 0;
switch (globals::liftedSolver) { switch (Globals::liftedSolver) {
case LiftedSolverType::LVE: solver = new LiftedVe (pfListCopy); break; case LiftedSolverType::LVE: solver = new LiftedVe (pfListCopy); break;
case LiftedSolverType::LBP: solver = new LiftedBp (pfListCopy); break; case LiftedSolverType::LBP: solver = new LiftedBp (pfListCopy); break;
case LiftedSolverType::LKC: solver = new LiftedKc (pfListCopy); break; case LiftedSolverType::LKC: solver = new LiftedKc (pfListCopy); break;
} }
if (globals::verbosity > 0) { if (Globals::verbosity > 0) {
solver->printSolverFlags(); solver->printSolverFlags();
std::cout << std::endl; std::cout << std::endl;
} }
@ -205,7 +205,7 @@ runGroundSolver (void)
if (fg->bayesianFactors()) { if (fg->bayesianFactors()) {
std::set<VarId> vids; std::set<VarId> vids;
for (size_t i = 0; i < tasks.size(); i++) { for (size_t i = 0; i < tasks.size(); i++) {
util::addToSet (vids, tasks[i]); Util::addToSet (vids, tasks[i]);
} }
mfg = BayesBall::getMinimalFactorGraph ( mfg = BayesBall::getMinimalFactorGraph (
*fg, VarIds (vids.begin(), vids.end())); *fg, VarIds (vids.begin(), vids.end()));
@ -213,13 +213,13 @@ runGroundSolver (void)
GroundSolver* solver = 0; GroundSolver* solver = 0;
CountingBp::setFindIdenticalFactorsFlag (false); CountingBp::setFindIdenticalFactorsFlag (false);
switch (globals::groundSolver) { switch (Globals::groundSolver) {
case GroundSolverType::VE: solver = new VarElim (*mfg); break; case GroundSolverType::VE: solver = new VarElim (*mfg); break;
case GroundSolverType::BP: solver = new BeliefProp (*mfg); break; case GroundSolverType::BP: solver = new BeliefProp (*mfg); break;
case GroundSolverType::CBP: solver = new CountingBp (*mfg); break; case GroundSolverType::CBP: solver = new CountingBp (*mfg); break;
} }
if (globals::verbosity > 0) { if (Globals::verbosity > 0) {
solver->printSolverFlags(); solver->printSolverFlags();
std::cout << std::endl; std::cout << std::endl;
} }
@ -251,14 +251,14 @@ setParfactorsParams (void)
while (distIdsList != YAP_TermNil()) { while (distIdsList != YAP_TermNil()) {
unsigned distId = (unsigned) YAP_IntOfTerm ( unsigned distId = (unsigned) YAP_IntOfTerm (
YAP_HeadOfTerm (distIdsList)); YAP_HeadOfTerm (distIdsList));
assert (util::contains (paramsMap, distId) == false); assert (Util::contains (paramsMap, distId) == false);
paramsMap[distId] = readParameters (YAP_HeadOfTerm (paramsList)); paramsMap[distId] = readParameters (YAP_HeadOfTerm (paramsList));
distIdsList = YAP_TailOfTerm (distIdsList); distIdsList = YAP_TailOfTerm (distIdsList);
paramsList = YAP_TailOfTerm (paramsList); paramsList = YAP_TailOfTerm (paramsList);
} }
ParfactorList::iterator it = pfList->begin(); ParfactorList::iterator it = pfList->begin();
while (it != pfList->end()) { while (it != pfList->end()) {
assert (util::contains (paramsMap, (*it)->distId())); assert (Util::contains (paramsMap, (*it)->distId()));
(*it)->setParams (paramsMap[(*it)->distId()]); (*it)->setParams (paramsMap[(*it)->distId()]);
++ it; ++ it;
} }
@ -277,7 +277,7 @@ setFactorsParams (void)
while (distIdsList != YAP_TermNil()) { while (distIdsList != YAP_TermNil()) {
unsigned distId = (unsigned) YAP_IntOfTerm ( unsigned distId = (unsigned) YAP_IntOfTerm (
YAP_HeadOfTerm (distIdsList)); YAP_HeadOfTerm (distIdsList));
assert (util::contains (paramsMap, distId) == false); assert (Util::contains (paramsMap, distId) == false);
paramsMap[distId] = readParameters (YAP_HeadOfTerm (paramsList)); paramsMap[distId] = readParameters (YAP_HeadOfTerm (paramsList));
distIdsList = YAP_TailOfTerm (distIdsList); distIdsList = YAP_TailOfTerm (distIdsList);
paramsList = YAP_TailOfTerm (paramsList); paramsList = YAP_TailOfTerm (paramsList);
@ -285,7 +285,7 @@ setFactorsParams (void)
const FacNodes& facNodes = fg->facNodes(); const FacNodes& facNodes = fg->facNodes();
for (size_t i = 0; i < facNodes.size(); i++) { for (size_t i = 0; i < facNodes.size(); i++) {
unsigned distId = facNodes[i]->factor().distId(); unsigned distId = facNodes[i]->factor().distId();
assert (util::contains (paramsMap, distId)); assert (Util::contains (paramsMap, distId));
facNodes[i]->factor().setParams (paramsMap[distId]); facNodes[i]->factor().setParams (paramsMap[distId]);
} }
return TRUE; return TRUE;
@ -343,7 +343,7 @@ setHorusFlag (void)
} else { } else {
value = ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG2))); value = ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG2)));
} }
return util::setHorusFlag (option, value); return Util::setHorusFlag (option, value);
} }
@ -519,8 +519,8 @@ readParameters (YAP_Term paramL)
params.push_back ((double) YAP_FloatOfTerm (YAP_HeadOfTerm (paramL))); params.push_back ((double) YAP_FloatOfTerm (YAP_HeadOfTerm (paramL)));
paramL = YAP_TailOfTerm (paramL); paramL = YAP_TailOfTerm (paramL);
} }
if (globals::logDomain) { if (Globals::logDomain) {
util::log (params); Util::log (params);
} }
return params; return params;
} }

View File

@ -56,7 +56,7 @@ class Indexer
inline inline
Indexer::Indexer (const Ranges& ranges, bool calcOffsets) Indexer::Indexer (const Ranges& ranges, bool calcOffsets)
: index_(0), indices_(ranges.size(), 0), ranges_(ranges), : index_(0), indices_(ranges.size(), 0), ranges_(ranges),
size_(util::sizeExpected (ranges)) size_(Util::sizeExpected (ranges))
{ {
if (calcOffsets) { if (calcOffsets) {
calculateOffsets(); calculateOffsets();
@ -290,7 +290,7 @@ MapIndexer::MapIndexer (
} }
offsets_.reserve (allArgs.size()); offsets_.reserve (allArgs.size());
for (size_t i = 0; i < allArgs.size(); i++) { for (size_t i = 0; i < allArgs.size(); i++) {
size_t idx = util::indexOf (wantedArgs, allArgs[i]); size_t idx = Util::indexOf (wantedArgs, allArgs[i]);
offsets_.push_back (idx != wantedArgs.size() ? offsets[idx] : 0); offsets_.push_back (idx != wantedArgs.size() ? offsets[idx] : 0);
} }
} }

View File

@ -77,7 +77,7 @@ LiftedBp::printSolverFlags (void) const
} }
ss << ",bp_max_iter=" << WeightedBp::maxIterations(); ss << ",bp_max_iter=" << WeightedBp::maxIterations();
ss << ",bp_accuracy=" << WeightedBp::accuracy(); ss << ",bp_accuracy=" << WeightedBp::accuracy();
ss << ",log_domain=" << util::toString (globals::logDomain); ss << ",log_domain=" << Util::toString (Globals::logDomain);
ss << "]" ; ss << "]" ;
std::cout << ss.str() << std::endl; std::cout << ss.str() << std::endl;
} }
@ -90,8 +90,8 @@ LiftedBp::refineParfactors (void)
pfList_ = parfactorList; pfList_ = parfactorList;
while (iterate() == false); while (iterate() == false);
if (globals::verbosity > 2) { if (Globals::verbosity > 2) {
util::printHeader ("AFTER REFINEMENT"); Util::printHeader ("AFTER REFINEMENT");
pfList_.print(); pfList_.print();
} }
} }
@ -187,7 +187,7 @@ LiftedBp::rangeOfGround (const Ground& gr)
} }
++ it; ++ it;
} }
return util::maxUnsigned(); return Util::maxUnsigned();
} }

View File

@ -22,7 +22,7 @@ OrNode::weight (void) const
{ {
double lw = leftBranch_->weight(); double lw = leftBranch_->weight();
double rw = rightBranch_->weight(); double rw = rightBranch_->weight();
return globals::logDomain ? util::logSum (lw, rw) : lw + rw; return Globals::logDomain ? Util::logSum (lw, rw) : lw + rw;
} }
@ -40,7 +40,7 @@ AndNode::weight (void) const
{ {
double lw = leftBranch_->weight(); double lw = leftBranch_->weight();
double rw = rightBranch_->weight(); double rw = rightBranch_->weight();
return globals::logDomain ? lw + rw : lw * rw; return Globals::logDomain ? lw + rw : lw * rw;
} }
@ -60,17 +60,17 @@ SetOrNode::~SetOrNode (void)
double double
SetOrNode::weight (void) const SetOrNode::weight (void) const
{ {
double weightSum = log_aware::addIdenty(); double weightSum = LogAware::addIdenty();
for (unsigned i = 0; i < nrGroundings_ + 1; i++) { for (unsigned i = 0; i < nrGroundings_ + 1; i++) {
nrPos_ = nrGroundings_ - i; nrPos_ = nrGroundings_ - i;
nrNeg_ = i; nrNeg_ = i;
if (globals::logDomain) { if (Globals::logDomain) {
double nrCombs = util::nrCombinations (nrGroundings_, i); double nrCombs = Util::nrCombinations (nrGroundings_, i);
double w = follow_->weight(); double w = follow_->weight();
weightSum = util::logSum (weightSum, std::log (nrCombs) + w); weightSum = Util::logSum (weightSum, std::log (nrCombs) + w);
} else { } else {
double w = follow_->weight(); double w = follow_->weight();
weightSum += util::nrCombinations (nrGroundings_, i) * w; weightSum += Util::nrCombinations (nrGroundings_, i) * w;
} }
} }
nrPos_ = -1; nrPos_ = -1;
@ -90,7 +90,7 @@ SetAndNode::~SetAndNode (void)
double double
SetAndNode::weight (void) const SetAndNode::weight (void) const
{ {
return log_aware::pow (follow_->weight(), nrGroundings_); return LogAware::pow (follow_->weight(), nrGroundings_);
} }
@ -108,8 +108,8 @@ double
IncExcNode::weight (void) const IncExcNode::weight (void) const
{ {
double w = 0.0; double w = 0.0;
if (globals::logDomain) { if (Globals::logDomain) {
w = util::logSum (plus1Branch_->weight(), plus2Branch_->weight()); w = Util::logSum (plus1Branch_->weight(), plus2Branch_->weight());
w = std::log (std::exp (w) - std::exp (minusBranch_->weight())); w = std::log (std::exp (w) - std::exp (minusBranch_->weight()));
} else { } else {
w = plus1Branch_->weight() + plus2Branch_->weight(); w = plus1Branch_->weight() + plus2Branch_->weight();
@ -160,7 +160,7 @@ LeafNode::weight (void) const
nrGroundings *= std::pow (SetOrNode::nrNegatives(), nrGroundings *= std::pow (SetOrNode::nrNegatives(),
clause_->nrNegCountedLogVars()); clause_->nrNegCountedLogVars());
} }
return log_aware::pow (weight, nrGroundings); return LogAware::pow (weight, nrGroundings);
} }
@ -176,7 +176,7 @@ double
SmoothNode::weight (void) const SmoothNode::weight (void) const
{ {
Clauses cs = clauses(); Clauses cs = clauses();
double totalWeight = log_aware::multIdenty(); double totalWeight = LogAware::multIdenty();
for (size_t i = 0; i < cs.size(); i++) { for (size_t i = 0; i < cs.size(); i++) {
double posWeight = lwcnf_.posWeight (cs[i]->literals()[0].lid()); double posWeight = lwcnf_.posWeight (cs[i]->literals()[0].lid());
double negWeight = lwcnf_.negWeight (cs[i]->literals()[0].lid()); double negWeight = lwcnf_.negWeight (cs[i]->literals()[0].lid());
@ -196,8 +196,8 @@ SmoothNode::weight (void) const
nrGroundings *= std::pow (SetOrNode::nrNegatives(), nrGroundings *= std::pow (SetOrNode::nrNegatives(),
cs[i]->nrNegCountedLogVars()); cs[i]->nrNegCountedLogVars());
} }
if (globals::logDomain) { if (Globals::logDomain) {
totalWeight += util::logSum (posWeight, negWeight) * nrGroundings; totalWeight += Util::logSum (posWeight, negWeight) * nrGroundings;
} else { } else {
totalWeight *= std::pow (posWeight + negWeight, nrGroundings); totalWeight *= std::pow (posWeight + negWeight, nrGroundings);
} }
@ -210,7 +210,7 @@ SmoothNode::weight (void) const
double double
TrueNode::weight (void) const TrueNode::weight (void) const
{ {
return log_aware::multIdenty(); return LogAware::multIdenty();
} }
@ -235,9 +235,9 @@ LiftedCircuit::LiftedCircuit (const LiftedWCNF* lwcnf)
if (compilationSucceeded_) { if (compilationSucceeded_) {
smoothCircuit (root_); smoothCircuit (root_);
} }
if (globals::verbosity > 1) { if (Globals::verbosity > 1) {
if (compilationSucceeded_) { if (compilationSucceeded_) {
double wmc = log_aware::exp (getWeightedModelCount()); double wmc = LogAware::exp (getWeightedModelCount());
std::cout << "Weighted model count = " << wmc; std::cout << "Weighted model count = " << wmc;
std::cout << std::endl << std::endl; std::cout << std::endl << std::endl;
} }
@ -302,7 +302,7 @@ LiftedCircuit::compile (
Clauses& clauses) Clauses& clauses)
{ {
if (compilationSucceeded_ == false if (compilationSucceeded_ == false
&& globals::verbosity <= 1) { && Globals::verbosity <= 1) {
return; return;
} }
@ -341,7 +341,7 @@ LiftedCircuit::compile (
} }
*follow = new CompilationFailedNode(); *follow = new CompilationFailedNode();
if (globals::verbosity > 1) { if (Globals::verbosity > 1) {
originClausesMap_[*follow] = clauses; originClausesMap_[*follow] = clauses;
explanationMap_[*follow] = "" ; explanationMap_[*follow] = "" ;
} }
@ -355,7 +355,7 @@ LiftedCircuit::tryUnitPropagation (
CircuitNode** follow, CircuitNode** follow,
Clauses& clauses) Clauses& clauses)
{ {
if (globals::verbosity > 1) { if (Globals::verbosity > 1) {
backupClauses_ = Clause::copyClauses (clauses); backupClauses_ = Clause::copyClauses (clauses);
} }
for (size_t i = 0; i < clauses.size(); i++) { for (size_t i = 0; i < clauses.size(); i++) {
@ -392,7 +392,7 @@ LiftedCircuit::tryUnitPropagation (
} }
AndNode* andNode = new AndNode(); AndNode* andNode = new AndNode();
if (globals::verbosity > 1) { if (Globals::verbosity > 1) {
originClausesMap_[andNode] = backupClauses_; originClausesMap_[andNode] = backupClauses_;
std::stringstream explanation; std::stringstream explanation;
explanation << " UP on " << clauses[i]->literals()[0]; explanation << " UP on " << clauses[i]->literals()[0];
@ -406,7 +406,7 @@ LiftedCircuit::tryUnitPropagation (
return true; return true;
} }
} }
if (globals::verbosity > 1) { if (Globals::verbosity > 1) {
Clause::deleteClauses (backupClauses_); Clause::deleteClauses (backupClauses_);
} }
return false; return false;
@ -422,7 +422,7 @@ LiftedCircuit::tryIndependence (
if (clauses.size() == 1) { if (clauses.size() == 1) {
return false; return false;
} }
if (globals::verbosity > 1) { if (Globals::verbosity > 1) {
backupClauses_ = Clause::copyClauses (clauses); backupClauses_ = Clause::copyClauses (clauses);
} }
Clauses depClauses = { clauses[0] }; Clauses depClauses = { clauses[0] };
@ -441,7 +441,7 @@ LiftedCircuit::tryIndependence (
} }
if (indepClauses.empty() == false) { if (indepClauses.empty() == false) {
AndNode* andNode = new AndNode (); AndNode* andNode = new AndNode ();
if (globals::verbosity > 1) { if (Globals::verbosity > 1) {
originClausesMap_[andNode] = backupClauses_; originClausesMap_[andNode] = backupClauses_;
explanationMap_[andNode] = " Independence" ; explanationMap_[andNode] = " Independence" ;
} }
@ -450,7 +450,7 @@ LiftedCircuit::tryIndependence (
(*follow) = andNode; (*follow) = andNode;
return true; return true;
} }
if (globals::verbosity > 1) { if (Globals::verbosity > 1) {
Clause::deleteClauses (backupClauses_); Clause::deleteClauses (backupClauses_);
} }
return false; return false;
@ -463,7 +463,7 @@ LiftedCircuit::tryShannonDecomp (
CircuitNode** follow, CircuitNode** follow,
Clauses& clauses) Clauses& clauses)
{ {
if (globals::verbosity > 1) { if (Globals::verbosity > 1) {
backupClauses_ = Clause::copyClauses (clauses); backupClauses_ = Clause::copyClauses (clauses);
} }
for (size_t i = 0; i < clauses.size(); i++) { for (size_t i = 0; i < clauses.size(); i++) {
@ -481,7 +481,7 @@ LiftedCircuit::tryShannonDecomp (
otherClauses.push_back (c2); otherClauses.push_back (c2);
OrNode* orNode = new OrNode(); OrNode* orNode = new OrNode();
if (globals::verbosity > 1) { if (Globals::verbosity > 1) {
originClausesMap_[orNode] = backupClauses_; originClausesMap_[orNode] = backupClauses_;
std::stringstream explanation; std::stringstream explanation;
explanation << " SD on " << literals[j]; explanation << " SD on " << literals[j];
@ -495,7 +495,7 @@ LiftedCircuit::tryShannonDecomp (
} }
} }
} }
if (globals::verbosity > 1) { if (Globals::verbosity > 1) {
Clause::deleteClauses (backupClauses_); Clause::deleteClauses (backupClauses_);
} }
return false; return false;
@ -508,7 +508,7 @@ LiftedCircuit::tryInclusionExclusion (
CircuitNode** follow, CircuitNode** follow,
Clauses& clauses) Clauses& clauses)
{ {
if (globals::verbosity > 1) { if (Globals::verbosity > 1) {
backupClauses_ = Clause::copyClauses (clauses); backupClauses_ = Clause::copyClauses (clauses);
} }
for (size_t i = 0; i < clauses.size(); i++) { for (size_t i = 0; i < clauses.size(); i++) {
@ -561,7 +561,7 @@ LiftedCircuit::tryInclusionExclusion (
clauses.push_back (c2); clauses.push_back (c2);
IncExcNode* ieNode = new IncExcNode(); IncExcNode* ieNode = new IncExcNode();
if (globals::verbosity > 1) { if (Globals::verbosity > 1) {
originClausesMap_[ieNode] = backupClauses_; originClausesMap_[ieNode] = backupClauses_;
std::stringstream explanation; std::stringstream explanation;
explanation << " IncExc on clause nº " << i + 1; explanation << " IncExc on clause nº " << i + 1;
@ -574,7 +574,7 @@ LiftedCircuit::tryInclusionExclusion (
return true; return true;
} }
} }
if (globals::verbosity > 1) { if (Globals::verbosity > 1) {
Clause::deleteClauses (backupClauses_); Clause::deleteClauses (backupClauses_);
} }
return false; return false;
@ -589,7 +589,7 @@ LiftedCircuit::tryIndepPartialGrounding (
{ {
// assumes that all literals have logical variables // assumes that all literals have logical variables
// else, shannon decomp was possible // else, shannon decomp was possible
if (globals::verbosity > 1) { if (Globals::verbosity > 1) {
backupClauses_ = Clause::copyClauses (clauses); backupClauses_ = Clause::copyClauses (clauses);
} }
LogVars rootLogVars; LogVars rootLogVars;
@ -603,7 +603,7 @@ LiftedCircuit::tryIndepPartialGrounding (
clauses[j]->addIpgLogVar (rootLogVars[j]); clauses[j]->addIpgLogVar (rootLogVars[j]);
} }
SetAndNode* setAndNode = new SetAndNode (ct.size()); SetAndNode* setAndNode = new SetAndNode (ct.size());
if (globals::verbosity > 1) { if (Globals::verbosity > 1) {
originClausesMap_[setAndNode] = backupClauses_; originClausesMap_[setAndNode] = backupClauses_;
explanationMap_[setAndNode] = " IPG" ; explanationMap_[setAndNode] = " IPG" ;
} }
@ -612,7 +612,7 @@ LiftedCircuit::tryIndepPartialGrounding (
return true; return true;
} }
} }
if (globals::verbosity > 1) { if (Globals::verbosity > 1) {
Clause::deleteClauses (backupClauses_); Clause::deleteClauses (backupClauses_);
} }
return false; return false;
@ -674,7 +674,7 @@ LiftedCircuit::tryAtomCounting (
return false; return false;
} }
} }
if (globals::verbosity > 1) { if (Globals::verbosity > 1) {
backupClauses_ = Clause::copyClauses (clauses); backupClauses_ = Clause::copyClauses (clauses);
} }
for (size_t i = 0; i < clauses.size(); i++) { for (size_t i = 0; i < clauses.size(); i++) {
@ -686,7 +686,7 @@ LiftedCircuit::tryAtomCounting (
unsigned nrGroundings = clauses[i]->constr().projectedCopy ( unsigned nrGroundings = clauses[i]->constr().projectedCopy (
literals[j].logVars()).size(); literals[j].logVars()).size();
SetOrNode* setOrNode = new SetOrNode (nrGroundings); SetOrNode* setOrNode = new SetOrNode (nrGroundings);
if (globals::verbosity > 1) { if (Globals::verbosity > 1) {
originClausesMap_[setOrNode] = backupClauses_; originClausesMap_[setOrNode] = backupClauses_;
explanationMap_[setOrNode] = " AC" ; explanationMap_[setOrNode] = " AC" ;
} }
@ -707,7 +707,7 @@ LiftedCircuit::tryAtomCounting (
} }
} }
} }
if (globals::verbosity > 1) { if (Globals::verbosity > 1) {
Clause::deleteClauses (backupClauses_); Clause::deleteClauses (backupClauses_);
} }
return false; return false;
@ -916,7 +916,7 @@ LiftedCircuit::createSmoothNode (
CircuitNode** prev) CircuitNode** prev)
{ {
if (missingLits.empty() == false) { if (missingLits.empty() == false) {
if (globals::verbosity > 1) { if (Globals::verbosity > 1) {
std::unordered_map<CircuitNode*, Clauses>::iterator it std::unordered_map<CircuitNode*, Clauses>::iterator it
= originClausesMap_.find (*prev); = originClausesMap_.find (*prev);
if (it != originClausesMap_.end()) { if (it != originClausesMap_.end()) {
@ -944,7 +944,7 @@ LiftedCircuit::createSmoothNode (
} }
SmoothNode* smoothNode = new SmoothNode (clauses, *lwcnf_); SmoothNode* smoothNode = new SmoothNode (clauses, *lwcnf_);
*prev = new AndNode (smoothNode, *prev); *prev = new AndNode (smoothNode, *prev);
if (globals::verbosity > 1) { if (Globals::verbosity > 1) {
originClausesMap_[*prev] = backupClauses_; originClausesMap_[*prev] = backupClauses_;
explanationMap_[*prev] = " Smoothing" ; explanationMap_[*prev] = " Smoothing" ;
} }
@ -1211,7 +1211,7 @@ LiftedCircuit::escapeNode (const CircuitNode* node) const
std::string std::string
LiftedCircuit::getExplanationString (CircuitNode* node) LiftedCircuit::getExplanationString (CircuitNode* node)
{ {
return util::contains (explanationMap_, node) return Util::contains (explanationMap_, node)
? explanationMap_[node] ? explanationMap_[node]
: "" ; : "" ;
} }
@ -1225,7 +1225,7 @@ LiftedCircuit::printClauses (
std::string extraOptions) std::string extraOptions)
{ {
Clauses clauses; Clauses clauses;
if (util::contains (originClausesMap_, node)) { if (Util::contains (originClausesMap_, node)) {
clauses = originClausesMap_[node]; clauses = originClausesMap_[node];
} else if (getCircuitNodeType (node) == CircuitNodeType::LEAF_NODE) { } else if (getCircuitNodeType (node) == CircuitNodeType::LEAF_NODE) {
clauses = { (dynamic_cast<LeafNode*>(node))->clause() } ; clauses = { (dynamic_cast<LeafNode*>(node))->clause() } ;
@ -1288,20 +1288,20 @@ LiftedKc::solveQuery (const Grounds& query)
std::vector<LiteralId> litIds = lwcnf_->prvGroupLiterals (groups[i]); std::vector<LiteralId> litIds = lwcnf_->prvGroupLiterals (groups[i]);
for (size_t j = 0; j < litIds.size(); j++) { for (size_t j = 0; j < litIds.size(); j++) {
if (indexer[i] == j) { if (indexer[i] == j) {
lwcnf_->addWeight (litIds[j], log_aware::one(), lwcnf_->addWeight (litIds[j], LogAware::one(),
log_aware::one()); LogAware::one());
} else { } else {
lwcnf_->addWeight (litIds[j], log_aware::zero(), lwcnf_->addWeight (litIds[j], LogAware::zero(),
log_aware::one()); LogAware::one());
} }
} }
} }
params.push_back (circuit_->getWeightedModelCount()); params.push_back (circuit_->getWeightedModelCount());
++ indexer; ++ indexer;
} }
log_aware::normalize (params); LogAware::normalize (params);
if (globals::logDomain) { if (Globals::logDomain) {
util::exp (params); Util::exp (params);
} }
return params; return params;
} }
@ -1313,7 +1313,7 @@ LiftedKc::printSolverFlags (void) const
{ {
std::stringstream ss; std::stringstream ss;
ss << "lifted kc [" ; ss << "lifted kc [" ;
ss << "log_domain=" << util::toString (globals::logDomain); ss << "log_domain=" << Util::toString (Globals::logDomain);
ss << "]" ; ss << "]" ;
std::cout << ss.str() << std::endl; std::cout << ss.str() << std::endl;
} }

View File

@ -47,13 +47,13 @@ LiftedOperations::shatterAgainstQuery (
} }
pfList.add (newPfs); pfList.add (newPfs);
} }
if (globals::verbosity > 2) { if (Globals::verbosity > 2) {
util::printAsteriskLine(); Util::printAsteriskLine();
std::cout << "SHATTERED AGAINST THE QUERY" << std::endl; std::cout << "SHATTERED AGAINST THE QUERY" << std::endl;
for (size_t i = 0; i < query.size(); i++) { for (size_t i = 0; i < query.size(); i++) {
std::cout << " -> " << query[i] << std::endl; std::cout << " -> " << query[i] << std::endl;
} }
util::printAsteriskLine(); Util::printAsteriskLine();
pfList.print(); pfList.print();
} }
} }
@ -85,11 +85,11 @@ LiftedOperations::runWeakBayesBall (
PrvGroup group = todo.front(); PrvGroup group = todo.front();
ParfactorList::iterator it = pfList.begin(); ParfactorList::iterator it = pfList.begin();
while (it != pfList.end()) { while (it != pfList.end()) {
if (util::contains (requiredPfs, *it) == false && if (Util::contains (requiredPfs, *it) == false &&
(*it)->containsGroup (group)) { (*it)->containsGroup (group)) {
std::vector<PrvGroup> groups = (*it)->getAllGroups(); std::vector<PrvGroup> groups = (*it)->getAllGroups();
for (size_t i = 0; i < groups.size(); i++) { for (size_t i = 0; i < groups.size(); i++) {
if (util::contains (done, groups[i]) == false) { if (Util::contains (done, groups[i]) == false) {
todo.push (groups[i]); todo.push (groups[i]);
done.insert (groups[i]); done.insert (groups[i]);
} }
@ -104,10 +104,10 @@ LiftedOperations::runWeakBayesBall (
ParfactorList::iterator it = pfList.begin(); ParfactorList::iterator it = pfList.begin();
bool foundNotRequired = false; bool foundNotRequired = false;
while (it != pfList.end()) { while (it != pfList.end()) {
if (util::contains (requiredPfs, *it) == false) { if (Util::contains (requiredPfs, *it) == false) {
if (globals::verbosity > 2) { if (Globals::verbosity > 2) {
if (foundNotRequired == false) { if (foundNotRequired == false) {
util::printHeader ("PARFACTORS TO DISCARD"); Util::printHeader ("PARFACTORS TO DISCARD");
foundNotRequired = true; foundNotRequired = true;
} }
(*it)->print(); (*it)->print();
@ -137,7 +137,7 @@ LiftedOperations::absorveEvidence (
if (absorvedPfs.size() == 1 && !absorvedPfs[0]) { if (absorvedPfs.size() == 1 && !absorvedPfs[0]) {
// just remove pf; // just remove pf;
} else { } else {
util::addToVector (newPfs, absorvedPfs); Util::addToVector (newPfs, absorvedPfs);
} }
delete pf; delete pf;
} else { } else {
@ -147,13 +147,13 @@ LiftedOperations::absorveEvidence (
} }
pfList.add (newPfs); pfList.add (newPfs);
} }
if (globals::verbosity > 2 && obsFormulas.empty() == false) { if (Globals::verbosity > 2 && obsFormulas.empty() == false) {
util::printAsteriskLine(); Util::printAsteriskLine();
std::cout << "AFTER EVIDENCE ABSORVED" << std::endl; std::cout << "AFTER EVIDENCE ABSORVED" << std::endl;
for (size_t i = 0; i < obsFormulas.size(); i++) { for (size_t i = 0; i < obsFormulas.size(); i++) {
std::cout << " -> " << obsFormulas[i] << std::endl; std::cout << " -> " << obsFormulas[i] << std::endl;
} }
util::printAsteriskLine(); Util::printAsteriskLine();
pfList.print(); pfList.print();
} }
} }

View File

@ -105,7 +105,7 @@ Substitution::getDiscardedLogVars (void) const
std::unordered_map<LogVar, LogVar>::const_iterator it std::unordered_map<LogVar, LogVar>::const_iterator it
= subs_.begin(); = subs_.begin();
while (it != subs_.end()) { while (it != subs_.end()) {
if (util::contains (doneLvs, it->second)) { if (Util::contains (doneLvs, it->second)) {
discardedLvs.push_back (it->first); discardedLvs.push_back (it->first);
} else { } else {
doneLvs.insert (it->second); doneLvs.insert (it->second);

View File

@ -15,13 +15,13 @@ namespace horus {
class Symbol class Symbol
{ {
public: public:
Symbol (void) : id_(util::maxUnsigned()) { } Symbol (void) : id_(Util::maxUnsigned()) { }
Symbol (unsigned id) : id_(id) { } Symbol (unsigned id) : id_(id) { }
operator unsigned (void) const { return id_; } operator unsigned (void) const { return id_; }
bool valid (void) const { return id_ != util::maxUnsigned(); } bool valid (void) const { return id_ != Util::maxUnsigned(); }
static Symbol invalid (void) { return Symbol(); } static Symbol invalid (void) { return Symbol(); }
@ -35,7 +35,7 @@ class Symbol
class LogVar class LogVar
{ {
public: public:
LogVar (void) : id_(util::maxUnsigned()) { } LogVar (void) : id_(Util::maxUnsigned()) { }
LogVar (unsigned id) : id_(id) { } LogVar (unsigned id) : id_(id) { }
@ -66,7 +66,7 @@ LogVar::operator++ (void)
inline bool inline bool
LogVar::valid (void) const LogVar::valid (void) const
{ {
return id_ != util::maxUnsigned(); return id_ != Util::maxUnsigned();
} }
} // namespace horus } // namespace horus
@ -166,7 +166,7 @@ class Substitution
inline void inline void
Substitution::add (LogVar X_old, LogVar X_new) Substitution::add (LogVar X_old, LogVar X_new)
{ {
assert (util::contains (subs_, X_old) == false); assert (Util::contains (subs_, X_old) == false);
subs_.insert (std::make_pair (X_old, X_new)); subs_.insert (std::make_pair (X_old, X_new));
} }
@ -175,7 +175,7 @@ Substitution::add (LogVar X_old, LogVar X_new)
inline void inline void
Substitution::rename (LogVar X_old, LogVar X_new) Substitution::rename (LogVar X_old, LogVar X_new)
{ {
assert (util::contains (subs_, X_old)); assert (Util::contains (subs_, X_old));
subs_.find (X_old)->second = X_new; subs_.find (X_old)->second = X_new;
} }
@ -197,7 +197,7 @@ Substitution::newNameFor (LogVar X) const
inline bool inline bool
Substitution::containsReplacementFor (LogVar X) const Substitution::containsReplacementFor (LogVar X) const
{ {
return util::contains (subs_, X); return Util::contains (subs_, X);
} }

View File

@ -25,7 +25,7 @@ LiftedOperator::getValidOps (
multOps = ProductOperator::getValidOps (pfList); multOps = ProductOperator::getValidOps (pfList);
validOps.insert (validOps.end(), multOps.begin(), multOps.end()); validOps.insert (validOps.end(), multOps.begin(), multOps.end());
if (globals::verbosity > 1 || multOps.empty()) { if (Globals::verbosity > 1 || multOps.empty()) {
std::vector<SumOutOperator*> sumOutOps; std::vector<SumOutOperator*> sumOutOps;
std::vector<CountingOperator*> countOps; std::vector<CountingOperator*> countOps;
std::vector<GroundOperator*> groundOps; std::vector<GroundOperator*> groundOps;
@ -103,14 +103,14 @@ ProductOperator::getValidOps (ParfactorList& pfList)
ParfactorList::iterator penultimate = -- pfList.end(); ParfactorList::iterator penultimate = -- pfList.end();
std::set<Parfactor*> pfs; std::set<Parfactor*> pfs;
while (it1 != penultimate) { while (it1 != penultimate) {
if (util::contains (pfs, *it1)) { if (Util::contains (pfs, *it1)) {
++ it1; ++ it1;
continue; continue;
} }
ParfactorList::iterator it2 = it1; ParfactorList::iterator it2 = it1;
++ it2; ++ it2;
while (it2 != pfList.end()) { while (it2 != pfList.end()) {
if (util::contains (pfs, *it2)) { if (Util::contains (pfs, *it2)) {
++ it2; ++ it2;
continue; continue;
} else { } else {
@ -119,7 +119,7 @@ ProductOperator::getValidOps (ParfactorList& pfList)
pfs.insert (*it2); pfs.insert (*it2);
validOps.push_back (new ProductOperator ( validOps.push_back (new ProductOperator (
it1, it2, pfList)); it1, it2, pfList));
if (globals::verbosity < 2) { if (Globals::verbosity < 2) {
return validOps; return validOps;
} }
break; break;
@ -353,7 +353,7 @@ CountingOperator::getLogCost (void)
cost += size * HistogramSet::nrHistograms (counts[i], range); cost += size * HistogramSet::nrHistograms (counts[i], range);
} }
PrvGroup group = (*pfIter_)->argument (fIdx).group(); PrvGroup group = (*pfIter_)->argument (fIdx).group();
size_t lvIndex = util::indexOf ( size_t lvIndex = Util::indexOf (
(*pfIter_)->argument (fIdx).logVars(), X_); (*pfIter_)->argument (fIdx).logVars(), X_);
assert (lvIndex != (*pfIter_)->argument (fIdx).logVars().size()); assert (lvIndex != (*pfIter_)->argument (fIdx).logVars().size());
ParfactorList::iterator pfIter = pfList_.begin(); ParfactorList::iterator pfIter = pfList_.begin();
@ -503,7 +503,7 @@ GroundOperator::getLogCost (void)
if (willBeAffected) { if (willBeAffected) {
// std::cout << " + " << std::exp (reps) << "x" << std::exp (pfSize); // std::cout << " + " << std::exp (reps) << "x" << std::exp (pfSize);
double pfCost = reps + pfSize; double pfCost = reps + pfSize;
totalCost = util::logSum (totalCost, pfCost); totalCost = Util::logSum (totalCost, pfCost);
} }
++ pflIt; ++ pflIt;
} }
@ -552,7 +552,7 @@ GroundOperator::getValidOps (ParfactorList& pfList)
while (it != pfList.end()) { while (it != pfList.end()) {
const ProbFormulas& formulas = (*it)->arguments(); const ProbFormulas& formulas = (*it)->arguments();
for (size_t i = 0; i < formulas.size(); i++) { for (size_t i = 0; i < formulas.size(); i++) {
if (util::contains (allGroups, formulas[i].group()) == false) { if (Util::contains (allGroups, formulas[i].group()) == false) {
const LogVars& lvs = formulas[i].logVars(); const LogVars& lvs = formulas[i].logVars();
for (size_t j = 0; j < lvs.size(); j++) { for (size_t j = 0; j < lvs.size(); j++) {
if ((*it)->constr()->isSingleton (lvs[j]) == false) { if ((*it)->constr()->isSingleton (lvs[j]) == false) {
@ -618,7 +618,7 @@ GroundOperator::getAffectedFormulas (void)
if (i != idx && fs[i].contains (X)) { if (i != idx && fs[i].contains (X)) {
std::pair<PrvGroup, unsigned> pair = std::make_pair ( std::pair<PrvGroup, unsigned> pair = std::make_pair (
fs[i].group(), fs[i].indexOf (X)); fs[i].group(), fs[i].indexOf (X));
if (util::contains (affectedFormulas, pair) == false) { if (Util::contains (affectedFormulas, pair) == false) {
q.push (pair); q.push (pair);
affectedFormulas.push_back (pair); affectedFormulas.push_back (pair);
} }
@ -642,8 +642,8 @@ LiftedVe::solveQuery (const Grounds& query)
runSolver (query); runSolver (query);
(*pfList_.begin())->normalize(); (*pfList_.begin())->normalize();
Params params = (*pfList_.begin())->params(); Params params = (*pfList_.begin())->params();
if (globals::logDomain) { if (Globals::logDomain) {
util::exp (params); Util::exp (params);
} }
return params; return params;
} }
@ -655,7 +655,7 @@ LiftedVe::printSolverFlags (void) const
{ {
std::stringstream ss; std::stringstream ss;
ss << "lve [" ; ss << "lve [" ;
ss << "log_domain=" << util::toString (globals::logDomain); ss << "log_domain=" << Util::toString (Globals::logDomain);
ss << "]" ; ss << "]" ;
std::cout << ss.str() << std::endl; std::cout << ss.str() << std::endl;
} }
@ -669,10 +669,10 @@ LiftedVe::runSolver (const Grounds& query)
LiftedOperations::shatterAgainstQuery (pfList_, query); LiftedOperations::shatterAgainstQuery (pfList_, query);
LiftedOperations::runWeakBayesBall (pfList_, query); LiftedOperations::runWeakBayesBall (pfList_, query);
while (true) { while (true) {
if (globals::verbosity > 2) { if (Globals::verbosity > 2) {
util::printDashedLine(); Util::printDashedLine();
pfList_.print(); pfList_.print();
if (globals::verbosity > 3) { if (Globals::verbosity > 3) {
LiftedOperator::printValidOps (pfList_, query); LiftedOperator::printValidOps (pfList_, query);
} }
} }
@ -680,9 +680,9 @@ LiftedVe::runSolver (const Grounds& query)
if (op == 0) { if (op == 0) {
break; break;
} }
if (globals::verbosity > 1) { if (Globals::verbosity > 1) {
std::cout << "best operation: " << op->toString(); std::cout << "best operation: " << op->toString();
if (globals::verbosity > 2) { if (Globals::verbosity > 2) {
std::cout << std::endl; std::cout << std::endl;
} }
} }
@ -698,7 +698,7 @@ LiftedVe::runSolver (const Grounds& query)
++ pfIter; ++ pfIter;
} }
} }
if (globals::verbosity > 0) { if (Globals::verbosity > 0) {
std::cout << "largest cost = " << std::exp (largestCost_); std::cout << "largest cost = " << std::exp (largestCost_);
std::cout << std::endl; std::cout << std::endl;
std::cout << std::endl; std::cout << std::endl;

View File

@ -29,7 +29,7 @@ Literal::isGround (
size_t size_t
Literal::indexOfLogVar (LogVar X) const Literal::indexOfLogVar (LogVar X) const
{ {
return util::indexOf (logVars_, X); return Util::indexOf (logVars_, X);
} }
@ -248,7 +248,7 @@ Clause::ipgCandidates (void) const
for (size_t i = 0; i < allLvs.size(); i++) { for (size_t i = 0; i < allLvs.size(); i++) {
bool valid = true; bool valid = true;
for (size_t j = 0; j < literals_.size(); j++) { for (size_t j = 0; j < literals_.size(); j++) {
if (util::contains (literals_[j].logVars(), allLvs[i]) == false) { if (Util::contains (literals_[j].logVars(), allLvs[i]) == false) {
valid = false; valid = false;
break; break;
} }
@ -455,7 +455,7 @@ LiftedWCNF::LiftedWCNF (const ParfactorList& pfList)
clauses_.push_back(c2); clauses_.push_back(c2);
*/ */
if (globals::verbosity > 1) { if (Globals::verbosity > 1) {
std::cout << "FORMULA INDICATORS:" << std::endl; std::cout << "FORMULA INDICATORS:" << std::endl;
printFormulaIndicators(); printFormulaIndicators();
std::cout << std::endl; std::cout << std::endl;
@ -490,7 +490,7 @@ LiftedWCNF::posWeight (LiteralId lid) const
{ {
std::unordered_map<LiteralId, std::pair<double,double>>::const_iterator it std::unordered_map<LiteralId, std::pair<double,double>>::const_iterator it
= weights_.find (lid); = weights_.find (lid);
return it != weights_.end() ? it->second.first : log_aware::one(); return it != weights_.end() ? it->second.first : LogAware::one();
} }
@ -500,7 +500,7 @@ LiftedWCNF::negWeight (LiteralId lid) const
{ {
std::unordered_map<LiteralId, std::pair<double,double>>::const_iterator it std::unordered_map<LiteralId, std::pair<double,double>>::const_iterator it
= weights_.find (lid); = weights_.find (lid);
return it != weights_.end() ? it->second.second : log_aware::one(); return it != weights_.end() ? it->second.second : LogAware::one();
} }
@ -508,7 +508,7 @@ LiftedWCNF::negWeight (LiteralId lid) const
std::vector<LiteralId> std::vector<LiteralId>
LiftedWCNF::prvGroupLiterals (PrvGroup prvGroup) LiftedWCNF::prvGroupLiterals (PrvGroup prvGroup)
{ {
assert (util::contains (map_, prvGroup)); assert (Util::contains (map_, prvGroup));
return map_[prvGroup]; return map_[prvGroup];
} }
@ -537,7 +537,7 @@ LiftedWCNF::createClause (LiteralId lid) const
LiteralId LiteralId
LiftedWCNF::getLiteralId (PrvGroup prvGroup, unsigned range) LiftedWCNF::getLiteralId (PrvGroup prvGroup, unsigned range)
{ {
assert (util::contains (map_, prvGroup)); assert (Util::contains (map_, prvGroup));
return map_[prvGroup][range]; return map_[prvGroup][range];
} }
@ -550,7 +550,7 @@ LiftedWCNF::addIndicatorClauses (const ParfactorList& pfList)
while (it != pfList.end()) { while (it != pfList.end()) {
const ProbFormulas& formulas = (*it)->arguments(); const ProbFormulas& formulas = (*it)->arguments();
for (size_t i = 0; i < formulas.size(); i++) { for (size_t i = 0; i < formulas.size(); i++) {
if (util::contains (map_, formulas[i].group()) == false) { if (Util::contains (map_, formulas[i].group()) == false) {
ConstraintTree tempConstr = (*it)->constr()->projectedCopy( ConstraintTree tempConstr = (*it)->constr()->projectedCopy(
formulas[i].logVars()); formulas[i].logVars());
Clause* clause = new Clause (tempConstr); Clause* clause = new Clause (tempConstr);
@ -595,7 +595,7 @@ LiftedWCNF::addParameterClauses (const ParfactorList& pfList)
// ¬θxi|u1,...,un v λu1 -> tempClause // ¬θxi|u1,...,un v λu1 -> tempClause
// ¬θxi|u1,...,un v λu2 -> tempClause // ¬θxi|u1,...,un v λu2 -> tempClause
double posWeight = (**it)[indexer]; double posWeight = (**it)[indexer];
addWeight (paramVarLid, posWeight, log_aware::one()); addWeight (paramVarLid, posWeight, LogAware::one());
Clause* clause1 = new Clause (*(*it)->constr()); Clause* clause1 = new Clause (*(*it)->constr());
@ -634,7 +634,7 @@ LiftedWCNF::printFormulaIndicators (void) const
while (it != pfList_.end()) { while (it != pfList_.end()) {
const ProbFormulas& formulas = (*it)->arguments(); const ProbFormulas& formulas = (*it)->arguments();
for (size_t i = 0; i < formulas.size(); i++) { for (size_t i = 0; i < formulas.size(); i++) {
if (util::contains (allGroups, formulas[i].group()) == false) { if (Util::contains (allGroups, formulas[i].group()) == false) {
allGroups.insert (formulas[i].group()); allGroups.insert (formulas[i].group());
std::cout << formulas[i] << " | " ; std::cout << formulas[i] << " | " ;
ConstraintTree tempCt = (*it)->constr()->projectedCopy ( ConstraintTree tempCt = (*it)->constr()->projectedCopy (

View File

@ -27,7 +27,7 @@ Parfactor::Parfactor (
ranges_.push_back (args_[i].range()); ranges_.push_back (args_[i].range());
const LogVars& lvs = args_[i].logVars(); const LogVars& lvs = args_[i].logVars();
for (size_t j = 0; j < lvs.size(); j++) { for (size_t j = 0; j < lvs.size(); j++) {
if (util::contains (logVars, lvs[j]) == false) { if (Util::contains (logVars, lvs[j]) == false) {
logVars.push_back (lvs[j]); logVars.push_back (lvs[j]);
} }
} }
@ -50,7 +50,7 @@ Parfactor::Parfactor (
} }
} }
} }
assert (params_.size() == util::sizeExpected (ranges_)); assert (params_.size() == Util::sizeExpected (ranges_));
} }
@ -62,7 +62,7 @@ Parfactor::Parfactor (const Parfactor* g, const Tuple& tuple)
ranges_ = g->ranges(); ranges_ = g->ranges();
distId_ = g->distId(); distId_ = g->distId();
constr_ = new ConstraintTree (g->logVars(), {tuple}); constr_ = new ConstraintTree (g->logVars(), {tuple});
assert (params_.size() == util::sizeExpected (ranges_)); assert (params_.size() == Util::sizeExpected (ranges_));
} }
@ -74,7 +74,7 @@ Parfactor::Parfactor (const Parfactor* g, ConstraintTree* constr)
ranges_ = g->ranges(); ranges_ = g->ranges();
distId_ = g->distId(); distId_ = g->distId();
constr_ = constr; constr_ = constr;
assert (params_.size() == util::sizeExpected (ranges_)); assert (params_.size() == Util::sizeExpected (ranges_));
} }
@ -86,7 +86,7 @@ Parfactor::Parfactor (const Parfactor& g)
ranges_ = g.ranges(); ranges_ = g.ranges();
distId_ = g.distId(); distId_ = g.distId();
constr_ = new ConstraintTree (*g.constr()); constr_ = new ConstraintTree (*g.constr());
assert (params_.size() == util::sizeExpected (ranges_)); assert (params_.size() == Util::sizeExpected (ranges_));
} }
@ -159,7 +159,7 @@ Parfactor::sumOutIndex (size_t fIdx)
std::vector<double> numAssigns = HistogramSet::getNumAssigns (N, R); std::vector<double> numAssigns = HistogramSet::getNumAssigns (N, R);
Indexer indexer (ranges_, fIdx); Indexer indexer (ranges_, fIdx);
while (indexer.valid()) { while (indexer.valid()) {
if (globals::logDomain) { if (Globals::logDomain) {
params_[indexer] += numAssigns[ indexer[fIdx] ]; params_[indexer] += numAssigns[ indexer[fIdx] ];
} else { } else {
params_[indexer] *= numAssigns[ indexer[fIdx] ]; params_[indexer] *= numAssigns[ indexer[fIdx] ];
@ -179,7 +179,7 @@ Parfactor::sumOutIndex (size_t fIdx)
constr_->remove (excl); constr_->remove (excl);
TFactor<ProbFormula>::sumOutIndex (fIdx); TFactor<ProbFormula>::sumOutIndex (fIdx);
log_aware::pow (params_, exp); LogAware::pow (params_, exp);
} }
@ -253,14 +253,14 @@ Parfactor::countConvert (LogVar X)
ranges_[fIdx] = H; ranges_[fIdx] = H;
MapIndexer mapIndexer (ranges_, fIdx); MapIndexer mapIndexer (ranges_, fIdx);
while (mapIndexer.valid()) { while (mapIndexer.valid()) {
double prod = log_aware::multIdenty(); double prod = LogAware::multIdenty();
size_t i = mapIndexer; size_t i = mapIndexer;
unsigned h = mapIndexer[fIdx]; unsigned h = mapIndexer[fIdx];
for (unsigned r = 0; r < R; r++) { for (unsigned r = 0; r < R; r++) {
if (globals::logDomain) { if (Globals::logDomain) {
prod += log_aware::pow (sumout[i][r], histograms[h][r]); prod += LogAware::pow (sumout[i][r], histograms[h][r]);
} else { } else {
prod *= log_aware::pow (sumout[i][r], histograms[h][r]); prod *= LogAware::pow (sumout[i][r], histograms[h][r]);
} }
} }
params_.push_back (prod); params_.push_back (prod);
@ -390,7 +390,7 @@ Parfactor::absorveEvidence (const ProbFormula& formula, unsigned evidence)
LogVarSet excl = exclusiveLogVars (fIdx); LogVarSet excl = exclusiveLogVars (fIdx);
assert (args_[fIdx].isCounting() == false); assert (args_[fIdx].isCounting() == false);
assert (constr_->isCountNormalized (excl)); assert (constr_->isCountNormalized (excl));
log_aware::pow (params_, constr_->getConditionalCount (excl)); LogAware::pow (params_, constr_->getConditionalCount (excl));
TFactor<ProbFormula>::absorveEvidence (formula, evidence); TFactor<ProbFormula>::absorveEvidence (formula, evidence);
constr_->remove (excl); constr_->remove (excl);
} }
@ -475,7 +475,7 @@ Parfactor::containsGrounds (const Grounds& grounds) const
} }
LogVars lvs = args_[idx].logVars(); LogVars lvs = args_[idx].logVars();
for (size_t j = 0; j < lvs.size(); j++) { for (size_t j = 0; j < lvs.size(); j++) {
if (util::contains (tupleLvs, lvs[j]) == false) { if (Util::contains (tupleLvs, lvs[j]) == false) {
tuple.push_back (grounds[i].args()[j]); tuple.push_back (grounds[i].args()[j]);
tupleLvs.push_back (lvs[j]); tupleLvs.push_back (lvs[j]);
} }
@ -613,10 +613,10 @@ Parfactor::print (bool printParams) const
cout << args_[i]; cout << args_[i];
} }
cout << endl; cout << endl;
if (args_[0].group() != util::maxUnsigned()) { if (args_[0].group() != Util::maxUnsigned()) {
std::vector<std::string> groups; std::vector<std::string> groups;
for (size_t i = 0; i < args_.size(); i++) { for (size_t i = 0; i < args_.size(); i++) {
groups.push_back (std::string ("g") + util::toString (args_[i].group())); groups.push_back (std::string ("g") + Util::toString (args_[i].group()));
} }
cout << "Groups: " << groups << endl; cout << "Groups: " << groups << endl;
} }
@ -844,8 +844,8 @@ Parfactor::getAlignLogVars (Parfactor* g1, Parfactor* g2)
g1->range (i) == g2->range (j) && g1->range (i) == g2->range (j) &&
matchedI.contains (i) == false && matchedI.contains (i) == false &&
matchedJ.contains (j) == false) { matchedJ.contains (j) == false) {
util::addToVector (Xs_1, formulas1[i].logVars()); Util::addToVector (Xs_1, formulas1[i].logVars());
util::addToVector (Xs_2, formulas2[j].logVars()); Util::addToVector (Xs_2, formulas2[j].logVars());
matchedI.insert (i); matchedI.insert (i);
matchedJ.insert (j); matchedJ.insert (j);
} }
@ -869,8 +869,8 @@ Parfactor::alignAndExponentiate (Parfactor* g1, Parfactor* g2)
assert (g2->constr()->isCountNormalized (Y_2)); assert (g2->constr()->isCountNormalized (Y_2));
unsigned condCount1 = g1->constr()->getConditionalCount (Y_1); unsigned condCount1 = g1->constr()->getConditionalCount (Y_1);
unsigned condCount2 = g2->constr()->getConditionalCount (Y_2); unsigned condCount2 = g2->constr()->getConditionalCount (Y_2);
log_aware::pow (g1->params(), 1.0 / condCount2); LogAware::pow (g1->params(), 1.0 / condCount2);
log_aware::pow (g2->params(), 1.0 / condCount1); LogAware::pow (g2->params(), 1.0 / condCount1);
} }

View File

@ -244,13 +244,13 @@ ParfactorList::addToShatteredList (Parfactor* g)
shattRes = shatter (*pfIter, pf); shattRes = shatter (*pfIter, pf);
if (shattRes.first.empty() == false) { if (shattRes.first.empty() == false) {
pfIter = removeAndDelete (pfIter); pfIter = removeAndDelete (pfIter);
util::addToQueue (residuals, shattRes.first); Util::addToQueue (residuals, shattRes.first);
} else { } else {
++ pfIter; ++ pfIter;
} }
if (shattRes.second.empty() == false) { if (shattRes.second.empty() == false) {
delete pf; delete pf;
util::addToQueue (residuals, shattRes.second); Util::addToQueue (residuals, shattRes.second);
pfSplitted = true; pfSplitted = true;
break; break;
} }
@ -261,7 +261,7 @@ ParfactorList::addToShatteredList (Parfactor* g)
if (res.empty()) { if (res.empty()) {
addShattered (pf); addShattered (pf);
} else { } else {
util::addToQueue (residuals, res); Util::addToQueue (residuals, res);
} }
} }
} }
@ -329,7 +329,7 @@ ParfactorList::shatterAgainstMySelf (
size_t fIdx2) size_t fIdx2)
{ {
/* /*
util::printDashedLine(); Util::printDashedLine();
std::cout << "-> SHATTERING" << std::endl; std::cout << "-> SHATTERING" << std::endl;
g->print(); g->print();
std::cout << "-> ON: " << g->argument (fIdx1) << "|" ; std::cout << "-> ON: " << g->argument (fIdx1) << "|" ;
@ -338,7 +338,7 @@ ParfactorList::shatterAgainstMySelf (
std::cout << "-> ON: " << g->argument (fIdx2) << "|" ; std::cout << "-> ON: " << g->argument (fIdx2) << "|" ;
std::cout << g->constr()->tupleSet (g->argument (fIdx2).logVars()) std::cout << g->constr()->tupleSet (g->argument (fIdx2).logVars())
std::cout << std::endl; std::cout << std::endl;
util::printDashedLine(); Util::printDashedLine();
*/ */
ProbFormula& f1 = g->argument (fIdx1); ProbFormula& f1 = g->argument (fIdx1);
ProbFormula& f2 = g->argument (fIdx2); ProbFormula& f2 = g->argument (fIdx2);
@ -399,7 +399,7 @@ ParfactorList::shatterAgainstMySelf (
res.push_back (res1[i]); res.push_back (res1[i]);
} }
} else { } else {
util::addToVector (res, res2); Util::addToVector (res, res2);
for (size_t j = 0; j < res2.size(); j++) { for (size_t j = 0; j < res2.size(); j++) {
} }
if (res1[i] != g) { if (res1[i] != g) {
@ -448,7 +448,7 @@ ParfactorList::shatter (
ProbFormula& f1 = g1->argument (fIdx1); ProbFormula& f1 = g1->argument (fIdx1);
ProbFormula& f2 = g2->argument (fIdx2); ProbFormula& f2 = g2->argument (fIdx2);
/* /*
util::printDashedLine(); Util::printDashedLine();
std::cout << "-> SHATTERING" << std::endl; std::cout << "-> SHATTERING" << std::endl;
g1->print(); g1->print();
std::cout << "-> WITH" << std::endl; std::cout << "-> WITH" << std::endl;
@ -457,7 +457,7 @@ ParfactorList::shatter (
std::cout << g1->constr()->tupleSet (f1.logVars()) << std::endl; std::cout << g1->constr()->tupleSet (f1.logVars()) << std::endl;
std::cout << "-> ON: " << f2 << "|" ; std::cout << "-> ON: " << f2 << "|" ;
std::cout << g2->constr()->tupleSet (f2.logVars()) << std::endl; std::cout << g2->constr()->tupleSet (f2.logVars()) << std::endl;
util::printDashedLine(); Util::printDashedLine();
*/ */
if (f1.isAtom()) { if (f1.isAtom()) {
f2.setGroup (f1.group()); f2.setGroup (f1.group());

View File

@ -23,7 +23,7 @@ ProbFormula::sameSkeletonAs (const ProbFormula& f) const
bool bool
ProbFormula::contains (LogVar lv) const ProbFormula::contains (LogVar lv) const
{ {
return util::contains (logVars_, lv); return Util::contains (logVars_, lv);
} }
@ -39,7 +39,7 @@ ProbFormula::contains (LogVarSet s) const
size_t size_t
ProbFormula::indexOf (LogVar X) const ProbFormula::indexOf (LogVar X) const
{ {
return util::indexOf (logVars_, X); return Util::indexOf (logVars_, X);
} }

View File

@ -6,7 +6,7 @@
namespace horus { namespace horus {
namespace globals { namespace Globals {
bool logDomain = false; bool logDomain = false;
@ -20,7 +20,7 @@ GroundSolverType groundSolver = GroundSolverType::VE;
namespace util { namespace Util {
template <> std::string template <> std::string
toString (const bool& b) toString (const bool& b)
@ -203,25 +203,25 @@ setHorusFlag (std::string option, std::string value)
{ {
bool returnVal = true; bool returnVal = true;
if (option == "lifted_solver") { if (option == "lifted_solver") {
if (value == "lve") globals::liftedSolver = LiftedSolverType::LVE; if (value == "lve") Globals::liftedSolver = LiftedSolverType::LVE;
else if (value == "lbp") globals::liftedSolver = LiftedSolverType::LBP; else if (value == "lbp") Globals::liftedSolver = LiftedSolverType::LBP;
else if (value == "lkc") globals::liftedSolver = LiftedSolverType::LKC; else if (value == "lkc") Globals::liftedSolver = LiftedSolverType::LKC;
else returnVal = invalidValue (option, value); else returnVal = invalidValue (option, value);
} else if (option == "ground_solver" || option == "solver") { } else if (option == "ground_solver" || option == "solver") {
if (value == "hve") globals::groundSolver = GroundSolverType::VE; if (value == "hve") Globals::groundSolver = GroundSolverType::VE;
else if (value == "bp") globals::groundSolver = GroundSolverType::BP; else if (value == "bp") Globals::groundSolver = GroundSolverType::BP;
else if (value == "cbp") globals::groundSolver = GroundSolverType::CBP; else if (value == "cbp") Globals::groundSolver = GroundSolverType::CBP;
else returnVal = invalidValue (option, value); else returnVal = invalidValue (option, value);
} else if (option == "verbosity") { } else if (option == "verbosity") {
std::stringstream ss; std::stringstream ss;
ss << value; ss << value;
ss >> globals::verbosity; ss >> Globals::verbosity;
} else if (option == "use_logarithms") { } else if (option == "use_logarithms") {
if (value == "true") globals::logDomain = true; if (value == "true") Globals::logDomain = true;
else if (value == "false") globals::logDomain = false; else if (value == "false") Globals::logDomain = false;
else returnVal = invalidValue (option, value); else returnVal = invalidValue (option, value);
} else if (option == "hve_elim_heuristic") { } else if (option == "hve_elim_heuristic") {
@ -335,14 +335,14 @@ printDashedLine (std::ostream& os)
namespace log_aware { namespace LogAware {
void void
normalize (Params& v) normalize (Params& v)
{ {
if (globals::logDomain) { if (Globals::logDomain) {
double sum = std::accumulate (v.begin(), v.end(), double sum = std::accumulate (v.begin(), v.end(),
log_aware::addIdenty(), util::logSum); LogAware::addIdenty(), Util::logSum);
assert (sum != -std::numeric_limits<double>::infinity()); assert (sum != -std::numeric_limits<double>::infinity());
v -= sum; v -= sum;
} else { } else {
@ -359,7 +359,7 @@ getL1Distance (const Params& v1, const Params& v2)
{ {
assert (v1.size() == v2.size()); assert (v1.size() == v2.size());
double dist = 0.0; double dist = 0.0;
if (globals::logDomain) { if (Globals::logDomain) {
dist = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0, dist = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0,
std::plus<double>(), func_obj::abs_diff_exp<double>()); std::plus<double>(), func_obj::abs_diff_exp<double>());
} else { } else {
@ -376,7 +376,7 @@ getMaxNorm (const Params& v1, const Params& v2)
{ {
assert (v1.size() == v2.size()); assert (v1.size() == v2.size());
double max = 0.0; double max = 0.0;
if (globals::logDomain) { if (Globals::logDomain) {
max = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0, max = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0,
func_obj::max<double>(), func_obj::abs_diff_exp<double>()); func_obj::max<double>(), func_obj::abs_diff_exp<double>());
} else { } else {
@ -391,7 +391,7 @@ getMaxNorm (const Params& v1, const Params& v2)
double double
pow (double base, unsigned iexp) pow (double base, unsigned iexp)
{ {
return globals::logDomain return Globals::logDomain
? base * iexp ? base * iexp
: std::pow (base, iexp); : std::pow (base, iexp);
} }
@ -402,7 +402,7 @@ double
pow (double base, double exp) pow (double base, double exp)
{ {
// `expoent' should not be in log domain // `expoent' should not be in log domain
return globals::logDomain return Globals::logDomain
? base * exp ? base * exp
: std::pow (base, exp); : std::pow (base, exp);
} }
@ -415,7 +415,7 @@ pow (Params& v, unsigned iexp)
if (iexp == 1) { if (iexp == 1) {
return; return;
} }
globals::logDomain ? v *= iexp : v ^= (int)iexp; Globals::logDomain ? v *= iexp : v ^= (int)iexp;
} }
@ -424,10 +424,10 @@ void
pow (Params& v, double exp) pow (Params& v, double exp)
{ {
// `expoent' should not be in log domain // `expoent' should not be in log domain
globals::logDomain ? v *= exp : v ^= exp; Globals::logDomain ? v *= exp : v ^= exp;
} }
} // namespace log_aware } // namespace LogAware
} // namespace horus } // namespace horus

View File

@ -16,7 +16,6 @@
#include "Horus.h" #include "Horus.h"
namespace horus { namespace horus {
namespace { namespace {
@ -26,7 +25,7 @@ const double NEG_INF = -std::numeric_limits<double>::infinity();
} }
namespace util { namespace Util {
template <typename T> void template <typename T> void
addToVector (std::vector<T>&, const std::vector<T>&); addToVector (std::vector<T>&, const std::vector<T>&);
@ -88,7 +87,7 @@ unsigned nrDigits (int);
bool isInteger (const std::string&); bool isInteger (const std::string&);
std::string parametersToString ( std::string parametersToString (
const Params&, unsigned = constants::PRECISION); const Params&, unsigned = Constants::PRECISION);
std::vector<std::string> getStateLines (const Vars&); std::vector<std::string> getStateLines (const Vars&);
@ -107,7 +106,7 @@ void printDashedLine (std::ostream& os = std::cout);
template <typename T> void template <typename T> void
util::addToVector (std::vector<T>& v, const std::vector<T>& elements) Util::addToVector (std::vector<T>& v, const std::vector<T>& elements)
{ {
v.insert (v.end(), elements.begin(), elements.end()); v.insert (v.end(), elements.begin(), elements.end());
} }
@ -115,7 +114,7 @@ util::addToVector (std::vector<T>& v, const std::vector<T>& elements)
template <typename T> void template <typename T> void
util::addToSet (std::set<T>& s, const std::vector<T>& elements) Util::addToSet (std::set<T>& s, const std::vector<T>& elements)
{ {
s.insert (elements.begin(), elements.end()); s.insert (elements.begin(), elements.end());
} }
@ -123,7 +122,7 @@ util::addToSet (std::set<T>& s, const std::vector<T>& elements)
template <typename T> void template <typename T> void
util::addToQueue (std::queue<T>& q, const std::vector<T>& elements) Util::addToQueue (std::queue<T>& q, const std::vector<T>& elements)
{ {
for (size_t i = 0; i < elements.size(); i++) { for (size_t i = 0; i < elements.size(); i++) {
q.push (elements[i]); q.push (elements[i]);
@ -133,7 +132,7 @@ util::addToQueue (std::queue<T>& q, const std::vector<T>& elements)
template <typename T> bool template <typename T> bool
util::contains (const std::vector<T>& v, const T& e) Util::contains (const std::vector<T>& v, const T& e)
{ {
return std::find (v.begin(), v.end(), e) != v.end(); return std::find (v.begin(), v.end(), e) != v.end();
} }
@ -141,7 +140,7 @@ util::contains (const std::vector<T>& v, const T& e)
template <typename T> bool template <typename T> bool
util::contains (const std::set<T>& s, const T& e) Util::contains (const std::set<T>& s, const T& e)
{ {
return s.find (e) != s.end(); return s.find (e) != s.end();
} }
@ -149,7 +148,7 @@ util::contains (const std::set<T>& s, const T& e)
template <typename K, typename V> bool template <typename K, typename V> bool
util::contains (const std::unordered_map<K, V>& m, const K& k) Util::contains (const std::unordered_map<K, V>& m, const K& k)
{ {
return m.find (k) != m.end(); return m.find (k) != m.end();
} }
@ -157,7 +156,7 @@ util::contains (const std::unordered_map<K, V>& m, const K& k)
template <typename T> size_t template <typename T> size_t
util::indexOf (const std::vector<T>& v, const T& e) Util::indexOf (const std::vector<T>& v, const T& e)
{ {
return std::distance (v.begin(), return std::distance (v.begin(),
std::find (v.begin(), v.end(), e)); std::find (v.begin(), v.end(), e));
@ -166,7 +165,7 @@ util::indexOf (const std::vector<T>& v, const T& e)
template <class Operation> void template <class Operation> void
util::apply_n_times ( Util::apply_n_times (
Params& v1, Params& v1,
const Params& v2, const Params& v2,
unsigned repetitions, unsigned repetitions,
@ -188,7 +187,7 @@ util::apply_n_times (
template <typename T> void template <typename T> void
util::log (std::vector<T>& v) Util::log (std::vector<T>& v)
{ {
std::transform (v.begin(), v.end(), v.begin(), ::log); std::transform (v.begin(), v.end(), v.begin(), ::log);
} }
@ -196,7 +195,7 @@ util::log (std::vector<T>& v)
template <typename T> void template <typename T> void
util::exp (std::vector<T>& v) Util::exp (std::vector<T>& v)
{ {
std::transform (v.begin(), v.end(), v.begin(), ::exp); std::transform (v.begin(), v.end(), v.begin(), ::exp);
} }
@ -204,7 +203,7 @@ util::exp (std::vector<T>& v)
template <typename T> std::string template <typename T> std::string
util::elementsToString (const std::vector<T>& v, std::string sep) Util::elementsToString (const std::vector<T>& v, std::string sep)
{ {
std::stringstream ss; std::stringstream ss;
for (size_t i = 0; i < v.size(); i++) { for (size_t i = 0; i < v.size(); i++) {
@ -216,7 +215,7 @@ util::elementsToString (const std::vector<T>& v, std::string sep)
template <typename T> std::string template <typename T> std::string
util::toString (const T& t) Util::toString (const T& t)
{ {
std::stringstream ss; std::stringstream ss;
ss << t; ss << t;
@ -226,7 +225,7 @@ util::toString (const T& t)
inline double inline double
util::logSum (double x, double y) Util::logSum (double x, double y)
{ {
// std::log (std::exp (x) + std::exp (y)) can overflow! // std::log (std::exp (x) + std::exp (y)) can overflow!
assert (std::isnan (x) == false); assert (std::isnan (x) == false);
@ -259,23 +258,23 @@ util::logSum (double x, double y)
inline unsigned inline unsigned
util::maxUnsigned (void) Util::maxUnsigned (void)
{ {
return std::numeric_limits<unsigned>::max(); return std::numeric_limits<unsigned>::max();
} }
namespace log_aware { namespace LogAware {
inline double one() { return globals::logDomain ? 0.0 : 1.0; } inline double one() { return Globals::logDomain ? 0.0 : 1.0; }
inline double zero() { return globals::logDomain ? NEG_INF : 0.0; } inline double zero() { return Globals::logDomain ? NEG_INF : 0.0; }
inline double addIdenty() { return globals::logDomain ? NEG_INF : 0.0; } inline double addIdenty() { return Globals::logDomain ? NEG_INF : 0.0; }
inline double multIdenty() { return globals::logDomain ? 0.0 : 1.0; } inline double multIdenty() { return Globals::logDomain ? 0.0 : 1.0; }
inline double withEvidence() { return globals::logDomain ? 0.0 : 1.0; } inline double withEvidence() { return Globals::logDomain ? 0.0 : 1.0; }
inline double noEvidence() { return globals::logDomain ? NEG_INF : 0.0; } inline double noEvidence() { return Globals::logDomain ? NEG_INF : 0.0; }
inline double log (double v) { return globals::logDomain ? ::log (v) : v; } inline double log (double v) { return Globals::logDomain ? ::log (v) : v; }
inline double exp (double v) { return globals::logDomain ? ::exp (v) : v; } inline double exp (double v) { return Globals::logDomain ? ::exp (v) : v; }
void normalize (Params&); void normalize (Params&);
@ -291,7 +290,7 @@ void pow (Params&, unsigned);
void pow (Params&, double); void pow (Params&, double);
} // namespace log_aware } // namespace LogAware
@ -393,7 +392,7 @@ template <typename T> std::ostream&
operator<< (std::ostream& os, const std::vector<T>& v) operator<< (std::ostream& os, const std::vector<T>& v)
{ {
os << "[" ; os << "[" ;
os << util::elementsToString (v, ", "); os << Util::elementsToString (v, ", ");
os << "]" ; os << "]" ;
return os; return os;
} }

View File

@ -13,7 +13,7 @@ Var::Var (const Var* v)
varId_ = v->varId(); varId_ = v->varId();
range_ = v->range(); range_ = v->range();
evidence_ = v->getEvidence(); evidence_ = v->getEvidence();
index_ = util::maxUnsigned(); index_ = Util::maxUnsigned();
} }
@ -25,7 +25,7 @@ Var::Var (VarId varId, unsigned range, int evidence)
varId_ = varId; varId_ = varId;
range_ = range; range_ = range;
evidence_ = evidence; evidence_ = evidence;
index_ = util::maxUnsigned(); index_ = Util::maxUnsigned();
} }
@ -81,7 +81,7 @@ void
Var::addVarInfo ( Var::addVarInfo (
VarId vid, std::string label, const States& states) VarId vid, std::string label, const States& states)
{ {
assert (util::contains (varsInfo_, vid) == false); assert (Util::contains (varsInfo_, vid) == false);
varsInfo_.insert (std::make_pair (vid, VarInfo (label, states))); varsInfo_.insert (std::make_pair (vid, VarInfo (label, states)));
} }
@ -90,7 +90,7 @@ Var::addVarInfo (
VarInfo VarInfo
Var::getVarInfo (VarId vid) Var::getVarInfo (VarId vid)
{ {
assert (util::contains (varsInfo_, vid)); assert (Util::contains (varsInfo_, vid));
return varsInfo_.find (vid)->second; return varsInfo_.find (vid)->second;
} }

View File

@ -26,7 +26,7 @@ class Var
public: public:
Var (const Var*); Var (const Var*);
Var (VarId, unsigned, int = constants::NO_EVIDENCE); Var (VarId, unsigned, int = Constants::NO_EVIDENCE);
virtual ~Var (void) { }; virtual ~Var (void) { };
@ -79,7 +79,7 @@ class Var
inline bool inline bool
Var::hasEvidence (void) const Var::hasEvidence (void) const
{ {
return evidence_ != constants::NO_EVIDENCE; return evidence_ != Constants::NO_EVIDENCE;
} }

View File

@ -13,7 +13,7 @@ namespace horus {
Params Params
VarElim::solveQuery (VarIds queryVids) VarElim::solveQuery (VarIds queryVids)
{ {
if (globals::verbosity > 1) { if (Globals::verbosity > 1) {
std::cout << "Solving query on " ; std::cout << "Solving query on " ;
for (size_t i = 0; i < queryVids.size(); i++) { for (size_t i = 0; i < queryVids.size(); i++) {
if (i != 0) std::cout << ", " ; if (i != 0) std::cout << ", " ;
@ -28,8 +28,8 @@ VarElim::solveQuery (VarIds queryVids)
createFactorList(); createFactorList();
absorveEvidence(); absorveEvidence();
Params params = processFactorList (queryVids); Params params = processFactorList (queryVids);
if (globals::logDomain) { if (Globals::logDomain) {
util::exp (params); Util::exp (params);
} }
return params; return params;
} }
@ -49,7 +49,7 @@ VarElim::printSolverFlags (void) const
case ElimHeuristic::MIN_FILL: ss << "min_fill"; break; case ElimHeuristic::MIN_FILL: ss << "min_fill"; break;
case ElimHeuristic::WEIGHTED_MIN_FILL: ss << "weighted_min_fill"; break; case ElimHeuristic::WEIGHTED_MIN_FILL: ss << "weighted_min_fill"; break;
} }
ss << ",log_domain=" << util::toString (globals::logDomain); ss << ",log_domain=" << Util::toString (Globals::logDomain);
ss << "]" ; ss << "]" ;
std::cout << ss.str() << std::endl; std::cout << ss.str() << std::endl;
} }
@ -81,15 +81,15 @@ VarElim::createFactorList (void)
void void
VarElim::absorveEvidence (void) VarElim::absorveEvidence (void)
{ {
if (globals::verbosity > 2) { if (Globals::verbosity > 2) {
util::printDashedLine(); Util::printDashedLine();
std::cout << "(initial factor list)" << std::endl; std::cout << "(initial factor list)" << std::endl;
printActiveFactors(); printActiveFactors();
} }
const VarNodes& varNodes = fg.varNodes(); const VarNodes& varNodes = fg.varNodes();
for (size_t i = 0; i < varNodes.size(); i++) { for (size_t i = 0; i < varNodes.size(); i++) {
if (varNodes[i]->hasEvidence()) { if (varNodes[i]->hasEvidence()) {
if (globals::verbosity > 1) { if (Globals::verbosity > 1) {
std::cout << "-> aborving evidence on "; std::cout << "-> aborving evidence on ";
std::cout << varNodes[i]->label() << " = " ; std::cout << varNodes[i]->label() << " = " ;
std::cout << varNodes[i]->getEvidence() << std::endl; std::cout << varNodes[i]->getEvidence() << std::endl;
@ -117,9 +117,9 @@ VarElim::processFactorList (const VarIds& queryVids)
VarIds elimOrder = ElimGraph::getEliminationOrder ( VarIds elimOrder = ElimGraph::getEliminationOrder (
factorList_, queryVids); factorList_, queryVids);
for (size_t i = 0; i < elimOrder.size(); i++) { for (size_t i = 0; i < elimOrder.size(); i++) {
if (globals::verbosity >= 2) { if (Globals::verbosity >= 2) {
if (globals::verbosity >= 3) { if (Globals::verbosity >= 3) {
util::printDashedLine(); Util::printDashedLine();
printActiveFactors(); printActiveFactors();
} }
std::cout << "-> summing out " ; std::cout << "-> summing out " ;
@ -146,7 +146,7 @@ VarElim::processFactorList (const VarIds& queryVids)
result.reorderArguments (unobservedVids); result.reorderArguments (unobservedVids);
result.normalize(); result.normalize();
if (globals::verbosity > 0) { if (Globals::verbosity > 0) {
std::cout << "total factor size: " << totalFactorSize_ << std::endl; std::cout << "total factor size: " << totalFactorSize_ << std::endl;
std::cout << "largest factor size: " << largestFactorSize_ << std::endl; std::cout << "largest factor size: " << largestFactorSize_ << std::endl;
std::cout << std::endl; std::cout << std::endl;

View File

@ -26,24 +26,24 @@ WeightedBp::getPosterioriOf (VarId vid)
assert (var); assert (var);
Params probs; Params probs;
if (var->hasEvidence()) { if (var->hasEvidence()) {
probs.resize (var->range(), log_aware::noEvidence()); probs.resize (var->range(), LogAware::noEvidence());
probs[var->getEvidence()] = log_aware::withEvidence(); probs[var->getEvidence()] = LogAware::withEvidence();
} else { } else {
probs.resize (var->range(), log_aware::multIdenty()); probs.resize (var->range(), LogAware::multIdenty());
const BpLinks& links = ninf(var)->getLinks(); const BpLinks& links = ninf(var)->getLinks();
if (globals::logDomain) { if (Globals::logDomain) {
for (size_t i = 0; i < links.size(); i++) { for (size_t i = 0; i < links.size(); i++) {
WeightedLink* l = static_cast<WeightedLink*> (links[i]); WeightedLink* l = static_cast<WeightedLink*> (links[i]);
probs += l->powMessage(); probs += l->powMessage();
} }
log_aware::normalize (probs); LogAware::normalize (probs);
util::exp (probs); Util::exp (probs);
} else { } else {
for (size_t i = 0; i < links.size(); i++) { for (size_t i = 0; i < links.size(); i++) {
WeightedLink* l = static_cast<WeightedLink*> (links[i]); WeightedLink* l = static_cast<WeightedLink*> (links[i]);
probs *= l->powMessage(); probs *= l->powMessage();
} }
log_aware::normalize (probs); LogAware::normalize (probs);
} }
} }
return probs; return probs;
@ -56,7 +56,7 @@ WeightedBp::createLinks (void)
{ {
using std::cout; using std::cout;
using std::endl; using std::endl;
if (globals::verbosity > 0) { if (Globals::verbosity > 0) {
cout << "compressed factor graph contains " ; cout << "compressed factor graph contains " ;
cout << fg.nrVarNodes() << " variables and " ; cout << fg.nrVarNodes() << " variables and " ;
cout << fg.nrFacNodes() << " factors " << endl; cout << fg.nrFacNodes() << " factors " << endl;
@ -66,7 +66,7 @@ WeightedBp::createLinks (void)
for (size_t i = 0; i < facNodes.size(); i++) { for (size_t i = 0; i < facNodes.size(); i++) {
const VarNodes& neighs = facNodes[i]->neighbors(); const VarNodes& neighs = facNodes[i]->neighbors();
for (size_t j = 0; j < neighs.size(); j++) { for (size_t j = 0; j < neighs.size(); j++) {
if (globals::verbosity > 1) { if (Globals::verbosity > 1) {
cout << "creating link " ; cout << "creating link " ;
cout << facNodes[i]->getLabel(); cout << facNodes[i]->getLabel();
cout << " -- " ; cout << " -- " ;
@ -77,7 +77,7 @@ WeightedBp::createLinks (void)
facNodes[i], neighs[j], j, weights_[i][j])); facNodes[i], neighs[j], j, weights_[i][j]));
} }
} }
if (globals::verbosity > 1) { if (Globals::verbosity > 1) {
cout << endl; cout << endl;
} }
} }
@ -92,7 +92,7 @@ WeightedBp::maxResidualSchedule (void)
calculateMessage (links_[i]); calculateMessage (links_[i]);
SortedOrder::iterator it = sortedOrder_.insert (links_[i]); SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
linkMap_.insert (make_pair (links_[i], it)); linkMap_.insert (make_pair (links_[i], it));
if (globals::verbosity >= 1) { if (Globals::verbosity >= 1) {
std::cout << "calculating " << links_[i]->toString() << std::endl; std::cout << "calculating " << links_[i]->toString() << std::endl;
} }
} }
@ -100,7 +100,7 @@ WeightedBp::maxResidualSchedule (void)
} }
for (size_t c = 0; c < links_.size(); c++) { for (size_t c = 0; c < links_.size(); c++) {
if (globals::verbosity > 1) { if (Globals::verbosity > 1) {
std::cout << std::endl << "current residuals:" << std::endl; std::cout << std::endl << "current residuals:" << std::endl;
for (SortedOrder::iterator it = sortedOrder_.begin(); for (SortedOrder::iterator it = sortedOrder_.begin();
it != sortedOrder_.end(); ++it) { it != sortedOrder_.end(); ++it) {
@ -112,7 +112,7 @@ WeightedBp::maxResidualSchedule (void)
SortedOrder::iterator it = sortedOrder_.begin(); SortedOrder::iterator it = sortedOrder_.begin();
BpLink* link = *it; BpLink* link = *it;
if (globals::verbosity >= 1) { if (Globals::verbosity >= 1) {
std::cout << "updating " << (*sortedOrder_.begin())->toString(); std::cout << "updating " << (*sortedOrder_.begin())->toString();
std::cout << std::endl; std::cout << std::endl;
} }
@ -130,7 +130,7 @@ WeightedBp::maxResidualSchedule (void)
const BpLinks& links = ninf(factorNeighbors[i])->getLinks(); const BpLinks& links = ninf(factorNeighbors[i])->getLinks();
for (size_t j = 0; j < links.size(); j++) { for (size_t j = 0; j < links.size(); j++) {
if (links[j]->varNode() != link->varNode()) { if (links[j]->varNode() != link->varNode()) {
if (globals::verbosity > 1) { if (Globals::verbosity > 1) {
std::cout << " calculating " << links[j]->toString(); std::cout << " calculating " << links[j]->toString();
std::cout << std::endl; std::cout << std::endl;
} }
@ -146,7 +146,7 @@ WeightedBp::maxResidualSchedule (void)
const BpLinks& links = ninf(link->facNode())->getLinks(); const BpLinks& links = ninf(link->facNode())->getLinks();
for (size_t i = 0; i < links.size(); i++) { for (size_t i = 0; i < links.size(); i++) {
if (links[i]->varNode() != link->varNode()) { if (links[i]->varNode() != link->varNode()) {
if (globals::verbosity > 1) { if (Globals::verbosity > 1) {
std::cout << " calculating " << links[i]->toString(); std::cout << " calculating " << links[i]->toString();
std::cout << std::endl; std::cout << std::endl;
} }
@ -171,19 +171,19 @@ WeightedBp::calcFactorToVarMsg (BpLink* _link)
// calculate the product of messages that were sent // calculate the product of messages that were sent
// to factor `src', except from var `dst' // to factor `src', except from var `dst'
unsigned reps = 1; unsigned reps = 1;
unsigned msgSize = util::sizeExpected (src->factor().ranges()); unsigned msgSize = Util::sizeExpected (src->factor().ranges());
Params msgProduct (msgSize, log_aware::multIdenty()); Params msgProduct (msgSize, LogAware::multIdenty());
if (globals::logDomain) { if (Globals::logDomain) {
for (size_t i = links.size(); i-- > 0; ) { for (size_t i = links.size(); i-- > 0; ) {
const WeightedLink* l = static_cast<const WeightedLink*> (links[i]); const WeightedLink* l = static_cast<const WeightedLink*> (links[i]);
if ( ! (l->varNode() == dst && l->index() == link->index())) { if ( ! (l->varNode() == dst && l->index() == link->index())) {
if (constants::SHOW_BP_CALCS) { if (Constants::SHOW_BP_CALCS) {
std::cout << " message from " << links[i]->varNode()->label(); std::cout << " message from " << links[i]->varNode()->label();
std::cout << ": " ; std::cout << ": " ;
} }
util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]), Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
reps, std::plus<double>()); reps, std::plus<double>());
if (constants::SHOW_BP_CALCS) { if (Constants::SHOW_BP_CALCS) {
std::cout << std::endl; std::cout << std::endl;
} }
} }
@ -193,13 +193,13 @@ WeightedBp::calcFactorToVarMsg (BpLink* _link)
for (size_t i = links.size(); i-- > 0; ) { for (size_t i = links.size(); i-- > 0; ) {
const WeightedLink* l = static_cast<const WeightedLink*> (links[i]); const WeightedLink* l = static_cast<const WeightedLink*> (links[i]);
if ( ! (l->varNode() == dst && l->index() == link->index())) { if ( ! (l->varNode() == dst && l->index() == link->index())) {
if (constants::SHOW_BP_CALCS) { if (Constants::SHOW_BP_CALCS) {
std::cout << " message from " << links[i]->varNode()->label(); std::cout << " message from " << links[i]->varNode()->label();
std::cout << ": " ; std::cout << ": " ;
} }
util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]), Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
reps, std::multiplies<double>()); reps, std::multiplies<double>());
if (constants::SHOW_BP_CALCS) { if (Constants::SHOW_BP_CALCS) {
std::cout << std::endl; std::cout << std::endl;
} }
} }
@ -209,12 +209,12 @@ WeightedBp::calcFactorToVarMsg (BpLink* _link)
Factor result (src->factor().arguments(), Factor result (src->factor().arguments(),
src->factor().ranges(), msgProduct); src->factor().ranges(), msgProduct);
assert (msgProduct.size() == src->factor().size()); assert (msgProduct.size() == src->factor().size());
if (globals::logDomain) { if (Globals::logDomain) {
result.params() += src->factor().params(); result.params() += src->factor().params();
} else { } else {
result.params() *= src->factor().params(); result.params() *= src->factor().params();
} }
if (constants::SHOW_BP_CALCS) { if (Constants::SHOW_BP_CALCS) {
std::cout << " message product: " ; std::cout << " message product: " ;
std::cout << msgProduct << std::endl; std::cout << msgProduct << std::endl;
std::cout << " original factor: " ; std::cout << " original factor: " ;
@ -223,13 +223,13 @@ WeightedBp::calcFactorToVarMsg (BpLink* _link)
std::cout << result.params() << std::endl; std::cout << result.params() << std::endl;
} }
result.sumOutAllExceptIndex (link->index()); result.sumOutAllExceptIndex (link->index());
if (constants::SHOW_BP_CALCS) { if (Constants::SHOW_BP_CALCS) {
std::cout << " marginalized: " ; std::cout << " marginalized: " ;
std::cout << result.params() << std::endl; std::cout << result.params() << std::endl;
} }
link->nextMessage() = result.params(); link->nextMessage() = result.params();
log_aware::normalize (link->nextMessage()); LogAware::normalize (link->nextMessage());
if (constants::SHOW_BP_CALCS) { if (Constants::SHOW_BP_CALCS) {
std::cout << " curr msg: " ; std::cout << " curr msg: " ;
std::cout << link->message() << std::endl; std::cout << link->message() << std::endl;
std::cout << " next msg: " ; std::cout << " next msg: " ;
@ -247,22 +247,22 @@ WeightedBp::getVarToFactorMsg (const BpLink* _link) const
const FacNode* dst = link->facNode(); const FacNode* dst = link->facNode();
Params msg; Params msg;
if (src->hasEvidence()) { if (src->hasEvidence()) {
msg.resize (src->range(), log_aware::noEvidence()); msg.resize (src->range(), LogAware::noEvidence());
double value = link->message()[src->getEvidence()]; double value = link->message()[src->getEvidence()];
if (constants::SHOW_BP_CALCS) { if (Constants::SHOW_BP_CALCS) {
msg[src->getEvidence()] = value; msg[src->getEvidence()] = value;
std::cout << msg << "^" << link->weight() << "-1" ; std::cout << msg << "^" << link->weight() << "-1" ;
} }
msg[src->getEvidence()] = log_aware::pow (value, link->weight() - 1); msg[src->getEvidence()] = LogAware::pow (value, link->weight() - 1);
} else { } else {
msg = link->message(); msg = link->message();
if (constants::SHOW_BP_CALCS) { if (Constants::SHOW_BP_CALCS) {
std::cout << msg << "^" << link->weight() << "-1" ; std::cout << msg << "^" << link->weight() << "-1" ;
} }
log_aware::pow (msg, link->weight() - 1); LogAware::pow (msg, link->weight() - 1);
} }
const BpLinks& links = ninf(src)->getLinks(); const BpLinks& links = ninf(src)->getLinks();
if (globals::logDomain) { if (Globals::logDomain) {
for (size_t i = 0; i < links.size(); i++) { for (size_t i = 0; i < links.size(); i++) {
WeightedLink* l = static_cast<WeightedLink*> (links[i]); WeightedLink* l = static_cast<WeightedLink*> (links[i]);
if ( ! (l->facNode() == dst && l->index() == link->index())) { if ( ! (l->facNode() == dst && l->index() == link->index())) {
@ -274,13 +274,13 @@ WeightedBp::getVarToFactorMsg (const BpLink* _link) const
WeightedLink* l = static_cast<WeightedLink*> (links[i]); WeightedLink* l = static_cast<WeightedLink*> (links[i]);
if ( ! (l->facNode() == dst && l->index() == link->index())) { if ( ! (l->facNode() == dst && l->index() == link->index())) {
msg *= l->powMessage(); msg *= l->powMessage();
if (constants::SHOW_BP_CALCS) { if (Constants::SHOW_BP_CALCS) {
std::cout << " x " << l->nextMessage() << "^" << link->weight(); std::cout << " x " << l->nextMessage() << "^" << link->weight();
} }
} }
} }
} }
if (constants::SHOW_BP_CALCS) { if (Constants::SHOW_BP_CALCS) {
std::cout << " = " << msg; std::cout << " = " << msg;
} }
return msg; return msg;

View File

@ -13,7 +13,7 @@ class WeightedLink : public BpLink
public: public:
WeightedLink (FacNode* fn, VarNode* vn, size_t idx, unsigned weight) WeightedLink (FacNode* fn, VarNode* vn, size_t idx, unsigned weight)
: BpLink (fn, vn), index_(idx), weight_(weight), : BpLink (fn, vn), index_(idx), weight_(weight),
pwdMsg_(vn->range(), log_aware::one()) { } pwdMsg_(vn->range(), LogAware::one()) { }
size_t index (void) const { return index_; } size_t index (void) const { return index_; }
@ -38,7 +38,7 @@ WeightedLink::updateMessage (void)
{ {
pwdMsg_ = *nextMsg_; pwdMsg_ = *nextMsg_;
swap (currMsg_, nextMsg_); swap (currMsg_, nextMsg_);
log_aware::pow (pwdMsg_, weight_); LogAware::pow (pwdMsg_, weight_);
} }