Merge branch 'master' of https://github.com/tacgomes/yap6.3
Conflicts: packages/CLPBN/clpbn/horus.yap
This commit is contained in:
@@ -1,77 +0,0 @@
|
||||
#include <cstdlib>
|
||||
#include <cassert>
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
|
||||
#include "BayesBall.h"
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
|
||||
FactorGraph*
|
||||
BayesBall::getMinimalFactorGraph (const VarIds& queryIds)
|
||||
{
|
||||
assert (fg_.isFromBayesNetwork());
|
||||
|
||||
Scheduling scheduling;
|
||||
for (unsigned i = 0; i < queryIds.size(); i++) {
|
||||
assert (dag_.getNode (queryIds[i]));
|
||||
DAGraphNode* n = dag_.getNode (queryIds[i]);
|
||||
scheduling.push (ScheduleInfo (n, false, true));
|
||||
}
|
||||
|
||||
while (!scheduling.empty()) {
|
||||
ScheduleInfo& sch = scheduling.front();
|
||||
DAGraphNode* n = sch.node;
|
||||
n->setAsVisited();
|
||||
if (n->hasEvidence() == false && sch.visitedFromChild) {
|
||||
if (n->isMarkedOnTop() == false) {
|
||||
n->markOnTop();
|
||||
scheduleParents (n, scheduling);
|
||||
}
|
||||
if (n->isMarkedOnBottom() == false) {
|
||||
n->markOnBottom();
|
||||
scheduleChilds (n, scheduling);
|
||||
}
|
||||
}
|
||||
if (sch.visitedFromParent) {
|
||||
if (n->hasEvidence() && n->isMarkedOnTop() == false) {
|
||||
n->markOnTop();
|
||||
scheduleParents (n, scheduling);
|
||||
}
|
||||
if (n->hasEvidence() == false && n->isMarkedOnBottom() == false) {
|
||||
n->markOnBottom();
|
||||
scheduleChilds (n, scheduling);
|
||||
}
|
||||
}
|
||||
scheduling.pop();
|
||||
}
|
||||
|
||||
FactorGraph* fg = new FactorGraph();
|
||||
constructGraph (fg);
|
||||
return fg;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BayesBall::constructGraph (FactorGraph* fg) const
|
||||
{
|
||||
const FacNodes& facNodes = fg_.facNodes();
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
const DAGraphNode* n = dag_.getNode (
|
||||
facNodes[i]->factor().argument (0));
|
||||
if (n->isMarkedOnTop()) {
|
||||
fg->addFactor (Factor (facNodes[i]->factor()));
|
||||
} else if (n->hasEvidence() && n->isVisited()) {
|
||||
VarIds varIds = { facNodes[i]->factor().argument (0) };
|
||||
Ranges ranges = { facNodes[i]->factor().range (0) };
|
||||
Params params (ranges[0], LogAware::noEvidence());
|
||||
params[n->getEvidence()] = LogAware::withEvidence();
|
||||
fg->addFactor (Factor (varIds, ranges, params));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -1,85 +0,0 @@
|
||||
#ifndef HORUS_BAYESBALL_H
|
||||
#define HORUS_BAYESBALL_H
|
||||
|
||||
#include <vector>
|
||||
#include <queue>
|
||||
#include <list>
|
||||
#include <map>
|
||||
|
||||
#include "FactorGraph.h"
|
||||
#include "BayesNet.h"
|
||||
#include "Horus.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
||||
struct ScheduleInfo
|
||||
{
|
||||
ScheduleInfo (DAGraphNode* n, bool vfp, bool vfc) :
|
||||
node(n), visitedFromParent(vfp), visitedFromChild(vfc) { }
|
||||
|
||||
DAGraphNode* node;
|
||||
bool visitedFromParent;
|
||||
bool visitedFromChild;
|
||||
};
|
||||
|
||||
|
||||
typedef queue<ScheduleInfo, list<ScheduleInfo>> Scheduling;
|
||||
|
||||
|
||||
class BayesBall
|
||||
{
|
||||
public:
|
||||
BayesBall (FactorGraph& fg)
|
||||
: fg_(fg) , dag_(fg.getStructure())
|
||||
{
|
||||
dag_.clear();
|
||||
}
|
||||
|
||||
FactorGraph* getMinimalFactorGraph (const VarIds&);
|
||||
|
||||
static FactorGraph* getMinimalFactorGraph (FactorGraph& fg, VarIds vids)
|
||||
{
|
||||
BayesBall bb (fg);
|
||||
return bb.getMinimalFactorGraph (vids);
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
void constructGraph (FactorGraph* fg) const;
|
||||
|
||||
void scheduleParents (const DAGraphNode* n, Scheduling& sch) const;
|
||||
|
||||
void scheduleChilds (const DAGraphNode* n, Scheduling& sch) const;
|
||||
|
||||
FactorGraph& fg_;
|
||||
|
||||
DAGraph& dag_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
BayesBall::scheduleParents (const DAGraphNode* n, Scheduling& sch) const
|
||||
{
|
||||
const vector<DAGraphNode*>& ps = n->parents();
|
||||
for (vector<DAGraphNode*>::const_iterator it = ps.begin();
|
||||
it != ps.end(); it++) {
|
||||
sch.push (ScheduleInfo (*it, false, true));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
BayesBall::scheduleChilds (const DAGraphNode* n, Scheduling& sch) const
|
||||
{
|
||||
const vector<DAGraphNode*>& cs = n->childs();
|
||||
for (vector<DAGraphNode*>::const_iterator it = cs.begin();
|
||||
it != cs.end(); it++) {
|
||||
sch.push (ScheduleInfo (*it, true, false));
|
||||
}
|
||||
}
|
||||
|
||||
#endif // HORUS_BAYESBALL_H
|
||||
|
@@ -1,107 +0,0 @@
|
||||
#include <cstdlib>
|
||||
#include <cassert>
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
|
||||
#include "BayesNet.h"
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
void
|
||||
DAGraph::addNode (DAGraphNode* n)
|
||||
{
|
||||
assert (Util::contains (varMap_, n->varId()) == false);
|
||||
nodes_.push_back (n);
|
||||
varMap_[n->varId()] = n;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
DAGraph::addEdge (VarId vid1, VarId vid2)
|
||||
{
|
||||
unordered_map<VarId, DAGraphNode*>::iterator it1;
|
||||
unordered_map<VarId, DAGraphNode*>::iterator it2;
|
||||
it1 = varMap_.find (vid1);
|
||||
it2 = varMap_.find (vid2);
|
||||
assert (it1 != varMap_.end());
|
||||
assert (it2 != varMap_.end());
|
||||
it1->second->addChild (it2->second);
|
||||
it2->second->addParent (it1->second);
|
||||
}
|
||||
|
||||
|
||||
|
||||
const DAGraphNode*
|
||||
DAGraph::getNode (VarId vid) const
|
||||
{
|
||||
unordered_map<VarId, DAGraphNode*>::const_iterator it;
|
||||
it = varMap_.find (vid);
|
||||
return it != varMap_.end() ? it->second : 0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
DAGraphNode*
|
||||
DAGraph::getNode (VarId vid)
|
||||
{
|
||||
unordered_map<VarId, DAGraphNode*>::const_iterator it;
|
||||
it = varMap_.find (vid);
|
||||
return it != varMap_.end() ? it->second : 0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
DAGraph::setIndexes (void)
|
||||
{
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
nodes_[i]->setIndex (i);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
DAGraph::clear (void)
|
||||
{
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
nodes_[i]->clear();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
DAGraph::exportToGraphViz (const char* fileName)
|
||||
{
|
||||
ofstream out (fileName);
|
||||
if (!out.is_open()) {
|
||||
cerr << "error: cannot open file to write at " ;
|
||||
cerr << "DAGraph::exportToDotFile()" << endl;
|
||||
abort();
|
||||
}
|
||||
out << "digraph {" << endl;
|
||||
out << "ranksep=1" << endl;
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
out << nodes_[i]->varId() ;
|
||||
out << " [" ;
|
||||
out << "label=\"" << nodes_[i]->label() << "\"" ;
|
||||
if (nodes_[i]->hasEvidence()) {
|
||||
out << ",style=filled, fillcolor=yellow" ;
|
||||
}
|
||||
out << "]" << endl;
|
||||
}
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
const vector<DAGraphNode*>& childs = nodes_[i]->childs();
|
||||
for (unsigned j = 0; j < childs.size(); j++) {
|
||||
out << nodes_[i]->varId() << " -> " << childs[j]->varId();
|
||||
out << " [style=bold]" << endl ;
|
||||
}
|
||||
}
|
||||
out << "}" << endl;
|
||||
out.close();
|
||||
}
|
||||
|
@@ -1,88 +0,0 @@
|
||||
#ifndef HORUS_BAYESNET_H
|
||||
#define HORUS_BAYESNET_H
|
||||
|
||||
#include <vector>
|
||||
#include <queue>
|
||||
#include <list>
|
||||
#include <map>
|
||||
|
||||
#include "Var.h"
|
||||
#include "Horus.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
||||
class Var;
|
||||
|
||||
class DAGraphNode : public Var
|
||||
{
|
||||
public:
|
||||
DAGraphNode (Var* v) : Var (v) , visited_(false),
|
||||
markedOnTop_(false), markedOnBottom_(false) { }
|
||||
|
||||
const vector<DAGraphNode*>& childs (void) const { return childs_; }
|
||||
|
||||
vector<DAGraphNode*>& childs (void) { return childs_; }
|
||||
|
||||
const vector<DAGraphNode*>& parents (void) const { return parents_; }
|
||||
|
||||
vector<DAGraphNode*>& parents (void) { return parents_; }
|
||||
|
||||
void addParent (DAGraphNode* p) { parents_.push_back (p); }
|
||||
|
||||
void addChild (DAGraphNode* c) { childs_.push_back (c); }
|
||||
|
||||
bool isVisited (void) const { return visited_; }
|
||||
|
||||
void setAsVisited (void) { visited_ = true; }
|
||||
|
||||
bool isMarkedOnTop (void) const { return markedOnTop_; }
|
||||
|
||||
void markOnTop (void) { markedOnTop_ = true; }
|
||||
|
||||
bool isMarkedOnBottom (void) const { return markedOnBottom_; }
|
||||
|
||||
void markOnBottom (void) { markedOnBottom_ = true; }
|
||||
|
||||
void clear (void) { visited_ = markedOnTop_ = markedOnBottom_ = false; }
|
||||
|
||||
private:
|
||||
bool visited_;
|
||||
bool markedOnTop_;
|
||||
bool markedOnBottom_;
|
||||
|
||||
vector<DAGraphNode*> childs_;
|
||||
vector<DAGraphNode*> parents_;
|
||||
};
|
||||
|
||||
|
||||
class DAGraph
|
||||
{
|
||||
public:
|
||||
DAGraph (void) { }
|
||||
|
||||
void addNode (DAGraphNode* n);
|
||||
|
||||
void addEdge (VarId vid1, VarId vid2);
|
||||
|
||||
const DAGraphNode* getNode (VarId vid) const;
|
||||
|
||||
DAGraphNode* getNode (VarId vid);
|
||||
|
||||
bool empty (void) const { return nodes_.empty(); }
|
||||
|
||||
void setIndexes (void);
|
||||
|
||||
void clear (void);
|
||||
|
||||
void exportToGraphViz (const char*);
|
||||
|
||||
private:
|
||||
vector<DAGraphNode*> nodes_;
|
||||
|
||||
unordered_map<VarId, DAGraphNode*> varMap_;
|
||||
};
|
||||
|
||||
#endif // HORUS_BAYESNET_H
|
||||
|
@@ -1,493 +0,0 @@
|
||||
#include <cassert>
|
||||
#include <limits>
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "BpSolver.h"
|
||||
#include "FactorGraph.h"
|
||||
#include "Factor.h"
|
||||
#include "Indexer.h"
|
||||
#include "Horus.h"
|
||||
|
||||
|
||||
BpSolver::BpSolver (const FactorGraph& fg) : Solver (fg)
|
||||
{
|
||||
fg_ = &fg;
|
||||
runned_ = false;
|
||||
}
|
||||
|
||||
|
||||
|
||||
BpSolver::~BpSolver (void)
|
||||
{
|
||||
for (unsigned i = 0; i < varsI_.size(); i++) {
|
||||
delete varsI_[i];
|
||||
}
|
||||
for (unsigned i = 0; i < facsI_.size(); i++) {
|
||||
delete facsI_[i];
|
||||
}
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
delete links_[i];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
BpSolver::solveQuery (VarIds queryVids)
|
||||
{
|
||||
assert (queryVids.empty() == false);
|
||||
if (queryVids.size() == 1) {
|
||||
return getPosterioriOf (queryVids[0]);
|
||||
} else {
|
||||
return getJointDistributionOf (queryVids);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
BpSolver::getPosterioriOf (VarId vid)
|
||||
{
|
||||
if (runned_ == false) {
|
||||
runSolver();
|
||||
}
|
||||
assert (fg_->getVarNode (vid));
|
||||
VarNode* var = fg_->getVarNode (vid);
|
||||
Params probs;
|
||||
if (var->hasEvidence()) {
|
||||
probs.resize (var->range(), LogAware::noEvidence());
|
||||
probs[var->getEvidence()] = LogAware::withEvidence();
|
||||
} else {
|
||||
probs.resize (var->range(), LogAware::multIdenty());
|
||||
const SpLinkSet& links = ninf(var)->getLinks();
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
Util::add (probs, links[i]->getMessage());
|
||||
}
|
||||
LogAware::normalize (probs);
|
||||
Util::fromLog (probs);
|
||||
} else {
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
Util::multiply (probs, links[i]->getMessage());
|
||||
}
|
||||
LogAware::normalize (probs);
|
||||
}
|
||||
}
|
||||
return probs;
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
||||
{
|
||||
if (runned_ == false) {
|
||||
runSolver();
|
||||
}
|
||||
int idx = -1;
|
||||
VarNode* vn = fg_->getVarNode (jointVarIds[0]);
|
||||
const FacNodes& facNodes = vn->neighbors();
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
if (facNodes[i]->factor().contains (jointVarIds)) {
|
||||
idx = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (idx == -1) {
|
||||
return getJointByConditioning (jointVarIds);
|
||||
} else {
|
||||
Factor res (facNodes[idx]->factor());
|
||||
const SpLinkSet& links = ninf(facNodes[idx])->getLinks();
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
Factor msg ({links[i]->getVariable()->varId()},
|
||||
{links[i]->getVariable()->range()},
|
||||
getVar2FactorMsg (links[i]));
|
||||
res.multiply (msg);
|
||||
}
|
||||
res.sumOutAllExcept (jointVarIds);
|
||||
res.reorderArguments (jointVarIds);
|
||||
res.normalize();
|
||||
Params jointDist = res.params();
|
||||
if (Globals::logDomain) {
|
||||
Util::fromLog (jointDist);
|
||||
}
|
||||
return jointDist;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpSolver::runSolver (void)
|
||||
{
|
||||
clock_t start;
|
||||
if (Constants::COLLECT_STATS) {
|
||||
start = clock();
|
||||
}
|
||||
initializeSolver();
|
||||
nIters_ = 0;
|
||||
while (!converged() && nIters_ < BpOptions::maxIter) {
|
||||
nIters_ ++;
|
||||
if (Constants::DEBUG >= 2) {
|
||||
Util::printHeader (string ("Iteration ") + Util::toString (nIters_));
|
||||
// cout << endl;
|
||||
}
|
||||
switch (BpOptions::schedule) {
|
||||
case BpOptions::Schedule::SEQ_RANDOM:
|
||||
random_shuffle (links_.begin(), links_.end());
|
||||
// no break
|
||||
case BpOptions::Schedule::SEQ_FIXED:
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
calculateAndUpdateMessage (links_[i]);
|
||||
}
|
||||
break;
|
||||
case BpOptions::Schedule::PARALLEL:
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
calculateMessage (links_[i]);
|
||||
}
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
updateMessage(links_[i]);
|
||||
}
|
||||
break;
|
||||
case BpOptions::Schedule::MAX_RESIDUAL:
|
||||
maxResidualSchedule();
|
||||
break;
|
||||
}
|
||||
if (Constants::DEBUG >= 2) {
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
if (Constants::DEBUG >= 2) {
|
||||
cout << endl;
|
||||
if (nIters_ < BpOptions::maxIter) {
|
||||
cout << "Sum-Product converged in " ;
|
||||
cout << nIters_ << " iterations" << endl;
|
||||
} else {
|
||||
cout << "The maximum number of iterations was hit, terminating..." ;
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
unsigned size = fg_->varNodes().size();
|
||||
if (Constants::COLLECT_STATS) {
|
||||
unsigned nIters = 0;
|
||||
bool loopy = fg_->isTree() == false;
|
||||
if (loopy) nIters = nIters_;
|
||||
double time = (double (clock() - start)) / CLOCKS_PER_SEC;
|
||||
Statistics::updateStatistics (size, loopy, nIters, time);
|
||||
}
|
||||
runned_ = true;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpSolver::createLinks (void)
|
||||
{
|
||||
const FacNodes& facNodes = fg_->facNodes();
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
const VarNodes& neighbors = facNodes[i]->neighbors();
|
||||
for (unsigned j = 0; j < neighbors.size(); j++) {
|
||||
links_.push_back (new SpLink (facNodes[i], neighbors[j]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpSolver::maxResidualSchedule (void)
|
||||
{
|
||||
if (nIters_ == 1) {
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
calculateMessage (links_[i]);
|
||||
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
|
||||
linkMap_.insert (make_pair (links_[i], it));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
for (unsigned c = 0; c < links_.size(); c++) {
|
||||
if (Constants::DEBUG >= 2) {
|
||||
cout << "current residuals:" << endl;
|
||||
for (SortedOrder::iterator it = sortedOrder_.begin();
|
||||
it != sortedOrder_.end(); it ++) {
|
||||
cout << " " << setw (30) << left << (*it)->toString();
|
||||
cout << "residual = " << (*it)->getResidual() << endl;
|
||||
}
|
||||
}
|
||||
|
||||
SortedOrder::iterator it = sortedOrder_.begin();
|
||||
SpLink* link = *it;
|
||||
if (link->getResidual() < BpOptions::accuracy) {
|
||||
return;
|
||||
}
|
||||
updateMessage (link);
|
||||
link->clearResidual();
|
||||
sortedOrder_.erase (it);
|
||||
linkMap_.find (link)->second = sortedOrder_.insert (link);
|
||||
|
||||
// update the messages that depend on message source --> destin
|
||||
const FacNodes& factorNeighbors = link->getVariable()->neighbors();
|
||||
for (unsigned i = 0; i < factorNeighbors.size(); i++) {
|
||||
if (factorNeighbors[i] != link->getFactor()) {
|
||||
const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks();
|
||||
for (unsigned j = 0; j < links.size(); j++) {
|
||||
if (links[j]->getVariable() != link->getVariable()) {
|
||||
calculateMessage (links[j]);
|
||||
SpLinkMap::iterator iter = linkMap_.find (links[j]);
|
||||
sortedOrder_.erase (iter->second);
|
||||
iter->second = sortedOrder_.insert (links[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (Constants::DEBUG >= 2) {
|
||||
Util::printDashedLine();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpSolver::calculateFactor2VariableMsg (SpLink* link)
|
||||
{
|
||||
FacNode* src = link->getFactor();
|
||||
const VarNode* dst = link->getVariable();
|
||||
const SpLinkSet& links = ninf(src)->getLinks();
|
||||
// calculate the product of messages that were sent
|
||||
// to factor `src', except from var `dst'
|
||||
unsigned msgSize = 1;
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
msgSize *= links[i]->getVariable()->range();
|
||||
}
|
||||
unsigned repetitions = 1;
|
||||
Params msgProduct (msgSize, LogAware::multIdenty());
|
||||
if (Globals::logDomain) {
|
||||
for (int i = links.size() - 1; i >= 0; i--) {
|
||||
if (links[i]->getVariable() != dst) {
|
||||
Util::add (msgProduct, getVar2FactorMsg (links[i]), repetitions);
|
||||
repetitions *= links[i]->getVariable()->range();
|
||||
} else {
|
||||
unsigned ds = links[i]->getVariable()->range();
|
||||
Util::add (msgProduct, Params (ds, 1.0), repetitions);
|
||||
repetitions *= ds;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = links.size() - 1; i >= 0; i--) {
|
||||
if (links[i]->getVariable() != dst) {
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " message from " << links[i]->getVariable()->label();
|
||||
cout << ": " << endl;
|
||||
}
|
||||
Util::multiply (msgProduct, getVar2FactorMsg (links[i]), repetitions);
|
||||
repetitions *= links[i]->getVariable()->range();
|
||||
} else {
|
||||
unsigned ds = links[i]->getVariable()->range();
|
||||
Util::multiply (msgProduct, Params (ds, 1.0), repetitions);
|
||||
repetitions *= ds;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Factor result (src->factor().arguments(),
|
||||
src->factor().ranges(), msgProduct);
|
||||
result.multiply (src->factor());
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " message product: " << msgProduct << endl;
|
||||
cout << " original factor: " << src->factor().params() << endl;
|
||||
cout << " factor product: " << result.params() << endl;
|
||||
}
|
||||
result.sumOutAllExcept (dst->varId());
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " marginalized: " ;
|
||||
cout << result.params() << endl;
|
||||
}
|
||||
const Params& resultParams = result.params();
|
||||
Params& message = link->getNextMessage();
|
||||
for (unsigned i = 0; i < resultParams.size(); i++) {
|
||||
message[i] = resultParams[i];
|
||||
}
|
||||
LogAware::normalize (message);
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " curr msg: " << link->getMessage() << endl;
|
||||
cout << " next msg: " << message << endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
BpSolver::getVar2FactorMsg (const SpLink* link) const
|
||||
{
|
||||
const VarNode* src = link->getVariable();
|
||||
const FacNode* dst = link->getFactor();
|
||||
Params msg;
|
||||
if (src->hasEvidence()) {
|
||||
msg.resize (src->range(), LogAware::noEvidence());
|
||||
msg[src->getEvidence()] = LogAware::withEvidence();
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << msg;
|
||||
}
|
||||
} else {
|
||||
msg.resize (src->range(), LogAware::one());
|
||||
}
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << msg;
|
||||
}
|
||||
const SpLinkSet& links = ninf (src)->getLinks();
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
if (links[i]->getFactor() != dst) {
|
||||
Util::add (msg, links[i]->getMessage());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
if (links[i]->getFactor() != dst) {
|
||||
Util::multiply (msg, links[i]->getMessage());
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " x " << links[i]->getMessage();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " = " << msg;
|
||||
}
|
||||
return msg;
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
BpSolver::getJointByConditioning (const VarIds& jointVarIds) const
|
||||
{
|
||||
VarNodes jointVars;
|
||||
for (unsigned i = 0; i < jointVarIds.size(); i++) {
|
||||
assert (fg_->getVarNode (jointVarIds[i]));
|
||||
jointVars.push_back (fg_->getVarNode (jointVarIds[i]));
|
||||
}
|
||||
|
||||
FactorGraph* fg = new FactorGraph (*fg_);
|
||||
BpSolver solver (*fg);
|
||||
solver.runSolver();
|
||||
Params prevBeliefs = solver.getPosterioriOf (jointVarIds[0]);
|
||||
|
||||
VarIds observedVids = {jointVars[0]->varId()};
|
||||
|
||||
for (unsigned i = 1; i < jointVarIds.size(); i++) {
|
||||
assert (jointVars[i]->hasEvidence() == false);
|
||||
Params newBeliefs;
|
||||
Vars observedVars;
|
||||
for (unsigned j = 0; j < observedVids.size(); j++) {
|
||||
observedVars.push_back (fg->getVarNode (observedVids[j]));
|
||||
}
|
||||
StatesIndexer idx (observedVars, false);
|
||||
while (idx.valid()) {
|
||||
for (unsigned j = 0; j < observedVars.size(); j++) {
|
||||
observedVars[j]->setEvidence (idx[j]);
|
||||
}
|
||||
++ idx;
|
||||
BpSolver solver (*fg);
|
||||
solver.runSolver();
|
||||
Params beliefs = solver.getPosterioriOf (jointVarIds[i]);
|
||||
for (unsigned k = 0; k < beliefs.size(); k++) {
|
||||
newBeliefs.push_back (beliefs[k]);
|
||||
}
|
||||
}
|
||||
|
||||
int count = -1;
|
||||
for (unsigned j = 0; j < newBeliefs.size(); j++) {
|
||||
if (j % jointVars[i]->range() == 0) {
|
||||
count ++;
|
||||
}
|
||||
newBeliefs[j] *= prevBeliefs[count];
|
||||
}
|
||||
prevBeliefs = newBeliefs;
|
||||
observedVids.push_back (jointVars[i]->varId());
|
||||
}
|
||||
return prevBeliefs;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpSolver::initializeSolver (void)
|
||||
{
|
||||
const VarNodes& varNodes = fg_->varNodes();
|
||||
varsI_.reserve (varNodes.size());
|
||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||
varsI_.push_back (new SPNodeInfo());
|
||||
}
|
||||
const FacNodes& facNodes = fg_->facNodes();
|
||||
facsI_.reserve (facNodes.size());
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
facsI_.push_back (new SPNodeInfo());
|
||||
}
|
||||
createLinks();
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
FacNode* src = links_[i]->getFactor();
|
||||
VarNode* dst = links_[i]->getVariable();
|
||||
ninf (dst)->addSpLink (links_[i]);
|
||||
ninf (src)->addSpLink (links_[i]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
BpSolver::converged (void)
|
||||
{
|
||||
if (links_.size() == 0) {
|
||||
return true;
|
||||
}
|
||||
if (nIters_ <= 1) {
|
||||
return false;
|
||||
}
|
||||
bool converged = true;
|
||||
if (BpOptions::schedule == BpOptions::Schedule::MAX_RESIDUAL) {
|
||||
double maxResidual = (*(sortedOrder_.begin()))->getResidual();
|
||||
if (maxResidual > BpOptions::accuracy) {
|
||||
converged = false;
|
||||
} else {
|
||||
converged = true;
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
double residual = links_[i]->getResidual();
|
||||
if (Constants::DEBUG >= 2) {
|
||||
cout << links_[i]->toString() + " residual = " << residual << endl;
|
||||
}
|
||||
if (residual > BpOptions::accuracy) {
|
||||
converged = false;
|
||||
if (Constants::DEBUG == 0) break;
|
||||
}
|
||||
}
|
||||
if (Constants::DEBUG >= 2) {
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
return converged;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpSolver::printLinkInformation (void) const
|
||||
{
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
SpLink* l = links_[i];
|
||||
cout << l->toString() << ":" << endl;
|
||||
cout << " curr msg = " ;
|
||||
cout << l->getMessage() << endl;
|
||||
cout << " next msg = " ;
|
||||
cout << l->getNextMessage() << endl;
|
||||
cout << " residual = " << l->getResidual() << endl;
|
||||
}
|
||||
}
|
||||
|
@@ -1,188 +0,0 @@
|
||||
#ifndef HORUS_BPSOLVER_H
|
||||
#define HORUS_BPSOLVER_H
|
||||
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <sstream>
|
||||
|
||||
#include "Solver.h"
|
||||
#include "Factor.h"
|
||||
#include "FactorGraph.h"
|
||||
#include "Util.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
||||
class SpLink
|
||||
{
|
||||
public:
|
||||
SpLink (FacNode* fn, VarNode* vn)
|
||||
{
|
||||
fac_ = fn;
|
||||
var_ = vn;
|
||||
v1_.resize (vn->range(), LogAware::tl (1.0 / vn->range()));
|
||||
v2_.resize (vn->range(), LogAware::tl (1.0 / vn->range()));
|
||||
currMsg_ = &v1_;
|
||||
nextMsg_ = &v2_;
|
||||
msgSended_ = false;
|
||||
residual_ = 0.0;
|
||||
}
|
||||
|
||||
virtual ~SpLink (void) { };
|
||||
|
||||
FacNode* getFactor (void) const { return fac_; }
|
||||
|
||||
VarNode* getVariable (void) const { return var_; }
|
||||
|
||||
const Params& getMessage (void) const { return *currMsg_; }
|
||||
|
||||
Params& getNextMessage (void) { return *nextMsg_; }
|
||||
|
||||
bool messageWasSended (void) const { return msgSended_; }
|
||||
|
||||
double getResidual (void) const { return residual_; }
|
||||
|
||||
void clearResidual (void) { residual_ = 0.0; }
|
||||
|
||||
void updateResidual (void)
|
||||
{
|
||||
residual_ = LogAware::getMaxNorm (v1_,v2_);
|
||||
}
|
||||
|
||||
virtual void updateMessage (void)
|
||||
{
|
||||
swap (currMsg_, nextMsg_);
|
||||
msgSended_ = true;
|
||||
}
|
||||
|
||||
string toString (void) const
|
||||
{
|
||||
stringstream ss;
|
||||
ss << fac_->getLabel();
|
||||
ss << " -- " ;
|
||||
ss << var_->label();
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
protected:
|
||||
FacNode* fac_;
|
||||
VarNode* var_;
|
||||
Params v1_;
|
||||
Params v2_;
|
||||
Params* currMsg_;
|
||||
Params* nextMsg_;
|
||||
bool msgSended_;
|
||||
double residual_;
|
||||
};
|
||||
|
||||
typedef vector<SpLink*> SpLinkSet;
|
||||
|
||||
|
||||
class SPNodeInfo
|
||||
{
|
||||
public:
|
||||
void addSpLink (SpLink* link) { links_.push_back (link); }
|
||||
const SpLinkSet& getLinks (void) { return links_; }
|
||||
private:
|
||||
SpLinkSet links_;
|
||||
};
|
||||
|
||||
|
||||
class BpSolver : public Solver
|
||||
{
|
||||
public:
|
||||
BpSolver (const FactorGraph&);
|
||||
|
||||
virtual ~BpSolver (void);
|
||||
|
||||
Params solveQuery (VarIds);
|
||||
|
||||
virtual Params getPosterioriOf (VarId);
|
||||
|
||||
virtual Params getJointDistributionOf (const VarIds&);
|
||||
|
||||
protected:
|
||||
void runSolver (void);
|
||||
|
||||
virtual void createLinks (void);
|
||||
|
||||
virtual void maxResidualSchedule (void);
|
||||
|
||||
virtual void calculateFactor2VariableMsg (SpLink*);
|
||||
|
||||
virtual Params getVar2FactorMsg (const SpLink*) const;
|
||||
|
||||
virtual Params getJointByConditioning (const VarIds&) const;
|
||||
|
||||
SPNodeInfo* ninf (const VarNode* var) const
|
||||
{
|
||||
return varsI_[var->getIndex()];
|
||||
}
|
||||
|
||||
SPNodeInfo* ninf (const FacNode* fac) const
|
||||
{
|
||||
return facsI_[fac->getIndex()];
|
||||
}
|
||||
|
||||
void calculateAndUpdateMessage (SpLink* link, bool calcResidual = true)
|
||||
{
|
||||
if (Constants::DEBUG >= 3) {
|
||||
cout << "calculating & updating " << link->toString() << endl;
|
||||
}
|
||||
calculateFactor2VariableMsg (link);
|
||||
if (calcResidual) {
|
||||
link->updateResidual();
|
||||
}
|
||||
link->updateMessage();
|
||||
}
|
||||
|
||||
void calculateMessage (SpLink* link, bool calcResidual = true)
|
||||
{
|
||||
if (Constants::DEBUG >= 3) {
|
||||
cout << "calculating " << link->toString() << endl;
|
||||
}
|
||||
calculateFactor2VariableMsg (link);
|
||||
if (calcResidual) {
|
||||
link->updateResidual();
|
||||
}
|
||||
}
|
||||
|
||||
void updateMessage (SpLink* link)
|
||||
{
|
||||
link->updateMessage();
|
||||
if (Constants::DEBUG >= 3) {
|
||||
cout << "updating " << link->toString() << endl;
|
||||
}
|
||||
}
|
||||
|
||||
struct CompareResidual
|
||||
{
|
||||
inline bool operator() (const SpLink* link1, const SpLink* link2)
|
||||
{
|
||||
return link1->getResidual() > link2->getResidual();
|
||||
}
|
||||
};
|
||||
|
||||
SpLinkSet links_;
|
||||
unsigned nIters_;
|
||||
vector<SPNodeInfo*> varsI_;
|
||||
vector<SPNodeInfo*> facsI_;
|
||||
bool runned_;
|
||||
const FactorGraph* fg_;
|
||||
|
||||
typedef multiset<SpLink*, CompareResidual> SortedOrder;
|
||||
SortedOrder sortedOrder_;
|
||||
|
||||
typedef unordered_map<SpLink*, SortedOrder::iterator> SpLinkMap;
|
||||
SpLinkMap linkMap_;
|
||||
|
||||
private:
|
||||
void initializeSolver (void);
|
||||
|
||||
bool converged (void);
|
||||
|
||||
void printLinkInformation (void) const;
|
||||
};
|
||||
|
||||
#endif // HORUS_BPSOLVER_H
|
||||
|
@@ -1,339 +0,0 @@
|
||||
|
||||
#include "CFactorGraph.h"
|
||||
#include "Factor.h"
|
||||
|
||||
|
||||
bool CFactorGraph::checkForIdenticalFactors = true;
|
||||
|
||||
CFactorGraph::CFactorGraph (const FactorGraph& fg)
|
||||
{
|
||||
groundFg_ = &fg;
|
||||
freeColor_ = 0;
|
||||
|
||||
const VarNodes& varNodes = fg.varNodes();
|
||||
varSignatures_.reserve (varNodes.size());
|
||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||
unsigned c = (varNodes[i]->neighbors().size() * 2) + 1;
|
||||
varSignatures_.push_back (Signature (c));
|
||||
}
|
||||
|
||||
const FacNodes& facNodes = fg.facNodes();
|
||||
facSignatures_.reserve (facNodes.size());
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
unsigned c = facNodes[i]->neighbors().size() + 1;
|
||||
facSignatures_.push_back (Signature (c));
|
||||
}
|
||||
|
||||
varColors_.resize (varNodes.size());
|
||||
facColors_.resize (facNodes.size());
|
||||
setInitialColors();
|
||||
createGroups();
|
||||
}
|
||||
|
||||
|
||||
|
||||
CFactorGraph::~CFactorGraph (void)
|
||||
{
|
||||
for (unsigned i = 0; i < varClusters_.size(); i++) {
|
||||
delete varClusters_[i];
|
||||
}
|
||||
for (unsigned i = 0; i < facClusters_.size(); i++) {
|
||||
delete facClusters_[i];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
CFactorGraph::setInitialColors (void)
|
||||
{
|
||||
// create the initial variable colors
|
||||
VarColorMap colorMap;
|
||||
const VarNodes& varNodes = groundFg_->varNodes();
|
||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||
unsigned dsize = varNodes[i]->range();
|
||||
VarColorMap::iterator it = colorMap.find (dsize);
|
||||
if (it == colorMap.end()) {
|
||||
it = colorMap.insert (make_pair (
|
||||
dsize, vector<Color> (dsize+1,-1))).first;
|
||||
}
|
||||
unsigned idx;
|
||||
if (varNodes[i]->hasEvidence()) {
|
||||
idx = varNodes[i]->getEvidence();
|
||||
} else {
|
||||
idx = dsize;
|
||||
}
|
||||
vector<Color>& stateColors = it->second;
|
||||
if (stateColors[idx] == -1) {
|
||||
stateColors[idx] = getFreeColor();
|
||||
}
|
||||
setColor (varNodes[i], stateColors[idx]);
|
||||
}
|
||||
|
||||
const FacNodes& facNodes = groundFg_->facNodes();
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
facNodes[i]->factor().setDistId (Util::maxUnsigned());
|
||||
}
|
||||
// FIXME FIXME FIXME : pfl should give correct dist ids.
|
||||
if (checkForIdenticalFactors || true) {
|
||||
unsigned groupCount = 1;
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
Factor& f1 = facNodes[i]->factor();
|
||||
if (f1.distId() != Util::maxUnsigned()) {
|
||||
continue;
|
||||
}
|
||||
f1.setDistId (groupCount);
|
||||
for (unsigned j = i + 1; j < facNodes.size(); j++) {
|
||||
Factor& f2 = facNodes[j]->factor();
|
||||
if (f2.distId() != Util::maxUnsigned()) {
|
||||
continue;
|
||||
}
|
||||
if (f1.size() == f2.size() &&
|
||||
f1.ranges() == f2.ranges() &&
|
||||
f1.params() == f2.params()) {
|
||||
f2.setDistId (groupCount);
|
||||
}
|
||||
}
|
||||
groupCount ++;
|
||||
}
|
||||
}
|
||||
// create the initial factor colors
|
||||
DistColorMap distColors;
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
unsigned distId = facNodes[i]->factor().distId();
|
||||
DistColorMap::iterator it = distColors.find (distId);
|
||||
if (it == distColors.end()) {
|
||||
it = distColors.insert (make_pair (distId, getFreeColor())).first;
|
||||
}
|
||||
setColor (facNodes[i], it->second);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
CFactorGraph::createGroups (void)
|
||||
{
|
||||
VarSignMap varGroups;
|
||||
FacSignMap facGroups;
|
||||
unsigned nIters = 0;
|
||||
bool groupsHaveChanged = true;
|
||||
const VarNodes& varNodes = groundFg_->varNodes();
|
||||
const FacNodes& facNodes = groundFg_->facNodes();
|
||||
|
||||
while (groupsHaveChanged || nIters == 1) {
|
||||
nIters ++;
|
||||
|
||||
unsigned prevFactorGroupsSize = facGroups.size();
|
||||
facGroups.clear();
|
||||
// set a new color to the factors with the same signature
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
const Signature& signature = getSignature (facNodes[i]);
|
||||
FacSignMap::iterator it = facGroups.find (signature);
|
||||
if (it == facGroups.end()) {
|
||||
it = facGroups.insert (make_pair (signature, FacNodes())).first;
|
||||
}
|
||||
it->second.push_back (facNodes[i]);
|
||||
}
|
||||
for (FacSignMap::iterator it = facGroups.begin();
|
||||
it != facGroups.end(); it++) {
|
||||
Color newColor = getFreeColor();
|
||||
FacNodes& groupMembers = it->second;
|
||||
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
||||
setColor (groupMembers[i], newColor);
|
||||
}
|
||||
}
|
||||
|
||||
// set a new color to the variables with the same signature
|
||||
unsigned prevVarGroupsSize = varGroups.size();
|
||||
varGroups.clear();
|
||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||
const Signature& signature = getSignature (varNodes[i]);
|
||||
VarSignMap::iterator it = varGroups.find (signature);
|
||||
if (it == varGroups.end()) {
|
||||
it = varGroups.insert (make_pair (signature, VarNodes())).first;
|
||||
}
|
||||
it->second.push_back (varNodes[i]);
|
||||
}
|
||||
for (VarSignMap::iterator it = varGroups.begin();
|
||||
it != varGroups.end(); it++) {
|
||||
Color newColor = getFreeColor();
|
||||
VarNodes& groupMembers = it->second;
|
||||
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
||||
setColor (groupMembers[i], newColor);
|
||||
}
|
||||
}
|
||||
|
||||
groupsHaveChanged = prevVarGroupsSize != varGroups.size()
|
||||
|| prevFactorGroupsSize != facGroups.size();
|
||||
}
|
||||
printGroups (varGroups, facGroups);
|
||||
createClusters (varGroups, facGroups);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
CFactorGraph::createClusters (
|
||||
const VarSignMap& varGroups,
|
||||
const FacSignMap& facGroups)
|
||||
{
|
||||
varClusters_.reserve (varGroups.size());
|
||||
for (VarSignMap::const_iterator it = varGroups.begin();
|
||||
it != varGroups.end(); it++) {
|
||||
const VarNodes& groupVars = it->second;
|
||||
VarCluster* vc = new VarCluster (groupVars);
|
||||
for (unsigned i = 0; i < groupVars.size(); i++) {
|
||||
vid2VarCluster_.insert (make_pair (groupVars[i]->varId(), vc));
|
||||
}
|
||||
varClusters_.push_back (vc);
|
||||
}
|
||||
|
||||
facClusters_.reserve (facGroups.size());
|
||||
for (FacSignMap::const_iterator it = facGroups.begin();
|
||||
it != facGroups.end(); it++) {
|
||||
FacNode* groupFactor = it->second[0];
|
||||
const VarNodes& neighs = groupFactor->neighbors();
|
||||
VarClusters varClusters;
|
||||
varClusters.reserve (neighs.size());
|
||||
for (unsigned i = 0; i < neighs.size(); i++) {
|
||||
VarId vid = neighs[i]->varId();
|
||||
varClusters.push_back (vid2VarCluster_.find (vid)->second);
|
||||
}
|
||||
facClusters_.push_back (new FacCluster (it->second, varClusters));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
const Signature&
|
||||
CFactorGraph::getSignature (const VarNode* varNode)
|
||||
{
|
||||
Signature& sign = varSignatures_[varNode->getIndex()];
|
||||
vector<Color>::iterator it = sign.colors.begin();
|
||||
const FacNodes& neighs = varNode->neighbors();
|
||||
for (unsigned i = 0; i < neighs.size(); i++) {
|
||||
*it = getColor (neighs[i]);
|
||||
it ++;
|
||||
*it = neighs[i]->factor().indexOf (varNode->varId());
|
||||
it ++;
|
||||
}
|
||||
*it = getColor (varNode);
|
||||
return sign;
|
||||
}
|
||||
|
||||
|
||||
|
||||
const Signature&
|
||||
CFactorGraph::getSignature (const FacNode* facNode)
|
||||
{
|
||||
Signature& sign = facSignatures_[facNode->getIndex()];
|
||||
vector<Color>::iterator it = sign.colors.begin();
|
||||
const VarNodes& neighs = facNode->neighbors();
|
||||
for (unsigned i = 0; i < neighs.size(); i++) {
|
||||
*it = getColor (neighs[i]);
|
||||
it ++;
|
||||
}
|
||||
*it = getColor (facNode);
|
||||
return sign;
|
||||
}
|
||||
|
||||
|
||||
|
||||
FactorGraph*
|
||||
CFactorGraph::getGroundFactorGraph (void) const
|
||||
{
|
||||
FactorGraph* fg = new FactorGraph();
|
||||
for (unsigned i = 0; i < varClusters_.size(); i++) {
|
||||
VarNode* var = varClusters_[i]->getGroundVarNodes()[0];
|
||||
VarNode* newVar = new VarNode (var);
|
||||
varClusters_[i]->setRepresentativeVariable (newVar);
|
||||
fg->addVarNode (newVar);
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < facClusters_.size(); i++) {
|
||||
const VarClusters& myVarClusters = facClusters_[i]->getVarClusters();
|
||||
Vars myGroundVars;
|
||||
myGroundVars.reserve (myVarClusters.size());
|
||||
for (unsigned j = 0; j < myVarClusters.size(); j++) {
|
||||
VarNode* v = myVarClusters[j]->getRepresentativeVariable();
|
||||
myGroundVars.push_back (v);
|
||||
}
|
||||
FacNode* fn = new FacNode (Factor (myGroundVars,
|
||||
facClusters_[i]->getGroundFactors()[0]->factor().params()));
|
||||
facClusters_[i]->setRepresentativeFactor (fn);
|
||||
fg->addFacNode (fn);
|
||||
for (unsigned j = 0; j < myGroundVars.size(); j++) {
|
||||
fg->addEdge (static_cast<VarNode*> (myGroundVars[j]), fn);
|
||||
}
|
||||
}
|
||||
return fg;
|
||||
}
|
||||
|
||||
|
||||
|
||||
unsigned
|
||||
CFactorGraph::getEdgeCount (
|
||||
const FacCluster* fc,
|
||||
const VarCluster* vc) const
|
||||
{
|
||||
unsigned count = 0;
|
||||
VarId vid = vc->getGroundVarNodes().front()->varId();
|
||||
const FacNodes& clusterGroundFactors = fc->getGroundFactors();
|
||||
for (unsigned i = 0; i < clusterGroundFactors.size(); i++) {
|
||||
if (clusterGroundFactors[i]->factor().contains (vid)) {
|
||||
count ++;
|
||||
}
|
||||
}
|
||||
// CVarNodes vars = vc->getGroundVarNodes();
|
||||
// for (unsigned i = 1; i < vars.size(); i++) {
|
||||
// VarNode* var = vc->getGroundVarNodes()[i];
|
||||
// unsigned count2 = 0;
|
||||
// for (unsigned i = 0; i < clusterGroundFactors.size(); i++) {
|
||||
// if (clusterGroundFactors[i]->getPosition (var) != -1) {
|
||||
// count2 ++;
|
||||
// }
|
||||
// }
|
||||
// if (count != count2) { cout << "oops!" << endl; abort(); }
|
||||
// }
|
||||
return count;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
CFactorGraph::printGroups (
|
||||
const VarSignMap& varGroups,
|
||||
const FacSignMap& facGroups) const
|
||||
{
|
||||
unsigned count = 1;
|
||||
cout << "variable groups:" << endl;
|
||||
for (VarSignMap::const_iterator it = varGroups.begin();
|
||||
it != varGroups.end(); it++) {
|
||||
const VarNodes& groupMembers = it->second;
|
||||
if (groupMembers.size() > 0) {
|
||||
cout << count << ": " ;
|
||||
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
||||
cout << groupMembers[i]->label() << " " ;
|
||||
}
|
||||
count ++;
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
|
||||
count = 1;
|
||||
cout << endl << "factor groups:" << endl;
|
||||
for (FacSignMap::const_iterator it = facGroups.begin();
|
||||
it != facGroups.end(); it++) {
|
||||
const FacNodes& groupMembers = it->second;
|
||||
if (groupMembers.size() > 0) {
|
||||
cout << ++count << ": " ;
|
||||
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
||||
cout << groupMembers[i]->getLabel() << " " ;
|
||||
}
|
||||
count ++;
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -1,247 +0,0 @@
|
||||
#ifndef HORUS_CFACTORGRAPH_H
|
||||
#define HORUS_CFACTORGRAPH_H
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "FactorGraph.h"
|
||||
#include "Factor.h"
|
||||
#include "Horus.h"
|
||||
|
||||
class VarCluster;
|
||||
class FacCluster;
|
||||
class Distribution;
|
||||
class Signature;
|
||||
|
||||
class SignatureHash;
|
||||
|
||||
|
||||
typedef long Color;
|
||||
|
||||
typedef unordered_map<unsigned, vector<Color>> VarColorMap;
|
||||
|
||||
typedef unordered_map<unsigned, Color> DistColorMap;
|
||||
typedef unordered_map<VarId, VarCluster*> VarId2VarCluster;
|
||||
|
||||
typedef vector<VarCluster*> VarClusters;
|
||||
typedef vector<FacCluster*> FacClusters;
|
||||
|
||||
typedef unordered_map<Signature, VarNodes, SignatureHash> VarSignMap;
|
||||
typedef unordered_map<Signature, FacNodes, SignatureHash> FacSignMap;
|
||||
|
||||
|
||||
|
||||
struct Signature
|
||||
{
|
||||
Signature (unsigned size) : colors(size) { }
|
||||
|
||||
bool operator< (const Signature& sig) const
|
||||
{
|
||||
if (colors.size() < sig.colors.size()) {
|
||||
return true;
|
||||
} else if (colors.size() > sig.colors.size()) {
|
||||
return false;
|
||||
} else {
|
||||
for (unsigned i = 0; i < colors.size(); i++) {
|
||||
if (colors[i] < sig.colors[i]) {
|
||||
return true;
|
||||
} else if (colors[i] > sig.colors[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool operator== (const Signature& sig) const
|
||||
{
|
||||
if (colors.size() != sig.colors.size()) {
|
||||
return false;
|
||||
}
|
||||
for (unsigned i = 0; i < colors.size(); i++) {
|
||||
if (colors[i] != sig.colors[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
vector<Color> colors;
|
||||
};
|
||||
|
||||
|
||||
|
||||
struct SignatureHash
|
||||
{
|
||||
size_t operator() (const Signature &sig) const
|
||||
{
|
||||
size_t val = hash<size_t>()(sig.colors.size());
|
||||
for (unsigned i = 0; i < sig.colors.size(); i++) {
|
||||
val ^= hash<size_t>()(sig.colors[i]);
|
||||
}
|
||||
return val;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
||||
class VarCluster
|
||||
{
|
||||
public:
|
||||
VarCluster (const VarNodes& vs)
|
||||
{
|
||||
for (unsigned i = 0; i < vs.size(); i++) {
|
||||
groundVars_.push_back (vs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void addFacCluster (FacCluster* fc)
|
||||
{
|
||||
facClusters_.push_back (fc);
|
||||
}
|
||||
|
||||
const FacClusters& getFacClusters (void) const
|
||||
{
|
||||
return facClusters_;
|
||||
}
|
||||
|
||||
VarNode* getRepresentativeVariable (void) const { return representVar_; }
|
||||
|
||||
void setRepresentativeVariable (VarNode* v) { representVar_ = v; }
|
||||
|
||||
const VarNodes& getGroundVarNodes (void) const { return groundVars_; }
|
||||
|
||||
private:
|
||||
VarNodes groundVars_;
|
||||
FacClusters facClusters_;
|
||||
VarNode* representVar_;
|
||||
};
|
||||
|
||||
|
||||
class FacCluster
|
||||
{
|
||||
public:
|
||||
FacCluster (const FacNodes& groundFactors, const VarClusters& vcs)
|
||||
{
|
||||
groundFactors_ = groundFactors;
|
||||
varClusters_ = vcs;
|
||||
for (unsigned i = 0; i < varClusters_.size(); i++) {
|
||||
varClusters_[i]->addFacCluster (this);
|
||||
}
|
||||
}
|
||||
|
||||
const VarClusters& getVarClusters (void) const
|
||||
{
|
||||
return varClusters_;
|
||||
}
|
||||
|
||||
bool containsGround (const FacNode* fn)
|
||||
{
|
||||
for (unsigned i = 0; i < groundFactors_.size(); i++) {
|
||||
if (groundFactors_[i] == fn) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
FacNode* getRepresentativeFactor (void) const
|
||||
{
|
||||
return representFactor_;
|
||||
}
|
||||
|
||||
void setRepresentativeFactor (FacNode* fn)
|
||||
{
|
||||
representFactor_ = fn;
|
||||
}
|
||||
|
||||
const FacNodes& getGroundFactors (void) const
|
||||
{
|
||||
return groundFactors_;
|
||||
}
|
||||
|
||||
|
||||
private:
|
||||
FacNodes groundFactors_;
|
||||
VarClusters varClusters_;
|
||||
FacNode* representFactor_;
|
||||
};
|
||||
|
||||
|
||||
class CFactorGraph
|
||||
{
|
||||
public:
|
||||
CFactorGraph (const FactorGraph&);
|
||||
|
||||
~CFactorGraph (void);
|
||||
|
||||
const VarClusters& getVarClusters (void) { return varClusters_; }
|
||||
|
||||
const FacClusters& getFacClusters (void) { return facClusters_; }
|
||||
|
||||
VarNode* getEquivalentVariable (VarId vid)
|
||||
{
|
||||
VarCluster* vc = vid2VarCluster_.find (vid)->second;
|
||||
return vc->getRepresentativeVariable();
|
||||
}
|
||||
|
||||
FactorGraph* getGroundFactorGraph (void) const;
|
||||
|
||||
unsigned getEdgeCount (const FacCluster*, const VarCluster*) const;
|
||||
|
||||
static bool checkForIdenticalFactors;
|
||||
|
||||
private:
|
||||
Color getFreeColor (void)
|
||||
{
|
||||
++ freeColor_;
|
||||
return freeColor_ - 1;
|
||||
}
|
||||
|
||||
Color getColor (const VarNode* vn) const
|
||||
{
|
||||
return varColors_[vn->getIndex()];
|
||||
}
|
||||
Color getColor (const FacNode* fn) const {
|
||||
return facColors_[fn->getIndex()];
|
||||
}
|
||||
|
||||
void setColor (const VarNode* vn, Color c)
|
||||
{
|
||||
varColors_[vn->getIndex()] = c;
|
||||
}
|
||||
|
||||
void setColor (const FacNode* fn, Color c)
|
||||
{
|
||||
facColors_[fn->getIndex()] = c;
|
||||
}
|
||||
|
||||
VarCluster* getVariableCluster (VarId vid) const
|
||||
{
|
||||
return vid2VarCluster_.find (vid)->second;
|
||||
}
|
||||
|
||||
void setInitialColors (void);
|
||||
|
||||
void createGroups (void);
|
||||
|
||||
void createClusters (const VarSignMap&, const FacSignMap&);
|
||||
|
||||
const Signature& getSignature (const VarNode*);
|
||||
|
||||
const Signature& getSignature (const FacNode*);
|
||||
|
||||
void printGroups (const VarSignMap&, const FacSignMap&) const;
|
||||
|
||||
Color freeColor_;
|
||||
vector<Color> varColors_;
|
||||
vector<Color> facColors_;
|
||||
vector<Signature> varSignatures_;
|
||||
vector<Signature> facSignatures_;
|
||||
VarClusters varClusters_;
|
||||
FacClusters facClusters_;
|
||||
VarId2VarCluster vid2VarCluster_;
|
||||
const FactorGraph* groundFg_;
|
||||
};
|
||||
|
||||
#endif // HORUS_CFACTORGRAPH_H
|
||||
|
@@ -1,244 +0,0 @@
|
||||
#include "CbpSolver.h"
|
||||
|
||||
|
||||
CbpSolver::CbpSolver (const FactorGraph& fg) : BpSolver (fg)
|
||||
{
|
||||
unsigned nGroundVars, nGroundFacs, nWithoutNeighs;
|
||||
if (Constants::COLLECT_STATS) {
|
||||
nGroundVars = fg_->varNodes().size();
|
||||
nGroundFacs = fg_->facNodes().size();
|
||||
const VarNodes& vars = fg_->varNodes();
|
||||
nWithoutNeighs = 0;
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
const FacNodes& factors = vars[i]->neighbors();
|
||||
if (factors.size() == 1 && factors[0]->neighbors().size() == 1) {
|
||||
nWithoutNeighs ++;
|
||||
}
|
||||
}
|
||||
}
|
||||
cfg_ = new CFactorGraph (fg);
|
||||
fg_ = cfg_->getGroundFactorGraph();
|
||||
if (Constants::COLLECT_STATS) {
|
||||
unsigned nClusterVars = fg_->varNodes().size();
|
||||
unsigned nClusterFacs = fg_->facNodes().size();
|
||||
Statistics::updateCompressingStatistics (nGroundVars,
|
||||
nGroundFacs, nClusterVars, nClusterFacs, nWithoutNeighs);
|
||||
}
|
||||
Util::printHeader ("Uncompressed Factor Graph");
|
||||
fg.print();
|
||||
Util::printHeader ("Compressed Factor Graph");
|
||||
fg_->print();
|
||||
}
|
||||
|
||||
|
||||
|
||||
CbpSolver::~CbpSolver (void)
|
||||
{
|
||||
delete cfg_;
|
||||
delete fg_;
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
delete links_[i];
|
||||
}
|
||||
links_.clear();
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
CbpSolver::getPosterioriOf (VarId vid)
|
||||
{
|
||||
if (runned_ == false) {
|
||||
runSolver();
|
||||
}
|
||||
assert (cfg_->getEquivalentVariable (vid));
|
||||
VarNode* var = cfg_->getEquivalentVariable (vid);
|
||||
Params probs;
|
||||
if (var->hasEvidence()) {
|
||||
probs.resize (var->range(), LogAware::noEvidence());
|
||||
probs[var->getEvidence()] = LogAware::withEvidence();
|
||||
} else {
|
||||
probs.resize (var->range(), LogAware::multIdenty());
|
||||
const SpLinkSet& links = ninf(var)->getLinks();
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
||||
Util::add (probs, l->poweredMessage());
|
||||
}
|
||||
LogAware::normalize (probs);
|
||||
Util::fromLog (probs);
|
||||
} else {
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
||||
Util::multiply (probs, l->poweredMessage());
|
||||
}
|
||||
LogAware::normalize (probs);
|
||||
}
|
||||
}
|
||||
return probs;
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
CbpSolver::getJointDistributionOf (const VarIds& jointVids)
|
||||
{
|
||||
VarIds eqVarIds;
|
||||
for (unsigned i = 0; i < jointVids.size(); i++) {
|
||||
VarNode* vn = cfg_->getEquivalentVariable (jointVids[i]);
|
||||
eqVarIds.push_back (vn->varId());
|
||||
}
|
||||
return BpSolver::getJointDistributionOf (eqVarIds);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
CbpSolver::createLinks (void)
|
||||
{
|
||||
const FacClusters& fcs = cfg_->getFacClusters();
|
||||
for (unsigned i = 0; i < fcs.size(); i++) {
|
||||
const VarClusters& vcs = fcs[i]->getVarClusters();
|
||||
for (unsigned j = 0; j < vcs.size(); j++) {
|
||||
unsigned c = cfg_->getEdgeCount (fcs[i], vcs[j]);
|
||||
links_.push_back (new CbpSolverLink (
|
||||
fcs[i]->getRepresentativeFactor(),
|
||||
vcs[j]->getRepresentativeVariable(), c));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
CbpSolver::maxResidualSchedule (void)
|
||||
{
|
||||
if (nIters_ == 1) {
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
calculateMessage (links_[i]);
|
||||
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
|
||||
linkMap_.insert (make_pair (links_[i], it));
|
||||
if (Constants::DEBUG >= 2 && Constants::DEBUG < 5) {
|
||||
cout << "calculating " << links_[i]->toString() << endl;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
for (unsigned c = 0; c < links_.size(); c++) {
|
||||
if (Constants::DEBUG >= 2) {
|
||||
cout << endl << "current residuals:" << endl;
|
||||
for (SortedOrder::iterator it = sortedOrder_.begin();
|
||||
it != sortedOrder_.end(); it ++) {
|
||||
cout << " " << setw (30) << left << (*it)->toString();
|
||||
cout << "residual = " << (*it)->getResidual() << endl;
|
||||
}
|
||||
}
|
||||
|
||||
SortedOrder::iterator it = sortedOrder_.begin();
|
||||
SpLink* link = *it;
|
||||
if (Constants::DEBUG >= 2) {
|
||||
cout << "updating " << (*sortedOrder_.begin())->toString() << endl;
|
||||
}
|
||||
if (link->getResidual() < BpOptions::accuracy) {
|
||||
return;
|
||||
}
|
||||
link->updateMessage();
|
||||
link->clearResidual();
|
||||
sortedOrder_.erase (it);
|
||||
linkMap_.find (link)->second = sortedOrder_.insert (link);
|
||||
|
||||
// update the messages that depend on message source --> destin
|
||||
const FacNodes& factorNeighbors = link->getVariable()->neighbors();
|
||||
for (unsigned i = 0; i < factorNeighbors.size(); i++) {
|
||||
const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks();
|
||||
for (unsigned j = 0; j < links.size(); j++) {
|
||||
if (links[j]->getVariable() != link->getVariable()) {
|
||||
if (Constants::DEBUG >= 2 && Constants::DEBUG < 5) {
|
||||
cout << " calculating " << links[j]->toString() << endl;
|
||||
}
|
||||
calculateMessage (links[j]);
|
||||
SpLinkMap::iterator iter = linkMap_.find (links[j]);
|
||||
sortedOrder_.erase (iter->second);
|
||||
iter->second = sortedOrder_.insert (links[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
// in counting bp, the message that a variable X sends to
|
||||
// to a factor F depends on the message that F sent to the X
|
||||
const SpLinkSet& links = ninf(link->getFactor())->getLinks();
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
if (links[i]->getVariable() != link->getVariable()) {
|
||||
if (Constants::DEBUG >= 2 && Constants::DEBUG < 5) {
|
||||
cout << " calculating " << links[i]->toString() << endl;
|
||||
}
|
||||
calculateMessage (links[i]);
|
||||
SpLinkMap::iterator iter = linkMap_.find (links[i]);
|
||||
sortedOrder_.erase (iter->second);
|
||||
iter->second = sortedOrder_.insert (links[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
CbpSolver::getVar2FactorMsg (const SpLink* link) const
|
||||
{
|
||||
Params msg;
|
||||
const VarNode* src = link->getVariable();
|
||||
const FacNode* dst = link->getFactor();
|
||||
const CbpSolverLink* l = static_cast<const CbpSolverLink*> (link);
|
||||
if (src->hasEvidence()) {
|
||||
msg.resize (src->range(), LogAware::noEvidence());
|
||||
double value = link->getMessage()[src->getEvidence()];
|
||||
msg[src->getEvidence()] = LogAware::pow (value, l->nrEdges() - 1);
|
||||
} else {
|
||||
msg = link->getMessage();
|
||||
LogAware::pow (msg, l->nrEdges() - 1);
|
||||
}
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " " << "init: " << msg << endl;
|
||||
}
|
||||
const SpLinkSet& links = ninf(src)->getLinks();
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
if (links[i]->getFactor() != dst) {
|
||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
||||
Util::add (msg, l->poweredMessage());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
if (links[i]->getFactor() != dst) {
|
||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
||||
Util::multiply (msg, l->poweredMessage());
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " msg from " << l->getFactor()->getLabel() << ": " ;
|
||||
cout << l->poweredMessage() << endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " result = " << msg << endl;
|
||||
}
|
||||
return msg;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
CbpSolver::printLinkInformation (void) const
|
||||
{
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links_[i]);
|
||||
cout << l->toString() << ":" << endl;
|
||||
cout << " curr msg = " << l->getMessage() << endl;
|
||||
cout << " next msg = " << l->getNextMessage() << endl;
|
||||
cout << " powered = " << l->poweredMessage() << endl;
|
||||
cout << " residual = " << l->getResidual() << endl;
|
||||
}
|
||||
}
|
||||
|
@@ -1,60 +0,0 @@
|
||||
#ifndef HORUS_CBP_H
|
||||
#define HORUS_CBP_H
|
||||
|
||||
#include "BpSolver.h"
|
||||
#include "CFactorGraph.h"
|
||||
|
||||
class Factor;
|
||||
|
||||
class CbpSolverLink : public SpLink
|
||||
{
|
||||
public:
|
||||
CbpSolverLink (FacNode* fn, VarNode* vn, unsigned c)
|
||||
: SpLink (fn, vn), nrEdges_(c),
|
||||
pwdMsg_(vn->range(), LogAware::one()) { }
|
||||
|
||||
unsigned nrEdges (void) const { return nrEdges_; }
|
||||
|
||||
const Params& poweredMessage (void) const { return pwdMsg_; }
|
||||
|
||||
void updateMessage (void)
|
||||
{
|
||||
pwdMsg_ = *nextMsg_;
|
||||
swap (currMsg_, nextMsg_);
|
||||
msgSended_ = true;
|
||||
LogAware::pow (pwdMsg_, nrEdges_);
|
||||
}
|
||||
|
||||
private:
|
||||
unsigned nrEdges_;
|
||||
Params pwdMsg_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class CbpSolver : public BpSolver
|
||||
{
|
||||
public:
|
||||
CbpSolver (const FactorGraph& fg);
|
||||
|
||||
~CbpSolver (void);
|
||||
|
||||
Params getPosterioriOf (VarId);
|
||||
|
||||
Params getJointDistributionOf (const VarIds&);
|
||||
|
||||
private:
|
||||
|
||||
void createLinks (void);
|
||||
|
||||
void maxResidualSchedule (void);
|
||||
|
||||
Params getVar2FactorMsg (const SpLink*) const;
|
||||
|
||||
void printLinkInformation (void) const;
|
||||
|
||||
CFactorGraph* cfg_;
|
||||
};
|
||||
|
||||
#endif // HORUS_CBP_H
|
||||
|
File diff suppressed because it is too large
Load Diff
@@ -1,204 +0,0 @@
|
||||
#ifndef HORUS_CONSTRAINTTREE_H
|
||||
#define HORUS_CONSTRAINTTREE_H
|
||||
|
||||
#include <cassert>
|
||||
#include <algorithm>
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "TinySet.h"
|
||||
#include "LiftedUtils.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
||||
class CTNode;
|
||||
typedef vector<CTNode*> CTNodes;
|
||||
|
||||
class ConstraintTree;
|
||||
typedef vector<ConstraintTree*> ConstraintTrees;
|
||||
|
||||
|
||||
|
||||
class CTNode
|
||||
{
|
||||
public:
|
||||
CTNode (const CTNode& n) : symbol_(n.symbol()), level_(n.level()) { }
|
||||
|
||||
CTNode (Symbol s, unsigned l) : symbol_(s) , level_(l) { }
|
||||
|
||||
unsigned level (void) const { return level_; }
|
||||
|
||||
void setLevel (unsigned level) { level_ = level; }
|
||||
|
||||
Symbol symbol (void) const { return symbol_; }
|
||||
|
||||
void setSymbol (const Symbol s) { symbol_ = s; }
|
||||
|
||||
CTNodes& childs (void) { return childs_; }
|
||||
|
||||
const CTNodes& childs (void) const { return childs_; }
|
||||
|
||||
unsigned nrChilds (void) const { return childs_.size(); }
|
||||
|
||||
bool isRoot (void) const { return level_ == 0; }
|
||||
|
||||
bool isLeaf (void) const { return childs_.empty(); }
|
||||
|
||||
void addChild (CTNode*, bool = true);
|
||||
|
||||
void removeChild (CTNode*);
|
||||
|
||||
void removeChilds (void);
|
||||
|
||||
void removeAndDeleteChild (CTNode*);
|
||||
|
||||
void removeAndDeleteAllChilds (void);
|
||||
|
||||
SymbolSet childSymbols (void) const;
|
||||
|
||||
static CTNode* copySubtree (const CTNode*);
|
||||
|
||||
static void deleteSubtree (CTNode*);
|
||||
|
||||
private:
|
||||
void updateChildLevels (CTNode*, unsigned);
|
||||
|
||||
Symbol symbol_;
|
||||
CTNodes childs_;
|
||||
unsigned level_;
|
||||
};
|
||||
|
||||
ostream& operator<< (ostream &out, const CTNode&);
|
||||
|
||||
|
||||
class ConstraintTree
|
||||
{
|
||||
public:
|
||||
ConstraintTree (unsigned);
|
||||
|
||||
ConstraintTree (const LogVars&);
|
||||
|
||||
ConstraintTree (const LogVars&, const Tuples&);
|
||||
|
||||
ConstraintTree (const ConstraintTree&);
|
||||
|
||||
~ConstraintTree (void);
|
||||
|
||||
CTNode* root (void) const { return root_; }
|
||||
|
||||
bool empty (void) const { return root_->childs().empty(); }
|
||||
|
||||
const LogVars& logVars (void) const
|
||||
{
|
||||
assert (LogVarSet (logVars_) == logVarSet_);
|
||||
return logVars_;
|
||||
}
|
||||
|
||||
const LogVarSet& logVarSet (void) const
|
||||
{
|
||||
assert (LogVarSet (logVars_) == logVarSet_);
|
||||
return logVarSet_;
|
||||
}
|
||||
|
||||
unsigned nrLogVars (void) const
|
||||
{
|
||||
return logVars_.size();
|
||||
assert (LogVarSet (logVars_) == logVarSet_);
|
||||
}
|
||||
|
||||
void addTuple (const Tuple&);
|
||||
|
||||
bool containsTuple (const Tuple&);
|
||||
|
||||
void moveToTop (const LogVars&);
|
||||
|
||||
void moveToBottom (const LogVars&);
|
||||
|
||||
void join (ConstraintTree*, bool = false);
|
||||
|
||||
unsigned getLevel (LogVar) const;
|
||||
|
||||
void rename (LogVar, LogVar);
|
||||
|
||||
void applySubstitution (const Substitution&);
|
||||
|
||||
void project (const LogVarSet&);
|
||||
|
||||
void remove (const LogVarSet&);
|
||||
|
||||
bool isSingleton (LogVar);
|
||||
|
||||
LogVarSet singletons (void);
|
||||
|
||||
TupleSet tupleSet (unsigned = 0) const;
|
||||
|
||||
TupleSet tupleSet (const LogVars&);
|
||||
|
||||
unsigned size (void) const;
|
||||
|
||||
unsigned nrSymbols (LogVar);
|
||||
|
||||
void exportToGraphViz (const char*, bool = false) const;
|
||||
|
||||
bool isCountNormalized (const LogVarSet&);
|
||||
|
||||
unsigned getConditionalCount (const LogVarSet&);
|
||||
|
||||
TinySet<unsigned> getConditionalCounts (const LogVarSet&);
|
||||
|
||||
bool isCarteesianProduct (const LogVarSet&) const;
|
||||
|
||||
std::pair<ConstraintTree*, ConstraintTree*> split (
|
||||
const Tuple&, unsigned);
|
||||
|
||||
std::pair<ConstraintTree*, ConstraintTree*> split (
|
||||
const ConstraintTree*, unsigned) const;
|
||||
|
||||
ConstraintTrees countNormalize (const LogVarSet&);
|
||||
|
||||
ConstraintTrees jointCountNormalize (
|
||||
ConstraintTree*, ConstraintTree*, LogVar, LogVar, LogVar);
|
||||
|
||||
static bool identical (
|
||||
const ConstraintTree*, const ConstraintTree*, unsigned);
|
||||
|
||||
static bool overlap (
|
||||
const ConstraintTree*, const ConstraintTree*, unsigned);
|
||||
|
||||
LogVars expand (LogVar);
|
||||
ConstraintTrees ground (LogVar);
|
||||
|
||||
private:
|
||||
unsigned countTuples (const CTNode*) const;
|
||||
|
||||
CTNodes getNodesBelow (CTNode*) const;
|
||||
|
||||
CTNodes getNodesAtLevel (unsigned) const;
|
||||
|
||||
void swapLogVar (LogVar);
|
||||
|
||||
bool join (CTNode*, const Tuple&, unsigned, CTNode*);
|
||||
|
||||
bool indenticalSubtrees (
|
||||
const CTNode*, const CTNode*, bool) const;
|
||||
|
||||
void getTuples (CTNode*, Tuples, unsigned, Tuples&, CTNodes&) const;
|
||||
|
||||
vector<std::pair<CTNode*, unsigned>> countNormalize (
|
||||
const CTNode*, unsigned);
|
||||
|
||||
static void split (
|
||||
CTNode*, CTNode*, CTNodes&, unsigned);
|
||||
|
||||
static bool overlap (const CTNode*, const CTNode*, unsigned);
|
||||
|
||||
CTNode* root_;
|
||||
LogVars logVars_;
|
||||
LogVarSet logVarSet_;
|
||||
};
|
||||
|
||||
|
||||
#endif // HORUS_CONSTRAINTTREE_H
|
||||
|
@@ -1,303 +0,0 @@
|
||||
#include <limits>
|
||||
|
||||
#include <fstream>
|
||||
|
||||
#include "ElimGraph.h"
|
||||
|
||||
ElimHeuristic ElimGraph::elimHeuristic_ = MIN_NEIGHBORS;
|
||||
|
||||
|
||||
ElimGraph::ElimGraph (const vector<Factor*>& factors)
|
||||
{
|
||||
for (unsigned i = 0; i < factors.size(); i++) {
|
||||
const VarIds& vids = factors[i]->arguments();
|
||||
for (unsigned j = 0; j < vids.size() - 1; j++) {
|
||||
EgNode* n1 = getEgNode (vids[j]);
|
||||
if (n1 == 0) {
|
||||
n1 = new EgNode (vids[j], factors[i]->range (j));
|
||||
addNode (n1);
|
||||
}
|
||||
for (unsigned k = j + 1; k < vids.size(); k++) {
|
||||
EgNode* n2 = getEgNode (vids[k]);
|
||||
if (n2 == 0) {
|
||||
n2 = new EgNode (vids[k], factors[i]->range (k));
|
||||
addNode (n2);
|
||||
}
|
||||
if (neighbors (n1, n2) == false) {
|
||||
addEdge (n1, n2);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (vids.size() == 1) {
|
||||
if (getEgNode (vids[0]) == 0) {
|
||||
addNode (new EgNode (vids[0], factors[i]->range (0)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
ElimGraph::~ElimGraph (void)
|
||||
{
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
delete nodes_[i];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
VarIds
|
||||
ElimGraph::getEliminatingOrder (const VarIds& exclude)
|
||||
{
|
||||
VarIds elimOrder;
|
||||
marked_.resize (nodes_.size(), false);
|
||||
for (unsigned i = 0; i < exclude.size(); i++) {
|
||||
assert (getEgNode (exclude[i]));
|
||||
EgNode* node = getEgNode (exclude[i]);
|
||||
marked_[*node] = true;
|
||||
}
|
||||
unsigned nVarsToEliminate = nodes_.size() - exclude.size();
|
||||
for (unsigned i = 0; i < nVarsToEliminate; i++) {
|
||||
EgNode* node = getLowestCostNode();
|
||||
marked_[*node] = true;
|
||||
elimOrder.push_back (node->varId());
|
||||
connectAllNeighbors (node);
|
||||
}
|
||||
return elimOrder;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
ElimGraph::print (void) const
|
||||
{
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
cout << "node " << nodes_[i]->label() << " neighs:" ;
|
||||
vector<EgNode*> neighs = nodes_[i]->neighbors();
|
||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||
cout << " " << neighs[j]->label();
|
||||
}
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
ElimGraph::exportToGraphViz (
|
||||
const char* fileName,
|
||||
bool showNeighborless,
|
||||
const VarIds& highlightVarIds) const
|
||||
{
|
||||
ofstream out (fileName);
|
||||
if (!out.is_open()) {
|
||||
cerr << "error: cannot open file to write at " ;
|
||||
cerr << "Markov::exportToDotFile()" << endl;
|
||||
abort();
|
||||
}
|
||||
|
||||
out << "strict graph {" << endl;
|
||||
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
if (showNeighborless || nodes_[i]->neighbors().size() != 0) {
|
||||
out << '"' << nodes_[i]->label() << '"' << endl;
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < highlightVarIds.size(); i++) {
|
||||
EgNode* node =getEgNode (highlightVarIds[i]);
|
||||
if (node) {
|
||||
out << '"' << node->label() << '"' ;
|
||||
out << " [shape=box3d]" << endl;
|
||||
} else {
|
||||
cout << "error: invalid variable id: " << highlightVarIds[i] << endl;
|
||||
abort();
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
vector<EgNode*> neighs = nodes_[i]->neighbors();
|
||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||
out << '"' << nodes_[i]->label() << '"' << " -- " ;
|
||||
out << '"' << neighs[j]->label() << '"' << endl;
|
||||
}
|
||||
}
|
||||
|
||||
out << "}" << endl;
|
||||
out.close();
|
||||
}
|
||||
|
||||
|
||||
|
||||
VarIds
|
||||
ElimGraph::getEliminationOrder (
|
||||
const vector<Factor*> factors,
|
||||
VarIds excludedVids)
|
||||
{
|
||||
ElimGraph graph (factors);
|
||||
// graph.print();
|
||||
// graph.exportToGraphViz ("_egg.dot");
|
||||
return graph.getEliminatingOrder (excludedVids);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
ElimGraph::addNode (EgNode* n)
|
||||
{
|
||||
nodes_.push_back (n);
|
||||
n->setIndex (nodes_.size() - 1);
|
||||
varMap_.insert (make_pair (n->varId(), n));
|
||||
}
|
||||
|
||||
|
||||
|
||||
EgNode*
|
||||
ElimGraph::getEgNode (VarId vid) const
|
||||
{
|
||||
unordered_map<VarId, EgNode*>::const_iterator it;
|
||||
it = varMap_.find (vid);
|
||||
return (it != varMap_.end()) ? it->second : 0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
EgNode*
|
||||
ElimGraph::getLowestCostNode (void) const
|
||||
{
|
||||
EgNode* bestNode = 0;
|
||||
unsigned minCost = std::numeric_limits<unsigned>::max();
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
if (marked_[i]) continue;
|
||||
unsigned cost = 0;
|
||||
switch (elimHeuristic_) {
|
||||
case MIN_NEIGHBORS:
|
||||
cost = getNeighborsCost (nodes_[i]);
|
||||
break;
|
||||
case MIN_WEIGHT:
|
||||
cost = getWeightCost (nodes_[i]);
|
||||
break;
|
||||
case MIN_FILL:
|
||||
cost = getFillCost (nodes_[i]);
|
||||
break;
|
||||
case WEIGHTED_MIN_FILL:
|
||||
cost = getWeightedFillCost (nodes_[i]);
|
||||
break;
|
||||
default:
|
||||
assert (false);
|
||||
}
|
||||
if (cost < minCost) {
|
||||
bestNode = nodes_[i];
|
||||
minCost = cost;
|
||||
}
|
||||
}
|
||||
assert (bestNode);
|
||||
return bestNode;
|
||||
}
|
||||
|
||||
|
||||
|
||||
unsigned
|
||||
ElimGraph::getNeighborsCost (const EgNode* n) const
|
||||
{
|
||||
unsigned cost = 0;
|
||||
const vector<EgNode*>& neighs = n->neighbors();
|
||||
for (unsigned i = 0; i < neighs.size(); i++) {
|
||||
if (marked_[*neighs[i]] == false) {
|
||||
cost ++;
|
||||
}
|
||||
}
|
||||
return cost;
|
||||
}
|
||||
|
||||
|
||||
|
||||
unsigned
|
||||
ElimGraph::getWeightCost (const EgNode* n) const
|
||||
{
|
||||
unsigned cost = 1;
|
||||
const vector<EgNode*>& neighs = n->neighbors();
|
||||
for (unsigned i = 0; i < neighs.size(); i++) {
|
||||
if (marked_[*neighs[i]] == false) {
|
||||
cost *= neighs[i]->range();
|
||||
}
|
||||
}
|
||||
return cost;
|
||||
}
|
||||
|
||||
|
||||
|
||||
unsigned
|
||||
ElimGraph::getFillCost (const EgNode* n) const
|
||||
{
|
||||
unsigned cost = 0;
|
||||
const vector<EgNode*>& neighs = n->neighbors();
|
||||
if (neighs.size() > 0) {
|
||||
for (unsigned i = 0; i < neighs.size() - 1; i++) {
|
||||
if (marked_[*neighs[i]] == true) continue;
|
||||
for (unsigned j = i+1; j < neighs.size(); j++) {
|
||||
if (marked_[*neighs[j]] == true) continue;
|
||||
if (!neighbors (neighs[i], neighs[j])) {
|
||||
cost ++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return cost;
|
||||
}
|
||||
|
||||
|
||||
|
||||
unsigned
|
||||
ElimGraph::getWeightedFillCost (const EgNode* n) const
|
||||
{
|
||||
unsigned cost = 0;
|
||||
const vector<EgNode*>& neighs = n->neighbors();
|
||||
if (neighs.size() > 0) {
|
||||
for (unsigned i = 0; i < neighs.size() - 1; i++) {
|
||||
if (marked_[*neighs[i]] == true) continue;
|
||||
for (unsigned j = i+1; j < neighs.size(); j++) {
|
||||
if (marked_[*neighs[j]] == true) continue;
|
||||
if (!neighbors (neighs[i], neighs[j])) {
|
||||
cost += neighs[i]->range() * neighs[j]->range();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return cost;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
ElimGraph::connectAllNeighbors (const EgNode* n)
|
||||
{
|
||||
const vector<EgNode*>& neighs = n->neighbors();
|
||||
if (neighs.size() > 0) {
|
||||
for (unsigned i = 0; i < neighs.size() - 1; i++) {
|
||||
if (marked_[*neighs[i]] == true) continue;
|
||||
for (unsigned j = i+1; j < neighs.size(); j++) {
|
||||
if (marked_[*neighs[j]] == true) continue;
|
||||
if (!neighbors (neighs[i], neighs[j])) {
|
||||
addEdge (neighs[i], neighs[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
ElimGraph::neighbors (const EgNode* n1, const EgNode* n2) const
|
||||
{
|
||||
const vector<EgNode*>& neighs = n1->neighbors();
|
||||
for (unsigned i = 0; i < neighs.size(); i++) {
|
||||
if (neighs[i] == n2) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
@@ -1,88 +0,0 @@
|
||||
#ifndef HORUS_ELIMGRAPH_H
|
||||
#define HORUS_ELIMGRAPH_H
|
||||
|
||||
#include "unordered_map"
|
||||
|
||||
#include "FactorGraph.h"
|
||||
#include "Horus.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
enum ElimHeuristic
|
||||
{
|
||||
MIN_NEIGHBORS,
|
||||
MIN_WEIGHT,
|
||||
MIN_FILL,
|
||||
WEIGHTED_MIN_FILL
|
||||
};
|
||||
|
||||
|
||||
class EgNode : public Var
|
||||
{
|
||||
public:
|
||||
EgNode (VarId vid, unsigned range) : Var (vid, range) { }
|
||||
|
||||
void addNeighbor (EgNode* n) { neighs_.push_back (n); }
|
||||
|
||||
const vector<EgNode*>& neighbors (void) const { return neighs_; }
|
||||
|
||||
private:
|
||||
vector<EgNode*> neighs_;
|
||||
};
|
||||
|
||||
|
||||
class ElimGraph
|
||||
{
|
||||
public:
|
||||
ElimGraph (const vector<Factor*>&); // TODO
|
||||
|
||||
~ElimGraph (void);
|
||||
|
||||
VarIds getEliminatingOrder (const VarIds&);
|
||||
|
||||
void print (void) const;
|
||||
|
||||
void exportToGraphViz (const char*, bool = true,
|
||||
const VarIds& = VarIds()) const;
|
||||
|
||||
static VarIds getEliminationOrder (const vector<Factor*>, VarIds);
|
||||
|
||||
static void setEliminationHeuristic (ElimHeuristic h)
|
||||
{
|
||||
elimHeuristic_ = h;
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
void addEdge (EgNode* n1, EgNode* n2)
|
||||
{
|
||||
assert (n1 != n2);
|
||||
n1->addNeighbor (n2);
|
||||
n2->addNeighbor (n1);
|
||||
}
|
||||
|
||||
void addNode (EgNode*);
|
||||
|
||||
EgNode* getEgNode (VarId) const;
|
||||
EgNode* getLowestCostNode (void) const;
|
||||
|
||||
unsigned getNeighborsCost (const EgNode*) const;
|
||||
|
||||
unsigned getWeightCost (const EgNode*) const;
|
||||
|
||||
unsigned getFillCost (const EgNode*) const;
|
||||
|
||||
unsigned getWeightedFillCost (const EgNode*) const;
|
||||
|
||||
void connectAllNeighbors (const EgNode*);
|
||||
|
||||
bool neighbors (const EgNode*, const EgNode*) const;
|
||||
|
||||
vector<EgNode*> nodes_;
|
||||
vector<bool> marked_;
|
||||
unordered_map<VarId, EgNode*> varMap_;
|
||||
static ElimHeuristic elimHeuristic_;
|
||||
};
|
||||
|
||||
#endif // HORUS_ELIMGRAPH_H
|
||||
|
@@ -1,265 +0,0 @@
|
||||
#include <cstdlib>
|
||||
#include <cassert>
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "Factor.h"
|
||||
#include "Indexer.h"
|
||||
|
||||
|
||||
|
||||
Factor::Factor (const Factor& g)
|
||||
{
|
||||
copyFromFactor (g);
|
||||
}
|
||||
|
||||
|
||||
|
||||
Factor::Factor (
|
||||
const VarIds& vids,
|
||||
const Ranges& ranges,
|
||||
const Params& params,
|
||||
unsigned distId)
|
||||
{
|
||||
args_ = vids;
|
||||
ranges_ = ranges;
|
||||
params_ = params;
|
||||
distId_ = distId;
|
||||
assert (params_.size() == Util::expectedSize (ranges_));
|
||||
}
|
||||
|
||||
|
||||
|
||||
Factor::Factor (
|
||||
const Vars& vars,
|
||||
const Params& params,
|
||||
unsigned distId)
|
||||
{
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
args_.push_back (vars[i]->varId());
|
||||
ranges_.push_back (vars[i]->range());
|
||||
}
|
||||
params_ = params;
|
||||
distId_ = distId;
|
||||
assert (params_.size() == Util::expectedSize (ranges_));
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::sumOutAllExcept (VarId vid)
|
||||
{
|
||||
assert (indexOf (vid) != -1);
|
||||
while (args_.back() != vid) {
|
||||
sumOutLastVariable();
|
||||
}
|
||||
while (args_.front() != vid) {
|
||||
sumOutFirstVariable();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::sumOutAllExcept (const VarIds& vids)
|
||||
{
|
||||
for (int i = 0; i < (int)args_.size(); i++) {
|
||||
if (Util::contains (vids, args_[i]) == false) {
|
||||
sumOut (args_[i]);
|
||||
i --;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::sumOut (VarId vid)
|
||||
{
|
||||
int idx = indexOf (vid);
|
||||
assert (idx != -1);
|
||||
|
||||
if (vid == args_.back()) {
|
||||
sumOutLastVariable(); // optimization
|
||||
return;
|
||||
}
|
||||
if (vid == args_.front()) {
|
||||
sumOutFirstVariable(); // optimization
|
||||
return;
|
||||
}
|
||||
|
||||
// number of parameters separating a different state of `var',
|
||||
// with the states of the remaining variables fixed
|
||||
unsigned varOffset = 1;
|
||||
|
||||
// number of parameters separating a different state of the variable
|
||||
// on the left of `var', with the states of the remaining vars fixed
|
||||
unsigned leftVarOffset = 1;
|
||||
|
||||
for (int i = args_.size() - 1; i > idx; i--) {
|
||||
varOffset *= ranges_[i];
|
||||
leftVarOffset *= ranges_[i];
|
||||
}
|
||||
leftVarOffset *= ranges_[idx];
|
||||
|
||||
unsigned offset = 0;
|
||||
unsigned count1 = 0;
|
||||
unsigned count2 = 0;
|
||||
unsigned newpsSize = params_.size() / ranges_[idx];
|
||||
|
||||
Params newps;
|
||||
newps.reserve (newpsSize);
|
||||
|
||||
while (newps.size() < newpsSize) {
|
||||
double sum = LogAware::addIdenty();
|
||||
for (unsigned i = 0; i < ranges_[idx]; i++) {
|
||||
if (Globals::logDomain) {
|
||||
sum = Util::logSum (sum, params_[offset]);
|
||||
} else {
|
||||
sum += params_[offset];
|
||||
}
|
||||
offset += varOffset;
|
||||
}
|
||||
newps.push_back (sum);
|
||||
count1 ++;
|
||||
if (idx == (int)args_.size() - 1) {
|
||||
offset = count1 * ranges_[idx];
|
||||
} else {
|
||||
if (((offset - varOffset + 1) % leftVarOffset) == 0) {
|
||||
count1 = 0;
|
||||
count2 ++;
|
||||
}
|
||||
offset = (leftVarOffset * count2) + count1;
|
||||
}
|
||||
}
|
||||
args_.erase (args_.begin() + idx);
|
||||
ranges_.erase (ranges_.begin() + idx);
|
||||
params_ = newps;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::sumOutFirstVariable (void)
|
||||
{
|
||||
unsigned range = ranges_.front();
|
||||
unsigned sep = params_.size() / range;
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned i = sep; i < params_.size(); i++) {
|
||||
params_[i % sep] = Util::logSum (params_[i % sep], params_[i]);
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = sep; i < params_.size(); i++) {
|
||||
params_[i % sep] += params_[i];
|
||||
}
|
||||
}
|
||||
params_.resize (sep);
|
||||
args_.erase (args_.begin());
|
||||
ranges_.erase (ranges_.begin());
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::sumOutLastVariable (void)
|
||||
{
|
||||
unsigned range = ranges_.back();
|
||||
unsigned idx1 = 0;
|
||||
unsigned idx2 = 0;
|
||||
if (Globals::logDomain) {
|
||||
while (idx1 < params_.size()) {
|
||||
params_[idx2] = params_[idx1];
|
||||
idx1 ++;
|
||||
for (unsigned j = 1; j < range; j++) {
|
||||
params_[idx2] = Util::logSum (params_[idx2], params_[idx1]);
|
||||
idx1 ++;
|
||||
}
|
||||
idx2 ++;
|
||||
}
|
||||
} else {
|
||||
while (idx1 < params_.size()) {
|
||||
params_[idx2] = params_[idx1];
|
||||
idx1 ++;
|
||||
for (unsigned j = 1; j < range; j++) {
|
||||
params_[idx2] += params_[idx1];
|
||||
idx1 ++;
|
||||
}
|
||||
idx2 ++;
|
||||
}
|
||||
}
|
||||
params_.resize (idx2);
|
||||
args_.pop_back();
|
||||
ranges_.pop_back();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::multiply (Factor& g)
|
||||
{
|
||||
if (args_.size() == 0) {
|
||||
copyFromFactor (g);
|
||||
return;
|
||||
}
|
||||
TFactor<VarId>::multiply (g);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::reorderAccordingVarIds (void)
|
||||
{
|
||||
VarIds sortedVarIds = args_;
|
||||
sort (sortedVarIds.begin(), sortedVarIds.end());
|
||||
reorderArguments (sortedVarIds);
|
||||
}
|
||||
|
||||
|
||||
|
||||
string
|
||||
Factor::getLabel (void) const
|
||||
{
|
||||
stringstream ss;
|
||||
ss << "f(" ;
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
if (i != 0) ss << "," ;
|
||||
ss << Var (args_[i], ranges_[i]).label();
|
||||
}
|
||||
ss << ")" ;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::print (void) const
|
||||
{
|
||||
Vars vars;
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
vars.push_back (new Var (args_[i], ranges_[i]));
|
||||
}
|
||||
vector<string> jointStrings = Util::getStateLines (vars);
|
||||
for (unsigned i = 0; i < params_.size(); i++) {
|
||||
cout << "[" << distId_ << "] f(" << jointStrings[i] << ")" ;
|
||||
cout << " = " << params_[i] << endl;
|
||||
}
|
||||
cout << endl;
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
delete vars[i];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::copyFromFactor (const Factor& g)
|
||||
{
|
||||
args_ = g.arguments();
|
||||
ranges_ = g.ranges();
|
||||
params_ = g.params();
|
||||
distId_ = g.distId();
|
||||
}
|
||||
|
@@ -1,288 +0,0 @@
|
||||
#ifndef HORUS_FACTOR_H
|
||||
#define HORUS_FACTOR_H
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "Var.h"
|
||||
#include "Indexer.h"
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
||||
template <typename T>
|
||||
class TFactor
|
||||
{
|
||||
public:
|
||||
const vector<T>& arguments (void) const { return args_; }
|
||||
|
||||
vector<T>& arguments (void) { return args_; }
|
||||
|
||||
const Ranges& ranges (void) const { return ranges_; }
|
||||
|
||||
const Params& params (void) const { return params_; }
|
||||
|
||||
Params& params (void) { return params_; }
|
||||
|
||||
unsigned nrArguments (void) const { return args_.size(); }
|
||||
|
||||
unsigned size (void) const { return params_.size(); }
|
||||
|
||||
unsigned distId (void) const { return distId_; }
|
||||
|
||||
void setDistId (unsigned id) { distId_ = id; }
|
||||
|
||||
void normalize (void) { LogAware::normalize (params_); }
|
||||
|
||||
void setParams (const Params& newParams)
|
||||
{
|
||||
params_ = newParams;
|
||||
assert (params_.size() == Util::expectedSize (ranges_));
|
||||
}
|
||||
|
||||
int indexOf (const T& t) const
|
||||
{
|
||||
int idx = -1;
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
if (args_[i] == t) {
|
||||
idx = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return idx;
|
||||
}
|
||||
|
||||
const T& argument (unsigned idx) const
|
||||
{
|
||||
assert (idx < args_.size());
|
||||
return args_[idx];
|
||||
}
|
||||
|
||||
T& argument (unsigned idx)
|
||||
{
|
||||
assert (idx < args_.size());
|
||||
return args_[idx];
|
||||
}
|
||||
|
||||
unsigned range (unsigned idx) const
|
||||
{
|
||||
assert (idx < ranges_.size());
|
||||
return ranges_[idx];
|
||||
}
|
||||
|
||||
void multiply (TFactor<T>& g)
|
||||
{
|
||||
const vector<T>& g_args = g.arguments();
|
||||
const Ranges& g_ranges = g.ranges();
|
||||
const Params& g_params = g.params();
|
||||
if (args_ == g_args) {
|
||||
// optimization: if the factors contain the same set of args,
|
||||
// we can do a 1 to 1 operation on the parameters
|
||||
if (Globals::logDomain) {
|
||||
Util::add (params_, g_params);
|
||||
} else {
|
||||
Util::multiply (params_, g_params);
|
||||
}
|
||||
} else {
|
||||
bool sharedArgs = false;
|
||||
vector<unsigned> gvarpos;
|
||||
for (unsigned i = 0; i < g_args.size(); i++) {
|
||||
int idx = indexOf (g_args[i]);
|
||||
if (idx == -1) {
|
||||
insertArgument (g_args[i], g_ranges[i]);
|
||||
gvarpos.push_back (args_.size() - 1);
|
||||
} else {
|
||||
sharedArgs = true;
|
||||
gvarpos.push_back (idx);
|
||||
}
|
||||
}
|
||||
if (sharedArgs == false) {
|
||||
// optimization: if the original factors doesn't have common args,
|
||||
// we don't need to marry the states of the common args
|
||||
unsigned count = 0;
|
||||
for (unsigned i = 0; i < params_.size(); i++) {
|
||||
if (Globals::logDomain) {
|
||||
params_[i] += g_params[count];
|
||||
} else {
|
||||
params_[i] *= g_params[count];
|
||||
}
|
||||
count ++;
|
||||
if (count >= g_params.size()) {
|
||||
count = 0;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
StatesIndexer indexer (ranges_, false);
|
||||
while (indexer.valid()) {
|
||||
unsigned g_li = 0;
|
||||
unsigned prod = 1;
|
||||
for (int j = gvarpos.size() - 1; j >= 0; j--) {
|
||||
g_li += indexer[gvarpos[j]] * prod;
|
||||
prod *= g_ranges[j];
|
||||
}
|
||||
if (Globals::logDomain) {
|
||||
params_[indexer] += g_params[g_li];
|
||||
} else {
|
||||
params_[indexer] *= g_params[g_li];
|
||||
}
|
||||
++ indexer;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void absorveEvidence (const T& arg, unsigned evidence)
|
||||
{
|
||||
int idx = indexOf (arg);
|
||||
assert (idx != -1);
|
||||
assert (evidence < ranges_[idx]);
|
||||
Params copy = params_;
|
||||
params_.clear();
|
||||
params_.reserve (copy.size() / ranges_[idx]);
|
||||
StatesIndexer indexer (ranges_);
|
||||
for (unsigned i = 0; i < evidence; i++) {
|
||||
indexer.increment (idx);
|
||||
}
|
||||
while (indexer.valid()) {
|
||||
params_.push_back (copy[indexer]);
|
||||
indexer.incrementExcluding (idx);
|
||||
}
|
||||
args_.erase (args_.begin() + idx);
|
||||
ranges_.erase (ranges_.begin() + idx);
|
||||
}
|
||||
|
||||
void reorderArguments (const vector<T> newArgs)
|
||||
{
|
||||
assert (newArgs.size() == args_.size());
|
||||
if (newArgs == args_) {
|
||||
return; // already in the wanted order
|
||||
}
|
||||
Ranges newRanges;
|
||||
vector<unsigned> positions;
|
||||
for (unsigned i = 0; i < newArgs.size(); i++) {
|
||||
unsigned idx = indexOf (newArgs[i]);
|
||||
newRanges.push_back (ranges_[idx]);
|
||||
positions.push_back (idx);
|
||||
}
|
||||
unsigned N = ranges_.size();
|
||||
Params newParams (params_.size());
|
||||
for (unsigned i = 0; i < params_.size(); i++) {
|
||||
unsigned li = i;
|
||||
// calculate vector index corresponding to linear index
|
||||
vector<unsigned> vi (N);
|
||||
for (int k = N-1; k >= 0; k--) {
|
||||
vi[k] = li % ranges_[k];
|
||||
li /= ranges_[k];
|
||||
}
|
||||
// convert permuted vector index to corresponding linear index
|
||||
unsigned prod = 1;
|
||||
unsigned new_li = 0;
|
||||
for (int k = N - 1; k >= 0; k--) {
|
||||
new_li += vi[positions[k]] * prod;
|
||||
prod *= ranges_[positions[k]];
|
||||
}
|
||||
newParams[new_li] = params_[i];
|
||||
}
|
||||
args_ = newArgs;
|
||||
ranges_ = newRanges;
|
||||
params_ = newParams;
|
||||
}
|
||||
|
||||
bool contains (const T& arg) const
|
||||
{
|
||||
return Util::contains (args_, arg);
|
||||
}
|
||||
|
||||
bool contains (const vector<T>& args) const
|
||||
{
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
if (contains (args[i]) == false) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
vector<T> args_;
|
||||
Ranges ranges_;
|
||||
Params params_;
|
||||
unsigned distId_;
|
||||
|
||||
private:
|
||||
void insertArgument (const T& arg, unsigned range)
|
||||
{
|
||||
assert (indexOf (arg) == -1);
|
||||
Params copy = params_;
|
||||
params_.clear();
|
||||
params_.reserve (copy.size() * range);
|
||||
for (unsigned i = 0; i < copy.size(); i++) {
|
||||
for (unsigned reps = 0; reps < range; reps++) {
|
||||
params_.push_back (copy[i]);
|
||||
}
|
||||
}
|
||||
args_.push_back (arg);
|
||||
ranges_.push_back (range);
|
||||
}
|
||||
|
||||
void insertArguments (const vector<T>& args, const Ranges& ranges)
|
||||
{
|
||||
Params copy = params_;
|
||||
unsigned nrStates = 1;
|
||||
for (unsigned i = 0; i < args.size(); i++) {
|
||||
assert (indexOf (args[i]) == -1);
|
||||
args_.push_back (args[i]);
|
||||
ranges_.push_back (ranges[i]);
|
||||
nrStates *= ranges[i];
|
||||
}
|
||||
params_.clear();
|
||||
params_.reserve (copy.size() * nrStates);
|
||||
for (unsigned i = 0; i < copy.size(); i++) {
|
||||
for (unsigned reps = 0; reps < nrStates; reps++) {
|
||||
params_.push_back (copy[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
||||
class Factor : public TFactor<VarId>
|
||||
{
|
||||
public:
|
||||
Factor (void) { }
|
||||
|
||||
Factor (const Factor&);
|
||||
|
||||
Factor (const VarIds&, const Ranges&, const Params&,
|
||||
unsigned = Util::maxUnsigned());
|
||||
|
||||
Factor (const Vars&, const Params&,
|
||||
unsigned = Util::maxUnsigned());
|
||||
|
||||
void sumOutAllExcept (VarId);
|
||||
|
||||
void sumOutAllExcept (const VarIds&);
|
||||
|
||||
void sumOut (VarId);
|
||||
|
||||
void sumOutFirstVariable (void);
|
||||
|
||||
void sumOutLastVariable (void);
|
||||
|
||||
void multiply (Factor&);
|
||||
|
||||
void reorderAccordingVarIds (void);
|
||||
|
||||
string getLabel (void) const;
|
||||
|
||||
void print (void) const;
|
||||
|
||||
private:
|
||||
void copyFromFactor (const Factor& f);
|
||||
|
||||
};
|
||||
|
||||
#endif // HORUS_FACTOR_H
|
||||
|
@@ -1,468 +0,0 @@
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
|
||||
#include "FactorGraph.h"
|
||||
#include "Factor.h"
|
||||
#include "BayesNet.h"
|
||||
#include "BayesBall.h"
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
bool FactorGraph::orderFactorVariables = false;
|
||||
|
||||
|
||||
FactorGraph::FactorGraph (const FactorGraph& fg)
|
||||
{
|
||||
const VarNodes& varNodes = fg.varNodes();
|
||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||
addVarNode (new VarNode (varNodes[i]));
|
||||
}
|
||||
const FacNodes& facNodes = fg.facNodes();
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
FacNode* facNode = new FacNode (facNodes[i]->factor());
|
||||
addFacNode (facNode);
|
||||
const VarNodes& neighs = facNodes[i]->neighbors();
|
||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||
addEdge (varNodes_[neighs[j]->getIndex()], facNode);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::readFromUaiFormat (const char* fileName)
|
||||
{
|
||||
std::ifstream is (fileName);
|
||||
if (!is.is_open()) {
|
||||
cerr << "error: cannot read from file " << fileName << endl;
|
||||
abort();
|
||||
}
|
||||
ignoreLines (is);
|
||||
string line;
|
||||
getline (is, line);
|
||||
if (line != "MARKOV") {
|
||||
cerr << "error: the network must be a MARKOV network " << endl;
|
||||
abort();
|
||||
}
|
||||
// read the number of vars
|
||||
ignoreLines (is);
|
||||
unsigned nrVars;
|
||||
is >> nrVars;
|
||||
// read the range of each var
|
||||
ignoreLines (is);
|
||||
Ranges ranges (nrVars);
|
||||
for (unsigned i = 0; i < nrVars; i++) {
|
||||
is >> ranges[i];
|
||||
}
|
||||
unsigned nrFactors;
|
||||
unsigned nrArgs;
|
||||
unsigned vid;
|
||||
is >> nrFactors;
|
||||
vector<VarIds> factorVarIds;
|
||||
vector<Ranges> factorRanges;
|
||||
for (unsigned i = 0; i < nrFactors; i++) {
|
||||
ignoreLines (is);
|
||||
is >> nrArgs;
|
||||
factorVarIds.push_back ({ });
|
||||
factorRanges.push_back ({ });
|
||||
for (unsigned j = 0; j < nrArgs; j++) {
|
||||
is >> vid;
|
||||
if (vid >= ranges.size()) {
|
||||
cerr << "error: invalid variable identifier `" << vid << "'" << endl;
|
||||
cerr << "identifiers must be between 0 and " << ranges.size() - 1 ;
|
||||
cerr << endl;
|
||||
abort();
|
||||
}
|
||||
factorVarIds.back().push_back (vid);
|
||||
factorRanges.back().push_back (ranges[vid]);
|
||||
}
|
||||
}
|
||||
// read the parameters
|
||||
unsigned nrParams;
|
||||
for (unsigned i = 0; i < nrFactors; i++) {
|
||||
ignoreLines (is);
|
||||
is >> nrParams;
|
||||
if (nrParams != Util::expectedSize (factorRanges[i])) {
|
||||
cerr << "error: invalid number of parameters for factor nº " << i ;
|
||||
cerr << ", expected: " << Util::expectedSize (factorRanges[i]);
|
||||
cerr << ", given: " << nrParams << endl;
|
||||
abort();
|
||||
}
|
||||
Params params (nrParams);
|
||||
for (unsigned j = 0; j < nrParams; j++) {
|
||||
is >> params[j];
|
||||
}
|
||||
if (Globals::logDomain) {
|
||||
Util::toLog (params);
|
||||
}
|
||||
addFactor (Factor (factorVarIds[i], factorRanges[i], params));
|
||||
}
|
||||
is.close();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::readFromLibDaiFormat (const char* fileName)
|
||||
{
|
||||
std::ifstream is (fileName);
|
||||
if (!is.is_open()) {
|
||||
cerr << "error: cannot read from file " << fileName << endl;
|
||||
abort();
|
||||
}
|
||||
ignoreLines (is);
|
||||
unsigned nrFactors;
|
||||
unsigned nrArgs;
|
||||
VarId vid;
|
||||
is >> nrFactors;
|
||||
for (unsigned i = 0; i < nrFactors; i++) {
|
||||
ignoreLines (is);
|
||||
// read the factor arguments
|
||||
is >> nrArgs;
|
||||
VarIds vids;
|
||||
for (unsigned j = 0; j < nrArgs; j++) {
|
||||
ignoreLines (is);
|
||||
is >> vid;
|
||||
vids.push_back (vid);
|
||||
}
|
||||
// read ranges
|
||||
Ranges ranges (nrArgs);
|
||||
for (unsigned j = 0; j < nrArgs; j++) {
|
||||
ignoreLines (is);
|
||||
is >> ranges[j];
|
||||
VarNode* var = getVarNode (vids[j]);
|
||||
if (var != 0 && ranges[j] != var->range()) {
|
||||
cerr << "error: variable `" << vids[j] << "' appears in two or " ;
|
||||
cerr << "more factors with a different range" << endl;
|
||||
}
|
||||
}
|
||||
// read parameters
|
||||
ignoreLines (is);
|
||||
unsigned nNonzeros;
|
||||
is >> nNonzeros;
|
||||
Params params (Util::expectedSize (ranges), 0);
|
||||
for (unsigned j = 0; j < nNonzeros; j++) {
|
||||
ignoreLines (is);
|
||||
unsigned index;
|
||||
is >> index;
|
||||
ignoreLines (is);
|
||||
double val;
|
||||
is >> val;
|
||||
params[index] = val;
|
||||
}
|
||||
reverse (vids.begin(), vids.end());
|
||||
if (Globals::logDomain) {
|
||||
Util::toLog (params);
|
||||
}
|
||||
addFactor (Factor (vids, ranges, params));
|
||||
}
|
||||
is.close();
|
||||
}
|
||||
|
||||
|
||||
|
||||
FactorGraph::~FactorGraph (void)
|
||||
{
|
||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||
delete varNodes_[i];
|
||||
}
|
||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||
delete facNodes_[i];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::addFactor (const Factor& factor)
|
||||
{
|
||||
FacNode* fn = new FacNode (factor);
|
||||
addFacNode (fn);
|
||||
const VarIds& vids = factor.arguments();
|
||||
for (unsigned i = 0; i < vids.size(); i++) {
|
||||
bool found = false;
|
||||
for (unsigned j = 0; j < varNodes_.size(); j++) {
|
||||
if (varNodes_[j]->varId() == vids[i]) {
|
||||
addEdge (varNodes_[j], fn);
|
||||
found = true;
|
||||
}
|
||||
}
|
||||
if (found == false) {
|
||||
VarNode* vn = new VarNode (vids[i], factor.range (i));
|
||||
addVarNode (vn);
|
||||
addEdge (vn, fn);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::addVarNode (VarNode* vn)
|
||||
{
|
||||
varNodes_.push_back (vn);
|
||||
vn->setIndex (varNodes_.size() - 1);
|
||||
varMap_.insert (make_pair (vn->varId(), vn));
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::addFacNode (FacNode* fn)
|
||||
{
|
||||
facNodes_.push_back (fn);
|
||||
fn->setIndex (facNodes_.size() - 1);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::addEdge (VarNode* vn, FacNode* fn)
|
||||
{
|
||||
vn->addNeighbor (fn);
|
||||
fn->addNeighbor (vn);
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
FactorGraph::isTree (void) const
|
||||
{
|
||||
return !containsCycle();
|
||||
}
|
||||
|
||||
|
||||
|
||||
DAGraph&
|
||||
FactorGraph::getStructure (void)
|
||||
{
|
||||
assert (fromBayesNet_);
|
||||
if (structure_.empty()) {
|
||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||
structure_.addNode (new DAGraphNode (varNodes_[i]));
|
||||
}
|
||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||
const VarIds& vids = facNodes_[i]->factor().arguments();
|
||||
for (unsigned j = 1; j < vids.size(); j++) {
|
||||
structure_.addEdge (vids[j], vids[0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
return structure_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::print (void) const
|
||||
{
|
||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||
cout << "var id = " << varNodes_[i]->varId() << endl;
|
||||
cout << "label = " << varNodes_[i]->label() << endl;
|
||||
cout << "range = " << varNodes_[i]->range() << endl;
|
||||
cout << "evidence = " << varNodes_[i]->getEvidence() << endl;
|
||||
cout << "factors = " ;
|
||||
for (unsigned j = 0; j < varNodes_[i]->neighbors().size(); j++) {
|
||||
cout << varNodes_[i]->neighbors()[j]->getLabel() << " " ;
|
||||
}
|
||||
cout << endl << endl;
|
||||
}
|
||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||
facNodes_[i]->factor().print();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::exportToGraphViz (const char* fileName) const
|
||||
{
|
||||
ofstream out (fileName);
|
||||
if (!out.is_open()) {
|
||||
cerr << "error: cannot open file to write at " ;
|
||||
cerr << "FactorGraph::exportToDotFile()" << endl;
|
||||
abort();
|
||||
}
|
||||
out << "graph \"" << fileName << "\" {" << endl;
|
||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||
if (varNodes_[i]->hasEvidence()) {
|
||||
out << '"' << varNodes_[i]->label() << '"' ;
|
||||
out << " [style=filled, fillcolor=yellow]" << endl;
|
||||
}
|
||||
}
|
||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||
out << '"' << facNodes_[i]->getLabel() << '"' ;
|
||||
out << " [label=\"" << facNodes_[i]->getLabel();
|
||||
out << "\"" << ", shape=box]" << endl;
|
||||
}
|
||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||
const VarNodes& myVars = facNodes_[i]->neighbors();
|
||||
for (unsigned j = 0; j < myVars.size(); j++) {
|
||||
out << '"' << facNodes_[i]->getLabel() << '"' ;
|
||||
out << " -- " ;
|
||||
out << '"' << myVars[j]->label() << '"' << endl;
|
||||
}
|
||||
}
|
||||
out << "}" << endl;
|
||||
out.close();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::exportToUaiFormat (const char* fileName) const
|
||||
{
|
||||
ofstream out (fileName);
|
||||
if (!out.is_open()) {
|
||||
cerr << "error: cannot open file " << fileName << endl;
|
||||
abort();
|
||||
}
|
||||
out << "MARKOV" << endl;
|
||||
out << varNodes_.size() << endl;
|
||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||
out << varNodes_[i]->range() << " " ;
|
||||
}
|
||||
out << endl;
|
||||
out << facNodes_.size() << endl;
|
||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||
const VarNodes& factorVars = facNodes_[i]->neighbors();
|
||||
out << factorVars.size();
|
||||
for (unsigned j = 0; j < factorVars.size(); j++) {
|
||||
out << " " << factorVars[j]->getIndex();
|
||||
}
|
||||
out << endl;
|
||||
}
|
||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||
Params params = facNodes_[i]->factor().params();
|
||||
if (Globals::logDomain) {
|
||||
Util::fromLog (params);
|
||||
}
|
||||
out << endl << params.size() << endl << " " ;
|
||||
for (unsigned j = 0; j < params.size(); j++) {
|
||||
out << params[j] << " " ;
|
||||
}
|
||||
out << endl;
|
||||
}
|
||||
out.close();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::exportToLibDaiFormat (const char* fileName) const
|
||||
{
|
||||
ofstream out (fileName);
|
||||
if (!out.is_open()) {
|
||||
cerr << "error: cannot open file " << fileName << endl;
|
||||
abort();
|
||||
}
|
||||
out << facNodes_.size() << endl << endl;
|
||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||
const VarNodes& factorVars = facNodes_[i]->neighbors();
|
||||
out << factorVars.size() << endl;
|
||||
for (int j = factorVars.size() - 1; j >= 0; j--) {
|
||||
out << factorVars[j]->varId() << " " ;
|
||||
}
|
||||
out << endl;
|
||||
for (unsigned j = 0; j < factorVars.size(); j++) {
|
||||
out << factorVars[j]->range() << " " ;
|
||||
}
|
||||
out << endl;
|
||||
Params params = facNodes_[i]->factor().params();
|
||||
if (Globals::logDomain) {
|
||||
Util::fromLog (params);
|
||||
}
|
||||
out << params.size() << endl;
|
||||
for (unsigned j = 0; j < params.size(); j++) {
|
||||
out << j << " " << params[j] << endl;
|
||||
}
|
||||
out << endl;
|
||||
}
|
||||
out.close();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::ignoreLines (std::ifstream& is) const
|
||||
{
|
||||
string ignoreStr;
|
||||
while (is.peek() == '#' || is.peek() == '\n') {
|
||||
getline (is, ignoreStr);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
FactorGraph::containsCycle (void) const
|
||||
{
|
||||
vector<bool> visitedVars (varNodes_.size(), false);
|
||||
vector<bool> visitedFactors (facNodes_.size(), false);
|
||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||
int v = varNodes_[i]->getIndex();
|
||||
if (!visitedVars[v]) {
|
||||
if (containsCycle (varNodes_[i], 0, visitedVars, visitedFactors)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
FactorGraph::containsCycle (
|
||||
const VarNode* v,
|
||||
const FacNode* p,
|
||||
vector<bool>& visitedVars,
|
||||
vector<bool>& visitedFactors) const
|
||||
{
|
||||
visitedVars[v->getIndex()] = true;
|
||||
const FacNodes& adjacencies = v->neighbors();
|
||||
for (unsigned i = 0; i < adjacencies.size(); i++) {
|
||||
int w = adjacencies[i]->getIndex();
|
||||
if (!visitedFactors[w]) {
|
||||
if (containsCycle (adjacencies[i], v, visitedVars, visitedFactors)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
else if (visitedFactors[w] && adjacencies[i] != p) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false; // no cycle detected in this component
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
FactorGraph::containsCycle (
|
||||
const FacNode* v,
|
||||
const VarNode* p,
|
||||
vector<bool>& visitedVars,
|
||||
vector<bool>& visitedFactors) const
|
||||
{
|
||||
visitedFactors[v->getIndex()] = true;
|
||||
const VarNodes& adjacencies = v->neighbors();
|
||||
for (unsigned i = 0; i < adjacencies.size(); i++) {
|
||||
int w = adjacencies[i]->getIndex();
|
||||
if (!visitedVars[w]) {
|
||||
if (containsCycle (adjacencies[i], v, visitedVars, visitedFactors)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
else if (visitedVars[w] && adjacencies[i] != p) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false; // no cycle detected in this component
|
||||
}
|
||||
|
@@ -1,142 +0,0 @@
|
||||
#ifndef HORUS_FACTORGRAPH_H
|
||||
#define HORUS_FACTORGRAPH_H
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "Factor.h"
|
||||
#include "BayesNet.h"
|
||||
#include "Horus.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
||||
class FacNode;
|
||||
|
||||
class VarNode : public Var
|
||||
{
|
||||
public:
|
||||
VarNode (VarId varId, unsigned nrStates)
|
||||
: Var (varId, nrStates) { }
|
||||
|
||||
VarNode (const Var* v) : Var (v) { }
|
||||
|
||||
void addNeighbor (FacNode* fn) { neighs_.push_back (fn); }
|
||||
|
||||
const FacNodes& neighbors (void) const { return neighs_; }
|
||||
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN (VarNode);
|
||||
|
||||
FacNodes neighs_;
|
||||
};
|
||||
|
||||
|
||||
class FacNode
|
||||
{
|
||||
public:
|
||||
FacNode (const Factor& f) : factor_(f), index_(-1) { }
|
||||
|
||||
const Factor& factor (void) const { return factor_; }
|
||||
|
||||
Factor& factor (void) { return factor_; }
|
||||
|
||||
void addNeighbor (VarNode* vn) { neighs_.push_back (vn); }
|
||||
|
||||
const VarNodes& neighbors (void) const { return neighs_; }
|
||||
|
||||
int getIndex (void) const { return index_; }
|
||||
|
||||
void setIndex (int index) { index_ = index; }
|
||||
|
||||
string getLabel (void) { return factor_.getLabel(); }
|
||||
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN (FacNode);
|
||||
|
||||
VarNodes neighs_;
|
||||
Factor factor_;
|
||||
int index_;
|
||||
};
|
||||
|
||||
|
||||
struct CompVarId
|
||||
{
|
||||
bool operator() (const Var* v1, const Var* v2) const
|
||||
{
|
||||
return v1->varId() < v2->varId();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
class FactorGraph
|
||||
{
|
||||
public:
|
||||
FactorGraph (bool fbn = false) : fromBayesNet_(fbn) { }
|
||||
|
||||
FactorGraph (const FactorGraph&);
|
||||
|
||||
~FactorGraph (void);
|
||||
|
||||
const VarNodes& varNodes (void) const { return varNodes_; }
|
||||
|
||||
const FacNodes& facNodes (void) const { return facNodes_; }
|
||||
|
||||
bool isFromBayesNetwork (void) const { return fromBayesNet_ ; }
|
||||
|
||||
VarNode* getVarNode (VarId vid) const
|
||||
{
|
||||
VarMap::const_iterator it = varMap_.find (vid);
|
||||
return it != varMap_.end() ? it->second : 0;
|
||||
}
|
||||
|
||||
void readFromUaiFormat (const char*);
|
||||
|
||||
void readFromLibDaiFormat (const char*);
|
||||
|
||||
void addFactor (const Factor& factor);
|
||||
|
||||
void addVarNode (VarNode*);
|
||||
|
||||
void addFacNode (FacNode*);
|
||||
|
||||
void addEdge (VarNode*, FacNode*);
|
||||
|
||||
bool isTree (void) const;
|
||||
|
||||
DAGraph& getStructure (void);
|
||||
|
||||
void print (void) const;
|
||||
|
||||
void exportToGraphViz (const char*) const;
|
||||
|
||||
void exportToUaiFormat (const char*) const;
|
||||
|
||||
void exportToLibDaiFormat (const char*) const;
|
||||
|
||||
static bool orderFactorVariables;
|
||||
|
||||
private:
|
||||
// DISALLOW_COPY_AND_ASSIGN (FactorGraph);
|
||||
|
||||
void ignoreLines (std::ifstream&) const;
|
||||
|
||||
bool containsCycle (void) const;
|
||||
|
||||
bool containsCycle (const VarNode*, const FacNode*,
|
||||
vector<bool>&, vector<bool>&) const;
|
||||
|
||||
bool containsCycle (const FacNode*, const VarNode*,
|
||||
vector<bool>&, vector<bool>&) const;
|
||||
|
||||
VarNodes varNodes_;
|
||||
FacNodes facNodes_;
|
||||
|
||||
DAGraph structure_;
|
||||
bool fromBayesNet_;
|
||||
|
||||
typedef unordered_map<unsigned, VarNode*> VarMap;
|
||||
VarMap varMap_;
|
||||
};
|
||||
|
||||
#endif // HORUS_FACTORGRAPH_H
|
||||
|
@@ -1,711 +0,0 @@
|
||||
|
||||
#include <algorithm>
|
||||
#include <set>
|
||||
|
||||
#include "FoveSolver.h"
|
||||
#include "Histogram.h"
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
vector<LiftedOperator*>
|
||||
LiftedOperator::getValidOps (
|
||||
ParfactorList& pfList,
|
||||
const Grounds& query)
|
||||
{
|
||||
vector<LiftedOperator*> validOps;
|
||||
vector<SumOutOperator*> sumOutOps;
|
||||
vector<CountingOperator*> countOps;
|
||||
vector<GroundOperator*> groundOps;
|
||||
|
||||
sumOutOps = SumOutOperator::getValidOps (pfList, query);
|
||||
countOps = CountingOperator::getValidOps (pfList);
|
||||
groundOps = GroundOperator::getValidOps (pfList);
|
||||
|
||||
validOps.insert (validOps.end(), sumOutOps.begin(), sumOutOps.end());
|
||||
validOps.insert (validOps.end(), countOps.begin(), countOps.end());
|
||||
validOps.insert (validOps.end(), groundOps.begin(), groundOps.end());
|
||||
return validOps;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
LiftedOperator::printValidOps (
|
||||
ParfactorList& pfList,
|
||||
const Grounds& query)
|
||||
{
|
||||
vector<LiftedOperator*> validOps;
|
||||
validOps = LiftedOperator::getValidOps (pfList, query);
|
||||
for (unsigned i = 0; i < validOps.size(); i++) {
|
||||
cout << "-> " << validOps[i]->toString() << endl;
|
||||
delete validOps[i];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
unsigned
|
||||
SumOutOperator::getCost (void)
|
||||
{
|
||||
TinySet<unsigned> groupSet;
|
||||
ParfactorList::const_iterator pfIter = pfList_.begin();
|
||||
while (pfIter != pfList_.end()) {
|
||||
if ((*pfIter)->containsGroup (group_)) {
|
||||
vector<unsigned> groups = (*pfIter)->getAllGroups();
|
||||
groupSet |= TinySet<unsigned> (groups);
|
||||
}
|
||||
++ pfIter;
|
||||
}
|
||||
unsigned cost = 1;
|
||||
for (unsigned i = 0; i < groupSet.size(); i++) {
|
||||
pfIter = pfList_.begin();
|
||||
while (pfIter != pfList_.end()) {
|
||||
if ((*pfIter)->containsGroup (groupSet[i])) {
|
||||
int idx = (*pfIter)->indexOfGroup (groupSet[i]);
|
||||
cost *= (*pfIter)->range (idx);
|
||||
break;
|
||||
}
|
||||
++ pfIter;
|
||||
}
|
||||
}
|
||||
return cost;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
SumOutOperator::apply (void)
|
||||
{
|
||||
vector<ParfactorList::iterator> iters
|
||||
= parfactorsWithGroup (pfList_, group_);
|
||||
Parfactor* product = *(iters[0]);
|
||||
pfList_.remove (iters[0]);
|
||||
for (unsigned i = 1; i < iters.size(); i++) {
|
||||
product->multiply (**(iters[i]));
|
||||
pfList_.removeAndDelete (iters[i]);
|
||||
}
|
||||
if (product->nrArguments() == 1) {
|
||||
delete product;
|
||||
return;
|
||||
}
|
||||
int fIdx = product->indexOfGroup (group_);
|
||||
LogVarSet excl = product->exclusiveLogVars (fIdx);
|
||||
if (product->constr()->isCountNormalized (excl)) {
|
||||
product->sumOut (fIdx);
|
||||
pfList_.addShattered (product);
|
||||
} else {
|
||||
Parfactors pfs = FoveSolver::countNormalize (product, excl);
|
||||
for (unsigned i = 0; i < pfs.size(); i++) {
|
||||
pfs[i]->sumOut (fIdx);
|
||||
pfList_.add (pfs[i]);
|
||||
}
|
||||
delete product;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
vector<SumOutOperator*>
|
||||
SumOutOperator::getValidOps (
|
||||
ParfactorList& pfList,
|
||||
const Grounds& query)
|
||||
{
|
||||
vector<SumOutOperator*> validOps;
|
||||
set<unsigned> allGroups;
|
||||
ParfactorList::const_iterator it = pfList.begin();
|
||||
while (it != pfList.end()) {
|
||||
const ProbFormulas& formulas = (*it)->arguments();
|
||||
for (unsigned i = 0; i < formulas.size(); i++) {
|
||||
allGroups.insert (formulas[i].group());
|
||||
}
|
||||
++ it;
|
||||
}
|
||||
set<unsigned>::const_iterator groupIt = allGroups.begin();
|
||||
while (groupIt != allGroups.end()) {
|
||||
if (validOp (*groupIt, pfList, query)) {
|
||||
validOps.push_back (new SumOutOperator (*groupIt, pfList));
|
||||
}
|
||||
++ groupIt;
|
||||
}
|
||||
return validOps;
|
||||
}
|
||||
|
||||
|
||||
|
||||
string
|
||||
SumOutOperator::toString (void)
|
||||
{
|
||||
stringstream ss;
|
||||
vector<ParfactorList::iterator> pfIters;
|
||||
pfIters = parfactorsWithGroup (pfList_, group_);
|
||||
int idx = (*pfIters[0])->indexOfGroup (group_);
|
||||
ProbFormula f = (*pfIters[0])->argument (idx);
|
||||
TupleSet tupleSet = (*pfIters[0])->constr()->tupleSet (f.logVars());
|
||||
ss << "sum out " << f.functor() << "/" << f.arity();
|
||||
ss << "|" << tupleSet << " (group " << group_ << ")";
|
||||
ss << " [cost=" << getCost() << "]" << endl;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
SumOutOperator::validOp (
|
||||
unsigned group,
|
||||
ParfactorList& pfList,
|
||||
const Grounds& query)
|
||||
{
|
||||
vector<ParfactorList::iterator> pfIters;
|
||||
pfIters = parfactorsWithGroup (pfList, group);
|
||||
if (isToEliminate (*pfIters[0], group, query) == false) {
|
||||
return false;
|
||||
}
|
||||
unordered_map<unsigned, unsigned> groupToRange;
|
||||
for (unsigned i = 0; i < pfIters.size(); i++) {
|
||||
int fIdx = (*pfIters[i])->indexOfGroup (group);
|
||||
if ((*pfIters[i])->argument (fIdx).contains (
|
||||
(*pfIters[i])->elimLogVars()) == false) {
|
||||
return false;
|
||||
}
|
||||
vector<unsigned> ranges = (*pfIters[i])->ranges();
|
||||
vector<unsigned> groups = (*pfIters[i])->getAllGroups();
|
||||
for (unsigned i = 0; i < groups.size(); i++) {
|
||||
unordered_map<unsigned, unsigned>::iterator it;
|
||||
it = groupToRange.find (groups[i]);
|
||||
if (it == groupToRange.end()) {
|
||||
groupToRange.insert (make_pair (groups[i], ranges[i]));
|
||||
} else {
|
||||
if (it->second != ranges[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
|
||||
vector<ParfactorList::iterator>
|
||||
SumOutOperator::parfactorsWithGroup (
|
||||
ParfactorList& pfList,
|
||||
unsigned group)
|
||||
{
|
||||
vector<ParfactorList::iterator> iters;
|
||||
ParfactorList::iterator pflIt = pfList.begin();
|
||||
while (pflIt != pfList.end()) {
|
||||
if ((*pflIt)->containsGroup (group)) {
|
||||
iters.push_back (pflIt);
|
||||
}
|
||||
++ pflIt;
|
||||
}
|
||||
return iters;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
SumOutOperator::isToEliminate (
|
||||
Parfactor* g,
|
||||
unsigned group,
|
||||
const Grounds& query)
|
||||
{
|
||||
int fIdx = g->indexOfGroup (group);
|
||||
const ProbFormula& formula = g->argument (fIdx);
|
||||
bool toElim = true;
|
||||
for (unsigned i = 0; i < query.size(); i++) {
|
||||
if (formula.functor() == query[i].functor() &&
|
||||
formula.arity() == query[i].arity()) {
|
||||
g->constr()->moveToTop (formula.logVars());
|
||||
if (g->constr()->containsTuple (query[i].args())) {
|
||||
toElim = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return toElim;
|
||||
}
|
||||
|
||||
|
||||
|
||||
unsigned
|
||||
CountingOperator::getCost (void)
|
||||
{
|
||||
unsigned cost = 0;
|
||||
int fIdx = (*pfIter_)->indexOfLogVar (X_);
|
||||
unsigned range = (*pfIter_)->range (fIdx);
|
||||
unsigned size = (*pfIter_)->size() / range;
|
||||
TinySet<unsigned> counts;
|
||||
counts = (*pfIter_)->constr()->getConditionalCounts (X_);
|
||||
for (unsigned i = 0; i < counts.size(); i++) {
|
||||
cost += size * HistogramSet::nrHistograms (counts[i], range);
|
||||
}
|
||||
return cost;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
CountingOperator::apply (void)
|
||||
{
|
||||
if ((*pfIter_)->constr()->isCountNormalized (X_)) {
|
||||
(*pfIter_)->countConvert (X_);
|
||||
} else {
|
||||
Parfactor* pf = *pfIter_;
|
||||
pfList_.remove (pfIter_);
|
||||
Parfactors pfs = FoveSolver::countNormalize (pf, X_);
|
||||
for (unsigned i = 0; i < pfs.size(); i++) {
|
||||
unsigned condCount = pfs[i]->constr()->getConditionalCount (X_);
|
||||
bool cartProduct = pfs[i]->constr()->isCarteesianProduct (
|
||||
pfs[i]->countedLogVars() | X_);
|
||||
if (condCount > 1 && cartProduct) {
|
||||
pfs[i]->countConvert (X_);
|
||||
}
|
||||
pfList_.add (pfs[i]);
|
||||
}
|
||||
delete pf;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
vector<CountingOperator*>
|
||||
CountingOperator::getValidOps (ParfactorList& pfList)
|
||||
{
|
||||
vector<CountingOperator*> validOps;
|
||||
ParfactorList::iterator it = pfList.begin();
|
||||
while (it != pfList.end()) {
|
||||
LogVarSet candidates = (*it)->uncountedLogVars();
|
||||
for (unsigned i = 0; i < candidates.size(); i++) {
|
||||
if (validOp (*it, candidates[i])) {
|
||||
validOps.push_back (new CountingOperator (
|
||||
it, candidates[i], pfList));
|
||||
}
|
||||
}
|
||||
++ it;
|
||||
}
|
||||
return validOps;
|
||||
}
|
||||
|
||||
|
||||
|
||||
string
|
||||
CountingOperator::toString (void)
|
||||
{
|
||||
stringstream ss;
|
||||
ss << "count convert " << X_ << " in " ;
|
||||
ss << (*pfIter_)->getLabel();
|
||||
ss << " [cost=" << getCost() << "]" << endl;
|
||||
Parfactors pfs = FoveSolver::countNormalize (*pfIter_, X_);
|
||||
if ((*pfIter_)->constr()->isCountNormalized (X_) == false) {
|
||||
for (unsigned i = 0; i < pfs.size(); i++) {
|
||||
ss << " º " << pfs[i]->getLabel() << endl;
|
||||
}
|
||||
}
|
||||
for (unsigned i = 0; i < pfs.size(); i++) {
|
||||
delete pfs[i];
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
CountingOperator::validOp (Parfactor* g, LogVar X)
|
||||
{
|
||||
if (g->nrFormulas (X) != 1) {
|
||||
return false;
|
||||
}
|
||||
int fIdx = g->indexOfLogVar (X);
|
||||
if (g->argument (fIdx).isCounting()) {
|
||||
return false;
|
||||
}
|
||||
bool countNormalized = g->constr()->isCountNormalized (X);
|
||||
if (countNormalized) {
|
||||
unsigned condCount = g->constr()->getConditionalCount (X);
|
||||
bool cartProduct = g->constr()->isCarteesianProduct (
|
||||
g->countedLogVars() | X);
|
||||
if (condCount == 1 || cartProduct == false) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
|
||||
unsigned
|
||||
GroundOperator::getCost (void)
|
||||
{
|
||||
unsigned cost = 0;
|
||||
bool isCountingLv = (*pfIter_)->countedLogVars().contains (X_);
|
||||
if (isCountingLv) {
|
||||
int fIdx = (*pfIter_)->indexOfLogVar (X_);
|
||||
unsigned currSize = (*pfIter_)->size();
|
||||
unsigned nrHists = (*pfIter_)->range (fIdx);
|
||||
unsigned range = (*pfIter_)->argument (fIdx).range();
|
||||
unsigned nrSymbols = (*pfIter_)->constr()->getConditionalCount (X_);
|
||||
cost = (currSize / nrHists) * (std::pow (range, nrSymbols));
|
||||
} else {
|
||||
cost = (*pfIter_)->constr()->nrSymbols (X_) * (*pfIter_)->size();
|
||||
}
|
||||
return cost;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
GroundOperator::apply (void)
|
||||
{
|
||||
bool countedLv = (*pfIter_)->countedLogVars().contains (X_);
|
||||
Parfactor* pf = *pfIter_;
|
||||
pfList_.remove (pfIter_);
|
||||
if (countedLv) {
|
||||
pf->fullExpand (X_);
|
||||
pfList_.add (pf);
|
||||
} else {
|
||||
ConstraintTrees cts = pf->constr()->ground (X_);
|
||||
for (unsigned i = 0; i < cts.size(); i++) {
|
||||
pfList_.add (new Parfactor (pf, cts[i]));
|
||||
}
|
||||
delete pf;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
vector<GroundOperator*>
|
||||
GroundOperator::getValidOps (ParfactorList& pfList)
|
||||
{
|
||||
vector<GroundOperator*> validOps;
|
||||
ParfactorList::iterator pfIter = pfList.begin();
|
||||
while (pfIter != pfList.end()) {
|
||||
LogVarSet set = (*pfIter)->logVarSet();
|
||||
for (unsigned i = 0; i < set.size(); i++) {
|
||||
if ((*pfIter)->constr()->isSingleton (set[i]) == false) {
|
||||
validOps.push_back (new GroundOperator (pfIter, set[i], pfList));
|
||||
}
|
||||
}
|
||||
++ pfIter;
|
||||
}
|
||||
return validOps;
|
||||
}
|
||||
|
||||
|
||||
|
||||
string
|
||||
GroundOperator::toString (void)
|
||||
{
|
||||
stringstream ss;
|
||||
((*pfIter_)->countedLogVars().contains (X_))
|
||||
? ss << "full expanding "
|
||||
: ss << "grounding " ;
|
||||
ss << X_ << " in " << (*pfIter_)->getLabel();
|
||||
ss << " [cost=" << getCost() << "]" << endl;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
FoveSolver::getPosterioriOf (const Ground& query)
|
||||
{
|
||||
return getJointDistributionOf ({query});
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
FoveSolver::getJointDistributionOf (const Grounds& query)
|
||||
{
|
||||
runSolver (query);
|
||||
(*pfList_.begin())->normalize();
|
||||
Params params = (*pfList_.begin())->params();
|
||||
if (Globals::logDomain) {
|
||||
Util::fromLog (params);
|
||||
}
|
||||
return params;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FoveSolver::absorveEvidence (
|
||||
ParfactorList& pfList,
|
||||
ObservedFormulas& obsFormulas)
|
||||
{
|
||||
for (unsigned i = 0; i < obsFormulas.size(); i++) {
|
||||
Parfactors newPfs;
|
||||
ParfactorList::iterator it = pfList.begin();
|
||||
while (it != pfList.end()) {
|
||||
Parfactor* pf = *it;
|
||||
it = pfList.remove (it);
|
||||
Parfactors absorvedPfs = absorve (obsFormulas[i], pf);
|
||||
if (absorvedPfs.empty() == false) {
|
||||
if (absorvedPfs.size() == 1 && absorvedPfs[0] == 0) {
|
||||
// just remove pf;
|
||||
} else {
|
||||
Util::addToVector (newPfs, absorvedPfs);
|
||||
}
|
||||
delete pf;
|
||||
} else {
|
||||
it = pfList.insertShattered (it, pf);
|
||||
++ it;
|
||||
}
|
||||
}
|
||||
pfList.add (newPfs);
|
||||
}
|
||||
if (Constants::DEBUG >= 2 && obsFormulas.empty() == false) {
|
||||
Util::printAsteriskLine();
|
||||
cout << "AFTER EVIDENCE ABSORVED" << endl;
|
||||
for (unsigned i = 0; i < obsFormulas.size(); i++) {
|
||||
cout << " -> " << obsFormulas[i] << endl;
|
||||
}
|
||||
Util::printAsteriskLine();
|
||||
pfList.print();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
Parfactors
|
||||
FoveSolver::countNormalize (
|
||||
Parfactor* g,
|
||||
const LogVarSet& set)
|
||||
{
|
||||
Parfactors normPfs;
|
||||
if (set.empty()) {
|
||||
normPfs.push_back (new Parfactor (*g));
|
||||
} else {
|
||||
ConstraintTrees normCts = g->constr()->countNormalize (set);
|
||||
for (unsigned i = 0; i < normCts.size(); i++) {
|
||||
normPfs.push_back (new Parfactor (g, normCts[i]));
|
||||
}
|
||||
}
|
||||
return normPfs;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FoveSolver::runSolver (const Grounds& query)
|
||||
{
|
||||
shatterAgainstQuery (query);
|
||||
runWeakBayesBall (query);
|
||||
while (true) {
|
||||
if (Constants::DEBUG >= 2) {
|
||||
Util::printDashedLine();
|
||||
pfList_.print();
|
||||
LiftedOperator::printValidOps (pfList_, query);
|
||||
}
|
||||
LiftedOperator* op = getBestOperation (query);
|
||||
if (op == 0) {
|
||||
break;
|
||||
}
|
||||
if (Constants::DEBUG >= 2) {
|
||||
cout << "best operation: " << op->toString() << endl;
|
||||
}
|
||||
op->apply();
|
||||
delete op;
|
||||
}
|
||||
assert (pfList_.size() > 0);
|
||||
if (pfList_.size() > 1) {
|
||||
ParfactorList::iterator pfIter = pfList_.begin();
|
||||
pfIter ++;
|
||||
while (pfIter != pfList_.end()) {
|
||||
(*pfList_.begin())->multiply (**pfIter);
|
||||
++ pfIter;
|
||||
}
|
||||
}
|
||||
(*pfList_.begin())->reorderAccordingGrounds (query);
|
||||
}
|
||||
|
||||
|
||||
|
||||
LiftedOperator*
|
||||
FoveSolver::getBestOperation (const Grounds& query)
|
||||
{
|
||||
unsigned bestCost;
|
||||
LiftedOperator* bestOp = 0;
|
||||
vector<LiftedOperator*> validOps;
|
||||
validOps = LiftedOperator::getValidOps (pfList_, query);
|
||||
for (unsigned i = 0; i < validOps.size(); i++) {
|
||||
unsigned cost = validOps[i]->getCost();
|
||||
if ((bestOp == 0) || (cost < bestCost)) {
|
||||
bestOp = validOps[i];
|
||||
bestCost = cost;
|
||||
}
|
||||
}
|
||||
for (unsigned i = 0; i < validOps.size(); i++) {
|
||||
if (validOps[i] != bestOp) {
|
||||
delete validOps[i];
|
||||
}
|
||||
}
|
||||
return bestOp;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FoveSolver::runWeakBayesBall (const Grounds& query)
|
||||
{
|
||||
queue<unsigned> todo; // groups to process
|
||||
set<unsigned> done; // processed or in queue
|
||||
for (unsigned i = 0; i < query.size(); i++) {
|
||||
ParfactorList::iterator it = pfList_.begin();
|
||||
while (it != pfList_.end()) {
|
||||
int group = (*it)->findGroup (query[i]);
|
||||
if (group != -1) {
|
||||
todo.push (group);
|
||||
done.insert (group);
|
||||
break;
|
||||
}
|
||||
++ it;
|
||||
}
|
||||
}
|
||||
|
||||
set<Parfactor*> requiredPfs;
|
||||
while (todo.empty() == false) {
|
||||
unsigned group = todo.front();
|
||||
ParfactorList::iterator it = pfList_.begin();
|
||||
while (it != pfList_.end()) {
|
||||
if (Util::contains (requiredPfs, *it) == false &&
|
||||
(*it)->containsGroup (group)) {
|
||||
vector<unsigned> groups = (*it)->getAllGroups();
|
||||
for (unsigned i = 0; i < groups.size(); i++) {
|
||||
if (Util::contains (done, groups[i]) == false) {
|
||||
todo.push (groups[i]);
|
||||
done.insert (groups[i]);
|
||||
}
|
||||
}
|
||||
requiredPfs.insert (*it);
|
||||
}
|
||||
++ it;
|
||||
}
|
||||
todo.pop();
|
||||
}
|
||||
|
||||
ParfactorList::iterator it = pfList_.begin();
|
||||
while (it != pfList_.end()) {
|
||||
if (Util::contains (requiredPfs, *it) == false) {
|
||||
it = pfList_.removeAndDelete (it);
|
||||
} else {
|
||||
++ it;
|
||||
}
|
||||
}
|
||||
|
||||
if (Constants::DEBUG >= 2) {
|
||||
Util::printHeader ("REQUIRED PARFACTORS");
|
||||
pfList_.print();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FoveSolver::shatterAgainstQuery (const Grounds& query)
|
||||
{
|
||||
for (unsigned i = 0; i < query.size(); i++) {
|
||||
if (query[i].isAtom()) {
|
||||
continue;
|
||||
}
|
||||
bool found = false;
|
||||
Parfactors newPfs;
|
||||
ParfactorList::iterator it = pfList_.begin();
|
||||
while (it != pfList_.end()) {
|
||||
if ((*it)->containsGround (query[i])) {
|
||||
found = true;
|
||||
std::pair<ConstraintTree*, ConstraintTree*> split =
|
||||
(*it)->constr()->split (query[i].args(), query[i].arity());
|
||||
ConstraintTree* commCt = split.first;
|
||||
ConstraintTree* exclCt = split.second;
|
||||
newPfs.push_back (new Parfactor (*it, commCt));
|
||||
if (exclCt->empty() == false) {
|
||||
newPfs.push_back (new Parfactor (*it, exclCt));
|
||||
} else {
|
||||
delete exclCt;
|
||||
}
|
||||
it = pfList_.removeAndDelete (it);
|
||||
} else {
|
||||
++ it;
|
||||
}
|
||||
}
|
||||
if (found == false) {
|
||||
cerr << "error: could not find a parfactor with ground " ;
|
||||
cerr << "`" << query[i] << "'" << endl;
|
||||
exit (0);
|
||||
}
|
||||
pfList_.add (newPfs);
|
||||
}
|
||||
if (Constants::DEBUG >= 2) {
|
||||
cout << endl;
|
||||
Util::printAsteriskLine();
|
||||
cout << "SHATTERED AGAINST THE QUERY" << endl;
|
||||
for (unsigned i = 0; i < query.size(); i++) {
|
||||
cout << " -> " << query[i] << endl;
|
||||
}
|
||||
Util::printAsteriskLine();
|
||||
pfList_.print();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
Parfactors
|
||||
FoveSolver::absorve (
|
||||
ObservedFormula& obsFormula,
|
||||
Parfactor* g)
|
||||
{
|
||||
Parfactors absorvedPfs;
|
||||
const ProbFormulas& formulas = g->arguments();
|
||||
for (unsigned i = 0; i < formulas.size(); i++) {
|
||||
if (obsFormula.functor() == formulas[i].functor() &&
|
||||
obsFormula.arity() == formulas[i].arity()) {
|
||||
|
||||
if (obsFormula.isAtom()) {
|
||||
if (formulas.size() > 1) {
|
||||
g->absorveEvidence (formulas[i], obsFormula.evidence());
|
||||
} else {
|
||||
// hack to erase parfactor g
|
||||
absorvedPfs.push_back (0);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
g->constr()->moveToTop (formulas[i].logVars());
|
||||
std::pair<ConstraintTree*, ConstraintTree*> res
|
||||
= g->constr()->split (&(obsFormula.constr()), formulas[i].arity());
|
||||
ConstraintTree* commCt = res.first;
|
||||
ConstraintTree* exclCt = res.second;
|
||||
|
||||
if (commCt->empty() == false) {
|
||||
if (formulas.size() > 1) {
|
||||
LogVarSet excl = g->exclusiveLogVars (i);
|
||||
Parfactors countNormPfs = countNormalize (g, excl);
|
||||
for (unsigned j = 0; j < countNormPfs.size(); j++) {
|
||||
countNormPfs[j]->absorveEvidence (
|
||||
formulas[i], obsFormula.evidence());
|
||||
absorvedPfs.push_back (countNormPfs[j]);
|
||||
}
|
||||
} else {
|
||||
delete commCt;
|
||||
}
|
||||
if (exclCt->empty() == false) {
|
||||
absorvedPfs.push_back (new Parfactor (g, exclCt));
|
||||
} else {
|
||||
delete exclCt;
|
||||
}
|
||||
if (absorvedPfs.empty()) {
|
||||
// hack to erase parfactor g
|
||||
absorvedPfs.push_back (0);
|
||||
}
|
||||
break;
|
||||
} else {
|
||||
delete commCt;
|
||||
delete exclCt;
|
||||
}
|
||||
}
|
||||
}
|
||||
return absorvedPfs;
|
||||
}
|
||||
|
@@ -1,135 +0,0 @@
|
||||
#ifndef HORUS_FOVESOLVER_H
|
||||
#define HORUS_FOVESOLVER_H
|
||||
|
||||
|
||||
#include "ParfactorList.h"
|
||||
|
||||
|
||||
class LiftedOperator
|
||||
{
|
||||
public:
|
||||
virtual unsigned getCost (void) = 0;
|
||||
|
||||
virtual void apply (void) = 0;
|
||||
|
||||
virtual string toString (void) = 0;
|
||||
|
||||
static vector<LiftedOperator*> getValidOps (
|
||||
ParfactorList&, const Grounds&);
|
||||
|
||||
static void printValidOps (ParfactorList&, const Grounds&);
|
||||
};
|
||||
|
||||
|
||||
|
||||
class SumOutOperator : public LiftedOperator
|
||||
{
|
||||
public:
|
||||
SumOutOperator (unsigned group, ParfactorList& pfList)
|
||||
: group_(group), pfList_(pfList) { }
|
||||
|
||||
unsigned getCost (void);
|
||||
|
||||
void apply (void);
|
||||
|
||||
static vector<SumOutOperator*> getValidOps (
|
||||
ParfactorList&, const Grounds&);
|
||||
|
||||
string toString (void);
|
||||
|
||||
private:
|
||||
static bool validOp (unsigned, ParfactorList&, const Grounds&);
|
||||
|
||||
static vector<ParfactorList::iterator> parfactorsWithGroup (
|
||||
ParfactorList& pfList, unsigned group);
|
||||
|
||||
static bool isToEliminate (Parfactor*, unsigned, const Grounds&);
|
||||
|
||||
unsigned group_;
|
||||
ParfactorList& pfList_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class CountingOperator : public LiftedOperator
|
||||
{
|
||||
public:
|
||||
CountingOperator (
|
||||
ParfactorList::iterator pfIter,
|
||||
LogVar X,
|
||||
ParfactorList& pfList)
|
||||
: pfIter_(pfIter), X_(X), pfList_(pfList) { }
|
||||
|
||||
unsigned getCost (void);
|
||||
|
||||
void apply (void);
|
||||
|
||||
static vector<CountingOperator*> getValidOps (ParfactorList&);
|
||||
|
||||
string toString (void);
|
||||
|
||||
private:
|
||||
static bool validOp (Parfactor*, LogVar);
|
||||
|
||||
ParfactorList::iterator pfIter_;
|
||||
LogVar X_;
|
||||
ParfactorList& pfList_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class GroundOperator : public LiftedOperator
|
||||
{
|
||||
public:
|
||||
GroundOperator (
|
||||
ParfactorList::iterator pfIter,
|
||||
LogVar X,
|
||||
ParfactorList& pfList)
|
||||
: pfIter_(pfIter), X_(X), pfList_(pfList) { }
|
||||
|
||||
unsigned getCost (void);
|
||||
|
||||
void apply (void);
|
||||
|
||||
static vector<GroundOperator*> getValidOps (ParfactorList&);
|
||||
|
||||
string toString (void);
|
||||
|
||||
private:
|
||||
ParfactorList::iterator pfIter_;
|
||||
LogVar X_;
|
||||
ParfactorList& pfList_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class FoveSolver
|
||||
{
|
||||
public:
|
||||
FoveSolver (const ParfactorList& pfList) : pfList_(pfList) { }
|
||||
|
||||
Params getPosterioriOf (const Ground&);
|
||||
|
||||
Params getJointDistributionOf (const Grounds&);
|
||||
|
||||
static void absorveEvidence (
|
||||
ParfactorList& pfList, ObservedFormulas& obsFormulas);
|
||||
|
||||
static Parfactors countNormalize (Parfactor*, const LogVarSet&);
|
||||
|
||||
private:
|
||||
void runSolver (const Grounds&);
|
||||
|
||||
LiftedOperator* getBestOperation (const Grounds&);
|
||||
|
||||
void runWeakBayesBall (const Grounds&);
|
||||
|
||||
void shatterAgainstQuery (const Grounds&);
|
||||
|
||||
static Parfactors absorve (ObservedFormula&, Parfactor*);
|
||||
|
||||
ParfactorList pfList_;
|
||||
};
|
||||
|
||||
#endif // HORUS_FOVESOLVER_H
|
||||
|
@@ -1,144 +0,0 @@
|
||||
#include <cassert>
|
||||
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
|
||||
#include "Histogram.h"
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
HistogramSet::HistogramSet (unsigned size, unsigned range)
|
||||
{
|
||||
size_ = size;
|
||||
hist_.resize (range, 0);
|
||||
hist_[0] = size;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
HistogramSet::nextHistogram (void)
|
||||
{
|
||||
for (int i = hist_.size() - 2; i >= 0; i--) {
|
||||
if (hist_[i] > 0) {
|
||||
hist_[i] --;
|
||||
hist_[i + 1] = maxCount (i + 1);
|
||||
clearAfter (i + 1);
|
||||
break;
|
||||
}
|
||||
}
|
||||
assert (std::accumulate (hist_.begin(), hist_.end(), 0) == (int)size_);
|
||||
}
|
||||
|
||||
|
||||
|
||||
unsigned
|
||||
HistogramSet::operator[] (unsigned idx) const
|
||||
{
|
||||
assert (idx < hist_.size());
|
||||
return hist_[idx];
|
||||
}
|
||||
|
||||
|
||||
|
||||
unsigned
|
||||
HistogramSet::nrHistograms (void) const
|
||||
{
|
||||
return Util::nrCombinations (size_ + hist_.size() - 1, hist_.size() - 1);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
HistogramSet::reset (void)
|
||||
{
|
||||
std::fill (hist_.begin() + 1, hist_.end(), 0);
|
||||
hist_[0] = size_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
vector<Histogram>
|
||||
HistogramSet::getHistograms (unsigned N, unsigned R)
|
||||
{
|
||||
HistogramSet hs (N, R);
|
||||
unsigned H = hs.nrHistograms();
|
||||
vector<Histogram> histograms;
|
||||
histograms.reserve (H);
|
||||
for (unsigned i = 0; i < H; i++) {
|
||||
histograms.push_back (hs.hist_);
|
||||
hs.nextHistogram();
|
||||
}
|
||||
return histograms;
|
||||
}
|
||||
|
||||
|
||||
|
||||
unsigned
|
||||
HistogramSet::nrHistograms (unsigned N, unsigned R)
|
||||
{
|
||||
return Util::nrCombinations (N + R - 1, R - 1);
|
||||
}
|
||||
|
||||
|
||||
|
||||
unsigned
|
||||
HistogramSet::findIndex (
|
||||
const Histogram& h,
|
||||
const vector<Histogram>& hists)
|
||||
{
|
||||
vector<Histogram>::const_iterator it = std::lower_bound (
|
||||
hists.begin(), hists.end(), h, std::greater<Histogram>());
|
||||
assert (it != hists.end() && *it == h);
|
||||
return std::distance (hists.begin(), it);
|
||||
}
|
||||
|
||||
|
||||
|
||||
vector<double>
|
||||
HistogramSet::getNumAssigns (unsigned N, unsigned R)
|
||||
{
|
||||
HistogramSet hs (N, R);
|
||||
unsigned N_factorial = Util::factorial (N);
|
||||
unsigned H = hs.nrHistograms();
|
||||
vector<double> numAssigns;
|
||||
numAssigns.reserve (H);
|
||||
for (unsigned h = 0; h < H; h++) {
|
||||
unsigned prod = 1;
|
||||
for (unsigned r = 0; r < R; r++) {
|
||||
prod *= Util::factorial (hs[r]);
|
||||
}
|
||||
numAssigns.push_back (LogAware::tl (N_factorial / prod));
|
||||
hs.nextHistogram();
|
||||
}
|
||||
return numAssigns;
|
||||
}
|
||||
|
||||
|
||||
|
||||
ostream& operator<< (ostream &os, const HistogramSet& hs)
|
||||
{
|
||||
os << "#" << hs.hist_;
|
||||
return os;
|
||||
}
|
||||
|
||||
|
||||
|
||||
unsigned
|
||||
HistogramSet::maxCount (unsigned idx) const
|
||||
{
|
||||
unsigned sum = 0;
|
||||
for (unsigned i = 0; i < idx; i++) {
|
||||
sum += hist_[i];
|
||||
}
|
||||
return size_ - sum;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
HistogramSet::clearAfter (unsigned idx)
|
||||
{
|
||||
std::fill (hist_.begin() + idx + 1, hist_.end(), 0);
|
||||
}
|
||||
|
@@ -1,45 +0,0 @@
|
||||
#ifndef HORUS_HISTOGRAM_H
|
||||
#define HORUS_HISTOGRAM_H
|
||||
|
||||
#include <vector>
|
||||
#include <ostream>
|
||||
|
||||
using namespace std;
|
||||
|
||||
typedef vector<unsigned> Histogram;
|
||||
|
||||
class HistogramSet
|
||||
{
|
||||
public:
|
||||
HistogramSet (unsigned, unsigned);
|
||||
|
||||
void nextHistogram (void);
|
||||
|
||||
unsigned operator[] (unsigned idx) const;
|
||||
|
||||
unsigned nrHistograms (void) const;
|
||||
|
||||
void reset (void);
|
||||
|
||||
static vector<Histogram> getHistograms (unsigned ,unsigned);
|
||||
|
||||
static unsigned nrHistograms (unsigned, unsigned);
|
||||
|
||||
static unsigned findIndex (
|
||||
const Histogram&, const vector<Histogram>&);
|
||||
|
||||
static vector<double> getNumAssigns (unsigned, unsigned);
|
||||
|
||||
friend std::ostream& operator<< (ostream &os, const HistogramSet& hs);
|
||||
|
||||
private:
|
||||
unsigned maxCount (unsigned) const;
|
||||
|
||||
void clearAfter (unsigned);
|
||||
|
||||
unsigned size_;
|
||||
Histogram hist_;
|
||||
};
|
||||
|
||||
#endif // HORUS_HISTOGRAM_H
|
||||
|
@@ -1,76 +0,0 @@
|
||||
#ifndef HORUS_HORUS_H
|
||||
#define HORUS_HORUS_H
|
||||
|
||||
#include <limits>
|
||||
|
||||
#include <vector>
|
||||
|
||||
#define DISALLOW_COPY_AND_ASSIGN(TypeName) \
|
||||
TypeName(const TypeName&); \
|
||||
void operator=(const TypeName&)
|
||||
|
||||
using namespace std;
|
||||
|
||||
class Var;
|
||||
class Factor;
|
||||
class VarNode;
|
||||
class FacNode;
|
||||
|
||||
typedef vector<double> Params;
|
||||
typedef unsigned VarId;
|
||||
typedef vector<VarId> VarIds;
|
||||
typedef vector<Var*> Vars;
|
||||
typedef vector<VarNode*> VarNodes;
|
||||
typedef vector<FacNode*> FacNodes;
|
||||
typedef vector<Factor*> Factors;
|
||||
typedef vector<string> States;
|
||||
typedef vector<unsigned> Ranges;
|
||||
|
||||
|
||||
enum InfAlgorithms
|
||||
{
|
||||
VE, // variable elimination
|
||||
BP, // belief propagation
|
||||
CBP // counting belief propagation
|
||||
};
|
||||
|
||||
|
||||
namespace Globals {
|
||||
|
||||
extern bool logDomain;
|
||||
|
||||
extern InfAlgorithms infAlgorithm;
|
||||
|
||||
};
|
||||
|
||||
|
||||
namespace Constants {
|
||||
|
||||
// level of debug information
|
||||
const unsigned DEBUG = 0;
|
||||
|
||||
const int NO_EVIDENCE = -1;
|
||||
|
||||
// number of digits to show when printing a parameter
|
||||
const unsigned PRECISION = 5;
|
||||
|
||||
const bool COLLECT_STATS = false;
|
||||
|
||||
};
|
||||
|
||||
|
||||
namespace BpOptions
|
||||
{
|
||||
enum Schedule {
|
||||
SEQ_FIXED,
|
||||
SEQ_RANDOM,
|
||||
PARALLEL,
|
||||
MAX_RESIDUAL
|
||||
};
|
||||
extern Schedule schedule;
|
||||
extern double accuracy;
|
||||
extern unsigned maxIter;
|
||||
}
|
||||
|
||||
#endif // HORUS_HORUS_H
|
||||
|
@@ -1,171 +0,0 @@
|
||||
#include <cstdlib>
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "FactorGraph.h"
|
||||
#include "VarElimSolver.h"
|
||||
#include "BpSolver.h"
|
||||
#include "CbpSolver.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
void processArguments (FactorGraph&, int, const char* []);
|
||||
void runSolver (const FactorGraph&, const VarIds&);
|
||||
|
||||
const string USAGE = "usage: \
|
||||
./hcli ve|bp|cbp NETWORK_FILE [VARIABLE | OBSERVED_VARIABLE=EVIDENCE]..." ;
|
||||
|
||||
|
||||
int
|
||||
main (int argc, const char* argv[])
|
||||
{
|
||||
if (argc <= 1) {
|
||||
cerr << "error: no solver specified" << endl;
|
||||
cerr << "error: no graphical model specified" << endl;
|
||||
cerr << USAGE << endl;
|
||||
exit (0);
|
||||
}
|
||||
if (argc <= 2) {
|
||||
cerr << "error: no graphical model specified" << endl;
|
||||
cerr << USAGE << endl;
|
||||
exit (0);
|
||||
}
|
||||
string solver (argv[1]);
|
||||
if (solver == "ve") {
|
||||
Globals::infAlgorithm = InfAlgorithms::VE;
|
||||
} else if (solver == "bp") {
|
||||
Globals::infAlgorithm = InfAlgorithms::BP;
|
||||
} else if (solver == "cbp") {
|
||||
Globals::infAlgorithm = InfAlgorithms::CBP;
|
||||
} else {
|
||||
cerr << "error: unknow solver `" << solver << "'" << endl ;
|
||||
cerr << USAGE << endl;
|
||||
exit(0);
|
||||
}
|
||||
string fileName (argv[2]);
|
||||
string extension = fileName.substr (
|
||||
fileName.find_last_of ('.') + 1);
|
||||
FactorGraph fg;
|
||||
if (extension == "uai") {
|
||||
fg.readFromUaiFormat (fileName.c_str());
|
||||
} else if (extension == "fg") {
|
||||
fg.readFromLibDaiFormat (fileName.c_str());
|
||||
} else {
|
||||
cerr << "error: the graphical model must be defined either " ;
|
||||
cerr << "in a UAI or libDAI file" << endl;
|
||||
exit (0);
|
||||
}
|
||||
processArguments (fg, argc, argv);
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
processArguments (FactorGraph& fg, int argc, const char* argv[])
|
||||
{
|
||||
VarIds queryIds;
|
||||
for (int i = 3; i < argc; i++) {
|
||||
const string& arg = argv[i];
|
||||
if (arg.find ('=') == std::string::npos) {
|
||||
if (!Util::isInteger (arg)) {
|
||||
cerr << "error: `" << arg << "' " ;
|
||||
cerr << "is not a valid variable id" ;
|
||||
cerr << endl;
|
||||
exit (0);
|
||||
}
|
||||
VarId vid;
|
||||
stringstream ss;
|
||||
ss << arg;
|
||||
ss >> vid;
|
||||
VarNode* queryVar = fg.getVarNode (vid);
|
||||
if (queryVar) {
|
||||
queryIds.push_back (vid);
|
||||
} else {
|
||||
cerr << "error: there isn't a variable with " ;
|
||||
cerr << "`" << vid << "' as id" ;
|
||||
cerr << endl;
|
||||
exit (0);
|
||||
}
|
||||
} else {
|
||||
size_t pos = arg.find ('=');
|
||||
if (arg.substr (0, pos).empty()) {
|
||||
cerr << "error: missing left argument" << endl;
|
||||
cerr << USAGE << endl;
|
||||
exit (0);
|
||||
}
|
||||
if (arg.substr (pos + 1).empty()) {
|
||||
cerr << "error: missing right argument" << endl;
|
||||
cerr << USAGE << endl;
|
||||
exit (0);
|
||||
}
|
||||
if (!Util::isInteger (arg.substr (0, pos))) {
|
||||
cerr << "error: `" << arg.substr (0, pos) << "' " ;
|
||||
cerr << "is not a variable id" ;
|
||||
cerr << endl;
|
||||
exit (0);
|
||||
}
|
||||
VarId vid;
|
||||
stringstream ss;
|
||||
ss << arg.substr (0, pos);
|
||||
ss >> vid;
|
||||
VarNode* var = fg.getVarNode (vid);
|
||||
if (var) {
|
||||
if (!Util::isInteger (arg.substr (pos + 1))) {
|
||||
cerr << "error: `" << arg.substr (pos + 1) << "' " ;
|
||||
cerr << "is not a state index" ;
|
||||
cerr << endl;
|
||||
exit (0);
|
||||
}
|
||||
int stateIndex;
|
||||
stringstream ss;
|
||||
ss << arg.substr (pos + 1);
|
||||
ss >> stateIndex;
|
||||
if (var->isValidState (stateIndex)) {
|
||||
var->setEvidence (stateIndex);
|
||||
} else {
|
||||
cerr << "error: `" << stateIndex << "' " ;
|
||||
cerr << "is not a valid state index for variable " ;
|
||||
cerr << "`" << var->varId() << "'" ;
|
||||
cerr << endl;
|
||||
exit (0);
|
||||
}
|
||||
} else {
|
||||
cerr << "error: there isn't a variable with " ;
|
||||
cerr << "`" << vid << "' as id" ;
|
||||
cerr << endl;
|
||||
exit (0);
|
||||
}
|
||||
}
|
||||
}
|
||||
runSolver (fg, queryIds);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
runSolver (const FactorGraph& fg, const VarIds& queryIds)
|
||||
{
|
||||
Solver* solver = 0;
|
||||
switch (Globals::infAlgorithm) {
|
||||
case InfAlgorithms::VE:
|
||||
solver = new VarElimSolver (fg);
|
||||
break;
|
||||
case InfAlgorithms::BP:
|
||||
solver = new BpSolver (fg);
|
||||
break;
|
||||
case InfAlgorithms::CBP:
|
||||
solver = new CbpSolver (fg);
|
||||
break;
|
||||
default:
|
||||
assert (false);
|
||||
}
|
||||
if (queryIds.size() == 0) {
|
||||
solver->printAllPosterioris();
|
||||
} else {
|
||||
solver->printAnswer (queryIds);
|
||||
}
|
||||
delete solver;
|
||||
}
|
||||
|
@@ -1,622 +0,0 @@
|
||||
#include <cstdlib>
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include <YapInterface.h>
|
||||
|
||||
#include "ParfactorList.h"
|
||||
#include "FactorGraph.h"
|
||||
#include "FoveSolver.h"
|
||||
#include "VarElimSolver.h"
|
||||
#include "BpSolver.h"
|
||||
#include "CbpSolver.h"
|
||||
#include "ElimGraph.h"
|
||||
#include "BayesBall.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
||||
typedef std::pair<ParfactorList*, ObservedFormulas*> LiftedNetwork;
|
||||
|
||||
|
||||
Params readParameters (YAP_Term);
|
||||
|
||||
vector<unsigned> readUnsignedList (YAP_Term);
|
||||
|
||||
void readLiftedEvidence (YAP_Term, ObservedFormulas&);
|
||||
|
||||
Parfactor* readParfactor (YAP_Term);
|
||||
|
||||
void runVeSolver (FactorGraph* fg, const vector<VarIds>& tasks,
|
||||
vector<Params>& results);
|
||||
|
||||
void runBpSolver (FactorGraph* fg, const vector<VarIds>& tasks,
|
||||
vector<Params>& results);
|
||||
|
||||
|
||||
|
||||
|
||||
vector<unsigned>
|
||||
readUnsignedList (YAP_Term list)
|
||||
{
|
||||
vector<unsigned> vec;
|
||||
while (list != YAP_TermNil()) {
|
||||
vec.push_back ((unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (list)));
|
||||
list = YAP_TailOfTerm (list);
|
||||
}
|
||||
return vec;
|
||||
}
|
||||
|
||||
|
||||
|
||||
int createLiftedNetwork (void)
|
||||
{
|
||||
Parfactors parfactors;
|
||||
YAP_Term parfactorList = YAP_ARG1;
|
||||
while (parfactorList != YAP_TermNil()) {
|
||||
YAP_Term pfTerm = YAP_HeadOfTerm (parfactorList);
|
||||
parfactors.push_back (readParfactor (pfTerm));
|
||||
parfactorList = YAP_TailOfTerm (parfactorList);
|
||||
}
|
||||
|
||||
// LiftedUtils::printSymbolDictionary();
|
||||
if (Constants::DEBUG > 2) {
|
||||
// Util::printHeader ("INITIAL PARFACTORS");
|
||||
// for (unsigned i = 0; i < parfactors.size(); i++) {
|
||||
// parfactors[i]->print();
|
||||
// }
|
||||
}
|
||||
|
||||
ParfactorList* pfList = new ParfactorList (parfactors);
|
||||
|
||||
if (Constants::DEBUG >= 2) {
|
||||
Util::printHeader ("SHATTERED PARFACTORS");
|
||||
pfList->print();
|
||||
}
|
||||
|
||||
// read evidence
|
||||
ObservedFormulas* obsFormulas = new ObservedFormulas();
|
||||
readLiftedEvidence (YAP_ARG2, *(obsFormulas));
|
||||
|
||||
LiftedNetwork* net = new LiftedNetwork (pfList, obsFormulas);
|
||||
YAP_Int p = (YAP_Int) (net);
|
||||
return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG3);
|
||||
}
|
||||
|
||||
|
||||
|
||||
Parfactor* readParfactor (YAP_Term pfTerm)
|
||||
{
|
||||
// read dist id
|
||||
unsigned distId = YAP_IntOfTerm (YAP_ArgOfTerm (1, pfTerm));
|
||||
|
||||
// read the ranges
|
||||
Ranges ranges;
|
||||
YAP_Term rangeList = YAP_ArgOfTerm (3, pfTerm);
|
||||
while (rangeList != YAP_TermNil()) {
|
||||
unsigned range = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (rangeList));
|
||||
ranges.push_back (range);
|
||||
rangeList = YAP_TailOfTerm (rangeList);
|
||||
}
|
||||
|
||||
// read parametric random vars
|
||||
ProbFormulas formulas;
|
||||
unsigned count = 0;
|
||||
unordered_map<YAP_Term, LogVar> lvMap;
|
||||
YAP_Term pvList = YAP_ArgOfTerm (2, pfTerm);
|
||||
while (pvList != YAP_TermNil()) {
|
||||
YAP_Term formulaTerm = YAP_HeadOfTerm (pvList);
|
||||
if (YAP_IsAtomTerm (formulaTerm)) {
|
||||
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (formulaTerm)));
|
||||
Symbol functor = LiftedUtils::getSymbol (name);
|
||||
formulas.push_back (ProbFormula (functor, ranges[count]));
|
||||
} else {
|
||||
LogVars logVars;
|
||||
YAP_Functor yapFunctor = YAP_FunctorOfTerm (formulaTerm);
|
||||
string name ((char*) YAP_AtomName (YAP_NameOfFunctor (yapFunctor)));
|
||||
Symbol functor = LiftedUtils::getSymbol (name);
|
||||
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
|
||||
for (unsigned i = 1; i <= arity; i++) {
|
||||
YAP_Term ti = YAP_ArgOfTerm (i, formulaTerm);
|
||||
unordered_map<YAP_Term, LogVar>::iterator it = lvMap.find (ti);
|
||||
if (it != lvMap.end()) {
|
||||
logVars.push_back (it->second);
|
||||
} else {
|
||||
unsigned newLv = lvMap.size();
|
||||
lvMap[ti] = newLv;
|
||||
logVars.push_back (newLv);
|
||||
}
|
||||
}
|
||||
formulas.push_back (ProbFormula (functor, logVars, ranges[count]));
|
||||
}
|
||||
count ++;
|
||||
pvList = YAP_TailOfTerm (pvList);
|
||||
}
|
||||
|
||||
// read the parameters
|
||||
const Params& params = readParameters (YAP_ArgOfTerm (4, pfTerm));
|
||||
|
||||
// read the constraint
|
||||
Tuples tuples;
|
||||
if (lvMap.size() >= 1) {
|
||||
YAP_Term tupleList = YAP_ArgOfTerm (5, pfTerm);
|
||||
while (tupleList != YAP_TermNil()) {
|
||||
YAP_Term term = YAP_HeadOfTerm (tupleList);
|
||||
assert (YAP_IsApplTerm (term));
|
||||
YAP_Functor yapFunctor = YAP_FunctorOfTerm (term);
|
||||
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
|
||||
assert (lvMap.size() == arity);
|
||||
Tuple tuple (arity);
|
||||
for (unsigned i = 1; i <= arity; i++) {
|
||||
YAP_Term ti = YAP_ArgOfTerm (i, term);
|
||||
if (YAP_IsAtomTerm (ti) == false) {
|
||||
cerr << "error: constraint has free variables" << endl;
|
||||
abort();
|
||||
}
|
||||
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ti)));
|
||||
tuple[i - 1] = LiftedUtils::getSymbol (name);
|
||||
}
|
||||
tuples.push_back (tuple);
|
||||
tupleList = YAP_TailOfTerm (tupleList);
|
||||
}
|
||||
}
|
||||
return new Parfactor (formulas, params, tuples, distId);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void readLiftedEvidence (
|
||||
YAP_Term observedList,
|
||||
ObservedFormulas& obsFormulas)
|
||||
{
|
||||
while (observedList != YAP_TermNil()) {
|
||||
YAP_Term pair = YAP_HeadOfTerm (observedList);
|
||||
YAP_Term ground = YAP_ArgOfTerm (1, pair);
|
||||
Symbol functor;
|
||||
Symbols args;
|
||||
if (YAP_IsAtomTerm (ground)) {
|
||||
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ground)));
|
||||
functor = LiftedUtils::getSymbol (name);
|
||||
} else {
|
||||
assert (YAP_IsApplTerm (ground));
|
||||
YAP_Functor yapFunctor = YAP_FunctorOfTerm (ground);
|
||||
string name ((char*) (YAP_AtomName (YAP_NameOfFunctor (yapFunctor))));
|
||||
functor = LiftedUtils::getSymbol (name);
|
||||
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
|
||||
for (unsigned i = 1; i <= arity; i++) {
|
||||
YAP_Term ti = YAP_ArgOfTerm (i, ground);
|
||||
assert (YAP_IsAtomTerm (ti));
|
||||
string arg ((char *) YAP_AtomName (YAP_AtomOfTerm (ti)));
|
||||
args.push_back (LiftedUtils::getSymbol (arg));
|
||||
}
|
||||
}
|
||||
unsigned evidence = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (2, pair));
|
||||
bool found = false;
|
||||
for (unsigned i = 0; i < obsFormulas.size(); i++) {
|
||||
if (obsFormulas[i].functor() == functor &&
|
||||
obsFormulas[i].arity() == args.size() &&
|
||||
obsFormulas[i].evidence() == evidence) {
|
||||
obsFormulas[i].addTuple (args);
|
||||
found = true;
|
||||
}
|
||||
}
|
||||
if (found == false) {
|
||||
obsFormulas.push_back (ObservedFormula (functor, evidence, args));
|
||||
}
|
||||
observedList = YAP_TailOfTerm (observedList);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
int
|
||||
createGroundNetwork (void)
|
||||
{
|
||||
string factorsType ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG1)));
|
||||
bool fromBayesNet = factorsType == "bayes";
|
||||
FactorGraph* fg = new FactorGraph (fromBayesNet);
|
||||
YAP_Term factorList = YAP_ARG2;
|
||||
while (factorList != YAP_TermNil()) {
|
||||
YAP_Term factor = YAP_HeadOfTerm (factorList);
|
||||
// read the var ids
|
||||
VarIds varIds = readUnsignedList (YAP_ArgOfTerm (1, factor));
|
||||
// read the ranges
|
||||
Ranges ranges = readUnsignedList (YAP_ArgOfTerm (2, factor));
|
||||
// read the parameters
|
||||
Params params = readParameters (YAP_ArgOfTerm (3, factor));
|
||||
// read dist id
|
||||
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (4, factor));
|
||||
fg->addFactor (Factor (varIds, ranges, params, distId));
|
||||
factorList = YAP_TailOfTerm (factorList);
|
||||
}
|
||||
|
||||
YAP_Term evidenceList = YAP_ARG3;
|
||||
while (evidenceList != YAP_TermNil()) {
|
||||
YAP_Term evTerm = YAP_HeadOfTerm (evidenceList);
|
||||
unsigned vid = (unsigned) YAP_IntOfTerm ((YAP_ArgOfTerm (1, evTerm)));
|
||||
unsigned ev = (unsigned) YAP_IntOfTerm ((YAP_ArgOfTerm (2, evTerm)));
|
||||
assert (fg->getVarNode (vid));
|
||||
fg->getVarNode (vid)->setEvidence (ev);
|
||||
evidenceList = YAP_TailOfTerm (evidenceList);
|
||||
}
|
||||
|
||||
YAP_Int p = (YAP_Int) (fg);
|
||||
return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG4);
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
readParameters (YAP_Term paramL)
|
||||
{
|
||||
Params params;
|
||||
assert (YAP_IsPairTerm (paramL));
|
||||
while (paramL != YAP_TermNil()) {
|
||||
params.push_back ((double) YAP_FloatOfTerm (YAP_HeadOfTerm (paramL)));
|
||||
paramL = YAP_TailOfTerm (paramL);
|
||||
}
|
||||
if (Globals::logDomain) {
|
||||
Util::toLog (params);
|
||||
}
|
||||
return params;
|
||||
}
|
||||
|
||||
|
||||
|
||||
int
|
||||
runLiftedSolver (void)
|
||||
{
|
||||
LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1);
|
||||
YAP_Term taskList = YAP_ARG2;
|
||||
vector<Params> results;
|
||||
ParfactorList pfListCopy (*network->first);
|
||||
FoveSolver::absorveEvidence (pfListCopy, *network->second);
|
||||
while (taskList != YAP_TermNil()) {
|
||||
Grounds queryVars;
|
||||
YAP_Term jointList = YAP_HeadOfTerm (taskList);
|
||||
while (jointList != YAP_TermNil()) {
|
||||
YAP_Term ground = YAP_HeadOfTerm (jointList);
|
||||
if (YAP_IsAtomTerm (ground)) {
|
||||
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ground)));
|
||||
queryVars.push_back (Ground (LiftedUtils::getSymbol (name)));
|
||||
} else {
|
||||
assert (YAP_IsApplTerm (ground));
|
||||
YAP_Functor yapFunctor = YAP_FunctorOfTerm (ground);
|
||||
string name ((char*) (YAP_AtomName (YAP_NameOfFunctor (yapFunctor))));
|
||||
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
|
||||
Symbol functor = LiftedUtils::getSymbol (name);
|
||||
Symbols args;
|
||||
for (unsigned i = 1; i <= arity; i++) {
|
||||
YAP_Term ti = YAP_ArgOfTerm (i, ground);
|
||||
assert (YAP_IsAtomTerm (ti));
|
||||
string arg ((char *) YAP_AtomName (YAP_AtomOfTerm (ti)));
|
||||
args.push_back (LiftedUtils::getSymbol (arg));
|
||||
}
|
||||
queryVars.push_back (Ground (functor, args));
|
||||
}
|
||||
jointList = YAP_TailOfTerm (jointList);
|
||||
}
|
||||
FoveSolver solver (pfListCopy);
|
||||
if (queryVars.size() == 1) {
|
||||
results.push_back (solver.getPosterioriOf (queryVars[0]));
|
||||
} else {
|
||||
results.push_back (solver.getJointDistributionOf (queryVars));
|
||||
}
|
||||
taskList = YAP_TailOfTerm (taskList);
|
||||
}
|
||||
|
||||
YAP_Term list = YAP_TermNil();
|
||||
for (int i = results.size() - 1; i >= 0; i--) {
|
||||
const Params& beliefs = results[i];
|
||||
YAP_Term queryBeliefsL = YAP_TermNil();
|
||||
for (int j = beliefs.size() - 1; j >= 0; j--) {
|
||||
YAP_Int sl1 = YAP_InitSlot (list);
|
||||
YAP_Term belief = YAP_MkFloatTerm (beliefs[j]);
|
||||
queryBeliefsL = YAP_MkPairTerm (belief, queryBeliefsL);
|
||||
list = YAP_GetFromSlot (sl1);
|
||||
YAP_RecoverSlots (1);
|
||||
}
|
||||
list = YAP_MkPairTerm (queryBeliefsL, list);
|
||||
}
|
||||
|
||||
return YAP_Unify (list, YAP_ARG3);
|
||||
}
|
||||
|
||||
|
||||
|
||||
int
|
||||
runGroundSolver (void)
|
||||
{
|
||||
FactorGraph* fg = (FactorGraph*) YAP_IntOfTerm (YAP_ARG1);
|
||||
|
||||
vector<VarIds> tasks;
|
||||
YAP_Term taskList = YAP_ARG2;
|
||||
while (taskList != YAP_TermNil()) {
|
||||
tasks.push_back (readUnsignedList (YAP_HeadOfTerm (taskList)));
|
||||
taskList = YAP_TailOfTerm (taskList);
|
||||
}
|
||||
|
||||
vector<Params> results;
|
||||
if (Globals::infAlgorithm == InfAlgorithms::VE) {
|
||||
runVeSolver (fg, tasks, results);
|
||||
} else {
|
||||
runBpSolver (fg, tasks, results);
|
||||
}
|
||||
|
||||
YAP_Term list = YAP_TermNil();
|
||||
for (int i = results.size() - 1; i >= 0; i--) {
|
||||
const Params& beliefs = results[i];
|
||||
YAP_Term queryBeliefsL = YAP_TermNil();
|
||||
for (int j = beliefs.size() - 1; j >= 0; j--) {
|
||||
YAP_Int sl1 = YAP_InitSlot (list);
|
||||
YAP_Term belief = YAP_MkFloatTerm (beliefs[j]);
|
||||
queryBeliefsL = YAP_MkPairTerm (belief, queryBeliefsL);
|
||||
list = YAP_GetFromSlot (sl1);
|
||||
YAP_RecoverSlots (1);
|
||||
}
|
||||
list = YAP_MkPairTerm (queryBeliefsL, list);
|
||||
}
|
||||
return YAP_Unify (list, YAP_ARG3);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void runVeSolver (
|
||||
FactorGraph* fg,
|
||||
const vector<VarIds>& tasks,
|
||||
vector<Params>& results)
|
||||
{
|
||||
results.reserve (tasks.size());
|
||||
for (unsigned i = 0; i < tasks.size(); i++) {
|
||||
FactorGraph* mfg = fg;
|
||||
if (fg->isFromBayesNetwork()) {
|
||||
mfg = BayesBall::getMinimalFactorGraph (*fg, tasks[i]);
|
||||
}
|
||||
VarElimSolver solver (*mfg);
|
||||
results.push_back (solver.solveQuery (tasks[i]));
|
||||
if (fg->isFromBayesNetwork()) {
|
||||
delete mfg;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void runBpSolver (
|
||||
FactorGraph* fg,
|
||||
const vector<VarIds>& tasks,
|
||||
vector<Params>& results)
|
||||
{
|
||||
std::set<VarId> vids;
|
||||
for (unsigned i = 0; i < tasks.size(); i++) {
|
||||
Util::addToSet (vids, tasks[i]);
|
||||
}
|
||||
Solver* solver = 0;
|
||||
FactorGraph* mfg = fg;
|
||||
if (fg->isFromBayesNetwork()) {
|
||||
mfg = BayesBall::getMinimalFactorGraph (
|
||||
*fg, VarIds (vids.begin(),vids.end()));
|
||||
}
|
||||
if (Globals::infAlgorithm == InfAlgorithms::BP) {
|
||||
solver = new BpSolver (*mfg);
|
||||
} else if (Globals::infAlgorithm == InfAlgorithms::CBP) {
|
||||
CFactorGraph::checkForIdenticalFactors = false;
|
||||
solver = new CbpSolver (*mfg);
|
||||
} else {
|
||||
cerr << "error: unknow solver" << endl;
|
||||
abort();
|
||||
}
|
||||
results.reserve (tasks.size());
|
||||
for (unsigned i = 0; i < tasks.size(); i++) {
|
||||
results.push_back (solver->solveQuery (tasks[i]));
|
||||
}
|
||||
if (fg->isFromBayesNetwork()) {
|
||||
delete mfg;
|
||||
}
|
||||
delete solver;
|
||||
}
|
||||
|
||||
|
||||
|
||||
int
|
||||
setParfactorsParams (void)
|
||||
{
|
||||
LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1);
|
||||
ParfactorList* pfList = network->first;
|
||||
YAP_Term distList = YAP_ARG2;
|
||||
unordered_map<unsigned, Params> paramsMap;
|
||||
while (distList != YAP_TermNil()) {
|
||||
YAP_Term dist = YAP_HeadOfTerm (distList);
|
||||
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, dist));
|
||||
assert (Util::contains (paramsMap, distId) == false);
|
||||
paramsMap[distId] = readParameters (YAP_ArgOfTerm (2, dist));
|
||||
distList = YAP_TailOfTerm (distList);
|
||||
}
|
||||
ParfactorList::iterator it = pfList->begin();
|
||||
while (it != pfList->end()) {
|
||||
assert (Util::contains (paramsMap, (*it)->distId()));
|
||||
// (*it)->setParams (paramsMap[(*it)->distId()]);
|
||||
++ it;
|
||||
}
|
||||
return TRUE;
|
||||
}
|
||||
|
||||
|
||||
|
||||
int
|
||||
setFactorsParams (void)
|
||||
{
|
||||
return TRUE; // TODO
|
||||
FactorGraph* fg = (FactorGraph*) YAP_IntOfTerm (YAP_ARG1);
|
||||
YAP_Term distList = YAP_ARG2;
|
||||
unordered_map<unsigned, Params> paramsMap;
|
||||
while (distList != YAP_TermNil()) {
|
||||
YAP_Term dist = YAP_HeadOfTerm (distList);
|
||||
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, dist));
|
||||
assert (Util::contains (paramsMap, distId) == false);
|
||||
paramsMap[distId] = readParameters (YAP_ArgOfTerm (2, dist));
|
||||
distList = YAP_TailOfTerm (distList);
|
||||
}
|
||||
const FacNodes& facNodes = fg->facNodes();
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
unsigned distId = facNodes[i]->factor().distId();
|
||||
assert (Util::contains (paramsMap, distId));
|
||||
facNodes[i]->factor().setParams (paramsMap[distId]);
|
||||
}
|
||||
return TRUE;
|
||||
}
|
||||
|
||||
|
||||
|
||||
int
|
||||
setVarsInformation (void)
|
||||
{
|
||||
Var::clearVarsInfo();
|
||||
YAP_Term labelsL = YAP_ARG1;
|
||||
vector<string> labels;
|
||||
while (labelsL != YAP_TermNil()) {
|
||||
YAP_Atom atom = YAP_AtomOfTerm (YAP_HeadOfTerm (labelsL));
|
||||
labels.push_back ((char*) YAP_AtomName (atom));
|
||||
labelsL = YAP_TailOfTerm (labelsL);
|
||||
}
|
||||
unsigned count = 0;
|
||||
YAP_Term stateNamesL = YAP_ARG2;
|
||||
while (stateNamesL != YAP_TermNil()) {
|
||||
States states;
|
||||
YAP_Term namesL = YAP_HeadOfTerm (stateNamesL);
|
||||
while (namesL != YAP_TermNil()) {
|
||||
YAP_Atom atom = YAP_AtomOfTerm (YAP_HeadOfTerm (namesL));
|
||||
states.push_back ((char*) YAP_AtomName (atom));
|
||||
namesL = YAP_TailOfTerm (namesL);
|
||||
}
|
||||
Var::addVarInfo (count, labels[count], states);
|
||||
count ++;
|
||||
stateNamesL = YAP_TailOfTerm (stateNamesL);
|
||||
}
|
||||
return TRUE;
|
||||
}
|
||||
|
||||
|
||||
|
||||
int
|
||||
setHorusFlag (void)
|
||||
{
|
||||
string key ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG1)));
|
||||
if (key == "inf_alg") {
|
||||
string value ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG2)));
|
||||
if ( value == "ve") {
|
||||
Globals::infAlgorithm = InfAlgorithms::VE;
|
||||
} else if (value == "bp") {
|
||||
Globals::infAlgorithm = InfAlgorithms::BP;
|
||||
} else if (value == "cbp") {
|
||||
Globals::infAlgorithm = InfAlgorithms::CBP;
|
||||
} else {
|
||||
cerr << "warning: invalid value `" << value << "' " ;
|
||||
cerr << "for `" << key << "'" << endl;
|
||||
return FALSE;
|
||||
}
|
||||
} else if (key == "elim_heuristic") {
|
||||
string value ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG2)));
|
||||
if ( value == "min_neighbors") {
|
||||
ElimGraph::setEliminationHeuristic (ElimHeuristic::MIN_NEIGHBORS);
|
||||
} else if (value == "min_weight") {
|
||||
ElimGraph::setEliminationHeuristic (ElimHeuristic::MIN_WEIGHT);
|
||||
} else if (value == "min_fill") {
|
||||
ElimGraph::setEliminationHeuristic (ElimHeuristic::MIN_FILL);
|
||||
} else if (value == "weighted_min_fill") {
|
||||
ElimGraph::setEliminationHeuristic (ElimHeuristic::WEIGHTED_MIN_FILL);
|
||||
} else {
|
||||
cerr << "warning: invalid value `" << value << "' " ;
|
||||
cerr << "for `" << key << "'" << endl;
|
||||
return FALSE;
|
||||
}
|
||||
} else if (key == "schedule") {
|
||||
string value ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG2)));
|
||||
if ( value == "seq_fixed") {
|
||||
BpOptions::schedule = BpOptions::Schedule::SEQ_FIXED;
|
||||
} else if (value == "seq_random") {
|
||||
BpOptions::schedule = BpOptions::Schedule::SEQ_RANDOM;
|
||||
} else if (value == "parallel") {
|
||||
BpOptions::schedule = BpOptions::Schedule::PARALLEL;
|
||||
} else if (value == "max_residual") {
|
||||
BpOptions::schedule = BpOptions::Schedule::MAX_RESIDUAL;
|
||||
} else {
|
||||
cerr << "warning: invalid value `" << value << "' " ;
|
||||
cerr << "for `" << key << "'" << endl;
|
||||
return FALSE;
|
||||
}
|
||||
} else if (key == "accuracy") {
|
||||
BpOptions::accuracy = (double) YAP_FloatOfTerm (YAP_ARG2);
|
||||
} else if (key == "max_iter") {
|
||||
BpOptions::maxIter = (int) YAP_IntOfTerm (YAP_ARG2);
|
||||
} else if (key == "use_logarithms") {
|
||||
string value ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG2)));
|
||||
if ( value == "true") {
|
||||
Globals::logDomain = true;
|
||||
} else if (value == "false") {
|
||||
Globals::logDomain = false;
|
||||
} else {
|
||||
cerr << "warning: invalid value `" << value << "' " ;
|
||||
cerr << "for `" << key << "'" << endl;
|
||||
return FALSE;
|
||||
}
|
||||
} else if (key == "order_factor_variables") {
|
||||
string value ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG2)));
|
||||
if ( value == "true") {
|
||||
FactorGraph::orderFactorVariables = true;
|
||||
} else if (value == "false") {
|
||||
FactorGraph::orderFactorVariables = false;
|
||||
} else {
|
||||
cerr << "warning: invalid value `" << value << "' " ;
|
||||
cerr << "for `" << key << "'" << endl;
|
||||
return FALSE;
|
||||
}
|
||||
} else {
|
||||
cerr << "warning: invalid key `" << key << "'" << endl;
|
||||
return FALSE;
|
||||
}
|
||||
return TRUE;
|
||||
}
|
||||
|
||||
|
||||
|
||||
int
|
||||
freeGroundNetwork (void)
|
||||
{
|
||||
delete (FactorGraph*) YAP_IntOfTerm (YAP_ARG1);
|
||||
return TRUE;
|
||||
}
|
||||
|
||||
|
||||
|
||||
int
|
||||
freeParfactors (void)
|
||||
{
|
||||
LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1);
|
||||
delete network->first;
|
||||
delete network->second;
|
||||
delete network;
|
||||
return TRUE;
|
||||
}
|
||||
|
||||
|
||||
|
||||
extern "C" void
|
||||
init_predicates (void)
|
||||
{
|
||||
YAP_UserCPredicate ("create_lifted_network", createLiftedNetwork, 3);
|
||||
YAP_UserCPredicate ("create_ground_network", createGroundNetwork, 4);
|
||||
YAP_UserCPredicate ("run_lifted_solver", runLiftedSolver, 3);
|
||||
YAP_UserCPredicate ("run_ground_solver", runGroundSolver, 3);
|
||||
YAP_UserCPredicate ("set_parfactors_params", setParfactorsParams, 2);
|
||||
YAP_UserCPredicate ("set_factors_params", setFactorsParams, 2);
|
||||
YAP_UserCPredicate ("set_vars_information", setVarsInformation, 2);
|
||||
YAP_UserCPredicate ("set_horus_flag", setHorusFlag, 2);
|
||||
YAP_UserCPredicate ("free_parfactors", freeParfactors, 1);
|
||||
YAP_UserCPredicate ("free_ground_network", freeGroundNetwork, 1);
|
||||
}
|
||||
|
@@ -1,296 +0,0 @@
|
||||
#ifndef HORUS_STATESINDEXER_H
|
||||
#define HORUS_STATESINDEXER_H
|
||||
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <functional>
|
||||
|
||||
#include <sstream>
|
||||
#include <iomanip>
|
||||
|
||||
#include "Var.h"
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
|
||||
class StatesIndexer
|
||||
{
|
||||
public:
|
||||
|
||||
StatesIndexer (const Ranges& ranges, bool calcOffsets = true)
|
||||
{
|
||||
size_ = 1;
|
||||
indices_.resize (ranges.size(), 0);
|
||||
ranges_ = ranges;
|
||||
for (unsigned i = 0; i < ranges.size(); i++) {
|
||||
size_ *= ranges[i];
|
||||
}
|
||||
li_ = 0;
|
||||
if (calcOffsets) {
|
||||
calculateOffsets();
|
||||
}
|
||||
}
|
||||
|
||||
StatesIndexer (const Vars& vars, bool calcOffsets = true)
|
||||
{
|
||||
size_ = 1;
|
||||
indices_.resize (vars.size(), 0);
|
||||
ranges_.reserve (vars.size());
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
ranges_.push_back (vars[i]->range());
|
||||
size_ *= vars[i]->range();
|
||||
}
|
||||
li_ = 0;
|
||||
if (calcOffsets) {
|
||||
calculateOffsets();
|
||||
}
|
||||
}
|
||||
|
||||
void increment (void)
|
||||
{
|
||||
for (int i = ranges_.size() - 1; i >= 0; i--) {
|
||||
indices_[i] ++;
|
||||
if (indices_[i] != ranges_[i]) {
|
||||
break;
|
||||
} else {
|
||||
indices_[i] = 0;
|
||||
}
|
||||
}
|
||||
li_ ++;
|
||||
}
|
||||
|
||||
void increment (unsigned dim)
|
||||
{
|
||||
assert (dim < ranges_.size());
|
||||
assert (ranges_.size() == offsets_.size());
|
||||
assert (indices_[dim] < ranges_[dim]);
|
||||
indices_[dim] ++;
|
||||
li_ += offsets_[dim];
|
||||
}
|
||||
|
||||
void incrementExcluding (unsigned skipDim)
|
||||
{
|
||||
assert (ranges_.size() == offsets_.size());
|
||||
for (int i = ranges_.size() - 1; i >= 0; i--) {
|
||||
if (i != (int)skipDim) {
|
||||
indices_[i] ++;
|
||||
li_ += offsets_[i];
|
||||
if (indices_[i] != ranges_[i]) {
|
||||
return;
|
||||
} else {
|
||||
indices_[i] = 0;
|
||||
li_ -= offsets_[i] * ranges_[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
li_ = size_;
|
||||
}
|
||||
|
||||
unsigned linearIndex (void) const
|
||||
{
|
||||
return li_;
|
||||
}
|
||||
|
||||
const vector<unsigned>& indices (void) const
|
||||
{
|
||||
return indices_;
|
||||
}
|
||||
|
||||
StatesIndexer& operator ++ (void)
|
||||
{
|
||||
increment();
|
||||
return *this;
|
||||
}
|
||||
|
||||
operator unsigned (void) const
|
||||
{
|
||||
return li_;
|
||||
}
|
||||
|
||||
unsigned operator[] (unsigned dim) const
|
||||
{
|
||||
assert (valid());
|
||||
assert (dim < ranges_.size());
|
||||
return indices_[dim];
|
||||
}
|
||||
|
||||
bool valid (void) const
|
||||
{
|
||||
return li_ < size_;
|
||||
}
|
||||
|
||||
void reset (void)
|
||||
{
|
||||
std::fill (indices_.begin(), indices_.end(), 0);
|
||||
li_ = 0;
|
||||
}
|
||||
|
||||
void reset (unsigned dim)
|
||||
{
|
||||
indices_[dim] = 0;
|
||||
li_ -= offsets_[dim] * ranges_[dim];
|
||||
}
|
||||
|
||||
unsigned size (void) const
|
||||
{
|
||||
return size_ ;
|
||||
}
|
||||
|
||||
friend ostream& operator<< (ostream &os, const StatesIndexer& idx)
|
||||
{
|
||||
os << "(" << std::setw (2) << std::setfill('0') << idx.li_ << ") " ;
|
||||
os << idx.indices_;
|
||||
return os;
|
||||
}
|
||||
|
||||
private:
|
||||
void calculateOffsets (void)
|
||||
{
|
||||
unsigned prod = 1;
|
||||
offsets_.resize (ranges_.size());
|
||||
for (int i = ranges_.size() - 1; i >= 0; i--) {
|
||||
offsets_[i] = prod;
|
||||
prod *= ranges_[i];
|
||||
}
|
||||
}
|
||||
|
||||
unsigned li_;
|
||||
unsigned size_;
|
||||
vector<unsigned> indices_;
|
||||
vector<unsigned> ranges_;
|
||||
vector<unsigned> offsets_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class MapIndexer
|
||||
{
|
||||
public:
|
||||
MapIndexer (const Ranges& ranges, const vector<bool>& mapDims)
|
||||
{
|
||||
assert (ranges.size() == mapDims.size());
|
||||
unsigned prod = 1;
|
||||
offsets_.resize (ranges.size());
|
||||
for (int i = ranges.size() - 1; i >= 0; i--) {
|
||||
if (mapDims[i]) {
|
||||
offsets_[i] = prod;
|
||||
prod *= ranges[i];
|
||||
}
|
||||
}
|
||||
indices_.resize (ranges.size(), 0);
|
||||
ranges_ = ranges;
|
||||
index_ = 0;
|
||||
valid_ = true;
|
||||
}
|
||||
|
||||
MapIndexer (const Ranges& ranges, unsigned ignoreDim)
|
||||
{
|
||||
unsigned prod = 1;
|
||||
offsets_.resize (ranges.size());
|
||||
for (int i = ranges.size() - 1; i >= 0; i--) {
|
||||
if (i != (int)ignoreDim) {
|
||||
offsets_[i] = prod;
|
||||
prod *= ranges[i];
|
||||
}
|
||||
}
|
||||
indices_.resize (ranges.size(), 0);
|
||||
ranges_ = ranges;
|
||||
index_ = 0;
|
||||
valid_ = true;
|
||||
}
|
||||
|
||||
/*
|
||||
MapIndexer (
|
||||
const VarIds& loopVids,
|
||||
const Ranges& loopRanges,
|
||||
const VarIds& mapVids,
|
||||
const Ranges& mapRanges)
|
||||
{
|
||||
unsigned prod = 1;
|
||||
vector<unsigned> offsets (mapRanges.size());
|
||||
for (int i = mapRanges.size() - 1; i >= 0; i--) {
|
||||
offsets[i] = prod;
|
||||
prod *= mapRanges[i];
|
||||
}
|
||||
|
||||
offsets_.reserve (loopVids.size());
|
||||
for (unsigned i = 0; i < loopVids.size(); i++) {
|
||||
VarIds::const_iterator it =
|
||||
std::find (mapVids.begin(), mapVids.end(), loopVids[i]);
|
||||
if (it != mapVids.end()) {
|
||||
offsets_.push_back (offsets[it - mapVids.begin()]);
|
||||
} else {
|
||||
offsets_.push_back (0);
|
||||
}
|
||||
}
|
||||
|
||||
indices_.resize (loopVids.size(), 0);
|
||||
ranges_ = loopRanges;
|
||||
index_ = 0;
|
||||
size_ = prod;
|
||||
}
|
||||
*/
|
||||
|
||||
MapIndexer& operator ++ (void)
|
||||
{
|
||||
assert (valid_);
|
||||
for (int i = ranges_.size() - 1; i >= 0; i--) {
|
||||
indices_[i] ++;
|
||||
index_ += offsets_[i];
|
||||
if (indices_[i] != ranges_[i]) {
|
||||
return *this;
|
||||
} else {
|
||||
indices_[i] = 0;
|
||||
index_ -= offsets_[i] * ranges_[i];
|
||||
}
|
||||
}
|
||||
valid_ = false;
|
||||
return *this;
|
||||
}
|
||||
|
||||
unsigned mappedIndex (void) const
|
||||
{
|
||||
return index_;
|
||||
}
|
||||
|
||||
operator unsigned (void) const
|
||||
{
|
||||
return index_;
|
||||
}
|
||||
|
||||
unsigned operator[] (unsigned dim) const
|
||||
{
|
||||
assert (valid());
|
||||
assert (dim < ranges_.size());
|
||||
return indices_[dim];
|
||||
}
|
||||
|
||||
bool valid (void) const
|
||||
{
|
||||
return valid_;
|
||||
}
|
||||
|
||||
void reset (void)
|
||||
{
|
||||
std::fill (indices_.begin(), indices_.end(), 0);
|
||||
index_ = 0;
|
||||
}
|
||||
|
||||
friend ostream& operator<< (ostream &os, const MapIndexer& idx)
|
||||
{
|
||||
os << "(" << std::setw (2) << std::setfill('0') << idx.index_ << ") " ;
|
||||
os << idx.indices_;
|
||||
return os;
|
||||
}
|
||||
|
||||
private:
|
||||
unsigned index_;
|
||||
bool valid_;
|
||||
vector<unsigned> ranges_;
|
||||
vector<unsigned> indices_;
|
||||
vector<unsigned> offsets_;
|
||||
};
|
||||
|
||||
|
||||
#endif // HORUS_STATESINDEXER_H
|
||||
|
@@ -1,131 +0,0 @@
|
||||
#include <cassert>
|
||||
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "LiftedUtils.h"
|
||||
#include "ConstraintTree.h"
|
||||
|
||||
|
||||
namespace LiftedUtils {
|
||||
|
||||
|
||||
unordered_map<string, unsigned> symbolDict;
|
||||
|
||||
|
||||
Symbol
|
||||
getSymbol (const string& symbolName)
|
||||
{
|
||||
unordered_map<string, unsigned>::iterator it
|
||||
= symbolDict.find (symbolName);
|
||||
if (it != symbolDict.end()) {
|
||||
return it->second;
|
||||
} else {
|
||||
symbolDict[symbolName] = symbolDict.size() - 1;
|
||||
return symbolDict.size() - 1;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
printSymbolDictionary (void)
|
||||
{
|
||||
unordered_map<string, unsigned>::const_iterator it
|
||||
= symbolDict.begin();
|
||||
while (it != symbolDict.end()) {
|
||||
cout << it->first << " -> " << it->second << endl;
|
||||
it ++;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
ostream& operator<< (ostream &os, const Symbol& s)
|
||||
{
|
||||
unordered_map<string, unsigned>::const_iterator it
|
||||
= LiftedUtils::symbolDict.begin();
|
||||
while (it != LiftedUtils::symbolDict.end() && it->second != s) {
|
||||
it ++;
|
||||
}
|
||||
assert (it != LiftedUtils::symbolDict.end());
|
||||
os << it->first;
|
||||
return os;
|
||||
}
|
||||
|
||||
|
||||
|
||||
ostream& operator<< (ostream &os, const LogVar& X)
|
||||
{
|
||||
const string labels[] = {
|
||||
"A", "B", "C", "D", "E", "F",
|
||||
"G", "H", "I", "J", "K", "M" };
|
||||
(X >= 12) ? os << "X_" << X.id_ : os << labels[X];
|
||||
return os;
|
||||
}
|
||||
|
||||
|
||||
|
||||
ostream& operator<< (ostream &os, const Tuple& t)
|
||||
{
|
||||
os << "(" ;
|
||||
for (unsigned i = 0; i < t.size(); i++) {
|
||||
os << ((i != 0) ? "," : "") << t[i];
|
||||
}
|
||||
os << ")" ;
|
||||
return os;
|
||||
}
|
||||
|
||||
|
||||
|
||||
ostream& operator<< (ostream &os, const Ground& gr)
|
||||
{
|
||||
os << gr.functor();
|
||||
os << "(" ;
|
||||
for (unsigned i = 0; i < gr.args().size(); i++) {
|
||||
if (i != 0) os << ", " ;
|
||||
os << gr.args()[i];
|
||||
}
|
||||
os << ")" ;
|
||||
return os;
|
||||
}
|
||||
|
||||
|
||||
|
||||
LogVars
|
||||
Substitution::getDiscardedLogVars (void) const
|
||||
{
|
||||
LogVars discardedLvs;
|
||||
set<LogVar> doneLvs;
|
||||
unordered_map<LogVar, LogVar>::const_iterator it;
|
||||
it = subs_.begin();
|
||||
while (it != subs_.end()) {
|
||||
if (Util::contains (doneLvs, it->second)) {
|
||||
discardedLvs.push_back (it->first);
|
||||
} else {
|
||||
doneLvs.insert (it->second);
|
||||
}
|
||||
it ++;
|
||||
}
|
||||
return discardedLvs;
|
||||
}
|
||||
|
||||
|
||||
|
||||
ostream& operator<< (ostream &os, const Substitution& theta)
|
||||
{
|
||||
unordered_map<LogVar, LogVar>::const_iterator it;
|
||||
os << "[" ;
|
||||
it = theta.subs_.begin();
|
||||
while (it != theta.subs_.end()) {
|
||||
if (it != theta.subs_.begin()) os << ", " ;
|
||||
os << it->first << "->" << it->second ;
|
||||
++ it;
|
||||
}
|
||||
os << "]" ;
|
||||
return os;
|
||||
}
|
||||
|
@@ -1,155 +0,0 @@
|
||||
#ifndef HORUS_LIFTEDUTILS_H
|
||||
#define HORUS_LIFTEDUTILS_H
|
||||
|
||||
#include <limits>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
|
||||
|
||||
#include "TinySet.h"
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
||||
class Symbol
|
||||
{
|
||||
public:
|
||||
Symbol (void) : id_(numeric_limits<unsigned>::max()) { }
|
||||
|
||||
Symbol (unsigned id) : id_(id) { }
|
||||
|
||||
operator unsigned (void) const { return id_; }
|
||||
|
||||
bool valid (void) const { return id_ != numeric_limits<unsigned>::max(); }
|
||||
|
||||
static Symbol invalid (void) { return Symbol(); }
|
||||
|
||||
friend ostream& operator<< (ostream &os, const Symbol& s);
|
||||
|
||||
private:
|
||||
unsigned id_;
|
||||
};
|
||||
|
||||
|
||||
class LogVar
|
||||
{
|
||||
public:
|
||||
LogVar (void) : id_(numeric_limits<unsigned>::max()) { }
|
||||
|
||||
LogVar (unsigned id) : id_(id) { }
|
||||
|
||||
operator unsigned (void) const { return id_; }
|
||||
|
||||
LogVar& operator++ (void)
|
||||
{
|
||||
assert (valid());
|
||||
id_ ++;
|
||||
return *this;
|
||||
}
|
||||
|
||||
bool valid (void) const
|
||||
{
|
||||
return id_ != numeric_limits<unsigned>::max();
|
||||
}
|
||||
|
||||
friend ostream& operator<< (ostream &os, const LogVar& X);
|
||||
|
||||
private:
|
||||
unsigned id_;
|
||||
};
|
||||
|
||||
|
||||
namespace std {
|
||||
template <> struct hash<Symbol> {
|
||||
size_t operator() (const Symbol& s) const {
|
||||
return std::hash<unsigned>() (s);
|
||||
}};
|
||||
|
||||
template <> struct hash<LogVar> {
|
||||
size_t operator() (const LogVar& X) const {
|
||||
return std::hash<unsigned>() (X);
|
||||
}};
|
||||
};
|
||||
|
||||
|
||||
typedef vector<Symbol> Symbols;
|
||||
typedef vector<Symbol> Tuple;
|
||||
typedef vector<Tuple> Tuples;
|
||||
typedef vector<LogVar> LogVars;
|
||||
typedef TinySet<Symbol> SymbolSet;
|
||||
typedef TinySet<LogVar> LogVarSet;
|
||||
typedef TinySet<Tuple> TupleSet;
|
||||
|
||||
|
||||
ostream& operator<< (ostream &os, const Tuple& t);
|
||||
|
||||
|
||||
namespace LiftedUtils {
|
||||
Symbol getSymbol (const string&);
|
||||
void printSymbolDictionary (void);
|
||||
}
|
||||
|
||||
|
||||
|
||||
class Ground
|
||||
{
|
||||
public:
|
||||
Ground (Symbol f) : functor_(f) { }
|
||||
|
||||
Ground (Symbol f, const Symbols& args) : functor_(f), args_(args) { }
|
||||
|
||||
Symbol functor (void) const { return functor_; }
|
||||
|
||||
Symbols args (void) const { return args_; }
|
||||
|
||||
unsigned arity (void) const { return args_.size(); }
|
||||
|
||||
bool isAtom (void) const { return args_.size() == 0; }
|
||||
|
||||
friend ostream& operator<< (ostream &os, const Ground& gr);
|
||||
|
||||
private:
|
||||
Symbol functor_;
|
||||
Symbols args_;
|
||||
};
|
||||
|
||||
typedef vector<Ground> Grounds;
|
||||
|
||||
|
||||
|
||||
class Substitution
|
||||
{
|
||||
public:
|
||||
void add (LogVar X_old, LogVar X_new)
|
||||
{
|
||||
assert (Util::contains (subs_, X_old) == false);
|
||||
subs_.insert (make_pair (X_old, X_new));
|
||||
}
|
||||
|
||||
void rename (LogVar X_old, LogVar X_new)
|
||||
{
|
||||
assert (Util::contains (subs_, X_old));
|
||||
subs_.find (X_old)->second = X_new;
|
||||
}
|
||||
|
||||
LogVar newNameFor (LogVar X) const
|
||||
{
|
||||
assert (Util::contains (subs_, X));
|
||||
return subs_.find (X)->second;
|
||||
}
|
||||
|
||||
LogVars getDiscardedLogVars (void) const;
|
||||
|
||||
friend ostream& operator<< (ostream &os, const Substitution& theta);
|
||||
|
||||
private:
|
||||
unordered_map<LogVar, LogVar> subs_;
|
||||
|
||||
};
|
||||
|
||||
|
||||
#endif // HORUS_LIFTEDUTILS_H
|
||||
|
@@ -1,177 +0,0 @@
|
||||
#
|
||||
# default base directory for YAP installation
|
||||
# (EROOT for architecture-dependent files)
|
||||
#
|
||||
GCC = @GCC@
|
||||
prefix = @prefix@
|
||||
exec_prefix = @exec_prefix@
|
||||
ROOTDIR = $(prefix)
|
||||
EROOTDIR = @exec_prefix@
|
||||
abs_top_builddir = @abs_top_builddir@
|
||||
#
|
||||
# where the binary should be
|
||||
#
|
||||
BINDIR = $(EROOTDIR)/bin
|
||||
#
|
||||
# where YAP should look for libraries
|
||||
#
|
||||
LIBDIR=@libdir@
|
||||
YAPLIBDIR=@libdir@/Yap
|
||||
#
|
||||
#
|
||||
CC=@CC@
|
||||
CXX=@CXX@
|
||||
|
||||
# normal
|
||||
#CXXFLAGS= -std=c++0x @SHLIB_CXXFLAGS@ $(YAP_EXTRAS) $(DEFS) -D_YAP_NOT_INSTALLED_=1 -I$(srcdir) -I../../../.. -I$(srcdir)/../../../../include @CPPFLAGS@ -DNDEBUG
|
||||
|
||||
# debug
|
||||
CXXFLAGS= -std=c++0x @SHLIB_CXXFLAGS@ $(YAP_EXTRAS) $(DEFS) -D_YAP_NOT_INSTALLED_=1 -I$(srcdir) -I../../../.. -I$(srcdir)/../../../../include @CPPFLAGS@ -g -O0 -Wextra
|
||||
|
||||
|
||||
#
|
||||
#
|
||||
# You shouldn't need to change what follows.
|
||||
#
|
||||
INSTALL=@INSTALL@
|
||||
INSTALL_DATA=@INSTALL_DATA@
|
||||
INSTALL_PROGRAM=@INSTALL_PROGRAM@
|
||||
SHELL=/bin/sh
|
||||
RANLIB=@RANLIB@
|
||||
srcdir=@srcdir@
|
||||
SO=@SO@
|
||||
#4.1VPATH=@srcdir@:@srcdir@/OPTYap
|
||||
CWD=$(PWD)
|
||||
|
||||
|
||||
HEADERS = \
|
||||
$(srcdir)/BayesNet.h \
|
||||
$(srcdir)/BayesBall.h \
|
||||
$(srcdir)/ElimGraph.h \
|
||||
$(srcdir)/FactorGraph.h \
|
||||
$(srcdir)/Factor.h \
|
||||
$(srcdir)/CFactorGraph.h \
|
||||
$(srcdir)/ConstraintTree.h \
|
||||
$(srcdir)/Solver.h \
|
||||
$(srcdir)/VarElimSolver.h \
|
||||
$(srcdir)/BpSolver.h \
|
||||
$(srcdir)/CbpSolver.h \
|
||||
$(srcdir)/FoveSolver.h \
|
||||
$(srcdir)/Var.h \
|
||||
$(srcdir)/Indexer.h \
|
||||
$(srcdir)/Parfactor.h \
|
||||
$(srcdir)/ProbFormula.h \
|
||||
$(srcdir)/Histogram.h \
|
||||
$(srcdir)/ParfactorList.h \
|
||||
$(srcdir)/LiftedUtils.h \
|
||||
$(srcdir)/TinySet.h \
|
||||
$(srcdir)/Util.h \
|
||||
$(srcdir)/Horus.h
|
||||
|
||||
CPP_SOURCES = \
|
||||
$(srcdir)/BayesNet.cpp \
|
||||
$(srcdir)/BayesBall.cpp \
|
||||
$(srcdir)/ElimGraph.cpp \
|
||||
$(srcdir)/FactorGraph.cpp \
|
||||
$(srcdir)/Factor.cpp \
|
||||
$(srcdir)/CFactorGraph.cpp \
|
||||
$(srcdir)/ConstraintTree.cpp \
|
||||
$(srcdir)/Var.cpp \
|
||||
$(srcdir)/Solver.cpp \
|
||||
$(srcdir)/VarElimSolver.cpp \
|
||||
$(srcdir)/BpSolver.cpp \
|
||||
$(srcdir)/CbpSolver.cpp \
|
||||
$(srcdir)/FoveSolver.cpp \
|
||||
$(srcdir)/Parfactor.cpp \
|
||||
$(srcdir)/ProbFormula.cpp \
|
||||
$(srcdir)/Histogram.cpp \
|
||||
$(srcdir)/ParfactorList.cpp \
|
||||
$(srcdir)/LiftedUtils.cpp \
|
||||
$(srcdir)/Util.cpp \
|
||||
$(srcdir)/HorusYap.cpp \
|
||||
$(srcdir)/HorusCli.cpp
|
||||
|
||||
OBJS = \
|
||||
BayesNet.o \
|
||||
BayesBall.o \
|
||||
ElimGraph.o \
|
||||
FactorGraph.o \
|
||||
Factor.o \
|
||||
CFactorGraph.o \
|
||||
ConstraintTree.o \
|
||||
Var.o \
|
||||
Solver.o \
|
||||
VarElimSolver.o \
|
||||
BpSolver.o \
|
||||
CbpSolver.o \
|
||||
FoveSolver.o \
|
||||
Parfactor.o \
|
||||
ProbFormula.o \
|
||||
Histogram.o \
|
||||
ParfactorList.o \
|
||||
LiftedUtils.o \
|
||||
Util.o \
|
||||
HorusYap.o
|
||||
|
||||
HCLI_OBJS = \
|
||||
BayesNet.o \
|
||||
BayesBall.o \
|
||||
ElimGraph.o \
|
||||
FactorGraph.o \
|
||||
Factor.o \
|
||||
CFactorGraph.o \
|
||||
ConstraintTree.o \
|
||||
Var.o \
|
||||
Solver.o \
|
||||
VarElimSolver.o \
|
||||
BpSolver.o \
|
||||
CbpSolver.o \
|
||||
FoveSolver.o \
|
||||
Parfactor.o \
|
||||
ProbFormula.o \
|
||||
Histogram.o \
|
||||
ParfactorList.o \
|
||||
LiftedUtils.o \
|
||||
Util.o \
|
||||
HorusCli.o
|
||||
|
||||
SOBJS=horus.@SO@
|
||||
|
||||
|
||||
all: $(SOBJS) hcli
|
||||
|
||||
# default rule
|
||||
%.o : $(srcdir)/%.cpp
|
||||
$(CXX) -c $(CXXFLAGS) $< -o $@
|
||||
|
||||
|
||||
@DO_SECOND_LD@horus.@SO@: $(OBJS)
|
||||
@DO_SECOND_LD@ @SHLIB_CXX_LD@ -o horus.@SO@ $(OBJS) @EXTRA_LIBS_FOR_SWIDLLS@
|
||||
|
||||
|
||||
hcli: $(HCLI_OBJS)
|
||||
$(CXX) -o hcli $(HCLI_OBJS)
|
||||
|
||||
|
||||
install: all
|
||||
$(INSTALL_PROGRAM) $(SOBJS) $(DESTDIR) $(YAPLIBDIR)
|
||||
|
||||
|
||||
clean:
|
||||
rm -f *.o *~ $(OBJS) $(SOBJS) *.BAK hcli
|
||||
|
||||
|
||||
erase_dots:
|
||||
rm -f *.dot *.png
|
||||
|
||||
|
||||
depend: $(HEADERS) $(CPP_SOURCES)
|
||||
-@if test "$(GCC)" = yes; then\
|
||||
$(CC) -std=c++0x -MM -MG $(CFLAGS) -I$(srcdir) -I$(srcdir)/../../../../include -I$(srcdir)/../../../../H $(CPP_SOURCES) >> Makefile;\
|
||||
else\
|
||||
makedepend -f - -- $(CFLAGS) -I$(srcdir)/../../../../H -I$(srcdir)/../../../../include -- $(CPP_SOURCES) |\
|
||||
sed 's|.*/\([^:]*\):|\1:|' >> Makefile ;\
|
||||
fi
|
||||
|
||||
# DO NOT DELETE THIS LINE -- make depend depends on it.
|
||||
|
@@ -1,685 +0,0 @@
|
||||
|
||||
#include "Parfactor.h"
|
||||
#include "Histogram.h"
|
||||
#include "Indexer.h"
|
||||
#include "Util.h"
|
||||
#include "Horus.h"
|
||||
|
||||
|
||||
Parfactor::Parfactor (
|
||||
const ProbFormulas& formulas,
|
||||
const Params& params,
|
||||
const Tuples& tuples,
|
||||
unsigned distId)
|
||||
{
|
||||
args_ = formulas;
|
||||
params_ = params;
|
||||
distId_ = distId;
|
||||
|
||||
LogVars logVars;
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
ranges_.push_back (args_[i].range());
|
||||
const LogVars& lvs = args_[i].logVars();
|
||||
for (unsigned j = 0; j < lvs.size(); j++) {
|
||||
if (Util::contains (logVars, lvs[j]) == false) {
|
||||
logVars.push_back (lvs[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
constr_ = new ConstraintTree (logVars, tuples);
|
||||
assert (params_.size() == Util::expectedSize (ranges_));
|
||||
}
|
||||
|
||||
|
||||
|
||||
Parfactor::Parfactor (const Parfactor* g, const Tuple& tuple)
|
||||
{
|
||||
args_ = g->arguments();
|
||||
params_ = g->params();
|
||||
ranges_ = g->ranges();
|
||||
distId_ = g->distId();
|
||||
constr_ = new ConstraintTree (g->logVars(), {tuple});
|
||||
assert (params_.size() == Util::expectedSize (ranges_));
|
||||
}
|
||||
|
||||
|
||||
|
||||
Parfactor::Parfactor (const Parfactor* g, ConstraintTree* constr)
|
||||
{
|
||||
args_ = g->arguments();
|
||||
params_ = g->params();
|
||||
ranges_ = g->ranges();
|
||||
distId_ = g->distId();
|
||||
constr_ = constr;
|
||||
assert (params_.size() == Util::expectedSize (ranges_));
|
||||
}
|
||||
|
||||
|
||||
|
||||
Parfactor::Parfactor (const Parfactor& g)
|
||||
{
|
||||
args_ = g.arguments();
|
||||
params_ = g.params();
|
||||
ranges_ = g.ranges();
|
||||
distId_ = g.distId();
|
||||
constr_ = new ConstraintTree (*g.constr());
|
||||
assert (params_.size() == Util::expectedSize (ranges_));
|
||||
}
|
||||
|
||||
|
||||
|
||||
Parfactor::~Parfactor (void)
|
||||
{
|
||||
delete constr_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
LogVarSet
|
||||
Parfactor::countedLogVars (void) const
|
||||
{
|
||||
LogVarSet set;
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
if (args_[i].isCounting()) {
|
||||
set.insert (args_[i].countedLogVar());
|
||||
}
|
||||
}
|
||||
return set;
|
||||
}
|
||||
|
||||
|
||||
|
||||
LogVarSet
|
||||
Parfactor::uncountedLogVars (void) const
|
||||
{
|
||||
return constr_->logVarSet() - countedLogVars();
|
||||
}
|
||||
|
||||
|
||||
|
||||
LogVarSet
|
||||
Parfactor::elimLogVars (void) const
|
||||
{
|
||||
LogVarSet requiredToElim = constr_->logVarSet();
|
||||
requiredToElim -= constr_->singletons();
|
||||
requiredToElim -= countedLogVars();
|
||||
return requiredToElim;
|
||||
}
|
||||
|
||||
|
||||
|
||||
LogVarSet
|
||||
Parfactor::exclusiveLogVars (unsigned fIdx) const
|
||||
{
|
||||
assert (fIdx < args_.size());
|
||||
LogVarSet remaining;
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
if (i != fIdx) {
|
||||
remaining |= args_[i].logVarSet();
|
||||
}
|
||||
}
|
||||
return args_[fIdx].logVarSet() - remaining;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Parfactor::setConstraintTree (ConstraintTree* newTree)
|
||||
{
|
||||
delete constr_;
|
||||
constr_ = newTree;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Parfactor::sumOut (unsigned fIdx)
|
||||
{
|
||||
assert (fIdx < args_.size());
|
||||
assert (args_[fIdx].contains (elimLogVars()));
|
||||
|
||||
LogVarSet excl = exclusiveLogVars (fIdx);
|
||||
if (args_[fIdx].isCounting()) {
|
||||
LogAware::pow (params_, constr_->getConditionalCount (
|
||||
excl - args_[fIdx].countedLogVar()));
|
||||
} else {
|
||||
LogAware::pow (params_, constr_->getConditionalCount (excl));
|
||||
}
|
||||
|
||||
if (args_[fIdx].isCounting()) {
|
||||
unsigned N = constr_->getConditionalCount (
|
||||
args_[fIdx].countedLogVar());
|
||||
unsigned R = args_[fIdx].range();
|
||||
vector<double> numAssigns = HistogramSet::getNumAssigns (N, R);
|
||||
StatesIndexer sindexer (ranges_, fIdx);
|
||||
while (sindexer.valid()) {
|
||||
unsigned h = sindexer[fIdx];
|
||||
if (Globals::logDomain) {
|
||||
params_[sindexer] += numAssigns[h];
|
||||
} else {
|
||||
params_[sindexer] *= numAssigns[h];
|
||||
}
|
||||
++ sindexer;
|
||||
}
|
||||
}
|
||||
|
||||
Params copy = params_;
|
||||
params_.clear();
|
||||
params_.resize (copy.size() / ranges_[fIdx], LogAware::addIdenty());
|
||||
MapIndexer indexer (ranges_, fIdx);
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned i = 0; i < copy.size(); i++) {
|
||||
params_[indexer] = Util::logSum (params_[indexer], copy[i]);
|
||||
++ indexer;
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = 0; i < copy.size(); i++) {
|
||||
params_[indexer] += copy[i];
|
||||
++ indexer;
|
||||
}
|
||||
}
|
||||
|
||||
args_.erase (args_.begin() + fIdx);
|
||||
ranges_.erase (ranges_.begin() + fIdx);
|
||||
constr_->remove (excl);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Parfactor::multiply (Parfactor& g)
|
||||
{
|
||||
alignAndExponentiate (this, &g);
|
||||
TFactor<ProbFormula>::multiply (g);
|
||||
constr_->join (g.constr(), true);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Parfactor::countConvert (LogVar X)
|
||||
{
|
||||
int fIdx = indexOfLogVar (X);
|
||||
assert (fIdx != -1);
|
||||
assert (constr_->isCountNormalized (X));
|
||||
assert (constr_->getConditionalCount (X) > 1);
|
||||
assert (constr_->isCarteesianProduct (countedLogVars() | X));
|
||||
|
||||
unsigned N = constr_->getConditionalCount (X);
|
||||
unsigned R = ranges_[fIdx];
|
||||
unsigned H = HistogramSet::nrHistograms (N, R);
|
||||
vector<Histogram> histograms = HistogramSet::getHistograms (N, R);
|
||||
|
||||
StatesIndexer indexer (ranges_);
|
||||
vector<Params> sumout (params_.size() / R);
|
||||
unsigned count = 0;
|
||||
while (indexer.valid()) {
|
||||
sumout[count].reserve (R);
|
||||
for (unsigned r = 0; r < R; r++) {
|
||||
sumout[count].push_back (params_[indexer]);
|
||||
indexer.increment (fIdx);
|
||||
}
|
||||
count ++;
|
||||
indexer.reset (fIdx);
|
||||
indexer.incrementExcluding (fIdx);
|
||||
}
|
||||
|
||||
params_.clear();
|
||||
params_.reserve (sumout.size() * H);
|
||||
|
||||
ranges_[fIdx] = H;
|
||||
MapIndexer mapIndexer (ranges_, fIdx);
|
||||
while (mapIndexer.valid()) {
|
||||
double prod = LogAware::multIdenty();
|
||||
unsigned i = mapIndexer.mappedIndex();
|
||||
unsigned h = mapIndexer[fIdx];
|
||||
for (unsigned r = 0; r < R; r++) {
|
||||
if (Globals::logDomain) {
|
||||
prod += LogAware::pow (sumout[i][r], histograms[h][r]);
|
||||
} else {
|
||||
prod *= LogAware::pow (sumout[i][r], histograms[h][r]);
|
||||
}
|
||||
}
|
||||
params_.push_back (prod);
|
||||
++ mapIndexer;
|
||||
}
|
||||
args_[fIdx].setCountedLogVar (X);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Parfactor::expand (LogVar X, LogVar X_new1, LogVar X_new2)
|
||||
{
|
||||
int fIdx = indexOfLogVar (X);
|
||||
assert (fIdx != -1);
|
||||
assert (args_[fIdx].isCounting());
|
||||
|
||||
unsigned N1 = constr_->getConditionalCount (X_new1);
|
||||
unsigned N2 = constr_->getConditionalCount (X_new2);
|
||||
unsigned N = N1 + N2;
|
||||
unsigned R = args_[fIdx].range();
|
||||
unsigned H1 = HistogramSet::nrHistograms (N1, R);
|
||||
unsigned H2 = HistogramSet::nrHistograms (N2, R);
|
||||
|
||||
vector<Histogram> histograms = HistogramSet::getHistograms (N, R);
|
||||
vector<Histogram> histograms1 = HistogramSet::getHistograms (N1, R);
|
||||
vector<Histogram> histograms2 = HistogramSet::getHistograms (N2, R);
|
||||
|
||||
vector<unsigned> sumIndexes;
|
||||
sumIndexes.reserve (H1 * H2);
|
||||
for (unsigned i = 0; i < H1; i++) {
|
||||
for (unsigned j = 0; j < H2; j++) {
|
||||
Histogram hist = histograms1[i];
|
||||
std::transform (
|
||||
hist.begin(), hist.end(),
|
||||
histograms2[j].begin(),
|
||||
hist.begin(),
|
||||
plus<int>());
|
||||
sumIndexes.push_back (HistogramSet::findIndex (hist, histograms));
|
||||
}
|
||||
}
|
||||
|
||||
expandPotential (fIdx, H1 * H2, sumIndexes);
|
||||
|
||||
args_.insert (args_.begin() + fIdx + 1, args_[fIdx]);
|
||||
args_[fIdx].rename (X, X_new1);
|
||||
args_[fIdx + 1].rename (X, X_new2);
|
||||
ranges_.insert (ranges_.begin() + fIdx + 1, H2);
|
||||
ranges_[fIdx] = H1;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Parfactor::fullExpand (LogVar X)
|
||||
{
|
||||
int fIdx = indexOfLogVar (X);
|
||||
assert (fIdx != -1);
|
||||
assert (args_[fIdx].isCounting());
|
||||
|
||||
unsigned N = constr_->getConditionalCount (X);
|
||||
unsigned R = args_[fIdx].range();
|
||||
|
||||
vector<Histogram> originHists = HistogramSet::getHistograms (N, R);
|
||||
vector<Histogram> expandHists = HistogramSet::getHistograms (1, R);
|
||||
|
||||
vector<unsigned> sumIndexes;
|
||||
sumIndexes.reserve (N * R);
|
||||
|
||||
Ranges expandRanges (N, R);
|
||||
StatesIndexer indexer (expandRanges);
|
||||
while (indexer.valid()) {
|
||||
vector<unsigned> hist (R, 0);
|
||||
for (unsigned n = 0; n < N; n++) {
|
||||
std::transform (
|
||||
hist.begin(), hist.end(),
|
||||
expandHists[indexer[n]].begin(),
|
||||
hist.begin(),
|
||||
plus<int>());
|
||||
}
|
||||
sumIndexes.push_back (HistogramSet::findIndex (hist, originHists));
|
||||
++ indexer;
|
||||
}
|
||||
|
||||
expandPotential (fIdx, std::pow (R, N), sumIndexes);
|
||||
|
||||
ProbFormula f = args_[fIdx];
|
||||
args_.erase (args_.begin() + fIdx);
|
||||
ranges_.erase (ranges_.begin() + fIdx);
|
||||
LogVars newLvs = constr_->expand (X);
|
||||
assert (newLvs.size() == N);
|
||||
for (unsigned i = 0 ; i < N; i++) {
|
||||
ProbFormula newFormula (f.functor(), f.logVars(), f.range());
|
||||
newFormula.rename (X, newLvs[i]);
|
||||
args_.insert (args_.begin() + fIdx + i, newFormula);
|
||||
ranges_.insert (ranges_.begin() + fIdx + i, R);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Parfactor::reorderAccordingGrounds (const Grounds& grounds)
|
||||
{
|
||||
ProbFormulas newFormulas;
|
||||
for (unsigned i = 0; i < grounds.size(); i++) {
|
||||
for (unsigned j = 0; j < args_.size(); j++) {
|
||||
if (grounds[i].functor() == args_[j].functor() &&
|
||||
grounds[i].arity() == args_[j].arity()) {
|
||||
constr_->moveToTop (args_[j].logVars());
|
||||
if (constr_->containsTuple (grounds[i].args())) {
|
||||
newFormulas.push_back (args_[j]);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
assert (newFormulas.size() == i + 1);
|
||||
}
|
||||
reorderArguments (newFormulas);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Parfactor::absorveEvidence (const ProbFormula& formula, unsigned evidence)
|
||||
{
|
||||
int fIdx = indexOf (formula);
|
||||
assert (fIdx != -1);
|
||||
LogVarSet excl = exclusiveLogVars (fIdx);
|
||||
assert (args_[fIdx].isCounting() == false);
|
||||
assert (constr_->isCountNormalized (excl));
|
||||
LogAware::pow (params_, constr_->getConditionalCount (excl));
|
||||
TFactor<ProbFormula>::absorveEvidence (formula, evidence);
|
||||
constr_->remove (excl);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Parfactor::setNewGroups (void)
|
||||
{
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
args_[i].setGroup (ProbFormula::getNewGroup());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Parfactor::applySubstitution (const Substitution& theta)
|
||||
{
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
LogVars& lvs = args_[i].logVars();
|
||||
for (unsigned j = 0; j < lvs.size(); j++) {
|
||||
lvs[j] = theta.newNameFor (lvs[j]);
|
||||
}
|
||||
if (args_[i].isCounting()) {
|
||||
LogVar clv = args_[i].countedLogVar();
|
||||
args_[i].setCountedLogVar (theta.newNameFor (clv));
|
||||
}
|
||||
}
|
||||
constr_->applySubstitution (theta);
|
||||
}
|
||||
|
||||
|
||||
|
||||
int
|
||||
Parfactor::findGroup (const Ground& ground) const
|
||||
{
|
||||
int group = -1;
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
if (args_[i].functor() == ground.functor() &&
|
||||
args_[i].arity() == ground.arity()) {
|
||||
constr_->moveToTop (args_[i].logVars());
|
||||
if (constr_->containsTuple (ground.args())) {
|
||||
group = args_[i].group();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return group;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
Parfactor::containsGround (const Ground& ground) const
|
||||
{
|
||||
return findGroup (ground) != -1;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
Parfactor::containsGroup (unsigned group) const
|
||||
{
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
if (args_[i].group() == group) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
|
||||
unsigned
|
||||
Parfactor::nrFormulas (LogVar X) const
|
||||
{
|
||||
unsigned count = 0;
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
if (args_[i].contains (X)) {
|
||||
count ++;
|
||||
}
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
|
||||
|
||||
int
|
||||
Parfactor::indexOfLogVar (LogVar X) const
|
||||
{
|
||||
int idx = -1;
|
||||
assert (nrFormulas (X) == 1);
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
if (args_[i].contains (X)) {
|
||||
idx = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return idx;
|
||||
}
|
||||
|
||||
|
||||
|
||||
int
|
||||
Parfactor::indexOfGroup (unsigned group) const
|
||||
{
|
||||
int pos = -1;
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
if (args_[i].group() == group) {
|
||||
pos = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return pos;
|
||||
}
|
||||
|
||||
|
||||
|
||||
vector<unsigned>
|
||||
Parfactor::getAllGroups (void) const
|
||||
{
|
||||
vector<unsigned> groups (args_.size());
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
groups[i] = args_[i].group();
|
||||
}
|
||||
return groups;
|
||||
}
|
||||
|
||||
|
||||
|
||||
string
|
||||
Parfactor::getLabel (void) const
|
||||
{
|
||||
stringstream ss;
|
||||
ss << "phi(" ;
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
if (i != 0) ss << "," ;
|
||||
ss << args_[i];
|
||||
}
|
||||
ss << ")" ;
|
||||
ConstraintTree copy (*constr_);
|
||||
copy.moveToTop (copy.logVarSet().elements());
|
||||
ss << "|" << copy.tupleSet();
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Parfactor::print (bool printParams) const
|
||||
{
|
||||
cout << "Formulas: " ;
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
if (i != 0) cout << ", " ;
|
||||
cout << args_[i];
|
||||
}
|
||||
cout << endl;
|
||||
if (args_[0].group() != Util::maxUnsigned()) {
|
||||
vector<string> groups;
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
groups.push_back (string ("g") + Util::toString (args_[i].group()));
|
||||
}
|
||||
cout << "Groups: " << groups << endl;
|
||||
}
|
||||
cout << "LogVars: " << constr_->logVarSet() << endl;
|
||||
cout << "Ranges: " << ranges_ << endl;
|
||||
if (printParams == false) {
|
||||
cout << "Params: " << params_ << endl;
|
||||
}
|
||||
ConstraintTree copy (*constr_);
|
||||
copy.moveToTop (copy.logVarSet().elements());
|
||||
cout << "Tuples: " << copy.tupleSet() << endl;
|
||||
if (printParams) {
|
||||
vector<string> jointStrings;
|
||||
StatesIndexer indexer (ranges_);
|
||||
while (indexer.valid()) {
|
||||
stringstream ss;
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
if (i != 0) ss << ", " ;
|
||||
if (args_[i].isCounting()) {
|
||||
unsigned N = constr_->getConditionalCount (
|
||||
args_[i].countedLogVar());
|
||||
HistogramSet hs (N, args_[i].range());
|
||||
unsigned c = 0;
|
||||
while (c < indexer[i]) {
|
||||
hs.nextHistogram();
|
||||
c ++;
|
||||
}
|
||||
ss << hs;
|
||||
} else {
|
||||
ss << indexer[i];
|
||||
}
|
||||
}
|
||||
jointStrings.push_back (ss.str());
|
||||
++ indexer;
|
||||
}
|
||||
for (unsigned i = 0; i < params_.size(); i++) {
|
||||
cout << "f(" << jointStrings[i] << ")" ;
|
||||
cout << " = " << params_[i] << endl;
|
||||
}
|
||||
}
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Parfactor::expandPotential (
|
||||
int fIdx,
|
||||
unsigned newRange,
|
||||
const vector<unsigned>& sumIndexes)
|
||||
{
|
||||
unsigned size = (params_.size() / ranges_[fIdx]) * newRange;
|
||||
Params copy = params_;
|
||||
params_.clear();
|
||||
params_.reserve (size);
|
||||
|
||||
unsigned prod = 1;
|
||||
vector<unsigned> offsets_ (ranges_.size());
|
||||
for (int i = ranges_.size() - 1; i >= 0; i--) {
|
||||
offsets_[i] = prod;
|
||||
prod *= ranges_[i];
|
||||
}
|
||||
|
||||
unsigned index = 0;
|
||||
ranges_[fIdx] = newRange;
|
||||
vector<unsigned> indices (ranges_.size(), 0);
|
||||
for (unsigned k = 0; k < size; k++) {
|
||||
params_.push_back (copy[index]);
|
||||
for (int i = ranges_.size() - 1; i >= 0; i--) {
|
||||
indices[i] ++;
|
||||
if (i == fIdx) {
|
||||
assert (indices[i] - 1 < sumIndexes.size());
|
||||
int diff = sumIndexes[indices[i]] - sumIndexes[indices[i] - 1];
|
||||
index += diff * offsets_[i];
|
||||
} else {
|
||||
index += offsets_[i];
|
||||
}
|
||||
if (indices[i] != ranges_[i]) {
|
||||
break;
|
||||
} else {
|
||||
if (i == fIdx) {
|
||||
int diff = sumIndexes[0] - sumIndexes[indices[i]];
|
||||
index += diff * offsets_[i];
|
||||
} else {
|
||||
index -= offsets_[i] * ranges_[i];
|
||||
}
|
||||
indices[i] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Parfactor::alignAndExponentiate (Parfactor* g1, Parfactor* g2)
|
||||
{
|
||||
LogVars X_1, X_2;
|
||||
const ProbFormulas& formulas1 = g1->arguments();
|
||||
const ProbFormulas& formulas2 = g2->arguments();
|
||||
for (unsigned i = 0; i < formulas1.size(); i++) {
|
||||
for (unsigned j = 0; j < formulas2.size(); j++) {
|
||||
if (formulas1[i].group() == formulas2[j].group()) {
|
||||
Util::addToVector (X_1, formulas1[i].logVars());
|
||||
Util::addToVector (X_2, formulas2[j].logVars());
|
||||
}
|
||||
}
|
||||
}
|
||||
LogVarSet Y_1 = g1->logVarSet() - LogVarSet (X_1);
|
||||
LogVarSet Y_2 = g2->logVarSet() - LogVarSet (X_2);
|
||||
assert (g1->constr()->isCountNormalized (Y_1));
|
||||
assert (g2->constr()->isCountNormalized (Y_2));
|
||||
unsigned condCount1 = g1->constr()->getConditionalCount (Y_1);
|
||||
unsigned condCount2 = g2->constr()->getConditionalCount (Y_2);
|
||||
LogAware::pow (g1->params(), 1.0 / condCount2);
|
||||
LogAware::pow (g2->params(), 1.0 / condCount1);
|
||||
// this must be done in the end or else X_1 and X_2
|
||||
// will refer the old log var names in the code above
|
||||
align (g1, X_1, g2, X_2);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Parfactor::align (
|
||||
Parfactor* g1, const LogVars& alignLvs1,
|
||||
Parfactor* g2, const LogVars& alignLvs2)
|
||||
{
|
||||
LogVar freeLogVar = 0;
|
||||
Substitution theta1;
|
||||
Substitution theta2;
|
||||
const LogVarSet& allLvs1 = g1->logVarSet();
|
||||
for (unsigned i = 0; i < allLvs1.size(); i++) {
|
||||
theta1.add (allLvs1[i], freeLogVar);
|
||||
++ freeLogVar;
|
||||
}
|
||||
|
||||
const LogVarSet& allLvs2 = g2->logVarSet();
|
||||
for (unsigned i = 0; i < allLvs2.size(); i++) {
|
||||
theta2.add (allLvs2[i], freeLogVar);
|
||||
++ freeLogVar;
|
||||
}
|
||||
|
||||
assert (alignLvs1.size() == alignLvs2.size());
|
||||
for (unsigned i = 0; i < alignLvs1.size(); i++) {
|
||||
theta1.rename (alignLvs1[i], theta2.newNameFor (alignLvs2[i]));
|
||||
}
|
||||
g1->applySubstitution (theta1);
|
||||
g2->applySubstitution (theta2);
|
||||
}
|
||||
|
@@ -1,101 +0,0 @@
|
||||
#ifndef HORUS_PARFACTOR_H
|
||||
#define HORUS_PARFACTOR_H
|
||||
|
||||
#include <list>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "ProbFormula.h"
|
||||
#include "ConstraintTree.h"
|
||||
#include "LiftedUtils.h"
|
||||
#include "Horus.h"
|
||||
|
||||
#include "Factor.h"
|
||||
|
||||
class Parfactor : public TFactor<ProbFormula>
|
||||
{
|
||||
public:
|
||||
Parfactor (
|
||||
const ProbFormulas&,
|
||||
const Params&,
|
||||
const Tuples&,
|
||||
unsigned);
|
||||
|
||||
Parfactor (const Parfactor*, const Tuple&);
|
||||
|
||||
Parfactor (const Parfactor*, ConstraintTree*);
|
||||
|
||||
Parfactor (const Parfactor&);
|
||||
|
||||
~Parfactor (void);
|
||||
|
||||
ConstraintTree* constr (void) { return constr_; }
|
||||
|
||||
const ConstraintTree* constr (void) const { return constr_; }
|
||||
|
||||
const LogVars& logVars (void) const { return constr_->logVars(); }
|
||||
|
||||
const LogVarSet& logVarSet (void) const { return constr_->logVarSet(); }
|
||||
|
||||
LogVarSet countedLogVars (void) const;
|
||||
|
||||
LogVarSet uncountedLogVars (void) const;
|
||||
|
||||
LogVarSet elimLogVars (void) const;
|
||||
|
||||
LogVarSet exclusiveLogVars (unsigned) const;
|
||||
|
||||
void setConstraintTree (ConstraintTree*);
|
||||
|
||||
void sumOut (unsigned fIdx);
|
||||
|
||||
void multiply (Parfactor&);
|
||||
|
||||
void countConvert (LogVar);
|
||||
|
||||
void expand (LogVar, LogVar, LogVar);
|
||||
|
||||
void fullExpand (LogVar);
|
||||
|
||||
void reorderAccordingGrounds (const Grounds&);
|
||||
|
||||
void absorveEvidence (const ProbFormula&, unsigned);
|
||||
|
||||
void setNewGroups (void);
|
||||
|
||||
void applySubstitution (const Substitution&);
|
||||
|
||||
int findGroup (const Ground&) const;
|
||||
|
||||
bool containsGround (const Ground&) const;
|
||||
|
||||
bool containsGroup (unsigned) const;
|
||||
|
||||
unsigned nrFormulas (LogVar) const;
|
||||
|
||||
int indexOfLogVar (LogVar) const;
|
||||
|
||||
int indexOfGroup (unsigned) const;
|
||||
|
||||
vector<unsigned> getAllGroups (void) const;
|
||||
|
||||
void print (bool = false) const;
|
||||
|
||||
string getLabel (void) const;
|
||||
|
||||
private:
|
||||
void expandPotential (int fIdx, unsigned newRange,
|
||||
const vector<unsigned>& sumIndexes);
|
||||
|
||||
static void alignAndExponentiate (Parfactor*, Parfactor*);
|
||||
|
||||
static void align (
|
||||
Parfactor*, const LogVars&, Parfactor*, const LogVars&);
|
||||
|
||||
ConstraintTree* constr_;
|
||||
};
|
||||
|
||||
|
||||
typedef vector<Parfactor*> Parfactors;
|
||||
|
||||
#endif // HORUS_PARFACTOR_H
|
||||
|
@@ -1,427 +0,0 @@
|
||||
#include <cassert>
|
||||
|
||||
#include "ParfactorList.h"
|
||||
|
||||
|
||||
ParfactorList::ParfactorList (const ParfactorList& pfList)
|
||||
{
|
||||
ParfactorList::const_iterator it = pfList.begin();
|
||||
while (it != pfList.end()) {
|
||||
addShattered (new Parfactor (**it));
|
||||
++ it;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
ParfactorList::ParfactorList (const Parfactors& pfs)
|
||||
{
|
||||
add (pfs);
|
||||
}
|
||||
|
||||
|
||||
|
||||
ParfactorList::~ParfactorList (void)
|
||||
{
|
||||
ParfactorList::const_iterator it = pfList_.begin();
|
||||
while (it != pfList_.end()) {
|
||||
delete *it;
|
||||
++ it;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
ParfactorList::add (Parfactor* pf)
|
||||
{
|
||||
pf->setNewGroups();
|
||||
addToShatteredList (pf);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
ParfactorList::add (const Parfactors& pfs)
|
||||
{
|
||||
for (unsigned i = 0; i < pfs.size(); i++) {
|
||||
pfs[i]->setNewGroups();
|
||||
addToShatteredList (pfs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
ParfactorList::addShattered (Parfactor* pf)
|
||||
{
|
||||
assert (isAllShattered());
|
||||
pfList_.push_back (pf);
|
||||
assert (isAllShattered());
|
||||
}
|
||||
|
||||
|
||||
|
||||
list<Parfactor*>::iterator
|
||||
ParfactorList::insertShattered (
|
||||
list<Parfactor*>::iterator it,
|
||||
Parfactor* pf)
|
||||
{
|
||||
return pfList_.insert (it, pf);
|
||||
assert (isAllShattered());
|
||||
}
|
||||
|
||||
|
||||
|
||||
list<Parfactor*>::iterator
|
||||
ParfactorList::remove (list<Parfactor*>::iterator it)
|
||||
{
|
||||
return pfList_.erase (it);
|
||||
}
|
||||
|
||||
|
||||
|
||||
list<Parfactor*>::iterator
|
||||
ParfactorList::removeAndDelete (list<Parfactor*>::iterator it)
|
||||
{
|
||||
delete *it;
|
||||
return pfList_.erase (it);
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
ParfactorList::isAllShattered (void) const
|
||||
{
|
||||
if (pfList_.size() <= 1) {
|
||||
return true;
|
||||
}
|
||||
vector<Parfactor*> pfs (pfList_.begin(), pfList_.end());
|
||||
for (unsigned i = 0; i < pfs.size() - 1; i++) {
|
||||
for (unsigned j = i + 1; j < pfs.size(); j++) {
|
||||
if (isShattered (pfs[i], pfs[j]) == false) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
ParfactorList::print (void) const
|
||||
{
|
||||
list<Parfactor*>::const_iterator it;
|
||||
for (it = pfList_.begin(); it != pfList_.end(); ++it) {
|
||||
(*it)->print();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
ParfactorList::isShattered (
|
||||
const Parfactor* g1,
|
||||
const Parfactor* g2) const
|
||||
{
|
||||
assert (g1 != g2);
|
||||
const ProbFormulas& fms1 = g1->arguments();
|
||||
const ProbFormulas& fms2 = g2->arguments();
|
||||
for (unsigned i = 0; i < fms1.size(); i++) {
|
||||
for (unsigned j = 0; j < fms2.size(); j++) {
|
||||
if (fms1[i].group() == fms2[j].group()) {
|
||||
if (identical (
|
||||
fms1[i], *(g1->constr()),
|
||||
fms2[j], *(g2->constr())) == false) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
if (disjoint (
|
||||
fms1[i], *(g1->constr()),
|
||||
fms2[j], *(g2->constr())) == false) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
ParfactorList::addToShatteredList (Parfactor* g)
|
||||
{
|
||||
queue<Parfactor*> residuals;
|
||||
residuals.push (g);
|
||||
while (residuals.empty() == false) {
|
||||
Parfactor* pf = residuals.front();
|
||||
bool pfSplitted = false;
|
||||
list<Parfactor*>::iterator pfIter;
|
||||
pfIter = pfList_.begin();
|
||||
while (pfIter != pfList_.end()) {
|
||||
std::pair<Parfactors, Parfactors> shattRes;
|
||||
shattRes = shatter (*pfIter, pf);
|
||||
if (shattRes.first.empty() == false) {
|
||||
pfIter = removeAndDelete (pfIter);
|
||||
Util::addToQueue (residuals, shattRes.first);
|
||||
} else {
|
||||
++ pfIter;
|
||||
}
|
||||
if (shattRes.second.empty() == false) {
|
||||
delete pf;
|
||||
Util::addToQueue (residuals, shattRes.second);
|
||||
pfSplitted = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
residuals.pop();
|
||||
if (pfSplitted == false) {
|
||||
addShattered (pf);
|
||||
}
|
||||
}
|
||||
assert (isAllShattered());
|
||||
}
|
||||
|
||||
|
||||
|
||||
std::pair<Parfactors, Parfactors>
|
||||
ParfactorList::shatter (Parfactor* g1, Parfactor* g2)
|
||||
{
|
||||
ProbFormulas& formulas1 = g1->arguments();
|
||||
ProbFormulas& formulas2 = g2->arguments();
|
||||
assert (g1 != 0 && g2 != 0 && g1 != g2);
|
||||
for (unsigned i = 0; i < formulas1.size(); i++) {
|
||||
for (unsigned j = 0; j < formulas2.size(); j++) {
|
||||
if (formulas1[i].sameSkeletonAs (formulas2[j])) {
|
||||
std::pair<Parfactors, Parfactors> res;
|
||||
res = shatter (i, g1, j, g2);
|
||||
if (res.first.empty() == false ||
|
||||
res.second.empty() == false) {
|
||||
return res;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return make_pair (Parfactors(), Parfactors());
|
||||
}
|
||||
|
||||
|
||||
|
||||
std::pair<Parfactors, Parfactors>
|
||||
ParfactorList::shatter (
|
||||
unsigned fIdx1, Parfactor* g1,
|
||||
unsigned fIdx2, Parfactor* g2)
|
||||
{
|
||||
ProbFormula& f1 = g1->argument (fIdx1);
|
||||
ProbFormula& f2 = g2->argument (fIdx2);
|
||||
// cout << endl;
|
||||
// Util::printDashedLine();
|
||||
// cout << "-> SHATTERING (#" << g1 << ", #" << g2 << ")" << endl;
|
||||
// g1->print();
|
||||
// cout << "-> WITH" << endl;
|
||||
// g2->print();
|
||||
// cout << "-> ON: " << f1 << "|" ;
|
||||
// cout << g1->constr()->tupleSet (f1.logVars()) << endl;
|
||||
// cout << "-> ON: " << f2 << "|" ;
|
||||
// cout << g2->constr()->tupleSet (f2.logVars()) << endl;
|
||||
// Util::printDashedLine();
|
||||
if (f1.isAtom()) {
|
||||
unsigned group = (f1.group() < f2.group()) ? f1.group() : f2.group();
|
||||
f1.setGroup (group);
|
||||
f2.setGroup (group);
|
||||
return { };
|
||||
}
|
||||
assert (g1->constr()->empty() == false);
|
||||
assert (g2->constr()->empty() == false);
|
||||
if (f1.group() == f2.group()) {
|
||||
assert (identical (f1, *(g1->constr()), f2, *(g2->constr())));
|
||||
return { };
|
||||
}
|
||||
|
||||
g1->constr()->moveToTop (f1.logVars());
|
||||
g2->constr()->moveToTop (f2.logVars());
|
||||
|
||||
std::pair<ConstraintTree*,ConstraintTree*> split1 =
|
||||
g1->constr()->split (g2->constr(), f1.arity());
|
||||
ConstraintTree* commCt1 = split1.first;
|
||||
ConstraintTree* exclCt1 = split1.second;
|
||||
|
||||
if (commCt1->empty()) {
|
||||
// disjoint
|
||||
delete commCt1;
|
||||
delete exclCt1;
|
||||
return { };
|
||||
}
|
||||
|
||||
std::pair<ConstraintTree*,ConstraintTree*> split2 =
|
||||
g2->constr()->split (g1->constr(), f2.arity());
|
||||
ConstraintTree* commCt2 = split2.first;
|
||||
ConstraintTree* exclCt2 = split2.second;
|
||||
|
||||
assert (commCt1->tupleSet (f1.arity()) ==
|
||||
commCt2->tupleSet (f2.arity()));
|
||||
|
||||
// unsigned static count = 0; count ++;
|
||||
// stringstream ss1; ss1 << "" << count << "_A.dot" ;
|
||||
// stringstream ss2; ss2 << "" << count << "_B.dot" ;
|
||||
// stringstream ss3; ss3 << "" << count << "_A_comm.dot" ;
|
||||
// stringstream ss4; ss4 << "" << count << "_A_excl.dot" ;
|
||||
// stringstream ss5; ss5 << "" << count << "_B_comm.dot" ;
|
||||
// stringstream ss6; ss6 << "" << count << "_B_excl.dot" ;
|
||||
// g1->constr()->exportToGraphViz (ss1.str().c_str(), true);
|
||||
// g2->constr()->exportToGraphViz (ss2.str().c_str(), true);
|
||||
// commCt1->exportToGraphViz (ss3.str().c_str(), true);
|
||||
// exclCt1->exportToGraphViz (ss4.str().c_str(), true);
|
||||
// commCt2->exportToGraphViz (ss5.str().c_str(), true);
|
||||
// exclCt2->exportToGraphViz (ss6.str().c_str(), true);
|
||||
|
||||
if (exclCt1->empty() && exclCt2->empty()) {
|
||||
unsigned group = (f1.group() < f2.group())
|
||||
? f1.group()
|
||||
: f2.group();
|
||||
// identical
|
||||
f1.setGroup (group);
|
||||
f2.setGroup (group);
|
||||
// unifyGroups
|
||||
delete commCt1;
|
||||
delete exclCt1;
|
||||
delete commCt2;
|
||||
delete exclCt2;
|
||||
return { };
|
||||
}
|
||||
|
||||
unsigned group;
|
||||
if (exclCt1->empty()) {
|
||||
group = f1.group();
|
||||
} else if (exclCt2->empty()) {
|
||||
group = f2.group();
|
||||
} else {
|
||||
group = ProbFormula::getNewGroup();
|
||||
}
|
||||
Parfactors res1 = shatter (g1, fIdx1, commCt1, exclCt1, group);
|
||||
Parfactors res2 = shatter (g2, fIdx2, commCt2, exclCt2, group);
|
||||
return make_pair (res1, res2);
|
||||
}
|
||||
|
||||
|
||||
|
||||
Parfactors
|
||||
ParfactorList::shatter (
|
||||
Parfactor* g,
|
||||
unsigned fIdx,
|
||||
ConstraintTree* commCt,
|
||||
ConstraintTree* exclCt,
|
||||
unsigned commGroup)
|
||||
{
|
||||
ProbFormula& f = g->argument (fIdx);
|
||||
if (exclCt->empty()) {
|
||||
delete commCt;
|
||||
delete exclCt;
|
||||
f.setGroup (commGroup);
|
||||
return { };
|
||||
}
|
||||
|
||||
Parfactors result;
|
||||
if (f.isCounting()) {
|
||||
LogVar X_new1 = g->constr()->logVarSet().back() + 1;
|
||||
LogVar X_new2 = g->constr()->logVarSet().back() + 2;
|
||||
ConstraintTrees cts = g->constr()->jointCountNormalize (
|
||||
commCt, exclCt, f.countedLogVar(), X_new1, X_new2);
|
||||
for (unsigned i = 0; i < cts.size(); i++) {
|
||||
Parfactor* newPf = new Parfactor (g, cts[i]);
|
||||
if (cts[i]->nrLogVars() == g->constr()->nrLogVars() + 1) {
|
||||
newPf->expand (f.countedLogVar(), X_new1, X_new2);
|
||||
assert (g->constr()->getConditionalCount (f.countedLogVar()) ==
|
||||
cts[i]->getConditionalCount (X_new1) +
|
||||
cts[i]->getConditionalCount (X_new2));
|
||||
} else {
|
||||
assert (g->constr()->getConditionalCount (f.countedLogVar()) ==
|
||||
cts[i]->getConditionalCount (f.countedLogVar()));
|
||||
}
|
||||
newPf->setNewGroups();
|
||||
result.push_back (newPf);
|
||||
}
|
||||
delete commCt;
|
||||
delete exclCt;
|
||||
} else {
|
||||
Parfactor* newPf = new Parfactor (g, commCt);
|
||||
newPf->setNewGroups();
|
||||
newPf->argument (fIdx).setGroup (commGroup);
|
||||
result.push_back (newPf);
|
||||
newPf = new Parfactor (g, exclCt);
|
||||
newPf->setNewGroups();
|
||||
result.push_back (newPf);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
ParfactorList::unifyGroups (unsigned group1, unsigned group2)
|
||||
{
|
||||
unsigned newGroup = ProbFormula::getNewGroup();
|
||||
for (ParfactorList::iterator it = pfList_.begin();
|
||||
it != pfList_.end(); it++) {
|
||||
ProbFormulas& formulas = (*it)->arguments();
|
||||
for (unsigned i = 0; i < formulas.size(); i++) {
|
||||
if (formulas[i].group() == group1 ||
|
||||
formulas[i].group() == group2) {
|
||||
formulas[i].setGroup (newGroup);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
ParfactorList::proper (
|
||||
const ProbFormula& f1, ConstraintTree c1,
|
||||
const ProbFormula& f2, ConstraintTree c2) const
|
||||
{
|
||||
return disjoint (f1, c1, f2, c2)
|
||||
|| identical (f1, c1, f2, c2);
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
ParfactorList::identical (
|
||||
const ProbFormula& f1, ConstraintTree c1,
|
||||
const ProbFormula& f2, ConstraintTree c2) const
|
||||
{
|
||||
if (f1.sameSkeletonAs (f2) == false) {
|
||||
return false;
|
||||
}
|
||||
if (f1.isAtom()) {
|
||||
return true;
|
||||
}
|
||||
c1.moveToTop (f1.logVars());
|
||||
c2.moveToTop (f2.logVars());
|
||||
return ConstraintTree::identical (
|
||||
&c1, &c2, f1.logVars().size());
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
ParfactorList::disjoint (
|
||||
const ProbFormula& f1, ConstraintTree c1,
|
||||
const ProbFormula& f2, ConstraintTree c2) const
|
||||
{
|
||||
if (f1.sameSkeletonAs (f2) == false) {
|
||||
return true;
|
||||
}
|
||||
if (f1.isAtom()) {
|
||||
return true;
|
||||
}
|
||||
c1.moveToTop (f1.logVars());
|
||||
c2.moveToTop (f2.logVars());
|
||||
return ConstraintTree::overlap (
|
||||
&c1, &c2, f1.arity()) == false;
|
||||
}
|
||||
|
@@ -1,97 +0,0 @@
|
||||
#ifndef HORUS_PARFACTORLIST_H
|
||||
#define HORUS_PARFACTORLIST_H
|
||||
|
||||
#include <list>
|
||||
#include <queue>
|
||||
|
||||
#include "Parfactor.h"
|
||||
#include "ProbFormula.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
||||
class ParfactorList
|
||||
{
|
||||
public:
|
||||
ParfactorList (void) { }
|
||||
|
||||
ParfactorList (const ParfactorList&);
|
||||
|
||||
ParfactorList (const Parfactors&);
|
||||
|
||||
~ParfactorList (void);
|
||||
|
||||
const list<Parfactor*>& parfactors (void) const { return pfList_; }
|
||||
|
||||
void clear (void) { pfList_.clear(); }
|
||||
|
||||
unsigned size (void) const { return pfList_.size(); }
|
||||
|
||||
typedef std::list<Parfactor*>::iterator iterator;
|
||||
|
||||
iterator begin (void) { return pfList_.begin(); }
|
||||
|
||||
iterator end (void) { return pfList_.end(); }
|
||||
|
||||
typedef std::list<Parfactor*>::const_iterator const_iterator;
|
||||
|
||||
const_iterator begin (void) const { return pfList_.begin(); }
|
||||
|
||||
const_iterator end (void) const { return pfList_.end(); }
|
||||
|
||||
void add (Parfactor* pf);
|
||||
|
||||
void add (const Parfactors& pfs);
|
||||
|
||||
void addShattered (Parfactor* pf);
|
||||
|
||||
list<Parfactor*>::iterator insertShattered (
|
||||
list<Parfactor*>::iterator, Parfactor*);
|
||||
|
||||
list<Parfactor*>::iterator remove (list<Parfactor*>::iterator);
|
||||
|
||||
list<Parfactor*>::iterator removeAndDelete (list<Parfactor*>::iterator);
|
||||
|
||||
bool isAllShattered (void) const;
|
||||
|
||||
void print (void) const;
|
||||
|
||||
private:
|
||||
|
||||
bool isShattered (const Parfactor*, const Parfactor*) const;
|
||||
|
||||
void addToShatteredList (Parfactor*);
|
||||
|
||||
std::pair<Parfactors, Parfactors> shatter (
|
||||
Parfactor*, Parfactor*);
|
||||
|
||||
std::pair<Parfactors, Parfactors> shatter (
|
||||
unsigned, Parfactor*, unsigned, Parfactor*);
|
||||
|
||||
Parfactors shatter (
|
||||
Parfactor*,
|
||||
unsigned,
|
||||
ConstraintTree*,
|
||||
ConstraintTree*,
|
||||
unsigned);
|
||||
|
||||
void unifyGroups (unsigned group1, unsigned group2);
|
||||
|
||||
bool proper (
|
||||
const ProbFormula&, ConstraintTree,
|
||||
const ProbFormula&, ConstraintTree) const;
|
||||
|
||||
bool identical (
|
||||
const ProbFormula&, ConstraintTree,
|
||||
const ProbFormula&, ConstraintTree) const;
|
||||
|
||||
bool disjoint (
|
||||
const ProbFormula&, ConstraintTree,
|
||||
const ProbFormula&, ConstraintTree) const;
|
||||
|
||||
list<Parfactor*> pfList_;
|
||||
};
|
||||
|
||||
#endif // HORUS_PARFACTORLIST_H
|
||||
|
@@ -1,123 +0,0 @@
|
||||
#include "ProbFormula.h"
|
||||
|
||||
|
||||
int ProbFormula::freeGroup_ = 0;
|
||||
|
||||
|
||||
|
||||
bool
|
||||
ProbFormula::sameSkeletonAs (const ProbFormula& f) const
|
||||
{
|
||||
return functor_ == f.functor() && logVars_.size() == f.arity();
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
ProbFormula::contains (LogVar lv) const
|
||||
{
|
||||
return Util::contains (logVars_, lv);
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
ProbFormula::contains (LogVarSet s) const
|
||||
{
|
||||
return LogVarSet (logVars_).contains (s);
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
ProbFormula::isAtom (void) const
|
||||
{
|
||||
return logVars_.size() == 0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
ProbFormula::isCounting (void) const
|
||||
{
|
||||
return countedLogVar_.valid();
|
||||
}
|
||||
|
||||
|
||||
|
||||
LogVar
|
||||
ProbFormula::countedLogVar (void) const
|
||||
{
|
||||
assert (isCounting());
|
||||
return countedLogVar_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
ProbFormula::setCountedLogVar (LogVar lv)
|
||||
{
|
||||
countedLogVar_ = lv;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
ProbFormula::rename (LogVar oldName, LogVar newName)
|
||||
{
|
||||
for (unsigned i = 0; i < logVars_.size(); i++) {
|
||||
if (logVars_[i] == oldName) {
|
||||
logVars_[i] = newName;
|
||||
}
|
||||
}
|
||||
if (isCounting() && countedLogVar_ == oldName) {
|
||||
countedLogVar_ = newName;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
bool operator== (const ProbFormula& f1, const ProbFormula& f2)
|
||||
{
|
||||
return f1.group_ == f2.group_;
|
||||
//return functor_ == f.functor_ && logVars_ == f.logVars_ ;
|
||||
}
|
||||
|
||||
|
||||
|
||||
std::ostream& operator<< (ostream &os, const ProbFormula& f)
|
||||
{
|
||||
os << f.functor_;
|
||||
if (f.isAtom() == false) {
|
||||
os << "(" ;
|
||||
for (unsigned i = 0; i < f.logVars_.size(); i++) {
|
||||
if (i != 0) os << ",";
|
||||
if (f.isCounting() && f.logVars_[i] == f.countedLogVar_) {
|
||||
os << "#" ;
|
||||
}
|
||||
os << f.logVars_[i];
|
||||
}
|
||||
os << ")" ;
|
||||
}
|
||||
os << "::" << f.range_;
|
||||
return os;
|
||||
}
|
||||
|
||||
|
||||
|
||||
unsigned
|
||||
ProbFormula::getNewGroup (void)
|
||||
{
|
||||
freeGroup_ ++;
|
||||
return freeGroup_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
ostream& operator<< (ostream &os, const ObservedFormula& of)
|
||||
{
|
||||
os << of.functor_ << "/" << of.arity_;
|
||||
os << "|" << of.constr_.tupleSet();
|
||||
os << " [evidence=" << of.evidence_ << "]";
|
||||
return os;
|
||||
}
|
||||
|
@@ -1,108 +0,0 @@
|
||||
#ifndef HORUS_PROBFORMULA_H
|
||||
#define HORUS_PROBFORMULA_H
|
||||
|
||||
#include <limits>
|
||||
|
||||
#include "ConstraintTree.h"
|
||||
#include "LiftedUtils.h"
|
||||
#include "Horus.h"
|
||||
|
||||
|
||||
|
||||
class ProbFormula
|
||||
{
|
||||
public:
|
||||
ProbFormula (Symbol f, const LogVars& lvs, unsigned range)
|
||||
: functor_(f), logVars_(lvs), range_(range),
|
||||
countedLogVar_(), group_(Util::maxUnsigned()) { }
|
||||
|
||||
ProbFormula (Symbol f, unsigned r)
|
||||
: functor_(f), range_(r), group_(Util::maxUnsigned()) { }
|
||||
|
||||
Symbol functor (void) const { return functor_; }
|
||||
|
||||
unsigned arity (void) const { return logVars_.size(); }
|
||||
|
||||
unsigned range (void) const { return range_; }
|
||||
|
||||
LogVars& logVars (void) { return logVars_; }
|
||||
|
||||
const LogVars& logVars (void) const { return logVars_; }
|
||||
|
||||
LogVarSet logVarSet (void) const { return LogVarSet (logVars_); }
|
||||
|
||||
unsigned group (void) const { return group_; }
|
||||
|
||||
void setGroup (unsigned g) { group_ = g; }
|
||||
|
||||
bool sameSkeletonAs (const ProbFormula&) const;
|
||||
|
||||
bool contains (LogVar) const;
|
||||
|
||||
bool contains (LogVarSet) const;
|
||||
|
||||
bool isAtom (void) const;
|
||||
|
||||
bool isCounting (void) const;
|
||||
|
||||
LogVar countedLogVar (void) const;
|
||||
|
||||
void setCountedLogVar (LogVar);
|
||||
|
||||
void rename (LogVar, LogVar);
|
||||
|
||||
static unsigned getNewGroup (void);
|
||||
|
||||
friend std::ostream& operator<< (ostream &os, const ProbFormula& f);
|
||||
|
||||
friend bool operator== (const ProbFormula& f1, const ProbFormula& f2);
|
||||
|
||||
private:
|
||||
Symbol functor_;
|
||||
LogVars logVars_;
|
||||
unsigned range_;
|
||||
LogVar countedLogVar_;
|
||||
unsigned group_;
|
||||
static int freeGroup_;
|
||||
};
|
||||
|
||||
typedef vector<ProbFormula> ProbFormulas;
|
||||
|
||||
|
||||
class ObservedFormula
|
||||
{
|
||||
public:
|
||||
ObservedFormula (Symbol f, unsigned a, unsigned ev)
|
||||
: functor_(f), arity_(a), evidence_(ev), constr_(a) { }
|
||||
|
||||
ObservedFormula (Symbol f, unsigned ev, const Tuple& tuple)
|
||||
: functor_(f), arity_(tuple.size()), evidence_(ev), constr_(arity_)
|
||||
{
|
||||
constr_.addTuple (tuple);
|
||||
}
|
||||
|
||||
Symbol functor (void) const { return functor_; }
|
||||
|
||||
unsigned arity (void) const { return arity_; }
|
||||
|
||||
unsigned evidence (void) const { return evidence_; }
|
||||
|
||||
ConstraintTree& constr (void) { return constr_; }
|
||||
|
||||
bool isAtom (void) const { return arity_ == 0; }
|
||||
|
||||
void addTuple (const Tuple& tuple) { constr_.addTuple (tuple); }
|
||||
|
||||
friend ostream& operator<< (ostream &os, const ObservedFormula& of);
|
||||
|
||||
private:
|
||||
Symbol functor_;
|
||||
unsigned arity_;
|
||||
unsigned evidence_;
|
||||
ConstraintTree constr_;
|
||||
};
|
||||
|
||||
typedef vector<ObservedFormula> ObservedFormulas;
|
||||
|
||||
#endif // HORUS_PROBFORMULA_H
|
||||
|
@@ -1,37 +0,0 @@
|
||||
#include "Solver.h"
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
void
|
||||
Solver::printAnswer (const VarIds& vids)
|
||||
{
|
||||
Vars unobservedVars;
|
||||
VarIds unobservedVids;
|
||||
for (unsigned i = 0; i < vids.size(); i++) {
|
||||
VarNode* vn = fg.getVarNode (vids[i]);
|
||||
if (vn->hasEvidence() == false) {
|
||||
unobservedVars.push_back (vn);
|
||||
unobservedVids.push_back (vids[i]);
|
||||
}
|
||||
}
|
||||
Params res = solveQuery (unobservedVids);
|
||||
vector<string> stateLines = Util::getStateLines (unobservedVars);
|
||||
for (unsigned i = 0; i < res.size(); i++) {
|
||||
cout << "P(" << stateLines[i] << ") = " ;
|
||||
cout << std::setprecision (Constants::PRECISION) << res[i];
|
||||
cout << endl;
|
||||
}
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Solver::printAllPosterioris (void)
|
||||
{
|
||||
const VarNodes& vars = fg.varNodes();
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
printAnswer ({vars[i]->varId()});
|
||||
}
|
||||
}
|
||||
|
@@ -1,30 +0,0 @@
|
||||
#ifndef HORUS_SOLVER_H
|
||||
#define HORUS_SOLVER_H
|
||||
|
||||
#include <iomanip>
|
||||
|
||||
#include "Var.h"
|
||||
#include "FactorGraph.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
|
||||
class Solver
|
||||
{
|
||||
public:
|
||||
Solver (const FactorGraph& factorGraph) : fg(factorGraph) { }
|
||||
|
||||
virtual ~Solver() { } // ensure that subclass destructor is called
|
||||
|
||||
virtual Params solveQuery (VarIds queryVids) = 0;
|
||||
|
||||
void printAnswer (const VarIds& vids);
|
||||
|
||||
void printAllPosterioris (void);
|
||||
|
||||
protected:
|
||||
const FactorGraph& fg;
|
||||
};
|
||||
|
||||
#endif // HORUS_SOLVER_H
|
||||
|
@@ -1,4 +0,0 @@
|
||||
TODO
|
||||
- add a way to calculate combinations and factorials with large numbers
|
||||
- refactor sumOut in parfactor -> is really ugly code
|
||||
- Indexer: start receiving ranges as constant reference
|
@@ -1,200 +0,0 @@
|
||||
#ifndef HORUS_TINYSET_H
|
||||
#define HORUS_TINYSET_H
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
||||
template <typename T>
|
||||
class TinySet
|
||||
{
|
||||
public:
|
||||
TinySet (void) {}
|
||||
|
||||
TinySet (const T& t)
|
||||
{
|
||||
elements_.push_back (t);
|
||||
}
|
||||
|
||||
TinySet (const vector<T>& elements)
|
||||
{
|
||||
elements_.reserve (elements.size());
|
||||
for (unsigned i = 0; i < elements.size(); i++) {
|
||||
insert (elements[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TinySet (const TinySet<T>& s) : elements_(s.elements_) { }
|
||||
|
||||
void insert (const T& t)
|
||||
{
|
||||
typename vector<T>::iterator it =
|
||||
std::lower_bound (elements_.begin(), elements_.end(), t);
|
||||
if (it == elements_.end() || *it != t) {
|
||||
elements_.insert (it, t);
|
||||
}
|
||||
}
|
||||
|
||||
void remove (const T& t)
|
||||
{
|
||||
typename vector<T>::iterator it =
|
||||
std::lower_bound (elements_.begin(), elements_.end(), t);
|
||||
if (it != elements_.end()) {
|
||||
elements_.erase (it);
|
||||
}
|
||||
}
|
||||
|
||||
/* set union */
|
||||
TinySet operator| (const TinySet& s) const
|
||||
{
|
||||
TinySet res;
|
||||
std::set_union (
|
||||
elements_.begin(),
|
||||
elements_.end(),
|
||||
s.elements_.begin(),
|
||||
s.elements_.end(),
|
||||
std::back_inserter (res.elements_));
|
||||
return res;
|
||||
}
|
||||
|
||||
/* set intersection */
|
||||
TinySet operator& (const TinySet& s) const
|
||||
{
|
||||
TinySet res;
|
||||
std::set_intersection (
|
||||
elements_.begin(),
|
||||
elements_.end(),
|
||||
s.elements_.begin(),
|
||||
s.elements_.end(),
|
||||
std::back_inserter (res.elements_));
|
||||
return res;
|
||||
}
|
||||
|
||||
/* set difference */
|
||||
TinySet operator- (const TinySet& s) const
|
||||
{
|
||||
TinySet res;
|
||||
std::set_difference (
|
||||
elements_.begin(),
|
||||
elements_.end(),
|
||||
s.elements_.begin(),
|
||||
s.elements_.end(),
|
||||
std::back_inserter (res.elements_));
|
||||
return res;
|
||||
}
|
||||
|
||||
TinySet& operator|= (const TinySet& s)
|
||||
{
|
||||
return *this = (*this | s);
|
||||
}
|
||||
|
||||
TinySet& operator&= (const TinySet& s)
|
||||
{
|
||||
return *this = (*this & s);
|
||||
}
|
||||
|
||||
TinySet& operator-= (const TinySet& s)
|
||||
{
|
||||
return *this = (*this - s);
|
||||
}
|
||||
|
||||
bool contains (const T& t) const
|
||||
{
|
||||
return std::binary_search (
|
||||
elements_.begin(), elements_.end(), t);
|
||||
}
|
||||
|
||||
bool contains (const TinySet& s) const
|
||||
{
|
||||
return std::includes (
|
||||
elements_.begin(),
|
||||
elements_.end(),
|
||||
s.elements_.begin(),
|
||||
s.elements_.end());
|
||||
}
|
||||
|
||||
bool in (const TinySet& s) const
|
||||
{
|
||||
return std::includes (
|
||||
s.elements_.begin(),
|
||||
s.elements_.end(),
|
||||
elements_.begin(),
|
||||
elements_.end());
|
||||
}
|
||||
|
||||
bool intersects (const TinySet& s) const
|
||||
{
|
||||
return (*this & s).size() > 0;
|
||||
}
|
||||
|
||||
T operator[] (unsigned i) const
|
||||
{
|
||||
return elements_[i];
|
||||
}
|
||||
|
||||
const vector<T>& elements (void) const
|
||||
{
|
||||
return elements_;
|
||||
}
|
||||
|
||||
T front (void) const
|
||||
{
|
||||
return elements_.front();
|
||||
}
|
||||
|
||||
T back (void) const
|
||||
{
|
||||
return elements_.back();
|
||||
}
|
||||
|
||||
unsigned size (void) const
|
||||
{
|
||||
return elements_.size();
|
||||
}
|
||||
|
||||
bool empty (void) const
|
||||
{
|
||||
return elements_.size() == 0;
|
||||
}
|
||||
|
||||
typedef typename std::vector<T>::const_iterator const_iterator;
|
||||
|
||||
const_iterator begin (void) const
|
||||
{
|
||||
return elements_.begin();
|
||||
}
|
||||
|
||||
const_iterator end (void) const
|
||||
{
|
||||
return elements_.end();
|
||||
}
|
||||
|
||||
friend bool operator== (const TinySet& s1, const TinySet& s2)
|
||||
{
|
||||
return s1.elements_ == s2.elements_;
|
||||
}
|
||||
|
||||
friend bool operator!= (const TinySet& s1, const TinySet& s2)
|
||||
{
|
||||
return s1.elements_ != s2.elements_;
|
||||
}
|
||||
|
||||
friend std::ostream& operator << (std::ostream& out, const TinySet<T>& s)
|
||||
{
|
||||
out << "{" ;
|
||||
for (unsigned i = 0; i < s.size(); i++) {
|
||||
out << ((i != 0) ? "," : "") << s.elements()[i];
|
||||
}
|
||||
out << "}" ;
|
||||
return out;
|
||||
}
|
||||
|
||||
protected:
|
||||
vector<T> elements_;
|
||||
};
|
||||
|
||||
|
||||
#endif // HORUS_TINYSET_H
|
||||
|
@@ -1,502 +0,0 @@
|
||||
#include <limits>
|
||||
|
||||
#include <sstream>
|
||||
#include <fstream>
|
||||
|
||||
#include "Util.h"
|
||||
#include "Indexer.h"
|
||||
|
||||
|
||||
namespace Globals {
|
||||
bool logDomain = false;
|
||||
|
||||
//InfAlgs infAlgorithm = InfAlgorithms::VE;
|
||||
//InfAlgs infAlgorithm = InfAlgorithms::BN_BP;
|
||||
//InfAlgs infAlgorithm = InfAlgorithms::FG_BP;
|
||||
InfAlgorithms infAlgorithm = InfAlgorithms::CBP;
|
||||
};
|
||||
|
||||
|
||||
|
||||
|
||||
namespace BpOptions {
|
||||
Schedule schedule = BpOptions::Schedule::SEQ_FIXED;
|
||||
//Schedule schedule = BpOptions::Schedule::SEQ_RANDOM;
|
||||
//Schedule schedule = BpOptions::Schedule::PARALLEL;
|
||||
//Schedule schedule = BpOptions::Schedule::MAX_RESIDUAL;
|
||||
double accuracy = 0.0001;
|
||||
unsigned maxIter = 1000;
|
||||
}
|
||||
|
||||
|
||||
|
||||
vector<NetInfo> Statistics::netInfo_;
|
||||
vector<CompressInfo> Statistics::compressInfo_;
|
||||
unsigned Statistics::primaryNetCount_;
|
||||
|
||||
|
||||
namespace Util {
|
||||
|
||||
void
|
||||
toLog (Params& v)
|
||||
{
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
v[i] = log (v[i]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
fromLog (Params& v)
|
||||
{
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
v[i] = exp (v[i]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
double
|
||||
factorial (double num)
|
||||
{
|
||||
double result = 1.0;
|
||||
for (int i = 1; i <= num; i++) {
|
||||
result *= i;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
|
||||
unsigned
|
||||
nrCombinations (unsigned n, unsigned r)
|
||||
{
|
||||
assert (n >= r);
|
||||
unsigned prod = 1;
|
||||
for (int i = (int)n; i > (int)(n - r); i--) {
|
||||
prod *= i;
|
||||
}
|
||||
return (prod / factorial (r));
|
||||
}
|
||||
|
||||
|
||||
|
||||
unsigned
|
||||
expectedSize (const Ranges& ranges)
|
||||
{
|
||||
unsigned prod = 1;
|
||||
for (unsigned i = 0; i < ranges.size(); i++) {
|
||||
prod *= ranges[i];
|
||||
}
|
||||
return prod;
|
||||
}
|
||||
|
||||
|
||||
|
||||
unsigned
|
||||
getNumberOfDigits (int number)
|
||||
{
|
||||
unsigned count = 1;
|
||||
while (number >= 10) {
|
||||
number /= 10;
|
||||
count ++;
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
isInteger (const string& s)
|
||||
{
|
||||
stringstream ss1 (s);
|
||||
stringstream ss2;
|
||||
int integer;
|
||||
ss1 >> integer;
|
||||
ss2 << integer;
|
||||
return (ss1.str() == ss2.str());
|
||||
}
|
||||
|
||||
|
||||
|
||||
string
|
||||
parametersToString (const Params& v, unsigned precision)
|
||||
{
|
||||
stringstream ss;
|
||||
ss.precision (precision);
|
||||
ss << "[" ;
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
if (i != 0) ss << ", " ;
|
||||
ss << v[i];
|
||||
}
|
||||
ss << "]" ;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
|
||||
|
||||
vector<string>
|
||||
getStateLines (const Vars& vars)
|
||||
{
|
||||
StatesIndexer idx (vars);
|
||||
vector<string> jointStrings;
|
||||
while (idx.valid()) {
|
||||
stringstream ss;
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
if (i != 0) ss << ", " ;
|
||||
ss << vars[i]->label() << "=" << vars[i]->states()[(idx[i])];
|
||||
}
|
||||
jointStrings.push_back (ss.str());
|
||||
++ idx;
|
||||
}
|
||||
return jointStrings;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
printHeader (string header, std::ostream& os)
|
||||
{
|
||||
printAsteriskLine (os);
|
||||
os << header << endl;
|
||||
printAsteriskLine (os);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
printSubHeader (string header, std::ostream& os)
|
||||
{
|
||||
printDashedLine (os);
|
||||
os << header << endl;
|
||||
printDashedLine (os);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
printAsteriskLine (std::ostream& os)
|
||||
{
|
||||
os << "********************************" ;
|
||||
os << "********************************" ;
|
||||
os << endl;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
printDashedLine (std::ostream& os)
|
||||
{
|
||||
os << "--------------------------------" ;
|
||||
os << "--------------------------------" ;
|
||||
os << endl;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
namespace LogAware {
|
||||
|
||||
void
|
||||
normalize (Params& v)
|
||||
{
|
||||
double sum;
|
||||
if (Globals::logDomain) {
|
||||
sum = LogAware::addIdenty();
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
sum = Util::logSum (sum, v[i]);
|
||||
}
|
||||
assert (sum != -numeric_limits<double>::infinity());
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
v[i] -= sum;
|
||||
}
|
||||
} else {
|
||||
sum = 0.0;
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
sum += v[i];
|
||||
}
|
||||
assert (sum != 0.0);
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
v[i] /= sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
double
|
||||
getL1Distance (const Params& v1, const Params& v2)
|
||||
{
|
||||
assert (v1.size() == v2.size());
|
||||
double dist = 0.0;
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned i = 0; i < v1.size(); i++) {
|
||||
dist += abs (exp(v1[i]) - exp(v2[i]));
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = 0; i < v1.size(); i++) {
|
||||
dist += abs (v1[i] - v2[i]);
|
||||
}
|
||||
}
|
||||
return dist;
|
||||
}
|
||||
|
||||
|
||||
|
||||
double
|
||||
getMaxNorm (const Params& v1, const Params& v2)
|
||||
{
|
||||
assert (v1.size() == v2.size());
|
||||
double max = 0.0;
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned i = 0; i < v1.size(); i++) {
|
||||
double diff = abs (exp(v1[i]) - exp(v2[i]));
|
||||
if (diff > max) {
|
||||
max = diff;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = 0; i < v1.size(); i++) {
|
||||
double diff = abs (v1[i] - v2[i]);
|
||||
if (diff > max) {
|
||||
max = diff;
|
||||
}
|
||||
}
|
||||
}
|
||||
return max;
|
||||
}
|
||||
|
||||
|
||||
double
|
||||
pow (double p, unsigned expoent)
|
||||
{
|
||||
return Globals::logDomain ? p * expoent : std::pow (p, expoent);
|
||||
}
|
||||
|
||||
|
||||
|
||||
double
|
||||
pow (double p, double expoent)
|
||||
{
|
||||
// assumes that `expoent' is never in log domain
|
||||
return Globals::logDomain ? p * expoent : std::pow (p, expoent);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
pow (Params& v, unsigned expoent)
|
||||
{
|
||||
if (expoent == 1) {
|
||||
return;
|
||||
}
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
v[i] *= expoent;
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
v[i] = std::pow (v[i], expoent);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
pow (Params& v, double expoent)
|
||||
{
|
||||
// assumes that `expoent' is never in log domain
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
v[i] *= expoent;
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
v[i] = std::pow (v[i], expoent);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
unsigned
|
||||
Statistics::getSolvedNetworksCounting (void)
|
||||
{
|
||||
return netInfo_.size();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Statistics::incrementPrimaryNetworksCounting (void)
|
||||
{
|
||||
primaryNetCount_ ++;
|
||||
}
|
||||
|
||||
|
||||
|
||||
unsigned
|
||||
Statistics::getPrimaryNetworksCounting (void)
|
||||
{
|
||||
return primaryNetCount_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Statistics::updateStatistics (
|
||||
unsigned size,
|
||||
bool loopy,
|
||||
unsigned nIters,
|
||||
double time)
|
||||
{
|
||||
netInfo_.push_back (NetInfo (size, loopy, nIters, time));
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Statistics::printStatistics (void)
|
||||
{
|
||||
cout << getStatisticString();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Statistics::writeStatistics (const char* fileName)
|
||||
{
|
||||
ofstream out (fileName);
|
||||
if (!out.is_open()) {
|
||||
cerr << "error: cannot open file to write at " ;
|
||||
cerr << "Statistics::writeStats()" << endl;
|
||||
abort();
|
||||
}
|
||||
out << getStatisticString();
|
||||
out.close();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Statistics::updateCompressingStatistics (
|
||||
unsigned nrGroundVars,
|
||||
unsigned nrGroundFactors,
|
||||
unsigned nrClusterVars,
|
||||
unsigned nrClusterFactors,
|
||||
unsigned nrNeighborless) {
|
||||
compressInfo_.push_back (CompressInfo (nrGroundVars, nrGroundFactors,
|
||||
nrClusterVars, nrClusterFactors, nrNeighborless));
|
||||
}
|
||||
|
||||
|
||||
|
||||
string
|
||||
Statistics::getStatisticString (void)
|
||||
{
|
||||
stringstream ss2, ss3, ss4, ss1;
|
||||
ss1 << "running mode: " ;
|
||||
switch (Globals::infAlgorithm) {
|
||||
case InfAlgorithms::VE: ss1 << "ve" << endl; break;
|
||||
case InfAlgorithms::BP: ss1 << "bp" << endl; break;
|
||||
case InfAlgorithms::CBP: ss1 << "cbp" << endl; break;
|
||||
}
|
||||
ss1 << "message schedule: " ;
|
||||
switch (BpOptions::schedule) {
|
||||
case BpOptions::Schedule::SEQ_FIXED:
|
||||
ss1 << "sequential fixed" << endl;
|
||||
break;
|
||||
case BpOptions::Schedule::SEQ_RANDOM:
|
||||
ss1 << "sequential random" << endl;
|
||||
break;
|
||||
case BpOptions::Schedule::PARALLEL:
|
||||
ss1 << "parallel" << endl;
|
||||
break;
|
||||
case BpOptions::Schedule::MAX_RESIDUAL:
|
||||
ss1 << "max residual" << endl;
|
||||
break;
|
||||
}
|
||||
ss1 << "max iterations: " << BpOptions::maxIter << endl;
|
||||
ss1 << "accuracy " << BpOptions::accuracy << endl;
|
||||
ss1 << endl << endl;
|
||||
Util::printSubHeader ("Network information", ss2);
|
||||
ss2 << left;
|
||||
ss2 << setw (15) << "Network Size" ;
|
||||
ss2 << setw (9) << "Loopy" ;
|
||||
ss2 << setw (15) << "Iterations" ;
|
||||
ss2 << setw (15) << "Solving Time" ;
|
||||
ss2 << endl;
|
||||
unsigned nLoopyNets = 0;
|
||||
unsigned nUnconvergedRuns = 0;
|
||||
double totalSolvingTime = 0.0;
|
||||
for (unsigned i = 0; i < netInfo_.size(); i++) {
|
||||
ss2 << setw (15) << netInfo_[i].size;
|
||||
if (netInfo_[i].loopy) {
|
||||
ss2 << setw (9) << "yes";
|
||||
nLoopyNets ++;
|
||||
} else {
|
||||
ss2 << setw (9) << "no";
|
||||
}
|
||||
if (netInfo_[i].nIters == 0) {
|
||||
ss2 << setw (15) << "n/a" ;
|
||||
} else {
|
||||
ss2 << setw (15) << netInfo_[i].nIters;
|
||||
if (netInfo_[i].nIters > BpOptions::maxIter) {
|
||||
nUnconvergedRuns ++;
|
||||
}
|
||||
}
|
||||
ss2 << setw (15) << netInfo_[i].time;
|
||||
totalSolvingTime += netInfo_[i].time;
|
||||
ss2 << endl;
|
||||
}
|
||||
ss2 << endl << endl;
|
||||
|
||||
unsigned c1 = 0, c2 = 0, c3 = 0, c4 = 0;
|
||||
if (compressInfo_.size() > 0) {
|
||||
Util::printSubHeader ("Compress information", ss3);
|
||||
ss3 << left;
|
||||
ss3 << "Ground Cluster Ground Cluster Neighborless" << endl;
|
||||
ss3 << "Vars Vars Factors Factors Vars" << endl;
|
||||
for (unsigned i = 0; i < compressInfo_.size(); i++) {
|
||||
ss3 << setw (9) << compressInfo_[i].nrGroundVars;
|
||||
ss3 << setw (10) << compressInfo_[i].nrClusterVars;
|
||||
ss3 << setw (10) << compressInfo_[i].nrGroundFactors;
|
||||
ss3 << setw (10) << compressInfo_[i].nrClusterFactors;
|
||||
ss3 << setw (10) << compressInfo_[i].nrNeighborless;
|
||||
ss3 << endl;
|
||||
c1 += compressInfo_[i].nrGroundVars - compressInfo_[i].nrNeighborless;
|
||||
c2 += compressInfo_[i].nrClusterVars;
|
||||
c3 += compressInfo_[i].nrGroundFactors - compressInfo_[i].nrNeighborless;
|
||||
c4 += compressInfo_[i].nrClusterFactors;
|
||||
if (compressInfo_[i].nrNeighborless != 0) {
|
||||
c2 --;
|
||||
c4 --;
|
||||
}
|
||||
}
|
||||
ss3 << endl << endl;
|
||||
}
|
||||
|
||||
ss4 << "primary networks: " << primaryNetCount_ << endl;
|
||||
ss4 << "solved networks: " << netInfo_.size() << endl;
|
||||
ss4 << "loopy networks: " << nLoopyNets << endl;
|
||||
ss4 << "unconverged runs: " << nUnconvergedRuns << endl;
|
||||
ss4 << "total solving time: " << totalSolvingTime << endl;
|
||||
if (compressInfo_.size() > 0) {
|
||||
double pc1 = (1.0 - (c2 / (double)c1)) * 100.0;
|
||||
double pc2 = (1.0 - (c4 / (double)c3)) * 100.0;
|
||||
ss4 << setprecision (5);
|
||||
ss4 << "variable compression: " << pc1 << "%" << endl;
|
||||
ss4 << "factor compression: " << pc2 << "%" << endl;
|
||||
}
|
||||
ss4 << endl << endl;
|
||||
|
||||
ss1 << ss4.str() << ss2.str() << ss3.str();
|
||||
return ss1.str();
|
||||
}
|
||||
|
@@ -1,376 +0,0 @@
|
||||
#ifndef HORUS_UTIL_H
|
||||
#define HORUS_UTIL_H
|
||||
|
||||
#include <cmath>
|
||||
#include <cassert>
|
||||
#include <limits>
|
||||
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <queue>
|
||||
#include <unordered_map>
|
||||
|
||||
#include <sstream>
|
||||
#include <iostream>
|
||||
|
||||
#include "Horus.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
||||
namespace Util {
|
||||
|
||||
template <typename T> void addToVector (vector<T>&, const vector<T>&);
|
||||
|
||||
template <typename T> void addToSet (set<T>&, const vector<T>&);
|
||||
|
||||
template <typename T> void addToQueue (queue<T>&, const vector<T>&);
|
||||
|
||||
template <typename T> bool contains (const vector<T>&, const T&);
|
||||
|
||||
template <typename T> bool contains (const set<T>&, const T&);
|
||||
|
||||
template <typename K, typename V> bool contains (
|
||||
const unordered_map<K, V>&, const K&);
|
||||
|
||||
template <typename T> std::string toString (const T&);
|
||||
|
||||
void toLog (Params&);
|
||||
|
||||
void fromLog (Params&);
|
||||
|
||||
double logSum (double, double);
|
||||
|
||||
void multiply (Params&, const Params&);
|
||||
|
||||
void multiply (Params&, const Params&, unsigned);
|
||||
|
||||
void add (Params&, const Params&);
|
||||
|
||||
void add (Params&, const Params&, unsigned);
|
||||
|
||||
double factorial (double);
|
||||
|
||||
unsigned nrCombinations (unsigned, unsigned);
|
||||
|
||||
unsigned expectedSize (const Ranges&);
|
||||
|
||||
unsigned getNumberOfDigits (int);
|
||||
|
||||
bool isInteger (const string&);
|
||||
|
||||
string parametersToString (const Params&, unsigned = Constants::PRECISION);
|
||||
|
||||
vector<string> getStateLines (const Vars&);
|
||||
|
||||
void printHeader (string, std::ostream& os = std::cout);
|
||||
|
||||
void printSubHeader (string, std::ostream& os = std::cout);
|
||||
|
||||
void printAsteriskLine (std::ostream& os = std::cout);
|
||||
|
||||
void printDashedLine (std::ostream& os = std::cout);
|
||||
|
||||
unsigned maxUnsigned (void);
|
||||
|
||||
};
|
||||
|
||||
|
||||
|
||||
template <typename T> void
|
||||
Util::addToVector (vector<T>& v, const vector<T>& elements)
|
||||
{
|
||||
v.insert (v.end(), elements.begin(), elements.end());
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T> void
|
||||
Util::addToSet (set<T>& s, const vector<T>& elements)
|
||||
{
|
||||
s.insert (elements.begin(), elements.end());
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T> void
|
||||
Util::addToQueue (queue<T>& q, const vector<T>& elements)
|
||||
{
|
||||
for (unsigned i = 0; i < elements.size(); i++) {
|
||||
q.push (elements[i]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T> bool
|
||||
Util::contains (const vector<T>& v, const T& e)
|
||||
{
|
||||
return std::find (v.begin(), v.end(), e) != v.end();
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T> bool
|
||||
Util::contains (const set<T>& s, const T& e)
|
||||
{
|
||||
return s.find (e) != s.end();
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename K, typename V> bool
|
||||
Util::contains (const unordered_map<K, V>& m, const K& k)
|
||||
{
|
||||
return m.find (k) != m.end();
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T> std::string
|
||||
Util::toString (const T& t)
|
||||
{
|
||||
std::stringstream ss;
|
||||
ss << t;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
std::ostream& operator << (std::ostream& os, const vector<T>& v)
|
||||
{
|
||||
os << "[" ;
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
os << ((i != 0) ? ", " : "") << v[i];
|
||||
}
|
||||
os << "]" ;
|
||||
return os;
|
||||
}
|
||||
|
||||
|
||||
namespace {
|
||||
const double INF = -numeric_limits<double>::infinity();
|
||||
};
|
||||
|
||||
|
||||
inline double
|
||||
Util::logSum (double x, double y)
|
||||
{
|
||||
return log (exp (x) + exp (y));
|
||||
assert (isfinite (x) && isfinite (y));
|
||||
// If one value is much smaller than the other, keep the larger value.
|
||||
if (x < (y - log (1e200))) {
|
||||
return y;
|
||||
}
|
||||
if (y < (x - log (1e200))) {
|
||||
return x;
|
||||
}
|
||||
double diff = x - y;
|
||||
assert (isfinite (diff) && isfinite (x) && isfinite (y));
|
||||
if (!isfinite (exp (diff))) {
|
||||
// difference is too large
|
||||
return x > y ? x : y;
|
||||
}
|
||||
// otherwise return the sum.
|
||||
return y + log (static_cast<double>(1.0) + exp (diff));
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
Util::multiply (Params& v1, const Params& v2)
|
||||
{
|
||||
assert (v1.size() == v2.size());
|
||||
for (unsigned i = 0; i < v1.size(); i++) {
|
||||
v1[i] *= v2[i];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
Util::multiply (Params& v1, const Params& v2, unsigned repetitions)
|
||||
{
|
||||
for (unsigned count = 0; count < v1.size(); ) {
|
||||
for (unsigned i = 0; i < v2.size(); i++) {
|
||||
for (unsigned r = 0; r < repetitions; r++) {
|
||||
v1[count] *= v2[i];
|
||||
count ++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
Util::add (Params& v1, const Params& v2)
|
||||
{
|
||||
assert (v1.size() == v2.size());
|
||||
for (unsigned i = 0; i < v1.size(); i++) {
|
||||
v1[i] += v2[i];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
Util::add (Params& v1, const Params& v2, unsigned repetitions)
|
||||
{
|
||||
for (unsigned count = 0; count < v1.size(); ) {
|
||||
for (unsigned i = 0; i < v2.size(); i++) {
|
||||
for (unsigned r = 0; r < repetitions; r++) {
|
||||
v1[count] += v2[i];
|
||||
count ++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline unsigned
|
||||
Util::maxUnsigned (void)
|
||||
{
|
||||
return numeric_limits<unsigned>::max();
|
||||
}
|
||||
|
||||
|
||||
|
||||
namespace LogAware {
|
||||
|
||||
inline double
|
||||
one()
|
||||
{
|
||||
return Globals::logDomain ? 0.0 : 1.0;
|
||||
}
|
||||
|
||||
|
||||
inline double
|
||||
zero() {
|
||||
return Globals::logDomain ? INF : 0.0 ;
|
||||
}
|
||||
|
||||
|
||||
inline double
|
||||
addIdenty()
|
||||
{
|
||||
return Globals::logDomain ? INF : 0.0;
|
||||
}
|
||||
|
||||
|
||||
inline double
|
||||
multIdenty()
|
||||
{
|
||||
return Globals::logDomain ? 0.0 : 1.0;
|
||||
}
|
||||
|
||||
|
||||
inline double
|
||||
withEvidence()
|
||||
{
|
||||
return Globals::logDomain ? 0.0 : 1.0;
|
||||
}
|
||||
|
||||
|
||||
inline double
|
||||
noEvidence() {
|
||||
return Globals::logDomain ? INF : 0.0;
|
||||
}
|
||||
|
||||
|
||||
inline double
|
||||
tl (double v)
|
||||
{
|
||||
return Globals::logDomain ? log (v) : v;
|
||||
}
|
||||
|
||||
|
||||
inline double
|
||||
fl (double v)
|
||||
{
|
||||
return Globals::logDomain ? exp (v) : v;
|
||||
}
|
||||
|
||||
|
||||
void normalize (Params&);
|
||||
|
||||
double getL1Distance (const Params&, const Params&);
|
||||
|
||||
double getMaxNorm (const Params&, const Params&);
|
||||
|
||||
double pow (double, unsigned);
|
||||
|
||||
double pow (double, double);
|
||||
|
||||
void pow (Params&, unsigned);
|
||||
|
||||
void pow (Params&, double);
|
||||
|
||||
};
|
||||
|
||||
|
||||
struct NetInfo
|
||||
{
|
||||
NetInfo (unsigned size, bool loopy, unsigned nIters, double time)
|
||||
{
|
||||
this->size = size;
|
||||
this->loopy = loopy;
|
||||
this->nIters = nIters;
|
||||
this->time = time;
|
||||
}
|
||||
unsigned size;
|
||||
bool loopy;
|
||||
unsigned nIters;
|
||||
double time;
|
||||
};
|
||||
|
||||
|
||||
struct CompressInfo
|
||||
{
|
||||
CompressInfo (unsigned a, unsigned b, unsigned c, unsigned d, unsigned e)
|
||||
{
|
||||
nrGroundVars = a;
|
||||
nrGroundFactors = b;
|
||||
nrClusterVars = c;
|
||||
nrClusterFactors = d;
|
||||
nrNeighborless = e;
|
||||
}
|
||||
unsigned nrGroundVars;
|
||||
unsigned nrGroundFactors;
|
||||
unsigned nrClusterVars;
|
||||
unsigned nrClusterFactors;
|
||||
unsigned nrNeighborless;
|
||||
};
|
||||
|
||||
|
||||
class Statistics
|
||||
{
|
||||
public:
|
||||
static unsigned getSolvedNetworksCounting (void);
|
||||
|
||||
static void incrementPrimaryNetworksCounting (void);
|
||||
|
||||
static unsigned getPrimaryNetworksCounting (void);
|
||||
|
||||
static void updateStatistics (unsigned, bool, unsigned, double);
|
||||
|
||||
static void printStatistics (void);
|
||||
|
||||
static void writeStatistics (const char*);
|
||||
|
||||
static void updateCompressingStatistics (
|
||||
unsigned, unsigned, unsigned, unsigned, unsigned);
|
||||
|
||||
private:
|
||||
static string getStatisticString (void);
|
||||
|
||||
static vector<NetInfo> netInfo_;
|
||||
static vector<CompressInfo> compressInfo_;
|
||||
static unsigned primaryNetCount_;
|
||||
};
|
||||
|
||||
#endif // HORUS_UTIL_H
|
||||
|
@@ -1,102 +0,0 @@
|
||||
#include <algorithm>
|
||||
#include <sstream>
|
||||
|
||||
#include "Var.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
||||
unordered_map<VarId, VarInfo> Var::varsInfo_;
|
||||
|
||||
|
||||
Var::Var (const Var* v)
|
||||
{
|
||||
varId_ = v->varId();
|
||||
range_ = v->range();
|
||||
evidence_ = v->getEvidence();
|
||||
index_ = std::numeric_limits<unsigned>::max();
|
||||
}
|
||||
|
||||
|
||||
|
||||
Var::Var (VarId varId, unsigned range, int evidence)
|
||||
{
|
||||
assert (range != 0);
|
||||
assert (evidence < (int) range);
|
||||
varId_ = varId;
|
||||
range_ = range;
|
||||
evidence_ = evidence;
|
||||
index_ = std::numeric_limits<unsigned>::max();
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
Var::isValidState (int stateIndex)
|
||||
{
|
||||
return stateIndex >= 0 && stateIndex < (int) range_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
Var::isValidState (const string& stateName)
|
||||
{
|
||||
States states = Var::getVarInfo (varId_).states;
|
||||
return Util::contains (states, stateName);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Var::setEvidence (int ev)
|
||||
{
|
||||
assert (ev < (int) range_);
|
||||
evidence_ = ev;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Var::setEvidence (const string& ev)
|
||||
{
|
||||
States states = Var::getVarInfo (varId_).states;
|
||||
for (unsigned i = 0; i < states.size(); i++) {
|
||||
if (states[i] == ev) {
|
||||
evidence_ = i;
|
||||
return;
|
||||
}
|
||||
}
|
||||
assert (false);
|
||||
}
|
||||
|
||||
|
||||
|
||||
string
|
||||
Var::label (void) const
|
||||
{
|
||||
if (Var::varsHaveInfo()) {
|
||||
return Var::getVarInfo (varId_).label;
|
||||
}
|
||||
stringstream ss;
|
||||
ss << "x" << varId_;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
|
||||
|
||||
States
|
||||
Var::states (void) const
|
||||
{
|
||||
if (Var::varsHaveInfo()) {
|
||||
return Var::getVarInfo (varId_).states;
|
||||
}
|
||||
States states;
|
||||
for (unsigned i = 0; i < range_; i++) {
|
||||
stringstream ss;
|
||||
ss << i ;
|
||||
states.push_back (ss.str());
|
||||
}
|
||||
return states;
|
||||
}
|
||||
|
@@ -1,108 +0,0 @@
|
||||
#ifndef HORUS_Var_H
|
||||
#define HORUS_Var_H
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "Util.h"
|
||||
#include "Horus.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
||||
struct VarInfo
|
||||
{
|
||||
VarInfo (string l, const States& sts) : label(l), states(sts) { }
|
||||
string label;
|
||||
States states;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class Var
|
||||
{
|
||||
public:
|
||||
Var (const Var*);
|
||||
|
||||
Var (VarId, unsigned, int = Constants::NO_EVIDENCE);
|
||||
|
||||
virtual ~Var (void) { };
|
||||
|
||||
unsigned varId (void) const { return varId_; }
|
||||
|
||||
unsigned range (void) const { return range_; }
|
||||
|
||||
int getEvidence (void) const { return evidence_; }
|
||||
|
||||
unsigned getIndex (void) const { return index_; }
|
||||
|
||||
void setIndex (unsigned idx) { index_ = idx; }
|
||||
|
||||
operator unsigned () const { return index_; }
|
||||
|
||||
bool hasEvidence (void) const
|
||||
{
|
||||
return evidence_ != Constants::NO_EVIDENCE;
|
||||
}
|
||||
|
||||
bool operator== (const Var& var) const
|
||||
{
|
||||
assert (!(varId_ == var.varId() && range_ != var.range()));
|
||||
return varId_ == var.varId();
|
||||
}
|
||||
|
||||
bool operator!= (const Var& var) const
|
||||
{
|
||||
assert (!(varId_ == var.varId() && range_ != var.range()));
|
||||
return varId_ != var.varId();
|
||||
}
|
||||
|
||||
bool isValidState (int);
|
||||
|
||||
bool isValidState (const string&);
|
||||
|
||||
void setEvidence (int);
|
||||
|
||||
void setEvidence (const string&);
|
||||
|
||||
string label (void) const;
|
||||
|
||||
States states (void) const;
|
||||
|
||||
static void addVarInfo (
|
||||
VarId vid, string label, const States& states)
|
||||
{
|
||||
assert (Util::contains (varsInfo_, vid) == false);
|
||||
varsInfo_.insert (make_pair (vid, VarInfo (label, states)));
|
||||
}
|
||||
|
||||
static VarInfo getVarInfo (VarId vid)
|
||||
{
|
||||
assert (Util::contains (varsInfo_, vid));
|
||||
return varsInfo_.find (vid)->second;
|
||||
}
|
||||
|
||||
static bool varsHaveInfo (void)
|
||||
{
|
||||
return varsInfo_.size() != 0;
|
||||
}
|
||||
|
||||
static void clearVarsInfo (void)
|
||||
{
|
||||
varsInfo_.clear();
|
||||
}
|
||||
|
||||
private:
|
||||
VarId varId_;
|
||||
unsigned range_;
|
||||
int evidence_;
|
||||
unsigned index_;
|
||||
|
||||
static unordered_map<VarId, VarInfo> varsInfo_;
|
||||
|
||||
};
|
||||
|
||||
#endif // BP_Var_H
|
||||
|
@@ -1,158 +0,0 @@
|
||||
#include <algorithm>
|
||||
|
||||
#include "VarElimSolver.h"
|
||||
#include "ElimGraph.h"
|
||||
#include "Factor.h"
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
VarElimSolver::~VarElimSolver (void)
|
||||
{
|
||||
delete factorList_.back();
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
VarElimSolver::solveQuery (VarIds queryVids)
|
||||
{
|
||||
factorList_.clear();
|
||||
varFactors_.clear();
|
||||
elimOrder_.clear();
|
||||
createFactorList();
|
||||
absorveEvidence();
|
||||
findEliminationOrder (queryVids);
|
||||
processFactorList (queryVids);
|
||||
Params params = factorList_.back()->params();
|
||||
if (Globals::logDomain) {
|
||||
Util::fromLog (params);
|
||||
}
|
||||
return params;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
VarElimSolver::createFactorList (void)
|
||||
{
|
||||
const FacNodes& facNodes = fg.facNodes();
|
||||
factorList_.reserve (facNodes.size() * 2);
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
factorList_.push_back (new Factor (facNodes[i]->factor()));
|
||||
const VarNodes& neighs = facNodes[i]->neighbors();
|
||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||
unordered_map<VarId,vector<unsigned> >::iterator it
|
||||
= varFactors_.find (neighs[j]->varId());
|
||||
if (it == varFactors_.end()) {
|
||||
it = varFactors_.insert (make_pair (
|
||||
neighs[j]->varId(), vector<unsigned>())).first;
|
||||
}
|
||||
it->second.push_back (i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
VarElimSolver::absorveEvidence (void)
|
||||
{
|
||||
const VarNodes& varNodes = fg.varNodes();
|
||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||
if (varNodes[i]->hasEvidence()) {
|
||||
const vector<unsigned>& idxs =
|
||||
varFactors_.find (varNodes[i]->varId())->second;
|
||||
for (unsigned j = 0; j < idxs.size(); j++) {
|
||||
Factor* factor = factorList_[idxs[j]];
|
||||
if (factor->nrArguments() == 1) {
|
||||
factorList_[idxs[j]] = 0;
|
||||
} else {
|
||||
factorList_[idxs[j]]->absorveEvidence (
|
||||
varNodes[i]->varId(), varNodes[i]->getEvidence());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
VarElimSolver::findEliminationOrder (const VarIds& vids)
|
||||
{
|
||||
elimOrder_ = ElimGraph::getEliminationOrder (factorList_, vids);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
VarElimSolver::processFactorList (const VarIds& vids)
|
||||
{
|
||||
for (unsigned i = 0; i < elimOrder_.size(); i++) {
|
||||
eliminate (elimOrder_[i]);
|
||||
}
|
||||
|
||||
Factor* finalFactor = new Factor();
|
||||
for (unsigned i = 0; i < factorList_.size(); i++) {
|
||||
if (factorList_[i]) {
|
||||
finalFactor->multiply (*factorList_[i]);
|
||||
delete factorList_[i];
|
||||
factorList_[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
VarIds unobservedVids;
|
||||
for (unsigned i = 0; i < vids.size(); i++) {
|
||||
if (fg.getVarNode (vids[i])->hasEvidence() == false) {
|
||||
unobservedVids.push_back (vids[i]);
|
||||
}
|
||||
}
|
||||
|
||||
finalFactor->reorderArguments (unobservedVids);
|
||||
finalFactor->normalize();
|
||||
factorList_.push_back (finalFactor);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
VarElimSolver::eliminate (VarId elimVar)
|
||||
{
|
||||
Factor* result = 0;
|
||||
vector<unsigned>& idxs = varFactors_.find (elimVar)->second;
|
||||
for (unsigned i = 0; i < idxs.size(); i++) {
|
||||
unsigned idx = idxs[i];
|
||||
if (factorList_[idx]) {
|
||||
if (result == 0) {
|
||||
result = new Factor (*factorList_[idx]);
|
||||
} else {
|
||||
result->multiply (*factorList_[idx]);
|
||||
}
|
||||
delete factorList_[idx];
|
||||
factorList_[idx] = 0;
|
||||
}
|
||||
}
|
||||
if (result != 0 && result->nrArguments() != 1) {
|
||||
result->sumOut (elimVar);
|
||||
factorList_.push_back (result);
|
||||
const VarIds& resultVarIds = result->arguments();
|
||||
for (unsigned i = 0; i < resultVarIds.size(); i++) {
|
||||
vector<unsigned>& idxs =
|
||||
varFactors_.find (resultVarIds[i])->second;
|
||||
idxs.push_back (factorList_.size() - 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
VarElimSolver::printActiveFactors (void)
|
||||
{
|
||||
for (unsigned i = 0; i < factorList_.size(); i++) {
|
||||
if (factorList_[i] != 0) {
|
||||
factorList_[i]->print();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -1,42 +0,0 @@
|
||||
#ifndef HORUS_VARELIMSOLVER_H
|
||||
#define HORUS_VARELIMSOLVER_H
|
||||
|
||||
#include "unordered_map"
|
||||
|
||||
#include "Solver.h"
|
||||
#include "FactorGraph.h"
|
||||
#include "Horus.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
||||
class VarElimSolver : public Solver
|
||||
{
|
||||
public:
|
||||
VarElimSolver (const FactorGraph& fg) : Solver (fg) { }
|
||||
|
||||
~VarElimSolver (void);
|
||||
|
||||
Params solveQuery (VarIds);
|
||||
|
||||
private:
|
||||
void createFactorList (void);
|
||||
|
||||
void absorveEvidence (void);
|
||||
|
||||
void findEliminationOrder (const VarIds&);
|
||||
|
||||
void processFactorList (const VarIds&);
|
||||
|
||||
void eliminate (VarId);
|
||||
|
||||
void printActiveFactors (void);
|
||||
|
||||
vector<Factor*> factorList_;
|
||||
VarIds elimOrder_;
|
||||
unordered_map<VarId, vector<unsigned>> varFactors_;
|
||||
};
|
||||
|
||||
#endif // HORUS_VARELIMSOLVER_H
|
||||
|
@@ -1,35 +0,0 @@
|
||||
|
||||
if [ $1 ] && [ $1 == "clear" ]; then
|
||||
rm *~
|
||||
rm -f school/*.log school/*~
|
||||
rm -f city/*.log city/*~
|
||||
rm -f workshop_attrs/*.log workshop_attrs/*~
|
||||
fi
|
||||
|
||||
function run_solver
|
||||
{
|
||||
constraint=$1
|
||||
solver_flag=true
|
||||
if [ -n "$2" ]; then
|
||||
if [ $SOLVER = hve ]; then
|
||||
extra_flag=clpbn_horus:set_horus_flag\(elim_heuristic,$2\)
|
||||
elif [ $SOLVER = bp ]; then
|
||||
extra_flag=clpbn_horus:set_horus_flag\(schedule,$2\)
|
||||
elif [ $SOLVER = cbp ]; then
|
||||
extra_flag=clpbn_horus:set_horus_flag\(schedule,$2\)
|
||||
else
|
||||
echo "unknow flag $2"
|
||||
fi
|
||||
fi
|
||||
/usr/bin/time -o $LOG_FILE -a -f "real:%E\tuser:%U\tsys:%S" \
|
||||
$YAP << EOF >> $LOG_FILE 2>> ignore.$LOG_FILE
|
||||
[$NETWORK].
|
||||
[$constraint].
|
||||
clpbn_horus:set_solver($SOLVER).
|
||||
clpbn_horus:set_horus_flag(use_logarithms, true).
|
||||
$solver_flag.
|
||||
$QUERY.
|
||||
open("$LOG_FILE", 'append', S), format(S, '$constraint: ~15+ ', []), close(S).
|
||||
EOF
|
||||
}
|
||||
|
@@ -1,17 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
source city.sh
|
||||
source ../benchs.sh
|
||||
|
||||
SOLVER="bp"
|
||||
|
||||
YAP=~/bin/$SHORTNAME-$SOLVER
|
||||
|
||||
LOG_FILE=$SOLVER.log
|
||||
#LOG_FILE=results`date "+ %H:%M:%S %d-%m-%Y"`.
|
||||
|
||||
rm -f $LOG_FILE
|
||||
rm -f ignore.$LOG_FILE
|
||||
|
||||
run_all_graphs "bp(shedule=seq_fixed) " seq_fixed
|
||||
|
@@ -1,17 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
source city.sh
|
||||
source ../benchs.sh
|
||||
|
||||
SOLVER="cbp"
|
||||
|
||||
YAP=~/bin/$SHORTNAME-$SOLVER
|
||||
|
||||
LOG_FILE=$SOLVER.log
|
||||
#LOG_FILE=results`date "+ %H:%M:%S %d-%m-%Y"`.
|
||||
|
||||
rm -f $LOG_FILE
|
||||
rm -f ignore.$LOG_FILE
|
||||
|
||||
run_all_graphs "cbp(shedule=seq_fixed) " seq_fixed
|
||||
|
@@ -1,25 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
NETWORK="'../../examples/city'"
|
||||
SHORTNAME="city"
|
||||
QUERY="is_joe_guilty(X)"
|
||||
|
||||
|
||||
function run_all_graphs
|
||||
{
|
||||
cp ~/bin/yap $YAP
|
||||
echo -n "**********************************" >> $LOG_FILE
|
||||
echo "**********************************" >> $LOG_FILE
|
||||
echo "results for solver $1" >> $LOG_FILE
|
||||
echo -n "**********************************" >> $LOG_FILE
|
||||
echo "**********************************" >> $LOG_FILE
|
||||
run_solver city_5 $2
|
||||
#run_solver city_1000 $2
|
||||
#run_solver city_5000 $2
|
||||
#run_solver city_10000 $2
|
||||
#run_solver city_50000 $2
|
||||
#run_solver city_100000 $2
|
||||
#run_solver city_500000 $2
|
||||
#run_solver city_1000000 $2
|
||||
}
|
||||
|
@@ -1,37 +0,0 @@
|
||||
#!/home/tiago/bin/yap -L --
|
||||
|
||||
|
||||
:- initialization(main).
|
||||
|
||||
|
||||
main :-
|
||||
unix(argv([H])),
|
||||
generate_town(H).
|
||||
|
||||
|
||||
generate_town(N) :-
|
||||
atomic_concat(['city_', N, '.yap'], FileName),
|
||||
open(FileName, 'write', S),
|
||||
atom_number(N, N2),
|
||||
generate_people(S, N2, 4),
|
||||
write(S, '\n'),
|
||||
generate_query(S, N2, 4),
|
||||
write(S, '\n'),
|
||||
close(S).
|
||||
|
||||
|
||||
generate_people(S, N, Counting) :-
|
||||
Counting > N, !.
|
||||
generate_people(S, N, Counting) :-
|
||||
format(S, 'people(p~w, nyc).~n', [Counting]),
|
||||
Counting1 is Counting + 1,
|
||||
generate_people(S, N, Counting1).
|
||||
|
||||
|
||||
generate_query(S, N, Counting) :-
|
||||
Counting > N, !.
|
||||
generate_query(S, N, Counting) :- !,
|
||||
format(S, 'ev(descn(p~w, t)).~n', [Counting]),
|
||||
Counting1 is Counting + 1,
|
||||
generate_query(S, N, Counting1).
|
||||
|
@@ -1,17 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
source city.sh
|
||||
source ../benchs.sh
|
||||
|
||||
SOLVER="fove"
|
||||
|
||||
YAP=~/bin/$SHORTNAME-$SOLVER
|
||||
|
||||
LOG_FILE=$SOLVER.log
|
||||
#LOG_FILE=results`date "+ %H:%M:%S %d-%m-%Y"`.
|
||||
|
||||
rm -f $LOG_FILE
|
||||
rm -f ignore.$LOG_FILEE
|
||||
|
||||
run_all_graphs "fove "
|
||||
|
@@ -1,17 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
source city.sh
|
||||
source ../benchs.sh
|
||||
|
||||
SOLVER="hve"
|
||||
|
||||
YAP=~/bin/$SHORTNAME-$SOLVER
|
||||
|
||||
LOG_FILE=$SOLVER.log
|
||||
#LOG_FILE=results`date "+ %H:%M:%S %d-%m-%Y"`.
|
||||
|
||||
rm -f $LOG_FILE
|
||||
rm -f ignore.$LOG_FILE
|
||||
|
||||
run_all_graphs "hve(elim_heuristic=min_neighbors) " min_neighbors
|
||||
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -1,95 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
#cp ~/bin/yap ~/bin/school_all
|
||||
#YAP=~/bin/school_all
|
||||
YAP=~/bin/yap
|
||||
|
||||
#OUT_FILE_NAME=results`date "+ %H:%M:%S %d-%m-%Y"`.log
|
||||
OUT_FILE_NAME=results.log
|
||||
rm -f $OUT_FILE_NAME
|
||||
rm -f ignore.$OUT_FILE_NAME
|
||||
|
||||
# yap -g "['../../../../examples/School/sch32'], [missing5], use_module(library(clpbn/learning/em)), graph(L), clpbn:set_clpbn_flag(em_solver,bp), clpbn_horus:set_horus_flag(inf_alg, bp), statistics(runtime, _), em(L,0.01,10,_,Lik), statistics(runtime, [T,_])."
|
||||
|
||||
function run_solver
|
||||
{
|
||||
if [ $2 = bp ]
|
||||
then
|
||||
if [ $4 = ve ]
|
||||
then
|
||||
extra_flag1=clpbn_horus:set_horus_flag\(inf_alg,$4\)
|
||||
extra_flag2=clpbn_horus:set_horus_flag\(elim_heuristic,$5\)
|
||||
else
|
||||
extra_flag1=clpbn_horus:set_horus_flag\(inf_alg,$4\)
|
||||
extra_flag2=clpbn_horus:set_horus_flag\(schedule,$5\)
|
||||
fi
|
||||
else
|
||||
extra_flag1=true
|
||||
extra_flag2=true
|
||||
fi
|
||||
/usr/bin/time -o "$OUT_FILE_NAME" -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF &>> "ignore.$OUT_FILE_NAME"
|
||||
:- [pos:train].
|
||||
:- ['../../../../examples/School/sch32'].
|
||||
:- use_module(library(clpbn/learning/em)).
|
||||
:- use_module(library(clpbn/bp)).
|
||||
[$1].
|
||||
graph(L),
|
||||
clpbn:set_clpbn_flag(em_solver,$2),
|
||||
$extra_flag1, $extra_flag2,
|
||||
em(L,0.01,10,_,Lik),
|
||||
open("$OUT_FILE_NAME", 'append',S),
|
||||
format(S, '$3: ~11+ Lik = ~3f, ',[Lik]),
|
||||
close(S).
|
||||
EOF
|
||||
}
|
||||
|
||||
|
||||
function run_all_graphs
|
||||
{
|
||||
echo "************************************************************************" >> "$OUT_FILE_NAME"
|
||||
echo "results for solver $2" >> "$OUT_FILE_NAME"
|
||||
echo "************************************************************************" >> "$OUT_FILE_NAME"
|
||||
run_solver missing5 $1 missing5 $3 $4 $5
|
||||
run_solver missing10 $1 missing10 $3 $4 $5
|
||||
#run_solver missing20 $1 missing20 $3 $4 $5
|
||||
#run_solver missing30 $1 missing30 $3 $4 $5
|
||||
#run_solver missing40 $1 missing40 $3 $4 $5
|
||||
#run_solver missing50 $1 missing50 $3 $4 $5
|
||||
}
|
||||
|
||||
|
||||
#run_all_graphs bp "hve(min_neighbors) " ve min_neighbors
|
||||
#run_all_graphs bp "bp(seq_fixed) " bp seq_fixed
|
||||
#run_all_graphs bp "cbp(seq_fixed) " cbp seq_fixed
|
||||
exit
|
||||
|
||||
|
||||
run_all_graphs bp "hve(min_neighbors) " ve min_neighbors
|
||||
run_all_graphs bp "hve(min_weight) " ve min_weight
|
||||
run_all_graphs bp "hve(min_fill) " ve min_fill
|
||||
run_all_graphs bp "hve(w_min_fill) " ve weighted_min_fill
|
||||
run_all_graphs bp "bp(seq_fixed) " bp seq_fixed
|
||||
run_all_graphs bp "bp(max_residual) " bp max_residual
|
||||
run_all_graphs bp "cbp(seq_fixed) " cbp seq_fixed
|
||||
run_all_graphs bp "cbp(max_residual) " cbp max_residual
|
||||
run_all_graphs gibbs "gibbs "
|
||||
echo "************************************************************************" >> "$OUT_FILE_NAME"
|
||||
echo "results for solver ve" >> "$OUT_FILE_NAME"
|
||||
echo "************************************************************************" >> "$OUT_FILE_NAME"
|
||||
run_solver missing5 ve missing5 $3 $4 $5
|
||||
run_solver missing10 ve missing10 $3 $4 $5
|
||||
run_solver missing20 ve missing20 $3 $4 $5
|
||||
run_solver missing30 ve missing30 $3 $4 $5
|
||||
run_solver missing40 ve missing40 $3 $4 $5
|
||||
#run_solver missing50 ve missing50 $3 $4 $5 #+24h!
|
||||
echo "************************************************************************" >> "$OUT_FILE_NAME"
|
||||
echo "results for solver jt" >> "$OUT_FILE_NAME"
|
||||
echo "************************************************************************" >> "$OUT_FILE_NAME"
|
||||
run_solver missing5 jt missing5 $3 $4 $5
|
||||
run_solver missing10 jt missing10 $3 $4 $5
|
||||
run_solver missing20 jt missing20 $3 $4 $5
|
||||
#run_solver missing30 jt missing30 $3 $4 $5 #+24h!
|
||||
#run_solver missing40 jt missing40 $3 $4 $5 #+24h!
|
||||
#run_solver missing50 jt missing50 $3 $4 $5 #+24h!
|
||||
exit
|
||||
|
@@ -1,11 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
source wa.sh
|
||||
source ../benchs.sh
|
||||
|
||||
SOLVER="bp"
|
||||
|
||||
YAP=~/bin/$SHORTNAME-$SOLVER
|
||||
|
||||
run_all_graphs "bp(shedule=seq_fixed) " seq_fixed
|
||||
|
@@ -1,11 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
source wa.sh
|
||||
source ../benchs.sh
|
||||
|
||||
SOLVER="cbp"
|
||||
|
||||
YAP=~/bin/$SHORTNAME-$SOLVER
|
||||
|
||||
run_all_graphs "cbp(shedule=seq_fixed) " seq_fixed
|
||||
|
@@ -1,12 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
source wa.sh
|
||||
source ../benchs.sh
|
||||
|
||||
SOLVER="fove"
|
||||
|
||||
YAP=~/bin/$SHORTNAME-$SOLVER
|
||||
|
||||
run_all_graphs "fove "
|
||||
|
||||
|
@@ -1,11 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
source wa.sh
|
||||
source ../benchs.sh
|
||||
|
||||
SOLVER="hve"
|
||||
|
||||
YAP=~/bin/$SHORTNAME-$SOLVER
|
||||
|
||||
run_all_graphs "hve(elim_heuristic=min_neighbors) " min_neighbors
|
||||
|
@@ -1,27 +0,0 @@
|
||||
#!/home/tiago/bin/yap -L --
|
||||
|
||||
|
||||
:- initialization(main).
|
||||
|
||||
|
||||
main :-
|
||||
unix(argv([H])),
|
||||
generate_town(H).
|
||||
|
||||
|
||||
generate_town(N) :-
|
||||
atomic_concat(['pop_', N, '.yap'], FileName),
|
||||
open(FileName, 'write', S),
|
||||
atom_number(N, N2),
|
||||
generate_people(S, N2, 4),
|
||||
write(S, '\n'),
|
||||
close(S).
|
||||
|
||||
|
||||
generate_people(S, N, Counting) :-
|
||||
Counting > N, !.
|
||||
generate_people(S, N, Counting) :-
|
||||
format(S, 'people(p~w).~n', [Counting]),
|
||||
Counting1 is Counting + 1,
|
||||
generate_people(S, N, Counting1).
|
||||
|
@@ -1,33 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
NETWORK="'../../examples/workshop_attrs'"
|
||||
SHORTNAME="wa"
|
||||
QUERY="series(X)"
|
||||
|
||||
|
||||
function run_all_graphs
|
||||
{
|
||||
LOG_FILE=$SOLVER.log
|
||||
#LOG_FILE=results`date "+ %H:%M:%S %d-%m-%Y"`.
|
||||
|
||||
rm -f $LOG_FILE
|
||||
rm -f ignore.$LOG_FILE
|
||||
|
||||
cp ~/bin/yap $YAP
|
||||
|
||||
echo -n "**********************************" >> $LOG_FILE
|
||||
echo "**********************************" >> $LOG_FILE
|
||||
echo "results for solver $1" >> $LOG_FILE
|
||||
echo -n "**********************************" >> $LOG_FILE
|
||||
echo "**********************************" >> $LOG_FILE
|
||||
run_solver pop_10 $2
|
||||
#run_solver pop_1000 $2
|
||||
#run_solver pop_5000 $2
|
||||
#run_solver pop_10000 $2
|
||||
#run_solver pop_50000 $2
|
||||
#run_solver pop_100000 $2
|
||||
#run_solver pop_500000 $2
|
||||
#run_solver pop_1000000 $2
|
||||
}
|
||||
|
||||
|
@@ -1,47 +0,0 @@
|
||||
5
|
||||
|
||||
1
|
||||
0
|
||||
2
|
||||
2
|
||||
0 0.001
|
||||
1 0.999
|
||||
|
||||
1
|
||||
1
|
||||
2
|
||||
2
|
||||
0 0.002
|
||||
1 0.998
|
||||
|
||||
3
|
||||
1 0 2
|
||||
2 2 2
|
||||
8
|
||||
0 0.95
|
||||
1 0.94
|
||||
2 0.29
|
||||
3 0.001
|
||||
4 0.05
|
||||
5 0.06
|
||||
6 0.71
|
||||
7 0.999
|
||||
|
||||
2
|
||||
2 3
|
||||
2 2
|
||||
4
|
||||
0 0.9
|
||||
1 0.05
|
||||
2 0.1
|
||||
3 0.95
|
||||
|
||||
2
|
||||
2 4
|
||||
2 2
|
||||
4
|
||||
0 0.7
|
||||
1 0.01
|
||||
2 0.3
|
||||
3 0.99
|
||||
|
@@ -1,28 +0,0 @@
|
||||
MARKOV
|
||||
5
|
||||
2 2 2 2 2
|
||||
5
|
||||
1 0
|
||||
1 1
|
||||
3 2 0 1
|
||||
2 3 2
|
||||
2 4 2
|
||||
|
||||
2
|
||||
.001 .999
|
||||
|
||||
2
|
||||
.002 .998
|
||||
|
||||
8
|
||||
.95 .94 .29 .001
|
||||
.05 .06 .71 .999
|
||||
|
||||
4
|
||||
.9 .05
|
||||
.1 .95
|
||||
|
||||
4
|
||||
.7 .01
|
||||
.3 .99
|
||||
|
@@ -1,41 +0,0 @@
|
||||
|
||||
:- use_module(library(pfl)).
|
||||
|
||||
%:- set_pfl_flag(solver,ve).
|
||||
:- set_pfl_flag(solver,bp), clpbn_horus:set_horus_flag(inf_alg,ve).
|
||||
%:- set_pfl_flag(solver,bp), clpbn_horus:set_horus_flag(inf_alg,bp).
|
||||
%:- set_pfl_flag(solver,fove).
|
||||
|
||||
% :- yap_flag(write_strings, off).
|
||||
|
||||
|
||||
bayes burglary::[b1,b3] ; [0.001, 0.999] ; [].
|
||||
|
||||
bayes earthquake::[e1,e2] ; [0.002, 0.998]; [].
|
||||
|
||||
bayes alarm::[a1,a2] , burglary, earthquake ; [0.95, 0.94, 0.29, 0.001, 0.05, 0.06, 0.71, 0.999] ; [].
|
||||
|
||||
bayes john_calls::[j1,j2] , alarm ; [0.9, 0.05, 0.1, 0.95] ; [].
|
||||
|
||||
bayes mary_calls::[m1,m2] , alarm ; [0.7, 0.01, 0.3, 0.99] ; [].
|
||||
|
||||
|
||||
b_cpt([0.001, 0.999]).
|
||||
|
||||
e_cpt([0.002, 0.998]).
|
||||
|
||||
a_cpt([0.95, 0.94, 0.29, 0.001,
|
||||
0.05, 0.06, 0.71, 0.999]).
|
||||
|
||||
jc_cpt([0.9, 0.05,
|
||||
0.1, 0.95]).
|
||||
|
||||
mc_cpt([0.7, 0.01,
|
||||
0.3, 0.99]).
|
||||
|
||||
% ?- alarm(A).
|
||||
?- john_calls(J), mary_calls(m1).
|
||||
%?- john_calls(J), mary_calls(m1), alarm(a1).
|
||||
%?- john_calls(J), alarm(a1).
|
||||
|
||||
|
@@ -1,15 +0,0 @@
|
||||
# example in counting belief propagation paper
|
||||
|
||||
MARKOV
|
||||
3
|
||||
2 2 2
|
||||
2
|
||||
2 0 1
|
||||
2 2 1
|
||||
|
||||
4
|
||||
1.2 1.4 2.0 0.4
|
||||
|
||||
4
|
||||
1.2 1.4 2.0 0.4
|
||||
|
@@ -1,102 +0,0 @@
|
||||
:- use_module(library(pfl)).
|
||||
|
||||
:- clpbn_horus:set_solver(fove).
|
||||
%:- clpbn_horus:set_solver(hve).
|
||||
:- clpbn_horus:set_solver(bp).
|
||||
%:- clpbn_horus:set_solver(cbp).
|
||||
|
||||
|
||||
people(joe,nyc).
|
||||
people(p2, nyc).
|
||||
people(p3, nyc).
|
||||
|
||||
|
||||
ev(descn(p2, t)).
|
||||
ev(descn(p3, t)).
|
||||
|
||||
% :- [city_7].
|
||||
|
||||
bayes city_conservativeness(C)::[y,n] ; cons_table(C) ; [people(_,C)].
|
||||
|
||||
bayes gender(P)::[m,f] ; gender_table(P) ; [people(P,_)].
|
||||
|
||||
bayes hair_color(P)::[t,f], city_conservativeness(C) ; hair_color_table(P) ; [people(P,C)].
|
||||
|
||||
bayes car_color(P)::[t,f], hair_color(P) ; car_color_table(P); [people(P,_)].
|
||||
|
||||
bayes height(P)::[t,f], gender(P) ; height_table(P) ; [people(P,_)].
|
||||
|
||||
bayes shoe_size(P):[t,f], height(P) ; shoe_size_table(P); [people(P,_)].
|
||||
|
||||
bayes guilty(P)::[y,n] ; guilty_table(P) ; [people(P,_)].
|
||||
|
||||
bayes descn(P)::[t,f], car_color(P), hair_color(P), height(P), guilty(P) ; descn_table(P) ; [people(P,_)].
|
||||
|
||||
bayes witness(C)::[t,f], descn(Joe), descn(P2) ; wit_table ; [people(_,C), Joe=joe, P2=p2].
|
||||
|
||||
|
||||
cons_table(amsterdam, [0.2, 0.8]) :- !.
|
||||
cons_table(_, [0.8, 0.2]).
|
||||
|
||||
|
||||
gender_table(_, [0.55, 0.44]).
|
||||
|
||||
|
||||
hair_color_table(_,
|
||||
/* conservative_city */
|
||||
/* y n */
|
||||
[ 0.05, 0.1,
|
||||
0.95, 0.9 ]).
|
||||
|
||||
|
||||
car_color_table(_,
|
||||
/* t f */
|
||||
[ 0.9, 0.2,
|
||||
0.1, 0.8 ]).
|
||||
|
||||
|
||||
height_table(_,
|
||||
/* m f */
|
||||
[ 0.6, 0.4,
|
||||
0.4, 0.6 ]).
|
||||
|
||||
|
||||
shoe_size_table(_,
|
||||
/* t f */
|
||||
[ 0.9, 0.1,
|
||||
0.1, 0.9 ]).
|
||||
|
||||
|
||||
guilty_table(_, [0.23, 0.77]).
|
||||
|
||||
|
||||
descn_table(_,
|
||||
/* color, hair, height, guilt */
|
||||
/* ttttt tttf ttft ttff tfttt tftf tfft tfff ttttt fttf ftft ftff ffttt fftf ffft ffff */
|
||||
[ 0.99, 0.5, 0.23, 0.88, 0.41, 0.3, 0.76, 0.87, 0.44, 0.43, 0.29, 0.72, 0.33, 0.91, 0.95, 0.92,
|
||||
0.01, 0.5, 0.77, 0.12, 0.59, 0.7, 0.24, 0.13, 0.56, 0.57, 0.61, 0.28, 0.77, 0.09, 0.05, 0.08]).
|
||||
|
||||
|
||||
wit_table([0.2, 0.45, 0.24, 0.34,
|
||||
0.8, 0.55, 0.76, 0.66]).
|
||||
|
||||
|
||||
runall(G, Wrapper) :-
|
||||
findall(G, Wrapper, L),
|
||||
execute_all(L).
|
||||
|
||||
|
||||
execute_all([]).
|
||||
execute_all(G.L) :-
|
||||
call(G),
|
||||
execute_all(L).
|
||||
|
||||
|
||||
is_joe_guilty(Guilty) :-
|
||||
witness(nyc, t),
|
||||
runall(X, ev(X)),
|
||||
guilty(joe, Guilty).
|
||||
|
||||
|
||||
% ?- is_joe_guilty(Guilty)
|
||||
|
@@ -1,31 +0,0 @@
|
||||
:- use_module(library(pfl)).
|
||||
|
||||
:- clpbn_horus:set_solver(fove).
|
||||
%:- clpbn_horus:set_solver(hve).
|
||||
%:- clpbn_horus:set_solver(bp).
|
||||
%:- clpbn_horus:set_solver(cbp).
|
||||
|
||||
:- yap_flag(write_strings, off).
|
||||
|
||||
c(p1,w1).
|
||||
c(p1,w2).
|
||||
c(p1,w3).
|
||||
c(p2,w1).
|
||||
c(p2,w2).
|
||||
c(p2,w3).
|
||||
c(p3,w1).
|
||||
c(p3,w2).
|
||||
c(p3,w3).
|
||||
c(p4,w1).
|
||||
c(p4,w2).
|
||||
c(p4,w3).
|
||||
c(p5,w1).
|
||||
c(p5,w2).
|
||||
c(p5,w3).
|
||||
|
||||
markov attends(P)::[t,f] , hot(W)::[t,f] ; [0.1, 0.2, 0.3, 0.4] ; [c(P,W)].
|
||||
|
||||
markov attends(P)::[t,f], series::[t,f] ; [0.5, 0.6, 0.7, 0.8] ; [c(P,_)].
|
||||
|
||||
% ?- series(X).
|
||||
|
@@ -1,21 +0,0 @@
|
||||
|
||||
:- use_module(library(pfl)).
|
||||
|
||||
:- set_pfl_flag(solver,fove).
|
||||
%:- set_pfl_flag(solver,bp), clpbn_horus:set_horus_flag(inf_alg,ve).
|
||||
%:- set_pfl_flag(solver,bp), clpbn_horus:set_horus_flag(inf_alg,bp).
|
||||
%:- set_pfl_flag(solver,bp), clpbn_horus:set_horus_flag(inf_alg,cbp).
|
||||
|
||||
|
||||
t(ann).
|
||||
t(dave).
|
||||
|
||||
% p(ann,t).
|
||||
|
||||
markov p(X)::[t,f] ; [0.1, 0.3] ; [t(X)].
|
||||
|
||||
% use standard Prolog queries: provide evidence first.
|
||||
|
||||
?- p(ann,t), p(ann,X).
|
||||
% ?- p(ann,X).
|
||||
|
@@ -1,24 +0,0 @@
|
||||
:- use_module(library(pfl)).
|
||||
|
||||
:- clpbn_horus:set_solver(fove).
|
||||
%:- clpbn_horus:set_solver(hve).
|
||||
%:- clpbn_horus:set_solver(bp).
|
||||
%:- clpbn_horus:set_solver(cbp).
|
||||
|
||||
:- yap_flag(write_strings, off).
|
||||
|
||||
|
||||
friends(P1, P2) :-
|
||||
people(P1),
|
||||
people(P2),
|
||||
P1 \= P2.
|
||||
|
||||
people @ 3.
|
||||
|
||||
markov smokes(P)::[t,f], cancer(P)::[t,f] ; [0.1, 0.2, 0.3, 0.4] ; [people(P)].
|
||||
|
||||
markov friend(P1,P2)::[t,f], smokes(P1)::[t,f], smokes(P2)::[t,f] ;
|
||||
[0.5, 0.6, 0.7, 0.8, 0.5, 0.6, 0.7, 0.8] ; [friends(P1, P2)].
|
||||
|
||||
% ?- smokes(p1, t), smokes(p2, f), friend(p1, p2, X).
|
||||
|
@@ -1,27 +0,0 @@
|
||||
:- use_module(library(pfl)).
|
||||
|
||||
%:- clpbn_horus:set_solver(fove).
|
||||
%:- clpbn_horus:set_solver(hve).
|
||||
:- clpbn_horus:set_solver(bp).
|
||||
%:- clpbn_horus:set_solver(cbp).
|
||||
|
||||
:- yap_flag(write_strings, off).
|
||||
|
||||
people @ 3.
|
||||
|
||||
markov attends(P)::[t,f], attr1::[t,f] ; [0.11, 0.2, 0.3, 0.4] ; [people(P)].
|
||||
|
||||
markov attends(P)::[t,f], attr2::[t,f] ; [0.1, 0.22, 0.3, 0.4] ; [people(P)].
|
||||
|
||||
markov attends(P)::[t,f], attr3::[t,f] ; [0.1, 0.2, 0.33, 0.4] ; [people(P)].
|
||||
|
||||
markov attends(P)::[t,f], attr4::[t,f] ; [0.1, 0.2, 0.3, 0.44] ; [people(P)].
|
||||
|
||||
markov attends(P)::[t,f], attr5::[t,f] ; [0.1, 0.2, 0.3, 0.45] ; [people(P)].
|
||||
|
||||
markov attends(P)::[t,f], attr6::[t,f] ; [0.1, 0.2, 0.3, 0.46] ; [people(P)].
|
||||
|
||||
markov attends(P)::[t,f], series::[t,f] ; [0.5, 0.6, 0.7, 0.87] ; [people(P)].
|
||||
|
||||
% ?- series(X).
|
||||
|
@@ -41,20 +41,30 @@ do_network([], _, _, _) :- !.
|
||||
do_network(QueryVars, EVars, Keys, Factors) :-
|
||||
retractall(currently_defined(_)),
|
||||
retractall(f(_,_,_,_)),
|
||||
writeln(keys:Keys),
|
||||
run_through_factors(QueryVars),
|
||||
run_through_factors(EVars),
|
||||
findall(K, currently_defined(K), Keys),
|
||||
writeln(keys2:Keys),
|
||||
ground_all_keys(QueryVars, Keys),
|
||||
ground_all_keys(EVars, Keys),
|
||||
findall(f(FType,FId,FKeys,FCPT), f(FType,FId,FKeys,FCPT), Factors).
|
||||
|
||||
match([], _Keys).
|
||||
match([V|GVars], Keys) :-
|
||||
clpbn:get_atts(V,[key(GKey)]), !,
|
||||
member(GKey, Keys), ground(GKey),
|
||||
match(GVars, Keys).
|
||||
match([_V|GVars], Keys) :-
|
||||
match(GVars, Keys).
|
||||
run_through_factors([]).
|
||||
run_through_factors([Var|_QueryVars]) :-
|
||||
clpbn:get_atts(Var,[key(K)]),
|
||||
find_factors(K),
|
||||
fail.
|
||||
run_through_factors([_|QueryVars]) :-
|
||||
run_through_factors(QueryVars).
|
||||
|
||||
|
||||
ground_all_keys([], _).
|
||||
ground_all_keys([V|GVars], AllKeys) :-
|
||||
clpbn:get_atts(V,[key(Key)]),
|
||||
\+ ground(Key), !,
|
||||
member(Key, AllKeys),
|
||||
ground_all_keys(GVars, AllKeys).
|
||||
ground_all_keys([_V|GVars], AllKeys) :-
|
||||
ground_all_keys(GVars, AllKeys).
|
||||
|
||||
|
||||
%
|
||||
@@ -99,6 +109,7 @@ keys([Var|QueryVars], [Key|QueryKeys]) :-
|
||||
initialize_evidence([]).
|
||||
initialize_evidence([V|EVars]) :-
|
||||
clpbn:get_atts(V, [key(K)]),
|
||||
ground(K),
|
||||
assert(currently_defined(K)),
|
||||
initialize_evidence(EVars).
|
||||
|
||||
@@ -106,7 +117,7 @@ initialize_evidence([V|EVars]) :-
|
||||
% gets key K, and collects factors that define it
|
||||
find_factors(K) :-
|
||||
\+ currently_defined(K),
|
||||
assert(currently_defined(K)),
|
||||
( ground(K) -> assert(currently_defined(K)) ; true),
|
||||
defined_in_factor(K, ParFactor),
|
||||
add_factor(ParFactor, Ks),
|
||||
member(K1, Ks),
|
||||
|
@@ -1,61 +1,63 @@
|
||||
|
||||
/*******************************************************
|
||||
|
||||
Interface with C++
|
||||
Horus Interface
|
||||
|
||||
********************************************************/
|
||||
|
||||
:- module(clpbn_horus,
|
||||
[set_solver/1,
|
||||
create_lifted_network/3,
|
||||
create_ground_network/4,
|
||||
set_parfactors_params/2,
|
||||
set_factors_params/2,
|
||||
run_lifted_solver/3,
|
||||
run_ground_solver/3,
|
||||
set_vars_information/2,
|
||||
set_horus_flag/2,
|
||||
free_parfactors/1,
|
||||
free_ground_network/1
|
||||
set_horus_flag/1,
|
||||
cpp_create_lifted_network/3,
|
||||
cpp_create_ground_network/4,
|
||||
cpp_set_parfactors_params/2,
|
||||
cpp_set_factors_params/2,
|
||||
cpp_run_lifted_solver/3,
|
||||
cpp_run_ground_solver/3,
|
||||
cpp_set_vars_information/2,
|
||||
cpp_set_horus_flag/2,
|
||||
cpp_free_parfactors/1,
|
||||
cpp_free_ground_network/1
|
||||
]).
|
||||
|
||||
:- use_module(library(clpbn),
|
||||
[set_clpbn_flag/2]).
|
||||
|
||||
|
||||
patch_things_up :-
|
||||
assert_static(clpbn_horus:set_horus_flag(_,_)).
|
||||
assert_static(clpbn_horus:cpp_set_horus_flag(_,_)).
|
||||
|
||||
|
||||
warning :-
|
||||
format(user_error,"Horus library not installed: cannot use bp, fove~n.",[]).
|
||||
|
||||
:- catch(load_foreign_files([horus], [], init_predicates), _, patch_things_up) -> true ; warning.
|
||||
|
||||
:- catch(load_foreign_files([horus], [], init_predicates), _, patch_things_up)
|
||||
-> true ; warning.
|
||||
|
||||
|
||||
set_solver(ve) :- pfl:set_pfl_flag(solver,ve).
|
||||
set_solver(jt) :- pfl:set_pfl_flag(solver,jt).
|
||||
set_solver(gibbs) :- pfl:set_pfl_flag(solver,gibbs).
|
||||
set_solver(fove) :- pfl:set_pfl_flag(solver,fove).
|
||||
set_solver(hve) :- pfl:set_pfl_flag(solver,bp), set_horus_flag(inf_alg, ve).
|
||||
set_solver(bp) :- pfl:set_pfl_flag(solver,bp), set_horus_flag(inf_alg, bp).
|
||||
set_solver(cbp) :- pfl:set_pfl_flag(solver,bp), set_horus_flag(inf_alg, cbp).
|
||||
set_solver(ve) :- set_clpbn_flag(solver,ve).
|
||||
set_solver(jt) :- set_clpbn_flag(solver,jt).
|
||||
set_solver(gibbs) :- set_clpbn_flag(solver,gibbs).
|
||||
set_solver(fove) :- set_clpbn_flag(solver,fove), set_horus_flag(lifted_solver, fove).
|
||||
set_solver(lbp) :- set_clpbn_flag(solver,fove), set_horus_flag(lifted_solver, lbp).
|
||||
set_solver(hve) :- set_clpbn_flag(solver,bp), set_horus_flag(ground_solver, ve).
|
||||
set_solver(bp) :- set_clpbn_flag(solver,bp), set_horus_flag(ground_solver, bp).
|
||||
set_solver(cbp) :- set_clpbn_flag(solver,bp), set_horus_flag(ground_solver, cbp).
|
||||
set_solver(S) :- throw(error('unknow solver ', S)).
|
||||
|
||||
|
||||
%:- set_horus_flag(inf_alg, ve).
|
||||
%:- set_horus_flag(inf_alg, bp).
|
||||
%: -set_horus_flag(inf_alg, cbp).
|
||||
set_horus_flag(K,V) :- cpp_set_horus_flag(K,V).
|
||||
|
||||
:- set_horus_flag(schedule, seq_fixed).
|
||||
%:- set_horus_flag(schedule, seq_random).
|
||||
%:- set_horus_flag(schedule, parallel).
|
||||
%:- set_horus_flag(schedule, max_residual).
|
||||
|
||||
:- set_horus_flag(accuracy, 0.0001).
|
||||
:- cpp_set_horus_flag(schedule, seq_fixed).
|
||||
%:- cpp_set_horus_flag(schedule, seq_random).
|
||||
%:- cpp_set_horus_flag(schedule, parallel).
|
||||
%:- cpp_set_horus_flag(schedule, max_residual).
|
||||
|
||||
:- set_horus_flag(max_iter, 1000).
|
||||
:- cpp_set_horus_flag(accuracy, 0.0001).
|
||||
|
||||
:- set_horus_flag(order_factor_variables, false).
|
||||
%:- set_horus_flag(order_factor_variables, true).
|
||||
:- cpp_set_horus_flag(max_iter, 1000).
|
||||
|
||||
:- set_horus_flag(use_logarithms, false).
|
||||
% :- set_horus_flag(use_logarithms, true).
|
||||
:- cpp_set_horus_flag(use_logarithms, false).
|
||||
% :- cpp_set_horus_flag(use_logarithms, true).
|
||||
|
||||
|
@@ -1,19 +1,27 @@
|
||||
|
||||
/*******************************************************
|
||||
|
||||
Belief Propagation and Variable Elimination Interface
|
||||
Interface to Horus Ground Solvers. Used by:
|
||||
- Variable Elimination
|
||||
- Belief Propagation
|
||||
- Counting Belief Propagation
|
||||
|
||||
********************************************************/
|
||||
|
||||
:- module(clpbn_bp,
|
||||
[bp/3,
|
||||
check_if_bp_done/1,
|
||||
init_bp_solver/4,
|
||||
run_bp_solver/3,
|
||||
call_bp_ground/6,
|
||||
finalize_bp_solver/1
|
||||
:- module(clpbn_horus_ground,
|
||||
[call_horus_ground_solver/6,
|
||||
check_if_horus_ground_solver_done/1,
|
||||
init_horus_ground_solver/4,
|
||||
run_horus_ground_solver/3,
|
||||
finalize_horus_ground_solver/1
|
||||
]).
|
||||
|
||||
:- use_module(horus,
|
||||
[cpp_create_ground_network/4,
|
||||
cpp_set_factors_params/2,
|
||||
cpp_run_ground_solver/3,
|
||||
cpp_set_vars_information/2,
|
||||
cpp_free_ground_network/1
|
||||
]).
|
||||
|
||||
:- use_module(library('clpbn/dists'),
|
||||
[dist/4,
|
||||
@@ -22,25 +30,20 @@
|
||||
get_dist_params/2
|
||||
]).
|
||||
|
||||
|
||||
:- use_module(library('clpbn/display'),
|
||||
[clpbn_bind_vals/3]).
|
||||
|
||||
|
||||
:- use_module(library('clpbn/aggregates'),
|
||||
[check_for_agg_vars/2]).
|
||||
|
||||
|
||||
:- use_module(library(charsio),
|
||||
[term_to_atom/2]).
|
||||
|
||||
|
||||
:- use_module(library(pfl),
|
||||
[skolem/2,
|
||||
get_pfl_parameters/2
|
||||
]).
|
||||
|
||||
|
||||
:- use_module(library(lists)).
|
||||
|
||||
:- use_module(library(atts)).
|
||||
@@ -48,45 +51,36 @@
|
||||
:- use_module(library(bhash)).
|
||||
|
||||
|
||||
:- use_module(horus,
|
||||
[create_ground_network/4,
|
||||
set_factors_params/2,
|
||||
run_ground_solver/3,
|
||||
set_vars_information/2,
|
||||
free_ground_network/1
|
||||
]).
|
||||
|
||||
|
||||
call_bp_ground(QueryVars, QueryKeys, AllKeys, Factors, Evidence, Output) :-
|
||||
writeln(here:Factors),
|
||||
call_horus_ground_solver(QueryVars, QueryKeys, AllKeys, Factors, Evidence, Output) :-
|
||||
b_hash_new(Hash0),
|
||||
keys_to_ids(AllKeys, 0, Hash0, Hash),
|
||||
get_factors_type(Factors, Type),
|
||||
evidence_to_ids(Evidence, Hash, EvidenceIds),
|
||||
factors_to_ids(Factors, Hash, FactorIds),
|
||||
writeln(type:Type), writeln(''),
|
||||
writeln(allKeys:AllKeys), writeln(''),
|
||||
writeln(factors:Factors), writeln(''),
|
||||
writeln(factorIds:FactorIds), writeln(''),
|
||||
writeln(evidence:Evidence), writeln(''),
|
||||
writeln(evidenceIds:EvidenceIds), writeln(''),
|
||||
create_ground_network(Type, FactorIds, EvidenceIds, Network),
|
||||
%writeln(type:Type), writeln(''),
|
||||
%writeln(allKeys:AllKeys), writeln(''),
|
||||
%sort(AllKeys,SKeys),writeln(allKeys:SKeys), writeln(''),
|
||||
%writeln(factors:Factors), writeln(''),
|
||||
%writeln(factorIds:FactorIds), writeln(''),
|
||||
%writeln(evidence:Evidence), writeln(''),
|
||||
%writeln(evidenceIds:EvidenceIds), writeln(''),
|
||||
cpp_create_ground_network(Type, FactorIds, EvidenceIds, Network),
|
||||
%get_vars_information(AllKeys, StatesNames),
|
||||
%set_vars_information(AllKeys, StatesNames),
|
||||
%terms_to_atoms(AllKeys, KeysAtoms),
|
||||
%cpp_set_vars_information(KeysAtoms, StatesNames),
|
||||
run_solver(ground(Network,Hash), QueryKeys, Solutions),
|
||||
writeln(answer:Solutions),
|
||||
clpbn_bind_vals([QueryVars], Solutions, Output),
|
||||
free_ground_network(Network).
|
||||
cpp_free_ground_network(Network).
|
||||
|
||||
|
||||
run_solver(ground(Network,Hash), QueryKeys, Solutions) :-
|
||||
%get_dists_parameters(DistIds, DistsParams),
|
||||
%set_factors_params(Network, DistsParams),
|
||||
%cpp_set_factors_params(Network, DistsParams),
|
||||
list_of_keys_to_ids(QueryKeys, Hash, QueryIds),
|
||||
writeln(queryKeys:QueryKeys), writeln(''),
|
||||
writeln(queryIds:QueryIds), writeln(''),
|
||||
%writeln(queryKeys:QueryKeys), writeln(''),
|
||||
%writeln(queryIds:QueryIds), writeln(''),
|
||||
list_of_keys_to_ids(QueryKeys, Hash, QueryIds),
|
||||
run_ground_solver(Network, [QueryIds], Solutions).
|
||||
cpp_run_ground_solver(Network, [QueryIds], Solutions).
|
||||
|
||||
|
||||
keys_to_ids([], _, Hash, Hash).
|
||||
@@ -114,7 +108,7 @@ factors_to_ids([f(_, DistId, Keys, CPT)|Fs], Hash, [f(Ids, Ranges, CPT, DistId)|
|
||||
|
||||
|
||||
get_ranges([],[]).
|
||||
get_ranges(K.Ks, Range.Rs) :- !,
|
||||
get_ranges(K.Ks, Range.Rs) :- !,
|
||||
skolem(K,Domain),
|
||||
length(Domain,Range),
|
||||
get_ranges(Ks, Rs).
|
||||
@@ -132,31 +126,40 @@ get_vars_information(Key.QueryKeys, Domain.StatesNames) :-
|
||||
get_vars_information(QueryKeys, StatesNames).
|
||||
|
||||
|
||||
finalize_bp_solver(bp(Network, _)) :-
|
||||
free_ground_network(Network).
|
||||
terms_to_atoms([], []).
|
||||
terms_to_atoms(K.Ks, Atom.As) :-
|
||||
term_to_atom(K,Atom),
|
||||
terms_to_atoms(Ks,As).
|
||||
|
||||
|
||||
bp([[]],_,_) :- !.
|
||||
bp([QueryVars], AllVars, Output) :-
|
||||
init_bp_solver(_, AllVars, _, Network),
|
||||
run_bp_solver([QueryVars], LPs, Network),
|
||||
finalize_bp_solver(Network),
|
||||
clpbn_bind_vals([QueryVars], LPs, Output).
|
||||
finalize_horus_ground_solver(bp(Network, _)) :-
|
||||
cpp_free_ground_network(Network).
|
||||
|
||||
|
||||
init_bp_solver(_, AllVars0, _, bp(BayesNet, DistIds)) :-
|
||||
%check_for_agg_vars(AllVars0, AllVars),
|
||||
get_vars_info(AllVars0, VarsInfo, DistIds0),
|
||||
sort(DistIds0, DistIds),
|
||||
create_ground_network(VarsInfo, BayesNet),
|
||||
true.
|
||||
init_horus_ground_solver(_, _AllVars0, _, bp(_BayesNet, _DistIds)) :- !.
|
||||
|
||||
run_horus_ground_solver(_QueryVars, _Solutions, bp(_Network, _DistIds)) :- !.
|
||||
|
||||
run_bp_solver(QueryVars, Solutions, bp(Network, DistIds)) :-
|
||||
get_dists_parameters(DistIds, DistsParams),
|
||||
set_factors_params(Network, DistsParams),
|
||||
vars_to_ids(QueryVars, QueryVarsIds),
|
||||
run_ground_solver(Network, QueryVarsIds, Solutions).
|
||||
%bp([[]],_,_) :- !.
|
||||
%bp([QueryVars], AllVars, Output) :-
|
||||
% init_horus_ground_solver(_, AllVars, _, Network),
|
||||
% run_horus_ground_solver([QueryVars], LPs, Network),
|
||||
% finalize_horus_ground_solver(Network),
|
||||
% clpbn_bind_vals([QueryVars], LPs, Output).
|
||||
%
|
||||
%init_horus_ground_solver(_, AllVars0, _, bp(BayesNet, DistIds)) :-
|
||||
% %check_for_agg_vars(AllVars0, AllVars),
|
||||
% get_vars_info(AllVars0, VarsInfo, DistIds0),
|
||||
% sort(DistIds0, DistIds),
|
||||
% cpp_create_ground_network(VarsInfo, BayesNet),
|
||||
% true.
|
||||
%
|
||||
%
|
||||
%run_horus_ground_solver(QueryVars, Solutions, bp(Network, DistIds)) :-
|
||||
% get_dists_parameters(DistIds, DistsParams),
|
||||
% cpp_set_factors_params(Network, DistsParams),
|
||||
% vars_to_ids(QueryVars, QueryVarsIds),
|
||||
% cpp_run_ground_solver(Network, QueryVarsIds, Solutions).
|
||||
|
||||
|
||||
get_dists_parameters([],[]).
|
@@ -1,27 +1,31 @@
|
||||
|
||||
/*******************************************************
|
||||
|
||||
First Order Variable Elimination Interface
|
||||
Interface to Horus Lifted Solvers. Used by:
|
||||
- Lifted Variable Elimination
|
||||
|
||||
********************************************************/
|
||||
|
||||
:- module(clpbn_fove,
|
||||
[fove/3,
|
||||
check_if_fove_done/1,
|
||||
init_fove_solver/4,
|
||||
run_fove_solver/3,
|
||||
finalize_fove_solver/1
|
||||
:- module(clpbn_horus_lifted,
|
||||
[call_horus_lifted_solver/3,
|
||||
check_if_horus_lifted_solver_done/1,
|
||||
init_horus_lifted_solver/4,
|
||||
run_horus_lifted_solver/3,
|
||||
finalize_horus_lifted_solver/1
|
||||
]).
|
||||
|
||||
:- use_module(horus,
|
||||
[cpp_create_lifted_network/3,
|
||||
cpp_set_parfactors_params/2,
|
||||
cpp_run_lifted_solver/3,
|
||||
cpp_free_parfactors/1
|
||||
]).
|
||||
|
||||
:- use_module(library('clpbn/display'),
|
||||
[clpbn_bind_vals/3]).
|
||||
|
||||
|
||||
:- use_module(library('clpbn/dists'),
|
||||
[get_dist_params/2]).
|
||||
|
||||
|
||||
:- use_module(library(pfl),
|
||||
[factor/6,
|
||||
skolem/2,
|
||||
@@ -29,30 +33,22 @@
|
||||
]).
|
||||
|
||||
|
||||
:- use_module(horus,
|
||||
[create_lifted_network/3,
|
||||
set_parfactors_params/2,
|
||||
run_lifted_solver/3,
|
||||
free_parfactors/1
|
||||
]).
|
||||
|
||||
|
||||
fove([[]], _, _) :- !.
|
||||
fove([QueryVars], AllVars, Output) :-
|
||||
init_fove_solver(_, AllVars, _, ParfactorList),
|
||||
run_fove_solver([QueryVars], LPs, ParfactorList),
|
||||
finalize_fove_solver(ParfactorList),
|
||||
call_horus_lifted_solver([[]], _, _) :- !.
|
||||
call_horus_lifted_solver([QueryVars], AllVars, Output) :-
|
||||
init_horus_lifted_solver(_, AllVars, _, ParfactorList),
|
||||
run_horus_lifted_solver([QueryVars], LPs, ParfactorList),
|
||||
finalize_horus_lifted_solver(ParfactorList),
|
||||
clpbn_bind_vals([QueryVars], LPs, Output).
|
||||
|
||||
|
||||
init_fove_solver(_, AllAttVars, _, fove(ParfactorList, DistIds)) :-
|
||||
init_horus_lifted_solver(_, AllAttVars, _, fove(ParfactorList, DistIds)) :-
|
||||
get_parfactors(Parfactors),
|
||||
get_dist_ids(Parfactors, DistIds0),
|
||||
sort(DistIds0, DistIds),
|
||||
get_observed_vars(AllAttVars, ObservedVars),
|
||||
writeln(parfactors:Parfactors:'\n'),
|
||||
writeln(evidence:ObservedVars:'\n'),
|
||||
create_lifted_network(Parfactors,ObservedVars,ParfactorList).
|
||||
%writeln(parfactors:Parfactors:'\n'),
|
||||
%writeln(evidence:ObservedVars:'\n'),
|
||||
cpp_create_lifted_network(Parfactors,ObservedVars,ParfactorList).
|
||||
|
||||
|
||||
:- table get_parfactors/1.
|
||||
@@ -138,15 +134,15 @@ get_dists_parameters([Id|Ids], [dist(Id, Params)|DistsInfo]) :-
|
||||
get_dists_parameters(Ids, DistsInfo).
|
||||
|
||||
|
||||
run_fove_solver(QueryVarsAtts, Solutions, fove(ParfactorList, DistIds)) :-
|
||||
run_horus_lifted_solver(QueryVarsAtts, Solutions, fove(ParfactorList, DistIds)) :-
|
||||
get_query_vars(QueryVarsAtts, QueryVars),
|
||||
writeln(queryVars:QueryVars), writeln(''),
|
||||
%writeln(queryVars:QueryVars), writeln(''),
|
||||
get_dists_parameters(DistIds, DistsParams),
|
||||
writeln(dists:DistsParams), writeln(''),
|
||||
set_parfactors_params(ParfactorList, DistsParams),
|
||||
run_lifted_solver(ParfactorList, QueryVars, Solutions).
|
||||
%writeln(dists:DistsParams), writeln(''),
|
||||
cpp_set_parfactors_params(ParfactorList, DistsParams),
|
||||
cpp_run_lifted_solver(ParfactorList, QueryVars, Solutions).
|
||||
|
||||
|
||||
finalize_fove_solver(fove(ParfactorList, _)) :-
|
||||
free_parfactors(ParfactorList).
|
||||
finalize_horus_lifted_solver(fove(ParfactorList, _)) :-
|
||||
cpp_free_parfactors(ParfactorList).
|
||||
|
Reference in New Issue
Block a user