Improve namespace names
This commit is contained in:
parent
6f83ceb6f5
commit
973df43fe0
@ -79,8 +79,8 @@ BayesBall::constructGraph (FactorGraph* fg) const
|
|||||||
} else if (n->hasEvidence() && n->isVisited()) {
|
} else if (n->hasEvidence() && n->isVisited()) {
|
||||||
VarIds varIds = { facNodes[i]->factor().argument (0) };
|
VarIds varIds = { facNodes[i]->factor().argument (0) };
|
||||||
Ranges ranges = { facNodes[i]->factor().range (0) };
|
Ranges ranges = { facNodes[i]->factor().range (0) };
|
||||||
Params params (ranges[0], LogAware::noEvidence());
|
Params params (ranges[0], log_aware::noEvidence());
|
||||||
params[n->getEvidence()] = LogAware::withEvidence();
|
params[n->getEvidence()] = log_aware::withEvidence();
|
||||||
fg->addFactor (Factor (varIds, ranges, params));
|
fg->addFactor (Factor (varIds, ranges, params));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -13,7 +13,7 @@ namespace horus {
|
|||||||
void
|
void
|
||||||
BayesBallGraph::addNode (BBNode* n)
|
BayesBallGraph::addNode (BBNode* n)
|
||||||
{
|
{
|
||||||
assert (Util::contains (varMap_, n->varId()) == false);
|
assert (util::contains (varMap_, n->varId()) == false);
|
||||||
nodes_.push_back (n);
|
nodes_.push_back (n);
|
||||||
varMap_[n->varId()] = n;
|
varMap_[n->varId()] = n;
|
||||||
}
|
}
|
||||||
|
@ -15,8 +15,8 @@ BpLink::BpLink (FacNode* fn, VarNode* vn)
|
|||||||
{
|
{
|
||||||
fac_ = fn;
|
fac_ = fn;
|
||||||
var_ = vn;
|
var_ = vn;
|
||||||
v1_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
|
v1_.resize (vn->range(), log_aware::log (1.0 / vn->range()));
|
||||||
v2_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
|
v2_.resize (vn->range(), log_aware::log (1.0 / vn->range()));
|
||||||
currMsg_ = &v1_;
|
currMsg_ = &v1_;
|
||||||
nextMsg_ = &v2_;
|
nextMsg_ = &v2_;
|
||||||
residual_ = 0.0;
|
residual_ = 0.0;
|
||||||
@ -35,7 +35,7 @@ BpLink::clearResidual (void)
|
|||||||
void
|
void
|
||||||
BpLink::updateResidual (void)
|
BpLink::updateResidual (void)
|
||||||
{
|
{
|
||||||
residual_ = LogAware::getMaxNorm (v1_, v2_);
|
residual_ = log_aware::getMaxNorm (v1_, v2_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -111,9 +111,9 @@ BeliefProp::printSolverFlags (void) const
|
|||||||
case MsgSchedule::PARALLEL: ss << "parallel"; break;
|
case MsgSchedule::PARALLEL: ss << "parallel"; break;
|
||||||
case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break;
|
case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break;
|
||||||
}
|
}
|
||||||
ss << ",bp_max_iter=" << Util::toString (maxIter_);
|
ss << ",bp_max_iter=" << util::toString (maxIter_);
|
||||||
ss << ",bp_accuracy=" << Util::toString (accuracy_);
|
ss << ",bp_accuracy=" << util::toString (accuracy_);
|
||||||
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
ss << ",log_domain=" << util::toString (globals::logDomain);
|
||||||
ss << "]" ;
|
ss << "]" ;
|
||||||
std::cout << ss.str() << std::endl;
|
std::cout << ss.str() << std::endl;
|
||||||
}
|
}
|
||||||
@ -130,22 +130,22 @@ BeliefProp::getPosterioriOf (VarId vid)
|
|||||||
VarNode* var = fg.getVarNode (vid);
|
VarNode* var = fg.getVarNode (vid);
|
||||||
Params probs;
|
Params probs;
|
||||||
if (var->hasEvidence()) {
|
if (var->hasEvidence()) {
|
||||||
probs.resize (var->range(), LogAware::noEvidence());
|
probs.resize (var->range(), log_aware::noEvidence());
|
||||||
probs[var->getEvidence()] = LogAware::withEvidence();
|
probs[var->getEvidence()] = log_aware::withEvidence();
|
||||||
} else {
|
} else {
|
||||||
probs.resize (var->range(), LogAware::multIdenty());
|
probs.resize (var->range(), log_aware::multIdenty());
|
||||||
const BpLinks& links = ninf(var)->getLinks();
|
const BpLinks& links = ninf(var)->getLinks();
|
||||||
if (Globals::logDomain) {
|
if (globals::logDomain) {
|
||||||
for (size_t i = 0; i < links.size(); i++) {
|
for (size_t i = 0; i < links.size(); i++) {
|
||||||
probs += links[i]->message();
|
probs += links[i]->message();
|
||||||
}
|
}
|
||||||
LogAware::normalize (probs);
|
log_aware::normalize (probs);
|
||||||
Util::exp (probs);
|
util::exp (probs);
|
||||||
} else {
|
} else {
|
||||||
for (size_t i = 0; i < links.size(); i++) {
|
for (size_t i = 0; i < links.size(); i++) {
|
||||||
probs *= links[i]->message();
|
probs *= links[i]->message();
|
||||||
}
|
}
|
||||||
LogAware::normalize (probs);
|
log_aware::normalize (probs);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return probs;
|
return probs;
|
||||||
@ -196,8 +196,8 @@ BeliefProp::getFactorJoint (
|
|||||||
res.reorderArguments (jointVarIds);
|
res.reorderArguments (jointVarIds);
|
||||||
res.normalize();
|
res.normalize();
|
||||||
Params jointDist = res.params();
|
Params jointDist = res.params();
|
||||||
if (Globals::logDomain) {
|
if (globals::logDomain) {
|
||||||
Util::exp (jointDist);
|
util::exp (jointDist);
|
||||||
}
|
}
|
||||||
return jointDist;
|
return jointDist;
|
||||||
}
|
}
|
||||||
@ -207,7 +207,7 @@ BeliefProp::getFactorJoint (
|
|||||||
void
|
void
|
||||||
BeliefProp::calculateAndUpdateMessage (BpLink* link, bool calcResidual)
|
BeliefProp::calculateAndUpdateMessage (BpLink* link, bool calcResidual)
|
||||||
{
|
{
|
||||||
if (Globals::verbosity > 2) {
|
if (globals::verbosity > 2) {
|
||||||
std::cout << "calculating & updating " << link->toString();
|
std::cout << "calculating & updating " << link->toString();
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
}
|
}
|
||||||
@ -223,7 +223,7 @@ BeliefProp::calculateAndUpdateMessage (BpLink* link, bool calcResidual)
|
|||||||
void
|
void
|
||||||
BeliefProp::calculateMessage (BpLink* link, bool calcResidual)
|
BeliefProp::calculateMessage (BpLink* link, bool calcResidual)
|
||||||
{
|
{
|
||||||
if (Globals::verbosity > 2) {
|
if (globals::verbosity > 2) {
|
||||||
std::cout << "calculating " << link->toString();
|
std::cout << "calculating " << link->toString();
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
}
|
}
|
||||||
@ -239,7 +239,7 @@ void
|
|||||||
BeliefProp::updateMessage (BpLink* link)
|
BeliefProp::updateMessage (BpLink* link)
|
||||||
{
|
{
|
||||||
link->updateMessage();
|
link->updateMessage();
|
||||||
if (Globals::verbosity > 2) {
|
if (globals::verbosity > 2) {
|
||||||
std::cout << "updating " << link->toString();
|
std::cout << "updating " << link->toString();
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
}
|
}
|
||||||
@ -254,9 +254,9 @@ BeliefProp::runSolver (void)
|
|||||||
nIters_ = 0;
|
nIters_ = 0;
|
||||||
while (!converged() && nIters_ < maxIter_) {
|
while (!converged() && nIters_ < maxIter_) {
|
||||||
nIters_ ++;
|
nIters_ ++;
|
||||||
if (Globals::verbosity > 1) {
|
if (globals::verbosity > 1) {
|
||||||
Util::printHeader (std::string ("Iteration ")
|
util::printHeader (std::string ("Iteration ")
|
||||||
+ Util::toString (nIters_));
|
+ util::toString (nIters_));
|
||||||
}
|
}
|
||||||
switch (schedule_) {
|
switch (schedule_) {
|
||||||
case MsgSchedule::SEQ_RANDOM:
|
case MsgSchedule::SEQ_RANDOM:
|
||||||
@ -280,7 +280,7 @@ BeliefProp::runSolver (void)
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (Globals::verbosity > 0) {
|
if (globals::verbosity > 0) {
|
||||||
if (nIters_ < maxIter_) {
|
if (nIters_ < maxIter_) {
|
||||||
std::cout << "Belief propagation converged in " ;
|
std::cout << "Belief propagation converged in " ;
|
||||||
std::cout << nIters_ << " iterations" << std::endl;
|
std::cout << nIters_ << " iterations" << std::endl;
|
||||||
@ -322,7 +322,7 @@ BeliefProp::maxResidualSchedule (void)
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (size_t c = 0; c < links_.size(); c++) {
|
for (size_t c = 0; c < links_.size(); c++) {
|
||||||
if (Globals::verbosity > 1) {
|
if (globals::verbosity > 1) {
|
||||||
std::cout << "current residuals:" << std::endl;
|
std::cout << "current residuals:" << std::endl;
|
||||||
for (SortedOrder::iterator it = sortedOrder_.begin();
|
for (SortedOrder::iterator it = sortedOrder_.begin();
|
||||||
it != sortedOrder_.end(); ++it) {
|
it != sortedOrder_.end(); ++it) {
|
||||||
@ -358,8 +358,8 @@ BeliefProp::maxResidualSchedule (void)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (Globals::verbosity > 1) {
|
if (globals::verbosity > 1) {
|
||||||
Util::printDashedLine();
|
util::printDashedLine();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -375,18 +375,18 @@ BeliefProp::calcFactorToVarMsg (BpLink* link)
|
|||||||
// calculate the product of messages that were sent
|
// calculate the product of messages that were sent
|
||||||
// to factor `src', except from var `dst'
|
// to factor `src', except from var `dst'
|
||||||
unsigned reps = 1;
|
unsigned reps = 1;
|
||||||
unsigned msgSize = Util::sizeExpected (src->factor().ranges());
|
unsigned msgSize = util::sizeExpected (src->factor().ranges());
|
||||||
Params msgProduct (msgSize, LogAware::multIdenty());
|
Params msgProduct (msgSize, log_aware::multIdenty());
|
||||||
if (Globals::logDomain) {
|
if (globals::logDomain) {
|
||||||
for (size_t i = links.size(); i-- > 0; ) {
|
for (size_t i = links.size(); i-- > 0; ) {
|
||||||
if (links[i]->varNode() != dst) {
|
if (links[i]->varNode() != dst) {
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (constants::SHOW_BP_CALCS) {
|
||||||
std::cout << " message from " << links[i]->varNode()->label();
|
std::cout << " message from " << links[i]->varNode()->label();
|
||||||
std::cout << ": " ;
|
std::cout << ": " ;
|
||||||
}
|
}
|
||||||
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
|
util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
|
||||||
reps, std::plus<double>());
|
reps, std::plus<double>());
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (constants::SHOW_BP_CALCS) {
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -395,13 +395,13 @@ BeliefProp::calcFactorToVarMsg (BpLink* link)
|
|||||||
} else {
|
} else {
|
||||||
for (size_t i = links.size(); i-- > 0; ) {
|
for (size_t i = links.size(); i-- > 0; ) {
|
||||||
if (links[i]->varNode() != dst) {
|
if (links[i]->varNode() != dst) {
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (constants::SHOW_BP_CALCS) {
|
||||||
std::cout << " message from " << links[i]->varNode()->label();
|
std::cout << " message from " << links[i]->varNode()->label();
|
||||||
std::cout << ": " ;
|
std::cout << ": " ;
|
||||||
}
|
}
|
||||||
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
|
util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
|
||||||
reps, std::multiplies<double>());
|
reps, std::multiplies<double>());
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (constants::SHOW_BP_CALCS) {
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -411,19 +411,19 @@ BeliefProp::calcFactorToVarMsg (BpLink* link)
|
|||||||
Factor result (src->factor().arguments(),
|
Factor result (src->factor().arguments(),
|
||||||
src->factor().ranges(), msgProduct);
|
src->factor().ranges(), msgProduct);
|
||||||
result.multiply (src->factor());
|
result.multiply (src->factor());
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (constants::SHOW_BP_CALCS) {
|
||||||
std::cout << " message product: " << msgProduct << std::endl;
|
std::cout << " message product: " << msgProduct << std::endl;
|
||||||
std::cout << " original factor: " << src->factor().params();
|
std::cout << " original factor: " << src->factor().params();
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
std::cout << " factor product: " << result.params() << std::endl;
|
std::cout << " factor product: " << result.params() << std::endl;
|
||||||
}
|
}
|
||||||
result.sumOutAllExcept (dst->varId());
|
result.sumOutAllExcept (dst->varId());
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (constants::SHOW_BP_CALCS) {
|
||||||
std::cout << " marginalized: " << result.params() << std::endl;
|
std::cout << " marginalized: " << result.params() << std::endl;
|
||||||
}
|
}
|
||||||
link->nextMessage() = result.params();
|
link->nextMessage() = result.params();
|
||||||
LogAware::normalize (link->nextMessage());
|
log_aware::normalize (link->nextMessage());
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (constants::SHOW_BP_CALCS) {
|
||||||
std::cout << " curr msg: " << link->message() << std::endl;
|
std::cout << " curr msg: " << link->message() << std::endl;
|
||||||
std::cout << " next msg: " << link->nextMessage() << std::endl;
|
std::cout << " next msg: " << link->nextMessage() << std::endl;
|
||||||
}
|
}
|
||||||
@ -437,22 +437,22 @@ BeliefProp::getVarToFactorMsg (const BpLink* link) const
|
|||||||
const VarNode* src = link->varNode();
|
const VarNode* src = link->varNode();
|
||||||
Params msg;
|
Params msg;
|
||||||
if (src->hasEvidence()) {
|
if (src->hasEvidence()) {
|
||||||
msg.resize (src->range(), LogAware::noEvidence());
|
msg.resize (src->range(), log_aware::noEvidence());
|
||||||
msg[src->getEvidence()] = LogAware::withEvidence();
|
msg[src->getEvidence()] = log_aware::withEvidence();
|
||||||
} else {
|
} else {
|
||||||
msg.resize (src->range(), LogAware::one());
|
msg.resize (src->range(), log_aware::one());
|
||||||
}
|
}
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (constants::SHOW_BP_CALCS) {
|
||||||
std::cout << msg;
|
std::cout << msg;
|
||||||
}
|
}
|
||||||
BpLinks::const_iterator it;
|
BpLinks::const_iterator it;
|
||||||
const BpLinks& links = ninf (src)->getLinks();
|
const BpLinks& links = ninf (src)->getLinks();
|
||||||
if (Globals::logDomain) {
|
if (globals::logDomain) {
|
||||||
for (it = links.begin(); it != links.end(); ++it) {
|
for (it = links.begin(); it != links.end(); ++it) {
|
||||||
if (*it != link) {
|
if (*it != link) {
|
||||||
msg += (*it)->message();
|
msg += (*it)->message();
|
||||||
}
|
}
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (constants::SHOW_BP_CALCS) {
|
||||||
std::cout << " x " << (*it)->message();
|
std::cout << " x " << (*it)->message();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -461,12 +461,12 @@ BeliefProp::getVarToFactorMsg (const BpLink* link) const
|
|||||||
if (*it != link) {
|
if (*it != link) {
|
||||||
msg *= (*it)->message();
|
msg *= (*it)->message();
|
||||||
}
|
}
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (constants::SHOW_BP_CALCS) {
|
||||||
std::cout << " x " << (*it)->message();
|
std::cout << " x " << (*it)->message();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (constants::SHOW_BP_CALCS) {
|
||||||
std::cout << " = " << msg;
|
std::cout << " = " << msg;
|
||||||
}
|
}
|
||||||
return msg;
|
return msg;
|
||||||
@ -516,11 +516,11 @@ BeliefProp::converged (void)
|
|||||||
if (nIters_ == 0) {
|
if (nIters_ == 0) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (Globals::verbosity > 2) {
|
if (globals::verbosity > 2) {
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
}
|
}
|
||||||
if (nIters_ == 1) {
|
if (nIters_ == 1) {
|
||||||
if (Globals::verbosity > 1) {
|
if (globals::verbosity > 1) {
|
||||||
std::cout << "no residuals" << std::endl << std::endl;
|
std::cout << "no residuals" << std::endl << std::endl;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
@ -536,18 +536,18 @@ BeliefProp::converged (void)
|
|||||||
} else {
|
} else {
|
||||||
for (size_t i = 0; i < links_.size(); i++) {
|
for (size_t i = 0; i < links_.size(); i++) {
|
||||||
double residual = links_[i]->residual();
|
double residual = links_[i]->residual();
|
||||||
if (Globals::verbosity > 1) {
|
if (globals::verbosity > 1) {
|
||||||
std::cout << links_[i]->toString() + " residual = " << residual;
|
std::cout << links_[i]->toString() + " residual = " << residual;
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
}
|
}
|
||||||
if (residual > accuracy_) {
|
if (residual > accuracy_) {
|
||||||
converged = false;
|
converged = false;
|
||||||
if (Globals::verbosity < 2) {
|
if (globals::verbosity < 2) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (Globals::verbosity > 1) {
|
if (globals::verbosity > 1) {
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -278,7 +278,7 @@ void
|
|||||||
ConstraintTree::moveToTop (const LogVars& lvs)
|
ConstraintTree::moveToTop (const LogVars& lvs)
|
||||||
{
|
{
|
||||||
for (size_t i = 0; i < lvs.size(); i++) {
|
for (size_t i = 0; i < lvs.size(); i++) {
|
||||||
size_t pos = Util::indexOf (logVars_, lvs[i]);
|
size_t pos = util::indexOf (logVars_, lvs[i]);
|
||||||
assert (pos != logVars_.size());
|
assert (pos != logVars_.size());
|
||||||
for (size_t j = pos; j-- > i; ) {
|
for (size_t j = pos; j-- > i; ) {
|
||||||
swapLogVar (logVars_[j]);
|
swapLogVar (logVars_[j]);
|
||||||
@ -292,7 +292,7 @@ void
|
|||||||
ConstraintTree::moveToBottom (const LogVars& lvs)
|
ConstraintTree::moveToBottom (const LogVars& lvs)
|
||||||
{
|
{
|
||||||
for (size_t i = lvs.size(); i-- > 0; ) {
|
for (size_t i = lvs.size(); i-- > 0; ) {
|
||||||
size_t pos = Util::indexOf (logVars_, lvs[i]);
|
size_t pos = util::indexOf (logVars_, lvs[i]);
|
||||||
assert (pos != logVars_.size());
|
assert (pos != logVars_.size());
|
||||||
size_t stop = logVars_.size() - (lvs.size() - i - 1);
|
size_t stop = logVars_.size() - (lvs.size() - i - 1);
|
||||||
for (size_t j = pos; j < stop - 1; j++) {
|
for (size_t j = pos; j < stop - 1; j++) {
|
||||||
@ -329,7 +329,7 @@ ConstraintTree::join (ConstraintTree* ct, bool oneTwoOne)
|
|||||||
if (intersect.empty()) {
|
if (intersect.empty()) {
|
||||||
// cartesian product
|
// cartesian product
|
||||||
appendOnBottom (root_, ct->root()->childs());
|
appendOnBottom (root_, ct->root()->childs());
|
||||||
Util::addToVector (logVars_, ct->logVars_);
|
util::addToVector (logVars_, ct->logVars_);
|
||||||
logVarSet_ |= ct->logVarSet_;
|
logVarSet_ |= ct->logVarSet_;
|
||||||
} else {
|
} else {
|
||||||
moveToTop (intersect.elements());
|
moveToTop (intersect.elements());
|
||||||
@ -350,7 +350,7 @@ ConstraintTree::join (ConstraintTree* ct, bool oneTwoOne)
|
|||||||
|
|
||||||
LogVars newLvs (ct->logVars().begin() + intersect.size(),
|
LogVars newLvs (ct->logVars().begin() + intersect.size(),
|
||||||
ct->logVars().end());
|
ct->logVars().end());
|
||||||
Util::addToVector (logVars_, newLvs);
|
util::addToVector (logVars_, newLvs);
|
||||||
logVarSet_ |= LogVarSet (newLvs);
|
logVarSet_ |= LogVarSet (newLvs);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -360,7 +360,7 @@ ConstraintTree::join (ConstraintTree* ct, bool oneTwoOne)
|
|||||||
unsigned
|
unsigned
|
||||||
ConstraintTree::getLevel (LogVar X) const
|
ConstraintTree::getLevel (LogVar X) const
|
||||||
{
|
{
|
||||||
unsigned level = Util::indexOf (logVars_, X);
|
unsigned level = util::indexOf (logVars_, X);
|
||||||
level += 1; // root is in level 0, first logVar is in level 1
|
level += 1; // root is in level 0, first logVar is in level 1
|
||||||
return level;
|
return level;
|
||||||
}
|
}
|
||||||
@ -496,7 +496,7 @@ ConstraintTree::tupleSet (const LogVars& originalLvs)
|
|||||||
{
|
{
|
||||||
LogVars uniqueLvs;
|
LogVars uniqueLvs;
|
||||||
for (size_t i = 0; i < originalLvs.size(); i++) {
|
for (size_t i = 0; i < originalLvs.size(); i++) {
|
||||||
if (Util::contains (uniqueLvs, originalLvs[i]) == false) {
|
if (util::contains (uniqueLvs, originalLvs[i]) == false) {
|
||||||
uniqueLvs.push_back (originalLvs[i]);
|
uniqueLvs.push_back (originalLvs[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -510,7 +510,7 @@ ConstraintTree::tupleSet (const LogVars& originalLvs)
|
|||||||
std::vector<size_t> indexes;
|
std::vector<size_t> indexes;
|
||||||
indexes.reserve (originalLvs.size());
|
indexes.reserve (originalLvs.size());
|
||||||
for (size_t i = 0; i < originalLvs.size(); i++) {
|
for (size_t i = 0; i < originalLvs.size(); i++) {
|
||||||
indexes.push_back (Util::indexOf (uniqueLvs, originalLvs[i]));
|
indexes.push_back (util::indexOf (uniqueLvs, originalLvs[i]));
|
||||||
}
|
}
|
||||||
Tuples tuples2;
|
Tuples tuples2;
|
||||||
tuples2.reserve (tuples.size());
|
tuples2.reserve (tuples.size());
|
||||||
@ -1030,7 +1030,7 @@ ConstraintTree::appendOnBottom (CTNode* n, const CTChilds& childs)
|
|||||||
void
|
void
|
||||||
ConstraintTree::swapLogVar (LogVar X)
|
ConstraintTree::swapLogVar (LogVar X)
|
||||||
{
|
{
|
||||||
size_t pos = Util::indexOf (logVars_, X);
|
size_t pos = util::indexOf (logVars_, X);
|
||||||
assert (pos != logVars_.size());
|
assert (pos != logVars_.size());
|
||||||
const CTNodes& nodes = getNodesAtLevel (pos);
|
const CTNodes& nodes = getNodesAtLevel (pos);
|
||||||
for (CTNodes::const_iterator nodeIt = nodes.begin();
|
for (CTNodes::const_iterator nodeIt = nodes.begin();
|
||||||
|
@ -52,8 +52,8 @@ CountingBp::printSolverFlags (void) const
|
|||||||
}
|
}
|
||||||
ss << ",bp_max_iter=" << WeightedBp::maxIterations();
|
ss << ",bp_max_iter=" << WeightedBp::maxIterations();
|
||||||
ss << ",bp_accuracy=" << WeightedBp::accuracy();
|
ss << ",bp_accuracy=" << WeightedBp::accuracy();
|
||||||
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
ss << ",log_domain=" << util::toString (globals::logDomain);
|
||||||
ss << ",fif=" << Util::toString (CountingBp::fif_);
|
ss << ",fif=" << util::toString (CountingBp::fif_);
|
||||||
ss << "]" ;
|
ss << "]" ;
|
||||||
std::cout << ss.str() << std::endl;
|
std::cout << ss.str() << std::endl;
|
||||||
}
|
}
|
||||||
@ -103,18 +103,18 @@ CountingBp::findIdenticalFactors()
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
for (size_t i = 0; i < facNodes.size(); i++) {
|
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||||
facNodes[i]->factor().setDistId (Util::maxUnsigned());
|
facNodes[i]->factor().setDistId (util::maxUnsigned());
|
||||||
}
|
}
|
||||||
unsigned groupCount = 1;
|
unsigned groupCount = 1;
|
||||||
for (size_t i = 0; i < facNodes.size() - 1; i++) {
|
for (size_t i = 0; i < facNodes.size() - 1; i++) {
|
||||||
Factor& f1 = facNodes[i]->factor();
|
Factor& f1 = facNodes[i]->factor();
|
||||||
if (f1.distId() != Util::maxUnsigned()) {
|
if (f1.distId() != util::maxUnsigned()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
f1.setDistId (groupCount);
|
f1.setDistId (groupCount);
|
||||||
for (size_t j = i + 1; j < facNodes.size(); j++) {
|
for (size_t j = i + 1; j < facNodes.size(); j++) {
|
||||||
Factor& f2 = facNodes[j]->factor();
|
Factor& f2 = facNodes[j]->factor();
|
||||||
if (f2.distId() != Util::maxUnsigned()) {
|
if (f2.distId() != util::maxUnsigned()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (f1.size() == f2.size() &&
|
if (f1.size() == f2.size() &&
|
||||||
@ -303,7 +303,7 @@ CountingBp::getSignature (const FacNode* facNode)
|
|||||||
VarId
|
VarId
|
||||||
CountingBp::getRepresentative (VarId vid)
|
CountingBp::getRepresentative (VarId vid)
|
||||||
{
|
{
|
||||||
assert (Util::contains (varClusterMap_, vid));
|
assert (util::contains (varClusterMap_, vid));
|
||||||
VarCluster* vc = varClusterMap_.find (vid)->second;
|
VarCluster* vc = varClusterMap_.find (vid)->second;
|
||||||
return vc->representative()->varId();
|
return vc->representative()->varId();
|
||||||
}
|
}
|
||||||
@ -314,7 +314,7 @@ FacNode*
|
|||||||
CountingBp::getRepresentative (FacNode* fn)
|
CountingBp::getRepresentative (FacNode* fn)
|
||||||
{
|
{
|
||||||
for (size_t i = 0; i < facClusters_.size(); i++) {
|
for (size_t i = 0; i < facClusters_.size(); i++) {
|
||||||
if (Util::contains (facClusters_[i]->members(), fn)) {
|
if (util::contains (facClusters_[i]->members(), fn)) {
|
||||||
return facClusters_[i]->representative();
|
return facClusters_[i]->representative();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -55,7 +55,7 @@ ElimGraph::getEliminatingOrder (const VarIds& excludedVids)
|
|||||||
VarIds elimOrder;
|
VarIds elimOrder;
|
||||||
unmarked_.reserve (nodes_.size());
|
unmarked_.reserve (nodes_.size());
|
||||||
for (size_t i = 0; i < nodes_.size(); i++) {
|
for (size_t i = 0; i < nodes_.size(); i++) {
|
||||||
if (Util::contains (excludedVids, nodes_[i]->varId()) == false) {
|
if (util::contains (excludedVids, nodes_[i]->varId()) == false) {
|
||||||
unmarked_.insert (nodes_[i]);
|
unmarked_.insert (nodes_[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -142,7 +142,7 @@ ElimGraph::getEliminationOrder (
|
|||||||
Factors::const_iterator first = factors.begin();
|
Factors::const_iterator first = factors.begin();
|
||||||
Factors::const_iterator end = factors.end();
|
Factors::const_iterator end = factors.end();
|
||||||
for (; first != end; ++first) {
|
for (; first != end; ++first) {
|
||||||
Util::addToVector (allVids, (*first)->arguments());
|
util::addToVector (allVids, (*first)->arguments());
|
||||||
}
|
}
|
||||||
TinySet<VarId> elimOrder (allVids);
|
TinySet<VarId> elimOrder (allVids);
|
||||||
elimOrder -= TinySet<VarId> (excludedVids);
|
elimOrder -= TinySet<VarId> (excludedVids);
|
||||||
@ -178,7 +178,7 @@ EgNode*
|
|||||||
ElimGraph::getLowestCostNode (void) const
|
ElimGraph::getLowestCostNode (void) const
|
||||||
{
|
{
|
||||||
EgNode* bestNode = 0;
|
EgNode* bestNode = 0;
|
||||||
unsigned minCost = Util::maxUnsigned();
|
unsigned minCost = util::maxUnsigned();
|
||||||
EGNeighs::const_iterator it;
|
EGNeighs::const_iterator it;
|
||||||
switch (elimHeuristic_) {
|
switch (elimHeuristic_) {
|
||||||
case MIN_NEIGHBORS: {
|
case MIN_NEIGHBORS: {
|
||||||
|
@ -27,7 +27,7 @@ Factor::Factor (
|
|||||||
ranges_ = ranges;
|
ranges_ = ranges;
|
||||||
params_ = params;
|
params_ = params;
|
||||||
distId_ = distId;
|
distId_ = distId;
|
||||||
assert (params_.size() == Util::sizeExpected (ranges_));
|
assert (params_.size() == util::sizeExpected (ranges_));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -43,7 +43,7 @@ Factor::Factor (
|
|||||||
}
|
}
|
||||||
params_ = params;
|
params_ = params;
|
||||||
distId_ = distId;
|
distId_ = distId;
|
||||||
assert (params_.size() == Util::sizeExpected (ranges_));
|
assert (params_.size() == util::sizeExpected (ranges_));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -131,7 +131,7 @@ Factor::print (void) const
|
|||||||
for (size_t i = 0; i < args_.size(); i++) {
|
for (size_t i = 0; i < args_.size(); i++) {
|
||||||
vars.push_back (new Var (args_[i], ranges_[i]));
|
vars.push_back (new Var (args_[i], ranges_[i]));
|
||||||
}
|
}
|
||||||
std::vector<std::string> jointStrings = Util::getStateLines (vars);
|
std::vector<std::string> jointStrings = util::getStateLines (vars);
|
||||||
for (size_t i = 0; i < params_.size(); i++) {
|
for (size_t i = 0; i < params_.size(); i++) {
|
||||||
// cout << "[" << distId_ << "] " ;
|
// cout << "[" << distId_ << "] " ;
|
||||||
std::cout << "f(" << jointStrings[i] << ")" ;
|
std::cout << "f(" << jointStrings[i] << ")" ;
|
||||||
@ -149,11 +149,11 @@ void
|
|||||||
Factor::sumOutFirstVariable (void)
|
Factor::sumOutFirstVariable (void)
|
||||||
{
|
{
|
||||||
size_t sep = params_.size() / 2;
|
size_t sep = params_.size() / 2;
|
||||||
if (Globals::logDomain) {
|
if (globals::logDomain) {
|
||||||
std::transform (
|
std::transform (
|
||||||
params_.begin(), params_.begin() + sep,
|
params_.begin(), params_.begin() + sep,
|
||||||
params_.begin() + sep, params_.begin(),
|
params_.begin() + sep, params_.begin(),
|
||||||
Util::logSum);
|
util::logSum);
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
std::transform (
|
std::transform (
|
||||||
@ -174,10 +174,10 @@ Factor::sumOutLastVariable (void)
|
|||||||
Params::iterator first1 = params_.begin();
|
Params::iterator first1 = params_.begin();
|
||||||
Params::iterator first2 = params_.begin();
|
Params::iterator first2 = params_.begin();
|
||||||
Params::iterator last = params_.end();
|
Params::iterator last = params_.end();
|
||||||
if (Globals::logDomain) {
|
if (globals::logDomain) {
|
||||||
while (first2 != last) {
|
while (first2 != last) {
|
||||||
// the arguments can be swaped, but that is ok
|
// the arguments can be swaped, but that is ok
|
||||||
*first1++ = Util::logSum (*first2++, *first2++);
|
*first1++ = util::logSum (*first2++, *first2++);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
while (first2 != last) {
|
while (first2 != last) {
|
||||||
@ -206,13 +206,13 @@ Factor::sumOutArgs (const std::vector<bool>& mask)
|
|||||||
ranges_.push_back (ranges_[i]);
|
ranges_.push_back (ranges_[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Params newps (new_size, LogAware::addIdenty());
|
Params newps (new_size, log_aware::addIdenty());
|
||||||
Params::const_iterator first = params_.begin();
|
Params::const_iterator first = params_.begin();
|
||||||
Params::const_iterator last = params_.end();
|
Params::const_iterator last = params_.end();
|
||||||
MapIndexer indexer (oldRanges, mask);
|
MapIndexer indexer (oldRanges, mask);
|
||||||
if (Globals::logDomain) {
|
if (globals::logDomain) {
|
||||||
while (first != last) {
|
while (first != last) {
|
||||||
newps[indexer] = Util::logSum (newps[indexer], *first++);
|
newps[indexer] = util::logSum (newps[indexer], *first++);
|
||||||
++ indexer;
|
++ indexer;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -34,7 +34,7 @@ class TFactor
|
|||||||
|
|
||||||
void setDistId (unsigned id) { distId_ = id; }
|
void setDistId (unsigned id) { distId_ = id; }
|
||||||
|
|
||||||
void normalize (void) { LogAware::normalize (params_); }
|
void normalize (void) { log_aware::normalize (params_); }
|
||||||
|
|
||||||
void randomize (void);
|
void randomize (void);
|
||||||
|
|
||||||
@ -91,7 +91,7 @@ template <typename T> inline void
|
|||||||
TFactor<T>::setParams (const Params& newParams)
|
TFactor<T>::setParams (const Params& newParams)
|
||||||
{
|
{
|
||||||
params_ = newParams;
|
params_ = newParams;
|
||||||
assert (params_.size() == Util::sizeExpected (ranges_));
|
assert (params_.size() == util::sizeExpected (ranges_));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -99,7 +99,7 @@ TFactor<T>::setParams (const Params& newParams)
|
|||||||
template <typename T> inline size_t
|
template <typename T> inline size_t
|
||||||
TFactor<T>::indexOf (const T& t) const
|
TFactor<T>::indexOf (const T& t) const
|
||||||
{
|
{
|
||||||
return Util::indexOf (args_, t);
|
return util::indexOf (args_, t);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -136,7 +136,7 @@ TFactor<T>::multiply (TFactor<T>& g)
|
|||||||
{
|
{
|
||||||
if (args_ == g.arguments()) {
|
if (args_ == g.arguments()) {
|
||||||
// optimization
|
// optimization
|
||||||
Globals::logDomain
|
globals::logDomain
|
||||||
? params_ += g.params()
|
? params_ += g.params()
|
||||||
: params_ *= g.params();
|
: params_ *= g.params();
|
||||||
return;
|
return;
|
||||||
@ -163,7 +163,7 @@ TFactor<T>::multiply (TFactor<T>& g)
|
|||||||
extend (range_prod);
|
extend (range_prod);
|
||||||
Params::iterator it = params_.begin();
|
Params::iterator it = params_.begin();
|
||||||
MapIndexer indexer (args_, ranges_, g_args, g_ranges);
|
MapIndexer indexer (args_, ranges_, g_args, g_ranges);
|
||||||
if (Globals::logDomain) {
|
if (globals::logDomain) {
|
||||||
for (; indexer.valid(); ++it, ++indexer) {
|
for (; indexer.valid(); ++it, ++indexer) {
|
||||||
*it += g_params[indexer];
|
*it += g_params[indexer];
|
||||||
}
|
}
|
||||||
@ -183,13 +183,13 @@ TFactor<T>::sumOutIndex (size_t idx)
|
|||||||
assert (idx < args_.size());
|
assert (idx < args_.size());
|
||||||
assert (args_.size() > 1);
|
assert (args_.size() > 1);
|
||||||
size_t new_size = params_.size() / ranges_[idx];
|
size_t new_size = params_.size() / ranges_[idx];
|
||||||
Params newps (new_size, LogAware::addIdenty());
|
Params newps (new_size, log_aware::addIdenty());
|
||||||
Params::const_iterator first = params_.begin();
|
Params::const_iterator first = params_.begin();
|
||||||
Params::const_iterator last = params_.end();
|
Params::const_iterator last = params_.end();
|
||||||
MapIndexer indexer (ranges_, idx);
|
MapIndexer indexer (ranges_, idx);
|
||||||
if (Globals::logDomain) {
|
if (globals::logDomain) {
|
||||||
for (; first != last; ++indexer) {
|
for (; first != last; ++indexer) {
|
||||||
newps[indexer] = Util::logSum (newps[indexer], *first++);
|
newps[indexer] = util::logSum (newps[indexer], *first++);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (; first != last; ++indexer) {
|
for (; first != last; ++indexer) {
|
||||||
@ -255,7 +255,7 @@ TFactor<T>::reorderArguments (const std::vector<T>& new_args)
|
|||||||
template <typename T> inline bool
|
template <typename T> inline bool
|
||||||
TFactor<T>::contains (const T& arg) const
|
TFactor<T>::contains (const T& arg) const
|
||||||
{
|
{
|
||||||
return Util::contains (args_, arg);
|
return util::contains (args_, arg);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -310,7 +310,7 @@ TFactor<T>::cartesianProduct (
|
|||||||
Params::const_iterator first1 = backup.begin();
|
Params::const_iterator first1 = backup.begin();
|
||||||
Params::const_iterator last1 = backup.end();
|
Params::const_iterator last1 = backup.end();
|
||||||
Params::const_iterator tmp;
|
Params::const_iterator tmp;
|
||||||
if (Globals::logDomain) {
|
if (globals::logDomain) {
|
||||||
for (; first1 != last1; ++first1) {
|
for (; first1 != last1; ++first1) {
|
||||||
for (tmp = first2; tmp != last2; ++tmp) {
|
for (tmp = first2; tmp != last2; ++tmp) {
|
||||||
params_.push_back ((*first1) + (*tmp));
|
params_.push_back ((*first1) + (*tmp));
|
||||||
@ -335,10 +335,10 @@ class Factor : public TFactor<VarId>
|
|||||||
Factor (const Factor&);
|
Factor (const Factor&);
|
||||||
|
|
||||||
Factor (const VarIds&, const Ranges&, const Params&,
|
Factor (const VarIds&, const Ranges&, const Params&,
|
||||||
unsigned = Util::maxUnsigned());
|
unsigned = util::maxUnsigned());
|
||||||
|
|
||||||
Factor (const Vars&, const Params&,
|
Factor (const Vars&, const Params&,
|
||||||
unsigned = Util::maxUnsigned());
|
unsigned = util::maxUnsigned());
|
||||||
|
|
||||||
void sumOut (VarId);
|
void sumOut (VarId);
|
||||||
|
|
||||||
|
@ -106,9 +106,9 @@ FactorGraph::readFromUaiFormat (const char* fileName)
|
|||||||
for (unsigned i = 0; i < nrFactors; i++) {
|
for (unsigned i = 0; i < nrFactors; i++) {
|
||||||
ignoreLines (is);
|
ignoreLines (is);
|
||||||
is >> nrParams;
|
is >> nrParams;
|
||||||
if (nrParams != Util::sizeExpected (allRanges[i])) {
|
if (nrParams != util::sizeExpected (allRanges[i])) {
|
||||||
std::cerr << "Error: invalid number of parameters for factor nº " << i ;
|
std::cerr << "Error: invalid number of parameters for factor nº " << i ;
|
||||||
std::cerr << ", " << Util::sizeExpected (allRanges[i]);
|
std::cerr << ", " << util::sizeExpected (allRanges[i]);
|
||||||
std::cerr << " expected, " << nrParams << " given." << std::endl;
|
std::cerr << " expected, " << nrParams << " given." << std::endl;
|
||||||
exit (EXIT_FAILURE);
|
exit (EXIT_FAILURE);
|
||||||
}
|
}
|
||||||
@ -116,8 +116,8 @@ FactorGraph::readFromUaiFormat (const char* fileName)
|
|||||||
for (unsigned j = 0; j < nrParams; j++) {
|
for (unsigned j = 0; j < nrParams; j++) {
|
||||||
is >> params[j];
|
is >> params[j];
|
||||||
}
|
}
|
||||||
if (Globals::logDomain) {
|
if (globals::logDomain) {
|
||||||
Util::log (params);
|
util::log (params);
|
||||||
}
|
}
|
||||||
Factor f (allVarIds[i], allRanges[i], params);
|
Factor f (allVarIds[i], allRanges[i], params);
|
||||||
if (bayesFactors_ && allVarIds[i].size() > 1) {
|
if (bayesFactors_ && allVarIds[i].size() > 1) {
|
||||||
@ -171,7 +171,7 @@ FactorGraph::readFromLibDaiFormat (const char* fileName)
|
|||||||
ignoreLines (is);
|
ignoreLines (is);
|
||||||
unsigned nNonzeros;
|
unsigned nNonzeros;
|
||||||
is >> nNonzeros;
|
is >> nNonzeros;
|
||||||
Params params (Util::sizeExpected (ranges), 0);
|
Params params (util::sizeExpected (ranges), 0);
|
||||||
for (unsigned j = 0; j < nNonzeros; j++) {
|
for (unsigned j = 0; j < nNonzeros; j++) {
|
||||||
ignoreLines (is);
|
ignoreLines (is);
|
||||||
unsigned index;
|
unsigned index;
|
||||||
@ -181,8 +181,8 @@ FactorGraph::readFromLibDaiFormat (const char* fileName)
|
|||||||
is >> val;
|
is >> val;
|
||||||
params[index] = val;
|
params[index] = val;
|
||||||
}
|
}
|
||||||
if (Globals::logDomain) {
|
if (globals::logDomain) {
|
||||||
Util::log (params);
|
util::log (params);
|
||||||
}
|
}
|
||||||
std::reverse (vids.begin(), vids.end());
|
std::reverse (vids.begin(), vids.end());
|
||||||
Factor f (vids, ranges, params);
|
Factor f (vids, ranges, params);
|
||||||
@ -306,13 +306,13 @@ FactorGraph::exportToLibDai (const char* fileName) const
|
|||||||
for (size_t i = 0; i < facNodes_.size(); i++) {
|
for (size_t i = 0; i < facNodes_.size(); i++) {
|
||||||
Factor f (facNodes_[i]->factor());
|
Factor f (facNodes_[i]->factor());
|
||||||
out << f.nrArguments() << std::endl;
|
out << f.nrArguments() << std::endl;
|
||||||
out << Util::elementsToString (f.arguments()) << std::endl;
|
out << util::elementsToString (f.arguments()) << std::endl;
|
||||||
out << Util::elementsToString (f.ranges()) << std::endl;
|
out << util::elementsToString (f.ranges()) << std::endl;
|
||||||
VarIds args = f.arguments();
|
VarIds args = f.arguments();
|
||||||
std::reverse (args.begin(), args.end());
|
std::reverse (args.begin(), args.end());
|
||||||
f.reorderArguments (args);
|
f.reorderArguments (args);
|
||||||
if (Globals::logDomain) {
|
if (globals::logDomain) {
|
||||||
Util::exp (f.params());
|
util::exp (f.params());
|
||||||
}
|
}
|
||||||
out << f.size() << std::endl;
|
out << f.size() << std::endl;
|
||||||
for (size_t j = 0; j < f.size(); j++) {
|
for (size_t j = 0; j < f.size(); j++) {
|
||||||
@ -347,7 +347,7 @@ FactorGraph::exportToUai (const char* fileName) const
|
|||||||
if (bayesFactors_) {
|
if (bayesFactors_) {
|
||||||
std::swap (args.front(), args.back());
|
std::swap (args.front(), args.back());
|
||||||
}
|
}
|
||||||
out << args.size() << " " << Util::elementsToString (args);
|
out << args.size() << " " << util::elementsToString (args);
|
||||||
out << std::endl;
|
out << std::endl;
|
||||||
}
|
}
|
||||||
out << std::endl;
|
out << std::endl;
|
||||||
@ -359,11 +359,11 @@ FactorGraph::exportToUai (const char* fileName) const
|
|||||||
f.reorderArguments (args);
|
f.reorderArguments (args);
|
||||||
}
|
}
|
||||||
Params params = f.params();
|
Params params = f.params();
|
||||||
if (Globals::logDomain) {
|
if (globals::logDomain) {
|
||||||
Util::exp (params);
|
util::exp (params);
|
||||||
}
|
}
|
||||||
out << params.size() << std::endl << " " ;
|
out << params.size() << std::endl << " " ;
|
||||||
out << Util::elementsToString (params);
|
out << util::elementsToString (params);
|
||||||
out << std::endl << std::endl;
|
out << std::endl << std::endl;
|
||||||
}
|
}
|
||||||
out.close();
|
out.close();
|
||||||
|
@ -20,7 +20,7 @@ class VarNode : public Var
|
|||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
VarNode (VarId varId, unsigned nrStates,
|
VarNode (VarId varId, unsigned nrStates,
|
||||||
int evidence = Constants::NO_EVIDENCE)
|
int evidence = constants::NO_EVIDENCE)
|
||||||
: Var (varId, nrStates, evidence) { }
|
: Var (varId, nrStates, evidence) { }
|
||||||
|
|
||||||
VarNode (const Var* v) : Var (v) { }
|
VarNode (const Var* v) : Var (v) { }
|
||||||
|
@ -28,10 +28,10 @@ GroundSolver::printAnswer (const VarIds& vids)
|
|||||||
if (unobservedVids.empty() == false) {
|
if (unobservedVids.empty() == false) {
|
||||||
Params res = solveQuery (unobservedVids);
|
Params res = solveQuery (unobservedVids);
|
||||||
std::vector<std::string> stateLines =
|
std::vector<std::string> stateLines =
|
||||||
Util::getStateLines (unobservedVars);
|
util::getStateLines (unobservedVars);
|
||||||
for (size_t i = 0; i < res.size(); i++) {
|
for (size_t i = 0; i < res.size(); i++) {
|
||||||
std::cout << "P(" << stateLines[i] << ") = " ;
|
std::cout << "P(" << stateLines[i] << ") = " ;
|
||||||
std::cout << std::setprecision (Constants::PRECISION) << res[i];
|
std::cout << std::setprecision (constants::PRECISION) << res[i];
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
}
|
}
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
|
@ -80,7 +80,7 @@ HistogramSet::getHistograms (unsigned N, unsigned R)
|
|||||||
unsigned
|
unsigned
|
||||||
HistogramSet::nrHistograms (unsigned N, unsigned R)
|
HistogramSet::nrHistograms (unsigned N, unsigned R)
|
||||||
{
|
{
|
||||||
return Util::nrCombinations (N + R - 1, R - 1);
|
return util::nrCombinations (N + R - 1, R - 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -102,17 +102,17 @@ std::vector<double>
|
|||||||
HistogramSet::getNumAssigns (unsigned N, unsigned R)
|
HistogramSet::getNumAssigns (unsigned N, unsigned R)
|
||||||
{
|
{
|
||||||
HistogramSet hs (N, R);
|
HistogramSet hs (N, R);
|
||||||
double N_fac = Util::logFactorial (N);
|
double N_fac = util::logFactorial (N);
|
||||||
unsigned H = hs.nrHistograms();
|
unsigned H = hs.nrHistograms();
|
||||||
std::vector<double> numAssigns;
|
std::vector<double> numAssigns;
|
||||||
numAssigns.reserve (H);
|
numAssigns.reserve (H);
|
||||||
for (unsigned h = 0; h < H; h++) {
|
for (unsigned h = 0; h < H; h++) {
|
||||||
double prod = 0.0;
|
double prod = 0.0;
|
||||||
for (unsigned r = 0; r < R; r++) {
|
for (unsigned r = 0; r < R; r++) {
|
||||||
prod += Util::logFactorial (hs[r]);
|
prod += util::logFactorial (hs[r]);
|
||||||
}
|
}
|
||||||
double res = N_fac - prod;
|
double res = N_fac - prod;
|
||||||
numAssigns.push_back (Globals::logDomain ? res : std::exp (res));
|
numAssigns.push_back (globals::logDomain ? res : std::exp (res));
|
||||||
hs.nextHistogram();
|
hs.nextHistogram();
|
||||||
}
|
}
|
||||||
return numAssigns;
|
return numAssigns;
|
||||||
|
@ -50,7 +50,7 @@ enum GroundSolverType
|
|||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
namespace Globals {
|
namespace globals {
|
||||||
|
|
||||||
extern bool logDomain;
|
extern bool logDomain;
|
||||||
|
|
||||||
@ -63,7 +63,7 @@ extern GroundSolverType groundSolver;
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
namespace Constants {
|
namespace constants {
|
||||||
|
|
||||||
// show message calculation for belief propagation
|
// show message calculation for belief propagation
|
||||||
const bool SHOW_BP_CALCS = false;
|
const bool SHOW_BP_CALCS = false;
|
||||||
@ -73,7 +73,7 @@ const int NO_EVIDENCE = -1;
|
|||||||
// number of digits to show when printing a parameter
|
// number of digits to show when printing a parameter
|
||||||
const unsigned PRECISION = 6;
|
const unsigned PRECISION = 6;
|
||||||
|
|
||||||
}
|
} // namespace constants
|
||||||
|
|
||||||
} // namespace horus
|
} // namespace horus
|
||||||
|
|
||||||
|
@ -47,7 +47,7 @@ main (int argc, const char* argv[])
|
|||||||
if (horus::FactorGraph::printFactorGraph()) {
|
if (horus::FactorGraph::printFactorGraph()) {
|
||||||
fg.print();
|
fg.print();
|
||||||
}
|
}
|
||||||
if (horus::Globals::verbosity > 0) {
|
if (horus::globals::verbosity > 0) {
|
||||||
std::cout << "factor graph contains " ;
|
std::cout << "factor graph contains " ;
|
||||||
std::cout << fg.nrVarNodes() << " variables and " ;
|
std::cout << fg.nrVarNodes() << " variables and " ;
|
||||||
std::cout << fg.nrFacNodes() << " factors " << std::endl;
|
std::cout << fg.nrFacNodes() << " factors " << std::endl;
|
||||||
@ -80,7 +80,7 @@ readHorusFlags (int argc, const char* argv[])
|
|||||||
std::cerr << USAGE << std::endl;
|
std::cerr << USAGE << std::endl;
|
||||||
exit (EXIT_FAILURE);
|
exit (EXIT_FAILURE);
|
||||||
}
|
}
|
||||||
horus::Util::setHorusFlag (leftArg, rightArg);
|
horus::util::setHorusFlag (leftArg, rightArg);
|
||||||
}
|
}
|
||||||
return i + 1;
|
return i + 1;
|
||||||
}
|
}
|
||||||
@ -116,13 +116,13 @@ readQueryAndEvidence (
|
|||||||
for (int i = start; i < argc; i++) {
|
for (int i = start; i < argc; i++) {
|
||||||
const std::string& arg = argv[i];
|
const std::string& arg = argv[i];
|
||||||
if (arg.find ('=') == std::string::npos) {
|
if (arg.find ('=') == std::string::npos) {
|
||||||
if (horus::Util::isInteger (arg) == false) {
|
if (horus::util::isInteger (arg) == false) {
|
||||||
std::cerr << "Error: `" << arg << "' " ;
|
std::cerr << "Error: `" << arg << "' " ;
|
||||||
std::cerr << "is not a variable id." ;
|
std::cerr << "is not a variable id." ;
|
||||||
std::cerr << std::endl;
|
std::cerr << std::endl;
|
||||||
exit (EXIT_FAILURE);
|
exit (EXIT_FAILURE);
|
||||||
}
|
}
|
||||||
horus::VarId vid = horus::Util::stringToUnsigned (arg);
|
horus::VarId vid = horus::util::stringToUnsigned (arg);
|
||||||
horus::VarNode* queryVar = fg.getVarNode (vid);
|
horus::VarNode* queryVar = fg.getVarNode (vid);
|
||||||
if (queryVar == false) {
|
if (queryVar == false) {
|
||||||
std::cerr << "Error: unknow variable with id " ;
|
std::cerr << "Error: unknow variable with id " ;
|
||||||
@ -139,12 +139,12 @@ readQueryAndEvidence (
|
|||||||
std::cerr << USAGE << std::endl;
|
std::cerr << USAGE << std::endl;
|
||||||
exit (EXIT_FAILURE);
|
exit (EXIT_FAILURE);
|
||||||
}
|
}
|
||||||
if (horus::Util::isInteger (leftArg) == false) {
|
if (horus::util::isInteger (leftArg) == false) {
|
||||||
std::cerr << "Error: `" << leftArg << "' " ;
|
std::cerr << "Error: `" << leftArg << "' " ;
|
||||||
std::cerr << "is not a variable id." << std::endl;
|
std::cerr << "is not a variable id." << std::endl;
|
||||||
exit (EXIT_FAILURE);
|
exit (EXIT_FAILURE);
|
||||||
}
|
}
|
||||||
horus::VarId vid = horus::Util::stringToUnsigned (leftArg);
|
horus::VarId vid = horus::util::stringToUnsigned (leftArg);
|
||||||
horus::VarNode* observedVar = fg.getVarNode (vid);
|
horus::VarNode* observedVar = fg.getVarNode (vid);
|
||||||
if (observedVar == false) {
|
if (observedVar == false) {
|
||||||
std::cerr << "Error: unknow variable with id " ;
|
std::cerr << "Error: unknow variable with id " ;
|
||||||
@ -156,12 +156,12 @@ readQueryAndEvidence (
|
|||||||
std::cerr << USAGE << std::endl;
|
std::cerr << USAGE << std::endl;
|
||||||
exit (EXIT_FAILURE);
|
exit (EXIT_FAILURE);
|
||||||
}
|
}
|
||||||
if (horus::Util::isInteger (rightArg) == false) {
|
if (horus::util::isInteger (rightArg) == false) {
|
||||||
std::cerr << "Error: `" << rightArg << "' " ;
|
std::cerr << "Error: `" << rightArg << "' " ;
|
||||||
std::cerr << "is not a state index." << std::endl;
|
std::cerr << "is not a state index." << std::endl;
|
||||||
exit (EXIT_FAILURE);
|
exit (EXIT_FAILURE);
|
||||||
}
|
}
|
||||||
unsigned stateIdx = horus::Util::stringToUnsigned (rightArg);
|
unsigned stateIdx = horus::util::stringToUnsigned (rightArg);
|
||||||
if (observedVar->isValidState (stateIdx) == false) {
|
if (observedVar->isValidState (stateIdx) == false) {
|
||||||
std::cerr << "Error: `" << stateIdx << "' " ;
|
std::cerr << "Error: `" << stateIdx << "' " ;
|
||||||
std::cerr << "is not a valid state index for variable with id " ;
|
std::cerr << "is not a valid state index for variable with id " ;
|
||||||
@ -182,7 +182,7 @@ runSolver (
|
|||||||
const horus::VarIds& queryIds)
|
const horus::VarIds& queryIds)
|
||||||
{
|
{
|
||||||
horus::GroundSolver* solver = 0;
|
horus::GroundSolver* solver = 0;
|
||||||
switch (horus::Globals::groundSolver) {
|
switch (horus::globals::groundSolver) {
|
||||||
case horus::GroundSolverType::VE:
|
case horus::GroundSolverType::VE:
|
||||||
solver = new horus::VarElim (fg);
|
solver = new horus::VarElim (fg);
|
||||||
break;
|
break;
|
||||||
@ -195,7 +195,7 @@ runSolver (
|
|||||||
default:
|
default:
|
||||||
assert (false);
|
assert (false);
|
||||||
}
|
}
|
||||||
if (horus::Globals::verbosity > 0) {
|
if (horus::globals::verbosity > 0) {
|
||||||
solver->printSolverFlags();
|
solver->printSolverFlags();
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
}
|
}
|
||||||
|
@ -48,8 +48,8 @@ createLiftedNetwork (void)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// LiftedUtils::printSymbolDictionary();
|
// LiftedUtils::printSymbolDictionary();
|
||||||
if (Globals::verbosity > 2) {
|
if (globals::verbosity > 2) {
|
||||||
Util::printHeader ("INITIAL PARFACTORS");
|
util::printHeader ("INITIAL PARFACTORS");
|
||||||
for (size_t i = 0; i < parfactors.size(); i++) {
|
for (size_t i = 0; i < parfactors.size(); i++) {
|
||||||
parfactors[i]->print();
|
parfactors[i]->print();
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
@ -58,8 +58,8 @@ createLiftedNetwork (void)
|
|||||||
|
|
||||||
ParfactorList* pfList = new ParfactorList (parfactors);
|
ParfactorList* pfList = new ParfactorList (parfactors);
|
||||||
|
|
||||||
if (Globals::verbosity > 2) {
|
if (globals::verbosity > 2) {
|
||||||
Util::printHeader ("SHATTERED PARFACTORS");
|
util::printHeader ("SHATTERED PARFACTORS");
|
||||||
pfList->print();
|
pfList->print();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -120,7 +120,7 @@ createGroundNetwork (void)
|
|||||||
if (FactorGraph::printFactorGraph()) {
|
if (FactorGraph::printFactorGraph()) {
|
||||||
fg->print();
|
fg->print();
|
||||||
}
|
}
|
||||||
if (Globals::verbosity > 0) {
|
if (globals::verbosity > 0) {
|
||||||
std::cout << "factor graph contains " ;
|
std::cout << "factor graph contains " ;
|
||||||
std::cout << fg->nrVarNodes() << " variables and " ;
|
std::cout << fg->nrVarNodes() << " variables and " ;
|
||||||
std::cout << fg->nrFacNodes() << " factors " << std::endl;
|
std::cout << fg->nrFacNodes() << " factors " << std::endl;
|
||||||
@ -139,13 +139,13 @@ runLiftedSolver (void)
|
|||||||
LiftedOperations::absorveEvidence (pfListCopy, *network->second);
|
LiftedOperations::absorveEvidence (pfListCopy, *network->second);
|
||||||
|
|
||||||
LiftedSolver* solver = 0;
|
LiftedSolver* solver = 0;
|
||||||
switch (Globals::liftedSolver) {
|
switch (globals::liftedSolver) {
|
||||||
case LiftedSolverType::LVE: solver = new LiftedVe (pfListCopy); break;
|
case LiftedSolverType::LVE: solver = new LiftedVe (pfListCopy); break;
|
||||||
case LiftedSolverType::LBP: solver = new LiftedBp (pfListCopy); break;
|
case LiftedSolverType::LBP: solver = new LiftedBp (pfListCopy); break;
|
||||||
case LiftedSolverType::LKC: solver = new LiftedKc (pfListCopy); break;
|
case LiftedSolverType::LKC: solver = new LiftedKc (pfListCopy); break;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Globals::verbosity > 0) {
|
if (globals::verbosity > 0) {
|
||||||
solver->printSolverFlags();
|
solver->printSolverFlags();
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
}
|
}
|
||||||
@ -205,7 +205,7 @@ runGroundSolver (void)
|
|||||||
if (fg->bayesianFactors()) {
|
if (fg->bayesianFactors()) {
|
||||||
std::set<VarId> vids;
|
std::set<VarId> vids;
|
||||||
for (size_t i = 0; i < tasks.size(); i++) {
|
for (size_t i = 0; i < tasks.size(); i++) {
|
||||||
Util::addToSet (vids, tasks[i]);
|
util::addToSet (vids, tasks[i]);
|
||||||
}
|
}
|
||||||
mfg = BayesBall::getMinimalFactorGraph (
|
mfg = BayesBall::getMinimalFactorGraph (
|
||||||
*fg, VarIds (vids.begin(), vids.end()));
|
*fg, VarIds (vids.begin(), vids.end()));
|
||||||
@ -213,13 +213,13 @@ runGroundSolver (void)
|
|||||||
|
|
||||||
GroundSolver* solver = 0;
|
GroundSolver* solver = 0;
|
||||||
CountingBp::setFindIdenticalFactorsFlag (false);
|
CountingBp::setFindIdenticalFactorsFlag (false);
|
||||||
switch (Globals::groundSolver) {
|
switch (globals::groundSolver) {
|
||||||
case GroundSolverType::VE: solver = new VarElim (*mfg); break;
|
case GroundSolverType::VE: solver = new VarElim (*mfg); break;
|
||||||
case GroundSolverType::BP: solver = new BeliefProp (*mfg); break;
|
case GroundSolverType::BP: solver = new BeliefProp (*mfg); break;
|
||||||
case GroundSolverType::CBP: solver = new CountingBp (*mfg); break;
|
case GroundSolverType::CBP: solver = new CountingBp (*mfg); break;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Globals::verbosity > 0) {
|
if (globals::verbosity > 0) {
|
||||||
solver->printSolverFlags();
|
solver->printSolverFlags();
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
}
|
}
|
||||||
@ -251,14 +251,14 @@ setParfactorsParams (void)
|
|||||||
while (distIdsList != YAP_TermNil()) {
|
while (distIdsList != YAP_TermNil()) {
|
||||||
unsigned distId = (unsigned) YAP_IntOfTerm (
|
unsigned distId = (unsigned) YAP_IntOfTerm (
|
||||||
YAP_HeadOfTerm (distIdsList));
|
YAP_HeadOfTerm (distIdsList));
|
||||||
assert (Util::contains (paramsMap, distId) == false);
|
assert (util::contains (paramsMap, distId) == false);
|
||||||
paramsMap[distId] = readParameters (YAP_HeadOfTerm (paramsList));
|
paramsMap[distId] = readParameters (YAP_HeadOfTerm (paramsList));
|
||||||
distIdsList = YAP_TailOfTerm (distIdsList);
|
distIdsList = YAP_TailOfTerm (distIdsList);
|
||||||
paramsList = YAP_TailOfTerm (paramsList);
|
paramsList = YAP_TailOfTerm (paramsList);
|
||||||
}
|
}
|
||||||
ParfactorList::iterator it = pfList->begin();
|
ParfactorList::iterator it = pfList->begin();
|
||||||
while (it != pfList->end()) {
|
while (it != pfList->end()) {
|
||||||
assert (Util::contains (paramsMap, (*it)->distId()));
|
assert (util::contains (paramsMap, (*it)->distId()));
|
||||||
(*it)->setParams (paramsMap[(*it)->distId()]);
|
(*it)->setParams (paramsMap[(*it)->distId()]);
|
||||||
++ it;
|
++ it;
|
||||||
}
|
}
|
||||||
@ -277,7 +277,7 @@ setFactorsParams (void)
|
|||||||
while (distIdsList != YAP_TermNil()) {
|
while (distIdsList != YAP_TermNil()) {
|
||||||
unsigned distId = (unsigned) YAP_IntOfTerm (
|
unsigned distId = (unsigned) YAP_IntOfTerm (
|
||||||
YAP_HeadOfTerm (distIdsList));
|
YAP_HeadOfTerm (distIdsList));
|
||||||
assert (Util::contains (paramsMap, distId) == false);
|
assert (util::contains (paramsMap, distId) == false);
|
||||||
paramsMap[distId] = readParameters (YAP_HeadOfTerm (paramsList));
|
paramsMap[distId] = readParameters (YAP_HeadOfTerm (paramsList));
|
||||||
distIdsList = YAP_TailOfTerm (distIdsList);
|
distIdsList = YAP_TailOfTerm (distIdsList);
|
||||||
paramsList = YAP_TailOfTerm (paramsList);
|
paramsList = YAP_TailOfTerm (paramsList);
|
||||||
@ -285,7 +285,7 @@ setFactorsParams (void)
|
|||||||
const FacNodes& facNodes = fg->facNodes();
|
const FacNodes& facNodes = fg->facNodes();
|
||||||
for (size_t i = 0; i < facNodes.size(); i++) {
|
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||||
unsigned distId = facNodes[i]->factor().distId();
|
unsigned distId = facNodes[i]->factor().distId();
|
||||||
assert (Util::contains (paramsMap, distId));
|
assert (util::contains (paramsMap, distId));
|
||||||
facNodes[i]->factor().setParams (paramsMap[distId]);
|
facNodes[i]->factor().setParams (paramsMap[distId]);
|
||||||
}
|
}
|
||||||
return TRUE;
|
return TRUE;
|
||||||
@ -343,7 +343,7 @@ setHorusFlag (void)
|
|||||||
} else {
|
} else {
|
||||||
value = ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG2)));
|
value = ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG2)));
|
||||||
}
|
}
|
||||||
return Util::setHorusFlag (option, value);
|
return util::setHorusFlag (option, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -519,8 +519,8 @@ readParameters (YAP_Term paramL)
|
|||||||
params.push_back ((double) YAP_FloatOfTerm (YAP_HeadOfTerm (paramL)));
|
params.push_back ((double) YAP_FloatOfTerm (YAP_HeadOfTerm (paramL)));
|
||||||
paramL = YAP_TailOfTerm (paramL);
|
paramL = YAP_TailOfTerm (paramL);
|
||||||
}
|
}
|
||||||
if (Globals::logDomain) {
|
if (globals::logDomain) {
|
||||||
Util::log (params);
|
util::log (params);
|
||||||
}
|
}
|
||||||
return params;
|
return params;
|
||||||
}
|
}
|
||||||
|
@ -56,7 +56,7 @@ class Indexer
|
|||||||
inline
|
inline
|
||||||
Indexer::Indexer (const Ranges& ranges, bool calcOffsets)
|
Indexer::Indexer (const Ranges& ranges, bool calcOffsets)
|
||||||
: index_(0), indices_(ranges.size(), 0), ranges_(ranges),
|
: index_(0), indices_(ranges.size(), 0), ranges_(ranges),
|
||||||
size_(Util::sizeExpected (ranges))
|
size_(util::sizeExpected (ranges))
|
||||||
{
|
{
|
||||||
if (calcOffsets) {
|
if (calcOffsets) {
|
||||||
calculateOffsets();
|
calculateOffsets();
|
||||||
@ -290,7 +290,7 @@ MapIndexer::MapIndexer (
|
|||||||
}
|
}
|
||||||
offsets_.reserve (allArgs.size());
|
offsets_.reserve (allArgs.size());
|
||||||
for (size_t i = 0; i < allArgs.size(); i++) {
|
for (size_t i = 0; i < allArgs.size(); i++) {
|
||||||
size_t idx = Util::indexOf (wantedArgs, allArgs[i]);
|
size_t idx = util::indexOf (wantedArgs, allArgs[i]);
|
||||||
offsets_.push_back (idx != wantedArgs.size() ? offsets[idx] : 0);
|
offsets_.push_back (idx != wantedArgs.size() ? offsets[idx] : 0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -77,7 +77,7 @@ LiftedBp::printSolverFlags (void) const
|
|||||||
}
|
}
|
||||||
ss << ",bp_max_iter=" << WeightedBp::maxIterations();
|
ss << ",bp_max_iter=" << WeightedBp::maxIterations();
|
||||||
ss << ",bp_accuracy=" << WeightedBp::accuracy();
|
ss << ",bp_accuracy=" << WeightedBp::accuracy();
|
||||||
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
ss << ",log_domain=" << util::toString (globals::logDomain);
|
||||||
ss << "]" ;
|
ss << "]" ;
|
||||||
std::cout << ss.str() << std::endl;
|
std::cout << ss.str() << std::endl;
|
||||||
}
|
}
|
||||||
@ -90,8 +90,8 @@ LiftedBp::refineParfactors (void)
|
|||||||
pfList_ = parfactorList;
|
pfList_ = parfactorList;
|
||||||
while (iterate() == false);
|
while (iterate() == false);
|
||||||
|
|
||||||
if (Globals::verbosity > 2) {
|
if (globals::verbosity > 2) {
|
||||||
Util::printHeader ("AFTER REFINEMENT");
|
util::printHeader ("AFTER REFINEMENT");
|
||||||
pfList_.print();
|
pfList_.print();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -187,7 +187,7 @@ LiftedBp::rangeOfGround (const Ground& gr)
|
|||||||
}
|
}
|
||||||
++ it;
|
++ it;
|
||||||
}
|
}
|
||||||
return Util::maxUnsigned();
|
return util::maxUnsigned();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -22,7 +22,7 @@ OrNode::weight (void) const
|
|||||||
{
|
{
|
||||||
double lw = leftBranch_->weight();
|
double lw = leftBranch_->weight();
|
||||||
double rw = rightBranch_->weight();
|
double rw = rightBranch_->weight();
|
||||||
return Globals::logDomain ? Util::logSum (lw, rw) : lw + rw;
|
return globals::logDomain ? util::logSum (lw, rw) : lw + rw;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -40,7 +40,7 @@ AndNode::weight (void) const
|
|||||||
{
|
{
|
||||||
double lw = leftBranch_->weight();
|
double lw = leftBranch_->weight();
|
||||||
double rw = rightBranch_->weight();
|
double rw = rightBranch_->weight();
|
||||||
return Globals::logDomain ? lw + rw : lw * rw;
|
return globals::logDomain ? lw + rw : lw * rw;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -60,17 +60,17 @@ SetOrNode::~SetOrNode (void)
|
|||||||
double
|
double
|
||||||
SetOrNode::weight (void) const
|
SetOrNode::weight (void) const
|
||||||
{
|
{
|
||||||
double weightSum = LogAware::addIdenty();
|
double weightSum = log_aware::addIdenty();
|
||||||
for (unsigned i = 0; i < nrGroundings_ + 1; i++) {
|
for (unsigned i = 0; i < nrGroundings_ + 1; i++) {
|
||||||
nrPos_ = nrGroundings_ - i;
|
nrPos_ = nrGroundings_ - i;
|
||||||
nrNeg_ = i;
|
nrNeg_ = i;
|
||||||
if (Globals::logDomain) {
|
if (globals::logDomain) {
|
||||||
double nrCombs = Util::nrCombinations (nrGroundings_, i);
|
double nrCombs = util::nrCombinations (nrGroundings_, i);
|
||||||
double w = follow_->weight();
|
double w = follow_->weight();
|
||||||
weightSum = Util::logSum (weightSum, std::log (nrCombs) + w);
|
weightSum = util::logSum (weightSum, std::log (nrCombs) + w);
|
||||||
} else {
|
} else {
|
||||||
double w = follow_->weight();
|
double w = follow_->weight();
|
||||||
weightSum += Util::nrCombinations (nrGroundings_, i) * w;
|
weightSum += util::nrCombinations (nrGroundings_, i) * w;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
nrPos_ = -1;
|
nrPos_ = -1;
|
||||||
@ -90,7 +90,7 @@ SetAndNode::~SetAndNode (void)
|
|||||||
double
|
double
|
||||||
SetAndNode::weight (void) const
|
SetAndNode::weight (void) const
|
||||||
{
|
{
|
||||||
return LogAware::pow (follow_->weight(), nrGroundings_);
|
return log_aware::pow (follow_->weight(), nrGroundings_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -108,8 +108,8 @@ double
|
|||||||
IncExcNode::weight (void) const
|
IncExcNode::weight (void) const
|
||||||
{
|
{
|
||||||
double w = 0.0;
|
double w = 0.0;
|
||||||
if (Globals::logDomain) {
|
if (globals::logDomain) {
|
||||||
w = Util::logSum (plus1Branch_->weight(), plus2Branch_->weight());
|
w = util::logSum (plus1Branch_->weight(), plus2Branch_->weight());
|
||||||
w = std::log (std::exp (w) - std::exp (minusBranch_->weight()));
|
w = std::log (std::exp (w) - std::exp (minusBranch_->weight()));
|
||||||
} else {
|
} else {
|
||||||
w = plus1Branch_->weight() + plus2Branch_->weight();
|
w = plus1Branch_->weight() + plus2Branch_->weight();
|
||||||
@ -160,7 +160,7 @@ LeafNode::weight (void) const
|
|||||||
nrGroundings *= std::pow (SetOrNode::nrNegatives(),
|
nrGroundings *= std::pow (SetOrNode::nrNegatives(),
|
||||||
clause_->nrNegCountedLogVars());
|
clause_->nrNegCountedLogVars());
|
||||||
}
|
}
|
||||||
return LogAware::pow (weight, nrGroundings);
|
return log_aware::pow (weight, nrGroundings);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -176,7 +176,7 @@ double
|
|||||||
SmoothNode::weight (void) const
|
SmoothNode::weight (void) const
|
||||||
{
|
{
|
||||||
Clauses cs = clauses();
|
Clauses cs = clauses();
|
||||||
double totalWeight = LogAware::multIdenty();
|
double totalWeight = log_aware::multIdenty();
|
||||||
for (size_t i = 0; i < cs.size(); i++) {
|
for (size_t i = 0; i < cs.size(); i++) {
|
||||||
double posWeight = lwcnf_.posWeight (cs[i]->literals()[0].lid());
|
double posWeight = lwcnf_.posWeight (cs[i]->literals()[0].lid());
|
||||||
double negWeight = lwcnf_.negWeight (cs[i]->literals()[0].lid());
|
double negWeight = lwcnf_.negWeight (cs[i]->literals()[0].lid());
|
||||||
@ -196,8 +196,8 @@ SmoothNode::weight (void) const
|
|||||||
nrGroundings *= std::pow (SetOrNode::nrNegatives(),
|
nrGroundings *= std::pow (SetOrNode::nrNegatives(),
|
||||||
cs[i]->nrNegCountedLogVars());
|
cs[i]->nrNegCountedLogVars());
|
||||||
}
|
}
|
||||||
if (Globals::logDomain) {
|
if (globals::logDomain) {
|
||||||
totalWeight += Util::logSum (posWeight, negWeight) * nrGroundings;
|
totalWeight += util::logSum (posWeight, negWeight) * nrGroundings;
|
||||||
} else {
|
} else {
|
||||||
totalWeight *= std::pow (posWeight + negWeight, nrGroundings);
|
totalWeight *= std::pow (posWeight + negWeight, nrGroundings);
|
||||||
}
|
}
|
||||||
@ -210,7 +210,7 @@ SmoothNode::weight (void) const
|
|||||||
double
|
double
|
||||||
TrueNode::weight (void) const
|
TrueNode::weight (void) const
|
||||||
{
|
{
|
||||||
return LogAware::multIdenty();
|
return log_aware::multIdenty();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -235,9 +235,9 @@ LiftedCircuit::LiftedCircuit (const LiftedWCNF* lwcnf)
|
|||||||
if (compilationSucceeded_) {
|
if (compilationSucceeded_) {
|
||||||
smoothCircuit (root_);
|
smoothCircuit (root_);
|
||||||
}
|
}
|
||||||
if (Globals::verbosity > 1) {
|
if (globals::verbosity > 1) {
|
||||||
if (compilationSucceeded_) {
|
if (compilationSucceeded_) {
|
||||||
double wmc = LogAware::exp (getWeightedModelCount());
|
double wmc = log_aware::exp (getWeightedModelCount());
|
||||||
std::cout << "Weighted model count = " << wmc;
|
std::cout << "Weighted model count = " << wmc;
|
||||||
std::cout << std::endl << std::endl;
|
std::cout << std::endl << std::endl;
|
||||||
}
|
}
|
||||||
@ -302,7 +302,7 @@ LiftedCircuit::compile (
|
|||||||
Clauses& clauses)
|
Clauses& clauses)
|
||||||
{
|
{
|
||||||
if (compilationSucceeded_ == false
|
if (compilationSucceeded_ == false
|
||||||
&& Globals::verbosity <= 1) {
|
&& globals::verbosity <= 1) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -341,7 +341,7 @@ LiftedCircuit::compile (
|
|||||||
}
|
}
|
||||||
|
|
||||||
*follow = new CompilationFailedNode();
|
*follow = new CompilationFailedNode();
|
||||||
if (Globals::verbosity > 1) {
|
if (globals::verbosity > 1) {
|
||||||
originClausesMap_[*follow] = clauses;
|
originClausesMap_[*follow] = clauses;
|
||||||
explanationMap_[*follow] = "" ;
|
explanationMap_[*follow] = "" ;
|
||||||
}
|
}
|
||||||
@ -355,7 +355,7 @@ LiftedCircuit::tryUnitPropagation (
|
|||||||
CircuitNode** follow,
|
CircuitNode** follow,
|
||||||
Clauses& clauses)
|
Clauses& clauses)
|
||||||
{
|
{
|
||||||
if (Globals::verbosity > 1) {
|
if (globals::verbosity > 1) {
|
||||||
backupClauses_ = Clause::copyClauses (clauses);
|
backupClauses_ = Clause::copyClauses (clauses);
|
||||||
}
|
}
|
||||||
for (size_t i = 0; i < clauses.size(); i++) {
|
for (size_t i = 0; i < clauses.size(); i++) {
|
||||||
@ -392,7 +392,7 @@ LiftedCircuit::tryUnitPropagation (
|
|||||||
}
|
}
|
||||||
|
|
||||||
AndNode* andNode = new AndNode();
|
AndNode* andNode = new AndNode();
|
||||||
if (Globals::verbosity > 1) {
|
if (globals::verbosity > 1) {
|
||||||
originClausesMap_[andNode] = backupClauses_;
|
originClausesMap_[andNode] = backupClauses_;
|
||||||
std::stringstream explanation;
|
std::stringstream explanation;
|
||||||
explanation << " UP on " << clauses[i]->literals()[0];
|
explanation << " UP on " << clauses[i]->literals()[0];
|
||||||
@ -406,7 +406,7 @@ LiftedCircuit::tryUnitPropagation (
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (Globals::verbosity > 1) {
|
if (globals::verbosity > 1) {
|
||||||
Clause::deleteClauses (backupClauses_);
|
Clause::deleteClauses (backupClauses_);
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
@ -422,7 +422,7 @@ LiftedCircuit::tryIndependence (
|
|||||||
if (clauses.size() == 1) {
|
if (clauses.size() == 1) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (Globals::verbosity > 1) {
|
if (globals::verbosity > 1) {
|
||||||
backupClauses_ = Clause::copyClauses (clauses);
|
backupClauses_ = Clause::copyClauses (clauses);
|
||||||
}
|
}
|
||||||
Clauses depClauses = { clauses[0] };
|
Clauses depClauses = { clauses[0] };
|
||||||
@ -441,7 +441,7 @@ LiftedCircuit::tryIndependence (
|
|||||||
}
|
}
|
||||||
if (indepClauses.empty() == false) {
|
if (indepClauses.empty() == false) {
|
||||||
AndNode* andNode = new AndNode ();
|
AndNode* andNode = new AndNode ();
|
||||||
if (Globals::verbosity > 1) {
|
if (globals::verbosity > 1) {
|
||||||
originClausesMap_[andNode] = backupClauses_;
|
originClausesMap_[andNode] = backupClauses_;
|
||||||
explanationMap_[andNode] = " Independence" ;
|
explanationMap_[andNode] = " Independence" ;
|
||||||
}
|
}
|
||||||
@ -450,7 +450,7 @@ LiftedCircuit::tryIndependence (
|
|||||||
(*follow) = andNode;
|
(*follow) = andNode;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
if (Globals::verbosity > 1) {
|
if (globals::verbosity > 1) {
|
||||||
Clause::deleteClauses (backupClauses_);
|
Clause::deleteClauses (backupClauses_);
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
@ -463,7 +463,7 @@ LiftedCircuit::tryShannonDecomp (
|
|||||||
CircuitNode** follow,
|
CircuitNode** follow,
|
||||||
Clauses& clauses)
|
Clauses& clauses)
|
||||||
{
|
{
|
||||||
if (Globals::verbosity > 1) {
|
if (globals::verbosity > 1) {
|
||||||
backupClauses_ = Clause::copyClauses (clauses);
|
backupClauses_ = Clause::copyClauses (clauses);
|
||||||
}
|
}
|
||||||
for (size_t i = 0; i < clauses.size(); i++) {
|
for (size_t i = 0; i < clauses.size(); i++) {
|
||||||
@ -481,7 +481,7 @@ LiftedCircuit::tryShannonDecomp (
|
|||||||
otherClauses.push_back (c2);
|
otherClauses.push_back (c2);
|
||||||
|
|
||||||
OrNode* orNode = new OrNode();
|
OrNode* orNode = new OrNode();
|
||||||
if (Globals::verbosity > 1) {
|
if (globals::verbosity > 1) {
|
||||||
originClausesMap_[orNode] = backupClauses_;
|
originClausesMap_[orNode] = backupClauses_;
|
||||||
std::stringstream explanation;
|
std::stringstream explanation;
|
||||||
explanation << " SD on " << literals[j];
|
explanation << " SD on " << literals[j];
|
||||||
@ -495,7 +495,7 @@ LiftedCircuit::tryShannonDecomp (
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (Globals::verbosity > 1) {
|
if (globals::verbosity > 1) {
|
||||||
Clause::deleteClauses (backupClauses_);
|
Clause::deleteClauses (backupClauses_);
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
@ -508,7 +508,7 @@ LiftedCircuit::tryInclusionExclusion (
|
|||||||
CircuitNode** follow,
|
CircuitNode** follow,
|
||||||
Clauses& clauses)
|
Clauses& clauses)
|
||||||
{
|
{
|
||||||
if (Globals::verbosity > 1) {
|
if (globals::verbosity > 1) {
|
||||||
backupClauses_ = Clause::copyClauses (clauses);
|
backupClauses_ = Clause::copyClauses (clauses);
|
||||||
}
|
}
|
||||||
for (size_t i = 0; i < clauses.size(); i++) {
|
for (size_t i = 0; i < clauses.size(); i++) {
|
||||||
@ -561,7 +561,7 @@ LiftedCircuit::tryInclusionExclusion (
|
|||||||
clauses.push_back (c2);
|
clauses.push_back (c2);
|
||||||
|
|
||||||
IncExcNode* ieNode = new IncExcNode();
|
IncExcNode* ieNode = new IncExcNode();
|
||||||
if (Globals::verbosity > 1) {
|
if (globals::verbosity > 1) {
|
||||||
originClausesMap_[ieNode] = backupClauses_;
|
originClausesMap_[ieNode] = backupClauses_;
|
||||||
std::stringstream explanation;
|
std::stringstream explanation;
|
||||||
explanation << " IncExc on clause nº " << i + 1;
|
explanation << " IncExc on clause nº " << i + 1;
|
||||||
@ -574,7 +574,7 @@ LiftedCircuit::tryInclusionExclusion (
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (Globals::verbosity > 1) {
|
if (globals::verbosity > 1) {
|
||||||
Clause::deleteClauses (backupClauses_);
|
Clause::deleteClauses (backupClauses_);
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
@ -589,7 +589,7 @@ LiftedCircuit::tryIndepPartialGrounding (
|
|||||||
{
|
{
|
||||||
// assumes that all literals have logical variables
|
// assumes that all literals have logical variables
|
||||||
// else, shannon decomp was possible
|
// else, shannon decomp was possible
|
||||||
if (Globals::verbosity > 1) {
|
if (globals::verbosity > 1) {
|
||||||
backupClauses_ = Clause::copyClauses (clauses);
|
backupClauses_ = Clause::copyClauses (clauses);
|
||||||
}
|
}
|
||||||
LogVars rootLogVars;
|
LogVars rootLogVars;
|
||||||
@ -603,7 +603,7 @@ LiftedCircuit::tryIndepPartialGrounding (
|
|||||||
clauses[j]->addIpgLogVar (rootLogVars[j]);
|
clauses[j]->addIpgLogVar (rootLogVars[j]);
|
||||||
}
|
}
|
||||||
SetAndNode* setAndNode = new SetAndNode (ct.size());
|
SetAndNode* setAndNode = new SetAndNode (ct.size());
|
||||||
if (Globals::verbosity > 1) {
|
if (globals::verbosity > 1) {
|
||||||
originClausesMap_[setAndNode] = backupClauses_;
|
originClausesMap_[setAndNode] = backupClauses_;
|
||||||
explanationMap_[setAndNode] = " IPG" ;
|
explanationMap_[setAndNode] = " IPG" ;
|
||||||
}
|
}
|
||||||
@ -612,7 +612,7 @@ LiftedCircuit::tryIndepPartialGrounding (
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (Globals::verbosity > 1) {
|
if (globals::verbosity > 1) {
|
||||||
Clause::deleteClauses (backupClauses_);
|
Clause::deleteClauses (backupClauses_);
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
@ -674,7 +674,7 @@ LiftedCircuit::tryAtomCounting (
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (Globals::verbosity > 1) {
|
if (globals::verbosity > 1) {
|
||||||
backupClauses_ = Clause::copyClauses (clauses);
|
backupClauses_ = Clause::copyClauses (clauses);
|
||||||
}
|
}
|
||||||
for (size_t i = 0; i < clauses.size(); i++) {
|
for (size_t i = 0; i < clauses.size(); i++) {
|
||||||
@ -686,7 +686,7 @@ LiftedCircuit::tryAtomCounting (
|
|||||||
unsigned nrGroundings = clauses[i]->constr().projectedCopy (
|
unsigned nrGroundings = clauses[i]->constr().projectedCopy (
|
||||||
literals[j].logVars()).size();
|
literals[j].logVars()).size();
|
||||||
SetOrNode* setOrNode = new SetOrNode (nrGroundings);
|
SetOrNode* setOrNode = new SetOrNode (nrGroundings);
|
||||||
if (Globals::verbosity > 1) {
|
if (globals::verbosity > 1) {
|
||||||
originClausesMap_[setOrNode] = backupClauses_;
|
originClausesMap_[setOrNode] = backupClauses_;
|
||||||
explanationMap_[setOrNode] = " AC" ;
|
explanationMap_[setOrNode] = " AC" ;
|
||||||
}
|
}
|
||||||
@ -707,7 +707,7 @@ LiftedCircuit::tryAtomCounting (
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (Globals::verbosity > 1) {
|
if (globals::verbosity > 1) {
|
||||||
Clause::deleteClauses (backupClauses_);
|
Clause::deleteClauses (backupClauses_);
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
@ -916,7 +916,7 @@ LiftedCircuit::createSmoothNode (
|
|||||||
CircuitNode** prev)
|
CircuitNode** prev)
|
||||||
{
|
{
|
||||||
if (missingLits.empty() == false) {
|
if (missingLits.empty() == false) {
|
||||||
if (Globals::verbosity > 1) {
|
if (globals::verbosity > 1) {
|
||||||
std::unordered_map<CircuitNode*, Clauses>::iterator it
|
std::unordered_map<CircuitNode*, Clauses>::iterator it
|
||||||
= originClausesMap_.find (*prev);
|
= originClausesMap_.find (*prev);
|
||||||
if (it != originClausesMap_.end()) {
|
if (it != originClausesMap_.end()) {
|
||||||
@ -944,7 +944,7 @@ LiftedCircuit::createSmoothNode (
|
|||||||
}
|
}
|
||||||
SmoothNode* smoothNode = new SmoothNode (clauses, *lwcnf_);
|
SmoothNode* smoothNode = new SmoothNode (clauses, *lwcnf_);
|
||||||
*prev = new AndNode (smoothNode, *prev);
|
*prev = new AndNode (smoothNode, *prev);
|
||||||
if (Globals::verbosity > 1) {
|
if (globals::verbosity > 1) {
|
||||||
originClausesMap_[*prev] = backupClauses_;
|
originClausesMap_[*prev] = backupClauses_;
|
||||||
explanationMap_[*prev] = " Smoothing" ;
|
explanationMap_[*prev] = " Smoothing" ;
|
||||||
}
|
}
|
||||||
@ -1211,7 +1211,7 @@ LiftedCircuit::escapeNode (const CircuitNode* node) const
|
|||||||
std::string
|
std::string
|
||||||
LiftedCircuit::getExplanationString (CircuitNode* node)
|
LiftedCircuit::getExplanationString (CircuitNode* node)
|
||||||
{
|
{
|
||||||
return Util::contains (explanationMap_, node)
|
return util::contains (explanationMap_, node)
|
||||||
? explanationMap_[node]
|
? explanationMap_[node]
|
||||||
: "" ;
|
: "" ;
|
||||||
}
|
}
|
||||||
@ -1225,7 +1225,7 @@ LiftedCircuit::printClauses (
|
|||||||
std::string extraOptions)
|
std::string extraOptions)
|
||||||
{
|
{
|
||||||
Clauses clauses;
|
Clauses clauses;
|
||||||
if (Util::contains (originClausesMap_, node)) {
|
if (util::contains (originClausesMap_, node)) {
|
||||||
clauses = originClausesMap_[node];
|
clauses = originClausesMap_[node];
|
||||||
} else if (getCircuitNodeType (node) == CircuitNodeType::LEAF_NODE) {
|
} else if (getCircuitNodeType (node) == CircuitNodeType::LEAF_NODE) {
|
||||||
clauses = { (dynamic_cast<LeafNode*>(node))->clause() } ;
|
clauses = { (dynamic_cast<LeafNode*>(node))->clause() } ;
|
||||||
@ -1288,20 +1288,20 @@ LiftedKc::solveQuery (const Grounds& query)
|
|||||||
std::vector<LiteralId> litIds = lwcnf_->prvGroupLiterals (groups[i]);
|
std::vector<LiteralId> litIds = lwcnf_->prvGroupLiterals (groups[i]);
|
||||||
for (size_t j = 0; j < litIds.size(); j++) {
|
for (size_t j = 0; j < litIds.size(); j++) {
|
||||||
if (indexer[i] == j) {
|
if (indexer[i] == j) {
|
||||||
lwcnf_->addWeight (litIds[j], LogAware::one(),
|
lwcnf_->addWeight (litIds[j], log_aware::one(),
|
||||||
LogAware::one());
|
log_aware::one());
|
||||||
} else {
|
} else {
|
||||||
lwcnf_->addWeight (litIds[j], LogAware::zero(),
|
lwcnf_->addWeight (litIds[j], log_aware::zero(),
|
||||||
LogAware::one());
|
log_aware::one());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
params.push_back (circuit_->getWeightedModelCount());
|
params.push_back (circuit_->getWeightedModelCount());
|
||||||
++ indexer;
|
++ indexer;
|
||||||
}
|
}
|
||||||
LogAware::normalize (params);
|
log_aware::normalize (params);
|
||||||
if (Globals::logDomain) {
|
if (globals::logDomain) {
|
||||||
Util::exp (params);
|
util::exp (params);
|
||||||
}
|
}
|
||||||
return params;
|
return params;
|
||||||
}
|
}
|
||||||
@ -1313,7 +1313,7 @@ LiftedKc::printSolverFlags (void) const
|
|||||||
{
|
{
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << "lifted kc [" ;
|
ss << "lifted kc [" ;
|
||||||
ss << "log_domain=" << Util::toString (Globals::logDomain);
|
ss << "log_domain=" << util::toString (globals::logDomain);
|
||||||
ss << "]" ;
|
ss << "]" ;
|
||||||
std::cout << ss.str() << std::endl;
|
std::cout << ss.str() << std::endl;
|
||||||
}
|
}
|
||||||
|
@ -47,13 +47,13 @@ LiftedOperations::shatterAgainstQuery (
|
|||||||
}
|
}
|
||||||
pfList.add (newPfs);
|
pfList.add (newPfs);
|
||||||
}
|
}
|
||||||
if (Globals::verbosity > 2) {
|
if (globals::verbosity > 2) {
|
||||||
Util::printAsteriskLine();
|
util::printAsteriskLine();
|
||||||
std::cout << "SHATTERED AGAINST THE QUERY" << std::endl;
|
std::cout << "SHATTERED AGAINST THE QUERY" << std::endl;
|
||||||
for (size_t i = 0; i < query.size(); i++) {
|
for (size_t i = 0; i < query.size(); i++) {
|
||||||
std::cout << " -> " << query[i] << std::endl;
|
std::cout << " -> " << query[i] << std::endl;
|
||||||
}
|
}
|
||||||
Util::printAsteriskLine();
|
util::printAsteriskLine();
|
||||||
pfList.print();
|
pfList.print();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -85,11 +85,11 @@ LiftedOperations::runWeakBayesBall (
|
|||||||
PrvGroup group = todo.front();
|
PrvGroup group = todo.front();
|
||||||
ParfactorList::iterator it = pfList.begin();
|
ParfactorList::iterator it = pfList.begin();
|
||||||
while (it != pfList.end()) {
|
while (it != pfList.end()) {
|
||||||
if (Util::contains (requiredPfs, *it) == false &&
|
if (util::contains (requiredPfs, *it) == false &&
|
||||||
(*it)->containsGroup (group)) {
|
(*it)->containsGroup (group)) {
|
||||||
std::vector<PrvGroup> groups = (*it)->getAllGroups();
|
std::vector<PrvGroup> groups = (*it)->getAllGroups();
|
||||||
for (size_t i = 0; i < groups.size(); i++) {
|
for (size_t i = 0; i < groups.size(); i++) {
|
||||||
if (Util::contains (done, groups[i]) == false) {
|
if (util::contains (done, groups[i]) == false) {
|
||||||
todo.push (groups[i]);
|
todo.push (groups[i]);
|
||||||
done.insert (groups[i]);
|
done.insert (groups[i]);
|
||||||
}
|
}
|
||||||
@ -104,10 +104,10 @@ LiftedOperations::runWeakBayesBall (
|
|||||||
ParfactorList::iterator it = pfList.begin();
|
ParfactorList::iterator it = pfList.begin();
|
||||||
bool foundNotRequired = false;
|
bool foundNotRequired = false;
|
||||||
while (it != pfList.end()) {
|
while (it != pfList.end()) {
|
||||||
if (Util::contains (requiredPfs, *it) == false) {
|
if (util::contains (requiredPfs, *it) == false) {
|
||||||
if (Globals::verbosity > 2) {
|
if (globals::verbosity > 2) {
|
||||||
if (foundNotRequired == false) {
|
if (foundNotRequired == false) {
|
||||||
Util::printHeader ("PARFACTORS TO DISCARD");
|
util::printHeader ("PARFACTORS TO DISCARD");
|
||||||
foundNotRequired = true;
|
foundNotRequired = true;
|
||||||
}
|
}
|
||||||
(*it)->print();
|
(*it)->print();
|
||||||
@ -137,7 +137,7 @@ LiftedOperations::absorveEvidence (
|
|||||||
if (absorvedPfs.size() == 1 && !absorvedPfs[0]) {
|
if (absorvedPfs.size() == 1 && !absorvedPfs[0]) {
|
||||||
// just remove pf;
|
// just remove pf;
|
||||||
} else {
|
} else {
|
||||||
Util::addToVector (newPfs, absorvedPfs);
|
util::addToVector (newPfs, absorvedPfs);
|
||||||
}
|
}
|
||||||
delete pf;
|
delete pf;
|
||||||
} else {
|
} else {
|
||||||
@ -147,13 +147,13 @@ LiftedOperations::absorveEvidence (
|
|||||||
}
|
}
|
||||||
pfList.add (newPfs);
|
pfList.add (newPfs);
|
||||||
}
|
}
|
||||||
if (Globals::verbosity > 2 && obsFormulas.empty() == false) {
|
if (globals::verbosity > 2 && obsFormulas.empty() == false) {
|
||||||
Util::printAsteriskLine();
|
util::printAsteriskLine();
|
||||||
std::cout << "AFTER EVIDENCE ABSORVED" << std::endl;
|
std::cout << "AFTER EVIDENCE ABSORVED" << std::endl;
|
||||||
for (size_t i = 0; i < obsFormulas.size(); i++) {
|
for (size_t i = 0; i < obsFormulas.size(); i++) {
|
||||||
std::cout << " -> " << obsFormulas[i] << std::endl;
|
std::cout << " -> " << obsFormulas[i] << std::endl;
|
||||||
}
|
}
|
||||||
Util::printAsteriskLine();
|
util::printAsteriskLine();
|
||||||
pfList.print();
|
pfList.print();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -105,7 +105,7 @@ Substitution::getDiscardedLogVars (void) const
|
|||||||
std::unordered_map<LogVar, LogVar>::const_iterator it
|
std::unordered_map<LogVar, LogVar>::const_iterator it
|
||||||
= subs_.begin();
|
= subs_.begin();
|
||||||
while (it != subs_.end()) {
|
while (it != subs_.end()) {
|
||||||
if (Util::contains (doneLvs, it->second)) {
|
if (util::contains (doneLvs, it->second)) {
|
||||||
discardedLvs.push_back (it->first);
|
discardedLvs.push_back (it->first);
|
||||||
} else {
|
} else {
|
||||||
doneLvs.insert (it->second);
|
doneLvs.insert (it->second);
|
||||||
|
@ -15,13 +15,13 @@ namespace horus {
|
|||||||
class Symbol
|
class Symbol
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
Symbol (void) : id_(Util::maxUnsigned()) { }
|
Symbol (void) : id_(util::maxUnsigned()) { }
|
||||||
|
|
||||||
Symbol (unsigned id) : id_(id) { }
|
Symbol (unsigned id) : id_(id) { }
|
||||||
|
|
||||||
operator unsigned (void) const { return id_; }
|
operator unsigned (void) const { return id_; }
|
||||||
|
|
||||||
bool valid (void) const { return id_ != Util::maxUnsigned(); }
|
bool valid (void) const { return id_ != util::maxUnsigned(); }
|
||||||
|
|
||||||
static Symbol invalid (void) { return Symbol(); }
|
static Symbol invalid (void) { return Symbol(); }
|
||||||
|
|
||||||
@ -35,7 +35,7 @@ class Symbol
|
|||||||
class LogVar
|
class LogVar
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
LogVar (void) : id_(Util::maxUnsigned()) { }
|
LogVar (void) : id_(util::maxUnsigned()) { }
|
||||||
|
|
||||||
LogVar (unsigned id) : id_(id) { }
|
LogVar (unsigned id) : id_(id) { }
|
||||||
|
|
||||||
@ -66,7 +66,7 @@ LogVar::operator++ (void)
|
|||||||
inline bool
|
inline bool
|
||||||
LogVar::valid (void) const
|
LogVar::valid (void) const
|
||||||
{
|
{
|
||||||
return id_ != Util::maxUnsigned();
|
return id_ != util::maxUnsigned();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace horus
|
} // namespace horus
|
||||||
@ -168,7 +168,7 @@ class Substitution
|
|||||||
inline void
|
inline void
|
||||||
Substitution::add (LogVar X_old, LogVar X_new)
|
Substitution::add (LogVar X_old, LogVar X_new)
|
||||||
{
|
{
|
||||||
assert (Util::contains (subs_, X_old) == false);
|
assert (util::contains (subs_, X_old) == false);
|
||||||
subs_.insert (std::make_pair (X_old, X_new));
|
subs_.insert (std::make_pair (X_old, X_new));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -177,7 +177,7 @@ Substitution::add (LogVar X_old, LogVar X_new)
|
|||||||
inline void
|
inline void
|
||||||
Substitution::rename (LogVar X_old, LogVar X_new)
|
Substitution::rename (LogVar X_old, LogVar X_new)
|
||||||
{
|
{
|
||||||
assert (Util::contains (subs_, X_old));
|
assert (util::contains (subs_, X_old));
|
||||||
subs_.find (X_old)->second = X_new;
|
subs_.find (X_old)->second = X_new;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -199,7 +199,7 @@ Substitution::newNameFor (LogVar X) const
|
|||||||
inline bool
|
inline bool
|
||||||
Substitution::containsReplacementFor (LogVar X) const
|
Substitution::containsReplacementFor (LogVar X) const
|
||||||
{
|
{
|
||||||
return Util::contains (subs_, X);
|
return util::contains (subs_, X);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -25,7 +25,7 @@ LiftedOperator::getValidOps (
|
|||||||
multOps = ProductOperator::getValidOps (pfList);
|
multOps = ProductOperator::getValidOps (pfList);
|
||||||
validOps.insert (validOps.end(), multOps.begin(), multOps.end());
|
validOps.insert (validOps.end(), multOps.begin(), multOps.end());
|
||||||
|
|
||||||
if (Globals::verbosity > 1 || multOps.empty()) {
|
if (globals::verbosity > 1 || multOps.empty()) {
|
||||||
std::vector<SumOutOperator*> sumOutOps;
|
std::vector<SumOutOperator*> sumOutOps;
|
||||||
std::vector<CountingOperator*> countOps;
|
std::vector<CountingOperator*> countOps;
|
||||||
std::vector<GroundOperator*> groundOps;
|
std::vector<GroundOperator*> groundOps;
|
||||||
@ -103,14 +103,14 @@ ProductOperator::getValidOps (ParfactorList& pfList)
|
|||||||
ParfactorList::iterator penultimate = -- pfList.end();
|
ParfactorList::iterator penultimate = -- pfList.end();
|
||||||
std::set<Parfactor*> pfs;
|
std::set<Parfactor*> pfs;
|
||||||
while (it1 != penultimate) {
|
while (it1 != penultimate) {
|
||||||
if (Util::contains (pfs, *it1)) {
|
if (util::contains (pfs, *it1)) {
|
||||||
++ it1;
|
++ it1;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
ParfactorList::iterator it2 = it1;
|
ParfactorList::iterator it2 = it1;
|
||||||
++ it2;
|
++ it2;
|
||||||
while (it2 != pfList.end()) {
|
while (it2 != pfList.end()) {
|
||||||
if (Util::contains (pfs, *it2)) {
|
if (util::contains (pfs, *it2)) {
|
||||||
++ it2;
|
++ it2;
|
||||||
continue;
|
continue;
|
||||||
} else {
|
} else {
|
||||||
@ -119,7 +119,7 @@ ProductOperator::getValidOps (ParfactorList& pfList)
|
|||||||
pfs.insert (*it2);
|
pfs.insert (*it2);
|
||||||
validOps.push_back (new ProductOperator (
|
validOps.push_back (new ProductOperator (
|
||||||
it1, it2, pfList));
|
it1, it2, pfList));
|
||||||
if (Globals::verbosity < 2) {
|
if (globals::verbosity < 2) {
|
||||||
return validOps;
|
return validOps;
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
@ -353,7 +353,7 @@ CountingOperator::getLogCost (void)
|
|||||||
cost += size * HistogramSet::nrHistograms (counts[i], range);
|
cost += size * HistogramSet::nrHistograms (counts[i], range);
|
||||||
}
|
}
|
||||||
PrvGroup group = (*pfIter_)->argument (fIdx).group();
|
PrvGroup group = (*pfIter_)->argument (fIdx).group();
|
||||||
size_t lvIndex = Util::indexOf (
|
size_t lvIndex = util::indexOf (
|
||||||
(*pfIter_)->argument (fIdx).logVars(), X_);
|
(*pfIter_)->argument (fIdx).logVars(), X_);
|
||||||
assert (lvIndex != (*pfIter_)->argument (fIdx).logVars().size());
|
assert (lvIndex != (*pfIter_)->argument (fIdx).logVars().size());
|
||||||
ParfactorList::iterator pfIter = pfList_.begin();
|
ParfactorList::iterator pfIter = pfList_.begin();
|
||||||
@ -503,7 +503,7 @@ GroundOperator::getLogCost (void)
|
|||||||
if (willBeAffected) {
|
if (willBeAffected) {
|
||||||
// std::cout << " + " << std::exp (reps) << "x" << std::exp (pfSize);
|
// std::cout << " + " << std::exp (reps) << "x" << std::exp (pfSize);
|
||||||
double pfCost = reps + pfSize;
|
double pfCost = reps + pfSize;
|
||||||
totalCost = Util::logSum (totalCost, pfCost);
|
totalCost = util::logSum (totalCost, pfCost);
|
||||||
}
|
}
|
||||||
++ pflIt;
|
++ pflIt;
|
||||||
}
|
}
|
||||||
@ -552,7 +552,7 @@ GroundOperator::getValidOps (ParfactorList& pfList)
|
|||||||
while (it != pfList.end()) {
|
while (it != pfList.end()) {
|
||||||
const ProbFormulas& formulas = (*it)->arguments();
|
const ProbFormulas& formulas = (*it)->arguments();
|
||||||
for (size_t i = 0; i < formulas.size(); i++) {
|
for (size_t i = 0; i < formulas.size(); i++) {
|
||||||
if (Util::contains (allGroups, formulas[i].group()) == false) {
|
if (util::contains (allGroups, formulas[i].group()) == false) {
|
||||||
const LogVars& lvs = formulas[i].logVars();
|
const LogVars& lvs = formulas[i].logVars();
|
||||||
for (size_t j = 0; j < lvs.size(); j++) {
|
for (size_t j = 0; j < lvs.size(); j++) {
|
||||||
if ((*it)->constr()->isSingleton (lvs[j]) == false) {
|
if ((*it)->constr()->isSingleton (lvs[j]) == false) {
|
||||||
@ -618,7 +618,7 @@ GroundOperator::getAffectedFormulas (void)
|
|||||||
if (i != idx && fs[i].contains (X)) {
|
if (i != idx && fs[i].contains (X)) {
|
||||||
std::pair<PrvGroup, unsigned> pair = std::make_pair (
|
std::pair<PrvGroup, unsigned> pair = std::make_pair (
|
||||||
fs[i].group(), fs[i].indexOf (X));
|
fs[i].group(), fs[i].indexOf (X));
|
||||||
if (Util::contains (affectedFormulas, pair) == false) {
|
if (util::contains (affectedFormulas, pair) == false) {
|
||||||
q.push (pair);
|
q.push (pair);
|
||||||
affectedFormulas.push_back (pair);
|
affectedFormulas.push_back (pair);
|
||||||
}
|
}
|
||||||
@ -642,8 +642,8 @@ LiftedVe::solveQuery (const Grounds& query)
|
|||||||
runSolver (query);
|
runSolver (query);
|
||||||
(*pfList_.begin())->normalize();
|
(*pfList_.begin())->normalize();
|
||||||
Params params = (*pfList_.begin())->params();
|
Params params = (*pfList_.begin())->params();
|
||||||
if (Globals::logDomain) {
|
if (globals::logDomain) {
|
||||||
Util::exp (params);
|
util::exp (params);
|
||||||
}
|
}
|
||||||
return params;
|
return params;
|
||||||
}
|
}
|
||||||
@ -655,7 +655,7 @@ LiftedVe::printSolverFlags (void) const
|
|||||||
{
|
{
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << "lve [" ;
|
ss << "lve [" ;
|
||||||
ss << "log_domain=" << Util::toString (Globals::logDomain);
|
ss << "log_domain=" << util::toString (globals::logDomain);
|
||||||
ss << "]" ;
|
ss << "]" ;
|
||||||
std::cout << ss.str() << std::endl;
|
std::cout << ss.str() << std::endl;
|
||||||
}
|
}
|
||||||
@ -669,10 +669,10 @@ LiftedVe::runSolver (const Grounds& query)
|
|||||||
LiftedOperations::shatterAgainstQuery (pfList_, query);
|
LiftedOperations::shatterAgainstQuery (pfList_, query);
|
||||||
LiftedOperations::runWeakBayesBall (pfList_, query);
|
LiftedOperations::runWeakBayesBall (pfList_, query);
|
||||||
while (true) {
|
while (true) {
|
||||||
if (Globals::verbosity > 2) {
|
if (globals::verbosity > 2) {
|
||||||
Util::printDashedLine();
|
util::printDashedLine();
|
||||||
pfList_.print();
|
pfList_.print();
|
||||||
if (Globals::verbosity > 3) {
|
if (globals::verbosity > 3) {
|
||||||
LiftedOperator::printValidOps (pfList_, query);
|
LiftedOperator::printValidOps (pfList_, query);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -680,9 +680,9 @@ LiftedVe::runSolver (const Grounds& query)
|
|||||||
if (op == 0) {
|
if (op == 0) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
if (Globals::verbosity > 1) {
|
if (globals::verbosity > 1) {
|
||||||
std::cout << "best operation: " << op->toString();
|
std::cout << "best operation: " << op->toString();
|
||||||
if (Globals::verbosity > 2) {
|
if (globals::verbosity > 2) {
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -698,7 +698,7 @@ LiftedVe::runSolver (const Grounds& query)
|
|||||||
++ pfIter;
|
++ pfIter;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (Globals::verbosity > 0) {
|
if (globals::verbosity > 0) {
|
||||||
std::cout << "largest cost = " << std::exp (largestCost_);
|
std::cout << "largest cost = " << std::exp (largestCost_);
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
|
@ -26,7 +26,7 @@ Literal::isGround (ConstraintTree constr, LogVarSet ipgLogVars) const
|
|||||||
size_t
|
size_t
|
||||||
Literal::indexOfLogVar (LogVar X) const
|
Literal::indexOfLogVar (LogVar X) const
|
||||||
{
|
{
|
||||||
return Util::indexOf (logVars_, X);
|
return util::indexOf (logVars_, X);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -245,7 +245,7 @@ Clause::ipgCandidates (void) const
|
|||||||
for (size_t i = 0; i < allLvs.size(); i++) {
|
for (size_t i = 0; i < allLvs.size(); i++) {
|
||||||
bool valid = true;
|
bool valid = true;
|
||||||
for (size_t j = 0; j < literals_.size(); j++) {
|
for (size_t j = 0; j < literals_.size(); j++) {
|
||||||
if (Util::contains (literals_[j].logVars(), allLvs[i]) == false) {
|
if (util::contains (literals_[j].logVars(), allLvs[i]) == false) {
|
||||||
valid = false;
|
valid = false;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -444,7 +444,7 @@ LiftedWCNF::LiftedWCNF (const ParfactorList& pfList)
|
|||||||
clauses_.push_back(c2);
|
clauses_.push_back(c2);
|
||||||
*/
|
*/
|
||||||
|
|
||||||
if (Globals::verbosity > 1) {
|
if (globals::verbosity > 1) {
|
||||||
std::cout << "FORMULA INDICATORS:" << std::endl;
|
std::cout << "FORMULA INDICATORS:" << std::endl;
|
||||||
printFormulaIndicators();
|
printFormulaIndicators();
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
@ -479,7 +479,7 @@ LiftedWCNF::posWeight (LiteralId lid) const
|
|||||||
{
|
{
|
||||||
std::unordered_map<LiteralId, std::pair<double,double>>::const_iterator it
|
std::unordered_map<LiteralId, std::pair<double,double>>::const_iterator it
|
||||||
= weights_.find (lid);
|
= weights_.find (lid);
|
||||||
return it != weights_.end() ? it->second.first : LogAware::one();
|
return it != weights_.end() ? it->second.first : log_aware::one();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -489,7 +489,7 @@ LiftedWCNF::negWeight (LiteralId lid) const
|
|||||||
{
|
{
|
||||||
std::unordered_map<LiteralId, std::pair<double,double>>::const_iterator it
|
std::unordered_map<LiteralId, std::pair<double,double>>::const_iterator it
|
||||||
= weights_.find (lid);
|
= weights_.find (lid);
|
||||||
return it != weights_.end() ? it->second.second : LogAware::one();
|
return it != weights_.end() ? it->second.second : log_aware::one();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -497,7 +497,7 @@ LiftedWCNF::negWeight (LiteralId lid) const
|
|||||||
std::vector<LiteralId>
|
std::vector<LiteralId>
|
||||||
LiftedWCNF::prvGroupLiterals (PrvGroup prvGroup)
|
LiftedWCNF::prvGroupLiterals (PrvGroup prvGroup)
|
||||||
{
|
{
|
||||||
assert (Util::contains (map_, prvGroup));
|
assert (util::contains (map_, prvGroup));
|
||||||
return map_[prvGroup];
|
return map_[prvGroup];
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -526,7 +526,7 @@ LiftedWCNF::createClause (LiteralId lid) const
|
|||||||
LiteralId
|
LiteralId
|
||||||
LiftedWCNF::getLiteralId (PrvGroup prvGroup, unsigned range)
|
LiftedWCNF::getLiteralId (PrvGroup prvGroup, unsigned range)
|
||||||
{
|
{
|
||||||
assert (Util::contains (map_, prvGroup));
|
assert (util::contains (map_, prvGroup));
|
||||||
return map_[prvGroup][range];
|
return map_[prvGroup][range];
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -539,7 +539,7 @@ LiftedWCNF::addIndicatorClauses (const ParfactorList& pfList)
|
|||||||
while (it != pfList.end()) {
|
while (it != pfList.end()) {
|
||||||
const ProbFormulas& formulas = (*it)->arguments();
|
const ProbFormulas& formulas = (*it)->arguments();
|
||||||
for (size_t i = 0; i < formulas.size(); i++) {
|
for (size_t i = 0; i < formulas.size(); i++) {
|
||||||
if (Util::contains (map_, formulas[i].group()) == false) {
|
if (util::contains (map_, formulas[i].group()) == false) {
|
||||||
ConstraintTree tempConstr = (*it)->constr()->projectedCopy(
|
ConstraintTree tempConstr = (*it)->constr()->projectedCopy(
|
||||||
formulas[i].logVars());
|
formulas[i].logVars());
|
||||||
Clause* clause = new Clause (tempConstr);
|
Clause* clause = new Clause (tempConstr);
|
||||||
@ -584,7 +584,7 @@ LiftedWCNF::addParameterClauses (const ParfactorList& pfList)
|
|||||||
// ¬θxi|u1,...,un v λu1 -> tempClause
|
// ¬θxi|u1,...,un v λu1 -> tempClause
|
||||||
// ¬θxi|u1,...,un v λu2 -> tempClause
|
// ¬θxi|u1,...,un v λu2 -> tempClause
|
||||||
double posWeight = (**it)[indexer];
|
double posWeight = (**it)[indexer];
|
||||||
addWeight (paramVarLid, posWeight, LogAware::one());
|
addWeight (paramVarLid, posWeight, log_aware::one());
|
||||||
|
|
||||||
Clause* clause1 = new Clause (*(*it)->constr());
|
Clause* clause1 = new Clause (*(*it)->constr());
|
||||||
|
|
||||||
@ -623,7 +623,7 @@ LiftedWCNF::printFormulaIndicators (void) const
|
|||||||
while (it != pfList_.end()) {
|
while (it != pfList_.end()) {
|
||||||
const ProbFormulas& formulas = (*it)->arguments();
|
const ProbFormulas& formulas = (*it)->arguments();
|
||||||
for (size_t i = 0; i < formulas.size(); i++) {
|
for (size_t i = 0; i < formulas.size(); i++) {
|
||||||
if (Util::contains (allGroups, formulas[i].group()) == false) {
|
if (util::contains (allGroups, formulas[i].group()) == false) {
|
||||||
allGroups.insert (formulas[i].group());
|
allGroups.insert (formulas[i].group());
|
||||||
std::cout << formulas[i] << " | " ;
|
std::cout << formulas[i] << " | " ;
|
||||||
ConstraintTree tempCt = (*it)->constr()->projectedCopy (
|
ConstraintTree tempCt = (*it)->constr()->projectedCopy (
|
||||||
|
@ -27,7 +27,7 @@ Parfactor::Parfactor (
|
|||||||
ranges_.push_back (args_[i].range());
|
ranges_.push_back (args_[i].range());
|
||||||
const LogVars& lvs = args_[i].logVars();
|
const LogVars& lvs = args_[i].logVars();
|
||||||
for (size_t j = 0; j < lvs.size(); j++) {
|
for (size_t j = 0; j < lvs.size(); j++) {
|
||||||
if (Util::contains (logVars, lvs[j]) == false) {
|
if (util::contains (logVars, lvs[j]) == false) {
|
||||||
logVars.push_back (lvs[j]);
|
logVars.push_back (lvs[j]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -50,7 +50,7 @@ Parfactor::Parfactor (
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
assert (params_.size() == Util::sizeExpected (ranges_));
|
assert (params_.size() == util::sizeExpected (ranges_));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -62,7 +62,7 @@ Parfactor::Parfactor (const Parfactor* g, const Tuple& tuple)
|
|||||||
ranges_ = g->ranges();
|
ranges_ = g->ranges();
|
||||||
distId_ = g->distId();
|
distId_ = g->distId();
|
||||||
constr_ = new ConstraintTree (g->logVars(), {tuple});
|
constr_ = new ConstraintTree (g->logVars(), {tuple});
|
||||||
assert (params_.size() == Util::sizeExpected (ranges_));
|
assert (params_.size() == util::sizeExpected (ranges_));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -74,7 +74,7 @@ Parfactor::Parfactor (const Parfactor* g, ConstraintTree* constr)
|
|||||||
ranges_ = g->ranges();
|
ranges_ = g->ranges();
|
||||||
distId_ = g->distId();
|
distId_ = g->distId();
|
||||||
constr_ = constr;
|
constr_ = constr;
|
||||||
assert (params_.size() == Util::sizeExpected (ranges_));
|
assert (params_.size() == util::sizeExpected (ranges_));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -86,7 +86,7 @@ Parfactor::Parfactor (const Parfactor& g)
|
|||||||
ranges_ = g.ranges();
|
ranges_ = g.ranges();
|
||||||
distId_ = g.distId();
|
distId_ = g.distId();
|
||||||
constr_ = new ConstraintTree (*g.constr());
|
constr_ = new ConstraintTree (*g.constr());
|
||||||
assert (params_.size() == Util::sizeExpected (ranges_));
|
assert (params_.size() == util::sizeExpected (ranges_));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -159,7 +159,7 @@ Parfactor::sumOutIndex (size_t fIdx)
|
|||||||
std::vector<double> numAssigns = HistogramSet::getNumAssigns (N, R);
|
std::vector<double> numAssigns = HistogramSet::getNumAssigns (N, R);
|
||||||
Indexer indexer (ranges_, fIdx);
|
Indexer indexer (ranges_, fIdx);
|
||||||
while (indexer.valid()) {
|
while (indexer.valid()) {
|
||||||
if (Globals::logDomain) {
|
if (globals::logDomain) {
|
||||||
params_[indexer] += numAssigns[ indexer[fIdx] ];
|
params_[indexer] += numAssigns[ indexer[fIdx] ];
|
||||||
} else {
|
} else {
|
||||||
params_[indexer] *= numAssigns[ indexer[fIdx] ];
|
params_[indexer] *= numAssigns[ indexer[fIdx] ];
|
||||||
@ -179,7 +179,7 @@ Parfactor::sumOutIndex (size_t fIdx)
|
|||||||
constr_->remove (excl);
|
constr_->remove (excl);
|
||||||
|
|
||||||
TFactor<ProbFormula>::sumOutIndex (fIdx);
|
TFactor<ProbFormula>::sumOutIndex (fIdx);
|
||||||
LogAware::pow (params_, exp);
|
log_aware::pow (params_, exp);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -253,14 +253,14 @@ Parfactor::countConvert (LogVar X)
|
|||||||
ranges_[fIdx] = H;
|
ranges_[fIdx] = H;
|
||||||
MapIndexer mapIndexer (ranges_, fIdx);
|
MapIndexer mapIndexer (ranges_, fIdx);
|
||||||
while (mapIndexer.valid()) {
|
while (mapIndexer.valid()) {
|
||||||
double prod = LogAware::multIdenty();
|
double prod = log_aware::multIdenty();
|
||||||
size_t i = mapIndexer;
|
size_t i = mapIndexer;
|
||||||
unsigned h = mapIndexer[fIdx];
|
unsigned h = mapIndexer[fIdx];
|
||||||
for (unsigned r = 0; r < R; r++) {
|
for (unsigned r = 0; r < R; r++) {
|
||||||
if (Globals::logDomain) {
|
if (globals::logDomain) {
|
||||||
prod += LogAware::pow (sumout[i][r], histograms[h][r]);
|
prod += log_aware::pow (sumout[i][r], histograms[h][r]);
|
||||||
} else {
|
} else {
|
||||||
prod *= LogAware::pow (sumout[i][r], histograms[h][r]);
|
prod *= log_aware::pow (sumout[i][r], histograms[h][r]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
params_.push_back (prod);
|
params_.push_back (prod);
|
||||||
@ -390,7 +390,7 @@ Parfactor::absorveEvidence (const ProbFormula& formula, unsigned evidence)
|
|||||||
LogVarSet excl = exclusiveLogVars (fIdx);
|
LogVarSet excl = exclusiveLogVars (fIdx);
|
||||||
assert (args_[fIdx].isCounting() == false);
|
assert (args_[fIdx].isCounting() == false);
|
||||||
assert (constr_->isCountNormalized (excl));
|
assert (constr_->isCountNormalized (excl));
|
||||||
LogAware::pow (params_, constr_->getConditionalCount (excl));
|
log_aware::pow (params_, constr_->getConditionalCount (excl));
|
||||||
TFactor<ProbFormula>::absorveEvidence (formula, evidence);
|
TFactor<ProbFormula>::absorveEvidence (formula, evidence);
|
||||||
constr_->remove (excl);
|
constr_->remove (excl);
|
||||||
}
|
}
|
||||||
@ -475,7 +475,7 @@ Parfactor::containsGrounds (const Grounds& grounds) const
|
|||||||
}
|
}
|
||||||
LogVars lvs = args_[idx].logVars();
|
LogVars lvs = args_[idx].logVars();
|
||||||
for (size_t j = 0; j < lvs.size(); j++) {
|
for (size_t j = 0; j < lvs.size(); j++) {
|
||||||
if (Util::contains (tupleLvs, lvs[j]) == false) {
|
if (util::contains (tupleLvs, lvs[j]) == false) {
|
||||||
tuple.push_back (grounds[i].args()[j]);
|
tuple.push_back (grounds[i].args()[j]);
|
||||||
tupleLvs.push_back (lvs[j]);
|
tupleLvs.push_back (lvs[j]);
|
||||||
}
|
}
|
||||||
@ -613,10 +613,10 @@ Parfactor::print (bool printParams) const
|
|||||||
cout << args_[i];
|
cout << args_[i];
|
||||||
}
|
}
|
||||||
cout << endl;
|
cout << endl;
|
||||||
if (args_[0].group() != Util::maxUnsigned()) {
|
if (args_[0].group() != util::maxUnsigned()) {
|
||||||
std::vector<std::string> groups;
|
std::vector<std::string> groups;
|
||||||
for (size_t i = 0; i < args_.size(); i++) {
|
for (size_t i = 0; i < args_.size(); i++) {
|
||||||
groups.push_back (std::string ("g") + Util::toString (args_[i].group()));
|
groups.push_back (std::string ("g") + util::toString (args_[i].group()));
|
||||||
}
|
}
|
||||||
cout << "Groups: " << groups << endl;
|
cout << "Groups: " << groups << endl;
|
||||||
}
|
}
|
||||||
@ -844,8 +844,8 @@ Parfactor::getAlignLogVars (Parfactor* g1, Parfactor* g2)
|
|||||||
g1->range (i) == g2->range (j) &&
|
g1->range (i) == g2->range (j) &&
|
||||||
matchedI.contains (i) == false &&
|
matchedI.contains (i) == false &&
|
||||||
matchedJ.contains (j) == false) {
|
matchedJ.contains (j) == false) {
|
||||||
Util::addToVector (Xs_1, formulas1[i].logVars());
|
util::addToVector (Xs_1, formulas1[i].logVars());
|
||||||
Util::addToVector (Xs_2, formulas2[j].logVars());
|
util::addToVector (Xs_2, formulas2[j].logVars());
|
||||||
matchedI.insert (i);
|
matchedI.insert (i);
|
||||||
matchedJ.insert (j);
|
matchedJ.insert (j);
|
||||||
}
|
}
|
||||||
@ -869,8 +869,8 @@ Parfactor::alignAndExponentiate (Parfactor* g1, Parfactor* g2)
|
|||||||
assert (g2->constr()->isCountNormalized (Y_2));
|
assert (g2->constr()->isCountNormalized (Y_2));
|
||||||
unsigned condCount1 = g1->constr()->getConditionalCount (Y_1);
|
unsigned condCount1 = g1->constr()->getConditionalCount (Y_1);
|
||||||
unsigned condCount2 = g2->constr()->getConditionalCount (Y_2);
|
unsigned condCount2 = g2->constr()->getConditionalCount (Y_2);
|
||||||
LogAware::pow (g1->params(), 1.0 / condCount2);
|
log_aware::pow (g1->params(), 1.0 / condCount2);
|
||||||
LogAware::pow (g2->params(), 1.0 / condCount1);
|
log_aware::pow (g2->params(), 1.0 / condCount1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -244,13 +244,13 @@ ParfactorList::addToShatteredList (Parfactor* g)
|
|||||||
shattRes = shatter (*pfIter, pf);
|
shattRes = shatter (*pfIter, pf);
|
||||||
if (shattRes.first.empty() == false) {
|
if (shattRes.first.empty() == false) {
|
||||||
pfIter = removeAndDelete (pfIter);
|
pfIter = removeAndDelete (pfIter);
|
||||||
Util::addToQueue (residuals, shattRes.first);
|
util::addToQueue (residuals, shattRes.first);
|
||||||
} else {
|
} else {
|
||||||
++ pfIter;
|
++ pfIter;
|
||||||
}
|
}
|
||||||
if (shattRes.second.empty() == false) {
|
if (shattRes.second.empty() == false) {
|
||||||
delete pf;
|
delete pf;
|
||||||
Util::addToQueue (residuals, shattRes.second);
|
util::addToQueue (residuals, shattRes.second);
|
||||||
pfSplitted = true;
|
pfSplitted = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -261,7 +261,7 @@ ParfactorList::addToShatteredList (Parfactor* g)
|
|||||||
if (res.empty()) {
|
if (res.empty()) {
|
||||||
addShattered (pf);
|
addShattered (pf);
|
||||||
} else {
|
} else {
|
||||||
Util::addToQueue (residuals, res);
|
util::addToQueue (residuals, res);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -329,7 +329,7 @@ ParfactorList::shatterAgainstMySelf (
|
|||||||
size_t fIdx2)
|
size_t fIdx2)
|
||||||
{
|
{
|
||||||
/*
|
/*
|
||||||
Util::printDashedLine();
|
util::printDashedLine();
|
||||||
std::cout << "-> SHATTERING" << std::endl;
|
std::cout << "-> SHATTERING" << std::endl;
|
||||||
g->print();
|
g->print();
|
||||||
std::cout << "-> ON: " << g->argument (fIdx1) << "|" ;
|
std::cout << "-> ON: " << g->argument (fIdx1) << "|" ;
|
||||||
@ -338,7 +338,7 @@ ParfactorList::shatterAgainstMySelf (
|
|||||||
std::cout << "-> ON: " << g->argument (fIdx2) << "|" ;
|
std::cout << "-> ON: " << g->argument (fIdx2) << "|" ;
|
||||||
std::cout << g->constr()->tupleSet (g->argument (fIdx2).logVars())
|
std::cout << g->constr()->tupleSet (g->argument (fIdx2).logVars())
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
Util::printDashedLine();
|
util::printDashedLine();
|
||||||
*/
|
*/
|
||||||
ProbFormula& f1 = g->argument (fIdx1);
|
ProbFormula& f1 = g->argument (fIdx1);
|
||||||
ProbFormula& f2 = g->argument (fIdx2);
|
ProbFormula& f2 = g->argument (fIdx2);
|
||||||
@ -399,7 +399,7 @@ ParfactorList::shatterAgainstMySelf (
|
|||||||
res.push_back (res1[i]);
|
res.push_back (res1[i]);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
Util::addToVector (res, res2);
|
util::addToVector (res, res2);
|
||||||
for (size_t j = 0; j < res2.size(); j++) {
|
for (size_t j = 0; j < res2.size(); j++) {
|
||||||
}
|
}
|
||||||
if (res1[i] != g) {
|
if (res1[i] != g) {
|
||||||
@ -448,7 +448,7 @@ ParfactorList::shatter (
|
|||||||
ProbFormula& f1 = g1->argument (fIdx1);
|
ProbFormula& f1 = g1->argument (fIdx1);
|
||||||
ProbFormula& f2 = g2->argument (fIdx2);
|
ProbFormula& f2 = g2->argument (fIdx2);
|
||||||
/*
|
/*
|
||||||
Util::printDashedLine();
|
util::printDashedLine();
|
||||||
std::cout << "-> SHATTERING" << std::endl;
|
std::cout << "-> SHATTERING" << std::endl;
|
||||||
g1->print();
|
g1->print();
|
||||||
std::cout << "-> WITH" << std::endl;
|
std::cout << "-> WITH" << std::endl;
|
||||||
@ -457,7 +457,7 @@ ParfactorList::shatter (
|
|||||||
std::cout << g1->constr()->tupleSet (f1.logVars()) << std::endl;
|
std::cout << g1->constr()->tupleSet (f1.logVars()) << std::endl;
|
||||||
std::cout << "-> ON: " << f2 << "|" ;
|
std::cout << "-> ON: " << f2 << "|" ;
|
||||||
std::cout << g2->constr()->tupleSet (f2.logVars()) << std::endl;
|
std::cout << g2->constr()->tupleSet (f2.logVars()) << std::endl;
|
||||||
Util::printDashedLine();
|
util::printDashedLine();
|
||||||
*/
|
*/
|
||||||
if (f1.isAtom()) {
|
if (f1.isAtom()) {
|
||||||
f2.setGroup (f1.group());
|
f2.setGroup (f1.group());
|
||||||
|
@ -23,7 +23,7 @@ ProbFormula::sameSkeletonAs (const ProbFormula& f) const
|
|||||||
bool
|
bool
|
||||||
ProbFormula::contains (LogVar lv) const
|
ProbFormula::contains (LogVar lv) const
|
||||||
{
|
{
|
||||||
return Util::contains (logVars_, lv);
|
return util::contains (logVars_, lv);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -39,7 +39,7 @@ ProbFormula::contains (LogVarSet s) const
|
|||||||
size_t
|
size_t
|
||||||
ProbFormula::indexOf (LogVar X) const
|
ProbFormula::indexOf (LogVar X) const
|
||||||
{
|
{
|
||||||
return Util::indexOf (logVars_, X);
|
return util::indexOf (logVars_, X);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@
|
|||||||
|
|
||||||
namespace horus {
|
namespace horus {
|
||||||
|
|
||||||
namespace Globals {
|
namespace globals {
|
||||||
|
|
||||||
bool logDomain = false;
|
bool logDomain = false;
|
||||||
|
|
||||||
@ -20,7 +20,7 @@ GroundSolverType groundSolver = GroundSolverType::VE;
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
namespace Util {
|
namespace util {
|
||||||
|
|
||||||
template <> std::string
|
template <> std::string
|
||||||
toString (const bool& b)
|
toString (const bool& b)
|
||||||
@ -203,25 +203,25 @@ setHorusFlag (std::string option, std::string value)
|
|||||||
{
|
{
|
||||||
bool returnVal = true;
|
bool returnVal = true;
|
||||||
if (option == "lifted_solver") {
|
if (option == "lifted_solver") {
|
||||||
if (value == "lve") Globals::liftedSolver = LiftedSolverType::LVE;
|
if (value == "lve") globals::liftedSolver = LiftedSolverType::LVE;
|
||||||
else if (value == "lbp") Globals::liftedSolver = LiftedSolverType::LBP;
|
else if (value == "lbp") globals::liftedSolver = LiftedSolverType::LBP;
|
||||||
else if (value == "lkc") Globals::liftedSolver = LiftedSolverType::LKC;
|
else if (value == "lkc") globals::liftedSolver = LiftedSolverType::LKC;
|
||||||
else returnVal = invalidValue (option, value);
|
else returnVal = invalidValue (option, value);
|
||||||
|
|
||||||
} else if (option == "ground_solver" || option == "solver") {
|
} else if (option == "ground_solver" || option == "solver") {
|
||||||
if (value == "hve") Globals::groundSolver = GroundSolverType::VE;
|
if (value == "hve") globals::groundSolver = GroundSolverType::VE;
|
||||||
else if (value == "bp") Globals::groundSolver = GroundSolverType::BP;
|
else if (value == "bp") globals::groundSolver = GroundSolverType::BP;
|
||||||
else if (value == "cbp") Globals::groundSolver = GroundSolverType::CBP;
|
else if (value == "cbp") globals::groundSolver = GroundSolverType::CBP;
|
||||||
else returnVal = invalidValue (option, value);
|
else returnVal = invalidValue (option, value);
|
||||||
|
|
||||||
} else if (option == "verbosity") {
|
} else if (option == "verbosity") {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << value;
|
ss << value;
|
||||||
ss >> Globals::verbosity;
|
ss >> globals::verbosity;
|
||||||
|
|
||||||
} else if (option == "use_logarithms") {
|
} else if (option == "use_logarithms") {
|
||||||
if (value == "true") Globals::logDomain = true;
|
if (value == "true") globals::logDomain = true;
|
||||||
else if (value == "false") Globals::logDomain = false;
|
else if (value == "false") globals::logDomain = false;
|
||||||
else returnVal = invalidValue (option, value);
|
else returnVal = invalidValue (option, value);
|
||||||
|
|
||||||
} else if (option == "hve_elim_heuristic") {
|
} else if (option == "hve_elim_heuristic") {
|
||||||
@ -335,14 +335,14 @@ printDashedLine (std::ostream& os)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
namespace LogAware {
|
namespace log_aware {
|
||||||
|
|
||||||
void
|
void
|
||||||
normalize (Params& v)
|
normalize (Params& v)
|
||||||
{
|
{
|
||||||
if (Globals::logDomain) {
|
if (globals::logDomain) {
|
||||||
double sum = std::accumulate (v.begin(), v.end(),
|
double sum = std::accumulate (v.begin(), v.end(),
|
||||||
LogAware::addIdenty(), Util::logSum);
|
log_aware::addIdenty(), util::logSum);
|
||||||
assert (sum != -std::numeric_limits<double>::infinity());
|
assert (sum != -std::numeric_limits<double>::infinity());
|
||||||
v -= sum;
|
v -= sum;
|
||||||
} else {
|
} else {
|
||||||
@ -359,7 +359,7 @@ getL1Distance (const Params& v1, const Params& v2)
|
|||||||
{
|
{
|
||||||
assert (v1.size() == v2.size());
|
assert (v1.size() == v2.size());
|
||||||
double dist = 0.0;
|
double dist = 0.0;
|
||||||
if (Globals::logDomain) {
|
if (globals::logDomain) {
|
||||||
dist = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0,
|
dist = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0,
|
||||||
std::plus<double>(), FuncObject::abs_diff_exp<double>());
|
std::plus<double>(), FuncObject::abs_diff_exp<double>());
|
||||||
} else {
|
} else {
|
||||||
@ -376,7 +376,7 @@ getMaxNorm (const Params& v1, const Params& v2)
|
|||||||
{
|
{
|
||||||
assert (v1.size() == v2.size());
|
assert (v1.size() == v2.size());
|
||||||
double max = 0.0;
|
double max = 0.0;
|
||||||
if (Globals::logDomain) {
|
if (globals::logDomain) {
|
||||||
max = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0,
|
max = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0,
|
||||||
FuncObject::max<double>(), FuncObject::abs_diff_exp<double>());
|
FuncObject::max<double>(), FuncObject::abs_diff_exp<double>());
|
||||||
} else {
|
} else {
|
||||||
@ -391,7 +391,7 @@ getMaxNorm (const Params& v1, const Params& v2)
|
|||||||
double
|
double
|
||||||
pow (double base, unsigned iexp)
|
pow (double base, unsigned iexp)
|
||||||
{
|
{
|
||||||
return Globals::logDomain
|
return globals::logDomain
|
||||||
? base * iexp
|
? base * iexp
|
||||||
: std::pow (base, iexp);
|
: std::pow (base, iexp);
|
||||||
}
|
}
|
||||||
@ -402,7 +402,7 @@ double
|
|||||||
pow (double base, double exp)
|
pow (double base, double exp)
|
||||||
{
|
{
|
||||||
// `expoent' should not be in log domain
|
// `expoent' should not be in log domain
|
||||||
return Globals::logDomain
|
return globals::logDomain
|
||||||
? base * exp
|
? base * exp
|
||||||
: std::pow (base, exp);
|
: std::pow (base, exp);
|
||||||
}
|
}
|
||||||
@ -415,7 +415,7 @@ pow (Params& v, unsigned iexp)
|
|||||||
if (iexp == 1) {
|
if (iexp == 1) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
Globals::logDomain ? v *= iexp : v ^= (int)iexp;
|
globals::logDomain ? v *= iexp : v ^= (int)iexp;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -424,10 +424,10 @@ void
|
|||||||
pow (Params& v, double exp)
|
pow (Params& v, double exp)
|
||||||
{
|
{
|
||||||
// `expoent' should not be in log domain
|
// `expoent' should not be in log domain
|
||||||
Globals::logDomain ? v *= exp : v ^= exp;
|
globals::logDomain ? v *= exp : v ^= exp;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace LogAware
|
} // namespace log_aware
|
||||||
|
|
||||||
} // namespace horus
|
} // namespace horus
|
||||||
|
|
||||||
|
@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
#include "Horus.h"
|
#include "Horus.h"
|
||||||
|
|
||||||
|
|
||||||
namespace horus {
|
namespace horus {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
@ -25,7 +26,7 @@ const double NEG_INF = -std::numeric_limits<double>::infinity();
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
namespace Util {
|
namespace util {
|
||||||
|
|
||||||
template <typename T> void
|
template <typename T> void
|
||||||
addToVector (std::vector<T>&, const std::vector<T>&);
|
addToVector (std::vector<T>&, const std::vector<T>&);
|
||||||
@ -87,7 +88,7 @@ unsigned nrDigits (int);
|
|||||||
bool isInteger (const std::string&);
|
bool isInteger (const std::string&);
|
||||||
|
|
||||||
std::string parametersToString (
|
std::string parametersToString (
|
||||||
const Params&, unsigned = Constants::PRECISION);
|
const Params&, unsigned = constants::PRECISION);
|
||||||
|
|
||||||
std::vector<std::string> getStateLines (const Vars&);
|
std::vector<std::string> getStateLines (const Vars&);
|
||||||
|
|
||||||
@ -106,7 +107,7 @@ void printDashedLine (std::ostream& os = std::cout);
|
|||||||
|
|
||||||
|
|
||||||
template <typename T> void
|
template <typename T> void
|
||||||
Util::addToVector (std::vector<T>& v, const std::vector<T>& elements)
|
util::addToVector (std::vector<T>& v, const std::vector<T>& elements)
|
||||||
{
|
{
|
||||||
v.insert (v.end(), elements.begin(), elements.end());
|
v.insert (v.end(), elements.begin(), elements.end());
|
||||||
}
|
}
|
||||||
@ -114,7 +115,7 @@ Util::addToVector (std::vector<T>& v, const std::vector<T>& elements)
|
|||||||
|
|
||||||
|
|
||||||
template <typename T> void
|
template <typename T> void
|
||||||
Util::addToSet (std::set<T>& s, const std::vector<T>& elements)
|
util::addToSet (std::set<T>& s, const std::vector<T>& elements)
|
||||||
{
|
{
|
||||||
s.insert (elements.begin(), elements.end());
|
s.insert (elements.begin(), elements.end());
|
||||||
}
|
}
|
||||||
@ -122,7 +123,7 @@ Util::addToSet (std::set<T>& s, const std::vector<T>& elements)
|
|||||||
|
|
||||||
|
|
||||||
template <typename T> void
|
template <typename T> void
|
||||||
Util::addToQueue (std::queue<T>& q, const std::vector<T>& elements)
|
util::addToQueue (std::queue<T>& q, const std::vector<T>& elements)
|
||||||
{
|
{
|
||||||
for (size_t i = 0; i < elements.size(); i++) {
|
for (size_t i = 0; i < elements.size(); i++) {
|
||||||
q.push (elements[i]);
|
q.push (elements[i]);
|
||||||
@ -132,7 +133,7 @@ Util::addToQueue (std::queue<T>& q, const std::vector<T>& elements)
|
|||||||
|
|
||||||
|
|
||||||
template <typename T> bool
|
template <typename T> bool
|
||||||
Util::contains (const std::vector<T>& v, const T& e)
|
util::contains (const std::vector<T>& v, const T& e)
|
||||||
{
|
{
|
||||||
return std::find (v.begin(), v.end(), e) != v.end();
|
return std::find (v.begin(), v.end(), e) != v.end();
|
||||||
}
|
}
|
||||||
@ -140,7 +141,7 @@ Util::contains (const std::vector<T>& v, const T& e)
|
|||||||
|
|
||||||
|
|
||||||
template <typename T> bool
|
template <typename T> bool
|
||||||
Util::contains (const std::set<T>& s, const T& e)
|
util::contains (const std::set<T>& s, const T& e)
|
||||||
{
|
{
|
||||||
return s.find (e) != s.end();
|
return s.find (e) != s.end();
|
||||||
}
|
}
|
||||||
@ -148,7 +149,7 @@ Util::contains (const std::set<T>& s, const T& e)
|
|||||||
|
|
||||||
|
|
||||||
template <typename K, typename V> bool
|
template <typename K, typename V> bool
|
||||||
Util::contains (const std::unordered_map<K, V>& m, const K& k)
|
util::contains (const std::unordered_map<K, V>& m, const K& k)
|
||||||
{
|
{
|
||||||
return m.find (k) != m.end();
|
return m.find (k) != m.end();
|
||||||
}
|
}
|
||||||
@ -156,7 +157,7 @@ Util::contains (const std::unordered_map<K, V>& m, const K& k)
|
|||||||
|
|
||||||
|
|
||||||
template <typename T> size_t
|
template <typename T> size_t
|
||||||
Util::indexOf (const std::vector<T>& v, const T& e)
|
util::indexOf (const std::vector<T>& v, const T& e)
|
||||||
{
|
{
|
||||||
return std::distance (v.begin(),
|
return std::distance (v.begin(),
|
||||||
std::find (v.begin(), v.end(), e));
|
std::find (v.begin(), v.end(), e));
|
||||||
@ -165,7 +166,7 @@ Util::indexOf (const std::vector<T>& v, const T& e)
|
|||||||
|
|
||||||
|
|
||||||
template <class Operation> void
|
template <class Operation> void
|
||||||
Util::apply_n_times (
|
util::apply_n_times (
|
||||||
Params& v1,
|
Params& v1,
|
||||||
const Params& v2,
|
const Params& v2,
|
||||||
unsigned repetitions,
|
unsigned repetitions,
|
||||||
@ -187,7 +188,7 @@ Util::apply_n_times (
|
|||||||
|
|
||||||
|
|
||||||
template <typename T> void
|
template <typename T> void
|
||||||
Util::log (std::vector<T>& v)
|
util::log (std::vector<T>& v)
|
||||||
{
|
{
|
||||||
std::transform (v.begin(), v.end(), v.begin(), ::log);
|
std::transform (v.begin(), v.end(), v.begin(), ::log);
|
||||||
}
|
}
|
||||||
@ -195,7 +196,7 @@ Util::log (std::vector<T>& v)
|
|||||||
|
|
||||||
|
|
||||||
template <typename T> void
|
template <typename T> void
|
||||||
Util::exp (std::vector<T>& v)
|
util::exp (std::vector<T>& v)
|
||||||
{
|
{
|
||||||
std::transform (v.begin(), v.end(), v.begin(), ::exp);
|
std::transform (v.begin(), v.end(), v.begin(), ::exp);
|
||||||
}
|
}
|
||||||
@ -203,7 +204,7 @@ Util::exp (std::vector<T>& v)
|
|||||||
|
|
||||||
|
|
||||||
template <typename T> std::string
|
template <typename T> std::string
|
||||||
Util::elementsToString (const std::vector<T>& v, std::string sep)
|
util::elementsToString (const std::vector<T>& v, std::string sep)
|
||||||
{
|
{
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
for (size_t i = 0; i < v.size(); i++) {
|
for (size_t i = 0; i < v.size(); i++) {
|
||||||
@ -215,7 +216,7 @@ Util::elementsToString (const std::vector<T>& v, std::string sep)
|
|||||||
|
|
||||||
|
|
||||||
template <typename T> std::string
|
template <typename T> std::string
|
||||||
Util::toString (const T& t)
|
util::toString (const T& t)
|
||||||
{
|
{
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << t;
|
ss << t;
|
||||||
@ -225,7 +226,7 @@ Util::toString (const T& t)
|
|||||||
|
|
||||||
|
|
||||||
inline double
|
inline double
|
||||||
Util::logSum (double x, double y)
|
util::logSum (double x, double y)
|
||||||
{
|
{
|
||||||
// std::log (std::exp (x) + std::exp (y)) can overflow!
|
// std::log (std::exp (x) + std::exp (y)) can overflow!
|
||||||
assert (std::isnan (x) == false);
|
assert (std::isnan (x) == false);
|
||||||
@ -258,23 +259,23 @@ Util::logSum (double x, double y)
|
|||||||
|
|
||||||
|
|
||||||
inline unsigned
|
inline unsigned
|
||||||
Util::maxUnsigned (void)
|
util::maxUnsigned (void)
|
||||||
{
|
{
|
||||||
return std::numeric_limits<unsigned>::max();
|
return std::numeric_limits<unsigned>::max();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
namespace LogAware {
|
namespace log_aware {
|
||||||
|
|
||||||
inline double one() { return Globals::logDomain ? 0.0 : 1.0; }
|
inline double one() { return globals::logDomain ? 0.0 : 1.0; }
|
||||||
inline double zero() { return Globals::logDomain ? NEG_INF : 0.0; }
|
inline double zero() { return globals::logDomain ? NEG_INF : 0.0; }
|
||||||
inline double addIdenty() { return Globals::logDomain ? NEG_INF : 0.0; }
|
inline double addIdenty() { return globals::logDomain ? NEG_INF : 0.0; }
|
||||||
inline double multIdenty() { return Globals::logDomain ? 0.0 : 1.0; }
|
inline double multIdenty() { return globals::logDomain ? 0.0 : 1.0; }
|
||||||
inline double withEvidence() { return Globals::logDomain ? 0.0 : 1.0; }
|
inline double withEvidence() { return globals::logDomain ? 0.0 : 1.0; }
|
||||||
inline double noEvidence() { return Globals::logDomain ? NEG_INF : 0.0; }
|
inline double noEvidence() { return globals::logDomain ? NEG_INF : 0.0; }
|
||||||
inline double log (double v) { return Globals::logDomain ? ::log (v) : v; }
|
inline double log (double v) { return globals::logDomain ? ::log (v) : v; }
|
||||||
inline double exp (double v) { return Globals::logDomain ? ::exp (v) : v; }
|
inline double exp (double v) { return globals::logDomain ? ::exp (v) : v; }
|
||||||
|
|
||||||
void normalize (Params&);
|
void normalize (Params&);
|
||||||
|
|
||||||
@ -290,7 +291,7 @@ void pow (Params&, unsigned);
|
|||||||
|
|
||||||
void pow (Params&, double);
|
void pow (Params&, double);
|
||||||
|
|
||||||
} // namespace LogAware
|
} // namespace log_aware
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -392,7 +393,7 @@ template <typename T>
|
|||||||
std::ostream& operator<< (std::ostream& os, const std::vector<T>& v)
|
std::ostream& operator<< (std::ostream& os, const std::vector<T>& v)
|
||||||
{
|
{
|
||||||
os << "[" ;
|
os << "[" ;
|
||||||
os << Util::elementsToString (v, ", ");
|
os << util::elementsToString (v, ", ");
|
||||||
os << "]" ;
|
os << "]" ;
|
||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
|
@ -13,7 +13,7 @@ Var::Var (const Var* v)
|
|||||||
varId_ = v->varId();
|
varId_ = v->varId();
|
||||||
range_ = v->range();
|
range_ = v->range();
|
||||||
evidence_ = v->getEvidence();
|
evidence_ = v->getEvidence();
|
||||||
index_ = Util::maxUnsigned();
|
index_ = util::maxUnsigned();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -25,7 +25,7 @@ Var::Var (VarId varId, unsigned range, int evidence)
|
|||||||
varId_ = varId;
|
varId_ = varId;
|
||||||
range_ = range;
|
range_ = range;
|
||||||
evidence_ = evidence;
|
evidence_ = evidence;
|
||||||
index_ = Util::maxUnsigned();
|
index_ = util::maxUnsigned();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -81,7 +81,7 @@ inline void
|
|||||||
Var::addVarInfo (
|
Var::addVarInfo (
|
||||||
VarId vid, std::string label, const States& states)
|
VarId vid, std::string label, const States& states)
|
||||||
{
|
{
|
||||||
assert (Util::contains (varsInfo_, vid) == false);
|
assert (util::contains (varsInfo_, vid) == false);
|
||||||
varsInfo_.insert (std::make_pair (vid, VarInfo (label, states)));
|
varsInfo_.insert (std::make_pair (vid, VarInfo (label, states)));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -90,7 +90,7 @@ Var::addVarInfo (
|
|||||||
inline VarInfo
|
inline VarInfo
|
||||||
Var::getVarInfo (VarId vid)
|
Var::getVarInfo (VarId vid)
|
||||||
{
|
{
|
||||||
assert (Util::contains (varsInfo_, vid));
|
assert (util::contains (varsInfo_, vid));
|
||||||
return varsInfo_.find (vid)->second;
|
return varsInfo_.find (vid)->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -26,7 +26,7 @@ class Var
|
|||||||
public:
|
public:
|
||||||
Var (const Var*);
|
Var (const Var*);
|
||||||
|
|
||||||
Var (VarId, unsigned, int = Constants::NO_EVIDENCE);
|
Var (VarId, unsigned, int = constants::NO_EVIDENCE);
|
||||||
|
|
||||||
virtual ~Var (void) { };
|
virtual ~Var (void) { };
|
||||||
|
|
||||||
@ -79,7 +79,7 @@ class Var
|
|||||||
inline bool
|
inline bool
|
||||||
Var::hasEvidence (void) const
|
Var::hasEvidence (void) const
|
||||||
{
|
{
|
||||||
return evidence_ != Constants::NO_EVIDENCE;
|
return evidence_ != constants::NO_EVIDENCE;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@ namespace horus {
|
|||||||
Params
|
Params
|
||||||
VarElim::solveQuery (VarIds queryVids)
|
VarElim::solveQuery (VarIds queryVids)
|
||||||
{
|
{
|
||||||
if (Globals::verbosity > 1) {
|
if (globals::verbosity > 1) {
|
||||||
std::cout << "Solving query on " ;
|
std::cout << "Solving query on " ;
|
||||||
for (size_t i = 0; i < queryVids.size(); i++) {
|
for (size_t i = 0; i < queryVids.size(); i++) {
|
||||||
if (i != 0) std::cout << ", " ;
|
if (i != 0) std::cout << ", " ;
|
||||||
@ -28,8 +28,8 @@ VarElim::solveQuery (VarIds queryVids)
|
|||||||
createFactorList();
|
createFactorList();
|
||||||
absorveEvidence();
|
absorveEvidence();
|
||||||
Params params = processFactorList (queryVids);
|
Params params = processFactorList (queryVids);
|
||||||
if (Globals::logDomain) {
|
if (globals::logDomain) {
|
||||||
Util::exp (params);
|
util::exp (params);
|
||||||
}
|
}
|
||||||
return params;
|
return params;
|
||||||
}
|
}
|
||||||
@ -49,7 +49,7 @@ VarElim::printSolverFlags (void) const
|
|||||||
case ElimHeuristic::MIN_FILL: ss << "min_fill"; break;
|
case ElimHeuristic::MIN_FILL: ss << "min_fill"; break;
|
||||||
case ElimHeuristic::WEIGHTED_MIN_FILL: ss << "weighted_min_fill"; break;
|
case ElimHeuristic::WEIGHTED_MIN_FILL: ss << "weighted_min_fill"; break;
|
||||||
}
|
}
|
||||||
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
ss << ",log_domain=" << util::toString (globals::logDomain);
|
||||||
ss << "]" ;
|
ss << "]" ;
|
||||||
std::cout << ss.str() << std::endl;
|
std::cout << ss.str() << std::endl;
|
||||||
}
|
}
|
||||||
@ -81,15 +81,15 @@ VarElim::createFactorList (void)
|
|||||||
void
|
void
|
||||||
VarElim::absorveEvidence (void)
|
VarElim::absorveEvidence (void)
|
||||||
{
|
{
|
||||||
if (Globals::verbosity > 2) {
|
if (globals::verbosity > 2) {
|
||||||
Util::printDashedLine();
|
util::printDashedLine();
|
||||||
std::cout << "(initial factor list)" << std::endl;
|
std::cout << "(initial factor list)" << std::endl;
|
||||||
printActiveFactors();
|
printActiveFactors();
|
||||||
}
|
}
|
||||||
const VarNodes& varNodes = fg.varNodes();
|
const VarNodes& varNodes = fg.varNodes();
|
||||||
for (size_t i = 0; i < varNodes.size(); i++) {
|
for (size_t i = 0; i < varNodes.size(); i++) {
|
||||||
if (varNodes[i]->hasEvidence()) {
|
if (varNodes[i]->hasEvidence()) {
|
||||||
if (Globals::verbosity > 1) {
|
if (globals::verbosity > 1) {
|
||||||
std::cout << "-> aborving evidence on ";
|
std::cout << "-> aborving evidence on ";
|
||||||
std::cout << varNodes[i]->label() << " = " ;
|
std::cout << varNodes[i]->label() << " = " ;
|
||||||
std::cout << varNodes[i]->getEvidence() << std::endl;
|
std::cout << varNodes[i]->getEvidence() << std::endl;
|
||||||
@ -117,9 +117,9 @@ VarElim::processFactorList (const VarIds& queryVids)
|
|||||||
VarIds elimOrder = ElimGraph::getEliminationOrder (
|
VarIds elimOrder = ElimGraph::getEliminationOrder (
|
||||||
factorList_, queryVids);
|
factorList_, queryVids);
|
||||||
for (size_t i = 0; i < elimOrder.size(); i++) {
|
for (size_t i = 0; i < elimOrder.size(); i++) {
|
||||||
if (Globals::verbosity >= 2) {
|
if (globals::verbosity >= 2) {
|
||||||
if (Globals::verbosity >= 3) {
|
if (globals::verbosity >= 3) {
|
||||||
Util::printDashedLine();
|
util::printDashedLine();
|
||||||
printActiveFactors();
|
printActiveFactors();
|
||||||
}
|
}
|
||||||
std::cout << "-> summing out " ;
|
std::cout << "-> summing out " ;
|
||||||
@ -146,7 +146,7 @@ VarElim::processFactorList (const VarIds& queryVids)
|
|||||||
|
|
||||||
result.reorderArguments (unobservedVids);
|
result.reorderArguments (unobservedVids);
|
||||||
result.normalize();
|
result.normalize();
|
||||||
if (Globals::verbosity > 0) {
|
if (globals::verbosity > 0) {
|
||||||
std::cout << "total factor size: " << totalFactorSize_ << std::endl;
|
std::cout << "total factor size: " << totalFactorSize_ << std::endl;
|
||||||
std::cout << "largest factor size: " << largestFactorSize_ << std::endl;
|
std::cout << "largest factor size: " << largestFactorSize_ << std::endl;
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
|
@ -26,24 +26,24 @@ WeightedBp::getPosterioriOf (VarId vid)
|
|||||||
assert (var);
|
assert (var);
|
||||||
Params probs;
|
Params probs;
|
||||||
if (var->hasEvidence()) {
|
if (var->hasEvidence()) {
|
||||||
probs.resize (var->range(), LogAware::noEvidence());
|
probs.resize (var->range(), log_aware::noEvidence());
|
||||||
probs[var->getEvidence()] = LogAware::withEvidence();
|
probs[var->getEvidence()] = log_aware::withEvidence();
|
||||||
} else {
|
} else {
|
||||||
probs.resize (var->range(), LogAware::multIdenty());
|
probs.resize (var->range(), log_aware::multIdenty());
|
||||||
const BpLinks& links = ninf(var)->getLinks();
|
const BpLinks& links = ninf(var)->getLinks();
|
||||||
if (Globals::logDomain) {
|
if (globals::logDomain) {
|
||||||
for (size_t i = 0; i < links.size(); i++) {
|
for (size_t i = 0; i < links.size(); i++) {
|
||||||
WeightedLink* l = static_cast<WeightedLink*> (links[i]);
|
WeightedLink* l = static_cast<WeightedLink*> (links[i]);
|
||||||
probs += l->powMessage();
|
probs += l->powMessage();
|
||||||
}
|
}
|
||||||
LogAware::normalize (probs);
|
log_aware::normalize (probs);
|
||||||
Util::exp (probs);
|
util::exp (probs);
|
||||||
} else {
|
} else {
|
||||||
for (size_t i = 0; i < links.size(); i++) {
|
for (size_t i = 0; i < links.size(); i++) {
|
||||||
WeightedLink* l = static_cast<WeightedLink*> (links[i]);
|
WeightedLink* l = static_cast<WeightedLink*> (links[i]);
|
||||||
probs *= l->powMessage();
|
probs *= l->powMessage();
|
||||||
}
|
}
|
||||||
LogAware::normalize (probs);
|
log_aware::normalize (probs);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return probs;
|
return probs;
|
||||||
@ -56,7 +56,7 @@ WeightedBp::createLinks (void)
|
|||||||
{
|
{
|
||||||
using std::cout;
|
using std::cout;
|
||||||
using std::endl;
|
using std::endl;
|
||||||
if (Globals::verbosity > 0) {
|
if (globals::verbosity > 0) {
|
||||||
cout << "compressed factor graph contains " ;
|
cout << "compressed factor graph contains " ;
|
||||||
cout << fg.nrVarNodes() << " variables and " ;
|
cout << fg.nrVarNodes() << " variables and " ;
|
||||||
cout << fg.nrFacNodes() << " factors " << endl;
|
cout << fg.nrFacNodes() << " factors " << endl;
|
||||||
@ -66,7 +66,7 @@ WeightedBp::createLinks (void)
|
|||||||
for (size_t i = 0; i < facNodes.size(); i++) {
|
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||||
const VarNodes& neighs = facNodes[i]->neighbors();
|
const VarNodes& neighs = facNodes[i]->neighbors();
|
||||||
for (size_t j = 0; j < neighs.size(); j++) {
|
for (size_t j = 0; j < neighs.size(); j++) {
|
||||||
if (Globals::verbosity > 1) {
|
if (globals::verbosity > 1) {
|
||||||
cout << "creating link " ;
|
cout << "creating link " ;
|
||||||
cout << facNodes[i]->getLabel();
|
cout << facNodes[i]->getLabel();
|
||||||
cout << " -- " ;
|
cout << " -- " ;
|
||||||
@ -77,7 +77,7 @@ WeightedBp::createLinks (void)
|
|||||||
facNodes[i], neighs[j], j, weights_[i][j]));
|
facNodes[i], neighs[j], j, weights_[i][j]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (Globals::verbosity > 1) {
|
if (globals::verbosity > 1) {
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -92,7 +92,7 @@ WeightedBp::maxResidualSchedule (void)
|
|||||||
calculateMessage (links_[i]);
|
calculateMessage (links_[i]);
|
||||||
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
|
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
|
||||||
linkMap_.insert (make_pair (links_[i], it));
|
linkMap_.insert (make_pair (links_[i], it));
|
||||||
if (Globals::verbosity >= 1) {
|
if (globals::verbosity >= 1) {
|
||||||
std::cout << "calculating " << links_[i]->toString() << std::endl;
|
std::cout << "calculating " << links_[i]->toString() << std::endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -100,7 +100,7 @@ WeightedBp::maxResidualSchedule (void)
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (size_t c = 0; c < links_.size(); c++) {
|
for (size_t c = 0; c < links_.size(); c++) {
|
||||||
if (Globals::verbosity > 1) {
|
if (globals::verbosity > 1) {
|
||||||
std::cout << std::endl << "current residuals:" << std::endl;
|
std::cout << std::endl << "current residuals:" << std::endl;
|
||||||
for (SortedOrder::iterator it = sortedOrder_.begin();
|
for (SortedOrder::iterator it = sortedOrder_.begin();
|
||||||
it != sortedOrder_.end(); ++it) {
|
it != sortedOrder_.end(); ++it) {
|
||||||
@ -112,7 +112,7 @@ WeightedBp::maxResidualSchedule (void)
|
|||||||
|
|
||||||
SortedOrder::iterator it = sortedOrder_.begin();
|
SortedOrder::iterator it = sortedOrder_.begin();
|
||||||
BpLink* link = *it;
|
BpLink* link = *it;
|
||||||
if (Globals::verbosity >= 1) {
|
if (globals::verbosity >= 1) {
|
||||||
std::cout << "updating " << (*sortedOrder_.begin())->toString();
|
std::cout << "updating " << (*sortedOrder_.begin())->toString();
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
}
|
}
|
||||||
@ -130,7 +130,7 @@ WeightedBp::maxResidualSchedule (void)
|
|||||||
const BpLinks& links = ninf(factorNeighbors[i])->getLinks();
|
const BpLinks& links = ninf(factorNeighbors[i])->getLinks();
|
||||||
for (size_t j = 0; j < links.size(); j++) {
|
for (size_t j = 0; j < links.size(); j++) {
|
||||||
if (links[j]->varNode() != link->varNode()) {
|
if (links[j]->varNode() != link->varNode()) {
|
||||||
if (Globals::verbosity > 1) {
|
if (globals::verbosity > 1) {
|
||||||
std::cout << " calculating " << links[j]->toString();
|
std::cout << " calculating " << links[j]->toString();
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
}
|
}
|
||||||
@ -146,7 +146,7 @@ WeightedBp::maxResidualSchedule (void)
|
|||||||
const BpLinks& links = ninf(link->facNode())->getLinks();
|
const BpLinks& links = ninf(link->facNode())->getLinks();
|
||||||
for (size_t i = 0; i < links.size(); i++) {
|
for (size_t i = 0; i < links.size(); i++) {
|
||||||
if (links[i]->varNode() != link->varNode()) {
|
if (links[i]->varNode() != link->varNode()) {
|
||||||
if (Globals::verbosity > 1) {
|
if (globals::verbosity > 1) {
|
||||||
std::cout << " calculating " << links[i]->toString();
|
std::cout << " calculating " << links[i]->toString();
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
}
|
}
|
||||||
@ -171,19 +171,19 @@ WeightedBp::calcFactorToVarMsg (BpLink* _link)
|
|||||||
// calculate the product of messages that were sent
|
// calculate the product of messages that were sent
|
||||||
// to factor `src', except from var `dst'
|
// to factor `src', except from var `dst'
|
||||||
unsigned reps = 1;
|
unsigned reps = 1;
|
||||||
unsigned msgSize = Util::sizeExpected (src->factor().ranges());
|
unsigned msgSize = util::sizeExpected (src->factor().ranges());
|
||||||
Params msgProduct (msgSize, LogAware::multIdenty());
|
Params msgProduct (msgSize, log_aware::multIdenty());
|
||||||
if (Globals::logDomain) {
|
if (globals::logDomain) {
|
||||||
for (size_t i = links.size(); i-- > 0; ) {
|
for (size_t i = links.size(); i-- > 0; ) {
|
||||||
const WeightedLink* l = static_cast<const WeightedLink*> (links[i]);
|
const WeightedLink* l = static_cast<const WeightedLink*> (links[i]);
|
||||||
if ( ! (l->varNode() == dst && l->index() == link->index())) {
|
if ( ! (l->varNode() == dst && l->index() == link->index())) {
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (constants::SHOW_BP_CALCS) {
|
||||||
std::cout << " message from " << links[i]->varNode()->label();
|
std::cout << " message from " << links[i]->varNode()->label();
|
||||||
std::cout << ": " ;
|
std::cout << ": " ;
|
||||||
}
|
}
|
||||||
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
|
util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
|
||||||
reps, std::plus<double>());
|
reps, std::plus<double>());
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (constants::SHOW_BP_CALCS) {
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -193,13 +193,13 @@ WeightedBp::calcFactorToVarMsg (BpLink* _link)
|
|||||||
for (size_t i = links.size(); i-- > 0; ) {
|
for (size_t i = links.size(); i-- > 0; ) {
|
||||||
const WeightedLink* l = static_cast<const WeightedLink*> (links[i]);
|
const WeightedLink* l = static_cast<const WeightedLink*> (links[i]);
|
||||||
if ( ! (l->varNode() == dst && l->index() == link->index())) {
|
if ( ! (l->varNode() == dst && l->index() == link->index())) {
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (constants::SHOW_BP_CALCS) {
|
||||||
std::cout << " message from " << links[i]->varNode()->label();
|
std::cout << " message from " << links[i]->varNode()->label();
|
||||||
std::cout << ": " ;
|
std::cout << ": " ;
|
||||||
}
|
}
|
||||||
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
|
util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
|
||||||
reps, std::multiplies<double>());
|
reps, std::multiplies<double>());
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (constants::SHOW_BP_CALCS) {
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -209,12 +209,12 @@ WeightedBp::calcFactorToVarMsg (BpLink* _link)
|
|||||||
Factor result (src->factor().arguments(),
|
Factor result (src->factor().arguments(),
|
||||||
src->factor().ranges(), msgProduct);
|
src->factor().ranges(), msgProduct);
|
||||||
assert (msgProduct.size() == src->factor().size());
|
assert (msgProduct.size() == src->factor().size());
|
||||||
if (Globals::logDomain) {
|
if (globals::logDomain) {
|
||||||
result.params() += src->factor().params();
|
result.params() += src->factor().params();
|
||||||
} else {
|
} else {
|
||||||
result.params() *= src->factor().params();
|
result.params() *= src->factor().params();
|
||||||
}
|
}
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (constants::SHOW_BP_CALCS) {
|
||||||
std::cout << " message product: " ;
|
std::cout << " message product: " ;
|
||||||
std::cout << msgProduct << std::endl;
|
std::cout << msgProduct << std::endl;
|
||||||
std::cout << " original factor: " ;
|
std::cout << " original factor: " ;
|
||||||
@ -223,13 +223,13 @@ WeightedBp::calcFactorToVarMsg (BpLink* _link)
|
|||||||
std::cout << result.params() << std::endl;
|
std::cout << result.params() << std::endl;
|
||||||
}
|
}
|
||||||
result.sumOutAllExceptIndex (link->index());
|
result.sumOutAllExceptIndex (link->index());
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (constants::SHOW_BP_CALCS) {
|
||||||
std::cout << " marginalized: " ;
|
std::cout << " marginalized: " ;
|
||||||
std::cout << result.params() << std::endl;
|
std::cout << result.params() << std::endl;
|
||||||
}
|
}
|
||||||
link->nextMessage() = result.params();
|
link->nextMessage() = result.params();
|
||||||
LogAware::normalize (link->nextMessage());
|
log_aware::normalize (link->nextMessage());
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (constants::SHOW_BP_CALCS) {
|
||||||
std::cout << " curr msg: " ;
|
std::cout << " curr msg: " ;
|
||||||
std::cout << link->message() << std::endl;
|
std::cout << link->message() << std::endl;
|
||||||
std::cout << " next msg: " ;
|
std::cout << " next msg: " ;
|
||||||
@ -247,22 +247,22 @@ WeightedBp::getVarToFactorMsg (const BpLink* _link) const
|
|||||||
const FacNode* dst = link->facNode();
|
const FacNode* dst = link->facNode();
|
||||||
Params msg;
|
Params msg;
|
||||||
if (src->hasEvidence()) {
|
if (src->hasEvidence()) {
|
||||||
msg.resize (src->range(), LogAware::noEvidence());
|
msg.resize (src->range(), log_aware::noEvidence());
|
||||||
double value = link->message()[src->getEvidence()];
|
double value = link->message()[src->getEvidence()];
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (constants::SHOW_BP_CALCS) {
|
||||||
msg[src->getEvidence()] = value;
|
msg[src->getEvidence()] = value;
|
||||||
std::cout << msg << "^" << link->weight() << "-1" ;
|
std::cout << msg << "^" << link->weight() << "-1" ;
|
||||||
}
|
}
|
||||||
msg[src->getEvidence()] = LogAware::pow (value, link->weight() - 1);
|
msg[src->getEvidence()] = log_aware::pow (value, link->weight() - 1);
|
||||||
} else {
|
} else {
|
||||||
msg = link->message();
|
msg = link->message();
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (constants::SHOW_BP_CALCS) {
|
||||||
std::cout << msg << "^" << link->weight() << "-1" ;
|
std::cout << msg << "^" << link->weight() << "-1" ;
|
||||||
}
|
}
|
||||||
LogAware::pow (msg, link->weight() - 1);
|
log_aware::pow (msg, link->weight() - 1);
|
||||||
}
|
}
|
||||||
const BpLinks& links = ninf(src)->getLinks();
|
const BpLinks& links = ninf(src)->getLinks();
|
||||||
if (Globals::logDomain) {
|
if (globals::logDomain) {
|
||||||
for (size_t i = 0; i < links.size(); i++) {
|
for (size_t i = 0; i < links.size(); i++) {
|
||||||
WeightedLink* l = static_cast<WeightedLink*> (links[i]);
|
WeightedLink* l = static_cast<WeightedLink*> (links[i]);
|
||||||
if ( ! (l->facNode() == dst && l->index() == link->index())) {
|
if ( ! (l->facNode() == dst && l->index() == link->index())) {
|
||||||
@ -274,13 +274,13 @@ WeightedBp::getVarToFactorMsg (const BpLink* _link) const
|
|||||||
WeightedLink* l = static_cast<WeightedLink*> (links[i]);
|
WeightedLink* l = static_cast<WeightedLink*> (links[i]);
|
||||||
if ( ! (l->facNode() == dst && l->index() == link->index())) {
|
if ( ! (l->facNode() == dst && l->index() == link->index())) {
|
||||||
msg *= l->powMessage();
|
msg *= l->powMessage();
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (constants::SHOW_BP_CALCS) {
|
||||||
std::cout << " x " << l->nextMessage() << "^" << link->weight();
|
std::cout << " x " << l->nextMessage() << "^" << link->weight();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (constants::SHOW_BP_CALCS) {
|
||||||
std::cout << " = " << msg;
|
std::cout << " = " << msg;
|
||||||
}
|
}
|
||||||
return msg;
|
return msg;
|
||||||
|
@ -13,7 +13,7 @@ class WeightedLink : public BpLink
|
|||||||
public:
|
public:
|
||||||
WeightedLink (FacNode* fn, VarNode* vn, size_t idx, unsigned weight)
|
WeightedLink (FacNode* fn, VarNode* vn, size_t idx, unsigned weight)
|
||||||
: BpLink (fn, vn), index_(idx), weight_(weight),
|
: BpLink (fn, vn), index_(idx), weight_(weight),
|
||||||
pwdMsg_(vn->range(), LogAware::one()) { }
|
pwdMsg_(vn->range(), log_aware::one()) { }
|
||||||
|
|
||||||
size_t index (void) const { return index_; }
|
size_t index (void) const { return index_; }
|
||||||
|
|
||||||
@ -38,7 +38,7 @@ WeightedLink::updateMessage (void)
|
|||||||
{
|
{
|
||||||
pwdMsg_ = *nextMsg_;
|
pwdMsg_ = *nextMsg_;
|
||||||
swap (currMsg_, nextMsg_);
|
swap (currMsg_, nextMsg_);
|
||||||
LogAware::pow (pwdMsg_, weight_);
|
log_aware::pow (pwdMsg_, weight_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user