This repository has been archived on 2023-08-20. You can view files and clone it, but cannot push or open issues or pull requests.
yap-6.3/packages/CLPBN/horus/LiftedBp.cpp

243 lines
5.8 KiB
C++
Raw Normal View History

2013-02-07 20:09:10 +00:00
#include <cassert>
#include <sstream>
#include "LiftedBp.h"
2012-12-27 12:54:58 +00:00
#include "LiftedOperations.h"
#include "WeightedBp.h"
#include "FactorGraph.h"
namespace Horus {
2013-02-07 23:53:13 +00:00
LiftedBp::LiftedBp (const ParfactorList& parfactorList)
: LiftedSolver (parfactorList)
{
refineParfactors();
2012-09-11 18:40:41 +01:00
createFactorGraph();
solver_ = new WeightedBp (*fg_, getWeights());
}
2012-05-30 19:23:41 +01:00
LiftedBp::~LiftedBp()
2012-06-13 12:17:49 +01:00
{
delete solver_;
2012-09-11 18:40:41 +01:00
delete fg_;
2012-06-13 12:17:49 +01:00
}
2012-05-30 19:23:41 +01:00
Params
LiftedBp::solveQuery (const Grounds& query)
2012-05-30 19:23:41 +01:00
{
2012-05-31 23:06:53 +01:00
assert (query.empty() == false);
Params res;
2013-02-07 13:37:15 +00:00
std::vector<PrvGroup> groups = getQueryGroups (query);
2012-05-31 23:06:53 +01:00
if (query.size() == 1) {
res = solver_->getPosterioriOf (groups[0]);
} else {
2012-06-14 11:57:00 +01:00
ParfactorList::iterator it = pfList_.begin();
size_t idx = pfList_.size();
size_t count = 0;
while (it != pfList_.end()) {
if ((*it)->containsGrounds (query)) {
idx = count;
break;
}
++ it;
++ count;
}
if (idx == pfList_.size()) {
res = getJointByConditioning (pfList_, query);
} else {
VarIds queryVids;
for (unsigned i = 0; i < groups.size(); i++) {
queryVids.push_back (groups[i]);
}
2012-09-11 18:48:16 +01:00
res = solver_->getFactorJoint (fg_->facNodes()[idx], queryVids);
2012-05-31 23:06:53 +01:00
}
}
2012-05-31 23:06:53 +01:00
return res;
2012-05-30 19:23:41 +01:00
}
void
LiftedBp::printSolverFlags() const
2012-05-30 19:23:41 +01:00
{
2013-02-07 13:37:15 +00:00
std::stringstream ss;
2012-05-30 19:23:41 +01:00
ss << "lifted bp [" ;
2013-01-08 17:06:40 +00:00
ss << "bp_msg_schedule=" ;
typedef WeightedBp::MsgSchedule MsgSchedule;
switch (WeightedBp::msgSchedule()) {
case MsgSchedule::seqFixedSch: ss << "seq_fixed"; break;
case MsgSchedule::seqRandomSch: ss << "seq_random"; break;
case MsgSchedule::parallelSch: ss << "parallel"; break;
case MsgSchedule::maxResidualSch: ss << "max_residual"; break;
2012-06-01 19:31:07 +01:00
}
2013-01-08 17:06:40 +00:00
ss << ",bp_max_iter=" << WeightedBp::maxIterations();
ss << ",bp_accuracy=" << WeightedBp::accuracy();
ss << ",log_domain=" << Util::toString (Globals::logDomain);
2012-05-30 19:23:41 +01:00
ss << "]" ;
2013-02-07 13:37:15 +00:00
std::cout << ss.str() << std::endl;
2012-05-30 19:23:41 +01:00
}
void
LiftedBp::refineParfactors()
{
pfList_ = parfactorList;
while (iterate() == false);
if (Globals::verbosity > 2) {
Util::printHeader ("AFTER REFINEMENT");
pfList_.print();
}
}
bool
LiftedBp::iterate()
{
ParfactorList::iterator it = pfList_.begin();
while (it != pfList_.end()) {
const ProbFormulas& args = (*it)->arguments();
for (size_t i = 0; i < args.size(); i++) {
LogVarSet lvs = (*it)->logVarSet() - args[i].logVars();
if ((*it)->constr()->isCountNormalized (lvs) == false) {
Parfactors pfs = LiftedOperations::countNormalize (*it, lvs);
it = pfList_.removeAndDelete (it);
pfList_.add (pfs);
return false;
}
}
++ it;
}
return true;
}
2013-02-07 13:37:15 +00:00
std::vector<PrvGroup>
LiftedBp::getQueryGroups (const Grounds& query)
{
2013-02-07 13:37:15 +00:00
std::vector<PrvGroup> queryGroups;
for (unsigned i = 0; i < query.size(); i++) {
ParfactorList::const_iterator it = pfList_.begin();
for (; it != pfList_.end(); ++it) {
if ((*it)->containsGround (query[i])) {
queryGroups.push_back ((*it)->findGroup (query[i]));
break;
}
}
}
assert (queryGroups.size() == query.size());
return queryGroups;
}
2012-09-11 18:40:41 +01:00
void
LiftedBp::createFactorGraph()
{
2012-09-11 18:40:41 +01:00
fg_ = new FactorGraph();
ParfactorList::const_iterator it = pfList_.begin();
for (; it != pfList_.end(); ++it) {
2013-02-07 13:37:15 +00:00
std::vector<PrvGroup> groups = (*it)->getAllGroups();
VarIds varIds;
for (size_t i = 0; i < groups.size(); i++) {
varIds.push_back (groups[i]);
}
2012-09-11 18:40:41 +01:00
fg_->addFactor (Factor (varIds, (*it)->ranges(), (*it)->params()));
}
}
2013-02-07 13:37:15 +00:00
std::vector<std::vector<unsigned>>
LiftedBp::getWeights() const
{
2013-02-07 13:37:15 +00:00
std::vector<std::vector<unsigned>> weights;
weights.reserve (pfList_.size());
ParfactorList::const_iterator it = pfList_.begin();
for (; it != pfList_.end(); ++it) {
const ProbFormulas& args = (*it)->arguments();
weights.push_back ({ });
weights.back().reserve (args.size());
for (size_t i = 0; i < args.size(); i++) {
LogVarSet lvs = (*it)->logVarSet() - args[i].logVars();
weights.back().push_back ((*it)->constr()->getConditionalCount (lvs));
}
}
return weights;
}
2012-06-14 11:57:00 +01:00
unsigned
LiftedBp::rangeOfGround (const Ground& gr)
{
ParfactorList::iterator it = pfList_.begin();
while (it != pfList_.end()) {
if ((*it)->containsGround (gr)) {
PrvGroup prvGroup = (*it)->findGroup (gr);
return (*it)->range ((*it)->indexOfGroup (prvGroup));
}
++ it;
}
return Util::maxUnsigned();
2012-06-14 11:57:00 +01:00
}
2012-12-17 18:39:42 +00:00
2012-06-14 11:57:00 +01:00
Params
LiftedBp::getJointByConditioning (
const ParfactorList& pfList,
2012-11-10 00:18:20 +00:00
const Grounds& query)
2012-06-14 11:57:00 +01:00
{
LiftedBp solver (pfList);
2012-11-10 00:18:20 +00:00
Params prevBeliefs = solver.solveQuery ({query[0]});
Grounds obsGrounds = {query[0]};
for (size_t i = 1; i < query.size(); i++) {
2012-06-14 11:57:00 +01:00
Params newBeliefs;
2013-02-07 13:37:15 +00:00
std::vector<ObservedFormula> obsFs;
2012-06-14 11:57:00 +01:00
Ranges obsRanges;
for (size_t j = 0; j < obsGrounds.size(); j++) {
obsFs.push_back (ObservedFormula (
obsGrounds[j].functor(), 0, obsGrounds[j].args()));
obsRanges.push_back (rangeOfGround (obsGrounds[j]));
}
Indexer indexer (obsRanges, false);
while (indexer.valid()) {
for (size_t j = 0; j < obsFs.size(); j++) {
obsFs[j].setEvidence (indexer[j]);
}
ParfactorList tempPfList (pfList);
2012-11-10 00:18:20 +00:00
LiftedOperations::absorveEvidence (tempPfList, obsFs);
2012-06-14 11:57:00 +01:00
LiftedBp solver (tempPfList);
2012-11-10 00:18:20 +00:00
Params beliefs = solver.solveQuery ({query[i]});
2012-06-14 11:57:00 +01:00
for (size_t k = 0; k < beliefs.size(); k++) {
newBeliefs.push_back (beliefs[k]);
}
++ indexer;
}
int count = -1;
2012-11-10 00:18:20 +00:00
unsigned range = rangeOfGround (query[i]);
2012-06-14 11:57:00 +01:00
for (size_t j = 0; j < newBeliefs.size(); j++) {
if (j % range == 0) {
count ++;
}
newBeliefs[j] *= prevBeliefs[count];
}
prevBeliefs = newBeliefs;
2012-11-10 00:18:20 +00:00
obsGrounds.push_back (query[i]);
2012-06-14 11:57:00 +01:00
}
return prevBeliefs;
}
} // namespace Horus
2013-02-07 23:53:13 +00:00