Place the constructor on the right place
This commit is contained in:
parent
57339760b9
commit
421d6f72ee
@ -32,6 +32,18 @@ FactorGraph::FactorGraph (const FactorGraph& fg)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
FactorGraph::~FactorGraph (void)
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < varNodes_.size(); i++) {
|
||||||
|
delete varNodes_[i];
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < facNodes_.size(); i++) {
|
||||||
|
delete facNodes_[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
FactorGraph::readFromUaiFormat (const char* fileName)
|
FactorGraph::readFromUaiFormat (const char* fileName)
|
||||||
{
|
{
|
||||||
@ -167,18 +179,6 @@ FactorGraph::readFromLibDaiFormat (const char* fileName)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
FactorGraph::~FactorGraph (void)
|
|
||||||
{
|
|
||||||
for (size_t i = 0; i < varNodes_.size(); i++) {
|
|
||||||
delete varNodes_[i];
|
|
||||||
}
|
|
||||||
for (size_t i = 0; i < facNodes_.size(); i++) {
|
|
||||||
delete facNodes_[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
FactorGraph::addFactor (const Factor& factor)
|
FactorGraph::addFactor (const Factor& factor)
|
||||||
{
|
{
|
||||||
|
84
packages/CLPBN/horus2/BayesBall.cpp
Normal file
84
packages/CLPBN/horus2/BayesBall.cpp
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
#include <cstdlib>
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <fstream>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "BayesBall.h"
|
||||||
|
#include "Util.h"
|
||||||
|
|
||||||
|
|
||||||
|
FactorGraph*
|
||||||
|
BayesBall::getMinimalFactorGraph (const VarIds& queryIds)
|
||||||
|
{
|
||||||
|
assert (fg_.bayesianFactors());
|
||||||
|
Scheduling scheduling;
|
||||||
|
for (size_t i = 0; i < queryIds.size(); i++) {
|
||||||
|
assert (dag_.getNode (queryIds[i]));
|
||||||
|
BBNode* n = dag_.getNode (queryIds[i]);
|
||||||
|
scheduling.push (ScheduleInfo (n, false, true));
|
||||||
|
}
|
||||||
|
|
||||||
|
while (!scheduling.empty()) {
|
||||||
|
ScheduleInfo& sch = scheduling.front();
|
||||||
|
BBNode* n = sch.node;
|
||||||
|
n->setAsVisited();
|
||||||
|
if (n->hasEvidence() == false && sch.visitedFromChild) {
|
||||||
|
if (n->isMarkedOnTop() == false) {
|
||||||
|
n->markOnTop();
|
||||||
|
scheduleParents (n, scheduling);
|
||||||
|
}
|
||||||
|
if (n->isMarkedOnBottom() == false) {
|
||||||
|
n->markOnBottom();
|
||||||
|
scheduleChilds (n, scheduling);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (sch.visitedFromParent) {
|
||||||
|
if (n->hasEvidence() && n->isMarkedOnTop() == false) {
|
||||||
|
n->markOnTop();
|
||||||
|
scheduleParents (n, scheduling);
|
||||||
|
}
|
||||||
|
if (n->hasEvidence() == false && n->isMarkedOnBottom() == false) {
|
||||||
|
n->markOnBottom();
|
||||||
|
scheduleChilds (n, scheduling);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
scheduling.pop();
|
||||||
|
}
|
||||||
|
|
||||||
|
FactorGraph* fg = new FactorGraph();
|
||||||
|
constructGraph (fg);
|
||||||
|
return fg;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
BayesBall::constructGraph (FactorGraph* fg) const
|
||||||
|
{
|
||||||
|
const FacNodes& facNodes = fg_.facNodes();
|
||||||
|
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||||
|
const BBNode* n = dag_.getNode (
|
||||||
|
facNodes[i]->factor().argument (0));
|
||||||
|
if (n->isMarkedOnTop()) {
|
||||||
|
fg->addFactor (facNodes[i]->factor());
|
||||||
|
} else if (n->hasEvidence() && n->isVisited()) {
|
||||||
|
VarIds varIds = { facNodes[i]->factor().argument (0) };
|
||||||
|
Ranges ranges = { facNodes[i]->factor().range (0) };
|
||||||
|
Params params (ranges[0], LogAware::noEvidence());
|
||||||
|
params[n->getEvidence()] = LogAware::withEvidence();
|
||||||
|
fg->addFactor (Factor (varIds, ranges, params));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const VarNodes& varNodes = fg_.varNodes();
|
||||||
|
for (size_t i = 0; i < varNodes.size(); i++) {
|
||||||
|
if (varNodes[i]->hasEvidence()) {
|
||||||
|
VarNode* vn = fg->getVarNode (varNodes[i]->varId());
|
||||||
|
if (vn) {
|
||||||
|
vn->setEvidence (varNodes[i]->getEvidence());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
85
packages/CLPBN/horus2/BayesBall.h
Normal file
85
packages/CLPBN/horus2/BayesBall.h
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
#ifndef HORUS_BAYESBALL_H
|
||||||
|
#define HORUS_BAYESBALL_H
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <queue>
|
||||||
|
#include <list>
|
||||||
|
#include <map>
|
||||||
|
|
||||||
|
#include "FactorGraph.h"
|
||||||
|
#include "BayesBallGraph.h"
|
||||||
|
#include "Horus.h"
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
|
||||||
|
struct ScheduleInfo
|
||||||
|
{
|
||||||
|
ScheduleInfo (BBNode* n, bool vfp, bool vfc) :
|
||||||
|
node(n), visitedFromParent(vfp), visitedFromChild(vfc) { }
|
||||||
|
|
||||||
|
BBNode* node;
|
||||||
|
bool visitedFromParent;
|
||||||
|
bool visitedFromChild;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
typedef queue<ScheduleInfo, list<ScheduleInfo>> Scheduling;
|
||||||
|
|
||||||
|
|
||||||
|
class BayesBall
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
BayesBall (FactorGraph& fg)
|
||||||
|
: fg_(fg) , dag_(fg.getStructure())
|
||||||
|
{
|
||||||
|
dag_.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
FactorGraph* getMinimalFactorGraph (const VarIds&);
|
||||||
|
|
||||||
|
static FactorGraph* getMinimalFactorGraph (FactorGraph& fg, VarIds vids)
|
||||||
|
{
|
||||||
|
BayesBall bb (fg);
|
||||||
|
return bb.getMinimalFactorGraph (vids);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
|
||||||
|
void constructGraph (FactorGraph* fg) const;
|
||||||
|
|
||||||
|
void scheduleParents (const BBNode* n, Scheduling& sch) const;
|
||||||
|
|
||||||
|
void scheduleChilds (const BBNode* n, Scheduling& sch) const;
|
||||||
|
|
||||||
|
FactorGraph& fg_;
|
||||||
|
|
||||||
|
BayesBallGraph& dag_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
inline void
|
||||||
|
BayesBall::scheduleParents (const BBNode* n, Scheduling& sch) const
|
||||||
|
{
|
||||||
|
const vector<BBNode*>& ps = n->parents();
|
||||||
|
for (vector<BBNode*>::const_iterator it = ps.begin();
|
||||||
|
it != ps.end(); ++it) {
|
||||||
|
sch.push (ScheduleInfo (*it, false, true));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
inline void
|
||||||
|
BayesBall::scheduleChilds (const BBNode* n, Scheduling& sch) const
|
||||||
|
{
|
||||||
|
const vector<BBNode*>& cs = n->childs();
|
||||||
|
for (vector<BBNode*>::const_iterator it = cs.begin();
|
||||||
|
it != cs.end(); ++it) {
|
||||||
|
sch.push (ScheduleInfo (*it, true, false));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // HORUS_BAYESBALL_H
|
||||||
|
|
106
packages/CLPBN/horus2/BayesBallGraph.cpp
Normal file
106
packages/CLPBN/horus2/BayesBallGraph.cpp
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
#include <cstdlib>
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <fstream>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "BayesBallGraph.h"
|
||||||
|
#include "Util.h"
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
BayesBallGraph::addNode (BBNode* n)
|
||||||
|
{
|
||||||
|
assert (Util::contains (varMap_, n->varId()) == false);
|
||||||
|
nodes_.push_back (n);
|
||||||
|
varMap_[n->varId()] = n;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
BayesBallGraph::addEdge (VarId vid1, VarId vid2)
|
||||||
|
{
|
||||||
|
unordered_map<VarId, BBNode*>::iterator it1;
|
||||||
|
unordered_map<VarId, BBNode*>::iterator it2;
|
||||||
|
it1 = varMap_.find (vid1);
|
||||||
|
it2 = varMap_.find (vid2);
|
||||||
|
assert (it1 != varMap_.end());
|
||||||
|
assert (it2 != varMap_.end());
|
||||||
|
it1->second->addChild (it2->second);
|
||||||
|
it2->second->addParent (it1->second);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
const BBNode*
|
||||||
|
BayesBallGraph::getNode (VarId vid) const
|
||||||
|
{
|
||||||
|
unordered_map<VarId, BBNode*>::const_iterator it;
|
||||||
|
it = varMap_.find (vid);
|
||||||
|
return it != varMap_.end() ? it->second : 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
BBNode*
|
||||||
|
BayesBallGraph::getNode (VarId vid)
|
||||||
|
{
|
||||||
|
unordered_map<VarId, BBNode*>::const_iterator it;
|
||||||
|
it = varMap_.find (vid);
|
||||||
|
return it != varMap_.end() ? it->second : 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
BayesBallGraph::setIndexes (void)
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < nodes_.size(); i++) {
|
||||||
|
nodes_[i]->setIndex (i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
BayesBallGraph::clear (void)
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < nodes_.size(); i++) {
|
||||||
|
nodes_[i]->clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
BayesBallGraph::exportToGraphViz (const char* fileName)
|
||||||
|
{
|
||||||
|
ofstream out (fileName);
|
||||||
|
if (!out.is_open()) {
|
||||||
|
cerr << "Error: couldn't open file '" << fileName << "'." ;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
out << "digraph {" << endl;
|
||||||
|
out << "ranksep=1" << endl;
|
||||||
|
for (size_t i = 0; i < nodes_.size(); i++) {
|
||||||
|
out << nodes_[i]->varId() ;
|
||||||
|
out << " [" ;
|
||||||
|
out << "label=\"" << nodes_[i]->label() << "\"" ;
|
||||||
|
if (nodes_[i]->hasEvidence()) {
|
||||||
|
out << ",style=filled, fillcolor=yellow" ;
|
||||||
|
}
|
||||||
|
out << "]" << endl;
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < nodes_.size(); i++) {
|
||||||
|
const vector<BBNode*>& childs = nodes_[i]->childs();
|
||||||
|
for (size_t j = 0; j < childs.size(); j++) {
|
||||||
|
out << nodes_[i]->varId() << " -> " << childs[j]->varId();
|
||||||
|
out << " [style=bold]" << endl ;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out << "}" << endl;
|
||||||
|
out.close();
|
||||||
|
}
|
||||||
|
|
84
packages/CLPBN/horus2/BayesBallGraph.h
Normal file
84
packages/CLPBN/horus2/BayesBallGraph.h
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
#ifndef HORUS_BAYESBALLGRAPH_H
|
||||||
|
#define HORUS_BAYESBALLGRAPH_H
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <queue>
|
||||||
|
#include <list>
|
||||||
|
#include <map>
|
||||||
|
|
||||||
|
#include "Var.h"
|
||||||
|
#include "Horus.h"
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
class BBNode : public Var
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
BBNode (Var* v) : Var (v) , visited_(false),
|
||||||
|
markedOnTop_(false), markedOnBottom_(false) { }
|
||||||
|
|
||||||
|
const vector<BBNode*>& childs (void) const { return childs_; }
|
||||||
|
|
||||||
|
vector<BBNode*>& childs (void) { return childs_; }
|
||||||
|
|
||||||
|
const vector<BBNode*>& parents (void) const { return parents_; }
|
||||||
|
|
||||||
|
vector<BBNode*>& parents (void) { return parents_; }
|
||||||
|
|
||||||
|
void addParent (BBNode* p) { parents_.push_back (p); }
|
||||||
|
|
||||||
|
void addChild (BBNode* c) { childs_.push_back (c); }
|
||||||
|
|
||||||
|
bool isVisited (void) const { return visited_; }
|
||||||
|
|
||||||
|
void setAsVisited (void) { visited_ = true; }
|
||||||
|
|
||||||
|
bool isMarkedOnTop (void) const { return markedOnTop_; }
|
||||||
|
|
||||||
|
void markOnTop (void) { markedOnTop_ = true; }
|
||||||
|
|
||||||
|
bool isMarkedOnBottom (void) const { return markedOnBottom_; }
|
||||||
|
|
||||||
|
void markOnBottom (void) { markedOnBottom_ = true; }
|
||||||
|
|
||||||
|
void clear (void) { visited_ = markedOnTop_ = markedOnBottom_ = false; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool visited_;
|
||||||
|
bool markedOnTop_;
|
||||||
|
bool markedOnBottom_;
|
||||||
|
|
||||||
|
vector<BBNode*> childs_;
|
||||||
|
vector<BBNode*> parents_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
class BayesBallGraph
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
BayesBallGraph (void) { }
|
||||||
|
|
||||||
|
void addNode (BBNode* n);
|
||||||
|
|
||||||
|
void addEdge (VarId vid1, VarId vid2);
|
||||||
|
|
||||||
|
const BBNode* getNode (VarId vid) const;
|
||||||
|
|
||||||
|
BBNode* getNode (VarId vid);
|
||||||
|
|
||||||
|
bool empty (void) const { return nodes_.empty(); }
|
||||||
|
|
||||||
|
void setIndexes (void);
|
||||||
|
|
||||||
|
void clear (void);
|
||||||
|
|
||||||
|
void exportToGraphViz (const char*);
|
||||||
|
|
||||||
|
private:
|
||||||
|
vector<BBNode*> nodes_;
|
||||||
|
|
||||||
|
unordered_map<VarId, BBNode*> varMap_;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // HORUS_BAYESBALLGRAPH_H
|
||||||
|
|
471
packages/CLPBN/horus2/BeliefProp.cpp
Normal file
471
packages/CLPBN/horus2/BeliefProp.cpp
Normal file
@ -0,0 +1,471 @@
|
|||||||
|
#include <cassert>
|
||||||
|
#include <limits>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
#include "BeliefProp.h"
|
||||||
|
#include "FactorGraph.h"
|
||||||
|
#include "Factor.h"
|
||||||
|
#include "Indexer.h"
|
||||||
|
#include "Horus.h"
|
||||||
|
|
||||||
|
|
||||||
|
BeliefProp::BeliefProp (const FactorGraph& fg) : GroundSolver (fg)
|
||||||
|
{
|
||||||
|
runned_ = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
BeliefProp::~BeliefProp (void)
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < varsI_.size(); i++) {
|
||||||
|
delete varsI_[i];
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < facsI_.size(); i++) {
|
||||||
|
delete facsI_[i];
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < links_.size(); i++) {
|
||||||
|
delete links_[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Params
|
||||||
|
BeliefProp::solveQuery (VarIds queryVids)
|
||||||
|
{
|
||||||
|
assert (queryVids.empty() == false);
|
||||||
|
return queryVids.size() == 1
|
||||||
|
? getPosterioriOf (queryVids[0])
|
||||||
|
: getJointDistributionOf (queryVids);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
BeliefProp::printSolverFlags (void) const
|
||||||
|
{
|
||||||
|
stringstream ss;
|
||||||
|
ss << "belief propagation [" ;
|
||||||
|
ss << "schedule=" ;
|
||||||
|
typedef BpOptions::Schedule Sch;
|
||||||
|
switch (BpOptions::schedule) {
|
||||||
|
case Sch::SEQ_FIXED: ss << "seq_fixed"; break;
|
||||||
|
case Sch::SEQ_RANDOM: ss << "seq_random"; break;
|
||||||
|
case Sch::PARALLEL: ss << "parallel"; break;
|
||||||
|
case Sch::MAX_RESIDUAL: ss << "max_residual"; break;
|
||||||
|
}
|
||||||
|
ss << ",max_iter=" << Util::toString (BpOptions::maxIter);
|
||||||
|
ss << ",accuracy=" << Util::toString (BpOptions::accuracy);
|
||||||
|
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
||||||
|
ss << "]" ;
|
||||||
|
cout << ss.str() << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Params
|
||||||
|
BeliefProp::getPosterioriOf (VarId vid)
|
||||||
|
{
|
||||||
|
if (runned_ == false) {
|
||||||
|
runSolver();
|
||||||
|
}
|
||||||
|
assert (fg.getVarNode (vid));
|
||||||
|
VarNode* var = fg.getVarNode (vid);
|
||||||
|
Params probs;
|
||||||
|
if (var->hasEvidence()) {
|
||||||
|
probs.resize (var->range(), LogAware::noEvidence());
|
||||||
|
probs[var->getEvidence()] = LogAware::withEvidence();
|
||||||
|
} else {
|
||||||
|
probs.resize (var->range(), LogAware::multIdenty());
|
||||||
|
const BpLinks& links = ninf(var)->getLinks();
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
for (size_t i = 0; i < links.size(); i++) {
|
||||||
|
probs += links[i]->message();
|
||||||
|
}
|
||||||
|
LogAware::normalize (probs);
|
||||||
|
Util::exp (probs);
|
||||||
|
} else {
|
||||||
|
for (size_t i = 0; i < links.size(); i++) {
|
||||||
|
probs *= links[i]->message();
|
||||||
|
}
|
||||||
|
LogAware::normalize (probs);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return probs;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Params
|
||||||
|
BeliefProp::getJointDistributionOf (const VarIds& jointVarIds)
|
||||||
|
{
|
||||||
|
if (runned_ == false) {
|
||||||
|
runSolver();
|
||||||
|
}
|
||||||
|
VarNode* vn = fg.getVarNode (jointVarIds[0]);
|
||||||
|
const FacNodes& facNodes = vn->neighbors();
|
||||||
|
size_t idx = facNodes.size();
|
||||||
|
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||||
|
if (facNodes[i]->factor().contains (jointVarIds)) {
|
||||||
|
idx = i;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (idx == facNodes.size()) {
|
||||||
|
return getJointByConditioning (jointVarIds);
|
||||||
|
}
|
||||||
|
return getFactorJoint (facNodes[idx], jointVarIds);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Params
|
||||||
|
BeliefProp::getFactorJoint (
|
||||||
|
FacNode* fn,
|
||||||
|
const VarIds& jointVarIds)
|
||||||
|
{
|
||||||
|
if (runned_ == false) {
|
||||||
|
runSolver();
|
||||||
|
}
|
||||||
|
Factor res (fn->factor());
|
||||||
|
const BpLinks& links = ninf(fn)->getLinks();
|
||||||
|
for (size_t i = 0; i < links.size(); i++) {
|
||||||
|
Factor msg ({links[i]->varNode()->varId()},
|
||||||
|
{links[i]->varNode()->range()},
|
||||||
|
getVarToFactorMsg (links[i]));
|
||||||
|
res.multiply (msg);
|
||||||
|
}
|
||||||
|
res.sumOutAllExcept (jointVarIds);
|
||||||
|
res.reorderArguments (jointVarIds);
|
||||||
|
res.normalize();
|
||||||
|
Params jointDist = res.params();
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
Util::exp (jointDist);
|
||||||
|
}
|
||||||
|
return jointDist;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
BeliefProp::runSolver (void)
|
||||||
|
{
|
||||||
|
initializeSolver();
|
||||||
|
nIters_ = 0;
|
||||||
|
while (!converged() && nIters_ < BpOptions::maxIter) {
|
||||||
|
nIters_ ++;
|
||||||
|
if (Globals::verbosity > 1) {
|
||||||
|
Util::printHeader (string ("Iteration ") + Util::toString (nIters_));
|
||||||
|
}
|
||||||
|
switch (BpOptions::schedule) {
|
||||||
|
case BpOptions::Schedule::SEQ_RANDOM:
|
||||||
|
std::random_shuffle (links_.begin(), links_.end());
|
||||||
|
// no break
|
||||||
|
case BpOptions::Schedule::SEQ_FIXED:
|
||||||
|
for (size_t i = 0; i < links_.size(); i++) {
|
||||||
|
calculateAndUpdateMessage (links_[i]);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case BpOptions::Schedule::PARALLEL:
|
||||||
|
for (size_t i = 0; i < links_.size(); i++) {
|
||||||
|
calculateMessage (links_[i]);
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < links_.size(); i++) {
|
||||||
|
updateMessage(links_[i]);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case BpOptions::Schedule::MAX_RESIDUAL:
|
||||||
|
maxResidualSchedule();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (Globals::verbosity > 0) {
|
||||||
|
if (nIters_ < BpOptions::maxIter) {
|
||||||
|
cout << "Belief propagation converged in " ;
|
||||||
|
cout << nIters_ << " iterations" << endl;
|
||||||
|
} else {
|
||||||
|
cout << "The maximum number of iterations was hit, terminating..." ;
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
runned_ = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
BeliefProp::createLinks (void)
|
||||||
|
{
|
||||||
|
const FacNodes& facNodes = fg.facNodes();
|
||||||
|
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||||
|
const VarNodes& neighbors = facNodes[i]->neighbors();
|
||||||
|
for (size_t j = 0; j < neighbors.size(); j++) {
|
||||||
|
links_.push_back (new BpLink (facNodes[i], neighbors[j]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
BeliefProp::maxResidualSchedule (void)
|
||||||
|
{
|
||||||
|
if (nIters_ == 1) {
|
||||||
|
for (size_t i = 0; i < links_.size(); i++) {
|
||||||
|
calculateMessage (links_[i]);
|
||||||
|
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
|
||||||
|
linkMap_.insert (make_pair (links_[i], it));
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t c = 0; c < links_.size(); c++) {
|
||||||
|
if (Globals::verbosity > 1) {
|
||||||
|
cout << "current residuals:" << endl;
|
||||||
|
for (SortedOrder::iterator it = sortedOrder_.begin();
|
||||||
|
it != sortedOrder_.end(); ++it) {
|
||||||
|
cout << " " << setw (30) << left << (*it)->toString();
|
||||||
|
cout << "residual = " << (*it)->residual() << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
SortedOrder::iterator it = sortedOrder_.begin();
|
||||||
|
BpLink* link = *it;
|
||||||
|
if (link->residual() < BpOptions::accuracy) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
updateMessage (link);
|
||||||
|
link->clearResidual();
|
||||||
|
sortedOrder_.erase (it);
|
||||||
|
linkMap_.find (link)->second = sortedOrder_.insert (link);
|
||||||
|
|
||||||
|
// update the messages that depend on message source --> destin
|
||||||
|
const FacNodes& factorNeighbors = link->varNode()->neighbors();
|
||||||
|
for (size_t i = 0; i < factorNeighbors.size(); i++) {
|
||||||
|
if (factorNeighbors[i] != link->facNode()) {
|
||||||
|
const BpLinks& links = ninf(factorNeighbors[i])->getLinks();
|
||||||
|
for (size_t j = 0; j < links.size(); j++) {
|
||||||
|
if (links[j]->varNode() != link->varNode()) {
|
||||||
|
calculateMessage (links[j]);
|
||||||
|
BpLinkMap::iterator iter = linkMap_.find (links[j]);
|
||||||
|
sortedOrder_.erase (iter->second);
|
||||||
|
iter->second = sortedOrder_.insert (links[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (Globals::verbosity > 1) {
|
||||||
|
Util::printDashedLine();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
BeliefProp::calcFactorToVarMsg (BpLink* link)
|
||||||
|
{
|
||||||
|
FacNode* src = link->facNode();
|
||||||
|
const VarNode* dst = link->varNode();
|
||||||
|
const BpLinks& links = ninf(src)->getLinks();
|
||||||
|
// calculate the product of messages that were sent
|
||||||
|
// to factor `src', except from var `dst'
|
||||||
|
unsigned reps = 1;
|
||||||
|
unsigned msgSize = Util::sizeExpected (src->factor().ranges());
|
||||||
|
Params msgProduct (msgSize, LogAware::multIdenty());
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
for (size_t i = links.size(); i-- > 0; ) {
|
||||||
|
if (links[i]->varNode() != dst) {
|
||||||
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
|
cout << " message from " << links[i]->varNode()->label();
|
||||||
|
cout << ": " ;
|
||||||
|
}
|
||||||
|
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
|
||||||
|
reps, std::plus<double>());
|
||||||
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
reps *= links[i]->varNode()->range();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (size_t i = links.size(); i-- > 0; ) {
|
||||||
|
if (links[i]->varNode() != dst) {
|
||||||
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
|
cout << " message from " << links[i]->varNode()->label();
|
||||||
|
cout << ": " ;
|
||||||
|
}
|
||||||
|
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
|
||||||
|
reps, std::multiplies<double>());
|
||||||
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
reps *= links[i]->varNode()->range();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Factor result (src->factor().arguments(),
|
||||||
|
src->factor().ranges(), msgProduct);
|
||||||
|
result.multiply (src->factor());
|
||||||
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
|
cout << " message product: " << msgProduct << endl;
|
||||||
|
cout << " original factor: " << src->factor().params() << endl;
|
||||||
|
cout << " factor product: " << result.params() << endl;
|
||||||
|
}
|
||||||
|
result.sumOutAllExcept (dst->varId());
|
||||||
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
|
cout << " marginalized: " << result.params() << endl;
|
||||||
|
}
|
||||||
|
link->nextMessage() = result.params();
|
||||||
|
LogAware::normalize (link->nextMessage());
|
||||||
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
|
cout << " curr msg: " << link->message() << endl;
|
||||||
|
cout << " next msg: " << link->nextMessage() << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Params
|
||||||
|
BeliefProp::getVarToFactorMsg (const BpLink* link) const
|
||||||
|
{
|
||||||
|
const VarNode* src = link->varNode();
|
||||||
|
Params msg;
|
||||||
|
if (src->hasEvidence()) {
|
||||||
|
msg.resize (src->range(), LogAware::noEvidence());
|
||||||
|
msg[src->getEvidence()] = LogAware::withEvidence();
|
||||||
|
} else {
|
||||||
|
msg.resize (src->range(), LogAware::one());
|
||||||
|
}
|
||||||
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
|
cout << msg;
|
||||||
|
}
|
||||||
|
BpLinks::const_iterator it;
|
||||||
|
const BpLinks& links = ninf (src)->getLinks();
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
for (it = links.begin(); it != links.end(); ++it) {
|
||||||
|
if (*it != link) {
|
||||||
|
msg += (*it)->message();
|
||||||
|
}
|
||||||
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
|
cout << " x " << (*it)->message();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (it = links.begin(); it != links.end(); ++it) {
|
||||||
|
if (*it != link) {
|
||||||
|
msg *= (*it)->message();
|
||||||
|
}
|
||||||
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
|
cout << " x " << (*it)->message();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
|
cout << " = " << msg;
|
||||||
|
}
|
||||||
|
return msg;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Params
|
||||||
|
BeliefProp::getJointByConditioning (const VarIds& jointVarIds) const
|
||||||
|
{
|
||||||
|
return GroundSolver::getJointByConditioning (
|
||||||
|
GroundSolverType::BP, fg, jointVarIds);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
BeliefProp::initializeSolver (void)
|
||||||
|
{
|
||||||
|
const VarNodes& varNodes = fg.varNodes();
|
||||||
|
varsI_.reserve (varNodes.size());
|
||||||
|
for (size_t i = 0; i < varNodes.size(); i++) {
|
||||||
|
varsI_.push_back (new SPNodeInfo());
|
||||||
|
}
|
||||||
|
const FacNodes& facNodes = fg.facNodes();
|
||||||
|
facsI_.reserve (facNodes.size());
|
||||||
|
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||||
|
facsI_.push_back (new SPNodeInfo());
|
||||||
|
}
|
||||||
|
createLinks();
|
||||||
|
for (size_t i = 0; i < links_.size(); i++) {
|
||||||
|
FacNode* src = links_[i]->facNode();
|
||||||
|
VarNode* dst = links_[i]->varNode();
|
||||||
|
ninf (dst)->addBpLink (links_[i]);
|
||||||
|
ninf (src)->addBpLink (links_[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
BeliefProp::converged (void)
|
||||||
|
{
|
||||||
|
if (links_.size() == 0) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (nIters_ == 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (Globals::verbosity > 2) {
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
if (nIters_ == 1) {
|
||||||
|
if (Globals::verbosity > 1) {
|
||||||
|
cout << "no residuals" << endl << endl;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
bool converged = true;
|
||||||
|
if (BpOptions::schedule == BpOptions::Schedule::MAX_RESIDUAL) {
|
||||||
|
double maxResidual = (*(sortedOrder_.begin()))->residual();
|
||||||
|
if (maxResidual > BpOptions::accuracy) {
|
||||||
|
converged = false;
|
||||||
|
} else {
|
||||||
|
converged = true;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (size_t i = 0; i < links_.size(); i++) {
|
||||||
|
double residual = links_[i]->residual();
|
||||||
|
if (Globals::verbosity > 1) {
|
||||||
|
cout << links_[i]->toString() + " residual = " << residual << endl;
|
||||||
|
}
|
||||||
|
if (residual > BpOptions::accuracy) {
|
||||||
|
converged = false;
|
||||||
|
if (Globals::verbosity < 2) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (Globals::verbosity > 1) {
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return converged;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
BeliefProp::printLinkInformation (void) const
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < links_.size(); i++) {
|
||||||
|
BpLink* l = links_[i];
|
||||||
|
cout << l->toString() << ":" << endl;
|
||||||
|
cout << " curr msg = " ;
|
||||||
|
cout << l->message() << endl;
|
||||||
|
cout << " next msg = " ;
|
||||||
|
cout << l->nextMessage() << endl;
|
||||||
|
cout << " residual = " << l->residual() << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
188
packages/CLPBN/horus2/BeliefProp.h
Normal file
188
packages/CLPBN/horus2/BeliefProp.h
Normal file
@ -0,0 +1,188 @@
|
|||||||
|
#ifndef HORUS_BELIEFPROP_H
|
||||||
|
#define HORUS_BELIEFPROP_H
|
||||||
|
|
||||||
|
#include <set>
|
||||||
|
#include <vector>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "GroundSolver.h"
|
||||||
|
#include "Factor.h"
|
||||||
|
#include "FactorGraph.h"
|
||||||
|
#include "Util.h"
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
|
||||||
|
class BpLink
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
BpLink (FacNode* fn, VarNode* vn)
|
||||||
|
{
|
||||||
|
fac_ = fn;
|
||||||
|
var_ = vn;
|
||||||
|
v1_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
|
||||||
|
v2_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
|
||||||
|
currMsg_ = &v1_;
|
||||||
|
nextMsg_ = &v2_;
|
||||||
|
residual_ = 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual ~BpLink (void) { };
|
||||||
|
|
||||||
|
FacNode* facNode (void) const { return fac_; }
|
||||||
|
|
||||||
|
VarNode* varNode (void) const { return var_; }
|
||||||
|
|
||||||
|
const Params& message (void) const { return *currMsg_; }
|
||||||
|
|
||||||
|
Params& nextMessage (void) { return *nextMsg_; }
|
||||||
|
|
||||||
|
double residual (void) const { return residual_; }
|
||||||
|
|
||||||
|
void clearResidual (void) { residual_ = 0.0; }
|
||||||
|
|
||||||
|
void updateResidual (void)
|
||||||
|
{
|
||||||
|
residual_ = LogAware::getMaxNorm (v1_,v2_);
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual void updateMessage (void)
|
||||||
|
{
|
||||||
|
swap (currMsg_, nextMsg_);
|
||||||
|
}
|
||||||
|
|
||||||
|
string toString (void) const
|
||||||
|
{
|
||||||
|
stringstream ss;
|
||||||
|
ss << fac_->getLabel();
|
||||||
|
ss << " -- " ;
|
||||||
|
ss << var_->label();
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
FacNode* fac_;
|
||||||
|
VarNode* var_;
|
||||||
|
Params v1_;
|
||||||
|
Params v2_;
|
||||||
|
Params* currMsg_;
|
||||||
|
Params* nextMsg_;
|
||||||
|
double residual_;
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef vector<BpLink*> BpLinks;
|
||||||
|
|
||||||
|
|
||||||
|
class SPNodeInfo
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
void addBpLink (BpLink* link) { links_.push_back (link); }
|
||||||
|
const BpLinks& getLinks (void) { return links_; }
|
||||||
|
private:
|
||||||
|
BpLinks links_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
class BeliefProp : public GroundSolver
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
BeliefProp (const FactorGraph&);
|
||||||
|
|
||||||
|
virtual ~BeliefProp (void);
|
||||||
|
|
||||||
|
Params solveQuery (VarIds);
|
||||||
|
|
||||||
|
virtual void printSolverFlags (void) const;
|
||||||
|
|
||||||
|
virtual Params getPosterioriOf (VarId);
|
||||||
|
|
||||||
|
virtual Params getJointDistributionOf (const VarIds&);
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void runSolver (void);
|
||||||
|
|
||||||
|
virtual void createLinks (void);
|
||||||
|
|
||||||
|
virtual void maxResidualSchedule (void);
|
||||||
|
|
||||||
|
virtual void calcFactorToVarMsg (BpLink*);
|
||||||
|
|
||||||
|
virtual Params getVarToFactorMsg (const BpLink*) const;
|
||||||
|
|
||||||
|
virtual Params getJointByConditioning (const VarIds&) const;
|
||||||
|
|
||||||
|
public:
|
||||||
|
Params getFactorJoint (FacNode* fn, const VarIds&);
|
||||||
|
|
||||||
|
protected:
|
||||||
|
SPNodeInfo* ninf (const VarNode* var) const
|
||||||
|
{
|
||||||
|
return varsI_[var->getIndex()];
|
||||||
|
}
|
||||||
|
|
||||||
|
SPNodeInfo* ninf (const FacNode* fac) const
|
||||||
|
{
|
||||||
|
return facsI_[fac->getIndex()];
|
||||||
|
}
|
||||||
|
|
||||||
|
void calculateAndUpdateMessage (BpLink* link, bool calcResidual = true)
|
||||||
|
{
|
||||||
|
if (Globals::verbosity > 2) {
|
||||||
|
cout << "calculating & updating " << link->toString() << endl;
|
||||||
|
}
|
||||||
|
calcFactorToVarMsg (link);
|
||||||
|
if (calcResidual) {
|
||||||
|
link->updateResidual();
|
||||||
|
}
|
||||||
|
link->updateMessage();
|
||||||
|
}
|
||||||
|
|
||||||
|
void calculateMessage (BpLink* link, bool calcResidual = true)
|
||||||
|
{
|
||||||
|
if (Globals::verbosity > 2) {
|
||||||
|
cout << "calculating " << link->toString() << endl;
|
||||||
|
}
|
||||||
|
calcFactorToVarMsg (link);
|
||||||
|
if (calcResidual) {
|
||||||
|
link->updateResidual();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void updateMessage (BpLink* link)
|
||||||
|
{
|
||||||
|
link->updateMessage();
|
||||||
|
if (Globals::verbosity > 2) {
|
||||||
|
cout << "updating " << link->toString() << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct CompareResidual
|
||||||
|
{
|
||||||
|
inline bool operator() (const BpLink* link1, const BpLink* link2)
|
||||||
|
{
|
||||||
|
return link1->residual() > link2->residual();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
BpLinks links_;
|
||||||
|
unsigned nIters_;
|
||||||
|
vector<SPNodeInfo*> varsI_;
|
||||||
|
vector<SPNodeInfo*> facsI_;
|
||||||
|
bool runned_;
|
||||||
|
|
||||||
|
typedef multiset<BpLink*, CompareResidual> SortedOrder;
|
||||||
|
SortedOrder sortedOrder_;
|
||||||
|
|
||||||
|
typedef unordered_map<BpLink*, SortedOrder::iterator> BpLinkMap;
|
||||||
|
BpLinkMap linkMap_;
|
||||||
|
|
||||||
|
private:
|
||||||
|
void initializeSolver (void);
|
||||||
|
|
||||||
|
bool converged (void);
|
||||||
|
|
||||||
|
virtual void printLinkInformation (void) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // HORUS_BELIEFPROP_H
|
||||||
|
|
1174
packages/CLPBN/horus2/ConstraintTree.cpp
Normal file
1174
packages/CLPBN/horus2/ConstraintTree.cpp
Normal file
File diff suppressed because it is too large
Load Diff
237
packages/CLPBN/horus2/ConstraintTree.h
Normal file
237
packages/CLPBN/horus2/ConstraintTree.h
Normal file
@ -0,0 +1,237 @@
|
|||||||
|
#ifndef HORUS_CONSTRAINTTREE_H
|
||||||
|
#define HORUS_CONSTRAINTTREE_H
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "TinySet.h"
|
||||||
|
#include "LiftedUtils.h"
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
|
||||||
|
class CTNode;
|
||||||
|
typedef vector<CTNode*> CTNodes;
|
||||||
|
|
||||||
|
class ConstraintTree;
|
||||||
|
typedef vector<ConstraintTree*> ConstraintTrees;
|
||||||
|
|
||||||
|
|
||||||
|
class CTNode
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
|
||||||
|
struct CompareSymbol
|
||||||
|
{
|
||||||
|
bool operator() (const CTNode* n1, const CTNode* n2) const
|
||||||
|
{
|
||||||
|
return n1->symbol() < n2->symbol();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
private:
|
||||||
|
|
||||||
|
typedef TinySet<CTNode*, CompareSymbol> CTChilds_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
|
||||||
|
CTNode (const CTNode& n, const CTChilds_& chs = CTChilds_())
|
||||||
|
: symbol_(n.symbol()), childs_(chs), level_(n.level()) { }
|
||||||
|
|
||||||
|
CTNode (Symbol s, unsigned l, const CTChilds_& chs = CTChilds_())
|
||||||
|
: symbol_(s), childs_(chs), level_(l) { }
|
||||||
|
|
||||||
|
unsigned level (void) const { return level_; }
|
||||||
|
|
||||||
|
void setLevel (unsigned level) { level_ = level; }
|
||||||
|
|
||||||
|
Symbol symbol (void) const { return symbol_; }
|
||||||
|
|
||||||
|
void setSymbol (const Symbol s) { symbol_ = s; }
|
||||||
|
|
||||||
|
public:
|
||||||
|
|
||||||
|
CTChilds_& childs (void) { return childs_; }
|
||||||
|
|
||||||
|
const CTChilds_& childs (void) const { return childs_; }
|
||||||
|
|
||||||
|
size_t nrChilds (void) const { return childs_.size(); }
|
||||||
|
|
||||||
|
bool isRoot (void) const { return level_ == 0; }
|
||||||
|
|
||||||
|
bool isLeaf (void) const { return childs_.empty(); }
|
||||||
|
|
||||||
|
CTChilds_::iterator findSymbol (Symbol symb)
|
||||||
|
{
|
||||||
|
CTNode tmp (symb, 0);
|
||||||
|
return childs_.find (&tmp);
|
||||||
|
}
|
||||||
|
|
||||||
|
void mergeSubtree (CTNode*, bool = true);
|
||||||
|
|
||||||
|
void removeChild (CTNode*);
|
||||||
|
|
||||||
|
void removeChilds (void);
|
||||||
|
|
||||||
|
void removeAndDeleteChild (CTNode*);
|
||||||
|
|
||||||
|
void removeAndDeleteAllChilds (void);
|
||||||
|
|
||||||
|
SymbolSet childSymbols (void) const;
|
||||||
|
|
||||||
|
static CTNode* copySubtree (const CTNode*);
|
||||||
|
|
||||||
|
static void deleteSubtree (CTNode*);
|
||||||
|
|
||||||
|
private:
|
||||||
|
void updateChildLevels (CTNode*, unsigned);
|
||||||
|
|
||||||
|
Symbol symbol_;
|
||||||
|
CTChilds_ childs_;
|
||||||
|
unsigned level_;
|
||||||
|
};
|
||||||
|
|
||||||
|
ostream& operator<< (ostream &out, const CTNode&);
|
||||||
|
|
||||||
|
|
||||||
|
typedef TinySet<CTNode*, CTNode::CompareSymbol> CTChilds;
|
||||||
|
|
||||||
|
|
||||||
|
class ConstraintTree
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
ConstraintTree (unsigned);
|
||||||
|
|
||||||
|
ConstraintTree (const LogVars&);
|
||||||
|
|
||||||
|
ConstraintTree (const LogVars&, const Tuples&);
|
||||||
|
|
||||||
|
ConstraintTree (vector<vector<string>> names);
|
||||||
|
|
||||||
|
ConstraintTree (const ConstraintTree&);
|
||||||
|
|
||||||
|
ConstraintTree (const CTChilds& rootChilds, const LogVars& logVars)
|
||||||
|
: root_(new CTNode (0, 0, rootChilds)),
|
||||||
|
logVars_(logVars),
|
||||||
|
logVarSet_(logVars) { }
|
||||||
|
|
||||||
|
~ConstraintTree (void);
|
||||||
|
|
||||||
|
CTNode* root (void) const { return root_; }
|
||||||
|
|
||||||
|
bool empty (void) const { return root_->childs().empty(); }
|
||||||
|
|
||||||
|
const LogVars& logVars (void) const
|
||||||
|
{
|
||||||
|
assert (LogVarSet (logVars_) == logVarSet_);
|
||||||
|
return logVars_;
|
||||||
|
}
|
||||||
|
|
||||||
|
const LogVarSet& logVarSet (void) const
|
||||||
|
{
|
||||||
|
assert (LogVarSet (logVars_) == logVarSet_);
|
||||||
|
return logVarSet_;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t nrLogVars (void) const
|
||||||
|
{
|
||||||
|
return logVars_.size();
|
||||||
|
assert (LogVarSet (logVars_) == logVarSet_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void addTuple (const Tuple&);
|
||||||
|
|
||||||
|
bool containsTuple (const Tuple&);
|
||||||
|
|
||||||
|
void moveToTop (const LogVars&);
|
||||||
|
|
||||||
|
void moveToBottom (const LogVars&);
|
||||||
|
|
||||||
|
void join (ConstraintTree*, bool oneTwoOne = false);
|
||||||
|
|
||||||
|
unsigned getLevel (LogVar) const;
|
||||||
|
|
||||||
|
void rename (LogVar, LogVar);
|
||||||
|
|
||||||
|
void applySubstitution (const Substitution&);
|
||||||
|
|
||||||
|
void project (const LogVarSet&);
|
||||||
|
|
||||||
|
ConstraintTree projectedCopy (const LogVarSet&);
|
||||||
|
|
||||||
|
void remove (const LogVarSet&);
|
||||||
|
|
||||||
|
bool isSingleton (LogVar);
|
||||||
|
|
||||||
|
LogVarSet singletons (void);
|
||||||
|
|
||||||
|
TupleSet tupleSet (unsigned = 0) const;
|
||||||
|
|
||||||
|
TupleSet tupleSet (const LogVars&);
|
||||||
|
|
||||||
|
unsigned size (void) const;
|
||||||
|
|
||||||
|
unsigned nrSymbols (LogVar);
|
||||||
|
|
||||||
|
void exportToGraphViz (const char*, bool = false) const;
|
||||||
|
|
||||||
|
bool isCountNormalized (const LogVarSet&);
|
||||||
|
|
||||||
|
unsigned getConditionalCount (const LogVarSet&);
|
||||||
|
|
||||||
|
TinySet<unsigned> getConditionalCounts (const LogVarSet&);
|
||||||
|
|
||||||
|
bool isCartesianProduct (const LogVarSet&);
|
||||||
|
|
||||||
|
std::pair<ConstraintTree*, ConstraintTree*> split (const Tuple&);
|
||||||
|
|
||||||
|
std::pair<ConstraintTree*, ConstraintTree*> split (
|
||||||
|
const LogVars&, ConstraintTree*, const LogVars&);
|
||||||
|
|
||||||
|
ConstraintTrees countNormalize (const LogVarSet&);
|
||||||
|
|
||||||
|
ConstraintTrees jointCountNormalize (
|
||||||
|
ConstraintTree*, ConstraintTree*, LogVar, LogVar, LogVar);
|
||||||
|
|
||||||
|
LogVars expand (LogVar);
|
||||||
|
|
||||||
|
ConstraintTrees ground (LogVar);
|
||||||
|
|
||||||
|
void cloneLogVar (LogVar, LogVar);
|
||||||
|
|
||||||
|
ConstraintTree& operator= (const ConstraintTree& ct);
|
||||||
|
|
||||||
|
private:
|
||||||
|
unsigned countTuples (const CTNode*) const;
|
||||||
|
|
||||||
|
CTNodes getNodesBelow (CTNode*) const;
|
||||||
|
|
||||||
|
CTNodes getNodesAtLevel (unsigned) const;
|
||||||
|
|
||||||
|
unsigned nrNodes (const CTNode* n) const;
|
||||||
|
|
||||||
|
void appendOnBottom (CTNode* n1, const CTChilds&);
|
||||||
|
|
||||||
|
void swapLogVar (LogVar);
|
||||||
|
|
||||||
|
bool join (CTNode*, const Tuple&, size_t, CTNode*);
|
||||||
|
|
||||||
|
void getTuples (CTNode*, Tuples, unsigned, Tuples&, CTNodes&) const;
|
||||||
|
|
||||||
|
vector<std::pair<CTNode*, unsigned>> countNormalize (
|
||||||
|
const CTNode*, unsigned);
|
||||||
|
|
||||||
|
static void split (
|
||||||
|
CTNode*, CTNode*, CTChilds&, CTChilds&, unsigned);
|
||||||
|
|
||||||
|
CTNode* root_;
|
||||||
|
LogVars logVars_;
|
||||||
|
LogVarSet logVarSet_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
#endif // HORUS_CONSTRAINTTREE_H
|
||||||
|
|
424
packages/CLPBN/horus2/CountingBp.cpp
Normal file
424
packages/CLPBN/horus2/CountingBp.cpp
Normal file
@ -0,0 +1,424 @@
|
|||||||
|
#include "CountingBp.h"
|
||||||
|
#include "WeightedBp.h"
|
||||||
|
|
||||||
|
|
||||||
|
bool CountingBp::checkForIdenticalFactors = true;
|
||||||
|
|
||||||
|
|
||||||
|
CountingBp::CountingBp (const FactorGraph& fg)
|
||||||
|
: GroundSolver (fg), freeColor_(0)
|
||||||
|
{
|
||||||
|
findIdenticalFactors();
|
||||||
|
setInitialColors();
|
||||||
|
createGroups();
|
||||||
|
compressedFg_ = getCompressedFactorGraph();
|
||||||
|
solver_ = new WeightedBp (*compressedFg_, getWeights());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
CountingBp::~CountingBp (void)
|
||||||
|
{
|
||||||
|
delete solver_;
|
||||||
|
delete compressedFg_;
|
||||||
|
for (size_t i = 0; i < varClusters_.size(); i++) {
|
||||||
|
delete varClusters_[i];
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < facClusters_.size(); i++) {
|
||||||
|
delete facClusters_[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
CountingBp::printSolverFlags (void) const
|
||||||
|
{
|
||||||
|
stringstream ss;
|
||||||
|
ss << "counting bp [" ;
|
||||||
|
ss << "schedule=" ;
|
||||||
|
typedef BpOptions::Schedule Sch;
|
||||||
|
switch (BpOptions::schedule) {
|
||||||
|
case Sch::SEQ_FIXED: ss << "seq_fixed"; break;
|
||||||
|
case Sch::SEQ_RANDOM: ss << "seq_random"; break;
|
||||||
|
case Sch::PARALLEL: ss << "parallel"; break;
|
||||||
|
case Sch::MAX_RESIDUAL: ss << "max_residual"; break;
|
||||||
|
}
|
||||||
|
ss << ",max_iter=" << BpOptions::maxIter;
|
||||||
|
ss << ",accuracy=" << BpOptions::accuracy;
|
||||||
|
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
||||||
|
ss << ",chkif=" <<
|
||||||
|
Util::toString (CountingBp::checkForIdenticalFactors);
|
||||||
|
ss << "]" ;
|
||||||
|
cout << ss.str() << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Params
|
||||||
|
CountingBp::solveQuery (VarIds queryVids)
|
||||||
|
{
|
||||||
|
assert (queryVids.empty() == false);
|
||||||
|
Params res;
|
||||||
|
if (queryVids.size() == 1) {
|
||||||
|
res = solver_->getPosterioriOf (getRepresentative (queryVids[0]));
|
||||||
|
} else {
|
||||||
|
VarNode* vn = fg.getVarNode (queryVids[0]);
|
||||||
|
const FacNodes& facNodes = vn->neighbors();
|
||||||
|
size_t idx = facNodes.size();
|
||||||
|
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||||
|
if (facNodes[i]->factor().contains (queryVids)) {
|
||||||
|
idx = i;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
if (idx == facNodes.size()) {
|
||||||
|
res = GroundSolver::getJointByConditioning (
|
||||||
|
GroundSolverType::CBP, fg, queryVids);
|
||||||
|
} else {
|
||||||
|
VarIds reprArgs;
|
||||||
|
for (size_t i = 0; i < queryVids.size(); i++) {
|
||||||
|
reprArgs.push_back (getRepresentative (queryVids[i]));
|
||||||
|
}
|
||||||
|
FacNode* reprFac = getRepresentative (facNodes[idx]);
|
||||||
|
assert (reprFac != 0);
|
||||||
|
res = solver_->getFactorJoint (reprFac, reprArgs);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
CountingBp::findIdenticalFactors()
|
||||||
|
{
|
||||||
|
const FacNodes& facNodes = fg.facNodes();
|
||||||
|
if (checkForIdenticalFactors == false ||
|
||||||
|
facNodes.size() == 1) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||||
|
facNodes[i]->factor().setDistId (Util::maxUnsigned());
|
||||||
|
}
|
||||||
|
unsigned groupCount = 1;
|
||||||
|
for (size_t i = 0; i < facNodes.size() - 1; i++) {
|
||||||
|
Factor& f1 = facNodes[i]->factor();
|
||||||
|
if (f1.distId() != Util::maxUnsigned()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
f1.setDistId (groupCount);
|
||||||
|
for (size_t j = i + 1; j < facNodes.size(); j++) {
|
||||||
|
Factor& f2 = facNodes[j]->factor();
|
||||||
|
if (f2.distId() != Util::maxUnsigned()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (f1.size() == f2.size() &&
|
||||||
|
f1.ranges() == f2.ranges() &&
|
||||||
|
f1.params() == f2.params()) {
|
||||||
|
f2.setDistId (groupCount);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
groupCount ++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
CountingBp::setInitialColors (void)
|
||||||
|
{
|
||||||
|
varColors_.resize (fg.nrVarNodes());
|
||||||
|
facColors_.resize (fg.nrFacNodes());
|
||||||
|
// create the initial variable colors
|
||||||
|
VarColorMap colorMap;
|
||||||
|
const VarNodes& varNodes = fg.varNodes();
|
||||||
|
for (size_t i = 0; i < varNodes.size(); i++) {
|
||||||
|
unsigned range = varNodes[i]->range();
|
||||||
|
VarColorMap::iterator it = colorMap.find (range);
|
||||||
|
if (it == colorMap.end()) {
|
||||||
|
it = colorMap.insert (make_pair (
|
||||||
|
range, Colors (range + 1, -1))).first;
|
||||||
|
}
|
||||||
|
unsigned idx = varNodes[i]->hasEvidence()
|
||||||
|
? varNodes[i]->getEvidence()
|
||||||
|
: range;
|
||||||
|
Colors& stateColors = it->second;
|
||||||
|
if (stateColors[idx] == -1) {
|
||||||
|
stateColors[idx] = getNewColor();
|
||||||
|
}
|
||||||
|
setColor (varNodes[i], stateColors[idx]);
|
||||||
|
}
|
||||||
|
const FacNodes& facNodes = fg.facNodes();
|
||||||
|
// create the initial factor colors
|
||||||
|
DistColorMap distColors;
|
||||||
|
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||||
|
unsigned distId = facNodes[i]->factor().distId();
|
||||||
|
DistColorMap::iterator it = distColors.find (distId);
|
||||||
|
if (it == distColors.end()) {
|
||||||
|
it = distColors.insert (make_pair (distId, getNewColor())).first;
|
||||||
|
}
|
||||||
|
setColor (facNodes[i], it->second);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
CountingBp::createGroups (void)
|
||||||
|
{
|
||||||
|
VarSignMap varGroups;
|
||||||
|
FacSignMap facGroups;
|
||||||
|
unsigned nIters = 0;
|
||||||
|
bool groupsHaveChanged = true;
|
||||||
|
const VarNodes& varNodes = fg.varNodes();
|
||||||
|
const FacNodes& facNodes = fg.facNodes();
|
||||||
|
|
||||||
|
while (groupsHaveChanged || nIters == 1) {
|
||||||
|
nIters ++;
|
||||||
|
|
||||||
|
// set a new color to the variables with the same signature
|
||||||
|
size_t prevVarGroupsSize = varGroups.size();
|
||||||
|
varGroups.clear();
|
||||||
|
for (size_t i = 0; i < varNodes.size(); i++) {
|
||||||
|
const VarSignature& signature = getSignature (varNodes[i]);
|
||||||
|
VarSignMap::iterator it = varGroups.find (signature);
|
||||||
|
if (it == varGroups.end()) {
|
||||||
|
it = varGroups.insert (make_pair (signature, VarNodes())).first;
|
||||||
|
}
|
||||||
|
it->second.push_back (varNodes[i]);
|
||||||
|
}
|
||||||
|
for (VarSignMap::iterator it = varGroups.begin();
|
||||||
|
it != varGroups.end(); ++it) {
|
||||||
|
Color newColor = getNewColor();
|
||||||
|
VarNodes& groupMembers = it->second;
|
||||||
|
for (size_t i = 0; i < groupMembers.size(); i++) {
|
||||||
|
setColor (groupMembers[i], newColor);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t prevFactorGroupsSize = facGroups.size();
|
||||||
|
facGroups.clear();
|
||||||
|
// set a new color to the factors with the same signature
|
||||||
|
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||||
|
const FacSignature& signature = getSignature (facNodes[i]);
|
||||||
|
FacSignMap::iterator it = facGroups.find (signature);
|
||||||
|
if (it == facGroups.end()) {
|
||||||
|
it = facGroups.insert (make_pair (signature, FacNodes())).first;
|
||||||
|
}
|
||||||
|
it->second.push_back (facNodes[i]);
|
||||||
|
}
|
||||||
|
for (FacSignMap::iterator it = facGroups.begin();
|
||||||
|
it != facGroups.end(); ++it) {
|
||||||
|
Color newColor = getNewColor();
|
||||||
|
FacNodes& groupMembers = it->second;
|
||||||
|
for (size_t i = 0; i < groupMembers.size(); i++) {
|
||||||
|
setColor (groupMembers[i], newColor);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
groupsHaveChanged = prevVarGroupsSize != varGroups.size()
|
||||||
|
|| prevFactorGroupsSize != facGroups.size();
|
||||||
|
}
|
||||||
|
// printGroups (varGroups, facGroups);
|
||||||
|
createClusters (varGroups, facGroups);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
CountingBp::createClusters (
|
||||||
|
const VarSignMap& varGroups,
|
||||||
|
const FacSignMap& facGroups)
|
||||||
|
{
|
||||||
|
varClusters_.reserve (varGroups.size());
|
||||||
|
for (VarSignMap::const_iterator it = varGroups.begin();
|
||||||
|
it != varGroups.end(); ++it) {
|
||||||
|
const VarNodes& groupVars = it->second;
|
||||||
|
VarCluster* vc = new VarCluster (groupVars);
|
||||||
|
for (size_t i = 0; i < groupVars.size(); i++) {
|
||||||
|
varClusterMap_.insert (make_pair (groupVars[i]->varId(), vc));
|
||||||
|
}
|
||||||
|
varClusters_.push_back (vc);
|
||||||
|
}
|
||||||
|
|
||||||
|
facClusters_.reserve (facGroups.size());
|
||||||
|
for (FacSignMap::const_iterator it = facGroups.begin();
|
||||||
|
it != facGroups.end(); ++it) {
|
||||||
|
FacNode* groupFactor = it->second[0];
|
||||||
|
const VarNodes& neighs = groupFactor->neighbors();
|
||||||
|
VarClusters varClusters;
|
||||||
|
varClusters.reserve (neighs.size());
|
||||||
|
for (size_t i = 0; i < neighs.size(); i++) {
|
||||||
|
VarId vid = neighs[i]->varId();
|
||||||
|
varClusters.push_back (varClusterMap_.find (vid)->second);
|
||||||
|
}
|
||||||
|
facClusters_.push_back (new FacCluster (it->second, varClusters));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
VarSignature
|
||||||
|
CountingBp::getSignature (const VarNode* varNode)
|
||||||
|
{
|
||||||
|
const FacNodes& neighs = varNode->neighbors();
|
||||||
|
VarSignature sign;
|
||||||
|
sign.reserve (neighs.size() + 1);
|
||||||
|
for (size_t i = 0; i < neighs.size(); i++) {
|
||||||
|
sign.push_back (make_pair (
|
||||||
|
getColor (neighs[i]),
|
||||||
|
neighs[i]->factor().indexOf (varNode->varId())));
|
||||||
|
}
|
||||||
|
std::sort (sign.begin(), sign.end());
|
||||||
|
sign.push_back (make_pair (getColor (varNode), 0));
|
||||||
|
return sign;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
FacSignature
|
||||||
|
CountingBp::getSignature (const FacNode* facNode)
|
||||||
|
{
|
||||||
|
const VarNodes& neighs = facNode->neighbors();
|
||||||
|
FacSignature sign;
|
||||||
|
sign.reserve (neighs.size() + 1);
|
||||||
|
for (size_t i = 0; i < neighs.size(); i++) {
|
||||||
|
sign.push_back (getColor (neighs[i]));
|
||||||
|
}
|
||||||
|
sign.push_back (getColor (facNode));
|
||||||
|
return sign;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
VarId
|
||||||
|
CountingBp::getRepresentative (VarId vid)
|
||||||
|
{
|
||||||
|
assert (Util::contains (varClusterMap_, vid));
|
||||||
|
VarCluster* vc = varClusterMap_.find (vid)->second;
|
||||||
|
return vc->representative()->varId();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
FacNode*
|
||||||
|
CountingBp::getRepresentative (FacNode* fn)
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < facClusters_.size(); i++) {
|
||||||
|
if (Util::contains (facClusters_[i]->members(), fn)) {
|
||||||
|
return facClusters_[i]->representative();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
FactorGraph*
|
||||||
|
CountingBp::getCompressedFactorGraph (void)
|
||||||
|
{
|
||||||
|
FactorGraph* fg = new FactorGraph();
|
||||||
|
for (size_t i = 0; i < varClusters_.size(); i++) {
|
||||||
|
VarNode* newVar = new VarNode (varClusters_[i]->first());
|
||||||
|
varClusters_[i]->setRepresentative (newVar);
|
||||||
|
fg->addVarNode (newVar);
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < facClusters_.size(); i++) {
|
||||||
|
Vars vars;
|
||||||
|
const VarClusters& clusters = facClusters_[i]->varClusters();
|
||||||
|
for (size_t j = 0; j < clusters.size(); j++) {
|
||||||
|
vars.push_back (clusters[j]->representative());
|
||||||
|
}
|
||||||
|
const Factor& groundFac = facClusters_[i]->first()->factor();
|
||||||
|
FacNode* fn = new FacNode (Factor (
|
||||||
|
vars, groundFac.params(), groundFac.distId()));
|
||||||
|
facClusters_[i]->setRepresentative (fn);
|
||||||
|
fg->addFacNode (fn);
|
||||||
|
for (size_t j = 0; j < vars.size(); j++) {
|
||||||
|
fg->addEdge (static_cast<VarNode*> (vars[j]), fn);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fg;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
vector<vector<unsigned>>
|
||||||
|
CountingBp::getWeights (void) const
|
||||||
|
{
|
||||||
|
vector<vector<unsigned>> weights;
|
||||||
|
weights.reserve (facClusters_.size());
|
||||||
|
for (size_t i = 0; i < facClusters_.size(); i++) {
|
||||||
|
const VarClusters& neighs = facClusters_[i]->varClusters();
|
||||||
|
weights.push_back ({ });
|
||||||
|
weights.back().reserve (neighs.size());
|
||||||
|
for (size_t j = 0; j < neighs.size(); j++) {
|
||||||
|
weights.back().push_back (getWeight (
|
||||||
|
facClusters_[i], neighs[j], j));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return weights;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
unsigned
|
||||||
|
CountingBp::getWeight (
|
||||||
|
const FacCluster* fc,
|
||||||
|
const VarCluster* vc,
|
||||||
|
size_t index) const
|
||||||
|
{
|
||||||
|
unsigned weight = 0;
|
||||||
|
VarId reprVid = vc->representative()->varId();
|
||||||
|
VarNode* groundVar = fg.getVarNode (reprVid);
|
||||||
|
const FacNodes& neighs = groundVar->neighbors();
|
||||||
|
for (size_t i = 0; i < neighs.size(); i++) {
|
||||||
|
FacNodes::const_iterator it;
|
||||||
|
it = std::find (fc->members().begin(), fc->members().end(), neighs[i]);
|
||||||
|
if (it != fc->members().end() &&
|
||||||
|
(*it)->factor().indexOf (reprVid) == index) {
|
||||||
|
weight ++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return weight;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
CountingBp::printGroups (
|
||||||
|
const VarSignMap& varGroups,
|
||||||
|
const FacSignMap& facGroups) const
|
||||||
|
{
|
||||||
|
unsigned count = 1;
|
||||||
|
cout << "variable groups:" << endl;
|
||||||
|
for (VarSignMap::const_iterator it = varGroups.begin();
|
||||||
|
it != varGroups.end(); ++it) {
|
||||||
|
const VarNodes& groupMembers = it->second;
|
||||||
|
if (groupMembers.size() > 0) {
|
||||||
|
cout << count << ": " ;
|
||||||
|
for (size_t i = 0; i < groupMembers.size(); i++) {
|
||||||
|
cout << groupMembers[i]->label() << " " ;
|
||||||
|
}
|
||||||
|
count ++;
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
count = 1;
|
||||||
|
cout << endl << "factor groups:" << endl;
|
||||||
|
for (FacSignMap::const_iterator it = facGroups.begin();
|
||||||
|
it != facGroups.end(); ++it) {
|
||||||
|
const FacNodes& groupMembers = it->second;
|
||||||
|
if (groupMembers.size() > 0) {
|
||||||
|
cout << ++count << ": " ;
|
||||||
|
for (size_t i = 0; i < groupMembers.size(); i++) {
|
||||||
|
cout << groupMembers[i]->getLabel() << " " ;
|
||||||
|
}
|
||||||
|
count ++;
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
182
packages/CLPBN/horus2/CountingBp.h
Normal file
182
packages/CLPBN/horus2/CountingBp.h
Normal file
@ -0,0 +1,182 @@
|
|||||||
|
#ifndef HORUS_COUNTINGBP_H
|
||||||
|
#define HORUS_COUNTINGBP_H
|
||||||
|
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
|
#include "GroundSolver.h"
|
||||||
|
#include "FactorGraph.h"
|
||||||
|
#include "Util.h"
|
||||||
|
#include "Horus.h"
|
||||||
|
|
||||||
|
class VarCluster;
|
||||||
|
class FacCluster;
|
||||||
|
class WeightedBp;
|
||||||
|
|
||||||
|
typedef long Color;
|
||||||
|
typedef vector<Color> Colors;
|
||||||
|
typedef vector<std::pair<Color,unsigned>> VarSignature;
|
||||||
|
typedef vector<Color> FacSignature;
|
||||||
|
|
||||||
|
typedef unordered_map<unsigned, Color> DistColorMap;
|
||||||
|
typedef unordered_map<unsigned, Colors> VarColorMap;
|
||||||
|
|
||||||
|
typedef unordered_map<VarSignature, VarNodes> VarSignMap;
|
||||||
|
typedef unordered_map<FacSignature, FacNodes> FacSignMap;
|
||||||
|
|
||||||
|
typedef unordered_map<VarId, VarCluster*> VarClusterMap;
|
||||||
|
|
||||||
|
typedef vector<VarCluster*> VarClusters;
|
||||||
|
typedef vector<FacCluster*> FacClusters;
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
inline size_t hash_combine (size_t seed, const T& v)
|
||||||
|
{
|
||||||
|
return seed ^ (hash<T>()(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
namespace std {
|
||||||
|
template <typename T1, typename T2> struct hash<std::pair<T1,T2>>
|
||||||
|
{
|
||||||
|
size_t operator() (const std::pair<T1,T2>& p) const
|
||||||
|
{
|
||||||
|
return hash_combine (std::hash<T1>()(p.first), p.second);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T> struct hash<std::vector<T>>
|
||||||
|
{
|
||||||
|
size_t operator() (const std::vector<T>& vec) const
|
||||||
|
{
|
||||||
|
size_t h = 0;
|
||||||
|
typename vector<T>::const_iterator first = vec.begin();
|
||||||
|
typename vector<T>::const_iterator last = vec.end();
|
||||||
|
for (; first != last; ++first) {
|
||||||
|
h = hash_combine (h, *first);
|
||||||
|
}
|
||||||
|
return h;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class VarCluster
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
VarCluster (const VarNodes& vs) : members_(vs) { }
|
||||||
|
|
||||||
|
const VarNode* first (void) const { return members_.front(); }
|
||||||
|
|
||||||
|
const VarNodes& members (void) const { return members_; }
|
||||||
|
|
||||||
|
VarNode* representative (void) const { return repr_; }
|
||||||
|
|
||||||
|
void setRepresentative (VarNode* vn) { repr_ = vn; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
VarNodes members_;
|
||||||
|
VarNode* repr_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
class FacCluster
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
FacCluster (const FacNodes& fcs, const VarClusters& vcs)
|
||||||
|
: members_(fcs), varClusters_(vcs) { }
|
||||||
|
|
||||||
|
const FacNode* first (void) const { return members_.front(); }
|
||||||
|
|
||||||
|
const FacNodes& members (void) const { return members_; }
|
||||||
|
|
||||||
|
FacNode* representative (void) const { return repr_; }
|
||||||
|
|
||||||
|
void setRepresentative (FacNode* fn) { repr_ = fn; }
|
||||||
|
|
||||||
|
VarClusters& varClusters (void) { return varClusters_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
FacNodes members_;
|
||||||
|
FacNode* repr_;
|
||||||
|
VarClusters varClusters_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
class CountingBp : public GroundSolver
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
CountingBp (const FactorGraph& fg);
|
||||||
|
|
||||||
|
~CountingBp (void);
|
||||||
|
|
||||||
|
void printSolverFlags (void) const;
|
||||||
|
|
||||||
|
Params solveQuery (VarIds);
|
||||||
|
|
||||||
|
static bool checkForIdenticalFactors;
|
||||||
|
|
||||||
|
private:
|
||||||
|
Color getNewColor (void)
|
||||||
|
{
|
||||||
|
++ freeColor_;
|
||||||
|
return freeColor_ - 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
Color getColor (const VarNode* vn) const
|
||||||
|
{
|
||||||
|
return varColors_[vn->getIndex()];
|
||||||
|
}
|
||||||
|
|
||||||
|
Color getColor (const FacNode* fn) const
|
||||||
|
{
|
||||||
|
return facColors_[fn->getIndex()];
|
||||||
|
}
|
||||||
|
|
||||||
|
void setColor (const VarNode* vn, Color c)
|
||||||
|
{
|
||||||
|
varColors_[vn->getIndex()] = c;
|
||||||
|
}
|
||||||
|
|
||||||
|
void setColor (const FacNode* fn, Color c)
|
||||||
|
{
|
||||||
|
facColors_[fn->getIndex()] = c;
|
||||||
|
}
|
||||||
|
|
||||||
|
void findIdenticalFactors (void);
|
||||||
|
|
||||||
|
void setInitialColors (void);
|
||||||
|
|
||||||
|
void createGroups (void);
|
||||||
|
|
||||||
|
void createClusters (const VarSignMap&, const FacSignMap&);
|
||||||
|
|
||||||
|
VarSignature getSignature (const VarNode*);
|
||||||
|
|
||||||
|
FacSignature getSignature (const FacNode*);
|
||||||
|
|
||||||
|
void printGroups (const VarSignMap&, const FacSignMap&) const;
|
||||||
|
|
||||||
|
VarId getRepresentative (VarId vid);
|
||||||
|
|
||||||
|
FacNode* getRepresentative (FacNode*);
|
||||||
|
|
||||||
|
FactorGraph* getCompressedFactorGraph (void);
|
||||||
|
|
||||||
|
vector<vector<unsigned>> getWeights (void) const;
|
||||||
|
|
||||||
|
unsigned getWeight (const FacCluster*,
|
||||||
|
const VarCluster*, size_t index) const;
|
||||||
|
|
||||||
|
|
||||||
|
Color freeColor_;
|
||||||
|
Colors varColors_;
|
||||||
|
Colors facColors_;
|
||||||
|
VarClusters varClusters_;
|
||||||
|
FacClusters facClusters_;
|
||||||
|
VarClusterMap varClusterMap_;
|
||||||
|
const FactorGraph* compressedFg_;
|
||||||
|
WeightedBp* solver_;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // HORUS_COUNTINGBP_H
|
||||||
|
|
243
packages/CLPBN/horus2/ElimGraph.cpp
Normal file
243
packages/CLPBN/horus2/ElimGraph.cpp
Normal file
@ -0,0 +1,243 @@
|
|||||||
|
#include <limits>
|
||||||
|
|
||||||
|
#include <fstream>
|
||||||
|
|
||||||
|
#include "ElimGraph.h"
|
||||||
|
|
||||||
|
ElimHeuristic ElimGraph::elimHeuristic = MIN_NEIGHBORS;
|
||||||
|
|
||||||
|
|
||||||
|
ElimGraph::ElimGraph (const vector<Factor*>& factors)
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < factors.size(); i++) {
|
||||||
|
if (factors[i] == 0) { // if contained just one var with evidence
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const VarIds& vids = factors[i]->arguments();
|
||||||
|
for (size_t j = 0; j < vids.size() - 1; j++) {
|
||||||
|
EgNode* n1 = getEgNode (vids[j]);
|
||||||
|
if (n1 == 0) {
|
||||||
|
n1 = new EgNode (vids[j], factors[i]->range (j));
|
||||||
|
addNode (n1);
|
||||||
|
}
|
||||||
|
for (size_t k = j + 1; k < vids.size(); k++) {
|
||||||
|
EgNode* n2 = getEgNode (vids[k]);
|
||||||
|
if (n2 == 0) {
|
||||||
|
n2 = new EgNode (vids[k], factors[i]->range (k));
|
||||||
|
addNode (n2);
|
||||||
|
}
|
||||||
|
if (neighbors (n1, n2) == false) {
|
||||||
|
addEdge (n1, n2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (vids.size() == 1) {
|
||||||
|
if (getEgNode (vids[0]) == 0) {
|
||||||
|
addNode (new EgNode (vids[0], factors[i]->range (0)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
ElimGraph::~ElimGraph (void)
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < nodes_.size(); i++) {
|
||||||
|
delete nodes_[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
VarIds
|
||||||
|
ElimGraph::getEliminatingOrder (const VarIds& exclude)
|
||||||
|
{
|
||||||
|
VarIds elimOrder;
|
||||||
|
unmarked_.reserve (nodes_.size());
|
||||||
|
for (size_t i = 0; i < nodes_.size(); i++) {
|
||||||
|
if (Util::contains (exclude, nodes_[i]->varId()) == false) {
|
||||||
|
unmarked_.insert (nodes_[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
size_t nrVarsToEliminate = nodes_.size() - exclude.size();
|
||||||
|
for (size_t i = 0; i < nrVarsToEliminate; i++) {
|
||||||
|
EgNode* node = getLowestCostNode();
|
||||||
|
unmarked_.remove (node);
|
||||||
|
const EGNeighs& neighs = node->neighbors();
|
||||||
|
for (size_t j = 0; j < neighs.size(); j++) {
|
||||||
|
neighs[j]->removeNeighbor (node);
|
||||||
|
}
|
||||||
|
elimOrder.push_back (node->varId());
|
||||||
|
connectAllNeighbors (node);
|
||||||
|
}
|
||||||
|
return elimOrder;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
ElimGraph::print (void) const
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < nodes_.size(); i++) {
|
||||||
|
cout << "node " << nodes_[i]->label() << " neighs:" ;
|
||||||
|
EGNeighs neighs = nodes_[i]->neighbors();
|
||||||
|
for (size_t j = 0; j < neighs.size(); j++) {
|
||||||
|
cout << " " << neighs[j]->label();
|
||||||
|
}
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
ElimGraph::exportToGraphViz (
|
||||||
|
const char* fileName,
|
||||||
|
bool showNeighborless,
|
||||||
|
const VarIds& highlightVarIds) const
|
||||||
|
{
|
||||||
|
ofstream out (fileName);
|
||||||
|
if (!out.is_open()) {
|
||||||
|
cerr << "Error: couldn't open file '" << fileName << "'." ;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
out << "strict graph {" << endl;
|
||||||
|
for (size_t i = 0; i < nodes_.size(); i++) {
|
||||||
|
if (showNeighborless || nodes_[i]->neighbors().size() != 0) {
|
||||||
|
out << '"' << nodes_[i]->label() << '"' << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < highlightVarIds.size(); i++) {
|
||||||
|
EgNode* node =getEgNode (highlightVarIds[i]);
|
||||||
|
if (node) {
|
||||||
|
out << '"' << node->label() << '"' ;
|
||||||
|
out << " [shape=box3d]" << endl;
|
||||||
|
} else {
|
||||||
|
cerr << "Error: invalid variable id: " << highlightVarIds[i] << "." ;
|
||||||
|
cerr << endl;
|
||||||
|
exit (EXIT_FAILURE);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < nodes_.size(); i++) {
|
||||||
|
EGNeighs neighs = nodes_[i]->neighbors();
|
||||||
|
for (size_t j = 0; j < neighs.size(); j++) {
|
||||||
|
out << '"' << nodes_[i]->label() << '"' << " -- " ;
|
||||||
|
out << '"' << neighs[j]->label() << '"' << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out << "}" << endl;
|
||||||
|
out.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
VarIds
|
||||||
|
ElimGraph::getEliminationOrder (
|
||||||
|
const Factors& factors,
|
||||||
|
VarIds excludedVids)
|
||||||
|
{
|
||||||
|
if (elimHeuristic == ElimHeuristic::SEQUENTIAL) {
|
||||||
|
VarIds allVids;
|
||||||
|
Factors::const_iterator first = factors.begin();
|
||||||
|
Factors::const_iterator end = factors.end();
|
||||||
|
for (; first != end; ++first) {
|
||||||
|
Util::addToVector (allVids, (*first)->arguments());
|
||||||
|
}
|
||||||
|
TinySet<VarId> elimOrder (allVids);
|
||||||
|
elimOrder -= TinySet<VarId> (excludedVids);
|
||||||
|
return elimOrder.elements();
|
||||||
|
}
|
||||||
|
ElimGraph graph (factors);
|
||||||
|
return graph.getEliminatingOrder (excludedVids);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
ElimGraph::addNode (EgNode* n)
|
||||||
|
{
|
||||||
|
nodes_.push_back (n);
|
||||||
|
n->setIndex (nodes_.size() - 1);
|
||||||
|
varMap_.insert (make_pair (n->varId(), n));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
EgNode*
|
||||||
|
ElimGraph::getEgNode (VarId vid) const
|
||||||
|
{
|
||||||
|
unordered_map<VarId, EgNode*>::const_iterator it;
|
||||||
|
it = varMap_.find (vid);
|
||||||
|
return (it != varMap_.end()) ? it->second : 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
EgNode*
|
||||||
|
ElimGraph::getLowestCostNode (void) const
|
||||||
|
{
|
||||||
|
EgNode* bestNode = 0;
|
||||||
|
unsigned minCost = std::numeric_limits<unsigned>::max();
|
||||||
|
EGNeighs::const_iterator it;
|
||||||
|
switch (elimHeuristic) {
|
||||||
|
case MIN_NEIGHBORS: {
|
||||||
|
for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) {
|
||||||
|
unsigned cost = getNeighborsCost (*it);
|
||||||
|
if (cost < minCost) {
|
||||||
|
bestNode = *it;
|
||||||
|
minCost = cost;
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
break;
|
||||||
|
case MIN_WEIGHT: {
|
||||||
|
for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) {
|
||||||
|
unsigned cost = getWeightCost (*it);
|
||||||
|
if (cost < minCost) {
|
||||||
|
bestNode = *it;
|
||||||
|
minCost = cost;
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
break;
|
||||||
|
case MIN_FILL: {
|
||||||
|
for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) {
|
||||||
|
unsigned cost = getFillCost (*it);
|
||||||
|
if (cost < minCost) {
|
||||||
|
bestNode = *it;
|
||||||
|
minCost = cost;
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
break;
|
||||||
|
case WEIGHTED_MIN_FILL: {
|
||||||
|
for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) {
|
||||||
|
unsigned cost = getWeightedFillCost (*it);
|
||||||
|
if (cost < minCost) {
|
||||||
|
bestNode = *it;
|
||||||
|
minCost = cost;
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
assert (false);
|
||||||
|
}
|
||||||
|
assert (bestNode);
|
||||||
|
return bestNode;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
ElimGraph::connectAllNeighbors (const EgNode* n)
|
||||||
|
{
|
||||||
|
const EGNeighs& neighs = n->neighbors();
|
||||||
|
if (neighs.size() > 0) {
|
||||||
|
for (size_t i = 0; i < neighs.size() - 1; i++) {
|
||||||
|
for (size_t j = i + 1; j < neighs.size(); j++) {
|
||||||
|
if ( ! neighbors (neighs[i], neighs[j])) {
|
||||||
|
addEdge (neighs[i], neighs[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
139
packages/CLPBN/horus2/ElimGraph.h
Normal file
139
packages/CLPBN/horus2/ElimGraph.h
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
#ifndef HORUS_ELIMGRAPH_H
|
||||||
|
#define HORUS_ELIMGRAPH_H
|
||||||
|
|
||||||
|
#include "unordered_map"
|
||||||
|
|
||||||
|
#include "FactorGraph.h"
|
||||||
|
#include "TinySet.h"
|
||||||
|
#include "Horus.h"
|
||||||
|
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
enum ElimHeuristic
|
||||||
|
{
|
||||||
|
SEQUENTIAL,
|
||||||
|
MIN_NEIGHBORS,
|
||||||
|
MIN_WEIGHT,
|
||||||
|
MIN_FILL,
|
||||||
|
WEIGHTED_MIN_FILL
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
class EgNode;
|
||||||
|
|
||||||
|
typedef TinySet<EgNode*> EGNeighs;
|
||||||
|
|
||||||
|
|
||||||
|
class EgNode : public Var
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
EgNode (VarId vid, unsigned range) : Var (vid, range) { }
|
||||||
|
|
||||||
|
void addNeighbor (EgNode* n) { neighs_.insert (n); }
|
||||||
|
|
||||||
|
void removeNeighbor (EgNode* n) { neighs_.remove (n); }
|
||||||
|
|
||||||
|
bool isNeighbor (EgNode* n) const { return neighs_.contains (n); }
|
||||||
|
|
||||||
|
const EGNeighs& neighbors (void) const { return neighs_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
EGNeighs neighs_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
class ElimGraph
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
ElimGraph (const Factors&);
|
||||||
|
|
||||||
|
~ElimGraph (void);
|
||||||
|
|
||||||
|
VarIds getEliminatingOrder (const VarIds&);
|
||||||
|
|
||||||
|
void print (void) const;
|
||||||
|
|
||||||
|
void exportToGraphViz (const char*, bool = true,
|
||||||
|
const VarIds& = VarIds()) const;
|
||||||
|
|
||||||
|
static VarIds getEliminationOrder (const Factors&, VarIds);
|
||||||
|
|
||||||
|
static ElimHeuristic elimHeuristic;
|
||||||
|
|
||||||
|
private:
|
||||||
|
|
||||||
|
void addEdge (EgNode* n1, EgNode* n2)
|
||||||
|
{
|
||||||
|
assert (n1 != n2);
|
||||||
|
n1->addNeighbor (n2);
|
||||||
|
n2->addNeighbor (n1);
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned getNeighborsCost (const EgNode* n) const
|
||||||
|
{
|
||||||
|
return n->neighbors().size();
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned getWeightCost (const EgNode* n) const
|
||||||
|
{
|
||||||
|
unsigned cost = 1;
|
||||||
|
const EGNeighs& neighs = n->neighbors();
|
||||||
|
for (size_t i = 0; i < neighs.size(); i++) {
|
||||||
|
cost *= neighs[i]->range();
|
||||||
|
}
|
||||||
|
return cost;
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned getFillCost (const EgNode* n) const
|
||||||
|
{
|
||||||
|
unsigned cost = 0;
|
||||||
|
const EGNeighs& neighs = n->neighbors();
|
||||||
|
if (neighs.size() > 0) {
|
||||||
|
for (size_t i = 0; i < neighs.size() - 1; i++) {
|
||||||
|
for (size_t j = i + 1; j < neighs.size(); j++) {
|
||||||
|
if ( ! neighbors (neighs[i], neighs[j])) {
|
||||||
|
cost ++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return cost;
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned getWeightedFillCost (const EgNode* n) const
|
||||||
|
{
|
||||||
|
unsigned cost = 0;
|
||||||
|
const EGNeighs& neighs = n->neighbors();
|
||||||
|
if (neighs.size() > 0) {
|
||||||
|
for (size_t i = 0; i < neighs.size() - 1; i++) {
|
||||||
|
for (size_t j = i + 1; j < neighs.size(); j++) {
|
||||||
|
if ( ! neighbors (neighs[i], neighs[j])) {
|
||||||
|
cost += neighs[i]->range() * neighs[j]->range();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return cost;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool neighbors (EgNode* n1, EgNode* n2) const
|
||||||
|
{
|
||||||
|
return n1->isNeighbor (n2);
|
||||||
|
}
|
||||||
|
|
||||||
|
void addNode (EgNode*);
|
||||||
|
|
||||||
|
EgNode* getEgNode (VarId) const;
|
||||||
|
|
||||||
|
EgNode* getLowestCostNode (void) const;
|
||||||
|
|
||||||
|
void connectAllNeighbors (const EgNode*);
|
||||||
|
|
||||||
|
vector<EgNode*> nodes_;
|
||||||
|
TinySet<EgNode*> unmarked_;
|
||||||
|
unordered_map<VarId, EgNode*> varMap_;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // HORUS_ELIMGRAPH_H
|
||||||
|
|
237
packages/CLPBN/horus2/Factor.cpp
Normal file
237
packages/CLPBN/horus2/Factor.cpp
Normal file
@ -0,0 +1,237 @@
|
|||||||
|
#include <cstdlib>
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "Factor.h"
|
||||||
|
#include "Indexer.h"
|
||||||
|
|
||||||
|
|
||||||
|
Factor::Factor (const Factor& g)
|
||||||
|
{
|
||||||
|
clone (g);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Factor::Factor (
|
||||||
|
const VarIds& vids,
|
||||||
|
const Ranges& ranges,
|
||||||
|
const Params& params,
|
||||||
|
unsigned distId)
|
||||||
|
{
|
||||||
|
args_ = vids;
|
||||||
|
ranges_ = ranges;
|
||||||
|
params_ = params;
|
||||||
|
distId_ = distId;
|
||||||
|
assert (params_.size() == Util::sizeExpected (ranges_));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Factor::Factor (
|
||||||
|
const Vars& vars,
|
||||||
|
const Params& params,
|
||||||
|
unsigned distId)
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < vars.size(); i++) {
|
||||||
|
args_.push_back (vars[i]->varId());
|
||||||
|
ranges_.push_back (vars[i]->range());
|
||||||
|
}
|
||||||
|
params_ = params;
|
||||||
|
distId_ = distId;
|
||||||
|
assert (params_.size() == Util::sizeExpected (ranges_));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
Factor::sumOut (VarId vid)
|
||||||
|
{
|
||||||
|
if (vid == args_.front() && ranges_.front() == 2) {
|
||||||
|
// optimization
|
||||||
|
sumOutFirstVariable();
|
||||||
|
} else if (vid == args_.back() && ranges_.back() == 2) {
|
||||||
|
// optimization
|
||||||
|
sumOutLastVariable();
|
||||||
|
} else {
|
||||||
|
assert (indexOf (vid) != args_.size());
|
||||||
|
sumOutIndex (indexOf (vid));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
Factor::sumOutAllExcept (VarId vid)
|
||||||
|
{
|
||||||
|
assert (indexOf (vid) != args_.size());
|
||||||
|
sumOutAllExceptIndex (indexOf (vid));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
Factor::sumOutAllExcept (const VarIds& vids)
|
||||||
|
{
|
||||||
|
vector<bool> mask (args_.size(), false);
|
||||||
|
for (unsigned i = 0; i < vids.size(); i++) {
|
||||||
|
assert (indexOf (vids[i]) != args_.size());
|
||||||
|
mask[indexOf (vids[i])] = true;
|
||||||
|
}
|
||||||
|
sumOutArgs (mask);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
Factor::sumOutAllExceptIndex (size_t idx)
|
||||||
|
{
|
||||||
|
assert (idx < args_.size());
|
||||||
|
vector<bool> mask (args_.size(), false);
|
||||||
|
mask[idx] = true;
|
||||||
|
sumOutArgs (mask);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
Factor::multiply (Factor& g)
|
||||||
|
{
|
||||||
|
if (args_.size() == 0) {
|
||||||
|
clone (g);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
TFactor<VarId>::multiply (g);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
string
|
||||||
|
Factor::getLabel (void) const
|
||||||
|
{
|
||||||
|
stringstream ss;
|
||||||
|
ss << "f(" ;
|
||||||
|
for (size_t i = 0; i < args_.size(); i++) {
|
||||||
|
if (i != 0) ss << "," ;
|
||||||
|
ss << Var (args_[i], ranges_[i]).label();
|
||||||
|
}
|
||||||
|
ss << ")" ;
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
Factor::print (void) const
|
||||||
|
{
|
||||||
|
Vars vars;
|
||||||
|
for (size_t i = 0; i < args_.size(); i++) {
|
||||||
|
vars.push_back (new Var (args_[i], ranges_[i]));
|
||||||
|
}
|
||||||
|
vector<string> jointStrings = Util::getStateLines (vars);
|
||||||
|
for (size_t i = 0; i < params_.size(); i++) {
|
||||||
|
// cout << "[" << distId_ << "] " ;
|
||||||
|
cout << "f(" << jointStrings[i] << ")" ;
|
||||||
|
cout << " = " << params_[i] << endl;
|
||||||
|
}
|
||||||
|
cout << endl;
|
||||||
|
for (size_t i = 0; i < vars.size(); i++) {
|
||||||
|
delete vars[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
Factor::sumOutFirstVariable (void)
|
||||||
|
{
|
||||||
|
size_t sep = params_.size() / 2;
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
std::transform (
|
||||||
|
params_.begin(), params_.begin() + sep,
|
||||||
|
params_.begin() + sep, params_.begin(),
|
||||||
|
Util::logSum);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
std::transform (
|
||||||
|
params_.begin(), params_.begin() + sep,
|
||||||
|
params_.begin() + sep, params_.begin(),
|
||||||
|
std::plus<double>());
|
||||||
|
}
|
||||||
|
params_.resize (sep);
|
||||||
|
args_.erase (args_.begin());
|
||||||
|
ranges_.erase (ranges_.begin());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
Factor::sumOutLastVariable (void)
|
||||||
|
{
|
||||||
|
Params::iterator first1 = params_.begin();
|
||||||
|
Params::iterator first2 = params_.begin();
|
||||||
|
Params::iterator last = params_.end();
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
while (first2 != last) {
|
||||||
|
// the arguments can be swaped, but that is ok
|
||||||
|
*first1++ = Util::logSum (*first2++, *first2++);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
while (first2 != last) {
|
||||||
|
*first1++ = (*first2++) + (*first2++);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
params_.resize (params_.size() / 2);
|
||||||
|
args_.pop_back();
|
||||||
|
ranges_.pop_back();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
Factor::sumOutArgs (const vector<bool>& mask)
|
||||||
|
{
|
||||||
|
assert (mask.size() == args_.size());
|
||||||
|
size_t new_size = 1;
|
||||||
|
Ranges oldRanges = ranges_;
|
||||||
|
args_.clear();
|
||||||
|
ranges_.clear();
|
||||||
|
for (unsigned i = 0; i < mask.size(); i++) {
|
||||||
|
if (mask[i]) {
|
||||||
|
new_size *= ranges_[i];
|
||||||
|
args_.push_back (args_[i]);
|
||||||
|
ranges_.push_back (ranges_[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Params newps (new_size, LogAware::addIdenty());
|
||||||
|
Params::const_iterator first = params_.begin();
|
||||||
|
Params::const_iterator last = params_.end();
|
||||||
|
MapIndexer indexer (oldRanges, mask);
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
while (first != last) {
|
||||||
|
newps[indexer] = Util::logSum (newps[indexer], *first++);
|
||||||
|
++ indexer;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
while (first != last) {
|
||||||
|
newps[indexer] += *first++;
|
||||||
|
++ indexer;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
params_ = newps;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
Factor::clone (const Factor& g)
|
||||||
|
{
|
||||||
|
args_ = g.arguments();
|
||||||
|
ranges_ = g.ranges();
|
||||||
|
params_ = g.params();
|
||||||
|
distId_ = g.distId();
|
||||||
|
}
|
||||||
|
|
294
packages/CLPBN/horus2/Factor.h
Normal file
294
packages/CLPBN/horus2/Factor.h
Normal file
@ -0,0 +1,294 @@
|
|||||||
|
#ifndef HORUS_FACTOR_H
|
||||||
|
#define HORUS_FACTOR_H
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "Var.h"
|
||||||
|
#include "Indexer.h"
|
||||||
|
#include "Util.h"
|
||||||
|
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class TFactor
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
const vector<T>& arguments (void) const { return args_; }
|
||||||
|
|
||||||
|
vector<T>& arguments (void) { return args_; }
|
||||||
|
|
||||||
|
const Ranges& ranges (void) const { return ranges_; }
|
||||||
|
|
||||||
|
const Params& params (void) const { return params_; }
|
||||||
|
|
||||||
|
Params& params (void) { return params_; }
|
||||||
|
|
||||||
|
size_t nrArguments (void) const { return args_.size(); }
|
||||||
|
|
||||||
|
size_t size (void) const { return params_.size(); }
|
||||||
|
|
||||||
|
unsigned distId (void) const { return distId_; }
|
||||||
|
|
||||||
|
void setDistId (unsigned id) { distId_ = id; }
|
||||||
|
|
||||||
|
void normalize (void) { LogAware::normalize (params_); }
|
||||||
|
|
||||||
|
void randomize (void)
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < params_.size(); ++i) {
|
||||||
|
params_[i] = (double) std::rand() / RAND_MAX;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void setParams (const Params& newParams)
|
||||||
|
{
|
||||||
|
params_ = newParams;
|
||||||
|
assert (params_.size() == Util::sizeExpected (ranges_));
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t indexOf (const T& t) const
|
||||||
|
{
|
||||||
|
return Util::indexOf (args_, t);
|
||||||
|
}
|
||||||
|
|
||||||
|
const T& argument (size_t idx) const
|
||||||
|
{
|
||||||
|
assert (idx < args_.size());
|
||||||
|
return args_[idx];
|
||||||
|
}
|
||||||
|
|
||||||
|
T& argument (size_t idx)
|
||||||
|
{
|
||||||
|
assert (idx < args_.size());
|
||||||
|
return args_[idx];
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned range (size_t idx) const
|
||||||
|
{
|
||||||
|
assert (idx < ranges_.size());
|
||||||
|
return ranges_[idx];
|
||||||
|
}
|
||||||
|
|
||||||
|
void multiply (TFactor<T>& g)
|
||||||
|
{
|
||||||
|
if (args_ == g.arguments()) {
|
||||||
|
// optimization
|
||||||
|
Globals::logDomain
|
||||||
|
? params_ += g.params()
|
||||||
|
: params_ *= g.params();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
unsigned range_prod = 1;
|
||||||
|
bool share_arguments = false;
|
||||||
|
const vector<T>& g_args = g.arguments();
|
||||||
|
const Ranges& g_ranges = g.ranges();
|
||||||
|
const Params& g_params = g.params();
|
||||||
|
for (size_t i = 0; i < g_args.size(); i++) {
|
||||||
|
size_t idx = indexOf (g_args[i]);
|
||||||
|
if (idx == args_.size()) {
|
||||||
|
range_prod *= g_ranges[i];
|
||||||
|
args_.push_back (g_args[i]);
|
||||||
|
ranges_.push_back (g_ranges[i]);
|
||||||
|
} else {
|
||||||
|
share_arguments = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (share_arguments == false) {
|
||||||
|
// optimization
|
||||||
|
cartesianProduct (g_params.begin(), g_params.end());
|
||||||
|
} else {
|
||||||
|
extend (range_prod);
|
||||||
|
Params::iterator it = params_.begin();
|
||||||
|
MapIndexer indexer (args_, ranges_, g_args, g_ranges);
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
for (; indexer.valid(); ++it, ++indexer) {
|
||||||
|
*it += g_params[indexer];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (; indexer.valid(); ++it, ++indexer) {
|
||||||
|
*it *= g_params[indexer];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void sumOutIndex (size_t idx)
|
||||||
|
{
|
||||||
|
assert (idx < args_.size());
|
||||||
|
assert (args_.size() > 1);
|
||||||
|
size_t new_size = params_.size() / ranges_[idx];
|
||||||
|
Params newps (new_size, LogAware::addIdenty());
|
||||||
|
Params::const_iterator first = params_.begin();
|
||||||
|
Params::const_iterator last = params_.end();
|
||||||
|
MapIndexer indexer (ranges_, idx);
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
for (; first != last; ++indexer) {
|
||||||
|
newps[indexer] = Util::logSum (newps[indexer], *first++);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (; first != last; ++indexer) {
|
||||||
|
newps[indexer] += *first++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
params_ = newps;
|
||||||
|
args_.erase (args_.begin() + idx);
|
||||||
|
ranges_.erase (ranges_.begin() + idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
void absorveEvidence (const T& arg, unsigned obsIdx)
|
||||||
|
{
|
||||||
|
size_t idx = indexOf (arg);
|
||||||
|
assert (idx != args_.size());
|
||||||
|
assert (obsIdx < ranges_[idx]);
|
||||||
|
Params newps;
|
||||||
|
newps.reserve (params_.size() / ranges_[idx]);
|
||||||
|
Indexer indexer (ranges_);
|
||||||
|
for (unsigned i = 0; i < obsIdx; ++i) {
|
||||||
|
indexer.incrementDimension (idx);
|
||||||
|
}
|
||||||
|
while (indexer.valid()) {
|
||||||
|
newps.push_back (params_[indexer]);
|
||||||
|
indexer.incrementExceptDimension (idx);
|
||||||
|
}
|
||||||
|
params_ = newps;
|
||||||
|
args_.erase (args_.begin() + idx);
|
||||||
|
ranges_.erase (ranges_.begin() + idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
void reorderArguments (const vector<T> new_args)
|
||||||
|
{
|
||||||
|
assert (new_args.size() == args_.size());
|
||||||
|
if (new_args == args_) {
|
||||||
|
return; // already on the desired order
|
||||||
|
}
|
||||||
|
Ranges new_ranges;
|
||||||
|
for (size_t i = 0; i < new_args.size(); i++) {
|
||||||
|
size_t idx = indexOf (new_args[i]);
|
||||||
|
assert (idx != args_.size());
|
||||||
|
new_ranges.push_back (ranges_[idx]);
|
||||||
|
}
|
||||||
|
Params newps;
|
||||||
|
newps.reserve (params_.size());
|
||||||
|
MapIndexer indexer (new_args, new_ranges, args_, ranges_);
|
||||||
|
for (; indexer.valid(); ++indexer) {
|
||||||
|
newps.push_back (params_[indexer]);
|
||||||
|
}
|
||||||
|
params_ = newps;
|
||||||
|
args_ = new_args;
|
||||||
|
ranges_ = new_ranges;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool contains (const T& arg) const
|
||||||
|
{
|
||||||
|
return Util::contains (args_, arg);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool contains (const vector<T>& args) const
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < args.size(); i++) {
|
||||||
|
if (contains (args[i]) == false) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
double& operator[] (size_t idx)
|
||||||
|
{
|
||||||
|
assert (idx < params_.size());
|
||||||
|
return params_[idx];
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
protected:
|
||||||
|
vector<T> args_;
|
||||||
|
Ranges ranges_;
|
||||||
|
Params params_;
|
||||||
|
unsigned distId_;
|
||||||
|
|
||||||
|
private:
|
||||||
|
void extend (unsigned range_prod)
|
||||||
|
{
|
||||||
|
Params backup = params_;
|
||||||
|
params_.clear();
|
||||||
|
params_.reserve (backup.size() * range_prod);
|
||||||
|
Params::const_iterator first = backup.begin();
|
||||||
|
Params::const_iterator last = backup.end();
|
||||||
|
for (; first != last; ++first) {
|
||||||
|
for (unsigned reps = 0; reps < range_prod; ++reps) {
|
||||||
|
params_.push_back (*first);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void cartesianProduct (
|
||||||
|
Params::const_iterator first2,
|
||||||
|
Params::const_iterator last2)
|
||||||
|
{
|
||||||
|
Params backup = params_;
|
||||||
|
params_.clear();
|
||||||
|
params_.reserve (params_.size() * (last2 - first2));
|
||||||
|
Params::const_iterator first1 = backup.begin();
|
||||||
|
Params::const_iterator last1 = backup.end();
|
||||||
|
Params::const_iterator tmp;
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
for (; first1 != last1; ++first1) {
|
||||||
|
for (tmp = first2; tmp != last2; ++tmp) {
|
||||||
|
params_.push_back ((*first1) + (*tmp));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (; first1 != last1; ++first1) {
|
||||||
|
for (tmp = first2; tmp != last2; ++tmp) {
|
||||||
|
params_.push_back ((*first1) * (*tmp));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Factor : public TFactor<VarId>
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
Factor (void) { }
|
||||||
|
|
||||||
|
Factor (const Factor&);
|
||||||
|
|
||||||
|
Factor (const VarIds&, const Ranges&, const Params&,
|
||||||
|
unsigned = Util::maxUnsigned());
|
||||||
|
|
||||||
|
Factor (const Vars&, const Params&,
|
||||||
|
unsigned = Util::maxUnsigned());
|
||||||
|
|
||||||
|
void sumOut (VarId);
|
||||||
|
|
||||||
|
void sumOutAllExcept (VarId);
|
||||||
|
|
||||||
|
void sumOutAllExcept (const VarIds&);
|
||||||
|
|
||||||
|
void sumOutAllExceptIndex (size_t idx);
|
||||||
|
|
||||||
|
void multiply (Factor&);
|
||||||
|
|
||||||
|
string getLabel (void) const;
|
||||||
|
|
||||||
|
void print (void) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
void sumOutFirstVariable (void);
|
||||||
|
|
||||||
|
void sumOutLastVariable (void);
|
||||||
|
|
||||||
|
void sumOutArgs (const vector<bool>& mask);
|
||||||
|
|
||||||
|
void clone (const Factor& f);
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // HORUS_FACTOR_H
|
||||||
|
|
454
packages/CLPBN/horus2/FactorGraph.cpp
Normal file
454
packages/CLPBN/horus2/FactorGraph.cpp
Normal file
@ -0,0 +1,454 @@
|
|||||||
|
#include <set>
|
||||||
|
#include <vector>
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <fstream>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "FactorGraph.h"
|
||||||
|
#include "Factor.h"
|
||||||
|
#include "BayesBall.h"
|
||||||
|
#include "Util.h"
|
||||||
|
|
||||||
|
|
||||||
|
FactorGraph::FactorGraph (const FactorGraph& fg)
|
||||||
|
{
|
||||||
|
const VarNodes& varNodes = fg.varNodes();
|
||||||
|
for (size_t i = 0; i < varNodes.size(); i++) {
|
||||||
|
addVarNode (new VarNode (varNodes[i]));
|
||||||
|
}
|
||||||
|
const FacNodes& facNodes = fg.facNodes();
|
||||||
|
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||||
|
FacNode* facNode = new FacNode (facNodes[i]->factor());
|
||||||
|
addFacNode (facNode);
|
||||||
|
const VarNodes& neighs = facNodes[i]->neighbors();
|
||||||
|
for (size_t j = 0; j < neighs.size(); j++) {
|
||||||
|
addEdge (varNodes_[neighs[j]->getIndex()], facNode);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
bayesFactors_ = fg.bayesianFactors();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
FactorGraph::readFromUaiFormat (const char* fileName)
|
||||||
|
{
|
||||||
|
std::ifstream is (fileName);
|
||||||
|
if (!is.is_open()) {
|
||||||
|
cerr << "Error: couldn't open file '" << fileName << "'." ;
|
||||||
|
exit (EXIT_FAILURE);
|
||||||
|
}
|
||||||
|
ignoreLines (is);
|
||||||
|
string line;
|
||||||
|
getline (is, line);
|
||||||
|
if (line != "MARKOV") {
|
||||||
|
cerr << "Error: the network must be a MARKOV network." << endl;
|
||||||
|
exit (EXIT_FAILURE);
|
||||||
|
}
|
||||||
|
// read the number of vars
|
||||||
|
ignoreLines (is);
|
||||||
|
unsigned nrVars;
|
||||||
|
is >> nrVars;
|
||||||
|
// read the range of each var
|
||||||
|
ignoreLines (is);
|
||||||
|
Ranges ranges (nrVars);
|
||||||
|
for (unsigned i = 0; i < nrVars; i++) {
|
||||||
|
is >> ranges[i];
|
||||||
|
}
|
||||||
|
unsigned nrFactors;
|
||||||
|
unsigned nrArgs;
|
||||||
|
unsigned vid;
|
||||||
|
is >> nrFactors;
|
||||||
|
vector<VarIds> factorVarIds;
|
||||||
|
vector<Ranges> factorRanges;
|
||||||
|
for (unsigned i = 0; i < nrFactors; i++) {
|
||||||
|
ignoreLines (is);
|
||||||
|
is >> nrArgs;
|
||||||
|
factorVarIds.push_back ({ });
|
||||||
|
factorRanges.push_back ({ });
|
||||||
|
for (unsigned j = 0; j < nrArgs; j++) {
|
||||||
|
is >> vid;
|
||||||
|
if (vid >= ranges.size()) {
|
||||||
|
cerr << "Error: invalid variable identifier `" << vid << "'. " ;
|
||||||
|
cerr << "Identifiers must be between 0 and " << ranges.size() - 1 ;
|
||||||
|
cerr << "." << endl;
|
||||||
|
exit (EXIT_FAILURE);
|
||||||
|
}
|
||||||
|
factorVarIds.back().push_back (vid);
|
||||||
|
factorRanges.back().push_back (ranges[vid]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// read the parameters
|
||||||
|
unsigned nrParams;
|
||||||
|
for (unsigned i = 0; i < nrFactors; i++) {
|
||||||
|
ignoreLines (is);
|
||||||
|
is >> nrParams;
|
||||||
|
if (nrParams != Util::sizeExpected (factorRanges[i])) {
|
||||||
|
cerr << "Error: invalid number of parameters for factor nº " << i ;
|
||||||
|
cerr << ", " << Util::sizeExpected (factorRanges[i]);
|
||||||
|
cerr << " expected, " << nrParams << " given." << endl;
|
||||||
|
exit (EXIT_FAILURE);
|
||||||
|
}
|
||||||
|
Params params (nrParams);
|
||||||
|
for (unsigned j = 0; j < nrParams; j++) {
|
||||||
|
is >> params[j];
|
||||||
|
}
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
Util::log (params);
|
||||||
|
}
|
||||||
|
addFactor (Factor (factorVarIds[i], factorRanges[i], params));
|
||||||
|
}
|
||||||
|
is.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
FactorGraph::readFromLibDaiFormat (const char* fileName)
|
||||||
|
{
|
||||||
|
std::ifstream is (fileName);
|
||||||
|
if (!is.is_open()) {
|
||||||
|
cerr << "Error: couldn't open file '" << fileName << "'." ;
|
||||||
|
exit (EXIT_FAILURE);
|
||||||
|
}
|
||||||
|
ignoreLines (is);
|
||||||
|
unsigned nrFactors;
|
||||||
|
unsigned nrArgs;
|
||||||
|
VarId vid;
|
||||||
|
is >> nrFactors;
|
||||||
|
for (unsigned i = 0; i < nrFactors; i++) {
|
||||||
|
ignoreLines (is);
|
||||||
|
// read the factor arguments
|
||||||
|
is >> nrArgs;
|
||||||
|
VarIds vids;
|
||||||
|
for (unsigned j = 0; j < nrArgs; j++) {
|
||||||
|
ignoreLines (is);
|
||||||
|
is >> vid;
|
||||||
|
vids.push_back (vid);
|
||||||
|
}
|
||||||
|
// read ranges
|
||||||
|
Ranges ranges (nrArgs);
|
||||||
|
for (unsigned j = 0; j < nrArgs; j++) {
|
||||||
|
ignoreLines (is);
|
||||||
|
is >> ranges[j];
|
||||||
|
VarNode* var = getVarNode (vids[j]);
|
||||||
|
if (var != 0 && ranges[j] != var->range()) {
|
||||||
|
cerr << "Error: variable `" << vids[j] << "' appears in two or " ;
|
||||||
|
cerr << "more factors with a different range." << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// read parameters
|
||||||
|
ignoreLines (is);
|
||||||
|
unsigned nNonzeros;
|
||||||
|
is >> nNonzeros;
|
||||||
|
Params params (Util::sizeExpected (ranges), 0);
|
||||||
|
for (unsigned j = 0; j < nNonzeros; j++) {
|
||||||
|
ignoreLines (is);
|
||||||
|
unsigned index;
|
||||||
|
is >> index;
|
||||||
|
ignoreLines (is);
|
||||||
|
double val;
|
||||||
|
is >> val;
|
||||||
|
params[index] = val;
|
||||||
|
}
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
Util::log (params);
|
||||||
|
}
|
||||||
|
std::reverse (vids.begin(), vids.end());
|
||||||
|
Factor f (vids, ranges, params);
|
||||||
|
std::reverse (vids.begin(), vids.end());
|
||||||
|
f.reorderArguments (vids);
|
||||||
|
addFactor (f);
|
||||||
|
}
|
||||||
|
is.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
FactorGraph::~FactorGraph (void)
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < varNodes_.size(); i++) {
|
||||||
|
delete varNodes_[i];
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < facNodes_.size(); i++) {
|
||||||
|
delete facNodes_[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
FactorGraph::addFactor (const Factor& factor)
|
||||||
|
{
|
||||||
|
FacNode* fn = new FacNode (factor);
|
||||||
|
addFacNode (fn);
|
||||||
|
const VarIds& vids = fn->factor().arguments();
|
||||||
|
for (size_t i = 0; i < vids.size(); i++) {
|
||||||
|
VarMap::const_iterator it = varMap_.find (vids[i]);
|
||||||
|
if (it != varMap_.end()) {
|
||||||
|
addEdge (it->second, fn);
|
||||||
|
} else {
|
||||||
|
VarNode* vn = new VarNode (vids[i], fn->factor().range (i));
|
||||||
|
addVarNode (vn);
|
||||||
|
addEdge (vn, fn);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
FactorGraph::addVarNode (VarNode* vn)
|
||||||
|
{
|
||||||
|
varNodes_.push_back (vn);
|
||||||
|
vn->setIndex (varNodes_.size() - 1);
|
||||||
|
varMap_.insert (make_pair (vn->varId(), vn));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
FactorGraph::addFacNode (FacNode* fn)
|
||||||
|
{
|
||||||
|
facNodes_.push_back (fn);
|
||||||
|
fn->setIndex (facNodes_.size() - 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
FactorGraph::addEdge (VarNode* vn, FacNode* fn)
|
||||||
|
{
|
||||||
|
vn->addNeighbor (fn);
|
||||||
|
fn->addNeighbor (vn);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
FactorGraph::isTree (void) const
|
||||||
|
{
|
||||||
|
return !containsCycle();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
BayesBallGraph&
|
||||||
|
FactorGraph::getStructure (void)
|
||||||
|
{
|
||||||
|
assert (bayesFactors_);
|
||||||
|
if (structure_.empty()) {
|
||||||
|
for (size_t i = 0; i < varNodes_.size(); i++) {
|
||||||
|
structure_.addNode (new BBNode (varNodes_[i]));
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < facNodes_.size(); i++) {
|
||||||
|
const VarIds& vids = facNodes_[i]->factor().arguments();
|
||||||
|
for (size_t j = 1; j < vids.size(); j++) {
|
||||||
|
structure_.addEdge (vids[j], vids[0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return structure_;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
FactorGraph::print (void) const
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < varNodes_.size(); i++) {
|
||||||
|
cout << "var id = " << varNodes_[i]->varId() << endl;
|
||||||
|
cout << "label = " << varNodes_[i]->label() << endl;
|
||||||
|
cout << "range = " << varNodes_[i]->range() << endl;
|
||||||
|
cout << "evidence = " << varNodes_[i]->getEvidence() << endl;
|
||||||
|
cout << "factors = " ;
|
||||||
|
for (size_t j = 0; j < varNodes_[i]->neighbors().size(); j++) {
|
||||||
|
cout << varNodes_[i]->neighbors()[j]->getLabel() << " " ;
|
||||||
|
}
|
||||||
|
cout << endl << endl;
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < facNodes_.size(); i++) {
|
||||||
|
facNodes_[i]->factor().print();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
FactorGraph::exportToGraphViz (const char* fileName) const
|
||||||
|
{
|
||||||
|
ofstream out (fileName);
|
||||||
|
if (!out.is_open()) {
|
||||||
|
cerr << "Error: couldn't open file '" << fileName << "'." ;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
out << "graph \"" << fileName << "\" {" << endl;
|
||||||
|
for (size_t i = 0; i < varNodes_.size(); i++) {
|
||||||
|
if (varNodes_[i]->hasEvidence()) {
|
||||||
|
out << '"' << varNodes_[i]->label() << '"' ;
|
||||||
|
out << " [style=filled, fillcolor=yellow]" << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < facNodes_.size(); i++) {
|
||||||
|
out << '"' << facNodes_[i]->getLabel() << '"' ;
|
||||||
|
out << " [label=\"" << facNodes_[i]->getLabel();
|
||||||
|
out << "\"" << ", shape=box]" << endl;
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < facNodes_.size(); i++) {
|
||||||
|
const VarNodes& myVars = facNodes_[i]->neighbors();
|
||||||
|
for (size_t j = 0; j < myVars.size(); j++) {
|
||||||
|
out << '"' << facNodes_[i]->getLabel() << '"' ;
|
||||||
|
out << " -- " ;
|
||||||
|
out << '"' << myVars[j]->label() << '"' << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out << "}" << endl;
|
||||||
|
out.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
FactorGraph::exportToUaiFormat (const char* fileName) const
|
||||||
|
{
|
||||||
|
ofstream out (fileName);
|
||||||
|
if (!out.is_open()) {
|
||||||
|
cerr << "Error: couldn't open file '" << fileName << "'." ;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
out << "MARKOV" << endl;
|
||||||
|
out << varNodes_.size() << endl;
|
||||||
|
VarNodes sortedVns = varNodes_;
|
||||||
|
std::sort (sortedVns.begin(), sortedVns.end(), sortByVarId());
|
||||||
|
for (size_t i = 0; i < sortedVns.size(); i++) {
|
||||||
|
out << ((i != 0) ? " " : "") << sortedVns[i]->range();
|
||||||
|
}
|
||||||
|
out << endl << facNodes_.size() << endl;
|
||||||
|
for (size_t i = 0; i < facNodes_.size(); i++) {
|
||||||
|
VarIds args = facNodes_[i]->factor().arguments();
|
||||||
|
out << args.size() << " " << Util::elementsToString (args) << endl;
|
||||||
|
}
|
||||||
|
out << endl;
|
||||||
|
for (size_t i = 0; i < facNodes_.size(); i++) {
|
||||||
|
Params params = facNodes_[i]->factor().params();
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
Util::exp (params);
|
||||||
|
}
|
||||||
|
out << params.size() << endl << " " ;
|
||||||
|
out << Util::elementsToString (params) << endl << endl;
|
||||||
|
}
|
||||||
|
out.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
FactorGraph::exportToLibDaiFormat (const char* fileName) const
|
||||||
|
{
|
||||||
|
ofstream out (fileName);
|
||||||
|
if (!out.is_open()) {
|
||||||
|
cerr << "Error: couldn't open file '" << fileName << "'." ;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
out << facNodes_.size() << endl << endl;
|
||||||
|
for (size_t i = 0; i < facNodes_.size(); i++) {
|
||||||
|
Factor f (facNodes_[i]->factor());
|
||||||
|
out << f.nrArguments() << endl;
|
||||||
|
out << Util::elementsToString (f.arguments()) << endl;
|
||||||
|
out << Util::elementsToString (f.ranges()) << endl;
|
||||||
|
VarIds args = f.arguments();
|
||||||
|
std::reverse (args.begin(), args.end());
|
||||||
|
f.reorderArguments (args);
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
Util::exp (f.params());
|
||||||
|
}
|
||||||
|
out << f.size() << endl;
|
||||||
|
for (size_t j = 0; j < f.size(); j++) {
|
||||||
|
out << j << " " << f[j] << endl;
|
||||||
|
}
|
||||||
|
out << endl;
|
||||||
|
}
|
||||||
|
out.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
FactorGraph::ignoreLines (std::ifstream& is) const
|
||||||
|
{
|
||||||
|
string ignoreStr;
|
||||||
|
while (is.peek() == '#' || is.peek() == '\n') {
|
||||||
|
getline (is, ignoreStr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
FactorGraph::containsCycle (void) const
|
||||||
|
{
|
||||||
|
vector<bool> visitedVars (varNodes_.size(), false);
|
||||||
|
vector<bool> visitedFactors (facNodes_.size(), false);
|
||||||
|
for (size_t i = 0; i < varNodes_.size(); i++) {
|
||||||
|
int v = varNodes_[i]->getIndex();
|
||||||
|
if (!visitedVars[v]) {
|
||||||
|
if (containsCycle (varNodes_[i], 0, visitedVars, visitedFactors)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
FactorGraph::containsCycle (
|
||||||
|
const VarNode* v,
|
||||||
|
const FacNode* p,
|
||||||
|
vector<bool>& visitedVars,
|
||||||
|
vector<bool>& visitedFactors) const
|
||||||
|
{
|
||||||
|
visitedVars[v->getIndex()] = true;
|
||||||
|
const FacNodes& adjacencies = v->neighbors();
|
||||||
|
for (size_t i = 0; i < adjacencies.size(); i++) {
|
||||||
|
int w = adjacencies[i]->getIndex();
|
||||||
|
if (!visitedFactors[w]) {
|
||||||
|
if (containsCycle (adjacencies[i], v, visitedVars, visitedFactors)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if (visitedFactors[w] && adjacencies[i] != p) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false; // no cycle detected in this component
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
FactorGraph::containsCycle (
|
||||||
|
const FacNode* v,
|
||||||
|
const VarNode* p,
|
||||||
|
vector<bool>& visitedVars,
|
||||||
|
vector<bool>& visitedFactors) const
|
||||||
|
{
|
||||||
|
visitedFactors[v->getIndex()] = true;
|
||||||
|
const VarNodes& adjacencies = v->neighbors();
|
||||||
|
for (size_t i = 0; i < adjacencies.size(); i++) {
|
||||||
|
int w = adjacencies[i]->getIndex();
|
||||||
|
if (!visitedVars[w]) {
|
||||||
|
if (containsCycle (adjacencies[i], v, visitedVars, visitedFactors)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if (visitedVars[w] && adjacencies[i] != p) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false; // no cycle detected in this component
|
||||||
|
}
|
||||||
|
|
150
packages/CLPBN/horus2/FactorGraph.h
Normal file
150
packages/CLPBN/horus2/FactorGraph.h
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
#ifndef HORUS_FACTORGRAPH_H
|
||||||
|
#define HORUS_FACTORGRAPH_H
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "Factor.h"
|
||||||
|
#include "BayesBallGraph.h"
|
||||||
|
#include "Horus.h"
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
|
||||||
|
class FacNode;
|
||||||
|
|
||||||
|
class VarNode : public Var
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
VarNode (VarId varId, unsigned nrStates,
|
||||||
|
int evidence = Constants::NO_EVIDENCE)
|
||||||
|
: Var (varId, nrStates, evidence) { }
|
||||||
|
|
||||||
|
VarNode (const Var* v) : Var (v) { }
|
||||||
|
|
||||||
|
void addNeighbor (FacNode* fn) { neighs_.push_back (fn); }
|
||||||
|
|
||||||
|
const FacNodes& neighbors (void) const { return neighs_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
DISALLOW_COPY_AND_ASSIGN (VarNode);
|
||||||
|
|
||||||
|
FacNodes neighs_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class FacNode
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
FacNode (const Factor& f) : factor_(f), index_(-1) { }
|
||||||
|
|
||||||
|
const Factor& factor (void) const { return factor_; }
|
||||||
|
|
||||||
|
Factor& factor (void) { return factor_; }
|
||||||
|
|
||||||
|
void addNeighbor (VarNode* vn) { neighs_.push_back (vn); }
|
||||||
|
|
||||||
|
const VarNodes& neighbors (void) const { return neighs_; }
|
||||||
|
|
||||||
|
size_t getIndex (void) const { return index_; }
|
||||||
|
|
||||||
|
void setIndex (size_t index) { index_ = index; }
|
||||||
|
|
||||||
|
string getLabel (void) { return factor_.getLabel(); }
|
||||||
|
|
||||||
|
private:
|
||||||
|
DISALLOW_COPY_AND_ASSIGN (FacNode);
|
||||||
|
|
||||||
|
VarNodes neighs_;
|
||||||
|
Factor factor_;
|
||||||
|
size_t index_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class FactorGraph
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
FactorGraph (void) : bayesFactors_(false) { }
|
||||||
|
|
||||||
|
FactorGraph (const FactorGraph&);
|
||||||
|
|
||||||
|
~FactorGraph (void);
|
||||||
|
|
||||||
|
const VarNodes& varNodes (void) const { return varNodes_; }
|
||||||
|
|
||||||
|
const FacNodes& facNodes (void) const { return facNodes_; }
|
||||||
|
|
||||||
|
void setFactorsAsBayesian (void) { bayesFactors_ = true; }
|
||||||
|
|
||||||
|
bool bayesianFactors (void) const { return bayesFactors_; }
|
||||||
|
|
||||||
|
size_t nrVarNodes (void) const { return varNodes_.size(); }
|
||||||
|
|
||||||
|
size_t nrFacNodes (void) const { return facNodes_.size(); }
|
||||||
|
|
||||||
|
VarNode* getVarNode (VarId vid) const
|
||||||
|
{
|
||||||
|
VarMap::const_iterator it = varMap_.find (vid);
|
||||||
|
return it != varMap_.end() ? it->second : 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void readFromUaiFormat (const char*);
|
||||||
|
|
||||||
|
void readFromLibDaiFormat (const char*);
|
||||||
|
|
||||||
|
void addFactor (const Factor& factor);
|
||||||
|
|
||||||
|
void addVarNode (VarNode*);
|
||||||
|
|
||||||
|
void addFacNode (FacNode*);
|
||||||
|
|
||||||
|
void addEdge (VarNode*, FacNode*);
|
||||||
|
|
||||||
|
bool isTree (void) const;
|
||||||
|
|
||||||
|
BayesBallGraph& getStructure (void);
|
||||||
|
|
||||||
|
void print (void) const;
|
||||||
|
|
||||||
|
void exportToGraphViz (const char*) const;
|
||||||
|
|
||||||
|
void exportToUaiFormat (const char*) const;
|
||||||
|
|
||||||
|
void exportToLibDaiFormat (const char*) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
// DISALLOW_COPY_AND_ASSIGN (FactorGraph);
|
||||||
|
|
||||||
|
void ignoreLines (std::ifstream&) const;
|
||||||
|
|
||||||
|
bool containsCycle (void) const;
|
||||||
|
|
||||||
|
bool containsCycle (const VarNode*, const FacNode*,
|
||||||
|
vector<bool>&, vector<bool>&) const;
|
||||||
|
|
||||||
|
bool containsCycle (const FacNode*, const VarNode*,
|
||||||
|
vector<bool>&, vector<bool>&) const;
|
||||||
|
|
||||||
|
VarNodes varNodes_;
|
||||||
|
FacNodes facNodes_;
|
||||||
|
|
||||||
|
BayesBallGraph structure_;
|
||||||
|
bool bayesFactors_;
|
||||||
|
|
||||||
|
typedef unordered_map<unsigned, VarNode*> VarMap;
|
||||||
|
VarMap varMap_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
struct sortByVarId
|
||||||
|
{
|
||||||
|
bool operator()(VarNode* vn1, VarNode* vn2) {
|
||||||
|
return vn1->varId() < vn2->varId();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
#endif // HORUS_FACTORGRAPH_H
|
||||||
|
|
107
packages/CLPBN/horus2/GroundSolver.cpp
Normal file
107
packages/CLPBN/horus2/GroundSolver.cpp
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
#include "GroundSolver.h"
|
||||||
|
#include "Util.h"
|
||||||
|
#include "BeliefProp.h"
|
||||||
|
#include "CountingBp.h"
|
||||||
|
#include "VarElim.h"
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
GroundSolver::printAnswer (const VarIds& vids)
|
||||||
|
{
|
||||||
|
Vars unobservedVars;
|
||||||
|
VarIds unobservedVids;
|
||||||
|
for (size_t i = 0; i < vids.size(); i++) {
|
||||||
|
VarNode* vn = fg.getVarNode (vids[i]);
|
||||||
|
if (vn->hasEvidence() == false) {
|
||||||
|
unobservedVars.push_back (vn);
|
||||||
|
unobservedVids.push_back (vids[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (unobservedVids.empty() == false) {
|
||||||
|
Params res = solveQuery (unobservedVids);
|
||||||
|
vector<string> stateLines = Util::getStateLines (unobservedVars);
|
||||||
|
for (size_t i = 0; i < res.size(); i++) {
|
||||||
|
cout << "P(" << stateLines[i] << ") = " ;
|
||||||
|
cout << std::setprecision (Constants::PRECISION) << res[i];
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
GroundSolver::printAllPosterioris (void)
|
||||||
|
{
|
||||||
|
VarNodes vars = fg.varNodes();
|
||||||
|
std::sort (vars.begin(), vars.end(), sortByVarId());
|
||||||
|
for (size_t i = 0; i < vars.size(); i++) {
|
||||||
|
printAnswer ({vars[i]->varId()});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Params
|
||||||
|
GroundSolver::getJointByConditioning (
|
||||||
|
GroundSolverType solverType,
|
||||||
|
FactorGraph fg,
|
||||||
|
const VarIds& jointVarIds) const
|
||||||
|
{
|
||||||
|
VarNodes jointVars;
|
||||||
|
for (size_t i = 0; i < jointVarIds.size(); i++) {
|
||||||
|
assert (fg.getVarNode (jointVarIds[i]));
|
||||||
|
jointVars.push_back (fg.getVarNode (jointVarIds[i]));
|
||||||
|
}
|
||||||
|
|
||||||
|
GroundSolver* solver = 0;
|
||||||
|
switch (solverType) {
|
||||||
|
case GroundSolverType::BP: solver = new BeliefProp (fg); break;
|
||||||
|
case GroundSolverType::CBP: solver = new CountingBp (fg); break;
|
||||||
|
case GroundSolverType::VE: solver = new VarElim (fg); break;
|
||||||
|
}
|
||||||
|
Params prevBeliefs = solver->solveQuery ({jointVarIds[0]});
|
||||||
|
VarIds observedVids = {jointVars[0]->varId()};
|
||||||
|
|
||||||
|
for (size_t i = 1; i < jointVarIds.size(); i++) {
|
||||||
|
assert (jointVars[i]->hasEvidence() == false);
|
||||||
|
Params newBeliefs;
|
||||||
|
Vars observedVars;
|
||||||
|
Ranges observedRanges;
|
||||||
|
for (size_t j = 0; j < observedVids.size(); j++) {
|
||||||
|
observedVars.push_back (fg.getVarNode (observedVids[j]));
|
||||||
|
observedRanges.push_back (observedVars.back()->range());
|
||||||
|
}
|
||||||
|
Indexer indexer (observedRanges, false);
|
||||||
|
while (indexer.valid()) {
|
||||||
|
for (size_t j = 0; j < observedVars.size(); j++) {
|
||||||
|
observedVars[j]->setEvidence (indexer[j]);
|
||||||
|
}
|
||||||
|
delete solver;
|
||||||
|
switch (solverType) {
|
||||||
|
case GroundSolverType::BP: solver = new BeliefProp (fg); break;
|
||||||
|
case GroundSolverType::CBP: solver = new CountingBp (fg); break;
|
||||||
|
case GroundSolverType::VE: solver = new VarElim (fg); break;
|
||||||
|
}
|
||||||
|
Params beliefs = solver->solveQuery ({jointVarIds[i]});
|
||||||
|
for (size_t k = 0; k < beliefs.size(); k++) {
|
||||||
|
newBeliefs.push_back (beliefs[k]);
|
||||||
|
}
|
||||||
|
++ indexer;
|
||||||
|
}
|
||||||
|
|
||||||
|
int count = -1;
|
||||||
|
for (size_t j = 0; j < newBeliefs.size(); j++) {
|
||||||
|
if (j % jointVars[i]->range() == 0) {
|
||||||
|
count ++;
|
||||||
|
}
|
||||||
|
newBeliefs[j] *= prevBeliefs[count];
|
||||||
|
}
|
||||||
|
prevBeliefs = newBeliefs;
|
||||||
|
observedVids.push_back (jointVars[i]->varId());
|
||||||
|
}
|
||||||
|
delete solver;
|
||||||
|
return prevBeliefs;
|
||||||
|
}
|
||||||
|
|
36
packages/CLPBN/horus2/GroundSolver.h
Normal file
36
packages/CLPBN/horus2/GroundSolver.h
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
#ifndef HORUS_GROUNDSOLVER_H
|
||||||
|
#define HORUS_GROUNDSOLVER_H
|
||||||
|
|
||||||
|
#include <iomanip>
|
||||||
|
|
||||||
|
#include "FactorGraph.h"
|
||||||
|
#include "Var.h"
|
||||||
|
#include "Horus.h"
|
||||||
|
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
class GroundSolver
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
GroundSolver (const FactorGraph& factorGraph) : fg(factorGraph) { }
|
||||||
|
|
||||||
|
virtual ~GroundSolver() { } // ensure that subclass destructor is called
|
||||||
|
|
||||||
|
virtual Params solveQuery (VarIds queryVids) = 0;
|
||||||
|
|
||||||
|
virtual void printSolverFlags (void) const = 0;
|
||||||
|
|
||||||
|
void printAnswer (const VarIds& vids);
|
||||||
|
|
||||||
|
void printAllPosterioris (void);
|
||||||
|
|
||||||
|
Params getJointByConditioning (GroundSolverType,
|
||||||
|
FactorGraph, const VarIds& jointVarIds) const;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
const FactorGraph& fg;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // HORUS_GROUNDSOLVER_H
|
||||||
|
|
146
packages/CLPBN/horus2/Histogram.cpp
Normal file
146
packages/CLPBN/horus2/Histogram.cpp
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
|
#include "Histogram.h"
|
||||||
|
#include "Util.h"
|
||||||
|
|
||||||
|
|
||||||
|
HistogramSet::HistogramSet (unsigned size, unsigned range)
|
||||||
|
{
|
||||||
|
size_ = size;
|
||||||
|
hist_.resize (range, 0);
|
||||||
|
hist_[0] = size;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
HistogramSet::nextHistogram (void)
|
||||||
|
{
|
||||||
|
for (size_t i = hist_.size() - 1; i-- > 0; ) {
|
||||||
|
if (hist_[i] > 0) {
|
||||||
|
hist_[i] --;
|
||||||
|
hist_[i + 1] = maxCount (i + 1);
|
||||||
|
clearAfter (i + 1);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert (std::accumulate (hist_.begin(), hist_.end(), 0)
|
||||||
|
== (int) size_);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
unsigned
|
||||||
|
HistogramSet::operator[] (size_t idx) const
|
||||||
|
{
|
||||||
|
assert (idx < hist_.size());
|
||||||
|
return hist_[idx];
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
unsigned
|
||||||
|
HistogramSet::nrHistograms (void) const
|
||||||
|
{
|
||||||
|
return HistogramSet::nrHistograms (size_, hist_.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
HistogramSet::reset (void)
|
||||||
|
{
|
||||||
|
std::fill (hist_.begin() + 1, hist_.end(), 0);
|
||||||
|
hist_[0] = size_;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
vector<Histogram>
|
||||||
|
HistogramSet::getHistograms (unsigned N, unsigned R)
|
||||||
|
{
|
||||||
|
HistogramSet hs (N, R);
|
||||||
|
unsigned H = hs.nrHistograms();
|
||||||
|
vector<Histogram> histograms;
|
||||||
|
histograms.reserve (H);
|
||||||
|
for (unsigned i = 0; i < H; i++) {
|
||||||
|
histograms.push_back (hs.hist_);
|
||||||
|
hs.nextHistogram();
|
||||||
|
}
|
||||||
|
return histograms;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
unsigned
|
||||||
|
HistogramSet::nrHistograms (unsigned N, unsigned R)
|
||||||
|
{
|
||||||
|
return Util::nrCombinations (N + R - 1, R - 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
size_t
|
||||||
|
HistogramSet::findIndex (
|
||||||
|
const Histogram& h,
|
||||||
|
const vector<Histogram>& hists)
|
||||||
|
{
|
||||||
|
vector<Histogram>::const_iterator it = std::lower_bound (
|
||||||
|
hists.begin(), hists.end(), h, std::greater<Histogram>());
|
||||||
|
assert (it != hists.end() && *it == h);
|
||||||
|
return std::distance (hists.begin(), it);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
vector<double>
|
||||||
|
HistogramSet::getNumAssigns (unsigned N, unsigned R)
|
||||||
|
{
|
||||||
|
HistogramSet hs (N, R);
|
||||||
|
double N_fac = Util::logFactorial (N);
|
||||||
|
unsigned H = hs.nrHistograms();
|
||||||
|
vector<double> numAssigns;
|
||||||
|
numAssigns.reserve (H);
|
||||||
|
for (unsigned h = 0; h < H; h++) {
|
||||||
|
double prod = 0.0;
|
||||||
|
for (unsigned r = 0; r < R; r++) {
|
||||||
|
prod += Util::logFactorial (hs[r]);
|
||||||
|
}
|
||||||
|
double res = N_fac - prod;
|
||||||
|
numAssigns.push_back (Globals::logDomain ? res : std::exp (res));
|
||||||
|
hs.nextHistogram();
|
||||||
|
}
|
||||||
|
return numAssigns;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
ostream& operator<< (ostream &os, const HistogramSet& hs)
|
||||||
|
{
|
||||||
|
os << "#" << hs.hist_;
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
unsigned
|
||||||
|
HistogramSet::maxCount (size_t idx) const
|
||||||
|
{
|
||||||
|
unsigned sum = 0;
|
||||||
|
for (size_t i = 0; i < idx; i++) {
|
||||||
|
sum += hist_[i];
|
||||||
|
}
|
||||||
|
return size_ - sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
HistogramSet::clearAfter (size_t idx)
|
||||||
|
{
|
||||||
|
std::fill (hist_.begin() + idx + 1, hist_.end(), 0);
|
||||||
|
}
|
||||||
|
|
45
packages/CLPBN/horus2/Histogram.h
Normal file
45
packages/CLPBN/horus2/Histogram.h
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
#ifndef HORUS_HISTOGRAM_H
|
||||||
|
#define HORUS_HISTOGRAM_H
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <ostream>
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
typedef vector<unsigned> Histogram;
|
||||||
|
|
||||||
|
class HistogramSet
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
HistogramSet (unsigned, unsigned);
|
||||||
|
|
||||||
|
void nextHistogram (void);
|
||||||
|
|
||||||
|
unsigned operator[] (size_t idx) const;
|
||||||
|
|
||||||
|
unsigned nrHistograms (void) const;
|
||||||
|
|
||||||
|
void reset (void);
|
||||||
|
|
||||||
|
static vector<Histogram> getHistograms (unsigned ,unsigned);
|
||||||
|
|
||||||
|
static unsigned nrHistograms (unsigned, unsigned);
|
||||||
|
|
||||||
|
static size_t findIndex (
|
||||||
|
const Histogram&, const vector<Histogram>&);
|
||||||
|
|
||||||
|
static vector<double> getNumAssigns (unsigned, unsigned);
|
||||||
|
|
||||||
|
friend std::ostream& operator<< (ostream &os, const HistogramSet& hs);
|
||||||
|
|
||||||
|
private:
|
||||||
|
unsigned maxCount (size_t) const;
|
||||||
|
|
||||||
|
void clearAfter (size_t);
|
||||||
|
|
||||||
|
unsigned size_;
|
||||||
|
Histogram hist_;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // HORUS_HISTOGRAM_H
|
||||||
|
|
87
packages/CLPBN/horus2/Horus.h
Normal file
87
packages/CLPBN/horus2/Horus.h
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
#ifndef HORUS_HORUS_H
|
||||||
|
#define HORUS_HORUS_H
|
||||||
|
|
||||||
|
#include <limits>
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#define DISALLOW_COPY_AND_ASSIGN(TypeName) \
|
||||||
|
TypeName(const TypeName&); \
|
||||||
|
void operator=(const TypeName&)
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
class Var;
|
||||||
|
class Factor;
|
||||||
|
class VarNode;
|
||||||
|
class FacNode;
|
||||||
|
|
||||||
|
typedef vector<double> Params;
|
||||||
|
typedef unsigned VarId;
|
||||||
|
typedef vector<VarId> VarIds;
|
||||||
|
typedef vector<Var*> Vars;
|
||||||
|
typedef vector<VarNode*> VarNodes;
|
||||||
|
typedef vector<FacNode*> FacNodes;
|
||||||
|
typedef vector<Factor*> Factors;
|
||||||
|
typedef vector<string> States;
|
||||||
|
typedef vector<unsigned> Ranges;
|
||||||
|
typedef unsigned long long ullong;
|
||||||
|
|
||||||
|
|
||||||
|
enum LiftedSolverType
|
||||||
|
{
|
||||||
|
LVE, // generalized counting first-order variable elimination (GC-FOVE)
|
||||||
|
LBP, // lifted first-order belief propagation
|
||||||
|
LKC // lifted first-order knowledge compilation
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
enum GroundSolverType
|
||||||
|
{
|
||||||
|
VE, // variable elimination
|
||||||
|
BP, // belief propagation
|
||||||
|
CBP // counting belief propagation
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
namespace Globals {
|
||||||
|
|
||||||
|
extern bool logDomain;
|
||||||
|
|
||||||
|
// level of debug information
|
||||||
|
extern unsigned verbosity;
|
||||||
|
|
||||||
|
extern LiftedSolverType liftedSolver;
|
||||||
|
extern GroundSolverType groundSolver;
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
namespace Constants {
|
||||||
|
|
||||||
|
// show message calculation for belief propagation
|
||||||
|
const bool SHOW_BP_CALCS = false;
|
||||||
|
|
||||||
|
const int NO_EVIDENCE = -1;
|
||||||
|
|
||||||
|
// number of digits to show when printing a parameter
|
||||||
|
const unsigned PRECISION = 6;
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
namespace BpOptions
|
||||||
|
{
|
||||||
|
enum Schedule {
|
||||||
|
SEQ_FIXED,
|
||||||
|
SEQ_RANDOM,
|
||||||
|
PARALLEL,
|
||||||
|
MAX_RESIDUAL
|
||||||
|
};
|
||||||
|
extern Schedule schedule;
|
||||||
|
extern double accuracy;
|
||||||
|
extern unsigned maxIter;
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // HORUS_HORUS_H
|
||||||
|
|
187
packages/CLPBN/horus2/HorusCli.cpp
Normal file
187
packages/CLPBN/horus2/HorusCli.cpp
Normal file
@ -0,0 +1,187 @@
|
|||||||
|
#include <cstdlib>
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "FactorGraph.h"
|
||||||
|
#include "VarElim.h"
|
||||||
|
#include "BeliefProp.h"
|
||||||
|
#include "CountingBp.h"
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
int readHorusFlags (int, const char* []);
|
||||||
|
void readFactorGraph (FactorGraph&, const char*);
|
||||||
|
VarIds readQueryAndEvidence (FactorGraph&, int, const char* [], int);
|
||||||
|
|
||||||
|
void runSolver (const FactorGraph&, const VarIds&);
|
||||||
|
|
||||||
|
const string USAGE = "usage: ./hcli [HORUS_FLAG=VALUE] \
|
||||||
|
MODEL_FILE [VARIABLE | OBSERVED_VARIABLE=EVIDENCE] ..." ;
|
||||||
|
|
||||||
|
|
||||||
|
int
|
||||||
|
main (int argc, const char* argv[])
|
||||||
|
{
|
||||||
|
if (argc <= 1) {
|
||||||
|
cerr << "Error: no probabilistic graphical model was given." << endl;
|
||||||
|
cerr << USAGE << endl;
|
||||||
|
exit (EXIT_FAILURE);
|
||||||
|
}
|
||||||
|
int idx = readHorusFlags (argc, argv);
|
||||||
|
FactorGraph fg;
|
||||||
|
readFactorGraph (fg, argv[idx]);
|
||||||
|
VarIds queryIds = readQueryAndEvidence (fg, argc, argv, idx + 1);
|
||||||
|
runSolver (fg, queryIds);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
int
|
||||||
|
readHorusFlags (int argc, const char* argv[])
|
||||||
|
{
|
||||||
|
int i = 1;
|
||||||
|
for (; i < argc; i++) {
|
||||||
|
const string& arg = argv[i];
|
||||||
|
size_t pos = arg.find ('=');
|
||||||
|
if (pos == std::string::npos) {
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
string leftArg = arg.substr (0, pos);
|
||||||
|
string rightArg = arg.substr (pos + 1);
|
||||||
|
if (leftArg.empty()) {
|
||||||
|
cerr << "Error: missing left argument." << endl;
|
||||||
|
cerr << USAGE << endl;
|
||||||
|
exit (EXIT_FAILURE);
|
||||||
|
}
|
||||||
|
if (rightArg.empty()) {
|
||||||
|
cerr << "Error: missing right argument." << endl;
|
||||||
|
cerr << USAGE << endl;
|
||||||
|
exit (EXIT_FAILURE);
|
||||||
|
}
|
||||||
|
Util::setHorusFlag (leftArg, rightArg);
|
||||||
|
}
|
||||||
|
return i + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
readFactorGraph (FactorGraph& fg, const char* s)
|
||||||
|
{
|
||||||
|
string fileName (s);
|
||||||
|
string extension = fileName.substr (fileName.find_last_of ('.') + 1);
|
||||||
|
if (extension == "uai") {
|
||||||
|
fg.readFromUaiFormat (fileName.c_str());
|
||||||
|
} else if (extension == "fg") {
|
||||||
|
fg.readFromLibDaiFormat (fileName.c_str());
|
||||||
|
} else {
|
||||||
|
cerr << "Error: the probabilistic graphical model must be " ;
|
||||||
|
cerr << "defined either in a UAI or libDAI file." << endl;
|
||||||
|
exit (EXIT_FAILURE);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
VarIds
|
||||||
|
readQueryAndEvidence (
|
||||||
|
FactorGraph& fg,
|
||||||
|
int argc,
|
||||||
|
const char* argv[],
|
||||||
|
int start)
|
||||||
|
{
|
||||||
|
VarIds queryIds;
|
||||||
|
for (int i = start; i < argc; i++) {
|
||||||
|
const string& arg = argv[i];
|
||||||
|
if (arg.find ('=') == std::string::npos) {
|
||||||
|
if (Util::isInteger (arg) == false) {
|
||||||
|
cerr << "Error: `" << arg << "' " ;
|
||||||
|
cerr << "is not a variable id." ;
|
||||||
|
cerr << endl;
|
||||||
|
exit (EXIT_FAILURE);
|
||||||
|
}
|
||||||
|
VarId vid = Util::stringToUnsigned (arg);
|
||||||
|
VarNode* queryVar = fg.getVarNode (vid);
|
||||||
|
if (queryVar == false) {
|
||||||
|
cerr << "Error: unknow variable with id " ;
|
||||||
|
cerr << "`" << vid << "'." << endl;
|
||||||
|
exit (EXIT_FAILURE);
|
||||||
|
}
|
||||||
|
queryIds.push_back (vid);
|
||||||
|
} else {
|
||||||
|
size_t pos = arg.find ('=');
|
||||||
|
string leftArg = arg.substr (0, pos);
|
||||||
|
string rightArg = arg.substr (pos + 1);
|
||||||
|
if (leftArg.empty()) {
|
||||||
|
cerr << "Error: missing left argument." << endl;
|
||||||
|
cerr << USAGE << endl;
|
||||||
|
exit (EXIT_FAILURE);
|
||||||
|
}
|
||||||
|
if (Util::isInteger (leftArg) == false) {
|
||||||
|
cerr << "Error: `" << leftArg << "' " ;
|
||||||
|
cerr << "is not a variable id." << endl ;
|
||||||
|
exit (EXIT_FAILURE);
|
||||||
|
}
|
||||||
|
VarId vid = Util::stringToUnsigned (leftArg);
|
||||||
|
VarNode* observedVar = fg.getVarNode (vid);
|
||||||
|
if (observedVar == false) {
|
||||||
|
cerr << "Error: unknow variable with id " ;
|
||||||
|
cerr << "`" << vid << "'." << endl;
|
||||||
|
exit (EXIT_FAILURE);
|
||||||
|
}
|
||||||
|
if (rightArg.empty()) {
|
||||||
|
cerr << "Error: missing right argument." << endl;
|
||||||
|
cerr << USAGE << endl;
|
||||||
|
exit (EXIT_FAILURE);
|
||||||
|
}
|
||||||
|
if (Util::isInteger (rightArg) == false) {
|
||||||
|
cerr << "Error: `" << rightArg << "' " ;
|
||||||
|
cerr << "is not a state index." << endl ;
|
||||||
|
exit (EXIT_FAILURE);
|
||||||
|
}
|
||||||
|
unsigned stateIdx = Util::stringToUnsigned (rightArg);
|
||||||
|
if (observedVar->isValidState (stateIdx) == false) {
|
||||||
|
cerr << "Error: `" << stateIdx << "' " ;
|
||||||
|
cerr << "is not a valid state index for variable with id " ;
|
||||||
|
cerr << "`" << vid << "'." << endl;
|
||||||
|
exit (EXIT_FAILURE);
|
||||||
|
}
|
||||||
|
observedVar->setEvidence (stateIdx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return queryIds;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
runSolver (const FactorGraph& fg, const VarIds& queryIds)
|
||||||
|
{
|
||||||
|
GroundSolver* solver = 0;
|
||||||
|
switch (Globals::groundSolver) {
|
||||||
|
case GroundSolverType::VE:
|
||||||
|
solver = new VarElim (fg);
|
||||||
|
break;
|
||||||
|
case GroundSolverType::BP:
|
||||||
|
solver = new BeliefProp (fg);
|
||||||
|
break;
|
||||||
|
case GroundSolverType::CBP:
|
||||||
|
solver = new CountingBp (fg);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
assert (false);
|
||||||
|
}
|
||||||
|
if (Globals::verbosity > 0) {
|
||||||
|
solver->printSolverFlags();
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
if (queryIds.empty()) {
|
||||||
|
solver->printAllPosterioris();
|
||||||
|
} else {
|
||||||
|
solver->printAnswer (queryIds);
|
||||||
|
}
|
||||||
|
delete solver;
|
||||||
|
}
|
||||||
|
|
570
packages/CLPBN/horus2/HorusYap.cpp
Normal file
570
packages/CLPBN/horus2/HorusYap.cpp
Normal file
@ -0,0 +1,570 @@
|
|||||||
|
#include <cstdlib>
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
#include <YapInterface.h>
|
||||||
|
|
||||||
|
#include "ParfactorList.h"
|
||||||
|
#include "FactorGraph.h"
|
||||||
|
#include "LiftedOperations.h"
|
||||||
|
#include "LiftedVe.h"
|
||||||
|
#include "VarElim.h"
|
||||||
|
#include "LiftedBp.h"
|
||||||
|
#include "CountingBp.h"
|
||||||
|
#include "BeliefProp.h"
|
||||||
|
#include "LiftedKc.h"
|
||||||
|
#include "ElimGraph.h"
|
||||||
|
#include "BayesBall.h"
|
||||||
|
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
typedef std::pair<ParfactorList*, ObservedFormulas*> LiftedNetwork;
|
||||||
|
|
||||||
|
Parfactor* readParfactor (YAP_Term);
|
||||||
|
|
||||||
|
void readLiftedEvidence (YAP_Term, ObservedFormulas&);
|
||||||
|
|
||||||
|
vector<unsigned> readUnsignedList (YAP_Term list);
|
||||||
|
|
||||||
|
Params readParameters (YAP_Term);
|
||||||
|
|
||||||
|
YAP_Term fillAnswersPrologList (vector<Params>& results);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
int
|
||||||
|
createLiftedNetwork (void)
|
||||||
|
{
|
||||||
|
Parfactors parfactors;
|
||||||
|
YAP_Term parfactorList = YAP_ARG1;
|
||||||
|
while (parfactorList != YAP_TermNil()) {
|
||||||
|
YAP_Term pfTerm = YAP_HeadOfTerm (parfactorList);
|
||||||
|
parfactors.push_back (readParfactor (pfTerm));
|
||||||
|
parfactorList = YAP_TailOfTerm (parfactorList);
|
||||||
|
}
|
||||||
|
|
||||||
|
// LiftedUtils::printSymbolDictionary();
|
||||||
|
if (Globals::verbosity > 2) {
|
||||||
|
Util::printHeader ("INITIAL PARFACTORS");
|
||||||
|
for (size_t i = 0; i < parfactors.size(); i++) {
|
||||||
|
parfactors[i]->print();
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ParfactorList* pfList = new ParfactorList (parfactors);
|
||||||
|
|
||||||
|
if (Globals::verbosity > 2) {
|
||||||
|
Util::printHeader ("SHATTERED PARFACTORS");
|
||||||
|
pfList->print();
|
||||||
|
}
|
||||||
|
|
||||||
|
// read evidence
|
||||||
|
ObservedFormulas* obsFormulas = new ObservedFormulas();
|
||||||
|
readLiftedEvidence (YAP_ARG2, *(obsFormulas));
|
||||||
|
|
||||||
|
LiftedNetwork* net = new LiftedNetwork (pfList, obsFormulas);
|
||||||
|
|
||||||
|
YAP_Int p = (YAP_Int) (net);
|
||||||
|
return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG3);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
int
|
||||||
|
createGroundNetwork (void)
|
||||||
|
{
|
||||||
|
string factorsType ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG1)));
|
||||||
|
FactorGraph* fg = new FactorGraph();
|
||||||
|
if (factorsType == "bayes") {
|
||||||
|
fg->setFactorsAsBayesian();
|
||||||
|
}
|
||||||
|
YAP_Term factorList = YAP_ARG2;
|
||||||
|
while (factorList != YAP_TermNil()) {
|
||||||
|
YAP_Term factor = YAP_HeadOfTerm (factorList);
|
||||||
|
// read the var ids
|
||||||
|
VarIds varIds = readUnsignedList (YAP_ArgOfTerm (1, factor));
|
||||||
|
// read the ranges
|
||||||
|
Ranges ranges = readUnsignedList (YAP_ArgOfTerm (2, factor));
|
||||||
|
// read the parameters
|
||||||
|
Params params = readParameters (YAP_ArgOfTerm (3, factor));
|
||||||
|
// read dist id
|
||||||
|
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (4, factor));
|
||||||
|
fg->addFactor (Factor (varIds, ranges, params, distId));
|
||||||
|
factorList = YAP_TailOfTerm (factorList);
|
||||||
|
}
|
||||||
|
unsigned nrObservedVars = 0;
|
||||||
|
YAP_Term evidenceList = YAP_ARG3;
|
||||||
|
while (evidenceList != YAP_TermNil()) {
|
||||||
|
YAP_Term evTerm = YAP_HeadOfTerm (evidenceList);
|
||||||
|
unsigned vid = (unsigned) YAP_IntOfTerm ((YAP_ArgOfTerm (1, evTerm)));
|
||||||
|
unsigned ev = (unsigned) YAP_IntOfTerm ((YAP_ArgOfTerm (2, evTerm)));
|
||||||
|
assert (fg->getVarNode (vid));
|
||||||
|
fg->getVarNode (vid)->setEvidence (ev);
|
||||||
|
evidenceList = YAP_TailOfTerm (evidenceList);
|
||||||
|
nrObservedVars ++;
|
||||||
|
}
|
||||||
|
if (Globals::verbosity > 0) {
|
||||||
|
cout << "factor graph contains " ;
|
||||||
|
cout << fg->nrVarNodes() << " variables " ;
|
||||||
|
cout << "(" << nrObservedVars << " observed) and " ;
|
||||||
|
cout << fg->nrFacNodes() << " factors " << endl;
|
||||||
|
}
|
||||||
|
YAP_Int p = (YAP_Int) (fg);
|
||||||
|
return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG4);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
int
|
||||||
|
runLiftedSolver (void)
|
||||||
|
{
|
||||||
|
LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1);
|
||||||
|
ParfactorList pfListCopy (*network->first);
|
||||||
|
LiftedOperations::absorveEvidence (pfListCopy, *network->second);
|
||||||
|
|
||||||
|
LiftedSolver* solver = 0;
|
||||||
|
switch (Globals::liftedSolver) {
|
||||||
|
case LiftedSolverType::LVE: solver = new LiftedVe (pfListCopy); break;
|
||||||
|
case LiftedSolverType::LBP: solver = new LiftedBp (pfListCopy); break;
|
||||||
|
case LiftedSolverType::LKC: solver = new LiftedKc (pfListCopy); break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (Globals::verbosity > 0) {
|
||||||
|
solver->printSolverFlags();
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
YAP_Term taskList = YAP_ARG2;
|
||||||
|
vector<Params> results;
|
||||||
|
while (taskList != YAP_TermNil()) {
|
||||||
|
Grounds queryVars;
|
||||||
|
YAP_Term jointList = YAP_HeadOfTerm (taskList);
|
||||||
|
while (jointList != YAP_TermNil()) {
|
||||||
|
YAP_Term ground = YAP_HeadOfTerm (jointList);
|
||||||
|
if (YAP_IsAtomTerm (ground)) {
|
||||||
|
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ground)));
|
||||||
|
queryVars.push_back (Ground (LiftedUtils::getSymbol (name)));
|
||||||
|
} else {
|
||||||
|
assert (YAP_IsApplTerm (ground));
|
||||||
|
YAP_Functor yapFunctor = YAP_FunctorOfTerm (ground);
|
||||||
|
string name ((char*) (YAP_AtomName (YAP_NameOfFunctor (yapFunctor))));
|
||||||
|
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
|
||||||
|
Symbol functor = LiftedUtils::getSymbol (name);
|
||||||
|
Symbols args;
|
||||||
|
for (unsigned i = 1; i <= arity; i++) {
|
||||||
|
YAP_Term ti = YAP_ArgOfTerm (i, ground);
|
||||||
|
assert (YAP_IsAtomTerm (ti));
|
||||||
|
string arg ((char *) YAP_AtomName (YAP_AtomOfTerm (ti)));
|
||||||
|
args.push_back (LiftedUtils::getSymbol (arg));
|
||||||
|
}
|
||||||
|
queryVars.push_back (Ground (functor, args));
|
||||||
|
}
|
||||||
|
jointList = YAP_TailOfTerm (jointList);
|
||||||
|
}
|
||||||
|
results.push_back (solver->solveQuery (queryVars));
|
||||||
|
taskList = YAP_TailOfTerm (taskList);
|
||||||
|
}
|
||||||
|
|
||||||
|
delete solver;
|
||||||
|
|
||||||
|
return YAP_Unify (fillAnswersPrologList (results), YAP_ARG3);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
int
|
||||||
|
runGroundSolver (void)
|
||||||
|
{
|
||||||
|
FactorGraph* fg = (FactorGraph*) YAP_IntOfTerm (YAP_ARG1);
|
||||||
|
|
||||||
|
vector<VarIds> tasks;
|
||||||
|
YAP_Term taskList = YAP_ARG2;
|
||||||
|
while (taskList != YAP_TermNil()) {
|
||||||
|
tasks.push_back (readUnsignedList (YAP_HeadOfTerm (taskList)));
|
||||||
|
taskList = YAP_TailOfTerm (taskList);
|
||||||
|
}
|
||||||
|
|
||||||
|
FactorGraph* mfg = fg;
|
||||||
|
if (fg->bayesianFactors()) {
|
||||||
|
std::set<VarId> vids;
|
||||||
|
for (size_t i = 0; i < tasks.size(); i++) {
|
||||||
|
Util::addToSet (vids, tasks[i]);
|
||||||
|
}
|
||||||
|
mfg = BayesBall::getMinimalFactorGraph (
|
||||||
|
*fg, VarIds (vids.begin(), vids.end()));
|
||||||
|
}
|
||||||
|
|
||||||
|
GroundSolver* solver = 0;
|
||||||
|
CountingBp::checkForIdenticalFactors = false;
|
||||||
|
switch (Globals::groundSolver) {
|
||||||
|
case GroundSolverType::VE: solver = new VarElim (*mfg); break;
|
||||||
|
case GroundSolverType::BP: solver = new BeliefProp (*mfg); break;
|
||||||
|
case GroundSolverType::CBP: solver = new CountingBp (*mfg); break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (Globals::verbosity > 0) {
|
||||||
|
solver->printSolverFlags();
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<Params> results;
|
||||||
|
results.reserve (tasks.size());
|
||||||
|
for (size_t i = 0; i < tasks.size(); i++) {
|
||||||
|
results.push_back (solver->solveQuery (tasks[i]));
|
||||||
|
}
|
||||||
|
|
||||||
|
delete solver;
|
||||||
|
if (fg->bayesianFactors()) {
|
||||||
|
delete mfg;
|
||||||
|
}
|
||||||
|
|
||||||
|
return YAP_Unify (fillAnswersPrologList (results), YAP_ARG3);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
int
|
||||||
|
setParfactorsParams (void)
|
||||||
|
{
|
||||||
|
LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1);
|
||||||
|
ParfactorList* pfList = network->first;
|
||||||
|
YAP_Term distIdsList = YAP_ARG2;
|
||||||
|
YAP_Term paramsList = YAP_ARG3;
|
||||||
|
unordered_map<unsigned, Params> paramsMap;
|
||||||
|
while (distIdsList != YAP_TermNil()) {
|
||||||
|
unsigned distId = (unsigned) YAP_IntOfTerm (
|
||||||
|
YAP_HeadOfTerm (distIdsList));
|
||||||
|
assert (Util::contains (paramsMap, distId) == false);
|
||||||
|
paramsMap[distId] = readParameters (YAP_HeadOfTerm (paramsList));
|
||||||
|
distIdsList = YAP_TailOfTerm (distIdsList);
|
||||||
|
paramsList = YAP_TailOfTerm (paramsList);
|
||||||
|
}
|
||||||
|
ParfactorList::iterator it = pfList->begin();
|
||||||
|
while (it != pfList->end()) {
|
||||||
|
assert (Util::contains (paramsMap, (*it)->distId()));
|
||||||
|
(*it)->setParams (paramsMap[(*it)->distId()]);
|
||||||
|
++ it;
|
||||||
|
}
|
||||||
|
return TRUE;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
int
|
||||||
|
setFactorsParams (void)
|
||||||
|
{
|
||||||
|
FactorGraph* fg = (FactorGraph*) YAP_IntOfTerm (YAP_ARG1);
|
||||||
|
YAP_Term distIdsList = YAP_ARG2;
|
||||||
|
YAP_Term paramsList = YAP_ARG3;
|
||||||
|
unordered_map<unsigned, Params> paramsMap;
|
||||||
|
while (distIdsList != YAP_TermNil()) {
|
||||||
|
unsigned distId = (unsigned) YAP_IntOfTerm (
|
||||||
|
YAP_HeadOfTerm (distIdsList));
|
||||||
|
assert (Util::contains (paramsMap, distId) == false);
|
||||||
|
paramsMap[distId] = readParameters (YAP_HeadOfTerm (paramsList));
|
||||||
|
distIdsList = YAP_TailOfTerm (distIdsList);
|
||||||
|
paramsList = YAP_TailOfTerm (paramsList);
|
||||||
|
}
|
||||||
|
const FacNodes& facNodes = fg->facNodes();
|
||||||
|
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||||
|
unsigned distId = facNodes[i]->factor().distId();
|
||||||
|
assert (Util::contains (paramsMap, distId));
|
||||||
|
facNodes[i]->factor().setParams (paramsMap[distId]);
|
||||||
|
}
|
||||||
|
return TRUE;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
int
|
||||||
|
setVarsInformation (void)
|
||||||
|
{
|
||||||
|
Var::clearVarsInfo();
|
||||||
|
vector<string> labels;
|
||||||
|
YAP_Term labelsL = YAP_ARG1;
|
||||||
|
while (labelsL != YAP_TermNil()) {
|
||||||
|
YAP_Atom atom = YAP_AtomOfTerm (YAP_HeadOfTerm (labelsL));
|
||||||
|
labels.push_back ((char*) YAP_AtomName (atom));
|
||||||
|
labelsL = YAP_TailOfTerm (labelsL);
|
||||||
|
}
|
||||||
|
unsigned count = 0;
|
||||||
|
YAP_Term stateNamesL = YAP_ARG2;
|
||||||
|
while (stateNamesL != YAP_TermNil()) {
|
||||||
|
States states;
|
||||||
|
YAP_Term namesL = YAP_HeadOfTerm (stateNamesL);
|
||||||
|
while (namesL != YAP_TermNil()) {
|
||||||
|
YAP_Atom atom = YAP_AtomOfTerm (YAP_HeadOfTerm (namesL));
|
||||||
|
states.push_back ((char*) YAP_AtomName (atom));
|
||||||
|
namesL = YAP_TailOfTerm (namesL);
|
||||||
|
}
|
||||||
|
Var::addVarInfo (count, labels[count], states);
|
||||||
|
count ++;
|
||||||
|
stateNamesL = YAP_TailOfTerm (stateNamesL);
|
||||||
|
}
|
||||||
|
return TRUE;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
int
|
||||||
|
setHorusFlag (void)
|
||||||
|
{
|
||||||
|
string key ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG1)));
|
||||||
|
string value;
|
||||||
|
if (key == "verbosity") {
|
||||||
|
stringstream ss;
|
||||||
|
ss << (int) YAP_IntOfTerm (YAP_ARG2);
|
||||||
|
ss >> value;
|
||||||
|
} else if (key == "accuracy") {
|
||||||
|
stringstream ss;
|
||||||
|
ss << (float) YAP_FloatOfTerm (YAP_ARG2);
|
||||||
|
ss >> value;
|
||||||
|
} else if (key == "max_iter") {
|
||||||
|
stringstream ss;
|
||||||
|
ss << (int) YAP_IntOfTerm (YAP_ARG2);
|
||||||
|
ss >> value;
|
||||||
|
} else {
|
||||||
|
value = ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG2)));
|
||||||
|
}
|
||||||
|
return Util::setHorusFlag (key, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
int
|
||||||
|
freeGroundNetwork (void)
|
||||||
|
{
|
||||||
|
delete (FactorGraph*) YAP_IntOfTerm (YAP_ARG1);
|
||||||
|
return TRUE;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
int
|
||||||
|
freeLiftedNetwork (void)
|
||||||
|
{
|
||||||
|
LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1);
|
||||||
|
delete network->first;
|
||||||
|
delete network->second;
|
||||||
|
delete network;
|
||||||
|
return TRUE;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Parfactor*
|
||||||
|
readParfactor (YAP_Term pfTerm)
|
||||||
|
{
|
||||||
|
// read dist id
|
||||||
|
unsigned distId = YAP_IntOfTerm (YAP_ArgOfTerm (1, pfTerm));
|
||||||
|
|
||||||
|
// read the ranges
|
||||||
|
Ranges ranges;
|
||||||
|
YAP_Term rangeList = YAP_ArgOfTerm (3, pfTerm);
|
||||||
|
while (rangeList != YAP_TermNil()) {
|
||||||
|
unsigned range = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (rangeList));
|
||||||
|
ranges.push_back (range);
|
||||||
|
rangeList = YAP_TailOfTerm (rangeList);
|
||||||
|
}
|
||||||
|
|
||||||
|
// read parametric random vars
|
||||||
|
ProbFormulas formulas;
|
||||||
|
unsigned count = 0;
|
||||||
|
unordered_map<YAP_Term, LogVar> lvMap;
|
||||||
|
YAP_Term pvList = YAP_ArgOfTerm (2, pfTerm);
|
||||||
|
while (pvList != YAP_TermNil()) {
|
||||||
|
YAP_Term formulaTerm = YAP_HeadOfTerm (pvList);
|
||||||
|
if (YAP_IsAtomTerm (formulaTerm)) {
|
||||||
|
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (formulaTerm)));
|
||||||
|
Symbol functor = LiftedUtils::getSymbol (name);
|
||||||
|
formulas.push_back (ProbFormula (functor, ranges[count]));
|
||||||
|
} else {
|
||||||
|
LogVars logVars;
|
||||||
|
YAP_Functor yapFunctor = YAP_FunctorOfTerm (formulaTerm);
|
||||||
|
string name ((char*) YAP_AtomName (YAP_NameOfFunctor (yapFunctor)));
|
||||||
|
Symbol functor = LiftedUtils::getSymbol (name);
|
||||||
|
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
|
||||||
|
for (unsigned i = 1; i <= arity; i++) {
|
||||||
|
YAP_Term ti = YAP_ArgOfTerm (i, formulaTerm);
|
||||||
|
unordered_map<YAP_Term, LogVar>::iterator it = lvMap.find (ti);
|
||||||
|
if (it != lvMap.end()) {
|
||||||
|
logVars.push_back (it->second);
|
||||||
|
} else {
|
||||||
|
unsigned newLv = lvMap.size();
|
||||||
|
lvMap[ti] = newLv;
|
||||||
|
logVars.push_back (newLv);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
formulas.push_back (ProbFormula (functor, logVars, ranges[count]));
|
||||||
|
}
|
||||||
|
count ++;
|
||||||
|
pvList = YAP_TailOfTerm (pvList);
|
||||||
|
}
|
||||||
|
|
||||||
|
// read the parameters
|
||||||
|
const Params& params = readParameters (YAP_ArgOfTerm (4, pfTerm));
|
||||||
|
|
||||||
|
// read the constraint
|
||||||
|
Tuples tuples;
|
||||||
|
if (lvMap.size() >= 1) {
|
||||||
|
YAP_Term tupleList = YAP_ArgOfTerm (5, pfTerm);
|
||||||
|
while (tupleList != YAP_TermNil()) {
|
||||||
|
YAP_Term term = YAP_HeadOfTerm (tupleList);
|
||||||
|
assert (YAP_IsApplTerm (term));
|
||||||
|
YAP_Functor yapFunctor = YAP_FunctorOfTerm (term);
|
||||||
|
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
|
||||||
|
assert (lvMap.size() == arity);
|
||||||
|
Tuple tuple (arity);
|
||||||
|
for (unsigned i = 1; i <= arity; i++) {
|
||||||
|
YAP_Term ti = YAP_ArgOfTerm (i, term);
|
||||||
|
if (YAP_IsAtomTerm (ti) == false) {
|
||||||
|
cerr << "Error: the constraint contains free variables." << endl;
|
||||||
|
exit (EXIT_FAILURE);
|
||||||
|
}
|
||||||
|
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ti)));
|
||||||
|
tuple[i - 1] = LiftedUtils::getSymbol (name);
|
||||||
|
}
|
||||||
|
tuples.push_back (tuple);
|
||||||
|
tupleList = YAP_TailOfTerm (tupleList);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return new Parfactor (formulas, params, tuples, distId);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
readLiftedEvidence (
|
||||||
|
YAP_Term observedList,
|
||||||
|
ObservedFormulas& obsFormulas)
|
||||||
|
{
|
||||||
|
while (observedList != YAP_TermNil()) {
|
||||||
|
YAP_Term pair = YAP_HeadOfTerm (observedList);
|
||||||
|
YAP_Term ground = YAP_ArgOfTerm (1, pair);
|
||||||
|
Symbol functor;
|
||||||
|
Symbols args;
|
||||||
|
if (YAP_IsAtomTerm (ground)) {
|
||||||
|
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ground)));
|
||||||
|
functor = LiftedUtils::getSymbol (name);
|
||||||
|
} else {
|
||||||
|
assert (YAP_IsApplTerm (ground));
|
||||||
|
YAP_Functor yapFunctor = YAP_FunctorOfTerm (ground);
|
||||||
|
string name ((char*) (YAP_AtomName (YAP_NameOfFunctor (yapFunctor))));
|
||||||
|
functor = LiftedUtils::getSymbol (name);
|
||||||
|
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
|
||||||
|
for (unsigned i = 1; i <= arity; i++) {
|
||||||
|
YAP_Term ti = YAP_ArgOfTerm (i, ground);
|
||||||
|
assert (YAP_IsAtomTerm (ti));
|
||||||
|
string arg ((char *) YAP_AtomName (YAP_AtomOfTerm (ti)));
|
||||||
|
args.push_back (LiftedUtils::getSymbol (arg));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
unsigned evidence = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (2, pair));
|
||||||
|
bool found = false;
|
||||||
|
for (size_t i = 0; i < obsFormulas.size(); i++) {
|
||||||
|
if (obsFormulas[i].functor() == functor &&
|
||||||
|
obsFormulas[i].arity() == args.size() &&
|
||||||
|
obsFormulas[i].evidence() == evidence) {
|
||||||
|
obsFormulas[i].addTuple (args);
|
||||||
|
found = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (found == false) {
|
||||||
|
obsFormulas.push_back (ObservedFormula (functor, evidence, args));
|
||||||
|
}
|
||||||
|
observedList = YAP_TailOfTerm (observedList);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
vector<unsigned>
|
||||||
|
readUnsignedList (YAP_Term list)
|
||||||
|
{
|
||||||
|
vector<unsigned> vec;
|
||||||
|
while (list != YAP_TermNil()) {
|
||||||
|
vec.push_back ((unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (list)));
|
||||||
|
list = YAP_TailOfTerm (list);
|
||||||
|
}
|
||||||
|
return vec;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Params
|
||||||
|
readParameters (YAP_Term paramL)
|
||||||
|
{
|
||||||
|
Params params;
|
||||||
|
assert (YAP_IsPairTerm (paramL));
|
||||||
|
while (paramL != YAP_TermNil()) {
|
||||||
|
params.push_back ((double) YAP_FloatOfTerm (YAP_HeadOfTerm (paramL)));
|
||||||
|
paramL = YAP_TailOfTerm (paramL);
|
||||||
|
}
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
Util::log (params);
|
||||||
|
}
|
||||||
|
return params;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
YAP_Term
|
||||||
|
fillAnswersPrologList (vector<Params>& results)
|
||||||
|
{
|
||||||
|
YAP_Term list = YAP_TermNil();
|
||||||
|
for (size_t i = results.size(); i-- > 0; ) {
|
||||||
|
const Params& beliefs = results[i];
|
||||||
|
YAP_Term queryBeliefsL = YAP_TermNil();
|
||||||
|
for (size_t j = beliefs.size(); j-- > 0; ) {
|
||||||
|
YAP_Int sl1 = YAP_InitSlot (list);
|
||||||
|
YAP_Term belief = YAP_MkFloatTerm (beliefs[j]);
|
||||||
|
queryBeliefsL = YAP_MkPairTerm (belief, queryBeliefsL);
|
||||||
|
list = YAP_GetFromSlot (sl1);
|
||||||
|
YAP_RecoverSlots (1);
|
||||||
|
}
|
||||||
|
list = YAP_MkPairTerm (queryBeliefsL, list);
|
||||||
|
}
|
||||||
|
return list;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
extern "C" void
|
||||||
|
init_predicates (void)
|
||||||
|
{
|
||||||
|
YAP_UserCPredicate ("cpp_create_lifted_network",
|
||||||
|
createLiftedNetwork, 3);
|
||||||
|
|
||||||
|
YAP_UserCPredicate ("cpp_create_ground_network",
|
||||||
|
createGroundNetwork, 4);
|
||||||
|
|
||||||
|
YAP_UserCPredicate ("cpp_run_lifted_solver",
|
||||||
|
runLiftedSolver, 3);
|
||||||
|
|
||||||
|
YAP_UserCPredicate ("cpp_run_ground_solver",
|
||||||
|
runGroundSolver, 3);
|
||||||
|
|
||||||
|
YAP_UserCPredicate ("cpp_set_parfactors_params",
|
||||||
|
setParfactorsParams, 3);
|
||||||
|
|
||||||
|
YAP_UserCPredicate ("cpp_set_factors_params",
|
||||||
|
setFactorsParams, 3);
|
||||||
|
|
||||||
|
YAP_UserCPredicate ("cpp_set_vars_information",
|
||||||
|
setVarsInformation, 2);
|
||||||
|
|
||||||
|
YAP_UserCPredicate ("cpp_set_horus_flag",
|
||||||
|
setHorusFlag, 2);
|
||||||
|
|
||||||
|
YAP_UserCPredicate ("cpp_free_lifted_network",
|
||||||
|
freeLiftedNetwork, 1);
|
||||||
|
|
||||||
|
YAP_UserCPredicate ("cpp_free_ground_network",
|
||||||
|
freeGroundNetwork, 1);
|
||||||
|
}
|
||||||
|
|
258
packages/CLPBN/horus2/Indexer.h
Normal file
258
packages/CLPBN/horus2/Indexer.h
Normal file
@ -0,0 +1,258 @@
|
|||||||
|
#ifndef HORUS_INDEXER_H
|
||||||
|
#define HORUS_INDEXER_H
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
|
#include <sstream>
|
||||||
|
#include <iomanip>
|
||||||
|
|
||||||
|
#include "Util.h"
|
||||||
|
|
||||||
|
|
||||||
|
class Indexer
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
Indexer (const Ranges& ranges, bool calcOffsets = true)
|
||||||
|
: index_(0), indices_(ranges.size(), 0), ranges_(ranges),
|
||||||
|
size_(Util::sizeExpected (ranges))
|
||||||
|
{
|
||||||
|
if (calcOffsets) {
|
||||||
|
calculateOffsets();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void increment (void)
|
||||||
|
{
|
||||||
|
for (size_t i = ranges_.size(); i-- > 0; ) {
|
||||||
|
indices_[i] ++;
|
||||||
|
if (indices_[i] != ranges_[i]) {
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
indices_[i] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
index_ ++;
|
||||||
|
}
|
||||||
|
|
||||||
|
void incrementDimension (size_t dim)
|
||||||
|
{
|
||||||
|
assert (dim < ranges_.size());
|
||||||
|
assert (ranges_.size() == offsets_.size());
|
||||||
|
assert (indices_[dim] < ranges_[dim]);
|
||||||
|
indices_[dim] ++;
|
||||||
|
index_ += offsets_[dim];
|
||||||
|
}
|
||||||
|
|
||||||
|
void incrementExceptDimension (size_t dim)
|
||||||
|
{
|
||||||
|
assert (ranges_.size() == offsets_.size());
|
||||||
|
for (size_t i = ranges_.size(); i-- > 0; ) {
|
||||||
|
if (i != dim) {
|
||||||
|
indices_[i] ++;
|
||||||
|
index_ += offsets_[i];
|
||||||
|
if (indices_[i] != ranges_[i]) {
|
||||||
|
return;
|
||||||
|
} else {
|
||||||
|
indices_[i] = 0;
|
||||||
|
index_ -= offsets_[i] * ranges_[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
index_ = size_;
|
||||||
|
}
|
||||||
|
|
||||||
|
Indexer& operator++ (void)
|
||||||
|
{
|
||||||
|
increment();
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
operator size_t (void) const
|
||||||
|
{
|
||||||
|
return index_;
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned operator[] (size_t dim) const
|
||||||
|
{
|
||||||
|
assert (valid());
|
||||||
|
assert (dim < ranges_.size());
|
||||||
|
return indices_[dim];
|
||||||
|
}
|
||||||
|
|
||||||
|
bool valid (void) const
|
||||||
|
{
|
||||||
|
return index_ < size_;
|
||||||
|
}
|
||||||
|
|
||||||
|
void reset (void)
|
||||||
|
{
|
||||||
|
std::fill (indices_.begin(), indices_.end(), 0);
|
||||||
|
index_ = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void resetDimension (size_t dim)
|
||||||
|
{
|
||||||
|
indices_[dim] = 0;
|
||||||
|
index_ -= offsets_[dim] * ranges_[dim];
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t size (void) const
|
||||||
|
{
|
||||||
|
return size_ ;
|
||||||
|
}
|
||||||
|
|
||||||
|
friend std::ostream& operator<< (std::ostream&, const Indexer&);
|
||||||
|
|
||||||
|
private:
|
||||||
|
void calculateOffsets (void)
|
||||||
|
{
|
||||||
|
size_t prod = 1;
|
||||||
|
offsets_.resize (ranges_.size());
|
||||||
|
for (size_t i = ranges_.size(); i-- > 0; ) {
|
||||||
|
offsets_[i] = prod;
|
||||||
|
prod *= ranges_[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t index_;
|
||||||
|
Ranges indices_;
|
||||||
|
const Ranges& ranges_;
|
||||||
|
size_t size_;
|
||||||
|
vector<size_t> offsets_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
inline std::ostream&
|
||||||
|
operator<< (std::ostream& os, const Indexer& indexer)
|
||||||
|
{
|
||||||
|
os << "(" ;
|
||||||
|
os << std::setw (2) << std::setfill('0') << indexer.index_;
|
||||||
|
os << ") " ;
|
||||||
|
os << indexer.indices_;
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class MapIndexer
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
MapIndexer (const Ranges& ranges, const vector<bool>& mask)
|
||||||
|
: index_(0), indices_(ranges.size(), 0), ranges_(ranges),
|
||||||
|
valid_(true)
|
||||||
|
{
|
||||||
|
size_t prod = 1;
|
||||||
|
offsets_.resize (ranges.size(), 0);
|
||||||
|
for (size_t i = ranges.size(); i-- > 0; ) {
|
||||||
|
if (mask[i]) {
|
||||||
|
offsets_[i] = prod;
|
||||||
|
prod *= ranges[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert (ranges.size() == mask.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
MapIndexer (const Ranges& ranges, size_t dim)
|
||||||
|
: index_(0), indices_(ranges.size(), 0), ranges_(ranges),
|
||||||
|
valid_(true)
|
||||||
|
{
|
||||||
|
size_t prod = 1;
|
||||||
|
offsets_.resize (ranges.size(), 0);
|
||||||
|
for (size_t i = ranges.size(); i-- > 0; ) {
|
||||||
|
if (i != dim) {
|
||||||
|
offsets_[i] = prod;
|
||||||
|
prod *= ranges[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
MapIndexer (
|
||||||
|
const vector<T>& allArgs,
|
||||||
|
const Ranges& allRanges,
|
||||||
|
const vector<T>& wantedArgs,
|
||||||
|
const Ranges& wantedRanges)
|
||||||
|
: index_(0), indices_(allArgs.size(), 0), ranges_(allRanges),
|
||||||
|
valid_(true)
|
||||||
|
{
|
||||||
|
size_t prod = 1;
|
||||||
|
vector<size_t> offsets (wantedRanges.size());
|
||||||
|
for (size_t i = wantedRanges.size(); i-- > 0; ) {
|
||||||
|
offsets[i] = prod;
|
||||||
|
prod *= wantedRanges[i];
|
||||||
|
}
|
||||||
|
offsets_.reserve (allArgs.size());
|
||||||
|
for (size_t i = 0; i < allArgs.size(); i++) {
|
||||||
|
size_t idx = Util::indexOf (wantedArgs, allArgs[i]);
|
||||||
|
offsets_.push_back (idx != wantedArgs.size() ? offsets[idx] : 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
MapIndexer& operator++ (void)
|
||||||
|
{
|
||||||
|
assert (valid_);
|
||||||
|
for (size_t i = ranges_.size(); i-- > 0; ) {
|
||||||
|
indices_[i] ++;
|
||||||
|
index_ += offsets_[i];
|
||||||
|
if (indices_[i] != ranges_[i]) {
|
||||||
|
return *this;
|
||||||
|
} else {
|
||||||
|
indices_[i] = 0;
|
||||||
|
index_ -= offsets_[i] * ranges_[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
valid_ = false;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
operator size_t (void) const
|
||||||
|
{
|
||||||
|
assert (valid());
|
||||||
|
return index_;
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned operator[] (size_t dim) const
|
||||||
|
{
|
||||||
|
assert (valid());
|
||||||
|
assert (dim < ranges_.size());
|
||||||
|
return indices_[dim];
|
||||||
|
}
|
||||||
|
|
||||||
|
bool valid (void) const
|
||||||
|
{
|
||||||
|
return valid_;
|
||||||
|
}
|
||||||
|
|
||||||
|
void reset (void)
|
||||||
|
{
|
||||||
|
std::fill (indices_.begin(), indices_.end(), 0);
|
||||||
|
index_ = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
friend std::ostream& operator<< (std::ostream&, const MapIndexer&);
|
||||||
|
|
||||||
|
private:
|
||||||
|
size_t index_;
|
||||||
|
Ranges indices_;
|
||||||
|
const Ranges& ranges_;
|
||||||
|
bool valid_;
|
||||||
|
vector<size_t> offsets_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
inline std::ostream&
|
||||||
|
operator<< (std::ostream &os, const MapIndexer& indexer)
|
||||||
|
{
|
||||||
|
os << "(" ;
|
||||||
|
os << std::setw (2) << std::setfill('0') << indexer.index_;
|
||||||
|
os << ") " ;
|
||||||
|
os << indexer.indices_;
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#endif // HORUS_INDEXER_H
|
||||||
|
|
234
packages/CLPBN/horus2/LiftedBp.cpp
Normal file
234
packages/CLPBN/horus2/LiftedBp.cpp
Normal file
@ -0,0 +1,234 @@
|
|||||||
|
#include "LiftedBp.h"
|
||||||
|
#include "WeightedBp.h"
|
||||||
|
#include "FactorGraph.h"
|
||||||
|
#include "LiftedOperations.h"
|
||||||
|
|
||||||
|
|
||||||
|
LiftedBp::LiftedBp (const ParfactorList& parfactorList)
|
||||||
|
: LiftedSolver (parfactorList)
|
||||||
|
{
|
||||||
|
refineParfactors();
|
||||||
|
createFactorGraph();
|
||||||
|
solver_ = new WeightedBp (*fg_, getWeights());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
LiftedBp::~LiftedBp (void)
|
||||||
|
{
|
||||||
|
delete solver_;
|
||||||
|
delete fg_;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Params
|
||||||
|
LiftedBp::solveQuery (const Grounds& query)
|
||||||
|
{
|
||||||
|
assert (query.empty() == false);
|
||||||
|
Params res;
|
||||||
|
vector<PrvGroup> groups = getQueryGroups (query);
|
||||||
|
if (query.size() == 1) {
|
||||||
|
res = solver_->getPosterioriOf (groups[0]);
|
||||||
|
} else {
|
||||||
|
ParfactorList::iterator it = pfList_.begin();
|
||||||
|
size_t idx = pfList_.size();
|
||||||
|
size_t count = 0;
|
||||||
|
while (it != pfList_.end()) {
|
||||||
|
if ((*it)->containsGrounds (query)) {
|
||||||
|
idx = count;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
++ it;
|
||||||
|
++ count;
|
||||||
|
}
|
||||||
|
if (idx == pfList_.size()) {
|
||||||
|
res = getJointByConditioning (pfList_, query);
|
||||||
|
} else {
|
||||||
|
VarIds queryVids;
|
||||||
|
for (unsigned i = 0; i < groups.size(); i++) {
|
||||||
|
queryVids.push_back (groups[i]);
|
||||||
|
}
|
||||||
|
res = solver_->getFactorJoint (fg_->facNodes()[idx], queryVids);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
LiftedBp::printSolverFlags (void) const
|
||||||
|
{
|
||||||
|
stringstream ss;
|
||||||
|
ss << "lifted bp [" ;
|
||||||
|
ss << "schedule=" ;
|
||||||
|
typedef BpOptions::Schedule Sch;
|
||||||
|
switch (BpOptions::schedule) {
|
||||||
|
case Sch::SEQ_FIXED: ss << "seq_fixed"; break;
|
||||||
|
case Sch::SEQ_RANDOM: ss << "seq_random"; break;
|
||||||
|
case Sch::PARALLEL: ss << "parallel"; break;
|
||||||
|
case Sch::MAX_RESIDUAL: ss << "max_residual"; break;
|
||||||
|
}
|
||||||
|
ss << ",max_iter=" << BpOptions::maxIter;
|
||||||
|
ss << ",accuracy=" << BpOptions::accuracy;
|
||||||
|
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
||||||
|
ss << "]" ;
|
||||||
|
cout << ss.str() << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
LiftedBp::refineParfactors (void)
|
||||||
|
{
|
||||||
|
pfList_ = parfactorList;
|
||||||
|
while (iterate() == false);
|
||||||
|
|
||||||
|
if (Globals::verbosity > 2) {
|
||||||
|
Util::printHeader ("AFTER REFINEMENT");
|
||||||
|
pfList_.print();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
LiftedBp::iterate (void)
|
||||||
|
{
|
||||||
|
ParfactorList::iterator it = pfList_.begin();
|
||||||
|
while (it != pfList_.end()) {
|
||||||
|
const ProbFormulas& args = (*it)->arguments();
|
||||||
|
for (size_t i = 0; i < args.size(); i++) {
|
||||||
|
LogVarSet lvs = (*it)->logVarSet() - args[i].logVars();
|
||||||
|
if ((*it)->constr()->isCountNormalized (lvs) == false) {
|
||||||
|
Parfactors pfs = LiftedOperations::countNormalize (*it, lvs);
|
||||||
|
it = pfList_.removeAndDelete (it);
|
||||||
|
pfList_.add (pfs);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
++ it;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
vector<PrvGroup>
|
||||||
|
LiftedBp::getQueryGroups (const Grounds& query)
|
||||||
|
{
|
||||||
|
vector<PrvGroup> queryGroups;
|
||||||
|
for (unsigned i = 0; i < query.size(); i++) {
|
||||||
|
ParfactorList::const_iterator it = pfList_.begin();
|
||||||
|
for (; it != pfList_.end(); ++it) {
|
||||||
|
if ((*it)->containsGround (query[i])) {
|
||||||
|
queryGroups.push_back ((*it)->findGroup (query[i]));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert (queryGroups.size() == query.size());
|
||||||
|
return queryGroups;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
LiftedBp::createFactorGraph (void)
|
||||||
|
{
|
||||||
|
fg_ = new FactorGraph();
|
||||||
|
ParfactorList::const_iterator it = pfList_.begin();
|
||||||
|
for (; it != pfList_.end(); ++it) {
|
||||||
|
vector<PrvGroup> groups = (*it)->getAllGroups();
|
||||||
|
VarIds varIds;
|
||||||
|
for (size_t i = 0; i < groups.size(); i++) {
|
||||||
|
varIds.push_back (groups[i]);
|
||||||
|
}
|
||||||
|
fg_->addFactor (Factor (varIds, (*it)->ranges(), (*it)->params()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
vector<vector<unsigned>>
|
||||||
|
LiftedBp::getWeights (void) const
|
||||||
|
{
|
||||||
|
vector<vector<unsigned>> weights;
|
||||||
|
weights.reserve (pfList_.size());
|
||||||
|
ParfactorList::const_iterator it = pfList_.begin();
|
||||||
|
for (; it != pfList_.end(); ++it) {
|
||||||
|
const ProbFormulas& args = (*it)->arguments();
|
||||||
|
weights.push_back ({ });
|
||||||
|
weights.back().reserve (args.size());
|
||||||
|
for (size_t i = 0; i < args.size(); i++) {
|
||||||
|
LogVarSet lvs = (*it)->logVarSet() - args[i].logVars();
|
||||||
|
weights.back().push_back ((*it)->constr()->getConditionalCount (lvs));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return weights;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
unsigned
|
||||||
|
LiftedBp::rangeOfGround (const Ground& gr)
|
||||||
|
{
|
||||||
|
ParfactorList::iterator it = pfList_.begin();
|
||||||
|
while (it != pfList_.end()) {
|
||||||
|
if ((*it)->containsGround (gr)) {
|
||||||
|
PrvGroup prvGroup = (*it)->findGroup (gr);
|
||||||
|
return (*it)->range ((*it)->indexOfGroup (prvGroup));
|
||||||
|
}
|
||||||
|
++ it;
|
||||||
|
}
|
||||||
|
return std::numeric_limits<unsigned>::max();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Params
|
||||||
|
LiftedBp::getJointByConditioning (
|
||||||
|
const ParfactorList& pfList,
|
||||||
|
const Grounds& query)
|
||||||
|
{
|
||||||
|
LiftedBp solver (pfList);
|
||||||
|
Params prevBeliefs = solver.solveQuery ({query[0]});
|
||||||
|
Grounds obsGrounds = {query[0]};
|
||||||
|
for (size_t i = 1; i < query.size(); i++) {
|
||||||
|
Params newBeliefs;
|
||||||
|
vector<ObservedFormula> obsFs;
|
||||||
|
Ranges obsRanges;
|
||||||
|
for (size_t j = 0; j < obsGrounds.size(); j++) {
|
||||||
|
obsFs.push_back (ObservedFormula (
|
||||||
|
obsGrounds[j].functor(), 0, obsGrounds[j].args()));
|
||||||
|
obsRanges.push_back (rangeOfGround (obsGrounds[j]));
|
||||||
|
}
|
||||||
|
Indexer indexer (obsRanges, false);
|
||||||
|
while (indexer.valid()) {
|
||||||
|
for (size_t j = 0; j < obsFs.size(); j++) {
|
||||||
|
obsFs[j].setEvidence (indexer[j]);
|
||||||
|
}
|
||||||
|
ParfactorList tempPfList (pfList);
|
||||||
|
LiftedOperations::absorveEvidence (tempPfList, obsFs);
|
||||||
|
LiftedBp solver (tempPfList);
|
||||||
|
Params beliefs = solver.solveQuery ({query[i]});
|
||||||
|
for (size_t k = 0; k < beliefs.size(); k++) {
|
||||||
|
newBeliefs.push_back (beliefs[k]);
|
||||||
|
}
|
||||||
|
++ indexer;
|
||||||
|
}
|
||||||
|
int count = -1;
|
||||||
|
unsigned range = rangeOfGround (query[i]);
|
||||||
|
for (size_t j = 0; j < newBeliefs.size(); j++) {
|
||||||
|
if (j % range == 0) {
|
||||||
|
count ++;
|
||||||
|
}
|
||||||
|
newBeliefs[j] *= prevBeliefs[count];
|
||||||
|
}
|
||||||
|
prevBeliefs = newBeliefs;
|
||||||
|
obsGrounds.push_back (query[i]);
|
||||||
|
}
|
||||||
|
return prevBeliefs;
|
||||||
|
}
|
||||||
|
|
43
packages/CLPBN/horus2/LiftedBp.h
Normal file
43
packages/CLPBN/horus2/LiftedBp.h
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
#ifndef HORUS_LIFTEDBP_H
|
||||||
|
#define HORUS_LIFTEDBP_H
|
||||||
|
|
||||||
|
#include "LiftedSolver.h"
|
||||||
|
#include "ParfactorList.h"
|
||||||
|
|
||||||
|
class FactorGraph;
|
||||||
|
class WeightedBp;
|
||||||
|
|
||||||
|
class LiftedBp : public LiftedSolver
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
LiftedBp (const ParfactorList& pfList);
|
||||||
|
|
||||||
|
~LiftedBp (void);
|
||||||
|
|
||||||
|
Params solveQuery (const Grounds&);
|
||||||
|
|
||||||
|
void printSolverFlags (void) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
void refineParfactors (void);
|
||||||
|
|
||||||
|
bool iterate (void);
|
||||||
|
|
||||||
|
vector<PrvGroup> getQueryGroups (const Grounds&);
|
||||||
|
|
||||||
|
void createFactorGraph (void);
|
||||||
|
|
||||||
|
vector<vector<unsigned>> getWeights (void) const;
|
||||||
|
|
||||||
|
unsigned rangeOfGround (const Ground&);
|
||||||
|
|
||||||
|
Params getJointByConditioning (const ParfactorList&, const Grounds&);
|
||||||
|
|
||||||
|
ParfactorList pfList_;
|
||||||
|
WeightedBp* solver_;
|
||||||
|
FactorGraph* fg_;
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // HORUS_LIFTEDBP_H
|
||||||
|
|
1309
packages/CLPBN/horus2/LiftedKc.cpp
Normal file
1309
packages/CLPBN/horus2/LiftedKc.cpp
Normal file
File diff suppressed because it is too large
Load Diff
300
packages/CLPBN/horus2/LiftedKc.h
Normal file
300
packages/CLPBN/horus2/LiftedKc.h
Normal file
@ -0,0 +1,300 @@
|
|||||||
|
#ifndef HORUS_LIFTEDKC_H
|
||||||
|
#define HORUS_LIFTEDKC_H
|
||||||
|
|
||||||
|
|
||||||
|
#include "LiftedWCNF.h"
|
||||||
|
#include "LiftedSolver.h"
|
||||||
|
#include "ParfactorList.h"
|
||||||
|
|
||||||
|
|
||||||
|
enum CircuitNodeType {
|
||||||
|
OR_NODE,
|
||||||
|
AND_NODE,
|
||||||
|
SET_OR_NODE,
|
||||||
|
SET_AND_NODE,
|
||||||
|
INC_EXC_NODE,
|
||||||
|
LEAF_NODE,
|
||||||
|
SMOOTH_NODE,
|
||||||
|
TRUE_NODE,
|
||||||
|
COMPILATION_FAILED_NODE
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class CircuitNode
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
CircuitNode (void) { }
|
||||||
|
|
||||||
|
virtual ~CircuitNode (void) { }
|
||||||
|
|
||||||
|
virtual double weight (void) const = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class OrNode : public CircuitNode
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
OrNode (void) : CircuitNode(), leftBranch_(0), rightBranch_(0) { }
|
||||||
|
|
||||||
|
~OrNode (void);
|
||||||
|
|
||||||
|
CircuitNode** leftBranch (void) { return &leftBranch_; }
|
||||||
|
CircuitNode** rightBranch (void) { return &rightBranch_; }
|
||||||
|
|
||||||
|
double weight (void) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
CircuitNode* leftBranch_;
|
||||||
|
CircuitNode* rightBranch_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class AndNode : public CircuitNode
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
AndNode (void) : CircuitNode(), leftBranch_(0), rightBranch_(0) { }
|
||||||
|
|
||||||
|
AndNode (CircuitNode* leftBranch, CircuitNode* rightBranch)
|
||||||
|
: CircuitNode(), leftBranch_(leftBranch), rightBranch_(rightBranch) { }
|
||||||
|
|
||||||
|
~AndNode (void);
|
||||||
|
|
||||||
|
CircuitNode** leftBranch (void) { return &leftBranch_; }
|
||||||
|
CircuitNode** rightBranch (void) { return &rightBranch_; }
|
||||||
|
|
||||||
|
double weight (void) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
CircuitNode* leftBranch_;
|
||||||
|
CircuitNode* rightBranch_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class SetOrNode : public CircuitNode
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
SetOrNode (unsigned nrGroundings)
|
||||||
|
: CircuitNode(), follow_(0), nrGroundings_(nrGroundings) { }
|
||||||
|
|
||||||
|
~SetOrNode (void);
|
||||||
|
|
||||||
|
CircuitNode** follow (void) { return &follow_; }
|
||||||
|
|
||||||
|
static unsigned nrPositives (void) { return nrPos_; }
|
||||||
|
|
||||||
|
static unsigned nrNegatives (void) { return nrNeg_; }
|
||||||
|
|
||||||
|
static bool isSet (void) { return nrPos_ >= 0; }
|
||||||
|
|
||||||
|
double weight (void) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
CircuitNode* follow_;
|
||||||
|
unsigned nrGroundings_;
|
||||||
|
static int nrPos_;
|
||||||
|
static int nrNeg_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class SetAndNode : public CircuitNode
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
SetAndNode (unsigned nrGroundings)
|
||||||
|
: CircuitNode(), follow_(0), nrGroundings_(nrGroundings) { }
|
||||||
|
|
||||||
|
~SetAndNode (void);
|
||||||
|
|
||||||
|
CircuitNode** follow (void) { return &follow_; }
|
||||||
|
|
||||||
|
double weight (void) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
CircuitNode* follow_;
|
||||||
|
unsigned nrGroundings_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class IncExcNode : public CircuitNode
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
IncExcNode (void)
|
||||||
|
: CircuitNode(), plus1Branch_(0), plus2Branch_(0), minusBranch_(0) { }
|
||||||
|
|
||||||
|
~IncExcNode (void);
|
||||||
|
|
||||||
|
CircuitNode** plus1Branch (void) { return &plus1Branch_; }
|
||||||
|
CircuitNode** plus2Branch (void) { return &plus2Branch_; }
|
||||||
|
CircuitNode** minusBranch (void) { return &minusBranch_; }
|
||||||
|
|
||||||
|
double weight (void) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
CircuitNode* plus1Branch_;
|
||||||
|
CircuitNode* plus2Branch_;
|
||||||
|
CircuitNode* minusBranch_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class LeafNode : public CircuitNode
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
LeafNode (Clause* clause, const LiftedWCNF& lwcnf)
|
||||||
|
: CircuitNode(), clause_(clause), lwcnf_(lwcnf) { }
|
||||||
|
|
||||||
|
~LeafNode (void);
|
||||||
|
|
||||||
|
const Clause* clause (void) const { return clause_; }
|
||||||
|
|
||||||
|
Clause* clause (void) { return clause_; }
|
||||||
|
|
||||||
|
double weight (void) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
Clause* clause_;
|
||||||
|
const LiftedWCNF& lwcnf_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class SmoothNode : public CircuitNode
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
SmoothNode (const Clauses& clauses, const LiftedWCNF& lwcnf)
|
||||||
|
: CircuitNode(), clauses_(clauses), lwcnf_(lwcnf) { }
|
||||||
|
|
||||||
|
~SmoothNode (void);
|
||||||
|
|
||||||
|
const Clauses& clauses (void) const { return clauses_; }
|
||||||
|
|
||||||
|
Clauses clauses (void) { return clauses_; }
|
||||||
|
|
||||||
|
double weight (void) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
Clauses clauses_;
|
||||||
|
const LiftedWCNF& lwcnf_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class TrueNode : public CircuitNode
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
TrueNode (void) : CircuitNode() { }
|
||||||
|
|
||||||
|
double weight (void) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class CompilationFailedNode : public CircuitNode
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
CompilationFailedNode (void) : CircuitNode() { }
|
||||||
|
|
||||||
|
double weight (void) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class LiftedCircuit
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
LiftedCircuit (const LiftedWCNF* lwcnf);
|
||||||
|
|
||||||
|
~LiftedCircuit (void);
|
||||||
|
|
||||||
|
bool isCompilationSucceeded (void) const;
|
||||||
|
|
||||||
|
double getWeightedModelCount (void) const;
|
||||||
|
|
||||||
|
void exportToGraphViz (const char*);
|
||||||
|
|
||||||
|
private:
|
||||||
|
|
||||||
|
void compile (CircuitNode** follow, Clauses& clauses);
|
||||||
|
|
||||||
|
bool tryUnitPropagation (CircuitNode** follow, Clauses& clauses);
|
||||||
|
|
||||||
|
bool tryIndependence (CircuitNode** follow, Clauses& clauses);
|
||||||
|
|
||||||
|
bool tryShannonDecomp (CircuitNode** follow, Clauses& clauses);
|
||||||
|
|
||||||
|
bool tryInclusionExclusion (CircuitNode** follow, Clauses& clauses);
|
||||||
|
|
||||||
|
bool tryIndepPartialGrounding (CircuitNode** follow, Clauses& clauses);
|
||||||
|
|
||||||
|
bool tryIndepPartialGroundingAux (Clauses& clauses, ConstraintTree& ct,
|
||||||
|
LogVars& rootLogVars);
|
||||||
|
|
||||||
|
bool tryAtomCounting (CircuitNode** follow, Clauses& clauses);
|
||||||
|
|
||||||
|
void shatterCountedLogVars (Clauses& clauses);
|
||||||
|
|
||||||
|
bool shatterCountedLogVarsAux (Clauses& clauses);
|
||||||
|
|
||||||
|
bool shatterCountedLogVarsAux (Clauses& clauses, size_t idx1, size_t idx2);
|
||||||
|
|
||||||
|
bool independentClause (Clause& clause, Clauses& otherClauses) const;
|
||||||
|
|
||||||
|
bool independentLiteral (const Literal& lit,
|
||||||
|
const Literals& otherLits) const;
|
||||||
|
|
||||||
|
LitLvTypesSet smoothCircuit (CircuitNode* node);
|
||||||
|
|
||||||
|
void createSmoothNode (const LitLvTypesSet& lids,
|
||||||
|
CircuitNode** prev);
|
||||||
|
|
||||||
|
vector<LogVarTypes> getAllPossibleTypes (unsigned nrLogVars) const;
|
||||||
|
|
||||||
|
bool containsTypes (const LogVarTypes& typesA,
|
||||||
|
const LogVarTypes& typesB) const;
|
||||||
|
|
||||||
|
CircuitNodeType getCircuitNodeType (const CircuitNode* node) const;
|
||||||
|
|
||||||
|
void exportToGraphViz (CircuitNode* node, ofstream&);
|
||||||
|
|
||||||
|
void printClauses (CircuitNode* node, ofstream&,
|
||||||
|
string extraOptions = "");
|
||||||
|
|
||||||
|
string escapeNode (const CircuitNode* node) const;
|
||||||
|
|
||||||
|
string getExplanationString (CircuitNode* node);
|
||||||
|
|
||||||
|
CircuitNode* root_;
|
||||||
|
const LiftedWCNF* lwcnf_;
|
||||||
|
bool compilationSucceeded_;
|
||||||
|
Clauses backupClauses_;
|
||||||
|
unordered_map<CircuitNode*, Clauses> originClausesMap_;
|
||||||
|
unordered_map<CircuitNode*, string> explanationMap_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class LiftedKc : public LiftedSolver
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
LiftedKc (const ParfactorList& pfList)
|
||||||
|
: LiftedSolver(pfList) { }
|
||||||
|
|
||||||
|
~LiftedKc (void);
|
||||||
|
|
||||||
|
Params solveQuery (const Grounds&);
|
||||||
|
|
||||||
|
void printSolverFlags (void) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
LiftedWCNF* lwcnf_;
|
||||||
|
LiftedCircuit* circuit_;
|
||||||
|
ParfactorList pfList_;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // HORUS_LIFTEDKC_H
|
||||||
|
|
271
packages/CLPBN/horus2/LiftedOperations.cpp
Normal file
271
packages/CLPBN/horus2/LiftedOperations.cpp
Normal file
@ -0,0 +1,271 @@
|
|||||||
|
#include "LiftedOperations.h"
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
LiftedOperations::shatterAgainstQuery (
|
||||||
|
ParfactorList& pfList,
|
||||||
|
const Grounds& query)
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < query.size(); i++) {
|
||||||
|
if (query[i].isAtom()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
bool found = false;
|
||||||
|
Parfactors newPfs;
|
||||||
|
ParfactorList::iterator it = pfList.begin();
|
||||||
|
while (it != pfList.end()) {
|
||||||
|
if ((*it)->containsGround (query[i])) {
|
||||||
|
found = true;
|
||||||
|
std::pair<ConstraintTree*, ConstraintTree*> split;
|
||||||
|
LogVars queryLvs (
|
||||||
|
(*it)->constr()->logVars().begin(),
|
||||||
|
(*it)->constr()->logVars().begin() + query[i].arity());
|
||||||
|
split = (*it)->constr()->split (query[i].args());
|
||||||
|
ConstraintTree* commCt = split.first;
|
||||||
|
ConstraintTree* exclCt = split.second;
|
||||||
|
newPfs.push_back (new Parfactor (*it, commCt));
|
||||||
|
if (exclCt->empty() == false) {
|
||||||
|
newPfs.push_back (new Parfactor (*it, exclCt));
|
||||||
|
} else {
|
||||||
|
delete exclCt;
|
||||||
|
}
|
||||||
|
it = pfList.removeAndDelete (it);
|
||||||
|
} else {
|
||||||
|
++ it;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (found == false) {
|
||||||
|
cerr << "Error: could not find a parfactor with ground " ;
|
||||||
|
cerr << "`" << query[i] << "'." << endl;
|
||||||
|
exit (EXIT_FAILURE);
|
||||||
|
}
|
||||||
|
pfList.add (newPfs);
|
||||||
|
}
|
||||||
|
if (Globals::verbosity > 2) {
|
||||||
|
Util::printAsteriskLine();
|
||||||
|
cout << "SHATTERED AGAINST THE QUERY" << endl;
|
||||||
|
for (size_t i = 0; i < query.size(); i++) {
|
||||||
|
cout << " -> " << query[i] << endl;
|
||||||
|
}
|
||||||
|
Util::printAsteriskLine();
|
||||||
|
pfList.print();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
LiftedOperations::runWeakBayesBall (
|
||||||
|
ParfactorList& pfList,
|
||||||
|
const Grounds& query)
|
||||||
|
{
|
||||||
|
queue<PrvGroup> todo; // groups to process
|
||||||
|
set<PrvGroup> done; // processed or in queue
|
||||||
|
for (size_t i = 0; i < query.size(); i++) {
|
||||||
|
ParfactorList::iterator it = pfList.begin();
|
||||||
|
while (it != pfList.end()) {
|
||||||
|
PrvGroup group = (*it)->findGroup (query[i]);
|
||||||
|
if (group != numeric_limits<PrvGroup>::max()) {
|
||||||
|
todo.push (group);
|
||||||
|
done.insert (group);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
++ it;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
set<Parfactor*> requiredPfs;
|
||||||
|
while (todo.empty() == false) {
|
||||||
|
PrvGroup group = todo.front();
|
||||||
|
ParfactorList::iterator it = pfList.begin();
|
||||||
|
while (it != pfList.end()) {
|
||||||
|
if (Util::contains (requiredPfs, *it) == false &&
|
||||||
|
(*it)->containsGroup (group)) {
|
||||||
|
vector<PrvGroup> groups = (*it)->getAllGroups();
|
||||||
|
for (size_t i = 0; i < groups.size(); i++) {
|
||||||
|
if (Util::contains (done, groups[i]) == false) {
|
||||||
|
todo.push (groups[i]);
|
||||||
|
done.insert (groups[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
requiredPfs.insert (*it);
|
||||||
|
}
|
||||||
|
++ it;
|
||||||
|
}
|
||||||
|
todo.pop();
|
||||||
|
}
|
||||||
|
|
||||||
|
ParfactorList::iterator it = pfList.begin();
|
||||||
|
bool foundNotRequired = false;
|
||||||
|
while (it != pfList.end()) {
|
||||||
|
if (Util::contains (requiredPfs, *it) == false) {
|
||||||
|
if (Globals::verbosity > 2) {
|
||||||
|
if (foundNotRequired == false) {
|
||||||
|
Util::printHeader ("PARFACTORS TO DISCARD");
|
||||||
|
foundNotRequired = true;
|
||||||
|
}
|
||||||
|
(*it)->print();
|
||||||
|
}
|
||||||
|
it = pfList.removeAndDelete (it);
|
||||||
|
} else {
|
||||||
|
++ it;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
LiftedOperations::absorveEvidence (
|
||||||
|
ParfactorList& pfList,
|
||||||
|
ObservedFormulas& obsFormulas)
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < obsFormulas.size(); i++) {
|
||||||
|
Parfactors newPfs;
|
||||||
|
ParfactorList::iterator it = pfList.begin();
|
||||||
|
while (it != pfList.end()) {
|
||||||
|
Parfactor* pf = *it;
|
||||||
|
it = pfList.remove (it);
|
||||||
|
Parfactors absorvedPfs = absorve (obsFormulas[i], pf);
|
||||||
|
if (absorvedPfs.empty() == false) {
|
||||||
|
if (absorvedPfs.size() == 1 && absorvedPfs[0] == 0) {
|
||||||
|
// just remove pf;
|
||||||
|
} else {
|
||||||
|
Util::addToVector (newPfs, absorvedPfs);
|
||||||
|
}
|
||||||
|
delete pf;
|
||||||
|
} else {
|
||||||
|
it = pfList.insertShattered (it, pf);
|
||||||
|
++ it;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pfList.add (newPfs);
|
||||||
|
}
|
||||||
|
if (Globals::verbosity > 2 && obsFormulas.empty() == false) {
|
||||||
|
Util::printAsteriskLine();
|
||||||
|
cout << "AFTER EVIDENCE ABSORVED" << endl;
|
||||||
|
for (size_t i = 0; i < obsFormulas.size(); i++) {
|
||||||
|
cout << " -> " << obsFormulas[i] << endl;
|
||||||
|
}
|
||||||
|
Util::printAsteriskLine();
|
||||||
|
pfList.print();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Parfactors
|
||||||
|
LiftedOperations::countNormalize (
|
||||||
|
Parfactor* g,
|
||||||
|
const LogVarSet& set)
|
||||||
|
{
|
||||||
|
Parfactors normPfs;
|
||||||
|
if (set.empty()) {
|
||||||
|
normPfs.push_back (new Parfactor (*g));
|
||||||
|
} else {
|
||||||
|
ConstraintTrees normCts = g->constr()->countNormalize (set);
|
||||||
|
for (size_t i = 0; i < normCts.size(); i++) {
|
||||||
|
normPfs.push_back (new Parfactor (g, normCts[i]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return normPfs;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Parfactor
|
||||||
|
LiftedOperations::calcGroundMultiplication (Parfactor pf)
|
||||||
|
{
|
||||||
|
LogVarSet lvs = pf.constr()->logVarSet();
|
||||||
|
lvs -= pf.constr()->singletons();
|
||||||
|
Parfactors newPfs = {new Parfactor (pf)};
|
||||||
|
for (size_t i = 0; i < lvs.size(); i++) {
|
||||||
|
Parfactors pfs = newPfs;
|
||||||
|
newPfs.clear();
|
||||||
|
for (size_t j = 0; j < pfs.size(); j++) {
|
||||||
|
bool countedLv = pfs[j]->countedLogVars().contains (lvs[i]);
|
||||||
|
if (countedLv) {
|
||||||
|
pfs[j]->fullExpand (lvs[i]);
|
||||||
|
newPfs.push_back (pfs[j]);
|
||||||
|
} else {
|
||||||
|
ConstraintTrees cts = pfs[j]->constr()->ground (lvs[i]);
|
||||||
|
for (size_t k = 0; k < cts.size(); k++) {
|
||||||
|
newPfs.push_back (new Parfactor (pfs[j], cts[k]));
|
||||||
|
}
|
||||||
|
delete pfs[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ParfactorList pfList (newPfs);
|
||||||
|
Parfactors groundShatteredPfs (pfList.begin(),pfList.end());
|
||||||
|
for (size_t i = 1; i < groundShatteredPfs.size(); i++) {
|
||||||
|
groundShatteredPfs[0]->multiply (*groundShatteredPfs[i]);
|
||||||
|
}
|
||||||
|
return Parfactor (*groundShatteredPfs[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Parfactors
|
||||||
|
LiftedOperations::absorve (
|
||||||
|
ObservedFormula& obsFormula,
|
||||||
|
Parfactor* g)
|
||||||
|
{
|
||||||
|
Parfactors absorvedPfs;
|
||||||
|
const ProbFormulas& formulas = g->arguments();
|
||||||
|
for (size_t i = 0; i < formulas.size(); i++) {
|
||||||
|
if (obsFormula.functor() == formulas[i].functor() &&
|
||||||
|
obsFormula.arity() == formulas[i].arity()) {
|
||||||
|
|
||||||
|
if (obsFormula.isAtom()) {
|
||||||
|
if (formulas.size() > 1) {
|
||||||
|
g->absorveEvidence (formulas[i], obsFormula.evidence());
|
||||||
|
} else {
|
||||||
|
// hack to erase parfactor g
|
||||||
|
absorvedPfs.push_back (0);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
g->constr()->moveToTop (formulas[i].logVars());
|
||||||
|
std::pair<ConstraintTree*, ConstraintTree*> res;
|
||||||
|
res = g->constr()->split (
|
||||||
|
formulas[i].logVars(),
|
||||||
|
&(obsFormula.constr()),
|
||||||
|
obsFormula.constr().logVars());
|
||||||
|
ConstraintTree* commCt = res.first;
|
||||||
|
ConstraintTree* exclCt = res.second;
|
||||||
|
|
||||||
|
if (commCt->empty() == false) {
|
||||||
|
if (formulas.size() > 1) {
|
||||||
|
LogVarSet excl = g->exclusiveLogVars (i);
|
||||||
|
Parfactor tempPf (g, commCt);
|
||||||
|
Parfactors countNormPfs = LiftedOperations::countNormalize (
|
||||||
|
&tempPf, excl);
|
||||||
|
for (size_t j = 0; j < countNormPfs.size(); j++) {
|
||||||
|
countNormPfs[j]->absorveEvidence (
|
||||||
|
formulas[i], obsFormula.evidence());
|
||||||
|
absorvedPfs.push_back (countNormPfs[j]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
delete commCt;
|
||||||
|
}
|
||||||
|
if (exclCt->empty() == false) {
|
||||||
|
absorvedPfs.push_back (new Parfactor (g, exclCt));
|
||||||
|
} else {
|
||||||
|
delete exclCt;
|
||||||
|
}
|
||||||
|
if (absorvedPfs.empty()) {
|
||||||
|
// hack to erase parfactor g
|
||||||
|
absorvedPfs.push_back (0);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
delete commCt;
|
||||||
|
delete exclCt;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return absorvedPfs;
|
||||||
|
}
|
||||||
|
|
27
packages/CLPBN/horus2/LiftedOperations.h
Normal file
27
packages/CLPBN/horus2/LiftedOperations.h
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
#ifndef HORUS_LIFTEDOPERATIONS_H
|
||||||
|
#define HORUS_LIFTEDOPERATIONS_H
|
||||||
|
|
||||||
|
#include "ParfactorList.h"
|
||||||
|
|
||||||
|
class LiftedOperations
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
static void shatterAgainstQuery (
|
||||||
|
ParfactorList& pfList, const Grounds& query);
|
||||||
|
|
||||||
|
static void runWeakBayesBall (
|
||||||
|
ParfactorList& pfList, const Grounds&);
|
||||||
|
|
||||||
|
static void absorveEvidence (
|
||||||
|
ParfactorList& pfList, ObservedFormulas& obsFormulas);
|
||||||
|
|
||||||
|
static Parfactors countNormalize (Parfactor*, const LogVarSet&);
|
||||||
|
|
||||||
|
static Parfactor calcGroundMultiplication (Parfactor pf);
|
||||||
|
|
||||||
|
private:
|
||||||
|
static Parfactors absorve (ObservedFormula&, Parfactor*);
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // HORUS_LIFTEDOPERATIONS_H
|
||||||
|
|
27
packages/CLPBN/horus2/LiftedSolver.h
Normal file
27
packages/CLPBN/horus2/LiftedSolver.h
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
#ifndef HORUS_LIFTEDSOLVER_H
|
||||||
|
#define HORUS_LIFTEDSOLVER_H
|
||||||
|
|
||||||
|
#include "ParfactorList.h"
|
||||||
|
#include "Horus.h"
|
||||||
|
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
class LiftedSolver
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
LiftedSolver (const ParfactorList& pfList)
|
||||||
|
: parfactorList(pfList) { }
|
||||||
|
|
||||||
|
virtual ~LiftedSolver() { } // ensure that subclass destructor is called
|
||||||
|
|
||||||
|
virtual Params solveQuery (const Grounds& query) = 0;
|
||||||
|
|
||||||
|
virtual void printSolverFlags (void) const = 0;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
const ParfactorList& parfactorList;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // HORUS_LIFTEDSOLVER_H
|
||||||
|
|
131
packages/CLPBN/horus2/LiftedUtils.cpp
Normal file
131
packages/CLPBN/horus2/LiftedUtils.cpp
Normal file
@ -0,0 +1,131 @@
|
|||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <iostream>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "LiftedUtils.h"
|
||||||
|
#include "ConstraintTree.h"
|
||||||
|
|
||||||
|
|
||||||
|
namespace LiftedUtils {
|
||||||
|
|
||||||
|
|
||||||
|
unordered_map<string, unsigned> symbolDict;
|
||||||
|
|
||||||
|
|
||||||
|
Symbol
|
||||||
|
getSymbol (const string& symbolName)
|
||||||
|
{
|
||||||
|
unordered_map<string, unsigned>::iterator it
|
||||||
|
= symbolDict.find (symbolName);
|
||||||
|
if (it != symbolDict.end()) {
|
||||||
|
return it->second;
|
||||||
|
} else {
|
||||||
|
symbolDict[symbolName] = symbolDict.size() - 1;
|
||||||
|
return symbolDict.size() - 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
printSymbolDictionary (void)
|
||||||
|
{
|
||||||
|
unordered_map<string, unsigned>::const_iterator it
|
||||||
|
= symbolDict.begin();
|
||||||
|
while (it != symbolDict.end()) {
|
||||||
|
cout << it->first << " -> " << it->second << endl;
|
||||||
|
++ it;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
ostream& operator<< (ostream &os, const Symbol& s)
|
||||||
|
{
|
||||||
|
unordered_map<string, unsigned>::const_iterator it
|
||||||
|
= LiftedUtils::symbolDict.begin();
|
||||||
|
while (it != LiftedUtils::symbolDict.end() && it->second != s) {
|
||||||
|
++ it;
|
||||||
|
}
|
||||||
|
assert (it != LiftedUtils::symbolDict.end());
|
||||||
|
os << it->first;
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
ostream& operator<< (ostream &os, const LogVar& X)
|
||||||
|
{
|
||||||
|
const string labels[] = {
|
||||||
|
"A", "B", "C", "D", "E", "F",
|
||||||
|
"G", "H", "I", "J", "K", "M" };
|
||||||
|
(X >= 12) ? os << "X_" << X.id_ : os << labels[X];
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
ostream& operator<< (ostream &os, const Tuple& t)
|
||||||
|
{
|
||||||
|
os << "(" ;
|
||||||
|
for (size_t i = 0; i < t.size(); i++) {
|
||||||
|
os << ((i != 0) ? "," : "") << t[i];
|
||||||
|
}
|
||||||
|
os << ")" ;
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
ostream& operator<< (ostream &os, const Ground& gr)
|
||||||
|
{
|
||||||
|
os << gr.functor();
|
||||||
|
os << "(" ;
|
||||||
|
for (size_t i = 0; i < gr.args().size(); i++) {
|
||||||
|
if (i != 0) os << ", " ;
|
||||||
|
os << gr.args()[i];
|
||||||
|
}
|
||||||
|
os << ")" ;
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
LogVars
|
||||||
|
Substitution::getDiscardedLogVars (void) const
|
||||||
|
{
|
||||||
|
LogVars discardedLvs;
|
||||||
|
set<LogVar> doneLvs;
|
||||||
|
unordered_map<LogVar, LogVar>::const_iterator it;
|
||||||
|
it = subs_.begin();
|
||||||
|
while (it != subs_.end()) {
|
||||||
|
if (Util::contains (doneLvs, it->second)) {
|
||||||
|
discardedLvs.push_back (it->first);
|
||||||
|
} else {
|
||||||
|
doneLvs.insert (it->second);
|
||||||
|
}
|
||||||
|
++ it;
|
||||||
|
}
|
||||||
|
return discardedLvs;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
ostream& operator<< (ostream &os, const Substitution& theta)
|
||||||
|
{
|
||||||
|
unordered_map<LogVar, LogVar>::const_iterator it;
|
||||||
|
os << "[" ;
|
||||||
|
it = theta.subs_.begin();
|
||||||
|
while (it != theta.subs_.end()) {
|
||||||
|
if (it != theta.subs_.begin()) os << ", " ;
|
||||||
|
os << it->first << "->" << it->second ;
|
||||||
|
++ it;
|
||||||
|
}
|
||||||
|
os << "]" ;
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
166
packages/CLPBN/horus2/LiftedUtils.h
Normal file
166
packages/CLPBN/horus2/LiftedUtils.h
Normal file
@ -0,0 +1,166 @@
|
|||||||
|
#ifndef HORUS_LIFTEDUTILS_H
|
||||||
|
#define HORUS_LIFTEDUTILS_H
|
||||||
|
|
||||||
|
#include <limits>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
|
|
||||||
|
#include "TinySet.h"
|
||||||
|
#include "Util.h"
|
||||||
|
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
|
||||||
|
class Symbol
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
Symbol (void) : id_(Util::maxUnsigned()) { }
|
||||||
|
|
||||||
|
Symbol (unsigned id) : id_(id) { }
|
||||||
|
|
||||||
|
operator unsigned (void) const { return id_; }
|
||||||
|
|
||||||
|
bool valid (void) const { return id_ != Util::maxUnsigned(); }
|
||||||
|
|
||||||
|
static Symbol invalid (void) { return Symbol(); }
|
||||||
|
|
||||||
|
friend ostream& operator<< (ostream &os, const Symbol& s);
|
||||||
|
|
||||||
|
private:
|
||||||
|
unsigned id_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
class LogVar
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
LogVar (void) : id_(Util::maxUnsigned()) { }
|
||||||
|
|
||||||
|
LogVar (unsigned id) : id_(id) { }
|
||||||
|
|
||||||
|
operator unsigned (void) const { return id_; }
|
||||||
|
|
||||||
|
LogVar& operator++ (void)
|
||||||
|
{
|
||||||
|
assert (valid());
|
||||||
|
id_ ++;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool valid (void) const
|
||||||
|
{
|
||||||
|
return id_ != Util::maxUnsigned();
|
||||||
|
}
|
||||||
|
|
||||||
|
friend ostream& operator<< (ostream &os, const LogVar& X);
|
||||||
|
|
||||||
|
private:
|
||||||
|
unsigned id_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
namespace std {
|
||||||
|
template <> struct hash<Symbol> {
|
||||||
|
size_t operator() (const Symbol& s) const {
|
||||||
|
return std::hash<unsigned>() (s);
|
||||||
|
}};
|
||||||
|
|
||||||
|
template <> struct hash<LogVar> {
|
||||||
|
size_t operator() (const LogVar& X) const {
|
||||||
|
return std::hash<unsigned>() (X);
|
||||||
|
}};
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
typedef vector<Symbol> Symbols;
|
||||||
|
typedef vector<Symbol> Tuple;
|
||||||
|
typedef vector<Tuple> Tuples;
|
||||||
|
typedef vector<LogVar> LogVars;
|
||||||
|
typedef TinySet<Symbol> SymbolSet;
|
||||||
|
typedef TinySet<LogVar> LogVarSet;
|
||||||
|
typedef TinySet<Tuple> TupleSet;
|
||||||
|
|
||||||
|
|
||||||
|
ostream& operator<< (ostream &os, const Tuple& t);
|
||||||
|
|
||||||
|
|
||||||
|
namespace LiftedUtils {
|
||||||
|
Symbol getSymbol (const string&);
|
||||||
|
void printSymbolDictionary (void);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Ground
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
Ground (Symbol f) : functor_(f) { }
|
||||||
|
|
||||||
|
Ground (Symbol f, const Symbols& args) : functor_(f), args_(args) { }
|
||||||
|
|
||||||
|
Symbol functor (void) const { return functor_; }
|
||||||
|
|
||||||
|
Symbols args (void) const { return args_; }
|
||||||
|
|
||||||
|
size_t arity (void) const { return args_.size(); }
|
||||||
|
|
||||||
|
bool isAtom (void) const { return args_.size() == 0; }
|
||||||
|
|
||||||
|
friend ostream& operator<< (ostream &os, const Ground& gr);
|
||||||
|
|
||||||
|
private:
|
||||||
|
Symbol functor_;
|
||||||
|
Symbols args_;
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef vector<Ground> Grounds;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Substitution
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
void add (LogVar X_old, LogVar X_new)
|
||||||
|
{
|
||||||
|
assert (Util::contains (subs_, X_old) == false);
|
||||||
|
subs_.insert (make_pair (X_old, X_new));
|
||||||
|
}
|
||||||
|
|
||||||
|
void rename (LogVar X_old, LogVar X_new)
|
||||||
|
{
|
||||||
|
assert (Util::contains (subs_, X_old));
|
||||||
|
subs_.find (X_old)->second = X_new;
|
||||||
|
}
|
||||||
|
|
||||||
|
LogVar newNameFor (LogVar X) const
|
||||||
|
{
|
||||||
|
unordered_map<LogVar, LogVar>::const_iterator it;
|
||||||
|
it = subs_.find (X);
|
||||||
|
if (it != subs_.end()) {
|
||||||
|
return subs_.find (X)->second;
|
||||||
|
}
|
||||||
|
return X;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool containsReplacementFor (LogVar X) const
|
||||||
|
{
|
||||||
|
return Util::contains (subs_, X);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t nrReplacements (void) const { return subs_.size(); }
|
||||||
|
|
||||||
|
LogVars getDiscardedLogVars (void) const;
|
||||||
|
|
||||||
|
friend ostream& operator<< (ostream &os, const Substitution& theta);
|
||||||
|
|
||||||
|
private:
|
||||||
|
unordered_map<LogVar, LogVar> subs_;
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
#endif // HORUS_LIFTEDUTILS_H
|
||||||
|
|
728
packages/CLPBN/horus2/LiftedVe.cpp
Normal file
728
packages/CLPBN/horus2/LiftedVe.cpp
Normal file
@ -0,0 +1,728 @@
|
|||||||
|
#include <algorithm>
|
||||||
|
#include <set>
|
||||||
|
|
||||||
|
#include "LiftedVe.h"
|
||||||
|
#include "LiftedOperations.h"
|
||||||
|
#include "Histogram.h"
|
||||||
|
#include "Util.h"
|
||||||
|
|
||||||
|
|
||||||
|
vector<LiftedOperator*>
|
||||||
|
LiftedOperator::getValidOps (
|
||||||
|
ParfactorList& pfList,
|
||||||
|
const Grounds& query)
|
||||||
|
{
|
||||||
|
vector<LiftedOperator*> validOps;
|
||||||
|
vector<ProductOperator*> multOps;
|
||||||
|
|
||||||
|
multOps = ProductOperator::getValidOps (pfList);
|
||||||
|
validOps.insert (validOps.end(), multOps.begin(), multOps.end());
|
||||||
|
|
||||||
|
if (Globals::verbosity > 1 || multOps.empty()) {
|
||||||
|
vector<SumOutOperator*> sumOutOps;
|
||||||
|
vector<CountingOperator*> countOps;
|
||||||
|
vector<GroundOperator*> groundOps;
|
||||||
|
sumOutOps = SumOutOperator::getValidOps (pfList, query);
|
||||||
|
countOps = CountingOperator::getValidOps (pfList);
|
||||||
|
groundOps = GroundOperator::getValidOps (pfList);
|
||||||
|
validOps.insert (validOps.end(), sumOutOps.begin(), sumOutOps.end());
|
||||||
|
validOps.insert (validOps.end(), countOps.begin(), countOps.end());
|
||||||
|
validOps.insert (validOps.end(), groundOps.begin(), groundOps.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
return validOps;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
LiftedOperator::printValidOps (
|
||||||
|
ParfactorList& pfList,
|
||||||
|
const Grounds& query)
|
||||||
|
{
|
||||||
|
vector<LiftedOperator*> validOps;
|
||||||
|
validOps = LiftedOperator::getValidOps (pfList, query);
|
||||||
|
for (size_t i = 0; i < validOps.size(); i++) {
|
||||||
|
cout << "-> " << validOps[i]->toString();
|
||||||
|
delete validOps[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
vector<ParfactorList::iterator>
|
||||||
|
LiftedOperator::getParfactorsWithGroup (
|
||||||
|
ParfactorList& pfList, PrvGroup group)
|
||||||
|
{
|
||||||
|
vector<ParfactorList::iterator> iters;
|
||||||
|
ParfactorList::iterator pflIt = pfList.begin();
|
||||||
|
while (pflIt != pfList.end()) {
|
||||||
|
if ((*pflIt)->containsGroup (group)) {
|
||||||
|
iters.push_back (pflIt);
|
||||||
|
}
|
||||||
|
++ pflIt;
|
||||||
|
}
|
||||||
|
return iters;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
double
|
||||||
|
ProductOperator::getLogCost (void)
|
||||||
|
{
|
||||||
|
return std::log (0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
ProductOperator::apply (void)
|
||||||
|
{
|
||||||
|
Parfactor* g1 = *g1_;
|
||||||
|
Parfactor* g2 = *g2_;
|
||||||
|
g1->multiply (*g2);
|
||||||
|
pfList_.remove (g1_);
|
||||||
|
pfList_.removeAndDelete (g2_);
|
||||||
|
pfList_.addShattered (g1);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
vector<ProductOperator*>
|
||||||
|
ProductOperator::getValidOps (ParfactorList& pfList)
|
||||||
|
{
|
||||||
|
vector<ProductOperator*> validOps;
|
||||||
|
ParfactorList::iterator it1 = pfList.begin();
|
||||||
|
ParfactorList::iterator penultimate = -- pfList.end();
|
||||||
|
set<Parfactor*> pfs;
|
||||||
|
while (it1 != penultimate) {
|
||||||
|
if (Util::contains (pfs, *it1)) {
|
||||||
|
++ it1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
ParfactorList::iterator it2 = it1;
|
||||||
|
++ it2;
|
||||||
|
while (it2 != pfList.end()) {
|
||||||
|
if (Util::contains (pfs, *it2)) {
|
||||||
|
++ it2;
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
if (validOp (*it1, *it2)) {
|
||||||
|
pfs.insert (*it1);
|
||||||
|
pfs.insert (*it2);
|
||||||
|
validOps.push_back (new ProductOperator (
|
||||||
|
it1, it2, pfList));
|
||||||
|
if (Globals::verbosity < 2) {
|
||||||
|
return validOps;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
++ it2;
|
||||||
|
}
|
||||||
|
++ it1;
|
||||||
|
}
|
||||||
|
return validOps;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
string
|
||||||
|
ProductOperator::toString (void)
|
||||||
|
{
|
||||||
|
stringstream ss;
|
||||||
|
ss << "just multiplicate " ;
|
||||||
|
ss << (*g1_)->getAllGroups();
|
||||||
|
ss << " x " ;
|
||||||
|
ss << (*g2_)->getAllGroups();
|
||||||
|
ss << " [cost=" << std::exp (getLogCost()) << "]" << endl;
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
ProductOperator::validOp (Parfactor* g1, Parfactor* g2)
|
||||||
|
{
|
||||||
|
TinySet<PrvGroup> g1_gs (g1->getAllGroups());
|
||||||
|
TinySet<PrvGroup> g2_gs (g2->getAllGroups());
|
||||||
|
if (g1_gs.contains (g2_gs) || g2_gs.contains (g1_gs)) {
|
||||||
|
TinySet<PrvGroup> intersect = g1_gs & g2_gs;
|
||||||
|
for (size_t i = 0; i < intersect.size(); i++) {
|
||||||
|
if (g1->nrFormulasWithGroup (intersect[i]) != 1 ||
|
||||||
|
g2->nrFormulasWithGroup (intersect[i]) != 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
size_t idx1 = g1->indexOfGroup (intersect[i]);
|
||||||
|
size_t idx2 = g2->indexOfGroup (intersect[i]);
|
||||||
|
if (g1->range (idx1) != g2->range (idx2)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Parfactor::canMultiply (g1, g2);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
double
|
||||||
|
SumOutOperator::getLogCost (void)
|
||||||
|
{
|
||||||
|
TinySet<PrvGroup> groupSet;
|
||||||
|
ParfactorList::const_iterator pfIter = pfList_.begin();
|
||||||
|
unsigned nrProdFactors = 0;
|
||||||
|
while (pfIter != pfList_.end()) {
|
||||||
|
if ((*pfIter)->containsGroup (group_)) {
|
||||||
|
vector<PrvGroup> groups = (*pfIter)->getAllGroups();
|
||||||
|
groupSet |= TinySet<PrvGroup> (groups);
|
||||||
|
++ nrProdFactors;
|
||||||
|
}
|
||||||
|
++ pfIter;
|
||||||
|
}
|
||||||
|
if (nrProdFactors == 1) {
|
||||||
|
// best possible case
|
||||||
|
return std::log (0.0);
|
||||||
|
}
|
||||||
|
double cost = 1.0;
|
||||||
|
for (size_t i = 0; i < groupSet.size(); i++) {
|
||||||
|
pfIter = pfList_.begin();
|
||||||
|
while (pfIter != pfList_.end()) {
|
||||||
|
if ((*pfIter)->containsGroup (groupSet[i])) {
|
||||||
|
size_t idx = (*pfIter)->indexOfGroup (groupSet[i]);
|
||||||
|
cost *= (*pfIter)->range (idx);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
++ pfIter;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return std::log (cost);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
SumOutOperator::apply (void)
|
||||||
|
{
|
||||||
|
vector<ParfactorList::iterator> iters;
|
||||||
|
iters = getParfactorsWithGroup (pfList_, group_);
|
||||||
|
Parfactor* product = *(iters[0]);
|
||||||
|
pfList_.remove (iters[0]);
|
||||||
|
for (size_t i = 1; i < iters.size(); i++) {
|
||||||
|
product->multiply (**(iters[i]));
|
||||||
|
pfList_.removeAndDelete (iters[i]);
|
||||||
|
}
|
||||||
|
if (product->nrArguments() == 1) {
|
||||||
|
delete product;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
size_t fIdx = product->indexOfGroup (group_);
|
||||||
|
LogVarSet excl = product->exclusiveLogVars (fIdx);
|
||||||
|
if (product->constr()->isCountNormalized (excl)) {
|
||||||
|
product->sumOutIndex (fIdx);
|
||||||
|
pfList_.addShattered (product);
|
||||||
|
} else {
|
||||||
|
Parfactors pfs = LiftedOperations::countNormalize (product, excl);
|
||||||
|
for (size_t i = 0; i < pfs.size(); i++) {
|
||||||
|
pfs[i]->sumOutIndex (fIdx);
|
||||||
|
pfList_.add (pfs[i]);
|
||||||
|
}
|
||||||
|
delete product;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
vector<SumOutOperator*>
|
||||||
|
SumOutOperator::getValidOps (
|
||||||
|
ParfactorList& pfList,
|
||||||
|
const Grounds& query)
|
||||||
|
{
|
||||||
|
vector<SumOutOperator*> validOps;
|
||||||
|
set<PrvGroup> allGroups;
|
||||||
|
ParfactorList::const_iterator it = pfList.begin();
|
||||||
|
while (it != pfList.end()) {
|
||||||
|
const ProbFormulas& formulas = (*it)->arguments();
|
||||||
|
for (size_t i = 0; i < formulas.size(); i++) {
|
||||||
|
allGroups.insert (formulas[i].group());
|
||||||
|
}
|
||||||
|
++ it;
|
||||||
|
}
|
||||||
|
set<PrvGroup>::const_iterator groupIt = allGroups.begin();
|
||||||
|
while (groupIt != allGroups.end()) {
|
||||||
|
if (validOp (*groupIt, pfList, query)) {
|
||||||
|
validOps.push_back (new SumOutOperator (*groupIt, pfList));
|
||||||
|
}
|
||||||
|
++ groupIt;
|
||||||
|
}
|
||||||
|
return validOps;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
string
|
||||||
|
SumOutOperator::toString (void)
|
||||||
|
{
|
||||||
|
stringstream ss;
|
||||||
|
vector<ParfactorList::iterator> pfIters;
|
||||||
|
pfIters = getParfactorsWithGroup (pfList_, group_);
|
||||||
|
size_t idx = (*pfIters[0])->indexOfGroup (group_);
|
||||||
|
ProbFormula f = (*pfIters[0])->argument (idx);
|
||||||
|
TupleSet tupleSet = (*pfIters[0])->constr()->tupleSet (f.logVars());
|
||||||
|
ss << "sum out " << f.functor() << "/" << f.arity();
|
||||||
|
ss << "|" << tupleSet << " (group " << group_ << ")";
|
||||||
|
ss << " [cost=" << std::exp (getLogCost()) << "]" << endl;
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
SumOutOperator::validOp (
|
||||||
|
PrvGroup group,
|
||||||
|
ParfactorList& pfList,
|
||||||
|
const Grounds& query)
|
||||||
|
{
|
||||||
|
vector<ParfactorList::iterator> pfIters;
|
||||||
|
pfIters = getParfactorsWithGroup (pfList, group);
|
||||||
|
if (isToEliminate (*pfIters[0], group, query) == false) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
int range = -1;
|
||||||
|
for (size_t i = 0; i < pfIters.size(); i++) {
|
||||||
|
if ((*pfIters[i])->nrFormulasWithGroup (group) > 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
size_t fIdx = (*pfIters[i])->indexOfGroup (group);
|
||||||
|
if ((*pfIters[i])->argument (fIdx).contains (
|
||||||
|
(*pfIters[i])->elimLogVars()) == false) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (range == -1) {
|
||||||
|
range = (*pfIters[i])->range (fIdx);
|
||||||
|
} else if ((int)(*pfIters[i])->range (fIdx) != range) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
SumOutOperator::isToEliminate (
|
||||||
|
Parfactor* g,
|
||||||
|
PrvGroup group,
|
||||||
|
const Grounds& query)
|
||||||
|
{
|
||||||
|
size_t fIdx = g->indexOfGroup (group);
|
||||||
|
const ProbFormula& formula = g->argument (fIdx);
|
||||||
|
bool toElim = true;
|
||||||
|
for (size_t i = 0; i < query.size(); i++) {
|
||||||
|
if (formula.functor() == query[i].functor() &&
|
||||||
|
formula.arity() == query[i].arity()) {
|
||||||
|
g->constr()->moveToTop (formula.logVars());
|
||||||
|
if (g->constr()->containsTuple (query[i].args())) {
|
||||||
|
toElim = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return toElim;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
double
|
||||||
|
CountingOperator::getLogCost (void)
|
||||||
|
{
|
||||||
|
double cost = 0.0;
|
||||||
|
size_t fIdx = (*pfIter_)->indexOfLogVar (X_);
|
||||||
|
unsigned range = (*pfIter_)->range (fIdx);
|
||||||
|
unsigned size = (*pfIter_)->size() / range;
|
||||||
|
TinySet<unsigned> counts;
|
||||||
|
counts = (*pfIter_)->constr()->getConditionalCounts (X_);
|
||||||
|
for (size_t i = 0; i < counts.size(); i++) {
|
||||||
|
cost += size * HistogramSet::nrHistograms (counts[i], range);
|
||||||
|
}
|
||||||
|
PrvGroup group = (*pfIter_)->argument (fIdx).group();
|
||||||
|
size_t lvIndex = Util::indexOf (
|
||||||
|
(*pfIter_)->argument (fIdx).logVars(), X_);
|
||||||
|
assert (lvIndex != (*pfIter_)->argument (fIdx).logVars().size());
|
||||||
|
ParfactorList::iterator pfIter = pfList_.begin();
|
||||||
|
while (pfIter != pfList_.end()) {
|
||||||
|
if (pfIter != pfIter_) {
|
||||||
|
size_t fIdx2 = (*pfIter)->indexOfGroup (group);
|
||||||
|
if (fIdx2 != (*pfIter)->nrArguments()) {
|
||||||
|
LogVar Y = ((*pfIter)->argument (fIdx2).logVars()[lvIndex]);
|
||||||
|
if ((*pfIter)->canCountConvert (Y) == false) {
|
||||||
|
// the real cost should be the cost of grounding Y
|
||||||
|
cost *= 10.0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
++ pfIter;
|
||||||
|
}
|
||||||
|
return std::log (cost);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
CountingOperator::apply (void)
|
||||||
|
{
|
||||||
|
if ((*pfIter_)->constr()->isCountNormalized (X_)) {
|
||||||
|
(*pfIter_)->countConvert (X_);
|
||||||
|
} else {
|
||||||
|
Parfactor* pf = *pfIter_;
|
||||||
|
pfList_.remove (pfIter_);
|
||||||
|
Parfactors pfs = LiftedOperations::countNormalize (pf, X_);
|
||||||
|
for (size_t i = 0; i < pfs.size(); i++) {
|
||||||
|
unsigned condCount = pfs[i]->constr()->getConditionalCount (X_);
|
||||||
|
bool cartProduct = pfs[i]->constr()->isCartesianProduct (
|
||||||
|
pfs[i]->countedLogVars() | X_);
|
||||||
|
if (condCount > 1 && cartProduct) {
|
||||||
|
pfs[i]->countConvert (X_);
|
||||||
|
}
|
||||||
|
pfList_.add (pfs[i]);
|
||||||
|
}
|
||||||
|
delete pf;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
vector<CountingOperator*>
|
||||||
|
CountingOperator::getValidOps (ParfactorList& pfList)
|
||||||
|
{
|
||||||
|
vector<CountingOperator*> validOps;
|
||||||
|
ParfactorList::iterator it = pfList.begin();
|
||||||
|
while (it != pfList.end()) {
|
||||||
|
LogVarSet candidates = (*it)->uncountedLogVars();
|
||||||
|
for (size_t i = 0; i < candidates.size(); i++) {
|
||||||
|
if (validOp (*it, candidates[i])) {
|
||||||
|
validOps.push_back (new CountingOperator (
|
||||||
|
it, candidates[i], pfList));
|
||||||
|
} else {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
++ it;
|
||||||
|
}
|
||||||
|
return validOps;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
string
|
||||||
|
CountingOperator::toString (void)
|
||||||
|
{
|
||||||
|
stringstream ss;
|
||||||
|
ss << "count convert " << X_ << " in " ;
|
||||||
|
ss << (*pfIter_)->getLabel();
|
||||||
|
ss << " [cost=" << std::exp (getLogCost()) << "]" << endl;
|
||||||
|
Parfactors pfs = LiftedOperations::countNormalize (*pfIter_, X_);
|
||||||
|
if ((*pfIter_)->constr()->isCountNormalized (X_) == false) {
|
||||||
|
for (size_t i = 0; i < pfs.size(); i++) {
|
||||||
|
ss << " º " << pfs[i]->getLabel() << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < pfs.size(); i++) {
|
||||||
|
delete pfs[i];
|
||||||
|
}
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
CountingOperator::validOp (Parfactor* g, LogVar X)
|
||||||
|
{
|
||||||
|
if (g->nrFormulas (X) != 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
size_t fIdx = g->indexOfLogVar (X);
|
||||||
|
if (g->argument (fIdx).isCounting()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
bool countNormalized = g->constr()->isCountNormalized (X);
|
||||||
|
if (countNormalized) {
|
||||||
|
return g->canCountConvert (X);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
double
|
||||||
|
GroundOperator::getLogCost (void)
|
||||||
|
{
|
||||||
|
vector<pair<PrvGroup, unsigned>> affectedFormulas;
|
||||||
|
affectedFormulas = getAffectedFormulas();
|
||||||
|
// cout << "affected formulas: " ;
|
||||||
|
// for (size_t i = 0; i < affectedFormulas.size(); i++) {
|
||||||
|
// cout << affectedFormulas[i].first << ":" ;
|
||||||
|
// cout << affectedFormulas[i].second << " " ;
|
||||||
|
// }
|
||||||
|
// cout << "cost =" ;
|
||||||
|
double totalCost = std::log (0.0);
|
||||||
|
ParfactorList::iterator pflIt = pfList_.begin();
|
||||||
|
while (pflIt != pfList_.end()) {
|
||||||
|
Parfactor* pf = *pflIt;
|
||||||
|
double reps = 0.0;
|
||||||
|
double pfSize = std::log (pf->size());
|
||||||
|
bool willBeAffected = false;
|
||||||
|
LogVarSet lvsToGround;
|
||||||
|
for (size_t i = 0; i < affectedFormulas.size(); i++) {
|
||||||
|
size_t fIdx = pf->indexOfGroup (affectedFormulas[i].first);
|
||||||
|
if (fIdx != pf->nrArguments()) {
|
||||||
|
ProbFormula f = pf->argument (fIdx);
|
||||||
|
LogVar X = f.logVars()[affectedFormulas[i].second];
|
||||||
|
bool isCountingLv = pf->countedLogVars().contains (X);
|
||||||
|
if (isCountingLv) {
|
||||||
|
unsigned nrHists = pf->range (fIdx);
|
||||||
|
unsigned nrSymbols = pf->constr()->getConditionalCount (X);
|
||||||
|
unsigned range = pf->argument (fIdx).range();
|
||||||
|
double power = std::log (range) * nrSymbols;
|
||||||
|
pfSize = (pfSize - std::log (nrHists)) + power;
|
||||||
|
} else {
|
||||||
|
if (lvsToGround.contains (X) == false) {
|
||||||
|
reps += std::log (pf->constr()->nrSymbols (X));
|
||||||
|
lvsToGround.insert (X);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
willBeAffected = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (willBeAffected) {
|
||||||
|
// cout << " + " << std::exp (reps) << "x" << std::exp (pfSize);
|
||||||
|
double pfCost = reps + pfSize;
|
||||||
|
totalCost = Util::logSum (totalCost, pfCost);
|
||||||
|
}
|
||||||
|
++ pflIt;
|
||||||
|
}
|
||||||
|
// cout << endl;
|
||||||
|
return totalCost + 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
GroundOperator::apply (void)
|
||||||
|
{
|
||||||
|
ParfactorList::iterator pfIter;
|
||||||
|
pfIter = getParfactorsWithGroup (pfList_, group_).front();
|
||||||
|
Parfactor* pf = *pfIter;
|
||||||
|
size_t idx = pf->indexOfGroup (group_);
|
||||||
|
ProbFormula f = pf->argument (idx);
|
||||||
|
LogVar X = f.logVars()[lvIndex_];
|
||||||
|
bool countedLv = pf->countedLogVars().contains (X);
|
||||||
|
pfList_.remove (pfIter);
|
||||||
|
if (countedLv) {
|
||||||
|
pf->fullExpand (X);
|
||||||
|
pfList_.add (pf);
|
||||||
|
} else {
|
||||||
|
ConstraintTrees cts = pf->constr()->ground (X);
|
||||||
|
for (size_t i = 0; i < cts.size(); i++) {
|
||||||
|
pfList_.add (new Parfactor (pf, cts[i]));
|
||||||
|
}
|
||||||
|
delete pf;
|
||||||
|
}
|
||||||
|
ParfactorList::iterator pflIt = pfList_.begin();
|
||||||
|
while (pflIt != pfList_.end()) {
|
||||||
|
(*pflIt)->simplifyGrounds();
|
||||||
|
++ pflIt;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
vector<GroundOperator*>
|
||||||
|
GroundOperator::getValidOps (ParfactorList& pfList)
|
||||||
|
{
|
||||||
|
vector<GroundOperator*> validOps;
|
||||||
|
set<PrvGroup> allGroups;
|
||||||
|
ParfactorList::const_iterator it = pfList.begin();
|
||||||
|
while (it != pfList.end()) {
|
||||||
|
const ProbFormulas& formulas = (*it)->arguments();
|
||||||
|
for (size_t i = 0; i < formulas.size(); i++) {
|
||||||
|
if (Util::contains (allGroups, formulas[i].group()) == false) {
|
||||||
|
const LogVars& lvs = formulas[i].logVars();
|
||||||
|
for (size_t j = 0; j < lvs.size(); j++) {
|
||||||
|
if ((*it)->constr()->isSingleton (lvs[j]) == false) {
|
||||||
|
validOps.push_back (new GroundOperator (
|
||||||
|
formulas[i].group(), j, pfList));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
allGroups.insert (formulas[i].group());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
++ it;
|
||||||
|
}
|
||||||
|
return validOps;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
string
|
||||||
|
GroundOperator::toString (void)
|
||||||
|
{
|
||||||
|
stringstream ss;
|
||||||
|
vector<ParfactorList::iterator> pfIters;
|
||||||
|
pfIters = getParfactorsWithGroup (pfList_, group_);
|
||||||
|
Parfactor* pf = *(getParfactorsWithGroup (pfList_, group_).front());
|
||||||
|
size_t idx = pf->indexOfGroup (group_);
|
||||||
|
ProbFormula f = pf->argument (idx);
|
||||||
|
LogVar lv = f.logVars()[lvIndex_];
|
||||||
|
TupleSet tupleSet = pf->constr()->tupleSet ({lv});
|
||||||
|
string pos = "th";
|
||||||
|
if (lvIndex_ == 0) {
|
||||||
|
pos = "st" ;
|
||||||
|
} else if (lvIndex_ == 1) {
|
||||||
|
pos = "nd" ;
|
||||||
|
} else if (lvIndex_ == 2) {
|
||||||
|
pos = "rd" ;
|
||||||
|
}
|
||||||
|
ss << "grounding " << lvIndex_ + 1 << pos << " log var in " ;
|
||||||
|
ss << f.functor() << "/" << f.arity();
|
||||||
|
ss << "|" << tupleSet << " (group " << group_ << ")";
|
||||||
|
ss << " [cost=" << std::exp (getLogCost()) << "]" << endl;
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
vector<pair<PrvGroup, unsigned>>
|
||||||
|
GroundOperator::getAffectedFormulas (void)
|
||||||
|
{
|
||||||
|
vector<pair<PrvGroup, unsigned>> affectedFormulas;
|
||||||
|
affectedFormulas.push_back (make_pair (group_, lvIndex_));
|
||||||
|
queue<pair<PrvGroup, unsigned>> q;
|
||||||
|
q.push (make_pair (group_, lvIndex_));
|
||||||
|
while (q.empty() == false) {
|
||||||
|
pair<PrvGroup, unsigned> front = q.front();
|
||||||
|
ParfactorList::iterator pflIt = pfList_.begin();
|
||||||
|
while (pflIt != pfList_.end()) {
|
||||||
|
size_t idx = (*pflIt)->indexOfGroup (front.first);
|
||||||
|
if (idx != (*pflIt)->nrArguments()) {
|
||||||
|
ProbFormula f = (*pflIt)->argument (idx);
|
||||||
|
LogVar X = f.logVars()[front.second];
|
||||||
|
const ProbFormulas& fs = (*pflIt)->arguments();
|
||||||
|
for (size_t i = 0; i < fs.size(); i++) {
|
||||||
|
if (i != idx && fs[i].contains (X)) {
|
||||||
|
pair<PrvGroup, unsigned> pair = make_pair (
|
||||||
|
fs[i].group(), fs[i].indexOf (X));
|
||||||
|
if (Util::contains (affectedFormulas, pair) == false) {
|
||||||
|
q.push (pair);
|
||||||
|
affectedFormulas.push_back (pair);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
++ pflIt;
|
||||||
|
}
|
||||||
|
q.pop();
|
||||||
|
}
|
||||||
|
return affectedFormulas;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Params
|
||||||
|
LiftedVe::solveQuery (const Grounds& query)
|
||||||
|
{
|
||||||
|
assert (query.empty() == false);
|
||||||
|
pfList_ = parfactorList;
|
||||||
|
runSolver (query);
|
||||||
|
(*pfList_.begin())->normalize();
|
||||||
|
Params params = (*pfList_.begin())->params();
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
Util::exp (params);
|
||||||
|
}
|
||||||
|
return params;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
LiftedVe::printSolverFlags (void) const
|
||||||
|
{
|
||||||
|
stringstream ss;
|
||||||
|
ss << "lve [" ;
|
||||||
|
ss << "log_domain=" << Util::toString (Globals::logDomain);
|
||||||
|
ss << "]" ;
|
||||||
|
cout << ss.str() << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
LiftedVe::runSolver (const Grounds& query)
|
||||||
|
{
|
||||||
|
largestCost_ = std::log (0);
|
||||||
|
LiftedOperations::shatterAgainstQuery (pfList_, query);
|
||||||
|
LiftedOperations::runWeakBayesBall (pfList_, query);
|
||||||
|
while (true) {
|
||||||
|
if (Globals::verbosity > 2) {
|
||||||
|
Util::printDashedLine();
|
||||||
|
pfList_.print();
|
||||||
|
if (Globals::verbosity > 3) {
|
||||||
|
LiftedOperator::printValidOps (pfList_, query);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
LiftedOperator* op = getBestOperation (query);
|
||||||
|
if (op == 0) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (Globals::verbosity > 1) {
|
||||||
|
cout << "best operation: " << op->toString();
|
||||||
|
if (Globals::verbosity > 2) {
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
op->apply();
|
||||||
|
delete op;
|
||||||
|
}
|
||||||
|
assert (pfList_.size() > 0);
|
||||||
|
if (pfList_.size() > 1) {
|
||||||
|
ParfactorList::iterator pfIter = pfList_.begin();
|
||||||
|
++ pfIter;
|
||||||
|
while (pfIter != pfList_.end()) {
|
||||||
|
(*pfList_.begin())->multiply (**pfIter);
|
||||||
|
++ pfIter;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (Globals::verbosity > 0) {
|
||||||
|
cout << "largest cost = " << std::exp (largestCost_) << endl;
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
(*pfList_.begin())->simplifyGrounds();
|
||||||
|
(*pfList_.begin())->reorderAccordingGrounds (query);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
LiftedOperator*
|
||||||
|
LiftedVe::getBestOperation (const Grounds& query)
|
||||||
|
{
|
||||||
|
double bestCost = 0.0;
|
||||||
|
LiftedOperator* bestOp = 0;
|
||||||
|
vector<LiftedOperator*> validOps;
|
||||||
|
validOps = LiftedOperator::getValidOps (pfList_, query);
|
||||||
|
for (size_t i = 0; i < validOps.size(); i++) {
|
||||||
|
double cost = validOps[i]->getLogCost();
|
||||||
|
if ((bestOp == 0) || (cost < bestCost)) {
|
||||||
|
bestOp = validOps[i];
|
||||||
|
bestCost = cost;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (bestCost > largestCost_) {
|
||||||
|
largestCost_ = bestCost;
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < validOps.size(); i++) {
|
||||||
|
if (validOps[i] != bestOp) {
|
||||||
|
delete validOps[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return bestOp;
|
||||||
|
}
|
||||||
|
|
155
packages/CLPBN/horus2/LiftedVe.h
Normal file
155
packages/CLPBN/horus2/LiftedVe.h
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
#ifndef HORUS_LIFTEDVE_H
|
||||||
|
#define HORUS_LIFTEDVE_H
|
||||||
|
|
||||||
|
#include "LiftedSolver.h"
|
||||||
|
#include "ParfactorList.h"
|
||||||
|
|
||||||
|
|
||||||
|
class LiftedOperator
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
virtual ~LiftedOperator (void) { }
|
||||||
|
|
||||||
|
virtual double getLogCost (void) = 0;
|
||||||
|
|
||||||
|
virtual void apply (void) = 0;
|
||||||
|
|
||||||
|
virtual string toString (void) = 0;
|
||||||
|
|
||||||
|
static vector<LiftedOperator*> getValidOps (
|
||||||
|
ParfactorList&, const Grounds&);
|
||||||
|
|
||||||
|
static void printValidOps (ParfactorList&, const Grounds&);
|
||||||
|
|
||||||
|
static vector<ParfactorList::iterator> getParfactorsWithGroup (
|
||||||
|
ParfactorList&, PrvGroup group);
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ProductOperator : public LiftedOperator
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
ProductOperator (
|
||||||
|
ParfactorList::iterator g1, ParfactorList::iterator g2,
|
||||||
|
ParfactorList& pfList) : g1_(g1), g2_(g2), pfList_(pfList) { }
|
||||||
|
|
||||||
|
double getLogCost (void);
|
||||||
|
|
||||||
|
void apply (void);
|
||||||
|
|
||||||
|
static vector<ProductOperator*> getValidOps (ParfactorList&);
|
||||||
|
|
||||||
|
string toString (void);
|
||||||
|
|
||||||
|
private:
|
||||||
|
static bool validOp (Parfactor*, Parfactor*);
|
||||||
|
|
||||||
|
ParfactorList::iterator g1_;
|
||||||
|
ParfactorList::iterator g2_;
|
||||||
|
ParfactorList& pfList_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class SumOutOperator : public LiftedOperator
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
SumOutOperator (PrvGroup group, ParfactorList& pfList)
|
||||||
|
: group_(group), pfList_(pfList) { }
|
||||||
|
|
||||||
|
double getLogCost (void);
|
||||||
|
|
||||||
|
void apply (void);
|
||||||
|
|
||||||
|
static vector<SumOutOperator*> getValidOps (
|
||||||
|
ParfactorList&, const Grounds&);
|
||||||
|
|
||||||
|
string toString (void);
|
||||||
|
|
||||||
|
private:
|
||||||
|
static bool validOp (PrvGroup, ParfactorList&, const Grounds&);
|
||||||
|
|
||||||
|
static bool isToEliminate (Parfactor*, PrvGroup, const Grounds&);
|
||||||
|
|
||||||
|
PrvGroup group_;
|
||||||
|
ParfactorList& pfList_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class CountingOperator : public LiftedOperator
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
CountingOperator (
|
||||||
|
ParfactorList::iterator pfIter,
|
||||||
|
LogVar X,
|
||||||
|
ParfactorList& pfList)
|
||||||
|
: pfIter_(pfIter), X_(X), pfList_(pfList) { }
|
||||||
|
|
||||||
|
double getLogCost (void);
|
||||||
|
|
||||||
|
void apply (void);
|
||||||
|
|
||||||
|
static vector<CountingOperator*> getValidOps (ParfactorList&);
|
||||||
|
|
||||||
|
string toString (void);
|
||||||
|
|
||||||
|
private:
|
||||||
|
static bool validOp (Parfactor*, LogVar);
|
||||||
|
|
||||||
|
ParfactorList::iterator pfIter_;
|
||||||
|
LogVar X_;
|
||||||
|
ParfactorList& pfList_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class GroundOperator : public LiftedOperator
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
GroundOperator (
|
||||||
|
PrvGroup group,
|
||||||
|
unsigned lvIndex,
|
||||||
|
ParfactorList& pfList)
|
||||||
|
: group_(group), lvIndex_(lvIndex), pfList_(pfList) { }
|
||||||
|
|
||||||
|
double getLogCost (void);
|
||||||
|
|
||||||
|
void apply (void);
|
||||||
|
|
||||||
|
static vector<GroundOperator*> getValidOps (ParfactorList&);
|
||||||
|
|
||||||
|
string toString (void);
|
||||||
|
|
||||||
|
private:
|
||||||
|
vector<pair<PrvGroup, unsigned>> getAffectedFormulas (void);
|
||||||
|
|
||||||
|
PrvGroup group_;
|
||||||
|
unsigned lvIndex_;
|
||||||
|
ParfactorList& pfList_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class LiftedVe : public LiftedSolver
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
LiftedVe (const ParfactorList& pfList)
|
||||||
|
: LiftedSolver(pfList) { }
|
||||||
|
|
||||||
|
Params solveQuery (const Grounds&);
|
||||||
|
|
||||||
|
void printSolverFlags (void) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
void runSolver (const Grounds&);
|
||||||
|
|
||||||
|
LiftedOperator* getBestOperation (const Grounds&);
|
||||||
|
|
||||||
|
ParfactorList pfList_;
|
||||||
|
double largestCost_;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // HORUS_LIFTEDVE_H
|
||||||
|
|
658
packages/CLPBN/horus2/LiftedWCNF.cpp
Normal file
658
packages/CLPBN/horus2/LiftedWCNF.cpp
Normal file
@ -0,0 +1,658 @@
|
|||||||
|
#include "LiftedWCNF.h"
|
||||||
|
#include "ConstraintTree.h"
|
||||||
|
#include "Indexer.h"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
Literal::isGround (ConstraintTree constr, LogVarSet ipgLogVars) const
|
||||||
|
{
|
||||||
|
if (logVars_.size() == 0) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
LogVarSet lvs (logVars_);
|
||||||
|
lvs -= ipgLogVars;
|
||||||
|
return constr.singletons().contains (lvs);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
size_t
|
||||||
|
Literal::indexOfLogVar (LogVar X) const
|
||||||
|
{
|
||||||
|
return Util::indexOf (logVars_, X);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
string
|
||||||
|
Literal::toString (
|
||||||
|
LogVarSet ipgLogVars,
|
||||||
|
LogVarSet posCountedLvs,
|
||||||
|
LogVarSet negCountedLvs) const
|
||||||
|
{
|
||||||
|
stringstream ss;
|
||||||
|
negated_ ? ss << "¬" : ss << "" ;
|
||||||
|
ss << "λ" ;
|
||||||
|
ss << lid_ ;
|
||||||
|
if (logVars_.empty() == false) {
|
||||||
|
ss << "(" ;
|
||||||
|
for (size_t i = 0; i < logVars_.size(); i++) {
|
||||||
|
if (i != 0) ss << ",";
|
||||||
|
if (posCountedLvs.contains (logVars_[i])) {
|
||||||
|
ss << "+" << logVars_[i];
|
||||||
|
} else if (negCountedLvs.contains (logVars_[i])) {
|
||||||
|
ss << "-" << logVars_[i];
|
||||||
|
} else if (ipgLogVars.contains (logVars_[i])) {
|
||||||
|
LogVar X = logVars_[i];
|
||||||
|
const string labels[] = {
|
||||||
|
"a", "b", "c", "d", "e", "f",
|
||||||
|
"g", "h", "i", "j", "k", "m" };
|
||||||
|
(X >= 12) ? ss << "x_" << X : ss << labels[X];
|
||||||
|
} else {
|
||||||
|
ss << logVars_[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ss << ")" ;
|
||||||
|
}
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
std::ostream&
|
||||||
|
operator<< (ostream &os, const Literal& lit)
|
||||||
|
{
|
||||||
|
os << lit.toString();
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
Clause::addLiteralComplemented (const Literal& lit)
|
||||||
|
{
|
||||||
|
assert (constr_.logVarSet().contains (lit.logVars()));
|
||||||
|
literals_.push_back (lit);
|
||||||
|
literals_.back().complement();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
Clause::containsLiteral (LiteralId lid) const
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < literals_.size(); i++) {
|
||||||
|
if (literals_[i].lid() == lid) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
Clause::containsPositiveLiteral (
|
||||||
|
LiteralId lid,
|
||||||
|
const LogVarTypes& types) const
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < literals_.size(); i++) {
|
||||||
|
if (literals_[i].lid() == lid
|
||||||
|
&& literals_[i].isPositive()
|
||||||
|
&& logVarTypes (i) == types) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
Clause::containsNegativeLiteral (
|
||||||
|
LiteralId lid,
|
||||||
|
const LogVarTypes& types) const
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < literals_.size(); i++) {
|
||||||
|
if (literals_[i].lid() == lid
|
||||||
|
&& literals_[i].isNegative()
|
||||||
|
&& logVarTypes (i) == types) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
Clause::removeLiterals (LiteralId lid)
|
||||||
|
{
|
||||||
|
size_t i = 0;
|
||||||
|
while (i != literals_.size()) {
|
||||||
|
if (literals_[i].lid() == lid) {
|
||||||
|
removeLiteral (i);
|
||||||
|
} else {
|
||||||
|
i ++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
Clause::removePositiveLiterals (
|
||||||
|
LiteralId lid,
|
||||||
|
const LogVarTypes& types)
|
||||||
|
{
|
||||||
|
size_t i = 0;
|
||||||
|
while (i != literals_.size()) {
|
||||||
|
if (literals_[i].lid() == lid
|
||||||
|
&& literals_[i].isPositive()
|
||||||
|
&& logVarTypes (i) == types) {
|
||||||
|
removeLiteral (i);
|
||||||
|
} else {
|
||||||
|
i ++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
Clause::removeNegativeLiterals (
|
||||||
|
LiteralId lid,
|
||||||
|
const LogVarTypes& types)
|
||||||
|
{
|
||||||
|
size_t i = 0;
|
||||||
|
while (i != literals_.size()) {
|
||||||
|
if (literals_[i].lid() == lid
|
||||||
|
&& literals_[i].isNegative()
|
||||||
|
&& logVarTypes (i) == types) {
|
||||||
|
removeLiteral (i);
|
||||||
|
} else {
|
||||||
|
i ++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
Clause::isCountedLogVar (LogVar X) const
|
||||||
|
{
|
||||||
|
assert (constr_.logVarSet().contains (X));
|
||||||
|
return posCountedLvs_.contains (X)
|
||||||
|
|| negCountedLvs_.contains (X);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
Clause::isPositiveCountedLogVar (LogVar X) const
|
||||||
|
{
|
||||||
|
assert (constr_.logVarSet().contains (X));
|
||||||
|
return posCountedLvs_.contains (X);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
Clause::isNegativeCountedLogVar (LogVar X) const
|
||||||
|
{
|
||||||
|
assert (constr_.logVarSet().contains (X));
|
||||||
|
return negCountedLvs_.contains (X);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
Clause::isIpgLogVar (LogVar X) const
|
||||||
|
{
|
||||||
|
assert (constr_.logVarSet().contains (X));
|
||||||
|
return ipgLvs_.contains (X);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
TinySet<LiteralId>
|
||||||
|
Clause::lidSet (void) const
|
||||||
|
{
|
||||||
|
TinySet<LiteralId> lidSet;
|
||||||
|
for (size_t i = 0; i < literals_.size(); i++) {
|
||||||
|
lidSet.insert (literals_[i].lid());
|
||||||
|
}
|
||||||
|
return lidSet;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
LogVarSet
|
||||||
|
Clause::ipgCandidates (void) const
|
||||||
|
{
|
||||||
|
LogVarSet candidates;
|
||||||
|
LogVarSet allLvs = constr_.logVarSet();
|
||||||
|
allLvs -= ipgLvs_;
|
||||||
|
allLvs -= posCountedLvs_;
|
||||||
|
allLvs -= negCountedLvs_;
|
||||||
|
for (size_t i = 0; i < allLvs.size(); i++) {
|
||||||
|
bool valid = true;
|
||||||
|
for (size_t j = 0; j < literals_.size(); j++) {
|
||||||
|
if (Util::contains (literals_[j].logVars(), allLvs[i]) == false) {
|
||||||
|
valid = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (valid) {
|
||||||
|
candidates.insert (allLvs[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return candidates;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
LogVarTypes
|
||||||
|
Clause::logVarTypes (size_t litIdx) const
|
||||||
|
{
|
||||||
|
LogVarTypes types;
|
||||||
|
const LogVars& lvs = literals_[litIdx].logVars();
|
||||||
|
for (size_t i = 0; i < lvs.size(); i++) {
|
||||||
|
if (posCountedLvs_.contains (lvs[i])) {
|
||||||
|
types.push_back (LogVarType::POS_LV);
|
||||||
|
} else if (negCountedLvs_.contains (lvs[i])) {
|
||||||
|
types.push_back (LogVarType::NEG_LV);
|
||||||
|
} else {
|
||||||
|
types.push_back (LogVarType::FULL_LV);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return types;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
Clause::removeLiteral (size_t litIdx)
|
||||||
|
{
|
||||||
|
LogVarSet lvsToRemove = literals_[litIdx].logVarSet()
|
||||||
|
- getLogVarSetExcluding (litIdx);
|
||||||
|
ipgLvs_ -= lvsToRemove;
|
||||||
|
posCountedLvs_ -= lvsToRemove;
|
||||||
|
negCountedLvs_ -= lvsToRemove;
|
||||||
|
constr_.remove (lvsToRemove);
|
||||||
|
literals_.erase (literals_.begin() + litIdx);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
Clause::independentClauses (Clause& c1, Clause& c2)
|
||||||
|
{
|
||||||
|
const Literals& lits1 = c1.literals();
|
||||||
|
const Literals& lits2 = c2.literals();
|
||||||
|
for (size_t i = 0; i < lits1.size(); i++) {
|
||||||
|
for (size_t j = 0; j < lits2.size(); j++) {
|
||||||
|
if (lits1[i].lid() == lits2[j].lid()
|
||||||
|
&& c1.logVarTypes (i) == c2.logVarTypes (j)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Clauses
|
||||||
|
Clause::copyClauses (const Clauses& clauses)
|
||||||
|
{
|
||||||
|
Clauses copy;
|
||||||
|
copy.reserve (clauses.size());
|
||||||
|
for (size_t i = 0; i < clauses.size(); i++) {
|
||||||
|
copy.push_back (new Clause (*clauses[i]));
|
||||||
|
}
|
||||||
|
return copy;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
Clause::printClauses (const Clauses& clauses)
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < clauses.size(); i++) {
|
||||||
|
cout << *clauses[i] << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
Clause::deleteClauses (Clauses& clauses)
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < clauses.size(); i++) {
|
||||||
|
delete clauses[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
std::ostream&
|
||||||
|
operator<< (ostream &os, const Clause& clause)
|
||||||
|
{
|
||||||
|
for (unsigned i = 0; i < clause.literals_.size(); i++) {
|
||||||
|
if (i != 0) os << " v " ;
|
||||||
|
os << clause.literals_[i].toString (clause.ipgLvs_,
|
||||||
|
clause.posCountedLvs_, clause.negCountedLvs_);
|
||||||
|
}
|
||||||
|
if (clause.constr_.empty() == false) {
|
||||||
|
ConstraintTree copy (clause.constr_);
|
||||||
|
copy.moveToTop (copy.logVarSet().elements());
|
||||||
|
os << " | " << copy.tupleSet();
|
||||||
|
}
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
LogVarSet
|
||||||
|
Clause::getLogVarSetExcluding (size_t idx) const
|
||||||
|
{
|
||||||
|
LogVarSet lvs;
|
||||||
|
for (size_t i = 0; i < literals_.size(); i++) {
|
||||||
|
if (i != idx) {
|
||||||
|
lvs |= literals_[i].logVars();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return lvs;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
std::ostream&
|
||||||
|
operator<< (std::ostream &os, const LitLvTypes& lit)
|
||||||
|
{
|
||||||
|
os << lit.lid_ << "<" ;
|
||||||
|
for (size_t i = 0; i < lit.lvTypes_.size(); i++) {
|
||||||
|
switch (lit.lvTypes_[i]) {
|
||||||
|
case LogVarType::FULL_LV: os << "F" ; break;
|
||||||
|
case LogVarType::POS_LV: os << "P" ; break;
|
||||||
|
case LogVarType::NEG_LV: os << "N" ; break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
os << ">" ;
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
LiftedWCNF::LiftedWCNF (const ParfactorList& pfList)
|
||||||
|
: freeLiteralId_(0), pfList_(pfList)
|
||||||
|
{
|
||||||
|
addIndicatorClauses (pfList);
|
||||||
|
addParameterClauses (pfList);
|
||||||
|
|
||||||
|
/*
|
||||||
|
// INCLUSION-EXCLUSION TEST
|
||||||
|
clauses_.clear();
|
||||||
|
vector<vector<string>> names = {
|
||||||
|
{"a1","b1"},{"a2","b2"}
|
||||||
|
};
|
||||||
|
Clause* c1 = new Clause (names);
|
||||||
|
c1->addLiteral (Literal (0, LogVars() = {0}));
|
||||||
|
c1->addLiteral (Literal (1, LogVars() = {1}));
|
||||||
|
clauses_.push_back(c1);
|
||||||
|
*/
|
||||||
|
|
||||||
|
/*
|
||||||
|
// INDEPENDENT PARTIAL GROUND TEST
|
||||||
|
clauses_.clear();
|
||||||
|
vector<vector<string>> names = {
|
||||||
|
{"a1","b1"},{"a2","b2"}
|
||||||
|
};
|
||||||
|
Clause* c1 = new Clause (names);
|
||||||
|
c1->addLiteral (Literal (0, LogVars() = {0,1}));
|
||||||
|
c1->addLiteral (Literal (1, LogVars() = {0,1}));
|
||||||
|
clauses_.push_back(c1);
|
||||||
|
Clause* c2 = new Clause (names);
|
||||||
|
c2->addLiteral (Literal (2, LogVars() = {0}));
|
||||||
|
c2->addLiteral (Literal (1, LogVars() = {0,1}));
|
||||||
|
clauses_.push_back(c2);
|
||||||
|
*/
|
||||||
|
|
||||||
|
/*
|
||||||
|
// ATOM-COUNTING TEST
|
||||||
|
clauses_.clear();
|
||||||
|
vector<vector<string>> names = {
|
||||||
|
{"p1","p1"},{"p1","p2"},{"p1","p3"},
|
||||||
|
{"p2","p1"},{"p2","p2"},{"p2","p3"},
|
||||||
|
{"p3","p1"},{"p3","p2"},{"p3","p3"}
|
||||||
|
};
|
||||||
|
Clause* c1 = new Clause (names);
|
||||||
|
c1->addLiteral (Literal (0, LogVars() = {0}));
|
||||||
|
c1->addLiteralComplemented (Literal (1, {0,1}));
|
||||||
|
clauses_.push_back(c1);
|
||||||
|
Clause* c2 = new Clause (names);
|
||||||
|
c2->addLiteral (Literal (0, LogVars()={0}));
|
||||||
|
c2->addLiteralComplemented (Literal (1, {1,0}));
|
||||||
|
clauses_.push_back(c2);
|
||||||
|
*/
|
||||||
|
|
||||||
|
if (Globals::verbosity > 1) {
|
||||||
|
cout << "FORMULA INDICATORS:" << endl;
|
||||||
|
printFormulaIndicators();
|
||||||
|
cout << endl;
|
||||||
|
cout << "WEIGHTED INDICATORS:" << endl;
|
||||||
|
printWeights();
|
||||||
|
cout << endl;
|
||||||
|
cout << "CLAUSES:" << endl;
|
||||||
|
printClauses();
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
LiftedWCNF::~LiftedWCNF (void)
|
||||||
|
{
|
||||||
|
Clause::deleteClauses (clauses_);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
LiftedWCNF::addWeight (LiteralId lid, double posW, double negW)
|
||||||
|
{
|
||||||
|
weights_[lid] = make_pair (posW, negW);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
double
|
||||||
|
LiftedWCNF::posWeight (LiteralId lid) const
|
||||||
|
{
|
||||||
|
unordered_map<LiteralId, std::pair<double,double>>::const_iterator it;
|
||||||
|
it = weights_.find (lid);
|
||||||
|
return it != weights_.end() ? it->second.first : LogAware::one();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
double
|
||||||
|
LiftedWCNF::negWeight (LiteralId lid) const
|
||||||
|
{
|
||||||
|
unordered_map<LiteralId, std::pair<double,double>>::const_iterator it;
|
||||||
|
it = weights_.find (lid);
|
||||||
|
return it != weights_.end() ? it->second.second : LogAware::one();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
vector<LiteralId>
|
||||||
|
LiftedWCNF::prvGroupLiterals (PrvGroup prvGroup)
|
||||||
|
{
|
||||||
|
assert (Util::contains (map_, prvGroup));
|
||||||
|
return map_[prvGroup];
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Clause*
|
||||||
|
LiftedWCNF::createClause (LiteralId lid) const
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < clauses_.size(); i++) {
|
||||||
|
const Literals& literals = clauses_[i]->literals();
|
||||||
|
for (size_t j = 0; j < literals.size(); j++) {
|
||||||
|
if (literals[j].lid() == lid) {
|
||||||
|
ConstraintTree ct = clauses_[i]->constr().projectedCopy (
|
||||||
|
literals[j].logVars());
|
||||||
|
Clause* c = new Clause (ct);
|
||||||
|
c->addLiteral (literals[j]);
|
||||||
|
return c;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
LiteralId
|
||||||
|
LiftedWCNF::getLiteralId (PrvGroup prvGroup, unsigned range)
|
||||||
|
{
|
||||||
|
assert (Util::contains (map_, prvGroup));
|
||||||
|
return map_[prvGroup][range];
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
LiftedWCNF::addIndicatorClauses (const ParfactorList& pfList)
|
||||||
|
{
|
||||||
|
ParfactorList::const_iterator it = pfList.begin();
|
||||||
|
while (it != pfList.end()) {
|
||||||
|
const ProbFormulas& formulas = (*it)->arguments();
|
||||||
|
for (size_t i = 0; i < formulas.size(); i++) {
|
||||||
|
if (Util::contains (map_, formulas[i].group()) == false) {
|
||||||
|
ConstraintTree tempConstr = (*it)->constr()->projectedCopy(
|
||||||
|
formulas[i].logVars());
|
||||||
|
Clause* clause = new Clause (tempConstr);
|
||||||
|
vector<LiteralId> lids;
|
||||||
|
for (size_t j = 0; j < formulas[i].range(); j++) {
|
||||||
|
clause->addLiteral (Literal (freeLiteralId_, formulas[i].logVars()));
|
||||||
|
lids.push_back (freeLiteralId_);
|
||||||
|
freeLiteralId_ ++;
|
||||||
|
}
|
||||||
|
clauses_.push_back (clause);
|
||||||
|
for (size_t j = 0; j < formulas[i].range() - 1; j++) {
|
||||||
|
for (size_t k = j + 1; k < formulas[i].range(); k++) {
|
||||||
|
ConstraintTree tempConstr2 = (*it)->constr()->projectedCopy (
|
||||||
|
formulas[i].logVars());
|
||||||
|
Clause* clause2 = new Clause (tempConstr2);
|
||||||
|
clause2->addLiteralComplemented (Literal (clause->literals()[j]));
|
||||||
|
clause2->addLiteralComplemented (Literal (clause->literals()[k]));
|
||||||
|
clauses_.push_back (clause2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
map_[formulas[i].group()] = lids;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
++ it;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
LiftedWCNF::addParameterClauses (const ParfactorList& pfList)
|
||||||
|
{
|
||||||
|
ParfactorList::const_iterator it = pfList.begin();
|
||||||
|
while (it != pfList.end()) {
|
||||||
|
Indexer indexer ((*it)->ranges());
|
||||||
|
vector<PrvGroup> groups = (*it)->getAllGroups();
|
||||||
|
while (indexer.valid()) {
|
||||||
|
LiteralId paramVarLid = freeLiteralId_;
|
||||||
|
// λu1 ∧ ... ∧ λun ∧ λxi <=> θxi|u1,...,un
|
||||||
|
//
|
||||||
|
// ¬λu1 ... ¬λun v θxi|u1,...,un -> clause1
|
||||||
|
// ¬θxi|u1,...,un v λu1 -> tempClause
|
||||||
|
// ¬θxi|u1,...,un v λu2 -> tempClause
|
||||||
|
double posWeight = (**it)[indexer];
|
||||||
|
addWeight (paramVarLid, posWeight, LogAware::one());
|
||||||
|
|
||||||
|
Clause* clause1 = new Clause (*(*it)->constr());
|
||||||
|
|
||||||
|
for (unsigned i = 0; i < groups.size(); i++) {
|
||||||
|
LiteralId lid = getLiteralId (groups[i], indexer[i]);
|
||||||
|
|
||||||
|
clause1->addLiteralComplemented (
|
||||||
|
Literal (lid, (*it)->argument(i).logVars()));
|
||||||
|
|
||||||
|
ConstraintTree ct = *(*it)->constr();
|
||||||
|
Clause* tempClause = new Clause (ct);
|
||||||
|
tempClause->addLiteralComplemented (Literal (
|
||||||
|
paramVarLid, (*it)->constr()->logVars()));
|
||||||
|
tempClause->addLiteral (Literal (lid, (*it)->argument(i).logVars()));
|
||||||
|
clauses_.push_back (tempClause);
|
||||||
|
}
|
||||||
|
clause1->addLiteral (Literal (paramVarLid, (*it)->constr()->logVars()));
|
||||||
|
clauses_.push_back (clause1);
|
||||||
|
freeLiteralId_ ++;
|
||||||
|
++ indexer;
|
||||||
|
}
|
||||||
|
++ it;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
LiftedWCNF::printFormulaIndicators (void) const
|
||||||
|
{
|
||||||
|
if (map_.empty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
set<PrvGroup> allGroups;
|
||||||
|
ParfactorList::const_iterator it = pfList_.begin();
|
||||||
|
while (it != pfList_.end()) {
|
||||||
|
const ProbFormulas& formulas = (*it)->arguments();
|
||||||
|
for (size_t i = 0; i < formulas.size(); i++) {
|
||||||
|
if (Util::contains (allGroups, formulas[i].group()) == false) {
|
||||||
|
allGroups.insert (formulas[i].group());
|
||||||
|
cout << formulas[i] << " | " ;
|
||||||
|
ConstraintTree tempCt = (*it)->constr()->projectedCopy (
|
||||||
|
formulas[i].logVars());
|
||||||
|
cout << tempCt.tupleSet();
|
||||||
|
cout << " indicators => " ;
|
||||||
|
vector<LiteralId> indicators =
|
||||||
|
(map_.find (formulas[i].group()))->second;
|
||||||
|
cout << indicators << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
++ it;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
LiftedWCNF::printWeights (void) const
|
||||||
|
{
|
||||||
|
unordered_map<LiteralId, std::pair<double,double>>::const_iterator it;
|
||||||
|
it = weights_.begin();
|
||||||
|
while (it != weights_.end()) {
|
||||||
|
cout << "λ" << it->first << " weights: " ;
|
||||||
|
cout << it->second.first << " " << it->second.second;
|
||||||
|
cout << endl;
|
||||||
|
++ it;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
LiftedWCNF::printClauses (void) const
|
||||||
|
{
|
||||||
|
Clause::printClauses (clauses_);
|
||||||
|
}
|
||||||
|
|
239
packages/CLPBN/horus2/LiftedWCNF.h
Normal file
239
packages/CLPBN/horus2/LiftedWCNF.h
Normal file
@ -0,0 +1,239 @@
|
|||||||
|
#ifndef HORUS_LIFTEDWCNF_H
|
||||||
|
#define HORUS_LIFTEDWCNF_H
|
||||||
|
|
||||||
|
#include "ParfactorList.h"
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
typedef long LiteralId;
|
||||||
|
|
||||||
|
class ConstraintTree;
|
||||||
|
|
||||||
|
|
||||||
|
enum LogVarType
|
||||||
|
{
|
||||||
|
FULL_LV,
|
||||||
|
POS_LV,
|
||||||
|
NEG_LV
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef vector<LogVarType> LogVarTypes;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Literal
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
Literal (LiteralId lid, const LogVars& lvs) :
|
||||||
|
lid_(lid), logVars_(lvs), negated_(false) { }
|
||||||
|
|
||||||
|
Literal (const Literal& lit, bool negated) :
|
||||||
|
lid_(lit.lid_), logVars_(lit.logVars_), negated_(negated) { }
|
||||||
|
|
||||||
|
LiteralId lid (void) const { return lid_; }
|
||||||
|
|
||||||
|
LogVars logVars (void) const { return logVars_; }
|
||||||
|
|
||||||
|
size_t nrLogVars (void) const { return logVars_.size(); }
|
||||||
|
|
||||||
|
LogVarSet logVarSet (void) const { return LogVarSet (logVars_); }
|
||||||
|
|
||||||
|
void complement (void) { negated_ = !negated_; }
|
||||||
|
|
||||||
|
bool isPositive (void) const { return negated_ == false; }
|
||||||
|
|
||||||
|
bool isNegative (void) const { return negated_; }
|
||||||
|
|
||||||
|
bool isGround (ConstraintTree constr, LogVarSet ipgLogVars) const;
|
||||||
|
|
||||||
|
size_t indexOfLogVar (LogVar X) const;
|
||||||
|
|
||||||
|
string toString (LogVarSet ipgLogVars = LogVarSet(),
|
||||||
|
LogVarSet posCountedLvs = LogVarSet(),
|
||||||
|
LogVarSet negCountedLvs = LogVarSet()) const;
|
||||||
|
|
||||||
|
friend std::ostream& operator<< (std::ostream &os, const Literal& lit);
|
||||||
|
|
||||||
|
private:
|
||||||
|
LiteralId lid_;
|
||||||
|
LogVars logVars_;
|
||||||
|
bool negated_;
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef vector<Literal> Literals;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Clause
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
Clause (const ConstraintTree& ct = ConstraintTree({})) : constr_(ct) { }
|
||||||
|
|
||||||
|
Clause (vector<vector<string>> names) : constr_(ConstraintTree (names)) { }
|
||||||
|
|
||||||
|
void addLiteral (const Literal& l) { literals_.push_back (l); }
|
||||||
|
|
||||||
|
const Literals& literals (void) const { return literals_; }
|
||||||
|
|
||||||
|
Literals& literals (void) { return literals_; }
|
||||||
|
|
||||||
|
size_t nrLiterals (void) const { return literals_.size(); }
|
||||||
|
|
||||||
|
const ConstraintTree& constr (void) const { return constr_; }
|
||||||
|
|
||||||
|
ConstraintTree constr (void) { return constr_; }
|
||||||
|
|
||||||
|
bool isUnit (void) const { return literals_.size() == 1; }
|
||||||
|
|
||||||
|
LogVarSet ipgLogVars (void) const { return ipgLvs_; }
|
||||||
|
|
||||||
|
void addIpgLogVar (LogVar X) { ipgLvs_.insert (X); }
|
||||||
|
|
||||||
|
void addPosCountedLogVar (LogVar X) { posCountedLvs_.insert (X); }
|
||||||
|
|
||||||
|
void addNegCountedLogVar (LogVar X) { negCountedLvs_.insert (X); }
|
||||||
|
|
||||||
|
LogVarSet posCountedLogVars (void) const { return posCountedLvs_; }
|
||||||
|
|
||||||
|
LogVarSet negCountedLogVars (void) const { return negCountedLvs_; }
|
||||||
|
|
||||||
|
unsigned nrPosCountedLogVars (void) const { return posCountedLvs_.size(); }
|
||||||
|
|
||||||
|
unsigned nrNegCountedLogVars (void) const { return negCountedLvs_.size(); }
|
||||||
|
|
||||||
|
void addLiteralComplemented (const Literal& lit);
|
||||||
|
|
||||||
|
bool containsLiteral (LiteralId lid) const;
|
||||||
|
|
||||||
|
bool containsPositiveLiteral (LiteralId lid, const LogVarTypes&) const;
|
||||||
|
|
||||||
|
bool containsNegativeLiteral (LiteralId lid, const LogVarTypes&) const;
|
||||||
|
|
||||||
|
void removeLiterals (LiteralId lid);
|
||||||
|
|
||||||
|
void removePositiveLiterals (LiteralId lid, const LogVarTypes&);
|
||||||
|
|
||||||
|
void removeNegativeLiterals (LiteralId lid, const LogVarTypes&);
|
||||||
|
|
||||||
|
bool isCountedLogVar (LogVar X) const;
|
||||||
|
|
||||||
|
bool isPositiveCountedLogVar (LogVar X) const;
|
||||||
|
|
||||||
|
bool isNegativeCountedLogVar (LogVar X) const;
|
||||||
|
|
||||||
|
bool isIpgLogVar (LogVar X) const;
|
||||||
|
|
||||||
|
TinySet<LiteralId> lidSet (void) const;
|
||||||
|
|
||||||
|
LogVarSet ipgCandidates (void) const;
|
||||||
|
|
||||||
|
LogVarTypes logVarTypes (size_t litIdx) const;
|
||||||
|
|
||||||
|
void removeLiteral (size_t litIdx);
|
||||||
|
|
||||||
|
static bool independentClauses (Clause& c1, Clause& c2);
|
||||||
|
|
||||||
|
static vector<Clause*> copyClauses (const vector<Clause*>& clauses);
|
||||||
|
|
||||||
|
static void printClauses (const vector<Clause*>& clauses);
|
||||||
|
|
||||||
|
static void deleteClauses (vector<Clause*>& clauses);
|
||||||
|
|
||||||
|
friend std::ostream& operator<< (ostream &os, const Clause& clause);
|
||||||
|
|
||||||
|
private:
|
||||||
|
LogVarSet getLogVarSetExcluding (size_t idx) const;
|
||||||
|
|
||||||
|
Literals literals_;
|
||||||
|
LogVarSet ipgLvs_;
|
||||||
|
LogVarSet posCountedLvs_;
|
||||||
|
LogVarSet negCountedLvs_;
|
||||||
|
ConstraintTree constr_;
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef vector<Clause*> Clauses;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class LitLvTypes
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
struct CompareLitLvTypes
|
||||||
|
{
|
||||||
|
bool operator() (
|
||||||
|
const LitLvTypes& types1,
|
||||||
|
const LitLvTypes& types2) const
|
||||||
|
{
|
||||||
|
if (types1.lid_ < types2.lid_) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (types1.lid_ == types2.lid_) {
|
||||||
|
return types1.lvTypes_ < types2.lvTypes_;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
LitLvTypes (LiteralId lid, const LogVarTypes& lvTypes) :
|
||||||
|
lid_(lid), lvTypes_(lvTypes) { }
|
||||||
|
|
||||||
|
LiteralId lid (void) const { return lid_; }
|
||||||
|
|
||||||
|
const LogVarTypes& logVarTypes (void) const { return lvTypes_; }
|
||||||
|
|
||||||
|
void setAllFullLogVars (void) {
|
||||||
|
std::fill (lvTypes_.begin(), lvTypes_.end(), LogVarType::FULL_LV); }
|
||||||
|
|
||||||
|
friend std::ostream& operator<< (std::ostream &os, const LitLvTypes& lit);
|
||||||
|
|
||||||
|
private:
|
||||||
|
LiteralId lid_;
|
||||||
|
LogVarTypes lvTypes_;
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef TinySet<LitLvTypes,LitLvTypes::CompareLitLvTypes> LitLvTypesSet;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class LiftedWCNF
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
LiftedWCNF (const ParfactorList& pfList);
|
||||||
|
|
||||||
|
~LiftedWCNF (void);
|
||||||
|
|
||||||
|
const Clauses& clauses (void) const { return clauses_; }
|
||||||
|
|
||||||
|
void addWeight (LiteralId lid, double posW, double negW);
|
||||||
|
|
||||||
|
double posWeight (LiteralId lid) const;
|
||||||
|
|
||||||
|
double negWeight (LiteralId lid) const;
|
||||||
|
|
||||||
|
vector<LiteralId> prvGroupLiterals (PrvGroup prvGroup);
|
||||||
|
|
||||||
|
Clause* createClause (LiteralId lid) const;
|
||||||
|
|
||||||
|
void printFormulaIndicators (void) const;
|
||||||
|
|
||||||
|
void printWeights (void) const;
|
||||||
|
|
||||||
|
void printClauses (void) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
|
||||||
|
LiteralId getLiteralId (PrvGroup prvGroup, unsigned range);
|
||||||
|
|
||||||
|
void addIndicatorClauses (const ParfactorList& pfList);
|
||||||
|
|
||||||
|
void addParameterClauses (const ParfactorList& pfList);
|
||||||
|
|
||||||
|
Clauses clauses_;
|
||||||
|
LiteralId freeLiteralId_;
|
||||||
|
const ParfactorList& pfList_;
|
||||||
|
unordered_map<PrvGroup, vector<LiteralId>> map_;
|
||||||
|
unordered_map<LiteralId, std::pair<double,double>> weights_;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // HORUS_LIFTEDWCNF_H
|
||||||
|
|
942
packages/CLPBN/horus2/Parfactor.cpp
Normal file
942
packages/CLPBN/horus2/Parfactor.cpp
Normal file
@ -0,0 +1,942 @@
|
|||||||
|
|
||||||
|
#include "Parfactor.h"
|
||||||
|
#include "Histogram.h"
|
||||||
|
#include "Indexer.h"
|
||||||
|
#include "Util.h"
|
||||||
|
#include "Horus.h"
|
||||||
|
|
||||||
|
|
||||||
|
Parfactor::Parfactor (
|
||||||
|
const ProbFormulas& formulas,
|
||||||
|
const Params& params,
|
||||||
|
const Tuples& tuples,
|
||||||
|
unsigned distId)
|
||||||
|
{
|
||||||
|
args_ = formulas;
|
||||||
|
params_ = params;
|
||||||
|
distId_ = distId;
|
||||||
|
|
||||||
|
LogVars logVars;
|
||||||
|
for (size_t i = 0; i < args_.size(); i++) {
|
||||||
|
ranges_.push_back (args_[i].range());
|
||||||
|
const LogVars& lvs = args_[i].logVars();
|
||||||
|
for (size_t j = 0; j < lvs.size(); j++) {
|
||||||
|
if (Util::contains (logVars, lvs[j]) == false) {
|
||||||
|
logVars.push_back (lvs[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
LogVar newLv = logVars.size();
|
||||||
|
constr_ = new ConstraintTree (logVars, tuples);
|
||||||
|
// Change formulas like f(X,X), X in {(p1),(p2),...}
|
||||||
|
// to be like f(X,Y), (X,Y) in {(p1,p1),(p2,p2),...}.
|
||||||
|
// This will simplify shattering on the constraint tree.
|
||||||
|
for (size_t i = 0; i < args_.size(); i++) {
|
||||||
|
LogVarSet lvSet;
|
||||||
|
LogVars& lvs = args_[i].logVars();
|
||||||
|
for (size_t j = 0; j < lvs.size(); j++) {
|
||||||
|
if (lvSet.contains (lvs[j]) == false) {
|
||||||
|
lvSet |= lvs[j];
|
||||||
|
} else {
|
||||||
|
constr_->cloneLogVar (lvs[j], newLv);
|
||||||
|
lvs[j] = newLv;
|
||||||
|
++ newLv;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert (params_.size() == Util::sizeExpected (ranges_));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Parfactor::Parfactor (const Parfactor* g, const Tuple& tuple)
|
||||||
|
{
|
||||||
|
args_ = g->arguments();
|
||||||
|
params_ = g->params();
|
||||||
|
ranges_ = g->ranges();
|
||||||
|
distId_ = g->distId();
|
||||||
|
constr_ = new ConstraintTree (g->logVars(), {tuple});
|
||||||
|
assert (params_.size() == Util::sizeExpected (ranges_));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Parfactor::Parfactor (const Parfactor* g, ConstraintTree* constr)
|
||||||
|
{
|
||||||
|
args_ = g->arguments();
|
||||||
|
params_ = g->params();
|
||||||
|
ranges_ = g->ranges();
|
||||||
|
distId_ = g->distId();
|
||||||
|
constr_ = constr;
|
||||||
|
assert (params_.size() == Util::sizeExpected (ranges_));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Parfactor::Parfactor (const Parfactor& g)
|
||||||
|
{
|
||||||
|
args_ = g.arguments();
|
||||||
|
params_ = g.params();
|
||||||
|
ranges_ = g.ranges();
|
||||||
|
distId_ = g.distId();
|
||||||
|
constr_ = new ConstraintTree (*g.constr());
|
||||||
|
assert (params_.size() == Util::sizeExpected (ranges_));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Parfactor::~Parfactor (void)
|
||||||
|
{
|
||||||
|
delete constr_;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
LogVarSet
|
||||||
|
Parfactor::countedLogVars (void) const
|
||||||
|
{
|
||||||
|
LogVarSet set;
|
||||||
|
for (size_t i = 0; i < args_.size(); i++) {
|
||||||
|
if (args_[i].isCounting()) {
|
||||||
|
set.insert (args_[i].countedLogVar());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return set;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
LogVarSet
|
||||||
|
Parfactor::uncountedLogVars (void) const
|
||||||
|
{
|
||||||
|
return constr_->logVarSet() - countedLogVars();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
LogVarSet
|
||||||
|
Parfactor::elimLogVars (void) const
|
||||||
|
{
|
||||||
|
LogVarSet requiredToElim = constr_->logVarSet();
|
||||||
|
requiredToElim -= constr_->singletons();
|
||||||
|
requiredToElim -= countedLogVars();
|
||||||
|
return requiredToElim;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
LogVarSet
|
||||||
|
Parfactor::exclusiveLogVars (size_t fIdx) const
|
||||||
|
{
|
||||||
|
assert (fIdx < args_.size());
|
||||||
|
LogVarSet remaining;
|
||||||
|
for (size_t i = 0; i < args_.size(); i++) {
|
||||||
|
if (i != fIdx) {
|
||||||
|
remaining |= args_[i].logVarSet();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return args_[fIdx].logVarSet() - remaining;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
Parfactor::sumOutIndex (size_t fIdx)
|
||||||
|
{
|
||||||
|
assert (fIdx < args_.size());
|
||||||
|
assert (args_[fIdx].contains (elimLogVars()));
|
||||||
|
|
||||||
|
if (args_[fIdx].isCounting()) {
|
||||||
|
unsigned N = constr_->getConditionalCount (
|
||||||
|
args_[fIdx].countedLogVar());
|
||||||
|
unsigned R = args_[fIdx].range();
|
||||||
|
vector<double> numAssigns = HistogramSet::getNumAssigns (N, R);
|
||||||
|
Indexer indexer (ranges_, fIdx);
|
||||||
|
while (indexer.valid()) {
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
params_[indexer] += numAssigns[ indexer[fIdx] ];
|
||||||
|
} else {
|
||||||
|
params_[indexer] *= numAssigns[ indexer[fIdx] ];
|
||||||
|
}
|
||||||
|
++ indexer;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
LogVarSet excl = exclusiveLogVars (fIdx);
|
||||||
|
unsigned exp;
|
||||||
|
if (args_[fIdx].isCounting()) {
|
||||||
|
// counting log vars were already raised on counting conversion
|
||||||
|
exp = constr_->getConditionalCount (excl - args_[fIdx].countedLogVar());
|
||||||
|
} else {
|
||||||
|
exp = constr_->getConditionalCount (excl);
|
||||||
|
}
|
||||||
|
constr_->remove (excl);
|
||||||
|
|
||||||
|
TFactor<ProbFormula>::sumOutIndex (fIdx);
|
||||||
|
LogAware::pow (params_, exp);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
Parfactor::multiply (Parfactor& g)
|
||||||
|
{
|
||||||
|
alignAndExponentiate (this, &g);
|
||||||
|
TFactor<ProbFormula>::multiply (g);
|
||||||
|
constr_->join (g.constr(), true);
|
||||||
|
simplifyGrounds();
|
||||||
|
assert (constr_->isCartesianProduct (countedLogVars()));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
Parfactor::canCountConvert (LogVar X)
|
||||||
|
{
|
||||||
|
if (nrFormulas (X) != 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
size_t fIdx = indexOfLogVar (X);
|
||||||
|
if (args_[fIdx].isCounting()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (constr_->isCountNormalized (X) == false) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (constr_->getConditionalCount (X) == 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (constr_->isCartesianProduct (countedLogVars() | X) == false) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
Parfactor::countConvert (LogVar X)
|
||||||
|
{
|
||||||
|
size_t fIdx = indexOfLogVar (X);
|
||||||
|
assert (constr_->isCountNormalized (X));
|
||||||
|
assert (constr_->getConditionalCount (X) > 1);
|
||||||
|
assert (canCountConvert (X));
|
||||||
|
|
||||||
|
unsigned N = constr_->getConditionalCount (X);
|
||||||
|
unsigned R = ranges_[fIdx];
|
||||||
|
unsigned H = HistogramSet::nrHistograms (N, R);
|
||||||
|
vector<Histogram> histograms = HistogramSet::getHistograms (N, R);
|
||||||
|
|
||||||
|
Indexer indexer (ranges_);
|
||||||
|
vector<Params> sumout (params_.size() / R);
|
||||||
|
unsigned count = 0;
|
||||||
|
while (indexer.valid()) {
|
||||||
|
sumout[count].reserve (R);
|
||||||
|
for (unsigned r = 0; r < R; r++) {
|
||||||
|
sumout[count].push_back (params_[indexer]);
|
||||||
|
indexer.incrementDimension (fIdx);
|
||||||
|
}
|
||||||
|
count ++;
|
||||||
|
indexer.resetDimension (fIdx);
|
||||||
|
indexer.incrementExceptDimension (fIdx);
|
||||||
|
}
|
||||||
|
|
||||||
|
params_.clear();
|
||||||
|
params_.reserve (sumout.size() * H);
|
||||||
|
|
||||||
|
ranges_[fIdx] = H;
|
||||||
|
MapIndexer mapIndexer (ranges_, fIdx);
|
||||||
|
while (mapIndexer.valid()) {
|
||||||
|
double prod = LogAware::multIdenty();
|
||||||
|
size_t i = mapIndexer;
|
||||||
|
unsigned h = mapIndexer[fIdx];
|
||||||
|
for (unsigned r = 0; r < R; r++) {
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
prod += LogAware::pow (sumout[i][r], histograms[h][r]);
|
||||||
|
} else {
|
||||||
|
prod *= LogAware::pow (sumout[i][r], histograms[h][r]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
params_.push_back (prod);
|
||||||
|
++ mapIndexer;
|
||||||
|
}
|
||||||
|
args_[fIdx].setCountedLogVar (X);
|
||||||
|
simplifyCountingFormulas (fIdx);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
Parfactor::expand (LogVar X, LogVar X_new1, LogVar X_new2)
|
||||||
|
{
|
||||||
|
size_t fIdx = indexOfLogVar (X);
|
||||||
|
assert (fIdx != args_.size());
|
||||||
|
assert (args_[fIdx].isCounting());
|
||||||
|
|
||||||
|
unsigned N1 = constr_->getConditionalCount (X_new1);
|
||||||
|
unsigned N2 = constr_->getConditionalCount (X_new2);
|
||||||
|
unsigned N = N1 + N2;
|
||||||
|
unsigned R = args_[fIdx].range();
|
||||||
|
unsigned H1 = HistogramSet::nrHistograms (N1, R);
|
||||||
|
unsigned H2 = HistogramSet::nrHistograms (N2, R);
|
||||||
|
|
||||||
|
vector<Histogram> histograms = HistogramSet::getHistograms (N, R);
|
||||||
|
vector<Histogram> histograms1 = HistogramSet::getHistograms (N1, R);
|
||||||
|
vector<Histogram> histograms2 = HistogramSet::getHistograms (N2, R);
|
||||||
|
|
||||||
|
vector<unsigned> sumIndexes;
|
||||||
|
sumIndexes.reserve (H1 * H2);
|
||||||
|
for (unsigned i = 0; i < H1; i++) {
|
||||||
|
for (unsigned j = 0; j < H2; j++) {
|
||||||
|
Histogram hist = histograms1[i];
|
||||||
|
hist += histograms2[j];
|
||||||
|
sumIndexes.push_back (HistogramSet::findIndex (hist, histograms));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
expandPotential (fIdx, H1 * H2, sumIndexes);
|
||||||
|
|
||||||
|
args_.insert (args_.begin() + fIdx + 1, args_[fIdx]);
|
||||||
|
args_[fIdx].rename (X, X_new1);
|
||||||
|
args_[fIdx + 1].rename (X, X_new2);
|
||||||
|
if (H1 == 2) {
|
||||||
|
args_[fIdx].clearCountedLogVar();
|
||||||
|
}
|
||||||
|
if (H2 == 2) {
|
||||||
|
args_[fIdx + 1].clearCountedLogVar();
|
||||||
|
}
|
||||||
|
ranges_.insert (ranges_.begin() + fIdx + 1, H2);
|
||||||
|
ranges_[fIdx] = H1;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
Parfactor::fullExpand (LogVar X)
|
||||||
|
{
|
||||||
|
size_t fIdx = indexOfLogVar (X);
|
||||||
|
assert (fIdx != args_.size());
|
||||||
|
assert (args_[fIdx].isCounting());
|
||||||
|
|
||||||
|
unsigned N = constr_->getConditionalCount (X);
|
||||||
|
unsigned R = args_[fIdx].range();
|
||||||
|
vector<Histogram> originHists = HistogramSet::getHistograms (N, R);
|
||||||
|
vector<Histogram> expandHists = HistogramSet::getHistograms (1, R);
|
||||||
|
assert (ranges_[fIdx] == originHists.size());
|
||||||
|
vector<unsigned> sumIndexes;
|
||||||
|
sumIndexes.reserve (N * R);
|
||||||
|
|
||||||
|
Ranges expandRanges (N, R);
|
||||||
|
Indexer indexer (expandRanges);
|
||||||
|
while (indexer.valid()) {
|
||||||
|
vector<unsigned> hist (R, 0);
|
||||||
|
for (unsigned n = 0; n < N; n++) {
|
||||||
|
hist += expandHists[indexer[n]];
|
||||||
|
}
|
||||||
|
sumIndexes.push_back (HistogramSet::findIndex (hist, originHists));
|
||||||
|
++ indexer;
|
||||||
|
}
|
||||||
|
|
||||||
|
expandPotential (fIdx, std::pow (R, N), sumIndexes);
|
||||||
|
|
||||||
|
ProbFormula f = args_[fIdx];
|
||||||
|
args_.erase (args_.begin() + fIdx);
|
||||||
|
ranges_.erase (ranges_.begin() + fIdx);
|
||||||
|
LogVars newLvs = constr_->expand (X);
|
||||||
|
assert (newLvs.size() == N);
|
||||||
|
for (unsigned i = 0 ; i < N; i++) {
|
||||||
|
ProbFormula newFormula (f.functor(), f.logVars(), f.range());
|
||||||
|
newFormula.rename (X, newLvs[i]);
|
||||||
|
args_.insert (args_.begin() + fIdx + i, newFormula);
|
||||||
|
ranges_.insert (ranges_.begin() + fIdx + i, R);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
Parfactor::reorderAccordingGrounds (const Grounds& grounds)
|
||||||
|
{
|
||||||
|
ProbFormulas newFormulas;
|
||||||
|
for (size_t i = 0; i < grounds.size(); i++) {
|
||||||
|
for (size_t j = 0; j < args_.size(); j++) {
|
||||||
|
if (grounds[i].functor() == args_[j].functor() &&
|
||||||
|
grounds[i].arity() == args_[j].arity()) {
|
||||||
|
constr_->moveToTop (args_[j].logVars());
|
||||||
|
if (constr_->containsTuple (grounds[i].args())) {
|
||||||
|
newFormulas.push_back (args_[j]);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert (newFormulas.size() == i + 1);
|
||||||
|
}
|
||||||
|
reorderArguments (newFormulas);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
Parfactor::absorveEvidence (const ProbFormula& formula, unsigned evidence)
|
||||||
|
{
|
||||||
|
size_t fIdx = indexOf (formula);
|
||||||
|
assert (fIdx != args_.size());
|
||||||
|
LogVarSet excl = exclusiveLogVars (fIdx);
|
||||||
|
assert (args_[fIdx].isCounting() == false);
|
||||||
|
assert (constr_->isCountNormalized (excl));
|
||||||
|
LogAware::pow (params_, constr_->getConditionalCount (excl));
|
||||||
|
TFactor<ProbFormula>::absorveEvidence (formula, evidence);
|
||||||
|
constr_->remove (excl);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
Parfactor::setNewGroups (void)
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < args_.size(); i++) {
|
||||||
|
args_[i].setGroup (ProbFormula::getNewGroup());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
Parfactor::applySubstitution (const Substitution& theta)
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < args_.size(); i++) {
|
||||||
|
LogVars& lvs = args_[i].logVars();
|
||||||
|
for (size_t j = 0; j < lvs.size(); j++) {
|
||||||
|
lvs[j] = theta.newNameFor (lvs[j]);
|
||||||
|
}
|
||||||
|
if (args_[i].isCounting()) {
|
||||||
|
LogVar clv = args_[i].countedLogVar();
|
||||||
|
args_[i].setCountedLogVar (theta.newNameFor (clv));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
constr_->applySubstitution (theta);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
size_t
|
||||||
|
Parfactor::indexOfGround (const Ground& ground) const
|
||||||
|
{
|
||||||
|
size_t idx = args_.size();
|
||||||
|
for (size_t i = 0; i < args_.size(); i++) {
|
||||||
|
if (args_[i].functor() == ground.functor() &&
|
||||||
|
args_[i].arity() == ground.arity()) {
|
||||||
|
constr_->moveToTop (args_[i].logVars());
|
||||||
|
if (constr_->containsTuple (ground.args())) {
|
||||||
|
idx = i;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return idx;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
PrvGroup
|
||||||
|
Parfactor::findGroup (const Ground& ground) const
|
||||||
|
{
|
||||||
|
size_t idx = indexOfGround (ground);
|
||||||
|
return idx == args_.size()
|
||||||
|
? numeric_limits<PrvGroup>::max()
|
||||||
|
: args_[idx].group();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
Parfactor::containsGround (const Ground& ground) const
|
||||||
|
{
|
||||||
|
return findGroup (ground) != numeric_limits<PrvGroup>::max();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
Parfactor::containsGrounds (const Grounds& grounds) const
|
||||||
|
{
|
||||||
|
Tuple tuple;
|
||||||
|
LogVars tupleLvs;
|
||||||
|
for (size_t i = 0; i < grounds.size(); i++) {
|
||||||
|
size_t idx = indexOfGround (grounds[i]);
|
||||||
|
if (idx == args_.size()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
LogVars lvs = args_[idx].logVars();
|
||||||
|
for (size_t j = 0; j < lvs.size(); j++) {
|
||||||
|
if (Util::contains (tupleLvs, lvs[j]) == false) {
|
||||||
|
tuple.push_back (grounds[i].args()[j]);
|
||||||
|
tupleLvs.push_back (lvs[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
constr_->moveToTop (tupleLvs);
|
||||||
|
return constr_->containsTuple (tuple);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
Parfactor::containsGroup (PrvGroup group) const
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < args_.size(); i++) {
|
||||||
|
if (args_[i].group() == group) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
Parfactor::containsGroups (vector<PrvGroup> groups) const
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < groups.size(); i++) {
|
||||||
|
if (containsGroup (groups[i]) == false) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
unsigned
|
||||||
|
Parfactor::nrFormulas (LogVar X) const
|
||||||
|
{
|
||||||
|
unsigned count = 0;
|
||||||
|
for (size_t i = 0; i < args_.size(); i++) {
|
||||||
|
if (args_[i].contains (X)) {
|
||||||
|
count ++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return count;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
int
|
||||||
|
Parfactor::indexOfLogVar (LogVar X) const
|
||||||
|
{
|
||||||
|
size_t idx = args_.size();
|
||||||
|
assert (nrFormulas (X) == 1);
|
||||||
|
for (size_t i = 0; i < args_.size(); i++) {
|
||||||
|
if (args_[i].contains (X)) {
|
||||||
|
idx = i;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return idx;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
int
|
||||||
|
Parfactor::indexOfGroup (PrvGroup group) const
|
||||||
|
{
|
||||||
|
size_t pos = args_.size();
|
||||||
|
for (size_t i = 0; i < args_.size(); i++) {
|
||||||
|
if (args_[i].group() == group) {
|
||||||
|
pos = i;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return pos;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
unsigned
|
||||||
|
Parfactor::nrFormulasWithGroup (PrvGroup group) const
|
||||||
|
{
|
||||||
|
unsigned count = 0;
|
||||||
|
for (size_t i = 0; i < args_.size(); i++) {
|
||||||
|
if (args_[i].group() == group) {
|
||||||
|
count ++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return count;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
vector<PrvGroup>
|
||||||
|
Parfactor::getAllGroups (void) const
|
||||||
|
{
|
||||||
|
vector<PrvGroup> groups (args_.size());
|
||||||
|
for (size_t i = 0; i < args_.size(); i++) {
|
||||||
|
groups[i] = args_[i].group();
|
||||||
|
}
|
||||||
|
return groups;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
string
|
||||||
|
Parfactor::getLabel (void) const
|
||||||
|
{
|
||||||
|
stringstream ss;
|
||||||
|
ss << "phi(" ;
|
||||||
|
for (size_t i = 0; i < args_.size(); i++) {
|
||||||
|
if (i != 0) ss << "," ;
|
||||||
|
ss << args_[i];
|
||||||
|
}
|
||||||
|
ss << ")" ;
|
||||||
|
ConstraintTree copy (*constr_);
|
||||||
|
copy.moveToTop (copy.logVarSet().elements());
|
||||||
|
ss << "|" << copy.tupleSet();
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
Parfactor::print (bool printParams) const
|
||||||
|
{
|
||||||
|
cout << "Formulas: " ;
|
||||||
|
for (size_t i = 0; i < args_.size(); i++) {
|
||||||
|
if (i != 0) cout << ", " ;
|
||||||
|
cout << args_[i];
|
||||||
|
}
|
||||||
|
cout << endl;
|
||||||
|
if (args_[0].group() != Util::maxUnsigned()) {
|
||||||
|
vector<string> groups;
|
||||||
|
for (size_t i = 0; i < args_.size(); i++) {
|
||||||
|
groups.push_back (string ("g") + Util::toString (args_[i].group()));
|
||||||
|
}
|
||||||
|
cout << "Groups: " << groups << endl;
|
||||||
|
}
|
||||||
|
cout << "LogVars: " << constr_->logVarSet() << endl;
|
||||||
|
cout << "Ranges: " << ranges_ << endl;
|
||||||
|
if (printParams == false) {
|
||||||
|
cout << "Params: " ;
|
||||||
|
if (params_.size() <= 32) {
|
||||||
|
cout.precision(10);
|
||||||
|
cout << params_ << endl;
|
||||||
|
} else {
|
||||||
|
cout << "|" << params_.size() << "|" << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ConstraintTree copy (*constr_);
|
||||||
|
copy.moveToTop (copy.logVarSet().elements());
|
||||||
|
cout << "Tuples: " << copy.tupleSet() << endl;
|
||||||
|
if (printParams) {
|
||||||
|
printParameters();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
Parfactor::printParameters (void) const
|
||||||
|
{
|
||||||
|
vector<string> jointStrings;
|
||||||
|
Indexer indexer (ranges_);
|
||||||
|
while (indexer.valid()) {
|
||||||
|
stringstream ss;
|
||||||
|
for (size_t i = 0; i < args_.size(); i++) {
|
||||||
|
if (i != 0) ss << ", " ;
|
||||||
|
if (args_[i].isCounting()) {
|
||||||
|
unsigned N = constr_->getConditionalCount (
|
||||||
|
args_[i].countedLogVar());
|
||||||
|
HistogramSet hs (N, args_[i].range());
|
||||||
|
unsigned c = 0;
|
||||||
|
while (c < indexer[i]) {
|
||||||
|
hs.nextHistogram();
|
||||||
|
c ++;
|
||||||
|
}
|
||||||
|
ss << hs;
|
||||||
|
} else {
|
||||||
|
ss << indexer[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
jointStrings.push_back (ss.str());
|
||||||
|
++ indexer;
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < params_.size(); i++) {
|
||||||
|
cout << "f(" << jointStrings[i] << ")" ;
|
||||||
|
cout << " = " << params_[i] << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
Parfactor::printProjections (void) const
|
||||||
|
{
|
||||||
|
ConstraintTree copy (*constr_);
|
||||||
|
|
||||||
|
LogVarSet Xs = copy.logVarSet();
|
||||||
|
for (size_t i = 0; i < Xs.size(); i++) {
|
||||||
|
cout << "-> projection of " << Xs[i] << ": " ;
|
||||||
|
cout << copy.tupleSet ({Xs[i]}) << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
Parfactor::expandPotential (
|
||||||
|
size_t fIdx,
|
||||||
|
unsigned newRange,
|
||||||
|
const vector<unsigned>& sumIndexes)
|
||||||
|
{
|
||||||
|
ullong newSize = (params_.size() / ranges_[fIdx]) * newRange;
|
||||||
|
if (newSize > params_.max_size()) {
|
||||||
|
cerr << "Error: an overflow occurred when performing expansion." ;
|
||||||
|
cerr << endl;
|
||||||
|
exit (EXIT_FAILURE);
|
||||||
|
}
|
||||||
|
|
||||||
|
Params backup = params_;
|
||||||
|
params_.clear();
|
||||||
|
params_.reserve (newSize);
|
||||||
|
|
||||||
|
size_t prod = 1;
|
||||||
|
vector<size_t> offsets (ranges_.size());
|
||||||
|
for (size_t i = ranges_.size(); i-- > 0; ) {
|
||||||
|
offsets[i] = prod;
|
||||||
|
prod *= ranges_[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t index = 0;
|
||||||
|
ranges_[fIdx] = newRange;
|
||||||
|
vector<unsigned> indices (ranges_.size(), 0);
|
||||||
|
for (size_t k = 0; k < newSize; k++) {
|
||||||
|
assert (index < backup.size());
|
||||||
|
params_.push_back (backup[index]);
|
||||||
|
for (size_t i = ranges_.size(); i-- > 0; ) {
|
||||||
|
indices[i] ++;
|
||||||
|
if (i == fIdx) {
|
||||||
|
if (indices[i] != ranges_[i]) {
|
||||||
|
int diff = sumIndexes[indices[i]] - sumIndexes[indices[i] - 1];
|
||||||
|
index += diff * offsets[i];
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
// last index contains the old range minus 1
|
||||||
|
index -= sumIndexes.back() * offsets[i];
|
||||||
|
indices[i] = 0;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (indices[i] != ranges_[i]) {
|
||||||
|
index += offsets[i];
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
index -= (ranges_[i] - 1) * offsets[i];
|
||||||
|
indices[i] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
Parfactor::simplifyCountingFormulas (size_t fIdx)
|
||||||
|
{
|
||||||
|
// check if we can simplify the parfactor
|
||||||
|
for (size_t i = 0; i < args_.size(); i++) {
|
||||||
|
if (i != fIdx &&
|
||||||
|
args_[i].isCounting() &&
|
||||||
|
args_[i].group() == args_[fIdx].group()) {
|
||||||
|
// if they only differ in the name of the counting log var
|
||||||
|
if ((args_[i].logVarSet() - args_[i].countedLogVar()) ==
|
||||||
|
(args_[fIdx].logVarSet()) - args_[fIdx].countedLogVar() &&
|
||||||
|
ranges_[i] == ranges_[fIdx]) {
|
||||||
|
simplifyParfactor (fIdx, i);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
Parfactor::simplifyGrounds (void)
|
||||||
|
{
|
||||||
|
if (args_.size() == 1) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
LogVarSet singletons = constr_->singletons();
|
||||||
|
for (long i = 0; i < (long)args_.size() - 1; i++) {
|
||||||
|
for (size_t j = i + 1; j < args_.size(); j++) {
|
||||||
|
if (args_[i].group() == args_[j].group() &&
|
||||||
|
singletons.contains (args_[i].logVarSet()) &&
|
||||||
|
singletons.contains (args_[j].logVarSet())) {
|
||||||
|
simplifyParfactor (i, j);
|
||||||
|
i --;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
Parfactor::canMultiply (Parfactor* g1, Parfactor* g2)
|
||||||
|
{
|
||||||
|
std::pair<LogVars, LogVars> res = getAlignLogVars (g1, g2);
|
||||||
|
LogVarSet Xs_1 (res.first);
|
||||||
|
LogVarSet Xs_2 (res.second);
|
||||||
|
LogVarSet Y_1 = g1->logVarSet() - Xs_1;
|
||||||
|
LogVarSet Y_2 = g2->logVarSet() - Xs_2;
|
||||||
|
Y_1 -= g1->countedLogVars();
|
||||||
|
Y_2 -= g2->countedLogVars();
|
||||||
|
return g1->constr()->isCountNormalized (Y_1) &&
|
||||||
|
g2->constr()->isCountNormalized (Y_2);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
Parfactor::simplifyParfactor (size_t fIdx1, size_t fIdx2)
|
||||||
|
{
|
||||||
|
Params backup = params_;
|
||||||
|
params_.clear();
|
||||||
|
Indexer indexer (ranges_);
|
||||||
|
while (indexer.valid()) {
|
||||||
|
if (indexer[fIdx1] == indexer[fIdx2]) {
|
||||||
|
params_.push_back (backup[indexer]);
|
||||||
|
}
|
||||||
|
++ indexer;
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < args_[fIdx2].logVars().size(); i++) {
|
||||||
|
if (nrFormulas (args_[fIdx2].logVars()[i]) == 1) {
|
||||||
|
constr_->remove ({ args_[fIdx2].logVars()[i] });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
args_.erase (args_.begin() + fIdx2);
|
||||||
|
ranges_.erase (ranges_.begin() + fIdx2);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
std::pair<LogVars, LogVars>
|
||||||
|
Parfactor::getAlignLogVars (Parfactor* g1, Parfactor* g2)
|
||||||
|
{
|
||||||
|
g1->simplifyGrounds();
|
||||||
|
g2->simplifyGrounds();
|
||||||
|
LogVars Xs_1, Xs_2;
|
||||||
|
TinySet<size_t> matchedI;
|
||||||
|
TinySet<size_t> matchedJ;
|
||||||
|
ProbFormulas& formulas1 = g1->arguments();
|
||||||
|
ProbFormulas& formulas2 = g2->arguments();
|
||||||
|
for (size_t i = 0; i < formulas1.size(); i++) {
|
||||||
|
for (size_t j = 0; j < formulas2.size(); j++) {
|
||||||
|
if (formulas1[i].group() == formulas2[j].group() &&
|
||||||
|
g1->range (i) == g2->range (j) &&
|
||||||
|
matchedI.contains (i) == false &&
|
||||||
|
matchedJ.contains (j) == false) {
|
||||||
|
Util::addToVector (Xs_1, formulas1[i].logVars());
|
||||||
|
Util::addToVector (Xs_2, formulas2[j].logVars());
|
||||||
|
matchedI.insert (i);
|
||||||
|
matchedJ.insert (j);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return make_pair (Xs_1, Xs_2);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
Parfactor::alignAndExponentiate (Parfactor* g1, Parfactor* g2)
|
||||||
|
{
|
||||||
|
alignLogicalVars (g1, g2);
|
||||||
|
LogVarSet comm = g1->logVarSet() & g2->logVarSet();
|
||||||
|
LogVarSet Y_1 = g1->logVarSet() - comm;
|
||||||
|
LogVarSet Y_2 = g2->logVarSet() - comm;
|
||||||
|
Y_1 -= g1->countedLogVars();
|
||||||
|
Y_2 -= g2->countedLogVars();
|
||||||
|
assert (g1->constr()->isCountNormalized (Y_1));
|
||||||
|
assert (g2->constr()->isCountNormalized (Y_2));
|
||||||
|
unsigned condCount1 = g1->constr()->getConditionalCount (Y_1);
|
||||||
|
unsigned condCount2 = g2->constr()->getConditionalCount (Y_2);
|
||||||
|
LogAware::pow (g1->params(), 1.0 / condCount2);
|
||||||
|
LogAware::pow (g2->params(), 1.0 / condCount1);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
Parfactor::alignLogicalVars (Parfactor* g1, Parfactor* g2)
|
||||||
|
{
|
||||||
|
std::pair<LogVars, LogVars> res = getAlignLogVars (g1, g2);
|
||||||
|
const LogVars& alignLvs1 = res.first;
|
||||||
|
const LogVars& alignLvs2 = res.second;
|
||||||
|
// cout << "ALIGNING :::::::::::::::::" << endl;
|
||||||
|
// g1->print();
|
||||||
|
// cout << "AND" << endl;
|
||||||
|
// g2->print();
|
||||||
|
// cout << "-> align lvs1 = " << alignLvs1 << endl;
|
||||||
|
// cout << "-> align lvs2 = " << alignLvs2 << endl;
|
||||||
|
LogVar freeLogVar (0);
|
||||||
|
Substitution theta1, theta2;
|
||||||
|
for (size_t i = 0; i < alignLvs1.size(); i++) {
|
||||||
|
bool b1 = theta1.containsReplacementFor (alignLvs1[i]);
|
||||||
|
bool b2 = theta2.containsReplacementFor (alignLvs2[i]);
|
||||||
|
if (b1 == false && b2 == false) {
|
||||||
|
theta1.add (alignLvs1[i], freeLogVar);
|
||||||
|
theta2.add (alignLvs2[i], freeLogVar);
|
||||||
|
++ freeLogVar;
|
||||||
|
} else if (b1 == false && b2) {
|
||||||
|
theta1.add (alignLvs1[i], theta2.newNameFor (alignLvs2[i]));
|
||||||
|
} else if (b1 && b2 == false) {
|
||||||
|
theta2.add (alignLvs2[i], theta1.newNameFor (alignLvs1[i]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const LogVarSet& allLvs1 = g1->logVarSet();
|
||||||
|
for (size_t i = 0; i < allLvs1.size(); i++) {
|
||||||
|
if (theta1.containsReplacementFor (allLvs1[i]) == false) {
|
||||||
|
theta1.add (allLvs1[i], freeLogVar);
|
||||||
|
++ freeLogVar;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const LogVarSet& allLvs2 = g2->logVarSet();
|
||||||
|
for (size_t i = 0; i < allLvs2.size(); i++) {
|
||||||
|
if (theta2.containsReplacementFor (allLvs2[i]) == false) {
|
||||||
|
theta2.add (allLvs2[i], freeLogVar);
|
||||||
|
++ freeLogVar;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handle this type of situation:
|
||||||
|
// g1 = p(X), q(X) ; X in {(p1),(p2)}
|
||||||
|
// g2 = p(X), q(Y) ; (X,Y) in {(p1,p2),(p2,p1)}
|
||||||
|
LogVars discardedLvs1 = theta1.getDiscardedLogVars();
|
||||||
|
for (size_t i = 0; i < discardedLvs1.size(); i++) {
|
||||||
|
if (g1->constr()->isSingleton (discardedLvs1[i]) &&
|
||||||
|
g1->nrFormulas (discardedLvs1[i]) == 1) {
|
||||||
|
g1->constr()->remove (discardedLvs1[i]);
|
||||||
|
} else {
|
||||||
|
LogVar X_new = ++ g1->constr()->logVarSet().back();
|
||||||
|
theta1.rename (discardedLvs1[i], X_new);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
LogVars discardedLvs2 = theta2.getDiscardedLogVars();
|
||||||
|
for (size_t i = 0; i < discardedLvs2.size(); i++) {
|
||||||
|
if (g2->constr()->isSingleton (discardedLvs2[i]) &&
|
||||||
|
g2->nrFormulas (discardedLvs2[i]) == 1) {
|
||||||
|
g2->constr()->remove (discardedLvs2[i]);
|
||||||
|
} else {
|
||||||
|
LogVar X_new = ++ g2->constr()->logVarSet().back();
|
||||||
|
theta2.rename (discardedLvs2[i], X_new);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// cout << "theta1: " << theta1 << endl;
|
||||||
|
// cout << "theta2: " << theta2 << endl;
|
||||||
|
g1->applySubstitution (theta1);
|
||||||
|
g2->applySubstitution (theta2);
|
||||||
|
}
|
||||||
|
|
125
packages/CLPBN/horus2/Parfactor.h
Normal file
125
packages/CLPBN/horus2/Parfactor.h
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
#ifndef HORUS_PARFACTOR_H
|
||||||
|
#define HORUS_PARFACTOR_H
|
||||||
|
|
||||||
|
#include <list>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
|
#include "ProbFormula.h"
|
||||||
|
#include "ConstraintTree.h"
|
||||||
|
#include "LiftedUtils.h"
|
||||||
|
#include "Horus.h"
|
||||||
|
|
||||||
|
#include "Factor.h"
|
||||||
|
|
||||||
|
class Parfactor : public TFactor<ProbFormula>
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
Parfactor (
|
||||||
|
const ProbFormulas&,
|
||||||
|
const Params&,
|
||||||
|
const Tuples&,
|
||||||
|
unsigned distId);
|
||||||
|
|
||||||
|
Parfactor (const Parfactor*, const Tuple&);
|
||||||
|
|
||||||
|
Parfactor (const Parfactor*, ConstraintTree*);
|
||||||
|
|
||||||
|
Parfactor (const Parfactor&);
|
||||||
|
|
||||||
|
~Parfactor (void);
|
||||||
|
|
||||||
|
ConstraintTree* constr (void) { return constr_; }
|
||||||
|
|
||||||
|
const ConstraintTree* constr (void) const { return constr_; }
|
||||||
|
|
||||||
|
const LogVars& logVars (void) const { return constr_->logVars(); }
|
||||||
|
|
||||||
|
const LogVarSet& logVarSet (void) const { return constr_->logVarSet(); }
|
||||||
|
|
||||||
|
LogVarSet countedLogVars (void) const;
|
||||||
|
|
||||||
|
LogVarSet uncountedLogVars (void) const;
|
||||||
|
|
||||||
|
LogVarSet elimLogVars (void) const;
|
||||||
|
|
||||||
|
LogVarSet exclusiveLogVars (size_t fIdx) const;
|
||||||
|
|
||||||
|
void sumOutIndex (size_t fIdx);
|
||||||
|
|
||||||
|
void multiply (Parfactor&);
|
||||||
|
|
||||||
|
bool canCountConvert (LogVar X);
|
||||||
|
|
||||||
|
void countConvert (LogVar);
|
||||||
|
|
||||||
|
void expand (LogVar, LogVar, LogVar);
|
||||||
|
|
||||||
|
void fullExpand (LogVar);
|
||||||
|
|
||||||
|
void reorderAccordingGrounds (const Grounds&);
|
||||||
|
|
||||||
|
void absorveEvidence (const ProbFormula&, unsigned);
|
||||||
|
|
||||||
|
void setNewGroups (void);
|
||||||
|
|
||||||
|
void applySubstitution (const Substitution&);
|
||||||
|
|
||||||
|
size_t indexOfGround (const Ground&) const;
|
||||||
|
|
||||||
|
PrvGroup findGroup (const Ground&) const;
|
||||||
|
|
||||||
|
bool containsGround (const Ground&) const;
|
||||||
|
|
||||||
|
bool containsGrounds (const Grounds&) const;
|
||||||
|
|
||||||
|
bool containsGroup (PrvGroup) const;
|
||||||
|
|
||||||
|
bool containsGroups (vector<PrvGroup>) const;
|
||||||
|
|
||||||
|
unsigned nrFormulas (LogVar) const;
|
||||||
|
|
||||||
|
int indexOfLogVar (LogVar) const;
|
||||||
|
|
||||||
|
int indexOfGroup (PrvGroup) const;
|
||||||
|
|
||||||
|
unsigned nrFormulasWithGroup (PrvGroup) const;
|
||||||
|
|
||||||
|
vector<PrvGroup> getAllGroups (void) const;
|
||||||
|
|
||||||
|
void print (bool = false) const;
|
||||||
|
|
||||||
|
void printParameters (void) const;
|
||||||
|
|
||||||
|
void printProjections (void) const;
|
||||||
|
|
||||||
|
string getLabel (void) const;
|
||||||
|
|
||||||
|
void simplifyGrounds (void);
|
||||||
|
|
||||||
|
static bool canMultiply (Parfactor*, Parfactor*);
|
||||||
|
|
||||||
|
private:
|
||||||
|
|
||||||
|
void simplifyCountingFormulas (size_t fIdx);
|
||||||
|
|
||||||
|
void simplifyParfactor (size_t fIdx1, size_t fIdx2);
|
||||||
|
|
||||||
|
static std::pair<LogVars, LogVars> getAlignLogVars (
|
||||||
|
Parfactor* g1, Parfactor* g2);
|
||||||
|
|
||||||
|
void expandPotential (size_t fIdx, unsigned newRange,
|
||||||
|
const vector<unsigned>& sumIndexes);
|
||||||
|
|
||||||
|
static void alignAndExponentiate (Parfactor*, Parfactor*);
|
||||||
|
|
||||||
|
static void alignLogicalVars (Parfactor*, Parfactor*);
|
||||||
|
|
||||||
|
ConstraintTree* constr_;
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
typedef vector<Parfactor*> Parfactors;
|
||||||
|
|
||||||
|
#endif // HORUS_PARFACTOR_H
|
||||||
|
|
638
packages/CLPBN/horus2/ParfactorList.cpp
Normal file
638
packages/CLPBN/horus2/ParfactorList.cpp
Normal file
@ -0,0 +1,638 @@
|
|||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
#include "ParfactorList.h"
|
||||||
|
|
||||||
|
|
||||||
|
ParfactorList::ParfactorList (const ParfactorList& pfList)
|
||||||
|
{
|
||||||
|
ParfactorList::const_iterator it = pfList.begin();
|
||||||
|
while (it != pfList.end()) {
|
||||||
|
addShattered (new Parfactor (**it));
|
||||||
|
++ it;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
ParfactorList::ParfactorList (const Parfactors& pfs)
|
||||||
|
{
|
||||||
|
add (pfs);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
ParfactorList::~ParfactorList (void)
|
||||||
|
{
|
||||||
|
ParfactorList::const_iterator it = pfList_.begin();
|
||||||
|
while (it != pfList_.end()) {
|
||||||
|
delete *it;
|
||||||
|
++ it;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
ParfactorList::add (Parfactor* pf)
|
||||||
|
{
|
||||||
|
pf->setNewGroups();
|
||||||
|
addToShatteredList (pf);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
ParfactorList::add (const Parfactors& pfs)
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < pfs.size(); i++) {
|
||||||
|
pfs[i]->setNewGroups();
|
||||||
|
addToShatteredList (pfs[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
ParfactorList::addShattered (Parfactor* pf)
|
||||||
|
{
|
||||||
|
assert (isAllShattered());
|
||||||
|
pfList_.push_back (pf);
|
||||||
|
assert (isAllShattered());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
list<Parfactor*>::iterator
|
||||||
|
ParfactorList::insertShattered (
|
||||||
|
list<Parfactor*>::iterator it,
|
||||||
|
Parfactor* pf)
|
||||||
|
{
|
||||||
|
return pfList_.insert (it, pf);
|
||||||
|
assert (isAllShattered());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
list<Parfactor*>::iterator
|
||||||
|
ParfactorList::remove (list<Parfactor*>::iterator it)
|
||||||
|
{
|
||||||
|
return pfList_.erase (it);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
list<Parfactor*>::iterator
|
||||||
|
ParfactorList::removeAndDelete (list<Parfactor*>::iterator it)
|
||||||
|
{
|
||||||
|
delete *it;
|
||||||
|
return pfList_.erase (it);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
ParfactorList::isAllShattered (void) const
|
||||||
|
{
|
||||||
|
if (pfList_.size() <= 1) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
vector<Parfactor*> pfs (pfList_.begin(), pfList_.end());
|
||||||
|
for (size_t i = 0; i < pfs.size(); i++) {
|
||||||
|
assert (isShattered (pfs[i]));
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < pfs.size() - 1; i++) {
|
||||||
|
for (size_t j = i + 1; j < pfs.size(); j++) {
|
||||||
|
if (isShattered (pfs[i], pfs[j]) == false) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
ParfactorList::print (void) const
|
||||||
|
{
|
||||||
|
Parfactors pfVec (pfList_.begin(), pfList_.end());
|
||||||
|
std::sort (pfVec.begin(), pfVec.end(), sortByParams());
|
||||||
|
for (size_t i = 0; i < pfVec.size(); i++) {
|
||||||
|
pfVec[i]->print();
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
ParfactorList&
|
||||||
|
ParfactorList::operator= (const ParfactorList& pfList)
|
||||||
|
{
|
||||||
|
if (this != &pfList) {
|
||||||
|
ParfactorList::const_iterator it0 = pfList_.begin();
|
||||||
|
while (it0 != pfList_.end()) {
|
||||||
|
delete *it0;
|
||||||
|
++ it0;
|
||||||
|
}
|
||||||
|
pfList_.clear();
|
||||||
|
ParfactorList::const_iterator it = pfList.begin();
|
||||||
|
while (it != pfList.end()) {
|
||||||
|
addShattered (new Parfactor (**it));
|
||||||
|
++ it;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
ParfactorList::isShattered (const Parfactor* g) const
|
||||||
|
{
|
||||||
|
const ProbFormulas& formulas = g->arguments();
|
||||||
|
if (formulas.size() < 2) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
ConstraintTree ct (*g->constr());
|
||||||
|
for (size_t i = 0; i < formulas.size() - 1; i++) {
|
||||||
|
for (size_t j = i + 1; j < formulas.size(); j++) {
|
||||||
|
if (formulas[i].group() == formulas[j].group()) {
|
||||||
|
if (identical (
|
||||||
|
formulas[i], *(g->constr()),
|
||||||
|
formulas[j], *(g->constr())) == false) {
|
||||||
|
g->print();
|
||||||
|
cout << "-> not identical on positions " ;
|
||||||
|
cout << i << " and " << j << endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (disjoint (
|
||||||
|
formulas[i], *(g->constr()),
|
||||||
|
formulas[j], *(g->constr())) == false) {
|
||||||
|
g->print();
|
||||||
|
cout << "-> not disjoint on positions " ;
|
||||||
|
cout << i << " and " << j << endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
ParfactorList::isShattered (
|
||||||
|
const Parfactor* g1,
|
||||||
|
const Parfactor* g2) const
|
||||||
|
{
|
||||||
|
assert (g1 != g2);
|
||||||
|
const ProbFormulas& fms1 = g1->arguments();
|
||||||
|
const ProbFormulas& fms2 = g2->arguments();
|
||||||
|
|
||||||
|
for (size_t i = 0; i < fms1.size(); i++) {
|
||||||
|
for (size_t j = 0; j < fms2.size(); j++) {
|
||||||
|
if (fms1[i].group() == fms2[j].group()) {
|
||||||
|
if (identical (
|
||||||
|
fms1[i], *(g1->constr()),
|
||||||
|
fms2[j], *(g2->constr())) == false) {
|
||||||
|
g1->print();
|
||||||
|
cout << "^" << endl;
|
||||||
|
g2->print();
|
||||||
|
cout << "-> not identical on group " << fms1[i].group() << endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (disjoint (
|
||||||
|
fms1[i], *(g1->constr()),
|
||||||
|
fms2[j], *(g2->constr())) == false) {
|
||||||
|
g1->print();
|
||||||
|
cout << "^" << endl;
|
||||||
|
g2->print();
|
||||||
|
cout << "-> not disjoint on groups " << fms1[i].group();
|
||||||
|
cout << " and " << fms2[j].group() << endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
ParfactorList::addToShatteredList (Parfactor* g)
|
||||||
|
{
|
||||||
|
queue<Parfactor*> residuals;
|
||||||
|
residuals.push (g);
|
||||||
|
while (residuals.empty() == false) {
|
||||||
|
Parfactor* pf = residuals.front();
|
||||||
|
bool pfSplitted = false;
|
||||||
|
list<Parfactor*>::iterator pfIter;
|
||||||
|
pfIter = pfList_.begin();
|
||||||
|
while (pfIter != pfList_.end()) {
|
||||||
|
std::pair<Parfactors, Parfactors> shattRes;
|
||||||
|
shattRes = shatter (*pfIter, pf);
|
||||||
|
if (shattRes.first.empty() == false) {
|
||||||
|
pfIter = removeAndDelete (pfIter);
|
||||||
|
Util::addToQueue (residuals, shattRes.first);
|
||||||
|
} else {
|
||||||
|
++ pfIter;
|
||||||
|
}
|
||||||
|
if (shattRes.second.empty() == false) {
|
||||||
|
delete pf;
|
||||||
|
Util::addToQueue (residuals, shattRes.second);
|
||||||
|
pfSplitted = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
residuals.pop();
|
||||||
|
if (pfSplitted == false) {
|
||||||
|
Parfactors res = shatterAgainstMySelf (pf);
|
||||||
|
if (res.empty()) {
|
||||||
|
addShattered (pf);
|
||||||
|
} else {
|
||||||
|
Util::addToQueue (residuals, res);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert (isAllShattered());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Parfactors
|
||||||
|
ParfactorList::shatterAgainstMySelf (Parfactor* g)
|
||||||
|
{
|
||||||
|
Parfactors pfs;
|
||||||
|
queue<Parfactor*> residuals;
|
||||||
|
residuals.push (g);
|
||||||
|
bool shattered = true;
|
||||||
|
while (residuals.empty() == false) {
|
||||||
|
Parfactor* pf = residuals.front();
|
||||||
|
Parfactors res = shatterAgainstMySelf2 (pf);
|
||||||
|
if (res.empty()) {
|
||||||
|
assert (isShattered (pf));
|
||||||
|
if (shattered) {
|
||||||
|
return { };
|
||||||
|
}
|
||||||
|
pfs.push_back (pf);
|
||||||
|
} else {
|
||||||
|
shattered = false;
|
||||||
|
for (size_t i = 0; i < res.size(); i++) {
|
||||||
|
assert (res[i]->constr()->empty() == false);
|
||||||
|
residuals.push (res[i]);
|
||||||
|
}
|
||||||
|
delete pf;
|
||||||
|
}
|
||||||
|
residuals.pop();
|
||||||
|
}
|
||||||
|
return pfs;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Parfactors
|
||||||
|
ParfactorList::shatterAgainstMySelf2 (Parfactor* g)
|
||||||
|
{
|
||||||
|
// slip a parfactor with overlapping formulas:
|
||||||
|
// e.g. {s(X),s(Y)}, with (X,Y) in {(p1,p2),(p1,p3),(p4,p1)}
|
||||||
|
const ProbFormulas& formulas = g->arguments();
|
||||||
|
for (size_t i = 0; i < formulas.size() - 1; i++) {
|
||||||
|
for (size_t j = i + 1; j < formulas.size(); j++) {
|
||||||
|
if (formulas[i].sameSkeletonAs (formulas[j])) {
|
||||||
|
Parfactors res = shatterAgainstMySelf (g, i, j);
|
||||||
|
if (res.empty() == false) {
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Parfactors();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Parfactors
|
||||||
|
ParfactorList::shatterAgainstMySelf (
|
||||||
|
Parfactor* g,
|
||||||
|
size_t fIdx1,
|
||||||
|
size_t fIdx2)
|
||||||
|
{
|
||||||
|
/*
|
||||||
|
Util::printDashedLine();
|
||||||
|
cout << "-> SHATTERING" << endl;
|
||||||
|
g->print();
|
||||||
|
cout << "-> ON: " << g->argument (fIdx1) << "|" ;
|
||||||
|
cout << g->constr()->tupleSet (g->argument (fIdx1).logVars()) << endl;
|
||||||
|
cout << "-> ON: " << g->argument (fIdx2) << "|" ;
|
||||||
|
cout << g->constr()->tupleSet (g->argument (fIdx2).logVars()) << endl;
|
||||||
|
Util::printDashedLine();
|
||||||
|
*/
|
||||||
|
ProbFormula& f1 = g->argument (fIdx1);
|
||||||
|
ProbFormula& f2 = g->argument (fIdx2);
|
||||||
|
if (f1.isAtom()) {
|
||||||
|
cerr << "Error: a ground occurs twice in the same parfactor." << endl;
|
||||||
|
cerr << endl;
|
||||||
|
exit (EXIT_FAILURE);
|
||||||
|
}
|
||||||
|
assert (g->constr()->empty() == false);
|
||||||
|
ConstraintTree ctCopy (*g->constr());
|
||||||
|
if (f1.group() == f2.group()) {
|
||||||
|
assert (identical (f1, *(g->constr()), f2, ctCopy));
|
||||||
|
return { };
|
||||||
|
}
|
||||||
|
|
||||||
|
g->constr()->moveToTop (f1.logVars());
|
||||||
|
ctCopy.moveToTop (f2.logVars());
|
||||||
|
|
||||||
|
std::pair<ConstraintTree*,ConstraintTree*> split1 =
|
||||||
|
g->constr()->split (f1.logVars(), &ctCopy, f2.logVars());
|
||||||
|
ConstraintTree* commCt1 = split1.first;
|
||||||
|
ConstraintTree* exclCt1 = split1.second;
|
||||||
|
|
||||||
|
if (commCt1->empty()) {
|
||||||
|
// disjoint
|
||||||
|
delete commCt1;
|
||||||
|
delete exclCt1;
|
||||||
|
return { };
|
||||||
|
}
|
||||||
|
|
||||||
|
PrvGroup newGroup = ProbFormula::getNewGroup();
|
||||||
|
Parfactors res1 = shatter (g, fIdx1, commCt1, exclCt1, newGroup);
|
||||||
|
if (res1.empty()) {
|
||||||
|
res1.push_back (g);
|
||||||
|
}
|
||||||
|
|
||||||
|
Parfactors res;
|
||||||
|
ctCopy.moveToTop (f1.logVars());
|
||||||
|
for (size_t i = 0; i < res1.size(); i++) {
|
||||||
|
res1[i]->constr()->moveToTop (f2.logVars());
|
||||||
|
std::pair<ConstraintTree*, ConstraintTree*> split2;
|
||||||
|
split2 = res1[i]->constr()->split (f2.logVars(), &ctCopy, f1.logVars());
|
||||||
|
ConstraintTree* commCt2 = split2.first;
|
||||||
|
ConstraintTree* exclCt2 = split2.second;
|
||||||
|
if (commCt2->empty()) {
|
||||||
|
if (res1[i] != g) {
|
||||||
|
res.push_back (res1[i]);
|
||||||
|
}
|
||||||
|
delete commCt2;
|
||||||
|
delete exclCt2;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
newGroup = ProbFormula::getNewGroup();
|
||||||
|
Parfactors res2 = shatter (res1[i], fIdx2, commCt2, exclCt2, newGroup);
|
||||||
|
if (res2.empty()) {
|
||||||
|
if (res1[i] != g) {
|
||||||
|
res.push_back (res1[i]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Util::addToVector (res, res2);
|
||||||
|
for (size_t j = 0; j < res2.size(); j++) {
|
||||||
|
}
|
||||||
|
if (res1[i] != g) {
|
||||||
|
delete res1[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (res.empty()) {
|
||||||
|
g->argument (fIdx2).setGroup (g->argument (fIdx1).group());
|
||||||
|
updateGroups (f2.group(), f1.group());
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
std::pair<Parfactors, Parfactors>
|
||||||
|
ParfactorList::shatter (Parfactor* g1, Parfactor* g2)
|
||||||
|
{
|
||||||
|
ProbFormulas& formulas1 = g1->arguments();
|
||||||
|
ProbFormulas& formulas2 = g2->arguments();
|
||||||
|
assert (g1 != 0 && g2 != 0 && g1 != g2);
|
||||||
|
for (size_t i = 0; i < formulas1.size(); i++) {
|
||||||
|
for (size_t j = 0; j < formulas2.size(); j++) {
|
||||||
|
if (formulas1[i].sameSkeletonAs (formulas2[j])) {
|
||||||
|
std::pair<Parfactors, Parfactors> res;
|
||||||
|
res = shatter (i, g1, j, g2);
|
||||||
|
if (res.first.empty() == false ||
|
||||||
|
res.second.empty() == false) {
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return make_pair (Parfactors(), Parfactors());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
std::pair<Parfactors, Parfactors>
|
||||||
|
ParfactorList::shatter (
|
||||||
|
size_t fIdx1, Parfactor* g1,
|
||||||
|
size_t fIdx2, Parfactor* g2)
|
||||||
|
{
|
||||||
|
ProbFormula& f1 = g1->argument (fIdx1);
|
||||||
|
ProbFormula& f2 = g2->argument (fIdx2);
|
||||||
|
/*
|
||||||
|
Util::printDashedLine();
|
||||||
|
cout << "-> SHATTERING" << endl;
|
||||||
|
g1->print();
|
||||||
|
cout << "-> WITH" << endl;
|
||||||
|
g2->print();
|
||||||
|
cout << "-> ON: " << f1 << "|" ;
|
||||||
|
cout << g1->constr()->tupleSet (f1.logVars()) << endl;
|
||||||
|
cout << "-> ON: " << f2 << "|" ;
|
||||||
|
cout << g2->constr()->tupleSet (f2.logVars()) << endl;
|
||||||
|
Util::printDashedLine();
|
||||||
|
*/
|
||||||
|
if (f1.isAtom()) {
|
||||||
|
f2.setGroup (f1.group());
|
||||||
|
updateGroups (f2.group(), f1.group());
|
||||||
|
return { };
|
||||||
|
}
|
||||||
|
assert (g1->constr()->empty() == false);
|
||||||
|
assert (g2->constr()->empty() == false);
|
||||||
|
if (f1.group() == f2.group()) {
|
||||||
|
assert (identical (f1, *(g1->constr()), f2, *(g2->constr())));
|
||||||
|
return { };
|
||||||
|
}
|
||||||
|
|
||||||
|
g1->constr()->moveToTop (f1.logVars());
|
||||||
|
g2->constr()->moveToTop (f2.logVars());
|
||||||
|
|
||||||
|
std::pair<ConstraintTree*,ConstraintTree*> split1 =
|
||||||
|
g1->constr()->split (f1.logVars(), g2->constr(), f2.logVars());
|
||||||
|
ConstraintTree* commCt1 = split1.first;
|
||||||
|
ConstraintTree* exclCt1 = split1.second;
|
||||||
|
|
||||||
|
if (commCt1->empty()) {
|
||||||
|
// disjoint
|
||||||
|
delete commCt1;
|
||||||
|
delete exclCt1;
|
||||||
|
return { };
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<ConstraintTree*,ConstraintTree*> split2 =
|
||||||
|
g2->constr()->split (f2.logVars(), g1->constr(), f1.logVars());
|
||||||
|
ConstraintTree* commCt2 = split2.first;
|
||||||
|
ConstraintTree* exclCt2 = split2.second;
|
||||||
|
|
||||||
|
assert (commCt1->tupleSet (f1.logVars()) ==
|
||||||
|
commCt2->tupleSet (f2.logVars()));
|
||||||
|
|
||||||
|
// stringstream ss1; ss1 << "" << count << "_A.dot" ;
|
||||||
|
// stringstream ss2; ss2 << "" << count << "_B.dot" ;
|
||||||
|
// stringstream ss3; ss3 << "" << count << "_A_comm.dot" ;
|
||||||
|
// stringstream ss4; ss4 << "" << count << "_A_excl.dot" ;
|
||||||
|
// stringstream ss5; ss5 << "" << count << "_B_comm.dot" ;
|
||||||
|
// stringstream ss6; ss6 << "" << count << "_B_excl.dot" ;
|
||||||
|
// g1->constr()->exportToGraphViz (ss1.str().c_str(), true);
|
||||||
|
// g2->constr()->exportToGraphViz (ss2.str().c_str(), true);
|
||||||
|
// commCt1->exportToGraphViz (ss3.str().c_str(), true);
|
||||||
|
// exclCt1->exportToGraphViz (ss4.str().c_str(), true);
|
||||||
|
// commCt2->exportToGraphViz (ss5.str().c_str(), true);
|
||||||
|
// exclCt2->exportToGraphViz (ss6.str().c_str(), true);
|
||||||
|
|
||||||
|
if (exclCt1->empty() && exclCt2->empty()) {
|
||||||
|
// identical
|
||||||
|
f2.setGroup (f1.group());
|
||||||
|
updateGroups (f2.group(), f1.group());
|
||||||
|
delete commCt1;
|
||||||
|
delete exclCt1;
|
||||||
|
delete commCt2;
|
||||||
|
delete exclCt2;
|
||||||
|
return { };
|
||||||
|
}
|
||||||
|
|
||||||
|
PrvGroup group;
|
||||||
|
if (exclCt1->empty()) {
|
||||||
|
group = f1.group();
|
||||||
|
} else if (exclCt2->empty()) {
|
||||||
|
group = f2.group();
|
||||||
|
} else {
|
||||||
|
group = ProbFormula::getNewGroup();
|
||||||
|
}
|
||||||
|
Parfactors res1 = shatter (g1, fIdx1, commCt1, exclCt1, group);
|
||||||
|
Parfactors res2 = shatter (g2, fIdx2, commCt2, exclCt2, group);
|
||||||
|
return make_pair (res1, res2);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Parfactors
|
||||||
|
ParfactorList::shatter (
|
||||||
|
Parfactor* g,
|
||||||
|
size_t fIdx,
|
||||||
|
ConstraintTree* commCt,
|
||||||
|
ConstraintTree* exclCt,
|
||||||
|
PrvGroup commGroup)
|
||||||
|
{
|
||||||
|
ProbFormula& f = g->argument (fIdx);
|
||||||
|
if (exclCt->empty()) {
|
||||||
|
delete commCt;
|
||||||
|
delete exclCt;
|
||||||
|
f.setGroup (commGroup);
|
||||||
|
return { };
|
||||||
|
}
|
||||||
|
|
||||||
|
Parfactors result;
|
||||||
|
if (f.isCounting()) {
|
||||||
|
LogVar X_new1 = g->constr()->logVarSet().back() + 1;
|
||||||
|
LogVar X_new2 = g->constr()->logVarSet().back() + 2;
|
||||||
|
ConstraintTrees cts = g->constr()->jointCountNormalize (
|
||||||
|
commCt, exclCt, f.countedLogVar(), X_new1, X_new2);
|
||||||
|
for (size_t i = 0; i < cts.size(); i++) {
|
||||||
|
Parfactor* newPf = new Parfactor (g, cts[i]);
|
||||||
|
if (cts[i]->nrLogVars() == g->constr()->nrLogVars() + 1) {
|
||||||
|
newPf->expand (f.countedLogVar(), X_new1, X_new2);
|
||||||
|
assert (g->constr()->getConditionalCount (f.countedLogVar()) ==
|
||||||
|
cts[i]->getConditionalCount (X_new1) +
|
||||||
|
cts[i]->getConditionalCount (X_new2));
|
||||||
|
} else {
|
||||||
|
assert (g->constr()->getConditionalCount (f.countedLogVar()) ==
|
||||||
|
cts[i]->getConditionalCount (f.countedLogVar()));
|
||||||
|
}
|
||||||
|
newPf->setNewGroups();
|
||||||
|
result.push_back (newPf);
|
||||||
|
}
|
||||||
|
delete commCt;
|
||||||
|
delete exclCt;
|
||||||
|
} else {
|
||||||
|
Parfactor* newPf = new Parfactor (g, commCt);
|
||||||
|
newPf->setNewGroups();
|
||||||
|
newPf->argument (fIdx).setGroup (commGroup);
|
||||||
|
result.push_back (newPf);
|
||||||
|
newPf = new Parfactor (g, exclCt);
|
||||||
|
newPf->setNewGroups();
|
||||||
|
result.push_back (newPf);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
ParfactorList::updateGroups (PrvGroup oldGroup, PrvGroup newGroup)
|
||||||
|
{
|
||||||
|
for (ParfactorList::iterator it = pfList_.begin();
|
||||||
|
it != pfList_.end(); ++it) {
|
||||||
|
ProbFormulas& formulas = (*it)->arguments();
|
||||||
|
for (size_t i = 0; i < formulas.size(); i++) {
|
||||||
|
if (formulas[i].group() == oldGroup) {
|
||||||
|
formulas[i].setGroup (newGroup);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
ParfactorList::proper (
|
||||||
|
const ProbFormula& f1, ConstraintTree ct1,
|
||||||
|
const ProbFormula& f2, ConstraintTree ct2) const
|
||||||
|
{
|
||||||
|
return disjoint (f1, ct1, f2, ct2)
|
||||||
|
|| identical (f1, ct1, f2, ct2);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
ParfactorList::identical (
|
||||||
|
const ProbFormula& f1, ConstraintTree ct1,
|
||||||
|
const ProbFormula& f2, ConstraintTree ct2) const
|
||||||
|
{
|
||||||
|
if (f1.sameSkeletonAs (f2) == false) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (f1.isAtom()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
TupleSet ts1 = ct1.tupleSet (f1.logVars());
|
||||||
|
TupleSet ts2 = ct2.tupleSet (f2.logVars());
|
||||||
|
return ts1 == ts2;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
ParfactorList::disjoint (
|
||||||
|
const ProbFormula& f1, ConstraintTree ct1,
|
||||||
|
const ProbFormula& f2, ConstraintTree ct2) const
|
||||||
|
{
|
||||||
|
if (f1.sameSkeletonAs (f2) == false) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (f1.isAtom()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
TupleSet ts1 = ct1.tupleSet (f1.logVars());
|
||||||
|
TupleSet ts2 = ct2.tupleSet (f2.logVars());
|
||||||
|
return (ts1 & ts2).empty();
|
||||||
|
}
|
||||||
|
|
121
packages/CLPBN/horus2/ParfactorList.h
Normal file
121
packages/CLPBN/horus2/ParfactorList.h
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
#ifndef HORUS_PARFACTORLIST_H
|
||||||
|
#define HORUS_PARFACTORLIST_H
|
||||||
|
|
||||||
|
#include <list>
|
||||||
|
#include <queue>
|
||||||
|
|
||||||
|
#include "Parfactor.h"
|
||||||
|
#include "ProbFormula.h"
|
||||||
|
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
|
||||||
|
class ParfactorList
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
ParfactorList (void) { }
|
||||||
|
|
||||||
|
ParfactorList (const ParfactorList&);
|
||||||
|
|
||||||
|
ParfactorList (const Parfactors&);
|
||||||
|
|
||||||
|
~ParfactorList (void);
|
||||||
|
|
||||||
|
const list<Parfactor*>& parfactors (void) const { return pfList_; }
|
||||||
|
|
||||||
|
void clear (void) { pfList_.clear(); }
|
||||||
|
|
||||||
|
size_t size (void) const { return pfList_.size(); }
|
||||||
|
|
||||||
|
typedef std::list<Parfactor*>::iterator iterator;
|
||||||
|
|
||||||
|
iterator begin (void) { return pfList_.begin(); }
|
||||||
|
|
||||||
|
iterator end (void) { return pfList_.end(); }
|
||||||
|
|
||||||
|
typedef std::list<Parfactor*>::const_iterator const_iterator;
|
||||||
|
|
||||||
|
const_iterator begin (void) const { return pfList_.begin(); }
|
||||||
|
|
||||||
|
const_iterator end (void) const { return pfList_.end(); }
|
||||||
|
|
||||||
|
void add (Parfactor* pf);
|
||||||
|
|
||||||
|
void add (const Parfactors& pfs);
|
||||||
|
|
||||||
|
void addShattered (Parfactor* pf);
|
||||||
|
|
||||||
|
list<Parfactor*>::iterator insertShattered (
|
||||||
|
list<Parfactor*>::iterator, Parfactor*);
|
||||||
|
|
||||||
|
list<Parfactor*>::iterator remove (list<Parfactor*>::iterator);
|
||||||
|
|
||||||
|
list<Parfactor*>::iterator removeAndDelete (list<Parfactor*>::iterator);
|
||||||
|
|
||||||
|
bool isAllShattered (void) const;
|
||||||
|
|
||||||
|
void print (void) const;
|
||||||
|
|
||||||
|
ParfactorList& operator= (const ParfactorList& pfList);
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool isShattered (const Parfactor*) const;
|
||||||
|
|
||||||
|
bool isShattered (const Parfactor*, const Parfactor*) const;
|
||||||
|
|
||||||
|
void addToShatteredList (Parfactor*);
|
||||||
|
|
||||||
|
Parfactors shatterAgainstMySelf (Parfactor* g);
|
||||||
|
|
||||||
|
Parfactors shatterAgainstMySelf2 (Parfactor* g);
|
||||||
|
|
||||||
|
Parfactors shatterAgainstMySelf (
|
||||||
|
Parfactor* g, size_t fIdx1, size_t fIdx2);
|
||||||
|
|
||||||
|
std::pair<Parfactors, Parfactors> shatter (
|
||||||
|
Parfactor*, Parfactor*);
|
||||||
|
|
||||||
|
std::pair<Parfactors, Parfactors> shatter (
|
||||||
|
size_t, Parfactor*, size_t, Parfactor*);
|
||||||
|
|
||||||
|
Parfactors shatter (
|
||||||
|
Parfactor*,
|
||||||
|
size_t,
|
||||||
|
ConstraintTree*,
|
||||||
|
ConstraintTree*,
|
||||||
|
PrvGroup);
|
||||||
|
|
||||||
|
void updateGroups (PrvGroup group1, PrvGroup group2);
|
||||||
|
|
||||||
|
bool proper (
|
||||||
|
const ProbFormula&, ConstraintTree,
|
||||||
|
const ProbFormula&, ConstraintTree) const;
|
||||||
|
|
||||||
|
bool identical (
|
||||||
|
const ProbFormula&, ConstraintTree,
|
||||||
|
const ProbFormula&, ConstraintTree) const;
|
||||||
|
|
||||||
|
bool disjoint (
|
||||||
|
const ProbFormula&, ConstraintTree,
|
||||||
|
const ProbFormula&, ConstraintTree) const;
|
||||||
|
|
||||||
|
struct sortByParams
|
||||||
|
{
|
||||||
|
inline bool operator() (const Parfactor* pf1, const Parfactor* pf2)
|
||||||
|
{
|
||||||
|
if (pf1->params().size() < pf2->params().size()) {
|
||||||
|
return true;
|
||||||
|
} else if (pf1->params().size() == pf2->params().size() &&
|
||||||
|
pf1->params() < pf2->params()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
list<Parfactor*> pfList_;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // HORUS_PARFACTORLIST_H
|
||||||
|
|
140
packages/CLPBN/horus2/ProbFormula.cpp
Normal file
140
packages/CLPBN/horus2/ProbFormula.cpp
Normal file
@ -0,0 +1,140 @@
|
|||||||
|
#include "ProbFormula.h"
|
||||||
|
|
||||||
|
|
||||||
|
PrvGroup ProbFormula::freeGroup_ = 0;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
ProbFormula::sameSkeletonAs (const ProbFormula& f) const
|
||||||
|
{
|
||||||
|
return functor_ == f.functor() && logVars_.size() == f.arity();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
ProbFormula::contains (LogVar lv) const
|
||||||
|
{
|
||||||
|
return Util::contains (logVars_, lv);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
ProbFormula::contains (LogVarSet s) const
|
||||||
|
{
|
||||||
|
return LogVarSet (logVars_).contains (s);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
size_t
|
||||||
|
ProbFormula::indexOf (LogVar X) const
|
||||||
|
{
|
||||||
|
return Util::indexOf (logVars_, X);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
ProbFormula::isAtom (void) const
|
||||||
|
{
|
||||||
|
return logVars_.size() == 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
ProbFormula::isCounting (void) const
|
||||||
|
{
|
||||||
|
return countedLogVar_.valid();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
LogVar
|
||||||
|
ProbFormula::countedLogVar (void) const
|
||||||
|
{
|
||||||
|
assert (isCounting());
|
||||||
|
return countedLogVar_;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
ProbFormula::setCountedLogVar (LogVar lv)
|
||||||
|
{
|
||||||
|
countedLogVar_ = lv;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
ProbFormula::clearCountedLogVar (void)
|
||||||
|
{
|
||||||
|
countedLogVar_ = LogVar();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
ProbFormula::rename (LogVar oldName, LogVar newName)
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < logVars_.size(); i++) {
|
||||||
|
if (logVars_[i] == oldName) {
|
||||||
|
logVars_[i] = newName;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (isCounting() && countedLogVar_ == oldName) {
|
||||||
|
countedLogVar_ = newName;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
bool operator== (const ProbFormula& f1, const ProbFormula& f2)
|
||||||
|
{
|
||||||
|
return f1.group_ == f2.group_ &&
|
||||||
|
f1.logVars_ == f2.logVars_;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
std::ostream& operator<< (ostream &os, const ProbFormula& f)
|
||||||
|
{
|
||||||
|
os << f.functor_;
|
||||||
|
if (f.isAtom() == false) {
|
||||||
|
os << "(" ;
|
||||||
|
for (size_t i = 0; i < f.logVars_.size(); i++) {
|
||||||
|
if (i != 0) os << ",";
|
||||||
|
if (f.isCounting() && f.logVars_[i] == f.countedLogVar_) {
|
||||||
|
os << "#" ;
|
||||||
|
}
|
||||||
|
os << f.logVars_[i];
|
||||||
|
}
|
||||||
|
os << ")" ;
|
||||||
|
}
|
||||||
|
os << "::" << f.range_;
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
PrvGroup
|
||||||
|
ProbFormula::getNewGroup (void)
|
||||||
|
{
|
||||||
|
freeGroup_ ++;
|
||||||
|
assert (freeGroup_ != numeric_limits<PrvGroup>::max());
|
||||||
|
return freeGroup_;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
ostream& operator<< (ostream &os, const ObservedFormula& of)
|
||||||
|
{
|
||||||
|
os << of.functor_ << "/" << of.arity_;
|
||||||
|
os << "|" << of.constr_.tupleSet();
|
||||||
|
os << " [evidence=" << of.evidence_ << "]";
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
114
packages/CLPBN/horus2/ProbFormula.h
Normal file
114
packages/CLPBN/horus2/ProbFormula.h
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
#ifndef HORUS_PROBFORMULA_H
|
||||||
|
#define HORUS_PROBFORMULA_H
|
||||||
|
|
||||||
|
#include <limits>
|
||||||
|
|
||||||
|
#include "ConstraintTree.h"
|
||||||
|
#include "LiftedUtils.h"
|
||||||
|
#include "Horus.h"
|
||||||
|
|
||||||
|
typedef unsigned long PrvGroup;
|
||||||
|
|
||||||
|
class ProbFormula
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
ProbFormula (Symbol f, const LogVars& lvs, unsigned range)
|
||||||
|
: functor_(f), logVars_(lvs), range_(range),
|
||||||
|
countedLogVar_(), group_(numeric_limits<PrvGroup>::max()) { }
|
||||||
|
|
||||||
|
ProbFormula (Symbol f, unsigned r)
|
||||||
|
: functor_(f), range_(r), group_(numeric_limits<PrvGroup>::max()) { }
|
||||||
|
|
||||||
|
Symbol functor (void) const { return functor_; }
|
||||||
|
|
||||||
|
unsigned arity (void) const { return logVars_.size(); }
|
||||||
|
|
||||||
|
unsigned range (void) const { return range_; }
|
||||||
|
|
||||||
|
LogVars& logVars (void) { return logVars_; }
|
||||||
|
|
||||||
|
const LogVars& logVars (void) const { return logVars_; }
|
||||||
|
|
||||||
|
LogVarSet logVarSet (void) const { return LogVarSet (logVars_); }
|
||||||
|
|
||||||
|
PrvGroup group (void) const { return group_; }
|
||||||
|
|
||||||
|
void setGroup (PrvGroup g) { group_ = g; }
|
||||||
|
|
||||||
|
bool sameSkeletonAs (const ProbFormula&) const;
|
||||||
|
|
||||||
|
bool contains (LogVar) const;
|
||||||
|
|
||||||
|
bool contains (LogVarSet) const;
|
||||||
|
|
||||||
|
size_t indexOf (LogVar) const;
|
||||||
|
|
||||||
|
bool isAtom (void) const;
|
||||||
|
|
||||||
|
bool isCounting (void) const;
|
||||||
|
|
||||||
|
LogVar countedLogVar (void) const;
|
||||||
|
|
||||||
|
void setCountedLogVar (LogVar);
|
||||||
|
|
||||||
|
void clearCountedLogVar (void);
|
||||||
|
|
||||||
|
void rename (LogVar, LogVar);
|
||||||
|
|
||||||
|
static PrvGroup getNewGroup (void);
|
||||||
|
|
||||||
|
friend std::ostream& operator<< (ostream &os, const ProbFormula& f);
|
||||||
|
|
||||||
|
friend bool operator== (const ProbFormula& f1, const ProbFormula& f2);
|
||||||
|
|
||||||
|
private:
|
||||||
|
Symbol functor_;
|
||||||
|
LogVars logVars_;
|
||||||
|
unsigned range_;
|
||||||
|
LogVar countedLogVar_;
|
||||||
|
PrvGroup group_;
|
||||||
|
static PrvGroup freeGroup_;
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef vector<ProbFormula> ProbFormulas;
|
||||||
|
|
||||||
|
|
||||||
|
class ObservedFormula
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
ObservedFormula (Symbol f, unsigned a, unsigned ev)
|
||||||
|
: functor_(f), arity_(a), evidence_(ev), constr_(a) { }
|
||||||
|
|
||||||
|
ObservedFormula (Symbol f, unsigned ev, const Tuple& tuple)
|
||||||
|
: functor_(f), arity_(tuple.size()), evidence_(ev), constr_(arity_)
|
||||||
|
{
|
||||||
|
constr_.addTuple (tuple);
|
||||||
|
}
|
||||||
|
|
||||||
|
Symbol functor (void) const { return functor_; }
|
||||||
|
|
||||||
|
unsigned arity (void) const { return arity_; }
|
||||||
|
|
||||||
|
unsigned evidence (void) const { return evidence_; }
|
||||||
|
|
||||||
|
void setEvidence (unsigned ev) { evidence_ = ev; }
|
||||||
|
|
||||||
|
ConstraintTree& constr (void) { return constr_; }
|
||||||
|
|
||||||
|
bool isAtom (void) const { return arity_ == 0; }
|
||||||
|
|
||||||
|
void addTuple (const Tuple& tuple) { constr_.addTuple (tuple); }
|
||||||
|
|
||||||
|
friend ostream& operator<< (ostream &os, const ObservedFormula& of);
|
||||||
|
|
||||||
|
private:
|
||||||
|
Symbol functor_;
|
||||||
|
unsigned arity_;
|
||||||
|
unsigned evidence_;
|
||||||
|
ConstraintTree constr_;
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef vector<ObservedFormula> ObservedFormulas;
|
||||||
|
|
||||||
|
#endif // HORUS_PROBFORMULA_H
|
||||||
|
|
264
packages/CLPBN/horus2/TinySet.h
Normal file
264
packages/CLPBN/horus2/TinySet.h
Normal file
@ -0,0 +1,264 @@
|
|||||||
|
#ifndef HORUS_TINYSET_H
|
||||||
|
#define HORUS_TINYSET_H
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T, typename Compare = std::less<T>>
|
||||||
|
class TinySet
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
|
||||||
|
typedef typename vector<T>::iterator iterator;
|
||||||
|
typedef typename vector<T>::const_iterator const_iterator;
|
||||||
|
|
||||||
|
TinySet (const TinySet& s)
|
||||||
|
: vec_(s.vec_), cmp_(s.cmp_) { }
|
||||||
|
|
||||||
|
TinySet (const Compare& cmp = Compare())
|
||||||
|
: vec_(), cmp_(cmp) { }
|
||||||
|
|
||||||
|
TinySet (const T& t, const Compare& cmp = Compare())
|
||||||
|
: vec_(1, t), cmp_(cmp) { }
|
||||||
|
|
||||||
|
TinySet (const vector<T>& elements, const Compare& cmp = Compare())
|
||||||
|
: vec_(elements), cmp_(cmp)
|
||||||
|
{
|
||||||
|
std::sort (begin(), end(), cmp_);
|
||||||
|
iterator it = unique_cmp (begin(), end());
|
||||||
|
vec_.resize (it - begin());
|
||||||
|
}
|
||||||
|
|
||||||
|
iterator insert (const T& t)
|
||||||
|
{
|
||||||
|
iterator it = std::lower_bound (begin(), end(), t, cmp_);
|
||||||
|
if (it == end() || cmp_(t, *it)) {
|
||||||
|
vec_.insert (it, t);
|
||||||
|
}
|
||||||
|
return it;
|
||||||
|
}
|
||||||
|
|
||||||
|
void insert_sorted (const T& t)
|
||||||
|
{
|
||||||
|
vec_.push_back (t);
|
||||||
|
assert (consistent());
|
||||||
|
}
|
||||||
|
|
||||||
|
void remove (const T& t)
|
||||||
|
{
|
||||||
|
iterator it = std::lower_bound (begin(), end(), t, cmp_);
|
||||||
|
if (it != end()) {
|
||||||
|
vec_.erase (it);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const_iterator find (const T& t) const
|
||||||
|
{
|
||||||
|
const_iterator it = std::lower_bound (begin(), end(), t, cmp_);
|
||||||
|
return it == end() || cmp_(t, *it) ? end() : it;
|
||||||
|
}
|
||||||
|
|
||||||
|
iterator find (const T& t)
|
||||||
|
{
|
||||||
|
iterator it = std::lower_bound (begin(), end(), t, cmp_);
|
||||||
|
return it == end() || cmp_(t, *it) ? end() : it;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* set union */
|
||||||
|
TinySet operator| (const TinySet& s) const
|
||||||
|
{
|
||||||
|
TinySet res;
|
||||||
|
std::set_union (
|
||||||
|
vec_.begin(), vec_.end(),
|
||||||
|
s.vec_.begin(), s.vec_.end(),
|
||||||
|
std::back_inserter (res.vec_),
|
||||||
|
cmp_);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* set intersection */
|
||||||
|
TinySet operator& (const TinySet& s) const
|
||||||
|
{
|
||||||
|
TinySet res;
|
||||||
|
std::set_intersection (
|
||||||
|
vec_.begin(), vec_.end(),
|
||||||
|
s.vec_.begin(), s.vec_.end(),
|
||||||
|
std::back_inserter (res.vec_),
|
||||||
|
cmp_);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* set difference */
|
||||||
|
TinySet operator- (const TinySet& s) const
|
||||||
|
{
|
||||||
|
TinySet res;
|
||||||
|
std::set_difference (
|
||||||
|
vec_.begin(), vec_.end(),
|
||||||
|
s.vec_.begin(), s.vec_.end(),
|
||||||
|
std::back_inserter (res.vec_),
|
||||||
|
cmp_);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
TinySet& operator|= (const TinySet& s)
|
||||||
|
{
|
||||||
|
return *this = (*this | s);
|
||||||
|
}
|
||||||
|
|
||||||
|
TinySet& operator&= (const TinySet& s)
|
||||||
|
{
|
||||||
|
return *this = (*this & s);
|
||||||
|
}
|
||||||
|
|
||||||
|
TinySet& operator-= (const TinySet& s)
|
||||||
|
{
|
||||||
|
return *this = (*this - s);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool contains (const T& t) const
|
||||||
|
{
|
||||||
|
return std::binary_search (
|
||||||
|
vec_.begin(), vec_.end(), t, cmp_);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool contains (const TinySet& s) const
|
||||||
|
{
|
||||||
|
return std::includes (
|
||||||
|
vec_.begin(),
|
||||||
|
vec_.end(),
|
||||||
|
s.vec_.begin(),
|
||||||
|
s.vec_.end(),
|
||||||
|
cmp_);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool in (const TinySet& s) const
|
||||||
|
{
|
||||||
|
return std::includes (
|
||||||
|
s.vec_.begin(),
|
||||||
|
s.vec_.end(),
|
||||||
|
vec_.begin(),
|
||||||
|
vec_.end(),
|
||||||
|
cmp_);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool intersects (const TinySet& s) const
|
||||||
|
{
|
||||||
|
return (*this & s).size() > 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
const T& operator[] (typename vector<T>::size_type i) const
|
||||||
|
{
|
||||||
|
return vec_[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
T& operator[] (typename vector<T>::size_type i)
|
||||||
|
{
|
||||||
|
return vec_[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
T front (void) const
|
||||||
|
{
|
||||||
|
return vec_.front();
|
||||||
|
}
|
||||||
|
|
||||||
|
T& front (void)
|
||||||
|
{
|
||||||
|
return vec_.front();
|
||||||
|
}
|
||||||
|
|
||||||
|
T back (void) const
|
||||||
|
{
|
||||||
|
return vec_.back();
|
||||||
|
}
|
||||||
|
|
||||||
|
T& back (void)
|
||||||
|
{
|
||||||
|
return vec_.back();
|
||||||
|
}
|
||||||
|
|
||||||
|
const vector<T>& elements (void) const
|
||||||
|
{
|
||||||
|
return vec_;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool empty (void) const
|
||||||
|
{
|
||||||
|
return size() == 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
typename vector<T>::size_type size (void) const
|
||||||
|
{
|
||||||
|
return vec_.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
void clear (void)
|
||||||
|
{
|
||||||
|
vec_.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
void reserve (typename vector<T>::size_type size)
|
||||||
|
{
|
||||||
|
vec_.reserve (size);
|
||||||
|
}
|
||||||
|
|
||||||
|
iterator begin (void) { return vec_.begin(); }
|
||||||
|
iterator end (void) { return vec_.end(); }
|
||||||
|
const_iterator begin (void) const { return vec_.begin(); }
|
||||||
|
const_iterator end (void) const { return vec_.end(); }
|
||||||
|
|
||||||
|
friend bool operator== (const TinySet& s1, const TinySet& s2)
|
||||||
|
{
|
||||||
|
return s1.vec_ == s2.vec_;
|
||||||
|
}
|
||||||
|
|
||||||
|
friend bool operator!= (const TinySet& s1, const TinySet& s2)
|
||||||
|
{
|
||||||
|
return ! (s1.vec_ == s2.vec_);
|
||||||
|
}
|
||||||
|
|
||||||
|
friend std::ostream& operator << (std::ostream& out, const TinySet& s)
|
||||||
|
{
|
||||||
|
out << "{" ;
|
||||||
|
typename vector<T>::size_type i;
|
||||||
|
for (i = 0; i < s.size(); i++) {
|
||||||
|
out << ((i != 0) ? "," : "") << s.vec_[i];
|
||||||
|
}
|
||||||
|
out << "}" ;
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
iterator unique_cmp (iterator first, iterator last)
|
||||||
|
{
|
||||||
|
if (first == last) {
|
||||||
|
return last;
|
||||||
|
}
|
||||||
|
iterator result = first;
|
||||||
|
while (++first != last) {
|
||||||
|
if (cmp_(*result, *first)) {
|
||||||
|
*(++result) = *first;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ++result;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool consistent (void) const
|
||||||
|
{
|
||||||
|
typename vector<T>::size_type i;
|
||||||
|
for (i = 0; i < vec_.size() - 1; i++) {
|
||||||
|
if ( ! cmp_(vec_[i], vec_[i + 1])) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<T> vec_;
|
||||||
|
Compare cmp_;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // HORUS_TINYSET_H
|
||||||
|
|
429
packages/CLPBN/horus2/Util.cpp
Normal file
429
packages/CLPBN/horus2/Util.cpp
Normal file
@ -0,0 +1,429 @@
|
|||||||
|
#include <limits>
|
||||||
|
|
||||||
|
#include <sstream>
|
||||||
|
#include <fstream>
|
||||||
|
|
||||||
|
#include "Util.h"
|
||||||
|
#include "Indexer.h"
|
||||||
|
#include "ElimGraph.h"
|
||||||
|
|
||||||
|
|
||||||
|
namespace Globals {
|
||||||
|
bool logDomain = false;
|
||||||
|
|
||||||
|
unsigned verbosity = 0;
|
||||||
|
|
||||||
|
LiftedSolverType liftedSolver = LiftedSolverType::LVE;
|
||||||
|
|
||||||
|
GroundSolverType groundSolver = GroundSolverType::VE;
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
namespace BpOptions {
|
||||||
|
Schedule schedule = BpOptions::Schedule::SEQ_FIXED;
|
||||||
|
//Schedule schedule = BpOptions::Schedule::SEQ_RANDOM;
|
||||||
|
//Schedule schedule = BpOptions::Schedule::PARALLEL;
|
||||||
|
//Schedule schedule = BpOptions::Schedule::MAX_RESIDUAL;
|
||||||
|
double accuracy = 0.0001;
|
||||||
|
unsigned maxIter = 1000;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
namespace Util {
|
||||||
|
|
||||||
|
|
||||||
|
template <> std::string
|
||||||
|
toString (const bool& b)
|
||||||
|
{
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << std::boolalpha << b;
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
unsigned
|
||||||
|
stringToUnsigned (string str)
|
||||||
|
{
|
||||||
|
int val;
|
||||||
|
stringstream ss;
|
||||||
|
ss << str;
|
||||||
|
ss >> val;
|
||||||
|
if (val < 0) {
|
||||||
|
cerr << "Error: the number readed is negative." << endl;
|
||||||
|
exit (EXIT_FAILURE);
|
||||||
|
}
|
||||||
|
return static_cast<unsigned> (val);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
double
|
||||||
|
stringToDouble (string str)
|
||||||
|
{
|
||||||
|
double val;
|
||||||
|
stringstream ss;
|
||||||
|
ss << str;
|
||||||
|
ss >> val;
|
||||||
|
return val;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
double
|
||||||
|
factorial (unsigned num)
|
||||||
|
{
|
||||||
|
double result = 1.0;
|
||||||
|
for (unsigned i = 1; i <= num; i++) {
|
||||||
|
result *= i;
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
double
|
||||||
|
logFactorial (unsigned num)
|
||||||
|
{
|
||||||
|
double result = 0.0;
|
||||||
|
if (num < 150) {
|
||||||
|
result = std::log (factorial (num));
|
||||||
|
} else {
|
||||||
|
for (unsigned i = 1; i <= num; i++) {
|
||||||
|
result += std::log (i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
unsigned
|
||||||
|
nrCombinations (unsigned n, unsigned k)
|
||||||
|
{
|
||||||
|
assert (n >= k);
|
||||||
|
int diff = n - k;
|
||||||
|
unsigned result = 0;
|
||||||
|
if (n < 150) {
|
||||||
|
unsigned prod = 1;
|
||||||
|
for (int i = n; i > diff; i--) {
|
||||||
|
prod *= i;
|
||||||
|
}
|
||||||
|
result = prod / factorial (k);
|
||||||
|
} else {
|
||||||
|
double prod = 0.0;
|
||||||
|
for (int i = n; i > diff; i--) {
|
||||||
|
prod += std::log (i);
|
||||||
|
}
|
||||||
|
prod -= logFactorial (k);
|
||||||
|
result = static_cast<unsigned> (std::exp (prod));
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
size_t
|
||||||
|
sizeExpected (const Ranges& ranges)
|
||||||
|
{
|
||||||
|
return std::accumulate (ranges.begin(),
|
||||||
|
ranges.end(), 1, multiplies<unsigned>());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
unsigned
|
||||||
|
nrDigits (int num)
|
||||||
|
{
|
||||||
|
unsigned count = 1;
|
||||||
|
while (num >= 10) {
|
||||||
|
num /= 10;
|
||||||
|
count ++;
|
||||||
|
}
|
||||||
|
return count;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
isInteger (const string& s)
|
||||||
|
{
|
||||||
|
stringstream ss1 (s);
|
||||||
|
stringstream ss2;
|
||||||
|
int integer;
|
||||||
|
ss1 >> integer;
|
||||||
|
ss2 << integer;
|
||||||
|
return (ss1.str() == ss2.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
string
|
||||||
|
parametersToString (const Params& v, unsigned precision)
|
||||||
|
{
|
||||||
|
stringstream ss;
|
||||||
|
ss.precision (precision);
|
||||||
|
ss << "[" ;
|
||||||
|
for (size_t i = 0; i < v.size(); i++) {
|
||||||
|
if (i != 0) ss << ", " ;
|
||||||
|
ss << v[i];
|
||||||
|
}
|
||||||
|
ss << "]" ;
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
vector<string>
|
||||||
|
getStateLines (const Vars& vars)
|
||||||
|
{
|
||||||
|
Ranges ranges;
|
||||||
|
for (size_t i = 0; i < vars.size(); i++) {
|
||||||
|
ranges.push_back (vars[i]->range());
|
||||||
|
}
|
||||||
|
Indexer indexer (ranges);
|
||||||
|
vector<string> jointStrings;
|
||||||
|
while (indexer.valid()) {
|
||||||
|
stringstream ss;
|
||||||
|
for (size_t i = 0; i < vars.size(); i++) {
|
||||||
|
if (i != 0) ss << ", " ;
|
||||||
|
ss << vars[i]->label() << "=" ;
|
||||||
|
ss << vars[i]->states()[(indexer[i])];
|
||||||
|
}
|
||||||
|
jointStrings.push_back (ss.str());
|
||||||
|
++ indexer;
|
||||||
|
}
|
||||||
|
return jointStrings;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
setHorusFlag (string key, string value)
|
||||||
|
{
|
||||||
|
bool returnVal = true;
|
||||||
|
if (key == "verbosity") {
|
||||||
|
stringstream ss;
|
||||||
|
ss << value;
|
||||||
|
ss >> Globals::verbosity;
|
||||||
|
} else if (key == "lifted_solver") {
|
||||||
|
if ( value == "lve") {
|
||||||
|
Globals::liftedSolver = LiftedSolverType::LVE;
|
||||||
|
} else if (value == "lbp") {
|
||||||
|
Globals::liftedSolver = LiftedSolverType::LBP;
|
||||||
|
} else if (value == "lkc") {
|
||||||
|
Globals::liftedSolver = LiftedSolverType::LKC;
|
||||||
|
} else {
|
||||||
|
cerr << "warning: invalid value `" << value << "' " ;
|
||||||
|
cerr << "for `" << key << "'" << endl;
|
||||||
|
returnVal = false;
|
||||||
|
}
|
||||||
|
} else if (key == "ground_solver") {
|
||||||
|
if ( value == "ve") {
|
||||||
|
Globals::groundSolver = GroundSolverType::VE;
|
||||||
|
} else if (value == "bp") {
|
||||||
|
Globals::groundSolver = GroundSolverType::BP;
|
||||||
|
} else if (value == "cbp") {
|
||||||
|
Globals::groundSolver = GroundSolverType::CBP;
|
||||||
|
} else {
|
||||||
|
cerr << "warning: invalid value `" << value << "' " ;
|
||||||
|
cerr << "for `" << key << "'" << endl;
|
||||||
|
returnVal = false;
|
||||||
|
}
|
||||||
|
} else if (key == "elim_heuristic") {
|
||||||
|
if ( value == "sequential") {
|
||||||
|
ElimGraph::elimHeuristic = ElimHeuristic::SEQUENTIAL;
|
||||||
|
} else if (value == "min_neighbors") {
|
||||||
|
ElimGraph::elimHeuristic = ElimHeuristic::MIN_NEIGHBORS;
|
||||||
|
} else if (value == "min_weight") {
|
||||||
|
ElimGraph::elimHeuristic = ElimHeuristic::MIN_WEIGHT;
|
||||||
|
} else if (value == "min_fill") {
|
||||||
|
ElimGraph::elimHeuristic = ElimHeuristic::MIN_FILL;
|
||||||
|
} else if (value == "weighted_min_fill") {
|
||||||
|
ElimGraph::elimHeuristic = ElimHeuristic::WEIGHTED_MIN_FILL;
|
||||||
|
} else {
|
||||||
|
cerr << "warning: invalid value `" << value << "' " ;
|
||||||
|
cerr << "for `" << key << "'" << endl;
|
||||||
|
returnVal = false;
|
||||||
|
}
|
||||||
|
} else if (key == "schedule") {
|
||||||
|
if ( value == "seq_fixed") {
|
||||||
|
BpOptions::schedule = BpOptions::Schedule::SEQ_FIXED;
|
||||||
|
} else if (value == "seq_random") {
|
||||||
|
BpOptions::schedule = BpOptions::Schedule::SEQ_RANDOM;
|
||||||
|
} else if (value == "parallel") {
|
||||||
|
BpOptions::schedule = BpOptions::Schedule::PARALLEL;
|
||||||
|
} else if (value == "max_residual") {
|
||||||
|
BpOptions::schedule = BpOptions::Schedule::MAX_RESIDUAL;
|
||||||
|
} else {
|
||||||
|
cerr << "warning: invalid value `" << value << "' " ;
|
||||||
|
cerr << "for `" << key << "'" << endl;
|
||||||
|
returnVal = false;
|
||||||
|
}
|
||||||
|
} else if (key == "accuracy") {
|
||||||
|
stringstream ss;
|
||||||
|
ss << value;
|
||||||
|
ss >> BpOptions::accuracy;
|
||||||
|
} else if (key == "max_iter") {
|
||||||
|
stringstream ss;
|
||||||
|
ss << value;
|
||||||
|
ss >> BpOptions::maxIter;
|
||||||
|
} else if (key == "use_logarithms") {
|
||||||
|
if ( value == "true") {
|
||||||
|
Globals::logDomain = true;
|
||||||
|
} else if (value == "false") {
|
||||||
|
Globals::logDomain = false;
|
||||||
|
} else {
|
||||||
|
cerr << "warning: invalid value `" << value << "' " ;
|
||||||
|
cerr << "for `" << key << "'" << endl;
|
||||||
|
returnVal = false;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
cerr << "warning: invalid key `" << key << "'" << endl;
|
||||||
|
returnVal = false;
|
||||||
|
}
|
||||||
|
return returnVal;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
printHeader (string header, std::ostream& os)
|
||||||
|
{
|
||||||
|
printAsteriskLine (os);
|
||||||
|
os << header << endl;
|
||||||
|
printAsteriskLine (os);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
printSubHeader (string header, std::ostream& os)
|
||||||
|
{
|
||||||
|
printDashedLine (os);
|
||||||
|
os << header << endl;
|
||||||
|
printDashedLine (os);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
printAsteriskLine (std::ostream& os)
|
||||||
|
{
|
||||||
|
os << "********************************" ;
|
||||||
|
os << "********************************" ;
|
||||||
|
os << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
printDashedLine (std::ostream& os)
|
||||||
|
{
|
||||||
|
os << "--------------------------------" ;
|
||||||
|
os << "--------------------------------" ;
|
||||||
|
os << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
namespace LogAware {
|
||||||
|
|
||||||
|
void
|
||||||
|
normalize (Params& v)
|
||||||
|
{
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
double sum = std::accumulate (v.begin(), v.end(),
|
||||||
|
LogAware::addIdenty(), Util::logSum);
|
||||||
|
assert (sum != -numeric_limits<double>::infinity());
|
||||||
|
v -= sum;
|
||||||
|
} else {
|
||||||
|
double sum = std::accumulate (v.begin(), v.end(), 0.0);
|
||||||
|
assert (sum != 0.0);
|
||||||
|
v /= sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
double
|
||||||
|
getL1Distance (const Params& v1, const Params& v2)
|
||||||
|
{
|
||||||
|
assert (v1.size() == v2.size());
|
||||||
|
double dist = 0.0;
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
dist = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0,
|
||||||
|
std::plus<double>(), FuncObject::abs_diff_exp<double>());
|
||||||
|
} else {
|
||||||
|
dist = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0,
|
||||||
|
std::plus<double>(), FuncObject::abs_diff<double>());
|
||||||
|
}
|
||||||
|
return dist;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
double
|
||||||
|
getMaxNorm (const Params& v1, const Params& v2)
|
||||||
|
{
|
||||||
|
assert (v1.size() == v2.size());
|
||||||
|
double max = 0.0;
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
max = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0,
|
||||||
|
FuncObject::max<double>(), FuncObject::abs_diff_exp<double>());
|
||||||
|
} else {
|
||||||
|
max = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0,
|
||||||
|
FuncObject::max<double>(), FuncObject::abs_diff<double>());
|
||||||
|
}
|
||||||
|
return max;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
double
|
||||||
|
pow (double base, unsigned iexp)
|
||||||
|
{
|
||||||
|
return Globals::logDomain
|
||||||
|
? base * iexp
|
||||||
|
: std::pow (base, iexp);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
double
|
||||||
|
pow (double base, double exp)
|
||||||
|
{
|
||||||
|
// `expoent' should not be in log domain
|
||||||
|
return Globals::logDomain
|
||||||
|
? base * exp
|
||||||
|
: std::pow (base, exp);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
pow (Params& v, unsigned iexp)
|
||||||
|
{
|
||||||
|
if (iexp == 1) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
Globals::logDomain ? v *= iexp : v ^= (int)iexp;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
pow (Params& v, double exp)
|
||||||
|
{
|
||||||
|
// `expoent' should not be in log domain
|
||||||
|
Globals::logDomain ? v *= exp : v ^= exp;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
422
packages/CLPBN/horus2/Util.h
Normal file
422
packages/CLPBN/horus2/Util.h
Normal file
@ -0,0 +1,422 @@
|
|||||||
|
#ifndef HORUS_UTIL_H
|
||||||
|
#define HORUS_UTIL_H
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
#include <cassert>
|
||||||
|
#include <limits>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <vector>
|
||||||
|
#include <set>
|
||||||
|
#include <queue>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
|
#include <sstream>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
#include "Horus.h"
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
const double NEG_INF = -numeric_limits<double>::infinity();
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
namespace Util {
|
||||||
|
|
||||||
|
template <typename T> void addToVector (vector<T>&, const vector<T>&);
|
||||||
|
|
||||||
|
template <typename T> void addToSet (set<T>&, const vector<T>&);
|
||||||
|
|
||||||
|
template <typename T> void addToQueue (queue<T>&, const vector<T>&);
|
||||||
|
|
||||||
|
template <typename T> bool contains (const vector<T>&, const T&);
|
||||||
|
|
||||||
|
template <typename T> bool contains (const set<T>&, const T&);
|
||||||
|
|
||||||
|
template <typename K, typename V> bool contains (
|
||||||
|
const unordered_map<K, V>&, const K&);
|
||||||
|
|
||||||
|
template <typename T> size_t indexOf (const vector<T>&, const T&);
|
||||||
|
|
||||||
|
template <class Operation>
|
||||||
|
void apply_n_times (Params& v1, const Params& v2, unsigned repetitions, Operation);
|
||||||
|
|
||||||
|
template <typename T> void log (vector<T>&);
|
||||||
|
|
||||||
|
template <typename T> void exp (vector<T>&);
|
||||||
|
|
||||||
|
template <typename T> string elementsToString (
|
||||||
|
const vector<T>& v, string sep = " ");
|
||||||
|
|
||||||
|
template <typename T> std::string toString (const T&);
|
||||||
|
|
||||||
|
template <> std::string toString (const bool&);
|
||||||
|
|
||||||
|
double logSum (double, double);
|
||||||
|
|
||||||
|
unsigned maxUnsigned (void);
|
||||||
|
|
||||||
|
unsigned stringToUnsigned (string);
|
||||||
|
|
||||||
|
double stringToDouble (string);
|
||||||
|
|
||||||
|
double factorial (unsigned);
|
||||||
|
|
||||||
|
double logFactorial (unsigned);
|
||||||
|
|
||||||
|
unsigned nrCombinations (unsigned, unsigned);
|
||||||
|
|
||||||
|
size_t sizeExpected (const Ranges&);
|
||||||
|
|
||||||
|
unsigned nrDigits (int);
|
||||||
|
|
||||||
|
bool isInteger (const string&);
|
||||||
|
|
||||||
|
string parametersToString (const Params&, unsigned = Constants::PRECISION);
|
||||||
|
|
||||||
|
vector<string> getStateLines (const Vars&);
|
||||||
|
|
||||||
|
bool setHorusFlag (string key, string value);
|
||||||
|
|
||||||
|
void printHeader (string, std::ostream& os = std::cout);
|
||||||
|
|
||||||
|
void printSubHeader (string, std::ostream& os = std::cout);
|
||||||
|
|
||||||
|
void printAsteriskLine (std::ostream& os = std::cout);
|
||||||
|
|
||||||
|
void printDashedLine (std::ostream& os = std::cout);
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T> void
|
||||||
|
Util::addToVector (vector<T>& v, const vector<T>& elements)
|
||||||
|
{
|
||||||
|
v.insert (v.end(), elements.begin(), elements.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T> void
|
||||||
|
Util::addToSet (set<T>& s, const vector<T>& elements)
|
||||||
|
{
|
||||||
|
s.insert (elements.begin(), elements.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T> void
|
||||||
|
Util::addToQueue (queue<T>& q, const vector<T>& elements)
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < elements.size(); i++) {
|
||||||
|
q.push (elements[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T> bool
|
||||||
|
Util::contains (const vector<T>& v, const T& e)
|
||||||
|
{
|
||||||
|
return std::find (v.begin(), v.end(), e) != v.end();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T> bool
|
||||||
|
Util::contains (const set<T>& s, const T& e)
|
||||||
|
{
|
||||||
|
return s.find (e) != s.end();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <typename K, typename V> bool
|
||||||
|
Util::contains (const unordered_map<K, V>& m, const K& k)
|
||||||
|
{
|
||||||
|
return m.find (k) != m.end();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T> size_t
|
||||||
|
Util::indexOf (const vector<T>& v, const T& e)
|
||||||
|
{
|
||||||
|
return std::distance (v.begin(),
|
||||||
|
std::find (v.begin(), v.end(), e));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <class Operation> void
|
||||||
|
Util::apply_n_times (Params& v1, const Params& v2, unsigned repetitions,
|
||||||
|
Operation unary_op)
|
||||||
|
{
|
||||||
|
Params::iterator first = v1.begin();
|
||||||
|
Params::const_iterator last = v1.end();
|
||||||
|
Params::const_iterator first2 = v2.begin();
|
||||||
|
Params::const_iterator last2 = v2.end();
|
||||||
|
while (first != last) {
|
||||||
|
for (first2 = v2.begin(); first2 != last2; ++first2) {
|
||||||
|
std::transform (first, first + repetitions, first,
|
||||||
|
std::bind1st (unary_op, *first2));
|
||||||
|
first += repetitions;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T> void
|
||||||
|
Util::log (vector<T>& v)
|
||||||
|
{
|
||||||
|
std::transform (v.begin(), v.end(), v.begin(), ::log);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T> void
|
||||||
|
Util::exp (vector<T>& v)
|
||||||
|
{
|
||||||
|
std::transform (v.begin(), v.end(), v.begin(), ::exp);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T> string
|
||||||
|
Util::elementsToString (const vector<T>& v, string sep)
|
||||||
|
{
|
||||||
|
stringstream ss;
|
||||||
|
for (size_t i = 0; i < v.size(); i++) {
|
||||||
|
ss << ((i != 0) ? sep : "") << v[i];
|
||||||
|
}
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T> std::string
|
||||||
|
Util::toString (const T& t)
|
||||||
|
{
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << t;
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
inline double
|
||||||
|
Util::logSum (double x, double y)
|
||||||
|
{
|
||||||
|
// std::log (std::exp (x) + std::exp (y)) can overflow!
|
||||||
|
assert (std::isnan (x) == false);
|
||||||
|
assert (std::isnan (y) == false);
|
||||||
|
if (x == NEG_INF) {
|
||||||
|
return y;
|
||||||
|
}
|
||||||
|
if (y == NEG_INF) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
// if one value is much smaller than the other,
|
||||||
|
// keep the larger value
|
||||||
|
const double tol = 460.517; // log (1e200)
|
||||||
|
if (x < y - tol) {
|
||||||
|
return y;
|
||||||
|
}
|
||||||
|
if (y < x - tol) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
assert (std::isnan (x - y) == false);
|
||||||
|
const double exp_diff = std::exp (x - y);
|
||||||
|
if (std::isfinite (exp_diff) == false) {
|
||||||
|
// difference is too large
|
||||||
|
return x > y ? x : y;
|
||||||
|
}
|
||||||
|
// otherwise return the sum
|
||||||
|
return y + std::log (static_cast<double>(1.0) + exp_diff);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
inline unsigned
|
||||||
|
Util::maxUnsigned (void)
|
||||||
|
{
|
||||||
|
return numeric_limits<unsigned>::max();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
namespace LogAware {
|
||||||
|
|
||||||
|
inline double one() { return Globals::logDomain ? 0.0 : 1.0; }
|
||||||
|
inline double zero() { return Globals::logDomain ? NEG_INF : 0.0; }
|
||||||
|
inline double addIdenty() { return Globals::logDomain ? NEG_INF : 0.0; }
|
||||||
|
inline double multIdenty() { return Globals::logDomain ? 0.0 : 1.0; }
|
||||||
|
inline double withEvidence() { return Globals::logDomain ? 0.0 : 1.0; }
|
||||||
|
inline double noEvidence() { return Globals::logDomain ? NEG_INF : 0.0; }
|
||||||
|
inline double log (double v) { return Globals::logDomain ? ::log (v) : v; }
|
||||||
|
inline double exp (double v) { return Globals::logDomain ? ::exp (v) : v; }
|
||||||
|
|
||||||
|
void normalize (Params&);
|
||||||
|
|
||||||
|
double getL1Distance (const Params&, const Params&);
|
||||||
|
|
||||||
|
double getMaxNorm (const Params&, const Params&);
|
||||||
|
|
||||||
|
double pow (double, unsigned);
|
||||||
|
|
||||||
|
double pow (double, double);
|
||||||
|
|
||||||
|
void pow (Params&, unsigned);
|
||||||
|
|
||||||
|
void pow (Params&, double);
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void operator+=(std::vector<T>& v, double val)
|
||||||
|
{
|
||||||
|
std::transform (v.begin(), v.end(), v.begin(),
|
||||||
|
std::bind2nd (plus<double>(), val));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void operator-=(std::vector<T>& v, double val)
|
||||||
|
{
|
||||||
|
std::transform (v.begin(), v.end(), v.begin(),
|
||||||
|
std::bind2nd (minus<double>(), val));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void operator*=(std::vector<T>& v, double val)
|
||||||
|
{
|
||||||
|
std::transform (v.begin(), v.end(), v.begin(),
|
||||||
|
std::bind2nd (multiplies<double>(), val));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void operator/=(std::vector<T>& v, double val)
|
||||||
|
{
|
||||||
|
std::transform (v.begin(), v.end(), v.begin(),
|
||||||
|
std::bind2nd (divides<double>(), val));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void operator+=(std::vector<T>& a, const std::vector<T>& b)
|
||||||
|
{
|
||||||
|
assert (a.size() == b.size());
|
||||||
|
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
|
||||||
|
plus<double>());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void operator-=(std::vector<T>& a, const std::vector<T>& b)
|
||||||
|
{
|
||||||
|
assert (a.size() == b.size());
|
||||||
|
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
|
||||||
|
minus<double>());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void operator*=(std::vector<T>& a, const std::vector<T>& b)
|
||||||
|
{
|
||||||
|
assert (a.size() == b.size());
|
||||||
|
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
|
||||||
|
multiplies<double>());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void operator/=(std::vector<T>& a, const std::vector<T>& b)
|
||||||
|
{
|
||||||
|
assert (a.size() == b.size());
|
||||||
|
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
|
||||||
|
divides<double>());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void operator^=(std::vector<T>& v, double exp)
|
||||||
|
{
|
||||||
|
std::transform (v.begin(), v.end(), v.begin(),
|
||||||
|
std::bind2nd (ptr_fun<double, double, double> (std::pow), exp));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void operator^=(std::vector<T>& v, int iexp)
|
||||||
|
{
|
||||||
|
std::transform (v.begin(), v.end(), v.begin(),
|
||||||
|
std::bind2nd (ptr_fun<double, int, double> (std::pow), iexp));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::ostream& operator << (std::ostream& os, const vector<T>& v)
|
||||||
|
{
|
||||||
|
os << "[" ;
|
||||||
|
os << Util::elementsToString (v, ", ");
|
||||||
|
os << "]" ;
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
namespace FuncObject {
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
struct max : public std::binary_function<T, T, T>
|
||||||
|
{
|
||||||
|
T operator() (const T& x, const T& y) const
|
||||||
|
{
|
||||||
|
return x < y ? y : x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct abs_diff : public std::binary_function<T, T, T>
|
||||||
|
{
|
||||||
|
T operator() (const T& x, const T& y) const
|
||||||
|
{
|
||||||
|
return std::abs (x - y);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct abs_diff_exp : public std::binary_function<T, T, T>
|
||||||
|
{
|
||||||
|
T operator() (const T& x, const T& y) const
|
||||||
|
{
|
||||||
|
return std::abs (std::exp (x) - std::exp (y));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // HORUS_UTIL_H
|
||||||
|
|
102
packages/CLPBN/horus2/Var.cpp
Normal file
102
packages/CLPBN/horus2/Var.cpp
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
#include <algorithm>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "Var.h"
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
|
||||||
|
unordered_map<VarId, VarInfo> Var::varsInfo_;
|
||||||
|
|
||||||
|
|
||||||
|
Var::Var (const Var* v)
|
||||||
|
{
|
||||||
|
varId_ = v->varId();
|
||||||
|
range_ = v->range();
|
||||||
|
evidence_ = v->getEvidence();
|
||||||
|
index_ = std::numeric_limits<unsigned>::max();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Var::Var (VarId varId, unsigned range, int evidence)
|
||||||
|
{
|
||||||
|
assert (range != 0);
|
||||||
|
assert (evidence < (int) range);
|
||||||
|
varId_ = varId;
|
||||||
|
range_ = range;
|
||||||
|
evidence_ = evidence;
|
||||||
|
index_ = std::numeric_limits<unsigned>::max();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
Var::isValidState (int stateIndex)
|
||||||
|
{
|
||||||
|
return stateIndex >= 0 && stateIndex < (int) range_;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
Var::isValidState (const string& stateName)
|
||||||
|
{
|
||||||
|
States states = Var::getVarInfo (varId_).states;
|
||||||
|
return Util::contains (states, stateName);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
Var::setEvidence (int ev)
|
||||||
|
{
|
||||||
|
assert (ev < (int) range_);
|
||||||
|
evidence_ = ev;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
Var::setEvidence (const string& ev)
|
||||||
|
{
|
||||||
|
States states = Var::getVarInfo (varId_).states;
|
||||||
|
for (size_t i = 0; i < states.size(); i++) {
|
||||||
|
if (states[i] == ev) {
|
||||||
|
evidence_ = i;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert (false);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
string
|
||||||
|
Var::label (void) const
|
||||||
|
{
|
||||||
|
if (Var::varsHaveInfo()) {
|
||||||
|
return Var::getVarInfo (varId_).label;
|
||||||
|
}
|
||||||
|
stringstream ss;
|
||||||
|
ss << "x" << varId_;
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
States
|
||||||
|
Var::states (void) const
|
||||||
|
{
|
||||||
|
if (Var::varsHaveInfo()) {
|
||||||
|
return Var::getVarInfo (varId_).states;
|
||||||
|
}
|
||||||
|
States states;
|
||||||
|
for (unsigned i = 0; i < range_; i++) {
|
||||||
|
stringstream ss;
|
||||||
|
ss << i ;
|
||||||
|
states.push_back (ss.str());
|
||||||
|
}
|
||||||
|
return states;
|
||||||
|
}
|
||||||
|
|
108
packages/CLPBN/horus2/Var.h
Normal file
108
packages/CLPBN/horus2/Var.h
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
#ifndef HORUS_VAR_H
|
||||||
|
#define HORUS_VAR_H
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
#include "Util.h"
|
||||||
|
#include "Horus.h"
|
||||||
|
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
|
||||||
|
struct VarInfo
|
||||||
|
{
|
||||||
|
VarInfo (string l, const States& sts) : label(l), states(sts) { }
|
||||||
|
string label;
|
||||||
|
States states;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Var
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
Var (const Var*);
|
||||||
|
|
||||||
|
Var (VarId, unsigned, int = Constants::NO_EVIDENCE);
|
||||||
|
|
||||||
|
virtual ~Var (void) { };
|
||||||
|
|
||||||
|
VarId varId (void) const { return varId_; }
|
||||||
|
|
||||||
|
unsigned range (void) const { return range_; }
|
||||||
|
|
||||||
|
int getEvidence (void) const { return evidence_; }
|
||||||
|
|
||||||
|
size_t getIndex (void) const { return index_; }
|
||||||
|
|
||||||
|
void setIndex (size_t idx) { index_ = idx; }
|
||||||
|
|
||||||
|
bool hasEvidence (void) const
|
||||||
|
{
|
||||||
|
return evidence_ != Constants::NO_EVIDENCE;
|
||||||
|
}
|
||||||
|
|
||||||
|
operator size_t (void) const { return index_; }
|
||||||
|
|
||||||
|
bool operator== (const Var& var) const
|
||||||
|
{
|
||||||
|
assert (!(varId_ == var.varId() && range_ != var.range()));
|
||||||
|
return varId_ == var.varId();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool operator!= (const Var& var) const
|
||||||
|
{
|
||||||
|
assert (!(varId_ == var.varId() && range_ != var.range()));
|
||||||
|
return varId_ != var.varId();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isValidState (int);
|
||||||
|
|
||||||
|
bool isValidState (const string&);
|
||||||
|
|
||||||
|
void setEvidence (int);
|
||||||
|
|
||||||
|
void setEvidence (const string&);
|
||||||
|
|
||||||
|
string label (void) const;
|
||||||
|
|
||||||
|
States states (void) const;
|
||||||
|
|
||||||
|
static void addVarInfo (
|
||||||
|
VarId vid, string label, const States& states)
|
||||||
|
{
|
||||||
|
assert (Util::contains (varsInfo_, vid) == false);
|
||||||
|
varsInfo_.insert (make_pair (vid, VarInfo (label, states)));
|
||||||
|
}
|
||||||
|
|
||||||
|
static VarInfo getVarInfo (VarId vid)
|
||||||
|
{
|
||||||
|
assert (Util::contains (varsInfo_, vid));
|
||||||
|
return varsInfo_.find (vid)->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool varsHaveInfo (void)
|
||||||
|
{
|
||||||
|
return varsInfo_.size() != 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void clearVarsInfo (void)
|
||||||
|
{
|
||||||
|
varsInfo_.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
VarId varId_;
|
||||||
|
unsigned range_;
|
||||||
|
int evidence_;
|
||||||
|
size_t index_;
|
||||||
|
|
||||||
|
static unordered_map<VarId, VarInfo> varsInfo_;
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // HORUS_VAR_H
|
||||||
|
|
217
packages/CLPBN/horus2/VarElim.cpp
Normal file
217
packages/CLPBN/horus2/VarElim.cpp
Normal file
@ -0,0 +1,217 @@
|
|||||||
|
#include <algorithm>
|
||||||
|
|
||||||
|
#include "VarElim.h"
|
||||||
|
#include "ElimGraph.h"
|
||||||
|
#include "Factor.h"
|
||||||
|
#include "Util.h"
|
||||||
|
|
||||||
|
|
||||||
|
VarElim::~VarElim (void)
|
||||||
|
{
|
||||||
|
delete factorList_.back();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Params
|
||||||
|
VarElim::solveQuery (VarIds queryVids)
|
||||||
|
{
|
||||||
|
if (Globals::verbosity > 1) {
|
||||||
|
cout << "Solving query on " ;
|
||||||
|
for (size_t i = 0; i < queryVids.size(); i++) {
|
||||||
|
if (i != 0) cout << ", " ;
|
||||||
|
cout << fg.getVarNode (queryVids[i])->label();
|
||||||
|
}
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
factorList_.clear();
|
||||||
|
varFactors_.clear();
|
||||||
|
elimOrder_.clear();
|
||||||
|
createFactorList();
|
||||||
|
absorveEvidence();
|
||||||
|
findEliminationOrder (queryVids);
|
||||||
|
processFactorList (queryVids);
|
||||||
|
Params params = factorList_.back()->params();
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
Util::exp (params);
|
||||||
|
}
|
||||||
|
return params;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
VarElim::printSolverFlags (void) const
|
||||||
|
{
|
||||||
|
stringstream ss;
|
||||||
|
ss << "variable elimination [" ;
|
||||||
|
ss << "elim_heuristic=" ;
|
||||||
|
ElimHeuristic eh = ElimGraph::elimHeuristic;
|
||||||
|
switch (eh) {
|
||||||
|
case SEQUENTIAL: ss << "sequential"; break;
|
||||||
|
case MIN_NEIGHBORS: ss << "min_neighbors"; break;
|
||||||
|
case MIN_WEIGHT: ss << "min_weight"; break;
|
||||||
|
case MIN_FILL: ss << "min_fill"; break;
|
||||||
|
case WEIGHTED_MIN_FILL: ss << "weighted_min_fill"; break;
|
||||||
|
}
|
||||||
|
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
||||||
|
ss << "]" ;
|
||||||
|
cout << ss.str() << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
VarElim::createFactorList (void)
|
||||||
|
{
|
||||||
|
const FacNodes& facNodes = fg.facNodes();
|
||||||
|
factorList_.reserve (facNodes.size() * 2);
|
||||||
|
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||||
|
factorList_.push_back (new Factor (facNodes[i]->factor()));
|
||||||
|
const VarNodes& neighs = facNodes[i]->neighbors();
|
||||||
|
for (size_t j = 0; j < neighs.size(); j++) {
|
||||||
|
unordered_map<VarId, vector<size_t>>::iterator it
|
||||||
|
= varFactors_.find (neighs[j]->varId());
|
||||||
|
if (it == varFactors_.end()) {
|
||||||
|
it = varFactors_.insert (make_pair (
|
||||||
|
neighs[j]->varId(), vector<size_t>())).first;
|
||||||
|
}
|
||||||
|
it->second.push_back (i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
VarElim::absorveEvidence (void)
|
||||||
|
{
|
||||||
|
if (Globals::verbosity > 2) {
|
||||||
|
Util::printDashedLine();
|
||||||
|
cout << "(initial factor list)" << endl;
|
||||||
|
printActiveFactors();
|
||||||
|
}
|
||||||
|
const VarNodes& varNodes = fg.varNodes();
|
||||||
|
for (size_t i = 0; i < varNodes.size(); i++) {
|
||||||
|
if (varNodes[i]->hasEvidence()) {
|
||||||
|
if (Globals::verbosity > 1) {
|
||||||
|
cout << "-> aborving evidence on ";
|
||||||
|
cout << varNodes[i]->label() << " = " ;
|
||||||
|
cout << varNodes[i]->getEvidence() << endl;
|
||||||
|
}
|
||||||
|
const vector<size_t>& idxs =
|
||||||
|
varFactors_.find (varNodes[i]->varId())->second;
|
||||||
|
for (size_t j = 0; j < idxs.size(); j++) {
|
||||||
|
Factor* factor = factorList_[idxs[j]];
|
||||||
|
if (factor->nrArguments() == 1) {
|
||||||
|
factorList_[idxs[j]] = 0;
|
||||||
|
} else {
|
||||||
|
factorList_[idxs[j]]->absorveEvidence (
|
||||||
|
varNodes[i]->varId(), varNodes[i]->getEvidence());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
VarElim::findEliminationOrder (const VarIds& vids)
|
||||||
|
{
|
||||||
|
elimOrder_ = ElimGraph::getEliminationOrder (factorList_, vids);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
VarElim::processFactorList (const VarIds& vids)
|
||||||
|
{
|
||||||
|
totalFactorSize_ = 0;
|
||||||
|
largestFactorSize_ = 0;
|
||||||
|
for (size_t i = 0; i < elimOrder_.size(); i++) {
|
||||||
|
if (Globals::verbosity >= 2) {
|
||||||
|
if (Globals::verbosity >= 3) {
|
||||||
|
Util::printDashedLine();
|
||||||
|
printActiveFactors();
|
||||||
|
}
|
||||||
|
cout << "-> summing out " ;
|
||||||
|
cout << fg.getVarNode (elimOrder_[i])->label() << endl;
|
||||||
|
}
|
||||||
|
eliminate (elimOrder_[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
Factor* finalFactor = new Factor();
|
||||||
|
for (size_t i = 0; i < factorList_.size(); i++) {
|
||||||
|
if (factorList_[i]) {
|
||||||
|
finalFactor->multiply (*factorList_[i]);
|
||||||
|
delete factorList_[i];
|
||||||
|
factorList_[i] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
VarIds unobservedVids;
|
||||||
|
for (size_t i = 0; i < vids.size(); i++) {
|
||||||
|
if (fg.getVarNode (vids[i])->hasEvidence() == false) {
|
||||||
|
unobservedVids.push_back (vids[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
finalFactor->reorderArguments (unobservedVids);
|
||||||
|
finalFactor->normalize();
|
||||||
|
factorList_.push_back (finalFactor);
|
||||||
|
if (Globals::verbosity > 0) {
|
||||||
|
cout << "total factor size: " << totalFactorSize_ << endl;
|
||||||
|
cout << "largest factor size: " << largestFactorSize_ << endl;
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
VarElim::eliminate (VarId elimVar)
|
||||||
|
{
|
||||||
|
Factor* result = 0;
|
||||||
|
vector<size_t>& idxs = varFactors_.find (elimVar)->second;
|
||||||
|
for (size_t i = 0; i < idxs.size(); i++) {
|
||||||
|
size_t idx = idxs[i];
|
||||||
|
if (factorList_[idx]) {
|
||||||
|
if (result == 0) {
|
||||||
|
result = new Factor (*factorList_[idx]);
|
||||||
|
} else {
|
||||||
|
result->multiply (*factorList_[idx]);
|
||||||
|
}
|
||||||
|
delete factorList_[idx];
|
||||||
|
factorList_[idx] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
totalFactorSize_ += result->size();
|
||||||
|
if (result->size() > largestFactorSize_) {
|
||||||
|
largestFactorSize_ = result->size();
|
||||||
|
}
|
||||||
|
if (result != 0 && result->nrArguments() != 1) {
|
||||||
|
result->sumOut (elimVar);
|
||||||
|
factorList_.push_back (result);
|
||||||
|
const VarIds& resultVarIds = result->arguments();
|
||||||
|
for (size_t i = 0; i < resultVarIds.size(); i++) {
|
||||||
|
vector<size_t>& idxs =
|
||||||
|
varFactors_.find (resultVarIds[i])->second;
|
||||||
|
idxs.push_back (factorList_.size() - 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
VarElim::printActiveFactors (void)
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < factorList_.size(); i++) {
|
||||||
|
if (factorList_[i] != 0) {
|
||||||
|
cout << factorList_[i]->getLabel() << " " ;
|
||||||
|
cout << factorList_[i]->params() << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
46
packages/CLPBN/horus2/VarElim.h
Normal file
46
packages/CLPBN/horus2/VarElim.h
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
#ifndef HORUS_VARELIM_H
|
||||||
|
#define HORUS_VARELIM_H
|
||||||
|
|
||||||
|
#include "unordered_map"
|
||||||
|
|
||||||
|
#include "GroundSolver.h"
|
||||||
|
#include "FactorGraph.h"
|
||||||
|
#include "Horus.h"
|
||||||
|
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
|
||||||
|
class VarElim : public GroundSolver
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
VarElim (const FactorGraph& fg) : GroundSolver (fg) { }
|
||||||
|
|
||||||
|
~VarElim (void);
|
||||||
|
|
||||||
|
Params solveQuery (VarIds);
|
||||||
|
|
||||||
|
void printSolverFlags (void) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
void createFactorList (void);
|
||||||
|
|
||||||
|
void absorveEvidence (void);
|
||||||
|
|
||||||
|
void findEliminationOrder (const VarIds&);
|
||||||
|
|
||||||
|
void processFactorList (const VarIds&);
|
||||||
|
|
||||||
|
void eliminate (VarId);
|
||||||
|
|
||||||
|
void printActiveFactors (void);
|
||||||
|
|
||||||
|
Factors factorList_;
|
||||||
|
VarIds elimOrder_;
|
||||||
|
unsigned largestFactorSize_;
|
||||||
|
unsigned totalFactorSize_;
|
||||||
|
unordered_map<VarId, vector<size_t>> varFactors_;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // HORUS_VARELIM_H
|
||||||
|
|
288
packages/CLPBN/horus2/WeightedBp.cpp
Normal file
288
packages/CLPBN/horus2/WeightedBp.cpp
Normal file
@ -0,0 +1,288 @@
|
|||||||
|
#include "WeightedBp.h"
|
||||||
|
|
||||||
|
|
||||||
|
WeightedBp::~WeightedBp (void)
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < links_.size(); i++) {
|
||||||
|
delete links_[i];
|
||||||
|
}
|
||||||
|
links_.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Params
|
||||||
|
WeightedBp::getPosterioriOf (VarId vid)
|
||||||
|
{
|
||||||
|
if (runned_ == false) {
|
||||||
|
runSolver();
|
||||||
|
}
|
||||||
|
VarNode* var = fg.getVarNode (vid);
|
||||||
|
assert (var != 0);
|
||||||
|
Params probs;
|
||||||
|
if (var->hasEvidence()) {
|
||||||
|
probs.resize (var->range(), LogAware::noEvidence());
|
||||||
|
probs[var->getEvidence()] = LogAware::withEvidence();
|
||||||
|
} else {
|
||||||
|
probs.resize (var->range(), LogAware::multIdenty());
|
||||||
|
const BpLinks& links = ninf(var)->getLinks();
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
for (size_t i = 0; i < links.size(); i++) {
|
||||||
|
WeightedLink* l = static_cast<WeightedLink*> (links[i]);
|
||||||
|
probs += l->powMessage();
|
||||||
|
}
|
||||||
|
LogAware::normalize (probs);
|
||||||
|
Util::exp (probs);
|
||||||
|
} else {
|
||||||
|
for (size_t i = 0; i < links.size(); i++) {
|
||||||
|
WeightedLink* l = static_cast<WeightedLink*> (links[i]);
|
||||||
|
probs *= l->powMessage();
|
||||||
|
}
|
||||||
|
LogAware::normalize (probs);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return probs;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
WeightedBp::createLinks (void)
|
||||||
|
{
|
||||||
|
if (Globals::verbosity > 0) {
|
||||||
|
cout << "compressed factor graph contains " ;
|
||||||
|
cout << fg.nrVarNodes() << " variables and " ;
|
||||||
|
cout << fg.nrFacNodes() << " factors " << endl;
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
const FacNodes& facNodes = fg.facNodes();
|
||||||
|
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||||
|
const VarNodes& neighs = facNodes[i]->neighbors();
|
||||||
|
for (size_t j = 0; j < neighs.size(); j++) {
|
||||||
|
if (Globals::verbosity > 1) {
|
||||||
|
cout << "creating link " ;
|
||||||
|
cout << facNodes[i]->getLabel();
|
||||||
|
cout << " -- " ;
|
||||||
|
cout << neighs[j]->label();
|
||||||
|
cout << " idx=" << j << ", weight=" << weights_[i][j] << endl;
|
||||||
|
}
|
||||||
|
links_.push_back (new WeightedLink (
|
||||||
|
facNodes[i], neighs[j], j, weights_[i][j]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (Globals::verbosity > 1) {
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
WeightedBp::maxResidualSchedule (void)
|
||||||
|
{
|
||||||
|
if (nIters_ == 1) {
|
||||||
|
for (size_t i = 0; i < links_.size(); i++) {
|
||||||
|
calculateMessage (links_[i]);
|
||||||
|
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
|
||||||
|
linkMap_.insert (make_pair (links_[i], it));
|
||||||
|
if (Globals::verbosity >= 1) {
|
||||||
|
cout << "calculating " << links_[i]->toString() << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t c = 0; c < links_.size(); c++) {
|
||||||
|
if (Globals::verbosity > 1) {
|
||||||
|
cout << endl << "current residuals:" << endl;
|
||||||
|
for (SortedOrder::iterator it = sortedOrder_.begin();
|
||||||
|
it != sortedOrder_.end(); ++it) {
|
||||||
|
cout << " " << setw (30) << left << (*it)->toString();
|
||||||
|
cout << "residual = " << (*it)->residual() << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
SortedOrder::iterator it = sortedOrder_.begin();
|
||||||
|
BpLink* link = *it;
|
||||||
|
if (Globals::verbosity >= 1) {
|
||||||
|
cout << "updating " << (*sortedOrder_.begin())->toString() << endl;
|
||||||
|
}
|
||||||
|
if (link->residual() < BpOptions::accuracy) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
link->updateMessage();
|
||||||
|
link->clearResidual();
|
||||||
|
sortedOrder_.erase (it);
|
||||||
|
linkMap_.find (link)->second = sortedOrder_.insert (link);
|
||||||
|
|
||||||
|
// update the messages that depend on message source --> destin
|
||||||
|
const FacNodes& factorNeighbors = link->varNode()->neighbors();
|
||||||
|
for (size_t i = 0; i < factorNeighbors.size(); i++) {
|
||||||
|
const BpLinks& links = ninf(factorNeighbors[i])->getLinks();
|
||||||
|
for (size_t j = 0; j < links.size(); j++) {
|
||||||
|
if (links[j]->varNode() != link->varNode()) {
|
||||||
|
if (Globals::verbosity > 1) {
|
||||||
|
cout << " calculating " << links[j]->toString() << endl;
|
||||||
|
}
|
||||||
|
calculateMessage (links[j]);
|
||||||
|
BpLinkMap::iterator iter = linkMap_.find (links[j]);
|
||||||
|
sortedOrder_.erase (iter->second);
|
||||||
|
iter->second = sortedOrder_.insert (links[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// in counting bp, the message that a variable X sends to
|
||||||
|
// to a factor F depends on the message that F sent to the X
|
||||||
|
const BpLinks& links = ninf(link->facNode())->getLinks();
|
||||||
|
for (size_t i = 0; i < links.size(); i++) {
|
||||||
|
if (links[i]->varNode() != link->varNode()) {
|
||||||
|
if (Globals::verbosity > 1) {
|
||||||
|
cout << " calculating " << links[i]->toString() << endl;
|
||||||
|
}
|
||||||
|
calculateMessage (links[i]);
|
||||||
|
BpLinkMap::iterator iter = linkMap_.find (links[i]);
|
||||||
|
sortedOrder_.erase (iter->second);
|
||||||
|
iter->second = sortedOrder_.insert (links[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
WeightedBp::calcFactorToVarMsg (BpLink* _link)
|
||||||
|
{
|
||||||
|
WeightedLink* link = static_cast<WeightedLink*> (_link);
|
||||||
|
FacNode* src = link->facNode();
|
||||||
|
const VarNode* dst = link->varNode();
|
||||||
|
const BpLinks& links = ninf(src)->getLinks();
|
||||||
|
// calculate the product of messages that were sent
|
||||||
|
// to factor `src', except from var `dst'
|
||||||
|
unsigned reps = 1;
|
||||||
|
unsigned msgSize = Util::sizeExpected (src->factor().ranges());
|
||||||
|
Params msgProduct (msgSize, LogAware::multIdenty());
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
for (size_t i = links.size(); i-- > 0; ) {
|
||||||
|
const WeightedLink* l = static_cast<const WeightedLink*> (links[i]);
|
||||||
|
if ( ! (l->varNode() == dst && l->index() == link->index())) {
|
||||||
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
|
cout << " message from " << links[i]->varNode()->label();
|
||||||
|
cout << ": " ;
|
||||||
|
}
|
||||||
|
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
|
||||||
|
reps, std::plus<double>());
|
||||||
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
reps *= links[i]->varNode()->range();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (size_t i = links.size(); i-- > 0; ) {
|
||||||
|
const WeightedLink* l = static_cast<const WeightedLink*> (links[i]);
|
||||||
|
if ( ! (l->varNode() == dst && l->index() == link->index())) {
|
||||||
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
|
cout << " message from " << links[i]->varNode()->label();
|
||||||
|
cout << ": " ;
|
||||||
|
}
|
||||||
|
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
|
||||||
|
reps, std::multiplies<double>());
|
||||||
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
reps *= links[i]->varNode()->range();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Factor result (src->factor().arguments(),
|
||||||
|
src->factor().ranges(), msgProduct);
|
||||||
|
assert (msgProduct.size() == src->factor().size());
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
result.params() += src->factor().params();
|
||||||
|
} else {
|
||||||
|
result.params() *= src->factor().params();
|
||||||
|
}
|
||||||
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
|
cout << " message product: " << msgProduct << endl;
|
||||||
|
cout << " original factor: " << src->factor().params() << endl;
|
||||||
|
cout << " factor product: " << result.params() << endl;
|
||||||
|
}
|
||||||
|
result.sumOutAllExceptIndex (link->index());
|
||||||
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
|
cout << " marginalized: " << result.params() << endl;
|
||||||
|
}
|
||||||
|
link->nextMessage() = result.params();
|
||||||
|
LogAware::normalize (link->nextMessage());
|
||||||
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
|
cout << " curr msg: " << link->message() << endl;
|
||||||
|
cout << " next msg: " << link->nextMessage() << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Params
|
||||||
|
WeightedBp::getVarToFactorMsg (const BpLink* _link) const
|
||||||
|
{
|
||||||
|
const WeightedLink* link = static_cast<const WeightedLink*> (_link);
|
||||||
|
const VarNode* src = link->varNode();
|
||||||
|
const FacNode* dst = link->facNode();
|
||||||
|
Params msg;
|
||||||
|
if (src->hasEvidence()) {
|
||||||
|
msg.resize (src->range(), LogAware::noEvidence());
|
||||||
|
double value = link->message()[src->getEvidence()];
|
||||||
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
|
msg[src->getEvidence()] = value;
|
||||||
|
cout << msg << "^" << link->weight() << "-1" ;
|
||||||
|
}
|
||||||
|
msg[src->getEvidence()] = LogAware::pow (value, link->weight() - 1);
|
||||||
|
} else {
|
||||||
|
msg = link->message();
|
||||||
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
|
cout << msg << "^" << link->weight() << "-1" ;
|
||||||
|
}
|
||||||
|
LogAware::pow (msg, link->weight() - 1);
|
||||||
|
}
|
||||||
|
const BpLinks& links = ninf(src)->getLinks();
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
for (size_t i = 0; i < links.size(); i++) {
|
||||||
|
WeightedLink* l = static_cast<WeightedLink*> (links[i]);
|
||||||
|
if ( ! (l->facNode() == dst && l->index() == link->index())) {
|
||||||
|
msg += l->powMessage();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (size_t i = 0; i < links.size(); i++) {
|
||||||
|
WeightedLink* l = static_cast<WeightedLink*> (links[i]);
|
||||||
|
if ( ! (l->facNode() == dst && l->index() == link->index())) {
|
||||||
|
msg *= l->powMessage();
|
||||||
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
|
cout << " x " << l->nextMessage() << "^" << link->weight();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
|
cout << " = " << msg;
|
||||||
|
}
|
||||||
|
return msg;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
WeightedBp::printLinkInformation (void) const
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < links_.size(); i++) {
|
||||||
|
WeightedLink* l = static_cast<WeightedLink*> (links_[i]);
|
||||||
|
cout << l->toString() << ":" << endl;
|
||||||
|
cout << " curr msg = " << l->message() << endl;
|
||||||
|
cout << " next msg = " << l->nextMessage() << endl;
|
||||||
|
cout << " pow msg = " << l->powMessage() << endl;
|
||||||
|
cout << " index = " << l->index() << endl;
|
||||||
|
cout << " weight = " << l->weight() << endl;
|
||||||
|
cout << " residual = " << l->residual() << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
61
packages/CLPBN/horus2/WeightedBp.h
Normal file
61
packages/CLPBN/horus2/WeightedBp.h
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
#ifndef HORUS_WEIGHTEDBP_H
|
||||||
|
#define HORUS_WEIGHTEDBP_H
|
||||||
|
|
||||||
|
#include "BeliefProp.h"
|
||||||
|
|
||||||
|
class WeightedLink : public BpLink
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
WeightedLink (FacNode* fn, VarNode* vn, size_t idx, unsigned weight)
|
||||||
|
: BpLink (fn, vn), index_(idx), weight_(weight),
|
||||||
|
pwdMsg_(vn->range(), LogAware::one()) { }
|
||||||
|
|
||||||
|
size_t index (void) const { return index_; }
|
||||||
|
|
||||||
|
unsigned weight (void) const { return weight_; }
|
||||||
|
|
||||||
|
const Params& powMessage (void) const { return pwdMsg_; }
|
||||||
|
|
||||||
|
void updateMessage (void)
|
||||||
|
{
|
||||||
|
pwdMsg_ = *nextMsg_;
|
||||||
|
swap (currMsg_, nextMsg_);
|
||||||
|
LogAware::pow (pwdMsg_, weight_);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
size_t index_;
|
||||||
|
unsigned weight_;
|
||||||
|
Params pwdMsg_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class WeightedBp : public BeliefProp
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
WeightedBp (const FactorGraph& fg,
|
||||||
|
const vector<vector<unsigned>>& weights)
|
||||||
|
: BeliefProp (fg), weights_(weights) { }
|
||||||
|
|
||||||
|
~WeightedBp (void);
|
||||||
|
|
||||||
|
Params getPosterioriOf (VarId);
|
||||||
|
|
||||||
|
private:
|
||||||
|
|
||||||
|
void createLinks (void);
|
||||||
|
|
||||||
|
void maxResidualSchedule (void);
|
||||||
|
|
||||||
|
void calcFactorToVarMsg (BpLink*);
|
||||||
|
|
||||||
|
Params getVarToFactorMsg (const BpLink*) const;
|
||||||
|
|
||||||
|
void printLinkInformation (void) const;
|
||||||
|
|
||||||
|
vector<vector<unsigned>> weights_;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // HORUS_WEIGHTEDBP_H
|
||||||
|
|
Reference in New Issue
Block a user