new version of belief propagation solver.
This commit is contained in:
parent
a16a7d5b1c
commit
69e5fed10f
149
packages/CLPBN/clpbn/bp/BPNodeInfo.cpp
Executable file
149
packages/CLPBN/clpbn/bp/BPNodeInfo.cpp
Executable file
@ -0,0 +1,149 @@
|
|||||||
|
#include <cassert>
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
#include "BPNodeInfo.h"
|
||||||
|
#include "BPSolver.h"
|
||||||
|
|
||||||
|
BPNodeInfo::BPNodeInfo (BayesNode* node)
|
||||||
|
{
|
||||||
|
node_ = node;
|
||||||
|
ds_ = node->getDomainSize();
|
||||||
|
piValsCalc_ = false;
|
||||||
|
ldValsCalc_ = false;
|
||||||
|
nPiMsgsRcv_ = 0;
|
||||||
|
nLdMsgsRcv_ = 0;
|
||||||
|
piVals_.resize (ds_, 1);
|
||||||
|
ldVals_.resize (ds_, 1);
|
||||||
|
const BnNodeSet& childs = node->getChilds();
|
||||||
|
for (unsigned i = 0; i < childs.size(); i++) {
|
||||||
|
cmsgs_.insert (make_pair (childs[i], false));
|
||||||
|
}
|
||||||
|
const BnNodeSet& parents = node->getParents();
|
||||||
|
for (unsigned i = 0; i < parents.size(); i++) {
|
||||||
|
pmsgs_.insert (make_pair (parents[i], false));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
ParamSet
|
||||||
|
BPNodeInfo::getBeliefs (void) const
|
||||||
|
{
|
||||||
|
double sum = 0.0;
|
||||||
|
ParamSet beliefs (ds_);
|
||||||
|
for (unsigned xi = 0; xi < ds_; xi++) {
|
||||||
|
double prod = piVals_[xi] * ldVals_[xi];
|
||||||
|
beliefs[xi] = prod;
|
||||||
|
sum += prod;
|
||||||
|
}
|
||||||
|
assert (sum);
|
||||||
|
//normalize the beliefs
|
||||||
|
for (unsigned xi = 0; xi < ds_; xi++) {
|
||||||
|
beliefs[xi] /= sum;
|
||||||
|
}
|
||||||
|
return beliefs;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
BPNodeInfo::readyToSendPiMsgTo (const BayesNode* child) const
|
||||||
|
{
|
||||||
|
for (unsigned i = 0; i < inChildLinks_.size(); i++) {
|
||||||
|
if (inChildLinks_[i]->getSource() != child
|
||||||
|
&& !inChildLinks_[i]->messageWasSended()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
BPNodeInfo::readyToSendLambdaMsgTo (const BayesNode* parent) const
|
||||||
|
{
|
||||||
|
for (unsigned i = 0; i < inParentLinks_.size(); i++) {
|
||||||
|
if (inParentLinks_[i]->getSource() != parent
|
||||||
|
&& !inParentLinks_[i]->messageWasSended()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
double
|
||||||
|
BPNodeInfo::getPiValue (unsigned idx) const
|
||||||
|
{
|
||||||
|
assert (idx >=0 && idx < ds_);
|
||||||
|
return piVals_[idx];
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
BPNodeInfo::setPiValue (unsigned idx, Param value)
|
||||||
|
{
|
||||||
|
assert (idx >=0 && idx < ds_);
|
||||||
|
piVals_[idx] = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
double
|
||||||
|
BPNodeInfo::getLambdaValue (unsigned idx) const
|
||||||
|
{
|
||||||
|
assert (idx >=0 && idx < ds_);
|
||||||
|
return ldVals_[idx];
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
BPNodeInfo::setLambdaValue (unsigned idx, Param value)
|
||||||
|
{
|
||||||
|
assert (idx >=0 && idx < ds_);
|
||||||
|
ldVals_[idx] = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
double
|
||||||
|
BPNodeInfo::getBeliefChange (void)
|
||||||
|
{
|
||||||
|
double change = 0.0;
|
||||||
|
if (oldBeliefs_.size() == 0) {
|
||||||
|
oldBeliefs_ = getBeliefs();
|
||||||
|
change = 9999999999.0;
|
||||||
|
} else {
|
||||||
|
ParamSet currentBeliefs = getBeliefs();
|
||||||
|
for (unsigned xi = 0; xi < ds_; xi++) {
|
||||||
|
change += abs (currentBeliefs[xi] - oldBeliefs_[xi]);
|
||||||
|
}
|
||||||
|
oldBeliefs_ = currentBeliefs;
|
||||||
|
}
|
||||||
|
return change;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
BPNodeInfo::receivedBottomInfluence (void) const
|
||||||
|
{
|
||||||
|
// if all lambda values are equal, then neither
|
||||||
|
// this node neither its descendents have evidence,
|
||||||
|
// we can use this to don't send lambda messages his parents
|
||||||
|
bool childInfluenced = false;
|
||||||
|
for (unsigned xi = 1; xi < ds_; xi++) {
|
||||||
|
if (ldVals_[xi] != ldVals_[0]) {
|
||||||
|
childInfluenced = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return childInfluenced;
|
||||||
|
}
|
||||||
|
|
82
packages/CLPBN/clpbn/bp/BPNodeInfo.h
Executable file
82
packages/CLPBN/clpbn/bp/BPNodeInfo.h
Executable file
@ -0,0 +1,82 @@
|
|||||||
|
#ifndef BP_BP_NODE_H
|
||||||
|
#define BP_BP_NODE_H
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <map>
|
||||||
|
|
||||||
|
#include "BPSolver.h"
|
||||||
|
#include "BayesNode.h"
|
||||||
|
#include "Shared.h"
|
||||||
|
|
||||||
|
//class Edge;
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
class BPNodeInfo
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
BPNodeInfo (int);
|
||||||
|
BPNodeInfo (BayesNode*);
|
||||||
|
|
||||||
|
ParamSet getBeliefs (void) const;
|
||||||
|
double getPiValue (unsigned) const;
|
||||||
|
void setPiValue (unsigned, Param);
|
||||||
|
double getLambdaValue (unsigned) const;
|
||||||
|
void setLambdaValue (unsigned, Param);
|
||||||
|
double getBeliefChange (void);
|
||||||
|
bool receivedBottomInfluence (void) const;
|
||||||
|
|
||||||
|
ParamSet& getPiValues (void) { return piVals_; }
|
||||||
|
ParamSet& getLambdaValues (void) { return ldVals_; }
|
||||||
|
bool arePiValuesCalculated (void) { return piValsCalc_; }
|
||||||
|
bool areLambdaValuesCalculated (void) { return ldValsCalc_; }
|
||||||
|
void markPiValuesAsCalculated (void) { piValsCalc_ = true; }
|
||||||
|
void markLambdaValuesAsCalculated (void) { ldValsCalc_ = true; }
|
||||||
|
void incNumPiMsgsRcv (void) { nPiMsgsRcv_ ++; }
|
||||||
|
void incNumLambdaMsgsRcv (void) { nLdMsgsRcv_ ++; }
|
||||||
|
|
||||||
|
bool receivedAllPiMessages (void)
|
||||||
|
{
|
||||||
|
return node_->getParents().size() == nPiMsgsRcv_;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool receivedAllLambdaMessages (void)
|
||||||
|
{
|
||||||
|
return node_->getChilds().size() == nLdMsgsRcv_;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool readyToSendPiMsgTo (const BayesNode*) const ;
|
||||||
|
bool readyToSendLambdaMsgTo (const BayesNode*) const;
|
||||||
|
|
||||||
|
CEdgeSet getIncomingParentLinks (void) { return inParentLinks_; }
|
||||||
|
CEdgeSet getIncomingChildLinks (void) { return inChildLinks_; }
|
||||||
|
CEdgeSet getOutcomingParentLinks (void) { return outParentLinks_; }
|
||||||
|
CEdgeSet getOutcomingChildLinks (void) { return outChildLinks_; }
|
||||||
|
|
||||||
|
void addIncomingParentLink (Edge* l) { inParentLinks_.push_back (l); }
|
||||||
|
void addIncomingChildLink (Edge* l) { inChildLinks_.push_back (l); }
|
||||||
|
void addOutcomingParentLink (Edge* l) { outParentLinks_.push_back (l); }
|
||||||
|
void addOutcomingChildLink (Edge* l) { outChildLinks_.push_back (l); }
|
||||||
|
|
||||||
|
private:
|
||||||
|
DISALLOW_COPY_AND_ASSIGN (BPNodeInfo);
|
||||||
|
|
||||||
|
ParamSet piVals_; // pi values
|
||||||
|
ParamSet ldVals_; // lambda values
|
||||||
|
ParamSet oldBeliefs_;
|
||||||
|
unsigned nPiMsgsRcv_;
|
||||||
|
unsigned nLdMsgsRcv_;
|
||||||
|
bool piValsCalc_;
|
||||||
|
bool ldValsCalc_;
|
||||||
|
EdgeSet inParentLinks_;
|
||||||
|
EdgeSet inChildLinks_;
|
||||||
|
EdgeSet outParentLinks_;
|
||||||
|
EdgeSet outChildLinks_;
|
||||||
|
unsigned ds_;
|
||||||
|
const BayesNode* node_;
|
||||||
|
map<const BayesNode*, bool> pmsgs_;
|
||||||
|
map<const BayesNode*, bool> cmsgs_;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif //BP_BP_NODE_H
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
@ -1,259 +1,106 @@
|
|||||||
#ifndef BP_BPSOLVER_H
|
#ifndef BP_BP_SOLVER_H
|
||||||
#define BP_BPSOLVER_H
|
#define BP_BP_SOLVER_H
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
|
||||||
#include <set>
|
#include <set>
|
||||||
|
|
||||||
#include "Solver.h"
|
#include "Solver.h"
|
||||||
#include "BayesNet.h"
|
#include "BayesNet.h"
|
||||||
#include "BpNode.h"
|
#include "BPNodeInfo.h"
|
||||||
#include "Shared.h"
|
#include "Shared.h"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
class BPSolver;
|
class BPNodeInfo;
|
||||||
|
|
||||||
static const string PI = "pi" ;
|
static const string PI = "pi" ;
|
||||||
static const string LD = "ld" ;
|
static const string LD = "ld" ;
|
||||||
|
|
||||||
|
|
||||||
enum MessageType {PI_MSG, LAMBDA_MSG};
|
enum MessageType {PI_MSG, LAMBDA_MSG};
|
||||||
|
enum JointCalcType {CHAIN_RULE, JUNCTION_NODE};
|
||||||
|
|
||||||
class BPSolver;
|
class Edge
|
||||||
struct Edge
|
|
||||||
{
|
{
|
||||||
Edge (BayesNode* s, BayesNode* d, MessageType t)
|
public:
|
||||||
{
|
Edge (BayesNode* s, BayesNode* d, MessageType t)
|
||||||
source = s;
|
{
|
||||||
destination = d;
|
source_ = s;
|
||||||
type = t;
|
destin_ = d;
|
||||||
}
|
type_ = t;
|
||||||
string getId (void) const
|
if (type_ == PI_MSG) {
|
||||||
{
|
currMsg_.resize (s->getDomainSize(), 1);
|
||||||
stringstream ss;
|
nextMsg_.resize (s->getDomainSize(), 1);
|
||||||
type == PI_MSG ? ss << PI : ss << LD;
|
} else {
|
||||||
ss << source->getVarId() << "." << destination->getVarId();
|
currMsg_.resize (d->getDomainSize(), 1);
|
||||||
return ss.str();
|
nextMsg_.resize (d->getDomainSize(), 1);
|
||||||
}
|
|
||||||
string toString (void) const
|
|
||||||
{
|
|
||||||
stringstream ss;
|
|
||||||
type == PI_MSG ? ss << PI << "(" : ss << LD << "(" ;
|
|
||||||
ss << source->getLabel() << " --> " ;
|
|
||||||
ss << destination->getLabel();
|
|
||||||
ss << ")" ;
|
|
||||||
return ss.str();
|
|
||||||
}
|
|
||||||
BayesNode* source;
|
|
||||||
BayesNode* destination;
|
|
||||||
MessageType type;
|
|
||||||
static BPSolver* klass;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/*
|
|
||||||
class BPMessage
|
|
||||||
{
|
|
||||||
BPMessage (BayesNode* parent, BayesNode* child)
|
|
||||||
{
|
|
||||||
parent_ = parent;
|
|
||||||
child_ = child;
|
|
||||||
currPiMsg_.resize (child->getDomainSize(), 1);
|
|
||||||
currLdMsg_.resize (parent->getDomainSize(), 1);
|
|
||||||
nextLdMsg_.resize (parent->getDomainSize(), 1);
|
|
||||||
nextPiMsg_.resize (child->getDomainSize(), 1);
|
|
||||||
piResidual_ = 1.0;
|
|
||||||
ldResidual_ = 1.0;
|
|
||||||
}
|
|
||||||
|
|
||||||
Param getPiMessageValue (int idx) const
|
|
||||||
{
|
|
||||||
assert (idx >=0 && idx < child->getDomainSize());
|
|
||||||
return currPiMsg_[idx];
|
|
||||||
}
|
|
||||||
|
|
||||||
Param getLambdaMessageValue (int idx) const
|
|
||||||
{
|
|
||||||
assert (idx >=0 && idx < parent->getDomainSize());
|
|
||||||
return currLdMsg_[idx];
|
|
||||||
}
|
|
||||||
|
|
||||||
const ParamSet& getPiMessage (void) const
|
|
||||||
{
|
|
||||||
return currPiMsg_;
|
|
||||||
}
|
|
||||||
|
|
||||||
const ParamSet& getLambdaMessage (void) const
|
|
||||||
{
|
|
||||||
return currLdMsg_;
|
|
||||||
}
|
|
||||||
|
|
||||||
ParamSet& piNextMessageReference (void)
|
|
||||||
{
|
|
||||||
return nextPiMsg_;
|
|
||||||
}
|
|
||||||
|
|
||||||
ParamSet& lambdaNextMessageReference (const BayesNode* source)
|
|
||||||
{
|
|
||||||
return nextLdMsg_;
|
|
||||||
}
|
|
||||||
|
|
||||||
void updatePiMessage (void)
|
|
||||||
{
|
|
||||||
currPiMsg_ = nextPiMsg_;
|
|
||||||
Util::normalize (currPiMsg_);
|
|
||||||
}
|
|
||||||
|
|
||||||
void updateLambdaMessage (void)
|
|
||||||
{
|
|
||||||
currLdMsg_ = nextLdMsg_;
|
|
||||||
Util::normalize (currLdMsg_);
|
|
||||||
}
|
|
||||||
|
|
||||||
double getPiResidual (void)
|
|
||||||
{
|
|
||||||
return piResidual_;
|
|
||||||
}
|
|
||||||
|
|
||||||
double getLambdaResidual (void)
|
|
||||||
{
|
|
||||||
return ldResidual_;
|
|
||||||
}
|
|
||||||
|
|
||||||
void updatePiResidual (void)
|
|
||||||
{
|
|
||||||
piResidual_ = Util::getL1dist (currPiMsg_, nextPiMsg_);
|
|
||||||
}
|
|
||||||
|
|
||||||
void updateLambdaResidual (void)
|
|
||||||
{
|
|
||||||
ldResidual_ = Util::getL1dist (currLdMsg_, nextLdMsg_);
|
|
||||||
}
|
|
||||||
|
|
||||||
void clearPiResidual (void)
|
|
||||||
{
|
|
||||||
piResidual_ = 0.0;
|
|
||||||
}
|
|
||||||
|
|
||||||
void clearLambdaResidual (void)
|
|
||||||
{
|
|
||||||
ldResidual_ = 0.0;
|
|
||||||
}
|
|
||||||
|
|
||||||
BayesNode* parent_;
|
|
||||||
BayesNode* child_;
|
|
||||||
ParamSet currPiMsg_; // current pi messages
|
|
||||||
ParamSet currLdMsg_; // current lambda messages
|
|
||||||
ParamSet nextPiMsg_;
|
|
||||||
ParamSet nextLdMsg_;
|
|
||||||
Param piResidual_;
|
|
||||||
Param ldResidual_;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class NodeInfo
|
|
||||||
{
|
|
||||||
NodeInfo (BayesNode* node)
|
|
||||||
{
|
|
||||||
node_ = node;
|
|
||||||
piVals_.resize (node->getDomainSize(), 1);
|
|
||||||
ldVals_.resize (node->getDomainSize(), 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
ParamSet getBeliefs (void) const
|
|
||||||
{
|
|
||||||
double sum = 0.0;
|
|
||||||
ParamSet beliefs (node_->getDomainSize());
|
|
||||||
for (int xi = 0; xi < node_->getDomainSize(); xi++) {
|
|
||||||
double prod = piVals_[xi] * ldVals_[xi];
|
|
||||||
beliefs[xi] = prod;
|
|
||||||
sum += prod;
|
|
||||||
}
|
|
||||||
assert (sum);
|
|
||||||
//normalize the beliefs
|
|
||||||
for (int xi = 0; xi < node_->getDomainSize(); xi++) {
|
|
||||||
beliefs[xi] /= sum;
|
|
||||||
}
|
|
||||||
return beliefs;
|
|
||||||
}
|
|
||||||
|
|
||||||
double getPiValue (int idx) const
|
|
||||||
{
|
|
||||||
assert (idx >=0 && idx < node_->getDomainSize());
|
|
||||||
return piVals_[idx];
|
|
||||||
}
|
|
||||||
|
|
||||||
void setPiValue (int idx, double value)
|
|
||||||
{
|
|
||||||
assert (idx >=0 && idx < node_->getDomainSize());
|
|
||||||
piVals_[idx] = value;
|
|
||||||
}
|
|
||||||
|
|
||||||
double getLambdaValue (int idx) const
|
|
||||||
{
|
|
||||||
assert (idx >=0 && idx < node_->getDomainSize());
|
|
||||||
return ldVals_[idx];
|
|
||||||
}
|
|
||||||
|
|
||||||
void setLambdaValue (int idx, double value)
|
|
||||||
{
|
|
||||||
assert (idx >=0 && idx < node_->getDomainSize());
|
|
||||||
ldVals_[idx] = value;
|
|
||||||
}
|
|
||||||
|
|
||||||
ParamSet& getPiValues (void)
|
|
||||||
{
|
|
||||||
return piVals_;
|
|
||||||
}
|
|
||||||
|
|
||||||
ParamSet& getLambdaValues (void)
|
|
||||||
{
|
|
||||||
return ldVals_;
|
|
||||||
}
|
|
||||||
|
|
||||||
double getBeliefChange (void)
|
|
||||||
{
|
|
||||||
double change = 0.0;
|
|
||||||
if (oldBeliefs_.size() == 0) {
|
|
||||||
oldBeliefs_ = getBeliefs();
|
|
||||||
change = MAX_CHANGE_;
|
|
||||||
} else {
|
|
||||||
ParamSet currentBeliefs = getBeliefs();
|
|
||||||
for (int xi = 0; xi < node_->getDomainSize(); xi++) {
|
|
||||||
change += abs (currentBeliefs[xi] - oldBeliefs_[xi]);
|
|
||||||
}
|
}
|
||||||
oldBeliefs_ = currentBeliefs;
|
msgSended_ = false;
|
||||||
|
residual_ = 0.0;
|
||||||
}
|
}
|
||||||
return change;
|
|
||||||
}
|
//void setMessage (ParamSet msg)
|
||||||
|
//{
|
||||||
|
// Util::normalize (msg);
|
||||||
|
// residual_ = Util::getMaxNorm (currMsg_, msg);
|
||||||
|
// currMsg_ = msg;
|
||||||
|
//}
|
||||||
|
|
||||||
bool hasReceivedChildInfluence (void) const
|
void setNextMessage (CParamSet msg)
|
||||||
{
|
{
|
||||||
// if all lambda values are equal, then neither
|
nextMsg_ = msg;
|
||||||
// this node neither its descendents have evidence,
|
Util::normalize (nextMsg_);
|
||||||
// we can use this to don't send lambda messages his parents
|
residual_ = Util::getMaxNorm (currMsg_, nextMsg_);
|
||||||
bool childInfluenced = false;
|
}
|
||||||
for (int xi = 1; xi < node_->getDomainSize(); xi++) {
|
|
||||||
if (ldVals_[xi] != ldVals_[0]) {
|
void updateMessage (void)
|
||||||
childInfluenced = true;
|
{
|
||||||
break;
|
currMsg_ = nextMsg_;
|
||||||
|
if (DL >= 3) {
|
||||||
|
cout << "updating " << toString() << endl;
|
||||||
}
|
}
|
||||||
|
msgSended_ = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void updateResidual (void)
|
||||||
|
{
|
||||||
|
residual_ = Util::getMaxNorm (currMsg_, nextMsg_);
|
||||||
}
|
}
|
||||||
return childInfluenced;
|
|
||||||
}
|
|
||||||
|
|
||||||
BayesNode* node_;
|
string toString (void) const
|
||||||
ParamSet piVals_; // pi values
|
{
|
||||||
ParamSet ldVals_; // lambda values
|
stringstream ss;
|
||||||
ParamSet oldBeliefs_;
|
if (type_ == PI_MSG) {
|
||||||
|
ss << PI;
|
||||||
|
} else if (type_ == LAMBDA_MSG) {
|
||||||
|
ss << LD;
|
||||||
|
} else {
|
||||||
|
abort();
|
||||||
|
}
|
||||||
|
ss << "(" << source_->getLabel();
|
||||||
|
ss << " --> " << destin_->getLabel() << ")" ;
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
BayesNode* getSource (void) const { return source_; }
|
||||||
|
BayesNode* getDestination (void) const { return destin_; }
|
||||||
|
MessageType getMessageType (void) const { return type_; }
|
||||||
|
CParamSet getMessage (void) const { return currMsg_; }
|
||||||
|
bool messageWasSended (void) const { return msgSended_; }
|
||||||
|
double getResidual (void) const { return residual_; }
|
||||||
|
void clearResidual (void) { residual_ = 0.0; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
BayesNode* source_;
|
||||||
|
BayesNode* destin_;
|
||||||
|
MessageType type_;
|
||||||
|
ParamSet currMsg_;
|
||||||
|
ParamSet nextMsg_;
|
||||||
|
bool msgSended_;
|
||||||
|
double residual_;
|
||||||
};
|
};
|
||||||
*/
|
|
||||||
|
|
||||||
|
|
||||||
bool compareResidual (const Edge&, const Edge&);
|
|
||||||
|
|
||||||
class BPSolver : public Solver
|
class BPSolver : public Solver
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
@ -261,190 +108,85 @@ class BPSolver : public Solver
|
|||||||
~BPSolver (void);
|
~BPSolver (void);
|
||||||
|
|
||||||
void runSolver (void);
|
void runSolver (void);
|
||||||
ParamSet getPosterioriOf (const Variable* var) const;
|
ParamSet getPosterioriOf (Vid) const;
|
||||||
ParamSet getJointDistribution (const NodeSet&) const;
|
ParamSet getJointDistributionOf (const VidSet&);
|
||||||
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
DISALLOW_COPY_AND_ASSIGN (BPSolver);
|
DISALLOW_COPY_AND_ASSIGN (BPSolver);
|
||||||
|
|
||||||
void initializeSolver (void);
|
void initializeSolver (void);
|
||||||
void incorporateEvidence (BayesNode*);
|
|
||||||
void runPolyTreeSolver (void);
|
void runPolyTreeSolver (void);
|
||||||
void polyTreePiMessage (BayesNode*, BayesNode*);
|
void runLoopySolver (void);
|
||||||
void polyTreeLambdaMessage (BayesNode*, BayesNode*);
|
|
||||||
void runGenericSolver (void);
|
|
||||||
void maxResidualSchedule (void);
|
void maxResidualSchedule (void);
|
||||||
bool converged (void) const;
|
bool converged (void) const;
|
||||||
void updatePiValues (BayesNode*);
|
void updatePiValues (BayesNode*);
|
||||||
void updateLambdaValues (BayesNode*);
|
void updateLambdaValues (BayesNode*);
|
||||||
void calculateNextPiMessage (BayesNode*, BayesNode*);
|
ParamSet calculateNextLambdaMessage (Edge* edge);
|
||||||
void calculateNextLambdaMessage (BayesNode*, BayesNode*);
|
ParamSet calculateNextPiMessage (Edge* edge);
|
||||||
|
ParamSet getJointByJunctionNode (const VidSet&) const;
|
||||||
|
ParamSet getJointByChainRule (const VidSet&) const;
|
||||||
void printMessageStatusOf (const BayesNode*) const;
|
void printMessageStatusOf (const BayesNode*) const;
|
||||||
void printAllMessageStatus (void) const;
|
void printAllMessageStatus (void) const;
|
||||||
// inlines
|
|
||||||
void updatePiMessage (BayesNode*, BayesNode*);
|
ParamSet getMessage (Edge* edge)
|
||||||
void updateLambdaMessage (BayesNode*, BayesNode*);
|
{
|
||||||
void calculateNextMessage (const Edge&);
|
if (DL >= 3) {
|
||||||
void updateMessage (const Edge&);
|
cout << " calculating " << edge->toString() << endl;
|
||||||
void updateValues (const Edge&);
|
}
|
||||||
double getResidual (const Edge&) const;
|
if (edge->getMessageType() == PI_MSG) {
|
||||||
void updateResidual (const Edge&);
|
return calculateNextPiMessage (edge);
|
||||||
void clearResidual (const Edge&);
|
} else if (edge->getMessageType() == LAMBDA_MSG) {
|
||||||
BpNode* M (const BayesNode*) const;
|
return calculateNextLambdaMessage (edge);
|
||||||
friend bool compareResidual (const Edge&, const Edge&);
|
} else {
|
||||||
|
abort();
|
||||||
|
}
|
||||||
|
return ParamSet();
|
||||||
|
}
|
||||||
|
|
||||||
|
void updateValues (Edge* edge)
|
||||||
|
{
|
||||||
|
if (!edge->getDestination()->hasEvidence()) {
|
||||||
|
if (edge->getMessageType() == PI_MSG) {
|
||||||
|
updatePiValues (edge->getDestination());
|
||||||
|
} else if (edge->getMessageType() == LAMBDA_MSG) {
|
||||||
|
updateLambdaValues (edge->getDestination());
|
||||||
|
} else {
|
||||||
|
abort();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
BPNodeInfo* M (const BayesNode* node) const
|
||||||
|
{
|
||||||
|
assert (node);
|
||||||
|
assert (node == bn_->getBayesNode (node->getVarId()));
|
||||||
|
assert (node->getIndex() < nodesI_.size());
|
||||||
|
return nodesI_[node->getIndex()];
|
||||||
|
}
|
||||||
|
|
||||||
const BayesNet* bn_;
|
const BayesNet* bn_;
|
||||||
vector<BpNode*> msgs_;
|
vector<BPNodeInfo*> nodesI_;
|
||||||
Schedule schedule_;
|
unsigned nIter_;
|
||||||
int nIter_;
|
vector<Edge*> links_;
|
||||||
int maxIter_;
|
bool useAlwaysLoopySolver_;
|
||||||
double accuracy_;
|
JointCalcType jointCalcType_;
|
||||||
vector<Edge> updateOrder_;
|
|
||||||
bool forceGenericSolver_;
|
|
||||||
|
|
||||||
struct compare
|
struct compare
|
||||||
{
|
{
|
||||||
inline bool operator() (const Edge& e1, const Edge& e2)
|
inline bool operator() (const Edge* e1, const Edge* e2)
|
||||||
{
|
{
|
||||||
return compareResidual (e1, e2);
|
return e1->getResidual() > e2->getResidual();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
typedef multiset<Edge, compare> SortedOrder;
|
typedef multiset<Edge*, compare> SortedOrder;
|
||||||
SortedOrder sortedOrder_;
|
SortedOrder sortedOrder_;
|
||||||
|
|
||||||
typedef unordered_map<string, SortedOrder::iterator> EdgeMap;
|
typedef map<Edge*, SortedOrder::iterator> EdgeMap;
|
||||||
EdgeMap edgeMap_;
|
EdgeMap edgeMap_;
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#endif //BP_BP_SOLVER_H
|
||||||
|
|
||||||
inline void
|
|
||||||
BPSolver::updatePiMessage (BayesNode* source, BayesNode* destination)
|
|
||||||
{
|
|
||||||
M(source)->updatePiMessage(destination);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline void
|
|
||||||
BPSolver::updateLambdaMessage (BayesNode* source, BayesNode* destination)
|
|
||||||
{
|
|
||||||
M(destination)->updateLambdaMessage(source);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline void
|
|
||||||
BPSolver::calculateNextMessage (const Edge& e)
|
|
||||||
{
|
|
||||||
if (DL >= 1) {
|
|
||||||
cout << "calculating " << e.toString() << endl;
|
|
||||||
}
|
|
||||||
if (e.type == PI_MSG) {
|
|
||||||
calculateNextPiMessage (e.source, e.destination);
|
|
||||||
} else {
|
|
||||||
calculateNextLambdaMessage (e.source, e.destination);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline void
|
|
||||||
BPSolver::updateMessage (const Edge& e)
|
|
||||||
{
|
|
||||||
if (DL >= 1) {
|
|
||||||
cout << "updating " << e.toString() << endl;
|
|
||||||
}
|
|
||||||
if (e.type == PI_MSG) {
|
|
||||||
M(e.source)->updatePiMessage(e.destination);
|
|
||||||
} else {
|
|
||||||
M(e.destination)->updateLambdaMessage(e.source);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline void
|
|
||||||
BPSolver::updateValues (const Edge& e)
|
|
||||||
{
|
|
||||||
if (!e.destination->hasEvidence()) {
|
|
||||||
if (e.type == PI_MSG) {
|
|
||||||
updatePiValues (e.destination);
|
|
||||||
} else {
|
|
||||||
updateLambdaValues (e.destination);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline double
|
|
||||||
BPSolver::getResidual (const Edge& e) const
|
|
||||||
{
|
|
||||||
if (e.type == PI_MSG) {
|
|
||||||
return M(e.source)->getPiResidual(e.destination);
|
|
||||||
} else {
|
|
||||||
return M(e.destination)->getLambdaResidual(e.source);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline void
|
|
||||||
BPSolver::updateResidual (const Edge& e)
|
|
||||||
{
|
|
||||||
if (e.type == PI_MSG) {
|
|
||||||
M(e.source)->updatePiResidual(e.destination);
|
|
||||||
} else {
|
|
||||||
M(e.destination)->updateLambdaResidual(e.source);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline void
|
|
||||||
BPSolver::clearResidual (const Edge& e)
|
|
||||||
{
|
|
||||||
if (e.type == PI_MSG) {
|
|
||||||
M(e.source)->clearPiResidual(e.destination);
|
|
||||||
} else {
|
|
||||||
M(e.destination)->clearLambdaResidual(e.source);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline bool
|
|
||||||
compareResidual (const Edge& e1, const Edge& e2)
|
|
||||||
{
|
|
||||||
double residual1;
|
|
||||||
double residual2;
|
|
||||||
if (e1.type == PI_MSG) {
|
|
||||||
residual1 = Edge::klass->M(e1.source)->getPiResidual(e1.destination);
|
|
||||||
} else {
|
|
||||||
residual1 = Edge::klass->M(e1.destination)->getLambdaResidual(e1.source);
|
|
||||||
}
|
|
||||||
if (e2.type == PI_MSG) {
|
|
||||||
residual2 = Edge::klass->M(e2.source)->getPiResidual(e2.destination);
|
|
||||||
} else {
|
|
||||||
residual2 = Edge::klass->M(e2.destination)->getLambdaResidual(e2.source);
|
|
||||||
}
|
|
||||||
return residual1 > residual2;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline BpNode*
|
|
||||||
BPSolver::M (const BayesNode* node) const
|
|
||||||
{
|
|
||||||
assert (node);
|
|
||||||
assert (node == bn_->getNode (node->getVarId()));
|
|
||||||
assert (node->getIndex() < msgs_.size());
|
|
||||||
return msgs_[node->getIndex()];
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
|
@ -1,30 +1,24 @@
|
|||||||
|
#include <cstdlib>
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <iomanip>
|
#include <iomanip>
|
||||||
#include <cassert>
|
|
||||||
#include <cstdlib>
|
|
||||||
#include <map>
|
|
||||||
|
|
||||||
#include "xmlParser/xmlParser.h"
|
#include "xmlParser/xmlParser.h"
|
||||||
|
|
||||||
#include "BayesNet.h"
|
#include "BayesNet.h"
|
||||||
|
|
||||||
|
|
||||||
BayesNet::BayesNet (void)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
BayesNet::BayesNet (const char* fileName)
|
BayesNet::BayesNet (const char* fileName)
|
||||||
{
|
{
|
||||||
map<string, Domain> domains;
|
map<string, Domain> domains;
|
||||||
XMLNode xMainNode = XMLNode::openFileHelper (fileName, "BIF");
|
XMLNode xMainNode = XMLNode::openFileHelper (fileName, "BIF");
|
||||||
// only the first network is parsed, others are ignored
|
// only the first network is parsed, others are ignored
|
||||||
XMLNode xNode = xMainNode.getChildNode ("NETWORK");
|
XMLNode xNode = xMainNode.getChildNode ("NETWORK");
|
||||||
int nVars = xNode.nChildNode ("VARIABLE");
|
unsigned nVars = xNode.nChildNode ("VARIABLE");
|
||||||
for (int i = 0; i < nVars; i++) {
|
for (unsigned i = 0; i < nVars; i++) {
|
||||||
XMLNode var = xNode.getChildNode ("VARIABLE", i);
|
XMLNode var = xNode.getChildNode ("VARIABLE", i);
|
||||||
string type = var.getAttribute ("TYPE");
|
string type = var.getAttribute ("TYPE");
|
||||||
if (type != "nature") {
|
if (type != "nature") {
|
||||||
@ -32,9 +26,9 @@ BayesNet::BayesNet (const char* fileName)
|
|||||||
abort();
|
abort();
|
||||||
}
|
}
|
||||||
Domain domain;
|
Domain domain;
|
||||||
string label = var.getChildNode("NAME").getText();
|
string varLabel = var.getChildNode("NAME").getText();
|
||||||
int domainSize = var.nChildNode ("OUTCOME");
|
unsigned dsize = var.nChildNode ("OUTCOME");
|
||||||
for (int j = 0; j < domainSize; j++) {
|
for (unsigned j = 0; j < dsize; j++) {
|
||||||
if (var.getChildNode("OUTCOME", j).getText() == 0) {
|
if (var.getChildNode("OUTCOME", j).getText() == 0) {
|
||||||
stringstream ss;
|
stringstream ss;
|
||||||
ss << j + 1;
|
ss << j + 1;
|
||||||
@ -43,37 +37,37 @@ BayesNet::BayesNet (const char* fileName)
|
|||||||
domain.push_back (var.getChildNode("OUTCOME", j).getText());
|
domain.push_back (var.getChildNode("OUTCOME", j).getText());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
domains.insert (make_pair (label, domain));
|
domains.insert (make_pair (varLabel, domain));
|
||||||
}
|
}
|
||||||
|
|
||||||
int nDefs = xNode.nChildNode ("DEFINITION");
|
unsigned nDefs = xNode.nChildNode ("DEFINITION");
|
||||||
if (nVars != nDefs) {
|
if (nVars != nDefs) {
|
||||||
cerr << "error: different number of variables and definitions";
|
cerr << "error: different number of variables and definitions" << endl;
|
||||||
cerr << endl;
|
abort();
|
||||||
}
|
}
|
||||||
|
|
||||||
queue<int> indexes;
|
queue<unsigned> indexes;
|
||||||
for (int i = 0; i < nDefs; i++) {
|
for (unsigned i = 0; i < nDefs; i++) {
|
||||||
indexes.push (i);
|
indexes.push (i);
|
||||||
}
|
}
|
||||||
|
|
||||||
while (!indexes.empty()) {
|
while (!indexes.empty()) {
|
||||||
int index = indexes.front();
|
unsigned index = indexes.front();
|
||||||
indexes.pop();
|
indexes.pop();
|
||||||
XMLNode def = xNode.getChildNode ("DEFINITION", index);
|
XMLNode def = xNode.getChildNode ("DEFINITION", index);
|
||||||
string label = def.getChildNode("FOR").getText();
|
string varLabel = def.getChildNode("FOR").getText();
|
||||||
map<string, Domain>::const_iterator iter;
|
map<string, Domain>::const_iterator iter;
|
||||||
iter = domains.find (label);
|
iter = domains.find (varLabel);
|
||||||
if (iter == domains.end()) {
|
if (iter == domains.end()) {
|
||||||
cerr << "error: unknow variable `" << label << "'" << endl;
|
cerr << "error: unknow variable `" << varLabel << "'" << endl;
|
||||||
abort();
|
abort();
|
||||||
}
|
}
|
||||||
bool processItLatter = false;
|
bool processItLatter = false;
|
||||||
NodeSet parents;
|
BnNodeSet parents;
|
||||||
int nParams = iter->second.size();
|
unsigned nParams = iter->second.size();
|
||||||
for (int j = 0; j < def.nChildNode ("GIVEN"); j++) {
|
for (int j = 0; j < def.nChildNode ("GIVEN"); j++) {
|
||||||
string parentLabel = def.getChildNode("GIVEN", j).getText();
|
string parentLabel = def.getChildNode("GIVEN", j).getText();
|
||||||
BayesNode* parentNode = getNode (parentLabel);
|
BayesNode* parentNode = getBayesNode (parentLabel);
|
||||||
if (parentNode) {
|
if (parentNode) {
|
||||||
nParams *= parentNode->getDomainSize();
|
nParams *= parentNode->getDomainSize();
|
||||||
parents.push_back (parentNode);
|
parents.push_back (parentNode);
|
||||||
@ -95,7 +89,7 @@ BayesNet::BayesNet (const char* fileName)
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (!processItLatter) {
|
if (!processItLatter) {
|
||||||
int count = 0;
|
unsigned count = 0;
|
||||||
ParamSet params (nParams);
|
ParamSet params (nParams);
|
||||||
stringstream s (def.getChildNode("TABLE").getText());
|
stringstream s (def.getChildNode("TABLE").getText());
|
||||||
while (!s.eof() && count < nParams) {
|
while (!s.eof() && count < nParams) {
|
||||||
@ -104,11 +98,11 @@ BayesNet::BayesNet (const char* fileName)
|
|||||||
}
|
}
|
||||||
if (count != nParams) {
|
if (count != nParams) {
|
||||||
cerr << "error: invalid number of parameters " ;
|
cerr << "error: invalid number of parameters " ;
|
||||||
cerr << "for variable `" << label << "'" << endl;
|
cerr << "for variable `" << varLabel << "'" << endl;
|
||||||
abort();
|
abort();
|
||||||
}
|
}
|
||||||
params = reorderParameters (params, iter->second.size());
|
params = reorderParameters (params, iter->second.size());
|
||||||
addNode (label, iter->second, parents, params);
|
addNode (varLabel, iter->second, parents, params);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
setIndexes();
|
setIndexes();
|
||||||
@ -118,7 +112,6 @@ BayesNet::BayesNet (const char* fileName)
|
|||||||
|
|
||||||
BayesNet::~BayesNet (void)
|
BayesNet::~BayesNet (void)
|
||||||
{
|
{
|
||||||
Statistics::writeStats();
|
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||||
delete nodes_[i];
|
delete nodes_[i];
|
||||||
}
|
}
|
||||||
@ -127,25 +120,25 @@ BayesNet::~BayesNet (void)
|
|||||||
|
|
||||||
|
|
||||||
BayesNode*
|
BayesNode*
|
||||||
BayesNet::addNode (unsigned varId)
|
BayesNet::addNode (Vid vid)
|
||||||
{
|
{
|
||||||
indexMap_.insert (make_pair (varId, nodes_.size()));
|
indexMap_.insert (make_pair (vid, nodes_.size()));
|
||||||
nodes_.push_back (new BayesNode (varId));
|
nodes_.push_back (new BayesNode (vid));
|
||||||
return nodes_.back();
|
return nodes_.back();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
BayesNode*
|
BayesNode*
|
||||||
BayesNet::addNode (unsigned varId,
|
BayesNet::addNode (Vid vid,
|
||||||
unsigned dsize,
|
unsigned dsize,
|
||||||
int evidence,
|
int evidence,
|
||||||
NodeSet& parents,
|
BnNodeSet& parents,
|
||||||
Distribution* dist)
|
Distribution* dist)
|
||||||
{
|
{
|
||||||
indexMap_.insert (make_pair (varId, nodes_.size()));
|
indexMap_.insert (make_pair (vid, nodes_.size()));
|
||||||
nodes_.push_back (new BayesNode (
|
nodes_.push_back (new BayesNode (
|
||||||
varId, dsize, evidence, parents, dist));
|
vid, dsize, evidence, parents, dist));
|
||||||
return nodes_.back();
|
return nodes_.back();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -154,7 +147,7 @@ BayesNet::addNode (unsigned varId,
|
|||||||
BayesNode*
|
BayesNode*
|
||||||
BayesNet::addNode (string label,
|
BayesNet::addNode (string label,
|
||||||
Domain domain,
|
Domain domain,
|
||||||
NodeSet& parents,
|
BnNodeSet& parents,
|
||||||
ParamSet& params)
|
ParamSet& params)
|
||||||
{
|
{
|
||||||
indexMap_.insert (make_pair (nodes_.size(), nodes_.size()));
|
indexMap_.insert (make_pair (nodes_.size(), nodes_.size()));
|
||||||
@ -169,9 +162,9 @@ BayesNet::addNode (string label,
|
|||||||
|
|
||||||
|
|
||||||
BayesNode*
|
BayesNode*
|
||||||
BayesNet::getNode (unsigned varId) const
|
BayesNet::getBayesNode (Vid vid) const
|
||||||
{
|
{
|
||||||
IndexMap::const_iterator it = indexMap_.find(varId);
|
IndexMap::const_iterator it = indexMap_.find (vid);
|
||||||
if (it == indexMap_.end()) {
|
if (it == indexMap_.end()) {
|
||||||
return 0;
|
return 0;
|
||||||
} else {
|
} else {
|
||||||
@ -182,7 +175,7 @@ BayesNet::getNode (unsigned varId) const
|
|||||||
|
|
||||||
|
|
||||||
BayesNode*
|
BayesNode*
|
||||||
BayesNet::getNode (string label) const
|
BayesNet::getBayesNode (string label) const
|
||||||
{
|
{
|
||||||
BayesNode* node = 0;
|
BayesNode* node = 0;
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||||
@ -196,6 +189,15 @@ BayesNet::getNode (string label) const
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Variable*
|
||||||
|
BayesNet::getVariable (Vid vid) const
|
||||||
|
{
|
||||||
|
return getBayesNode (vid);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
BayesNet::addDistribution (Distribution* dist)
|
BayesNet::addDistribution (Distribution* dist)
|
||||||
{
|
{
|
||||||
@ -219,15 +221,15 @@ BayesNet::getDistribution (unsigned distId) const
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
const NodeSet&
|
const BnNodeSet&
|
||||||
BayesNet::getNodes (void) const
|
BayesNet::getBayesNodes (void) const
|
||||||
{
|
{
|
||||||
return nodes_;
|
return nodes_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
int
|
unsigned
|
||||||
BayesNet::getNumberOfNodes (void) const
|
BayesNet::getNumberOfNodes (void) const
|
||||||
{
|
{
|
||||||
return nodes_.size();
|
return nodes_.size();
|
||||||
@ -235,10 +237,10 @@ BayesNet::getNumberOfNodes (void) const
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
NodeSet
|
BnNodeSet
|
||||||
BayesNet::getRootNodes (void) const
|
BayesNet::getRootNodes (void) const
|
||||||
{
|
{
|
||||||
NodeSet roots;
|
BnNodeSet roots;
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||||
if (nodes_[i]->isRoot()) {
|
if (nodes_[i]->isRoot()) {
|
||||||
roots.push_back (nodes_[i]);
|
roots.push_back (nodes_[i]);
|
||||||
@ -249,10 +251,10 @@ BayesNet::getRootNodes (void) const
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
NodeSet
|
BnNodeSet
|
||||||
BayesNet::getLeafNodes (void) const
|
BayesNet::getLeafNodes (void) const
|
||||||
{
|
{
|
||||||
NodeSet leafs;
|
BnNodeSet leafs;
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||||
if (nodes_[i]->isLeaf()) {
|
if (nodes_[i]->isLeaf()) {
|
||||||
leafs.push_back (nodes_[i]);
|
leafs.push_back (nodes_[i]);
|
||||||
@ -276,30 +278,32 @@ BayesNet::getVariables (void) const
|
|||||||
|
|
||||||
|
|
||||||
BayesNet*
|
BayesNet*
|
||||||
BayesNet::pruneNetwork (BayesNode* queryNode) const
|
BayesNet::getMinimalRequesiteNetwork (Vid vid) const
|
||||||
{
|
{
|
||||||
NodeSet queryNodes;
|
return getMinimalRequesiteNetwork (VidSet() = {vid});
|
||||||
queryNodes.push_back (queryNode);
|
|
||||||
return pruneNetwork (queryNodes);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
BayesNet*
|
BayesNet*
|
||||||
BayesNet::pruneNetwork (const NodeSet& interestedVars) const
|
BayesNet::getMinimalRequesiteNetwork (const VidSet& queryVids) const
|
||||||
{
|
{
|
||||||
/*
|
BnNodeSet queryVars;
|
||||||
cout << "interested vars: " ;
|
for (unsigned i = 0; i < queryVids.size(); i++) {
|
||||||
for (unsigned i = 0; i < interestedVars.size(); i++) {
|
assert (getBayesNode (queryVids[i]));
|
||||||
cout << interestedVars[i]->getLabel() << " " ;
|
queryVars.push_back (getBayesNode (queryVids[i]));
|
||||||
}
|
}
|
||||||
cout << endl;
|
// cout << "query vars: " ;
|
||||||
*/
|
// for (unsigned i = 0; i < queryVars.size(); i++) {
|
||||||
|
// cout << queryVars[i]->getLabel() << " " ;
|
||||||
|
// }
|
||||||
|
// cout << endl;
|
||||||
|
|
||||||
vector<StateInfo*> states (nodes_.size(), 0);
|
vector<StateInfo*> states (nodes_.size(), 0);
|
||||||
|
|
||||||
Scheduling scheduling;
|
Scheduling scheduling;
|
||||||
for (NodeSet::const_iterator it = interestedVars.begin();
|
for (BnNodeSet::const_iterator it = queryVars.begin();
|
||||||
it != interestedVars.end(); it++) {
|
it != queryVars.end(); it++) {
|
||||||
scheduling.push (ScheduleInfo (*it, false, true));
|
scheduling.push (ScheduleInfo (*it, false, true));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -378,18 +382,18 @@ BayesNet::constructGraph (BayesNet* bn,
|
|||||||
states[i]->markedOnTop;
|
states[i]->markedOnTop;
|
||||||
}
|
}
|
||||||
if (isRequired) {
|
if (isRequired) {
|
||||||
NodeSet parents;
|
BnNodeSet parents;
|
||||||
if (states[i]->markedOnTop) {
|
if (states[i]->markedOnTop) {
|
||||||
const NodeSet& ps = nodes_[i]->getParents();
|
const BnNodeSet& ps = nodes_[i]->getParents();
|
||||||
for (unsigned j = 0; j < ps.size(); j++) {
|
for (unsigned j = 0; j < ps.size(); j++) {
|
||||||
BayesNode* parent = bn->getNode (ps[j]->getVarId());
|
BayesNode* parent = bn->getBayesNode (ps[j]->getVarId());
|
||||||
if (!parent) {
|
if (!parent) {
|
||||||
parent = bn->addNode (ps[j]->getVarId());
|
parent = bn->addNode (ps[j]->getVarId());
|
||||||
}
|
}
|
||||||
parents.push_back (parent);
|
parents.push_back (parent);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
BayesNode* node = bn->getNode (nodes_[i]->getVarId());
|
BayesNode* node = bn->getBayesNode (nodes_[i]->getVarId());
|
||||||
if (node) {
|
if (node) {
|
||||||
node->setData (nodes_[i]->getDomainSize(),
|
node->setData (nodes_[i]->getDomainSize(),
|
||||||
nodes_[i]->getEvidence(), parents,
|
nodes_[i]->getEvidence(), parents,
|
||||||
@ -411,65 +415,6 @@ BayesNet::constructGraph (BayesNet* bn,
|
|||||||
bn->setIndexes();
|
bn->setIndexes();
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
|
||||||
void
|
|
||||||
BayesNet::constructGraph (BayesNet* bn,
|
|
||||||
const vector<StateInfo*>& states) const
|
|
||||||
{
|
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
|
||||||
if (states[i]) {
|
|
||||||
if (nodes_[i]->hasEvidence() && states[i]->visited) {
|
|
||||||
NodeSet parents;
|
|
||||||
if (states[i]->markedOnTop) {
|
|
||||||
const NodeSet& ps = nodes_[i]->getParents();
|
|
||||||
for (unsigned j = 0; j < ps.size(); j++) {
|
|
||||||
BayesNode* parent = bn->getNode (ps[j]->getVarId());
|
|
||||||
if (parent == 0) {
|
|
||||||
parent = bn->addNode (ps[j]->getVarId());
|
|
||||||
}
|
|
||||||
parents.push_back (parent);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
BayesNode* n = bn->getNode (nodes_[i]->getVarId());
|
|
||||||
if (n) {
|
|
||||||
n->setData (nodes_[i]->getDomainSize(),
|
|
||||||
nodes_[i]->getEvidence(), parents,
|
|
||||||
nodes_[i]->getDistribution());
|
|
||||||
} else {
|
|
||||||
bn->addNode (nodes_[i]->getVarId(),
|
|
||||||
nodes_[i]->getDomainSize(),
|
|
||||||
nodes_[i]->getEvidence(), parents,
|
|
||||||
nodes_[i]->getDistribution());
|
|
||||||
}
|
|
||||||
|
|
||||||
} else if (states[i]->markedOnTop) {
|
|
||||||
NodeSet parents;
|
|
||||||
const NodeSet& ps = nodes_[i]->getParents();
|
|
||||||
for (unsigned j = 0; j < ps.size(); j++) {
|
|
||||||
BayesNode* parent = bn->getNode (ps[j]->getVarId());
|
|
||||||
if (parent == 0) {
|
|
||||||
parent = bn->addNode (ps[j]->getVarId());
|
|
||||||
}
|
|
||||||
parents.push_back (parent);
|
|
||||||
}
|
|
||||||
|
|
||||||
BayesNode* n = bn->getNode (nodes_[i]->getVarId());
|
|
||||||
if (n) {
|
|
||||||
n->setData (nodes_[i]->getDomainSize(),
|
|
||||||
nodes_[i]->getEvidence(), parents,
|
|
||||||
nodes_[i]->getDistribution());
|
|
||||||
} else {
|
|
||||||
bn->addNode (nodes_[i]->getVarId(),
|
|
||||||
nodes_[i]->getDomainSize(),
|
|
||||||
nodes_[i]->getEvidence(), parents,
|
|
||||||
nodes_[i]->getDistribution());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}*/
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
bool
|
||||||
@ -480,70 +425,6 @@ BayesNet::isSingleConnected (void) const
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
vector<DomainConf>
|
|
||||||
BayesNet::getDomainConfigurationsOf (const NodeSet& nodes)
|
|
||||||
{
|
|
||||||
int nConfs = 1;
|
|
||||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
|
||||||
nConfs *= nodes[i]->getDomainSize();
|
|
||||||
}
|
|
||||||
|
|
||||||
vector<DomainConf> confs (nConfs);
|
|
||||||
for (int i = 0; i < nConfs; i++) {
|
|
||||||
confs[i].resize (nodes.size());
|
|
||||||
}
|
|
||||||
|
|
||||||
int nReps = 1;
|
|
||||||
for (int i = nodes.size() - 1; i >= 0; i--) {
|
|
||||||
int index = 0;
|
|
||||||
while (index < nConfs) {
|
|
||||||
for (int j = 0; j < nodes[i]->getDomainSize(); j++) {
|
|
||||||
for (int r = 0; r < nReps; r++) {
|
|
||||||
confs[index][i] = j;
|
|
||||||
index++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
nReps *= nodes[i]->getDomainSize();
|
|
||||||
}
|
|
||||||
|
|
||||||
return confs;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
vector<string>
|
|
||||||
BayesNet::getInstantiations (const NodeSet& parents_)
|
|
||||||
{
|
|
||||||
int nParents = parents_.size();
|
|
||||||
int rowSize = 1;
|
|
||||||
for (unsigned i = 0; i < parents_.size(); i++) {
|
|
||||||
rowSize *= parents_[i]->getDomainSize();
|
|
||||||
}
|
|
||||||
int nReps = 1;
|
|
||||||
vector<string> headers (rowSize);
|
|
||||||
for (int i = nParents - 1; i >= 0; i--) {
|
|
||||||
Domain domain = parents_[i]->getDomain();
|
|
||||||
int index = 0;
|
|
||||||
while (index < rowSize) {
|
|
||||||
for (int j = 0; j < parents_[i]->getDomainSize(); j++) {
|
|
||||||
for (int r = 0; r < nReps; r++) {
|
|
||||||
if (headers[index] != "") {
|
|
||||||
headers[index] = domain[j] + "," + headers[index];
|
|
||||||
} else {
|
|
||||||
headers[index] = domain[j];
|
|
||||||
}
|
|
||||||
index++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
nReps *= parents_[i]->getDomainSize();
|
|
||||||
}
|
|
||||||
return headers;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
BayesNet::setIndexes (void)
|
BayesNet::setIndexes (void)
|
||||||
{
|
{
|
||||||
@ -565,7 +446,7 @@ BayesNet::freeDistributions (void)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
BayesNet::printNetwork (void) const
|
BayesNet::printGraphicalModel (void) const
|
||||||
{
|
{
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||||
cout << *nodes_[i];
|
cout << *nodes_[i];
|
||||||
@ -575,32 +456,11 @@ BayesNet::printNetwork (void) const
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
BayesNet::printNetworkToFile (const char* fileName) const
|
BayesNet::exportToDotFormat (const char* fileName,
|
||||||
|
bool showNeighborless,
|
||||||
|
CVidSet& highlightVids) const
|
||||||
{
|
{
|
||||||
string s = "../../" ;
|
ofstream out (fileName);
|
||||||
s += fileName;
|
|
||||||
ofstream out (s.c_str());
|
|
||||||
if (!out.is_open()) {
|
|
||||||
cerr << "error: cannot open file to write at " ;
|
|
||||||
cerr << "BayesNet::printToFile()" << endl;
|
|
||||||
abort();
|
|
||||||
}
|
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
|
||||||
out << *nodes_[i];
|
|
||||||
}
|
|
||||||
out.close();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BayesNet::exportToDotFile (const char* fileName,
|
|
||||||
bool showNeighborless,
|
|
||||||
const NodeSet& highlightNodes) const
|
|
||||||
{
|
|
||||||
string s = "../../" ;
|
|
||||||
s+= fileName;
|
|
||||||
ofstream out (s.c_str());
|
|
||||||
if (!out.is_open()) {
|
if (!out.is_open()) {
|
||||||
cerr << "error: cannot open file to write at " ;
|
cerr << "error: cannot open file to write at " ;
|
||||||
cerr << "BayesNet::exportToDotFile()" << endl;
|
cerr << "BayesNet::exportToDotFile()" << endl;
|
||||||
@ -608,13 +468,6 @@ BayesNet::exportToDotFile (const char* fileName,
|
|||||||
}
|
}
|
||||||
|
|
||||||
out << "digraph \"" << fileName << "\" {" << endl;
|
out << "digraph \"" << fileName << "\" {" << endl;
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
|
||||||
const NodeSet& childs = nodes_[i]->getChilds();
|
|
||||||
for (unsigned j = 0; j < childs.size(); j++) {
|
|
||||||
out << '"' << nodes_[i]->getLabel() << '"' << " -> " ;
|
|
||||||
out << '"' << childs[j]->getLabel() << '"' << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||||
if (showNeighborless || nodes_[i]->hasNeighbors()) {
|
if (showNeighborless || nodes_[i]->hasNeighbors()) {
|
||||||
@ -627,9 +480,24 @@ BayesNet::exportToDotFile (const char* fileName,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (unsigned i = 0; i < highlightNodes.size(); i++) {
|
for (unsigned i = 0; i < highlightVids.size(); i++) {
|
||||||
out << '"' << highlightNodes[i]->getLabel() << '"' ;
|
BayesNode* node = getBayesNode (highlightVids[i]);
|
||||||
out << " [shape=box]" << endl;
|
if (node) {
|
||||||
|
out << '"' << node->getLabel() << '"' ;
|
||||||
|
// out << " [shape=polygon, sides=6]" << endl;
|
||||||
|
out << " [shape=box3d]" << endl;
|
||||||
|
} else {
|
||||||
|
cout << "error: invalid variable id: " << highlightVids[i] << endl;
|
||||||
|
abort();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||||
|
const BnNodeSet& childs = nodes_[i]->getChilds();
|
||||||
|
for (unsigned j = 0; j < childs.size(); j++) {
|
||||||
|
out << '"' << nodes_[i]->getLabel() << '"' << " -> " ;
|
||||||
|
out << '"' << childs[j]->getLabel() << '"' << endl;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
out << "}" << endl;
|
out << "}" << endl;
|
||||||
@ -639,11 +507,9 @@ BayesNet::exportToDotFile (const char* fileName,
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
BayesNet::exportToBifFile (const char* fileName) const
|
BayesNet::exportToBifFormat (const char* fileName) const
|
||||||
{
|
{
|
||||||
string s = "../../" ;
|
ofstream out (fileName);
|
||||||
s += fileName;
|
|
||||||
ofstream out (s.c_str());
|
|
||||||
if(!out.is_open()) {
|
if(!out.is_open()) {
|
||||||
cerr << "error: cannot open file to write at " ;
|
cerr << "error: cannot open file to write at " ;
|
||||||
cerr << "BayesNet::exportToBifFile()" << endl;
|
cerr << "BayesNet::exportToBifFile()" << endl;
|
||||||
@ -666,7 +532,7 @@ BayesNet::exportToBifFile (const char* fileName) const
|
|||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||||
out << "<DEFINITION>" << endl;
|
out << "<DEFINITION>" << endl;
|
||||||
out << "\t<FOR>" << nodes_[i]->getLabel() << "</FOR>" << endl;
|
out << "\t<FOR>" << nodes_[i]->getLabel() << "</FOR>" << endl;
|
||||||
const NodeSet& parents = nodes_[i]->getParents();
|
const BnNodeSet& parents = nodes_[i]->getParents();
|
||||||
for (unsigned j = 0; j < parents.size(); j++) {
|
for (unsigned j = 0; j < parents.size(); j++) {
|
||||||
out << "\t<GIVEN>" << parents[j]->getLabel();
|
out << "\t<GIVEN>" << parents[j]->getLabel();
|
||||||
out << "</GIVEN>" << endl;
|
out << "</GIVEN>" << endl;
|
||||||
@ -682,7 +548,7 @@ BayesNet::exportToBifFile (const char* fileName) const
|
|||||||
}
|
}
|
||||||
out << "</NETWORK>" << endl;
|
out << "</NETWORK>" << endl;
|
||||||
out << "</BIF>" << endl << endl;
|
out << "</BIF>" << endl << endl;
|
||||||
out.close();
|
out.close();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -731,8 +597,8 @@ vector<int>
|
|||||||
BayesNet::getAdjacentNodes (int v) const
|
BayesNet::getAdjacentNodes (int v) const
|
||||||
{
|
{
|
||||||
vector<int> adjacencies;
|
vector<int> adjacencies;
|
||||||
const NodeSet& parents = nodes_[v]->getParents();
|
const BnNodeSet& parents = nodes_[v]->getParents();
|
||||||
const NodeSet& childs = nodes_[v]->getChilds();
|
const BnNodeSet& childs = nodes_[v]->getChilds();
|
||||||
for (unsigned i = 0; i < parents.size(); i++) {
|
for (unsigned i = 0; i < parents.size(); i++) {
|
||||||
adjacencies.push_back (parents[i]->getIndex());
|
adjacencies.push_back (parents[i]->getIndex());
|
||||||
}
|
}
|
||||||
@ -745,8 +611,8 @@ BayesNet::getAdjacentNodes (int v) const
|
|||||||
|
|
||||||
|
|
||||||
ParamSet
|
ParamSet
|
||||||
BayesNet::reorderParameters (const ParamSet& params,
|
BayesNet::reorderParameters (CParamSet params,
|
||||||
int domainSize) const
|
unsigned domainSize) const
|
||||||
{
|
{
|
||||||
// the interchange format for bayesian networks keeps the probabilities
|
// the interchange format for bayesian networks keeps the probabilities
|
||||||
// in the following order:
|
// in the following order:
|
||||||
@ -773,15 +639,15 @@ BayesNet::reorderParameters (const ParamSet& params,
|
|||||||
|
|
||||||
|
|
||||||
ParamSet
|
ParamSet
|
||||||
BayesNet::revertParameterReorder (const ParamSet& params,
|
BayesNet::revertParameterReorder (CParamSet params,
|
||||||
int domainSize) const
|
unsigned domainSize) const
|
||||||
{
|
{
|
||||||
unsigned count = 0;
|
unsigned count = 0;
|
||||||
unsigned rowSize = params.size() / domainSize;
|
unsigned rowSize = params.size() / domainSize;
|
||||||
ParamSet reordered;
|
ParamSet reordered;
|
||||||
while (reordered.size() < params.size()) {
|
while (reordered.size() < params.size()) {
|
||||||
unsigned idx = count;
|
unsigned idx = count;
|
||||||
for (int i = 0; i < domainSize; i++) {
|
for (unsigned i = 0; i < domainSize; i++) {
|
||||||
reordered.push_back (params[idx]);
|
reordered.push_back (params[idx]);
|
||||||
idx += rowSize;
|
idx += rowSize;
|
||||||
}
|
}
|
||||||
|
@ -4,8 +4,6 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
#include <queue>
|
#include <queue>
|
||||||
#include <list>
|
#include <list>
|
||||||
#include <string>
|
|
||||||
#include <unordered_map>
|
|
||||||
#include <map>
|
#include <map>
|
||||||
|
|
||||||
#include "GraphicalModel.h"
|
#include "GraphicalModel.h"
|
||||||
@ -46,42 +44,42 @@ struct StateInfo
|
|||||||
|
|
||||||
typedef vector<Distribution*> DistSet;
|
typedef vector<Distribution*> DistSet;
|
||||||
typedef queue<ScheduleInfo, list<ScheduleInfo> > Scheduling;
|
typedef queue<ScheduleInfo, list<ScheduleInfo> > Scheduling;
|
||||||
typedef unordered_map<unsigned, unsigned> Histogram;
|
typedef map<unsigned, unsigned> Histogram;
|
||||||
typedef unordered_map<unsigned, double> Times;
|
typedef map<unsigned, double> Times;
|
||||||
|
|
||||||
|
|
||||||
class BayesNet : public GraphicalModel
|
class BayesNet : public GraphicalModel
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
BayesNet (void);
|
BayesNet (void) {};
|
||||||
BayesNet (const char*);
|
BayesNet (const char*);
|
||||||
~BayesNet (void);
|
~BayesNet (void);
|
||||||
|
|
||||||
BayesNode* addNode (unsigned);
|
BayesNode* addNode (unsigned);
|
||||||
BayesNode* addNode (unsigned, unsigned, int, NodeSet&, Distribution*);
|
BayesNode* addNode (unsigned, unsigned, int, BnNodeSet&,
|
||||||
BayesNode* addNode (string, Domain, NodeSet&, ParamSet&);
|
Distribution*);
|
||||||
BayesNode* getNode (unsigned) const;
|
BayesNode* addNode (string, Domain, BnNodeSet&, ParamSet&);
|
||||||
BayesNode* getNode (string) const;
|
BayesNode* getBayesNode (Vid) const;
|
||||||
|
BayesNode* getBayesNode (string) const;
|
||||||
|
Variable* getVariable (Vid) const;
|
||||||
void addDistribution (Distribution*);
|
void addDistribution (Distribution*);
|
||||||
Distribution* getDistribution (unsigned) const;
|
Distribution* getDistribution (unsigned) const;
|
||||||
const NodeSet& getNodes (void) const;
|
const BnNodeSet& getBayesNodes (void) const;
|
||||||
int getNumberOfNodes (void) const;
|
unsigned getNumberOfNodes (void) const;
|
||||||
NodeSet getRootNodes (void) const;
|
BnNodeSet getRootNodes (void) const;
|
||||||
NodeSet getLeafNodes (void) const;
|
BnNodeSet getLeafNodes (void) const;
|
||||||
VarSet getVariables (void) const;
|
VarSet getVariables (void) const;
|
||||||
BayesNet* pruneNetwork (BayesNode*) const;
|
BayesNet* getMinimalRequesiteNetwork (Vid) const;
|
||||||
BayesNet* pruneNetwork (const NodeSet& queryNodes) const;
|
BayesNet* getMinimalRequesiteNetwork (const VidSet&) const;
|
||||||
void constructGraph (BayesNet*, const vector<StateInfo*>&) const;
|
void constructGraph (BayesNet*,
|
||||||
|
const vector<StateInfo*>&) const;
|
||||||
bool isSingleConnected (void) const;
|
bool isSingleConnected (void) const;
|
||||||
static vector<DomainConf> getDomainConfigurationsOf (const NodeSet&);
|
|
||||||
static vector<string> getInstantiations (const NodeSet& nodes);
|
|
||||||
void setIndexes (void);
|
void setIndexes (void);
|
||||||
void freeDistributions (void);
|
void freeDistributions (void);
|
||||||
void printNetwork (void) const;
|
void printGraphicalModel (void) const;
|
||||||
void printNetworkToFile (const char*) const;
|
void exportToDotFormat (const char*, bool = true,
|
||||||
void exportToDotFile (const char*, bool = true,
|
CVidSet = VidSet()) const;
|
||||||
const NodeSet& = NodeSet()) const;
|
void exportToBifFormat (const char*) const;
|
||||||
void exportToBifFile (const char*) const;
|
|
||||||
|
|
||||||
static Histogram histogram_;
|
static Histogram histogram_;
|
||||||
static Times times_;
|
static Times times_;
|
||||||
@ -93,12 +91,12 @@ class BayesNet : public GraphicalModel
|
|||||||
bool containsUndirectedCycle (int, int,
|
bool containsUndirectedCycle (int, int,
|
||||||
vector<bool>&)const;
|
vector<bool>&)const;
|
||||||
vector<int> getAdjacentNodes (int) const ;
|
vector<int> getAdjacentNodes (int) const ;
|
||||||
ParamSet reorderParameters (const ParamSet&, int) const;
|
ParamSet reorderParameters (CParamSet, unsigned) const;
|
||||||
ParamSet revertParameterReorder (const ParamSet&, int) const;
|
ParamSet revertParameterReorder (CParamSet, unsigned) const;
|
||||||
void scheduleParents (const BayesNode*, Scheduling&) const;
|
void scheduleParents (const BayesNode*, Scheduling&) const;
|
||||||
void scheduleChilds (const BayesNode*, Scheduling&) const;
|
void scheduleChilds (const BayesNode*, Scheduling&) const;
|
||||||
|
|
||||||
NodeSet nodes_;
|
BnNodeSet nodes_;
|
||||||
DistSet dists_;
|
DistSet dists_;
|
||||||
IndexMap indexMap_;
|
IndexMap indexMap_;
|
||||||
};
|
};
|
||||||
@ -108,8 +106,8 @@ class BayesNet : public GraphicalModel
|
|||||||
inline void
|
inline void
|
||||||
BayesNet::scheduleParents (const BayesNode* n, Scheduling& sch) const
|
BayesNet::scheduleParents (const BayesNode* n, Scheduling& sch) const
|
||||||
{
|
{
|
||||||
const NodeSet& ps = n->getParents();
|
const BnNodeSet& ps = n->getParents();
|
||||||
for (NodeSet::const_iterator it = ps.begin(); it != ps.end(); it++) {
|
for (BnNodeSet::const_iterator it = ps.begin(); it != ps.end(); it++) {
|
||||||
sch.push (ScheduleInfo (*it, false, true));
|
sch.push (ScheduleInfo (*it, false, true));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -119,11 +117,11 @@ BayesNet::scheduleParents (const BayesNode* n, Scheduling& sch) const
|
|||||||
inline void
|
inline void
|
||||||
BayesNet::scheduleChilds (const BayesNode* n, Scheduling& sch) const
|
BayesNet::scheduleChilds (const BayesNode* n, Scheduling& sch) const
|
||||||
{
|
{
|
||||||
const NodeSet& cs = n->getChilds();
|
const BnNodeSet& cs = n->getChilds();
|
||||||
for (NodeSet::const_iterator it = cs.begin(); it != cs.end(); it++) {
|
for (BnNodeSet::const_iterator it = cs.begin(); it != cs.end(); it++) {
|
||||||
sch.push (ScheduleInfo (*it, true, false));
|
sch.push (ScheduleInfo (*it, true, false));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif //BP_BAYES_NET_H
|
||||||
|
|
||||||
|
@ -1,26 +1,21 @@
|
|||||||
|
#include <cstdlib>
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <iomanip>
|
#include <iomanip>
|
||||||
#include <cassert>
|
|
||||||
#include <cstdlib>
|
|
||||||
|
|
||||||
#include "BayesNode.h"
|
#include "BayesNode.h"
|
||||||
|
|
||||||
|
|
||||||
BayesNode::BayesNode (unsigned varId) : Variable (varId)
|
BayesNode::BayesNode (Vid vid,
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
BayesNode::BayesNode (unsigned varId,
|
|
||||||
unsigned dsize,
|
unsigned dsize,
|
||||||
int evidence,
|
int evidence,
|
||||||
const NodeSet& parents,
|
const BnNodeSet& parents,
|
||||||
Distribution* dist) : Variable(varId, dsize, evidence)
|
Distribution* dist) : Variable (vid, dsize, evidence)
|
||||||
{
|
{
|
||||||
parents_ = parents;
|
parents_ = parents;
|
||||||
dist_ = dist;
|
dist_ = dist;
|
||||||
for (unsigned int i = 0; i < parents.size(); i++) {
|
for (unsigned int i = 0; i < parents.size(); i++) {
|
||||||
parents[i]->addChild (this);
|
parents[i]->addChild (this);
|
||||||
}
|
}
|
||||||
@ -28,15 +23,15 @@ BayesNode::BayesNode (unsigned varId,
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
BayesNode::BayesNode (unsigned varId,
|
BayesNode::BayesNode (Vid vid,
|
||||||
string label,
|
string label,
|
||||||
const Domain& domain,
|
const Domain& domain,
|
||||||
const NodeSet& parents,
|
const BnNodeSet& parents,
|
||||||
Distribution* dist) : Variable(varId, domain)
|
Distribution* dist) : Variable (vid, domain,
|
||||||
|
NO_EVIDENCE, label)
|
||||||
{
|
{
|
||||||
label_ = new string (label);
|
parents_ = parents;
|
||||||
parents_ = parents;
|
dist_ = dist;
|
||||||
dist_ = dist;
|
|
||||||
for (unsigned int i = 0; i < parents.size(); i++) {
|
for (unsigned int i = 0; i < parents.size(); i++) {
|
||||||
parents[i]->addChild (this);
|
parents[i]->addChild (this);
|
||||||
}
|
}
|
||||||
@ -47,11 +42,11 @@ BayesNode::BayesNode (unsigned varId,
|
|||||||
void
|
void
|
||||||
BayesNode::setData (unsigned dsize,
|
BayesNode::setData (unsigned dsize,
|
||||||
int evidence,
|
int evidence,
|
||||||
const NodeSet& parents,
|
const BnNodeSet& parents,
|
||||||
Distribution* dist)
|
Distribution* dist)
|
||||||
{
|
{
|
||||||
setDomainSize (dsize);
|
setDomainSize (dsize);
|
||||||
evidence_ = evidence;
|
setEvidence (evidence);
|
||||||
parents_ = parents;
|
parents_ = parents;
|
||||||
dist_ = dist;
|
dist_ = dist;
|
||||||
for (unsigned int i = 0; i < parents.size(); i++) {
|
for (unsigned int i = 0; i < parents.size(); i++) {
|
||||||
@ -135,19 +130,18 @@ BayesNode::getCptEntries (void)
|
|||||||
{
|
{
|
||||||
if (dist_->entries.size() == 0) {
|
if (dist_->entries.size() == 0) {
|
||||||
unsigned rowSize = getRowSize();
|
unsigned rowSize = getRowSize();
|
||||||
unsigned nParents = parents_.size();
|
vector<DConf> confs (rowSize);
|
||||||
vector<DomainConf> confs (rowSize);
|
|
||||||
|
|
||||||
for (unsigned i = 0; i < rowSize; i++) {
|
for (unsigned i = 0; i < rowSize; i++) {
|
||||||
confs[i].resize (nParents);
|
confs[i].resize (parents_.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
int nReps = 1;
|
unsigned nReps = 1;
|
||||||
for (int i = nParents - 1; i >= 0; i--) {
|
for (int i = parents_.size() - 1; i >= 0; i--) {
|
||||||
unsigned index = 0;
|
unsigned index = 0;
|
||||||
while (index < rowSize) {
|
while (index < rowSize) {
|
||||||
for (int j = 0; j < parents_[i]->getDomainSize(); j++) {
|
for (unsigned j = 0; j < parents_[i]->getDomainSize(); j++) {
|
||||||
for (int r = 0; r < nReps; r++) {
|
for (unsigned r = 0; r < nReps; r++) {
|
||||||
confs[index][i] = j;
|
confs[index][i] = j;
|
||||||
index++;
|
index++;
|
||||||
}
|
}
|
||||||
@ -184,7 +178,7 @@ BayesNode::cptEntryToString (const CptEntry& entry) const
|
|||||||
{
|
{
|
||||||
stringstream ss;
|
stringstream ss;
|
||||||
ss << "p(" ;
|
ss << "p(" ;
|
||||||
const DomainConf& conf = entry.getParentConfigurations();
|
const DConf& conf = entry.getDomainConfiguration();
|
||||||
int row = entry.getParameterIndex() / getRowSize();
|
int row = entry.getParameterIndex() / getRowSize();
|
||||||
ss << getDomain()[row];
|
ss << getDomain()[row];
|
||||||
if (parents_.size() > 0) {
|
if (parents_.size() > 0) {
|
||||||
@ -207,7 +201,7 @@ BayesNode::cptEntryToString (int row, const CptEntry& entry) const
|
|||||||
{
|
{
|
||||||
stringstream ss;
|
stringstream ss;
|
||||||
ss << "p(" ;
|
ss << "p(" ;
|
||||||
const DomainConf& conf = entry.getParentConfigurations();
|
const DConf& conf = entry.getDomainConfiguration();
|
||||||
ss << getDomain()[row];
|
ss << getDomain()[row];
|
||||||
if (parents_.size() > 0) {
|
if (parents_.size() > 0) {
|
||||||
ss << "|" ;
|
ss << "|" ;
|
||||||
@ -227,16 +221,16 @@ BayesNode::cptEntryToString (int row, const CptEntry& entry) const
|
|||||||
vector<string>
|
vector<string>
|
||||||
BayesNode::getDomainHeaders (void) const
|
BayesNode::getDomainHeaders (void) const
|
||||||
{
|
{
|
||||||
int nParents = parents_.size();
|
unsigned nParents = parents_.size();
|
||||||
int rowSize = getRowSize();
|
unsigned rowSize = getRowSize();
|
||||||
int nReps = 1;
|
unsigned nReps = 1;
|
||||||
vector<string> headers (rowSize);
|
vector<string> headers (rowSize);
|
||||||
for (int i = nParents - 1; i >= 0; i--) {
|
for (int i = nParents - 1; i >= 0; i--) {
|
||||||
Domain domain = parents_[i]->getDomain();
|
Domain domain = parents_[i]->getDomain();
|
||||||
int index = 0;
|
unsigned index = 0;
|
||||||
while (index < rowSize) {
|
while (index < rowSize) {
|
||||||
for (int j = 0; j < parents_[i]->getDomainSize(); j++) {
|
for (unsigned j = 0; j < parents_[i]->getDomainSize(); j++) {
|
||||||
for (int r = 0; r < nReps; r++) {
|
for (unsigned r = 0; r < nReps; r++) {
|
||||||
if (headers[index] != "") {
|
if (headers[index] != "") {
|
||||||
headers[index] = domain[j] + "," + headers[index];
|
headers[index] = domain[j] + "," + headers[index];
|
||||||
} else {
|
} else {
|
||||||
@ -270,7 +264,7 @@ operator << (ostream& o, const BayesNode& node)
|
|||||||
o << endl;
|
o << endl;
|
||||||
|
|
||||||
o << "Parents: " ;
|
o << "Parents: " ;
|
||||||
const NodeSet& parents = node.getParents();
|
const BnNodeSet& parents = node.getParents();
|
||||||
if (parents.size() != 0) {
|
if (parents.size() != 0) {
|
||||||
for (unsigned int i = 0; i < parents.size() - 1; i++) {
|
for (unsigned int i = 0; i < parents.size() - 1; i++) {
|
||||||
o << parents[i]->getLabel() << ", " ;
|
o << parents[i]->getLabel() << ", " ;
|
||||||
@ -280,7 +274,7 @@ operator << (ostream& o, const BayesNode& node)
|
|||||||
o << endl;
|
o << endl;
|
||||||
|
|
||||||
o << "Childs: " ;
|
o << "Childs: " ;
|
||||||
const NodeSet& childs = node.getChilds();
|
const BnNodeSet& childs = node.getChilds();
|
||||||
if (childs.size() != 0) {
|
if (childs.size() != 0) {
|
||||||
for (unsigned int i = 0; i < childs.size() - 1; i++) {
|
for (unsigned int i = 0; i < childs.size() - 1; i++) {
|
||||||
o << childs[i]->getLabel() << ", " ;
|
o << childs[i]->getLabel() << ", " ;
|
||||||
|
@ -1,9 +1,7 @@
|
|||||||
#ifndef BP_BAYESNODE_H
|
#ifndef BP_BAYES_NODE_H
|
||||||
#define BP_BAYESNODE_H
|
#define BP_BAYES_NODE_H
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
|
||||||
#include <sstream>
|
|
||||||
|
|
||||||
#include "Variable.h"
|
#include "Variable.h"
|
||||||
#include "CptEntry.h"
|
#include "CptEntry.h"
|
||||||
@ -16,11 +14,12 @@ using namespace std;
|
|||||||
class BayesNode : public Variable
|
class BayesNode : public Variable
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
BayesNode (unsigned);
|
BayesNode (Vid vid) : Variable (vid) {}
|
||||||
BayesNode (unsigned, unsigned, int, const NodeSet&, Distribution*);
|
BayesNode (Vid, unsigned, int, const BnNodeSet&, Distribution*);
|
||||||
BayesNode (unsigned, string, const Domain&, const NodeSet&, Distribution*);
|
BayesNode (Vid, string, const Domain&, const BnNodeSet&, Distribution*);
|
||||||
|
|
||||||
void setData (unsigned, int, const NodeSet&, Distribution*);
|
void setData (unsigned, int, const BnNodeSet&,
|
||||||
|
Distribution*);
|
||||||
void addChild (BayesNode*);
|
void addChild (BayesNode*);
|
||||||
Distribution* getDistribution (void);
|
Distribution* getDistribution (void);
|
||||||
const ParamSet& getParameters (void);
|
const ParamSet& getParameters (void);
|
||||||
@ -34,11 +33,21 @@ class BayesNode : public Variable
|
|||||||
int getIndexOfParent (const BayesNode*) const;
|
int getIndexOfParent (const BayesNode*) const;
|
||||||
string cptEntryToString (const CptEntry&) const;
|
string cptEntryToString (const CptEntry&) const;
|
||||||
string cptEntryToString (int, const CptEntry&) const;
|
string cptEntryToString (int, const CptEntry&) const;
|
||||||
// inlines
|
|
||||||
const NodeSet& getParents (void) const;
|
const BnNodeSet& getParents (void) const { return parents_; }
|
||||||
const NodeSet& getChilds (void) const;
|
const BnNodeSet& getChilds (void) const { return childs_; }
|
||||||
double getProbability (int, const CptEntry& entry);
|
|
||||||
unsigned getRowSize (void) const;
|
unsigned getRowSize (void) const
|
||||||
|
{
|
||||||
|
return dist_->params.size() / getDomainSize();
|
||||||
|
}
|
||||||
|
|
||||||
|
double getProbability (int row, const CptEntry& entry)
|
||||||
|
{
|
||||||
|
int col = entry.getParameterIndex();
|
||||||
|
int idx = (row * getRowSize()) + col;
|
||||||
|
return dist_->params[idx];
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
DISALLOW_COPY_AND_ASSIGN (BayesNode);
|
DISALLOW_COPY_AND_ASSIGN (BayesNode);
|
||||||
@ -46,46 +55,12 @@ class BayesNode : public Variable
|
|||||||
Domain getDomainHeaders (void) const;
|
Domain getDomainHeaders (void) const;
|
||||||
friend ostream& operator << (ostream&, const BayesNode&);
|
friend ostream& operator << (ostream&, const BayesNode&);
|
||||||
|
|
||||||
NodeSet parents_;
|
BnNodeSet parents_;
|
||||||
NodeSet childs_;
|
BnNodeSet childs_;
|
||||||
Distribution* dist_;
|
Distribution* dist_;
|
||||||
};
|
};
|
||||||
|
|
||||||
ostream& operator << (ostream&, const BayesNode&);
|
ostream& operator << (ostream&, const BayesNode&);
|
||||||
|
|
||||||
|
#endif //BP_BAYES_NODE_H
|
||||||
|
|
||||||
inline const NodeSet&
|
|
||||||
BayesNode::getParents (void) const
|
|
||||||
{
|
|
||||||
return parents_;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline const NodeSet&
|
|
||||||
BayesNode::getChilds (void) const
|
|
||||||
{
|
|
||||||
return childs_;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline double
|
|
||||||
BayesNode::getProbability (int row, const CptEntry& entry)
|
|
||||||
{
|
|
||||||
int col = entry.getParameterIndex();
|
|
||||||
int idx = (row * getRowSize()) + col;
|
|
||||||
return dist_->params[idx];
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline unsigned
|
|
||||||
BayesNode::getRowSize (void) const
|
|
||||||
{
|
|
||||||
return dist_->params.size() / getDomainSize();
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
|
198
packages/CLPBN/clpbn/bp/CountingBP.cpp
Normal file
198
packages/CLPBN/clpbn/bp/CountingBP.cpp
Normal file
@ -0,0 +1,198 @@
|
|||||||
|
#include "CountingBP.h"
|
||||||
|
|
||||||
|
|
||||||
|
CountingBP::~CountingBP (void)
|
||||||
|
{
|
||||||
|
delete lfg_;
|
||||||
|
delete fg_;
|
||||||
|
for (unsigned i = 0; i < links_.size(); i++) {
|
||||||
|
delete links_[i];
|
||||||
|
}
|
||||||
|
links_.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
ParamSet
|
||||||
|
CountingBP::getPosterioriOf (Vid vid) const
|
||||||
|
{
|
||||||
|
FgVarNode* var = lfg_->getEquivalentVariable (vid);
|
||||||
|
ParamSet probs;
|
||||||
|
|
||||||
|
if (var->hasEvidence()) {
|
||||||
|
probs.resize (var->getDomainSize(), 0.0);
|
||||||
|
probs[var->getEvidence()] = 1.0;
|
||||||
|
} else {
|
||||||
|
probs.resize (var->getDomainSize(), 1.0);
|
||||||
|
CLinkSet links = varsI_[var->getIndex()]->getLinks();
|
||||||
|
for (unsigned i = 0; i < links.size(); i++) {
|
||||||
|
ParamSet msg = links[i]->getMessage();
|
||||||
|
CountingBPLink* l = static_cast<CountingBPLink*> (links[i]);
|
||||||
|
Util::pow (msg, l->getNumberOfEdges());
|
||||||
|
for (unsigned j = 0; j < msg.size(); j++) {
|
||||||
|
probs[j] *= msg[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Util::normalize (probs);
|
||||||
|
}
|
||||||
|
return probs;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
CountingBP::initializeSolver (void)
|
||||||
|
{
|
||||||
|
lfg_ = new LiftedFG (*fg_);
|
||||||
|
unsigned nUncVars = fg_->getFgVarNodes().size();
|
||||||
|
unsigned nUncFactors = fg_->getFactors().size();
|
||||||
|
CFgVarSet vars = fg_->getFgVarNodes();
|
||||||
|
unsigned nNeighborLessVars = 0;
|
||||||
|
for (unsigned i = 0; i < vars.size(); i++) {
|
||||||
|
CFactorSet factors = vars[i]->getFactors();
|
||||||
|
if (factors.size() == 1 && factors[0]->getFgVarNodes().size() == 1) {
|
||||||
|
nNeighborLessVars ++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// cout << "UNCOMPRESSED FACTOR GRAPH" << endl;
|
||||||
|
// fg_->printGraphicalModel();
|
||||||
|
fg_->exportToDotFormat ("uncompress.dot");
|
||||||
|
|
||||||
|
FactorGraph *temp;
|
||||||
|
temp = fg_;
|
||||||
|
fg_ = lfg_->getCompressedFactorGraph();
|
||||||
|
unsigned nCompVars = fg_->getFgVarNodes().size();
|
||||||
|
unsigned nCompFactors = fg_->getFactors().size();
|
||||||
|
|
||||||
|
Statistics::updateCompressingStats (nUncVars,
|
||||||
|
nUncFactors,
|
||||||
|
nCompVars,
|
||||||
|
nCompFactors,
|
||||||
|
nNeighborLessVars);
|
||||||
|
|
||||||
|
cout << "COMPRESSED FACTOR GRAPH" << endl;
|
||||||
|
fg_->printGraphicalModel();
|
||||||
|
//fg_->exportToDotFormat ("compress.dot");
|
||||||
|
|
||||||
|
SPSolver::initializeSolver();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
CountingBP::createLinks (void)
|
||||||
|
{
|
||||||
|
const FactorClusterSet fcs = lfg_->getFactorClusters();
|
||||||
|
for (unsigned i = 0; i < fcs.size(); i++) {
|
||||||
|
const VarClusterSet vcs = fcs[i]->getVarClusters();
|
||||||
|
for (unsigned j = 0; j < vcs.size(); j++) {
|
||||||
|
unsigned c = lfg_->getGroundEdgeCount (fcs[i], vcs[j]);
|
||||||
|
links_.push_back (
|
||||||
|
new CountingBPLink (fcs[i]->getRepresentativeFactor(),
|
||||||
|
vcs[j]->getRepresentativeVariable(), c));
|
||||||
|
//cout << (links_.back())->toString() << " edge count =" << c << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
CountingBP::deleteJunction (Factor* f, FgVarNode*)
|
||||||
|
{
|
||||||
|
f->freeDistribution();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
CountingBP::maxResidualSchedule (void)
|
||||||
|
{
|
||||||
|
if (nIter_ == 1) {
|
||||||
|
for (unsigned i = 0; i < links_.size(); i++) {
|
||||||
|
links_[i]->setNextMessage (getFactor2VarMsg (links_[i]));
|
||||||
|
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
|
||||||
|
linkMap_.insert (make_pair (links_[i], it));
|
||||||
|
if (DL >= 2 && DL < 5) {
|
||||||
|
cout << "calculating " << links_[i]->toString() << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (unsigned c = 0; c < links_.size(); c++) {
|
||||||
|
if (DL >= 2) {
|
||||||
|
cout << endl << "current residuals:" << endl;
|
||||||
|
for (SortedOrder::iterator it = sortedOrder_.begin();
|
||||||
|
it != sortedOrder_.end(); it ++) {
|
||||||
|
cout << " " << setw (30) << left << (*it)->toString();
|
||||||
|
cout << "residual = " << (*it)->getResidual() << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
SortedOrder::iterator it = sortedOrder_.begin();
|
||||||
|
Link* link = *it;
|
||||||
|
if (DL >= 2) {
|
||||||
|
cout << "updating " << (*sortedOrder_.begin())->toString() << endl;
|
||||||
|
}
|
||||||
|
if (link->getResidual() < SolverOptions::accuracy) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
link->updateMessage();
|
||||||
|
link->clearResidual();
|
||||||
|
sortedOrder_.erase (it);
|
||||||
|
linkMap_.find (link)->second = sortedOrder_.insert (link);
|
||||||
|
|
||||||
|
// update the messages that depend on message source --> destin
|
||||||
|
CFactorSet factorNeighbors = link->getVariable()->getFactors();
|
||||||
|
for (unsigned i = 0; i < factorNeighbors.size(); i++) {
|
||||||
|
CLinkSet links = factorsI_[factorNeighbors[i]->getIndex()]->getLinks();
|
||||||
|
for (unsigned j = 0; j < links.size(); j++) {
|
||||||
|
if (links[j]->getVariable() != link->getVariable()) { //FIXMEFIXME
|
||||||
|
if (DL >= 2 && DL < 5) {
|
||||||
|
cout << " calculating " << links[j]->toString() << endl;
|
||||||
|
}
|
||||||
|
links[j]->setNextMessage (getFactor2VarMsg (links[j]));
|
||||||
|
LinkMap::iterator iter = linkMap_.find (links[j]);
|
||||||
|
sortedOrder_.erase (iter->second);
|
||||||
|
iter->second = sortedOrder_.insert (links[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
ParamSet
|
||||||
|
CountingBP::getVar2FactorMsg (const Link* link) const
|
||||||
|
{
|
||||||
|
const FgVarNode* src = link->getVariable();
|
||||||
|
const Factor* dest = link->getFactor();
|
||||||
|
ParamSet msg;
|
||||||
|
if (src->hasEvidence()) {
|
||||||
|
cout << "has evidence" << endl;
|
||||||
|
msg.resize (src->getDomainSize(), 0.0);
|
||||||
|
msg[src->getEvidence()] = link->getMessage()[src->getEvidence()];
|
||||||
|
cout << "-> " << link->getVariable()->getLabel() << " " << link->getFactor()->getLabel() << endl;
|
||||||
|
cout << "-> p2s " << Util::parametersToString (msg) << endl;
|
||||||
|
} else {
|
||||||
|
msg = link->getMessage();
|
||||||
|
}
|
||||||
|
const CountingBPLink* l = static_cast<const CountingBPLink*> (link);
|
||||||
|
Util::pow (msg, l->getNumberOfEdges() - 1);
|
||||||
|
CLinkSet links = varsI_[src->getIndex()]->getLinks();
|
||||||
|
for (unsigned i = 0; i < links.size(); i++) {
|
||||||
|
if (links[i]->getFactor() != dest) {
|
||||||
|
ParamSet msgFromFactor = links[i]->getMessage();
|
||||||
|
CountingBPLink* l = static_cast<CountingBPLink*> (links[i]);
|
||||||
|
Util::pow (msgFromFactor, l->getNumberOfEdges());
|
||||||
|
for (unsigned j = 0; j < msgFromFactor.size(); j++) {
|
||||||
|
msg[j] *= msgFromFactor[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return msg;
|
||||||
|
}
|
||||||
|
|
45
packages/CLPBN/clpbn/bp/CountingBP.h
Normal file
45
packages/CLPBN/clpbn/bp/CountingBP.h
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
#ifndef BP_COUNTING_BP_H
|
||||||
|
#define BP_COUNTING_BP_H
|
||||||
|
|
||||||
|
#include "SPSolver.h"
|
||||||
|
#include "LiftedFG.h"
|
||||||
|
|
||||||
|
class Factor;
|
||||||
|
class FgVarNode;
|
||||||
|
|
||||||
|
class CountingBPLink : public Link
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
CountingBPLink (Factor* f, FgVarNode* v, unsigned c) : Link (f, v)
|
||||||
|
{
|
||||||
|
edgeCount_ = c;
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned getNumberOfEdges (void) const { return edgeCount_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
unsigned edgeCount_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
class CountingBP : public SPSolver
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
CountingBP (FactorGraph& fg) : SPSolver (fg) { }
|
||||||
|
~CountingBP (void);
|
||||||
|
|
||||||
|
ParamSet getPosterioriOf (Vid) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
void initializeSolver (void);
|
||||||
|
void createLinks (void);
|
||||||
|
void deleteJunction (Factor*, FgVarNode*);
|
||||||
|
|
||||||
|
void maxResidualSchedule (void);
|
||||||
|
ParamSet getVar2FactorMsg (const Link*) const;
|
||||||
|
|
||||||
|
LiftedFG* lfg_;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // BP_COUNTING_BP_H
|
||||||
|
|
@ -1,5 +1,5 @@
|
|||||||
#ifndef BP_CPTENTRY_H
|
#ifndef BP_CPT_ENTRY_H
|
||||||
#define BP_CPTENTRY_H
|
#define BP_CPT_ENTRY_H
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
@ -10,62 +10,34 @@ using namespace std;
|
|||||||
class CptEntry
|
class CptEntry
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
CptEntry (unsigned, const vector<unsigned>&);
|
CptEntry (unsigned index, const DConf& conf)
|
||||||
|
{
|
||||||
|
index_ = index;
|
||||||
|
conf_ = conf;
|
||||||
|
}
|
||||||
|
|
||||||
unsigned getParameterIndex (void) const;
|
unsigned getParameterIndex (void) const { return index_; }
|
||||||
const vector<unsigned>& getParentConfigurations (void) const;
|
const DConf& getDomainConfiguration (void) const { return conf_; }
|
||||||
bool matchConstraints (const DomainConstr&) const;
|
|
||||||
bool matchConstraints (const vector<DomainConstr>&) const;
|
bool matchConstraints (const DConstraint& constr) const
|
||||||
|
{
|
||||||
|
return conf_[constr.first] == constr.second;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool matchConstraints (const vector<DConstraint>& constrs) const
|
||||||
|
{
|
||||||
|
for (unsigned j = 0; j < constrs.size(); j++) {
|
||||||
|
if (conf_[constrs[j].first] != constrs[j].second) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
unsigned index_;
|
unsigned index_;
|
||||||
vector<unsigned> confs_;
|
DConf conf_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#endif //BP_CPT_ENTRY_H
|
||||||
|
|
||||||
|
|
||||||
inline
|
|
||||||
CptEntry::CptEntry (unsigned index, const vector<unsigned>& confs)
|
|
||||||
{
|
|
||||||
index_ = index;
|
|
||||||
confs_ = confs;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline unsigned
|
|
||||||
CptEntry::getParameterIndex (void) const
|
|
||||||
{
|
|
||||||
return index_;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline const vector<unsigned>&
|
|
||||||
CptEntry::getParentConfigurations (void) const
|
|
||||||
{
|
|
||||||
return confs_;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline bool
|
|
||||||
CptEntry::matchConstraints (const DomainConstr& constr) const
|
|
||||||
{
|
|
||||||
return confs_[constr.first] == constr.second;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline bool
|
|
||||||
CptEntry::matchConstraints (const vector<DomainConstr>& constrs) const
|
|
||||||
{
|
|
||||||
for (unsigned j = 0; j < constrs.size(); j++) {
|
|
||||||
if (confs_[constrs[j].first] != constrs[j].second) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
@ -2,8 +2,8 @@
|
|||||||
#define BP_DISTRIBUTION_H
|
#define BP_DISTRIBUTION_H
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
|
||||||
|
|
||||||
|
#include "CptEntry.h"
|
||||||
#include "Shared.h"
|
#include "Shared.h"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
@ -11,16 +11,18 @@ using namespace std;
|
|||||||
struct Distribution
|
struct Distribution
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
Distribution (unsigned id)
|
Distribution (unsigned id, bool shared = false)
|
||||||
{
|
{
|
||||||
this->id = id;
|
this->id = id;
|
||||||
this->params = params;
|
this->params = params;
|
||||||
|
this->shared = shared;
|
||||||
}
|
}
|
||||||
|
|
||||||
Distribution (const ParamSet& params)
|
Distribution (const ParamSet& params, bool shared = false)
|
||||||
{
|
{
|
||||||
this->id = -1;
|
this->id = -1;
|
||||||
this->params = params;
|
this->params = params;
|
||||||
|
this->shared = shared;
|
||||||
}
|
}
|
||||||
|
|
||||||
void updateParameters (const ParamSet& params)
|
void updateParameters (const ParamSet& params)
|
||||||
@ -31,10 +33,11 @@ struct Distribution
|
|||||||
unsigned id;
|
unsigned id;
|
||||||
ParamSet params;
|
ParamSet params;
|
||||||
vector<CptEntry> entries;
|
vector<CptEntry> entries;
|
||||||
|
bool shared;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
DISALLOW_COPY_AND_ASSIGN (Distribution);
|
DISALLOW_COPY_AND_ASSIGN (Distribution);
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif
|
#endif //BP_DISTRIBUTION_H
|
||||||
|
|
||||||
|
@ -1,37 +1,37 @@
|
|||||||
#include <iostream>
|
|
||||||
#include <sstream>
|
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
#include "Factor.h"
|
#include "Factor.h"
|
||||||
#include "FgVarNode.h"
|
#include "FgVarNode.h"
|
||||||
|
|
||||||
|
|
||||||
int Factor::indexCount_ = 0;
|
Factor::Factor (const Factor& g)
|
||||||
|
{
|
||||||
Factor::Factor (FgVarNode* var) {
|
copyFactor (g);
|
||||||
vs_.push_back (var);
|
|
||||||
int nParams = var->getDomainSize();
|
|
||||||
// create a uniform distribution
|
|
||||||
double val = 1.0 / nParams;
|
|
||||||
ps_ = ParamSet (nParams, val);
|
|
||||||
id_ = indexCount_;
|
|
||||||
indexCount_ ++;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Factor::Factor (const FgVarSet& vars) {
|
Factor::Factor (FgVarNode* var)
|
||||||
vs_ = vars;
|
{
|
||||||
|
Factor (FgVarSet() = {var});
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Factor::Factor (const FgVarSet& vars)
|
||||||
|
{
|
||||||
|
vars_ = vars;
|
||||||
int nParams = 1;
|
int nParams = 1;
|
||||||
for (unsigned i = 0; i < vs_.size(); i++) {
|
for (unsigned i = 0; i < vars_.size(); i++) {
|
||||||
nParams *= vs_[i]->getDomainSize();
|
nParams *= vars_[i]->getDomainSize();
|
||||||
}
|
}
|
||||||
// create a uniform distribution
|
// create a uniform distribution
|
||||||
double val = 1.0 / nParams;
|
double val = 1.0 / nParams;
|
||||||
ps_ = ParamSet (nParams, val);
|
dist_ = new Distribution (ParamSet (nParams, val));
|
||||||
id_ = indexCount_;
|
|
||||||
indexCount_ ++;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -39,10 +39,17 @@ Factor::Factor (const FgVarSet& vars) {
|
|||||||
Factor::Factor (FgVarNode* var,
|
Factor::Factor (FgVarNode* var,
|
||||||
const ParamSet& params)
|
const ParamSet& params)
|
||||||
{
|
{
|
||||||
vs_.push_back (var);
|
vars_.push_back (var);
|
||||||
ps_ = params;
|
dist_ = new Distribution (params);
|
||||||
id_ = indexCount_;
|
}
|
||||||
indexCount_ ++;
|
|
||||||
|
|
||||||
|
|
||||||
|
Factor::Factor (FgVarSet& vars,
|
||||||
|
Distribution* dist)
|
||||||
|
{
|
||||||
|
vars_ = vars;
|
||||||
|
dist_ = dist;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -50,42 +57,8 @@ Factor::Factor (FgVarNode* var,
|
|||||||
Factor::Factor (const FgVarSet& vars,
|
Factor::Factor (const FgVarSet& vars,
|
||||||
const ParamSet& params)
|
const ParamSet& params)
|
||||||
{
|
{
|
||||||
vs_ = vars;
|
vars_ = vars;
|
||||||
ps_ = params;
|
dist_ = new Distribution (params);
|
||||||
id_ = indexCount_;
|
|
||||||
indexCount_ ++;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
const FgVarSet&
|
|
||||||
Factor::getFgVarNodes (void) const
|
|
||||||
{
|
|
||||||
return vs_;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
FgVarSet&
|
|
||||||
Factor::getFgVarNodes (void)
|
|
||||||
{
|
|
||||||
return vs_;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
const ParamSet&
|
|
||||||
Factor::getParameters (void) const
|
|
||||||
{
|
|
||||||
return ps_;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
ParamSet&
|
|
||||||
Factor::getParameters (void)
|
|
||||||
{
|
|
||||||
return ps_;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -93,75 +66,95 @@ Factor::getParameters (void)
|
|||||||
void
|
void
|
||||||
Factor::setParameters (const ParamSet& params)
|
Factor::setParameters (const ParamSet& params)
|
||||||
{
|
{
|
||||||
//cout << "ps size: " << ps_.size() << endl;
|
assert (dist_->params.size() == params.size());
|
||||||
//cout << "params size: " << params.size() << endl;
|
dist_->updateParameters (params);
|
||||||
assert (ps_.size() == params.size());
|
|
||||||
ps_ = params;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Factor&
|
void
|
||||||
Factor::operator= (const Factor& g)
|
Factor::copyFactor (const Factor& g)
|
||||||
{
|
{
|
||||||
FgVarSet vars = g.getFgVarNodes();
|
vars_ = g.getFgVarNodes();
|
||||||
ParamSet params = g.getParameters();
|
dist_ = new Distribution (g.getDistribution()->params);
|
||||||
return *this;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Factor&
|
void
|
||||||
Factor::operator*= (const Factor& g)
|
Factor::multiplyByFactor (const Factor& g, const vector<CptEntry>* entries)
|
||||||
{
|
{
|
||||||
FgVarSet gVs = g.getFgVarNodes();
|
if (vars_.size() == 0) {
|
||||||
|
copyFactor (g);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const FgVarSet& gVs = g.getFgVarNodes();
|
||||||
const ParamSet& gPs = g.getParameters();
|
const ParamSet& gPs = g.getParameters();
|
||||||
|
|
||||||
bool hasCommonVars = false;
|
bool factorsAreEqual = true;
|
||||||
vector<int> varIndexes;
|
if (gVs.size() == vars_.size()) {
|
||||||
for (unsigned i = 0; i < gVs.size(); i++) {
|
for (unsigned i = 0; i < vars_.size(); i++) {
|
||||||
int idx = getIndexOf (gVs[i]);
|
if (gVs[i] != vars_[i]) {
|
||||||
if (idx == -1) {
|
factorsAreEqual = false;
|
||||||
insertVariable (gVs[i]);
|
break;
|
||||||
varIndexes.push_back (vs_.size() - 1);
|
|
||||||
} else {
|
|
||||||
hasCommonVars = true;
|
|
||||||
varIndexes.push_back (idx);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (hasCommonVars) {
|
|
||||||
vector<int> offsets (gVs.size());
|
|
||||||
offsets[gVs.size() - 1] = 1;
|
|
||||||
for (int i = gVs.size() - 2; i >= 0; i--) {
|
|
||||||
offsets[i] = offsets[i + 1] * gVs[i + 1]->getDomainSize();
|
|
||||||
}
|
|
||||||
vector<CptEntry> entries = getCptEntries();
|
|
||||||
for (unsigned i = 0; i < entries.size(); i++) {
|
|
||||||
int idx = 0;
|
|
||||||
const DomainConf conf = entries[i].getParentConfigurations();
|
|
||||||
for (unsigned j = 0; j < varIndexes.size(); j++) {
|
|
||||||
idx += offsets[j] * conf[varIndexes[j]];
|
|
||||||
}
|
}
|
||||||
//cout << "ps_[" << i << "] = " << ps_[i] << " * " ;
|
|
||||||
//cout << gPs[idx] << " , idx = " << idx << endl;
|
|
||||||
ps_[i] = ps_[i] * gPs[idx];
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// if the originally factors doesn't have common factors.
|
factorsAreEqual = false;
|
||||||
// we don't have to make domain comparations
|
}
|
||||||
unsigned idx = 0;
|
|
||||||
for (unsigned i = 0; i < ps_.size(); i++) {
|
if (factorsAreEqual) {
|
||||||
//cout << "ps_[" << i << "] = " << ps_[i] << " * " ;
|
// optimization: if the factors contain the same set of variables,
|
||||||
//cout << gPs[idx] << " , idx = " << idx << endl;
|
// we can do 1 to 1 operations on the parameteres
|
||||||
ps_[i] = ps_[i] * gPs[idx];
|
for (unsigned i = 0; i < dist_->params.size(); i++) {
|
||||||
idx ++;
|
dist_->params[i] *= gPs[i];
|
||||||
if (idx >= gPs.size()) {
|
}
|
||||||
idx = 0;
|
} else {
|
||||||
|
bool hasCommonVars = false;
|
||||||
|
vector<unsigned> gVsIndexes;
|
||||||
|
for (unsigned i = 0; i < gVs.size(); i++) {
|
||||||
|
int idx = getIndexOf (gVs[i]);
|
||||||
|
if (idx == -1) {
|
||||||
|
insertVariable (gVs[i]);
|
||||||
|
gVsIndexes.push_back (vars_.size() - 1);
|
||||||
|
} else {
|
||||||
|
hasCommonVars = true;
|
||||||
|
gVsIndexes.push_back (idx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (hasCommonVars) {
|
||||||
|
vector<unsigned> gVsOffsets (gVs.size());
|
||||||
|
gVsOffsets[gVs.size() - 1] = 1;
|
||||||
|
for (int i = gVs.size() - 2; i >= 0; i--) {
|
||||||
|
gVsOffsets[i] = gVsOffsets[i + 1] * gVs[i + 1]->getDomainSize();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (entries == 0) {
|
||||||
|
entries = &getCptEntries();
|
||||||
|
}
|
||||||
|
|
||||||
|
for (unsigned i = 0; i < entries->size(); i++) {
|
||||||
|
unsigned idx = 0;
|
||||||
|
const DConf& conf = (*entries)[i].getDomainConfiguration();
|
||||||
|
for (unsigned j = 0; j < gVsIndexes.size(); j++) {
|
||||||
|
idx += gVsOffsets[j] * conf[ gVsIndexes[j] ];
|
||||||
|
}
|
||||||
|
dist_->params[i] = dist_->params[i] * gPs[idx];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// optimization: if the original factors doesn't have common variables,
|
||||||
|
// we don't need to marry the states of the common variables
|
||||||
|
unsigned count = 0;
|
||||||
|
for (unsigned i = 0; i < dist_->params.size(); i++) {
|
||||||
|
dist_->params[i] *= gPs[count];
|
||||||
|
count ++;
|
||||||
|
if (count >= gPs.size()) {
|
||||||
|
count = 0;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return *this;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -169,81 +162,109 @@ Factor::operator*= (const Factor& g)
|
|||||||
void
|
void
|
||||||
Factor::insertVariable (FgVarNode* var)
|
Factor::insertVariable (FgVarNode* var)
|
||||||
{
|
{
|
||||||
int c = 0;
|
assert (getIndexOf (var) == -1);
|
||||||
ParamSet newPs (ps_.size() * var->getDomainSize());
|
ParamSet newPs;
|
||||||
for (unsigned i = 0; i < ps_.size(); i++) {
|
newPs.reserve (dist_->params.size() * var->getDomainSize());
|
||||||
for (int j = 0; j < var->getDomainSize(); j++) {
|
for (unsigned i = 0; i < dist_->params.size(); i++) {
|
||||||
newPs[c] = ps_[i];
|
for (unsigned j = 0; j < var->getDomainSize(); j++) {
|
||||||
c ++;
|
newPs.push_back (dist_->params[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
vs_.push_back (var);
|
vars_.push_back (var);
|
||||||
ps_ = newPs;
|
dist_->updateParameters (newPs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
Factor::marginalizeVariable (const FgVarNode* var) {
|
Factor::removeVariable (const FgVarNode* var)
|
||||||
int varIndex = getIndexOf (var);
|
|
||||||
marginalizeVariable (varIndex);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Factor::marginalizeVariable (unsigned varIndex)
|
|
||||||
{
|
{
|
||||||
assert (varIndex >= 0 && varIndex < vs_.size());
|
int varIndex = getIndexOf (var);
|
||||||
int distOffset = 1;
|
assert (varIndex >= 0 && varIndex < (int)vars_.size());
|
||||||
int leftVarOffset = 1;
|
|
||||||
for (unsigned i = vs_.size() - 1; i > varIndex; i--) {
|
// number of parameters separating a different state of `var',
|
||||||
distOffset *= vs_[i]->getDomainSize();
|
// with the states of the remaining variables fixed
|
||||||
leftVarOffset *= vs_[i]->getDomainSize();
|
unsigned varOffset = 1;
|
||||||
}
|
|
||||||
leftVarOffset *= vs_[varIndex]->getDomainSize();
|
// number of parameters separating a different state of the variable
|
||||||
|
// on the left of `var', with the states of the remaining vars fixed
|
||||||
|
unsigned leftVarOffset = 1;
|
||||||
|
|
||||||
|
for (int i = vars_.size() - 1; i > varIndex; i--) {
|
||||||
|
varOffset *= vars_[i]->getDomainSize();
|
||||||
|
leftVarOffset *= vars_[i]->getDomainSize();
|
||||||
|
}
|
||||||
|
leftVarOffset *= vars_[varIndex]->getDomainSize();
|
||||||
|
|
||||||
|
unsigned offset = 0;
|
||||||
|
unsigned count1 = 0;
|
||||||
|
unsigned count2 = 0;
|
||||||
|
unsigned newPsSize = dist_->params.size() / vars_[varIndex]->getDomainSize();
|
||||||
|
|
||||||
int ds = vs_[varIndex]->getDomainSize();
|
|
||||||
int count = 0;
|
|
||||||
int offset = 0;
|
|
||||||
int startIndex = 0;
|
|
||||||
int currDomainIdx = 0;
|
|
||||||
unsigned newPsSize = ps_.size() / ds;
|
|
||||||
ParamSet newPs;
|
ParamSet newPs;
|
||||||
newPs.reserve (newPsSize);
|
newPs.reserve (newPsSize);
|
||||||
|
|
||||||
stringstream ss;
|
// stringstream ss;
|
||||||
ss << "marginalizing " << vs_[varIndex]->getLabel();
|
// ss << "marginalizing " << vars_[varIndex]->getLabel();
|
||||||
ss << " from factor " << getLabel() << endl;
|
// ss << " from factor " << getLabel() << endl;
|
||||||
while (newPs.size() < newPsSize) {
|
while (newPs.size() < newPsSize) {
|
||||||
ss << " sum = ";
|
// ss << " sum = ";
|
||||||
double sum = 0.0;
|
double sum = 0.0;
|
||||||
for (int j = 0; j < ds; j++) {
|
for (unsigned i = 0; i < vars_[varIndex]->getDomainSize(); i++) {
|
||||||
if (j != 0) ss << " + ";
|
// if (i != 0) ss << " + ";
|
||||||
ss << ps_[offset];
|
// ss << dist_->params[offset];
|
||||||
sum = sum + ps_[offset];
|
sum += dist_->params[offset];
|
||||||
offset = offset + distOffset;
|
offset += varOffset;
|
||||||
}
|
}
|
||||||
newPs.push_back (sum);
|
newPs.push_back (sum);
|
||||||
count ++;
|
count1 ++;
|
||||||
if (varIndex == vs_.size() - 1) {
|
if (varIndex == (int)vars_.size() - 1) {
|
||||||
offset = count * ds;
|
offset = count1 * vars_[varIndex]->getDomainSize();
|
||||||
} else {
|
} else {
|
||||||
offset = offset - distOffset + 1;
|
if (((offset - varOffset + 1) % leftVarOffset) == 0) {
|
||||||
if ((offset % leftVarOffset) == 0) {
|
count1 = 0;
|
||||||
currDomainIdx ++;
|
count2 ++;
|
||||||
startIndex = leftVarOffset * currDomainIdx;
|
|
||||||
offset = startIndex;
|
|
||||||
count = 0;
|
|
||||||
} else {
|
|
||||||
offset = startIndex + count;
|
|
||||||
}
|
}
|
||||||
|
offset = (leftVarOffset * count2) + count1;
|
||||||
}
|
}
|
||||||
ss << " = " << sum << endl;
|
// ss << " = " << sum << endl;
|
||||||
}
|
}
|
||||||
//cout << ss.str() << endl;
|
// cout << ss.str() << endl;
|
||||||
ps_ = newPs;
|
vars_.erase (vars_.begin() + varIndex);
|
||||||
vs_.erase (vs_.begin() + varIndex);
|
dist_->updateParameters (newPs);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
const vector<CptEntry>&
|
||||||
|
Factor::getCptEntries (void) const
|
||||||
|
{
|
||||||
|
if (dist_->entries.size() == 0) {
|
||||||
|
vector<DConf> confs (dist_->params.size());
|
||||||
|
for (unsigned i = 0; i < dist_->params.size(); i++) {
|
||||||
|
confs[i].resize (vars_.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned nReps = 1;
|
||||||
|
for (int i = vars_.size() - 1; i >= 0; i--) {
|
||||||
|
unsigned index = 0;
|
||||||
|
while (index < dist_->params.size()) {
|
||||||
|
for (unsigned j = 0; j < vars_[i]->getDomainSize(); j++) {
|
||||||
|
for (unsigned r = 0; r < nReps; r++) {
|
||||||
|
confs[index][i] = j;
|
||||||
|
index++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nReps *= vars_[i]->getDomainSize();
|
||||||
|
}
|
||||||
|
dist_->entries.clear();
|
||||||
|
dist_->entries.reserve (dist_->params.size());
|
||||||
|
for (unsigned i = 0; i < dist_->params.size(); i++) {
|
||||||
|
dist_->entries.push_back (CptEntry (i, confs[i]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return dist_->entries;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -252,11 +273,10 @@ string
|
|||||||
Factor::getLabel (void) const
|
Factor::getLabel (void) const
|
||||||
{
|
{
|
||||||
stringstream ss;
|
stringstream ss;
|
||||||
ss << "f(" ;
|
ss << "Φ(" ;
|
||||||
// ss << "Φ(" ;
|
for (unsigned i = 0; i < vars_.size(); i++) {
|
||||||
for (unsigned i = 0; i < vs_.size(); i++) {
|
if (i != 0) ss << "," ;
|
||||||
if (i != 0) ss << ", " ;
|
ss << vars_[i]->getLabel();
|
||||||
ss << "v" << vs_[i]->getVarId();
|
|
||||||
}
|
}
|
||||||
ss << ")" ;
|
ss << ")" ;
|
||||||
return ss.str();
|
return ss.str();
|
||||||
@ -264,62 +284,24 @@ Factor::getLabel (void) const
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
string
|
void
|
||||||
Factor::toString (void) const
|
Factor::printFactor (void)
|
||||||
{
|
{
|
||||||
stringstream ss;
|
stringstream ss;
|
||||||
ss << "vars: " ;
|
ss << getLabel() << endl;
|
||||||
for (unsigned i = 0; i < vs_.size(); i++) {
|
ss << "--------------------" << endl;
|
||||||
if (i != 0) ss << ", " ;
|
VarSet vs;
|
||||||
ss << "v" << vs_[i]->getVarId();
|
for (unsigned i = 0; i < vars_.size(); i++) {
|
||||||
|
vs.push_back (vars_[i]);
|
||||||
}
|
}
|
||||||
ss << endl;
|
vector<string> domainConfs = Util::getInstantiations (vs);
|
||||||
vector<CptEntry> entries = getCptEntries();
|
const vector<CptEntry>& entries = getCptEntries();
|
||||||
for (unsigned i = 0; i < entries.size(); i++) {
|
for (unsigned i = 0; i < entries.size(); i++) {
|
||||||
ss << "Φ(" ;
|
ss << "Φ(" << domainConfs[i] << ")" ;
|
||||||
char s = 'a' ;
|
unsigned idx = entries[i].getParameterIndex();
|
||||||
const DomainConf& conf = entries[i].getParentConfigurations();
|
ss << " = " << dist_->params[idx] << endl;
|
||||||
for (unsigned j = 0; j < conf.size(); j++) {
|
|
||||||
if (j != 0) ss << "," ;
|
|
||||||
ss << s << conf[j] + 1;
|
|
||||||
s++;
|
|
||||||
}
|
|
||||||
ss << ") = " << ps_[entries[i].getParameterIndex()] << endl;
|
|
||||||
}
|
}
|
||||||
return ss.str();
|
cout << ss.str();
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
vector<CptEntry>
|
|
||||||
Factor::getCptEntries (void) const
|
|
||||||
{
|
|
||||||
vector<DomainConf> confs (ps_.size());
|
|
||||||
for (unsigned i = 0; i < ps_.size(); i++) {
|
|
||||||
confs[i].resize (vs_.size());
|
|
||||||
}
|
|
||||||
|
|
||||||
int nReps = 1;
|
|
||||||
for (int i = vs_.size() - 1; i >= 0; i--) {
|
|
||||||
unsigned index = 0;
|
|
||||||
while (index < ps_.size()) {
|
|
||||||
for (int j = 0; j < vs_[i]->getDomainSize(); j++) {
|
|
||||||
for (int r = 0; r < nReps; r++) {
|
|
||||||
confs[index][i] = j;
|
|
||||||
index++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
nReps *= vs_[i]->getDomainSize();
|
|
||||||
}
|
|
||||||
|
|
||||||
vector<CptEntry> entries;
|
|
||||||
for (unsigned i = 0; i < ps_.size(); i++) {
|
|
||||||
for (unsigned j = 0; j < vs_.size(); j++) {
|
|
||||||
}
|
|
||||||
entries.push_back (CptEntry (i, confs[i]));
|
|
||||||
}
|
|
||||||
return entries;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -327,20 +309,11 @@ Factor::getCptEntries (void) const
|
|||||||
int
|
int
|
||||||
Factor::getIndexOf (const FgVarNode* var) const
|
Factor::getIndexOf (const FgVarNode* var) const
|
||||||
{
|
{
|
||||||
for (unsigned i = 0; i < vs_.size(); i++) {
|
for (unsigned i = 0; i < vars_.size(); i++) {
|
||||||
if (vs_[i] == var) {
|
if (vars_[i] == var) {
|
||||||
return i;
|
return i;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Factor operator* (const Factor& f, const Factor& g)
|
|
||||||
{
|
|
||||||
Factor r = f;
|
|
||||||
r *= g;
|
|
||||||
return r;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
@ -3,43 +3,46 @@
|
|||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "Distribution.h"
|
||||||
#include "CptEntry.h"
|
#include "CptEntry.h"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
class FgVarNode;
|
class FgVarNode;
|
||||||
|
class Distribution;
|
||||||
|
|
||||||
class Factor
|
class Factor
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
|
Factor (void) { }
|
||||||
|
Factor (const Factor&);
|
||||||
Factor (FgVarNode*);
|
Factor (FgVarNode*);
|
||||||
Factor (const FgVarSet&);
|
Factor (CFgVarSet);
|
||||||
Factor (FgVarNode*, const ParamSet&);
|
Factor (FgVarNode*, const ParamSet&);
|
||||||
Factor (const FgVarSet&, const ParamSet&);
|
Factor (FgVarSet&, Distribution*);
|
||||||
|
Factor (CFgVarSet, CParamSet);
|
||||||
|
|
||||||
const FgVarSet& getFgVarNodes (void) const;
|
void setParameters (CParamSet);
|
||||||
FgVarSet& getFgVarNodes (void);
|
void copyFactor (const Factor& f);
|
||||||
const ParamSet& getParameters (void) const;
|
void multiplyByFactor (const Factor& f, const vector<CptEntry>* = 0);
|
||||||
ParamSet& getParameters (void);
|
void insertVariable (FgVarNode* index);
|
||||||
void setParameters (const ParamSet&);
|
void removeVariable (const FgVarNode* var);
|
||||||
Factor& operator= (const Factor& f);
|
const vector<CptEntry>& getCptEntries (void) const;
|
||||||
Factor& operator*= (const Factor& f);
|
string getLabel (void) const;
|
||||||
void insertVariable (FgVarNode* index);
|
void printFactor (void);
|
||||||
void marginalizeVariable (const FgVarNode* var);
|
|
||||||
void marginalizeVariable (unsigned);
|
CFgVarSet getFgVarNodes (void) const { return vars_; }
|
||||||
string getLabel (void) const;
|
CParamSet getParameters (void) const { return dist_->params; }
|
||||||
string toString (void) const;
|
Distribution* getDistribution (void) const { return dist_; }
|
||||||
|
unsigned getIndex (void) const { return index_; }
|
||||||
|
void setIndex (unsigned index) { index_ = index; }
|
||||||
|
void freeDistribution (void) { delete dist_; dist_ = 0;}
|
||||||
|
int getIndexOf (const FgVarNode*) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
vector<CptEntry> getCptEntries() const;
|
FgVarSet vars_;
|
||||||
int getIndexOf (const FgVarNode*) const;
|
Distribution* dist_;
|
||||||
|
unsigned index_;
|
||||||
FgVarSet vs_;
|
|
||||||
ParamSet ps_;
|
|
||||||
int id_;
|
|
||||||
static int indexCount_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
Factor operator* (const Factor&, const Factor&);
|
#endif //BP_FACTOR_H
|
||||||
|
|
||||||
#endif
|
|
||||||
|
@ -1,23 +1,26 @@
|
|||||||
|
#include <cstdlib>
|
||||||
|
#include <vector>
|
||||||
|
#include <set>
|
||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <vector>
|
|
||||||
#include <cstdlib>
|
|
||||||
|
|
||||||
#include "FactorGraph.h"
|
#include "FactorGraph.h"
|
||||||
#include "FgVarNode.h"
|
#include "FgVarNode.h"
|
||||||
#include "Factor.h"
|
#include "Factor.h"
|
||||||
|
#include "BayesNet.h"
|
||||||
|
|
||||||
|
|
||||||
FactorGraph::FactorGraph (const char* fileName)
|
FactorGraph::FactorGraph (const char* fileName)
|
||||||
{
|
{
|
||||||
string line;
|
|
||||||
ifstream is (fileName);
|
ifstream is (fileName);
|
||||||
if (!is.is_open()) {
|
if (!is.is_open()) {
|
||||||
cerr << "error: cannot read from file " + std::string (fileName) << endl;
|
cerr << "error: cannot read from file " + std::string (fileName) << endl;
|
||||||
abort();
|
abort();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
string line;
|
||||||
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
||||||
getline (is, line);
|
getline (is, line);
|
||||||
if (line != "MARKOV") {
|
if (line != "MARKOV") {
|
||||||
@ -39,7 +42,7 @@ FactorGraph::FactorGraph (const char* fileName)
|
|||||||
|
|
||||||
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
||||||
for (int i = 0; i < nVars; i++) {
|
for (int i = 0; i < nVars; i++) {
|
||||||
varNodes_.push_back (new FgVarNode (i, domainSizes[i]));
|
addVariable (new FgVarNode (i, domainSizes[i]));
|
||||||
}
|
}
|
||||||
|
|
||||||
int nFactors;
|
int nFactors;
|
||||||
@ -50,11 +53,11 @@ FactorGraph::FactorGraph (const char* fileName)
|
|||||||
is >> nFactorVars;
|
is >> nFactorVars;
|
||||||
FgVarSet factorVars;
|
FgVarSet factorVars;
|
||||||
for (int j = 0; j < nFactorVars; j++) {
|
for (int j = 0; j < nFactorVars; j++) {
|
||||||
int varId;
|
int vid;
|
||||||
is >> varId;
|
is >> vid;
|
||||||
FgVarNode* var = getVariableById (varId);
|
FgVarNode* var = getFgVarNode (vid);
|
||||||
if (var == 0) {
|
if (!var) {
|
||||||
cerr << "error: invalid variable identifier (" << varId << ")" << endl;
|
cerr << "error: invalid variable identifier (" << vid << ")" << endl;
|
||||||
abort();
|
abort();
|
||||||
}
|
}
|
||||||
factorVars.push_back (var);
|
factorVars.push_back (var);
|
||||||
@ -87,6 +90,33 @@ FactorGraph::FactorGraph (const char* fileName)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
FactorGraph::FactorGraph (const BayesNet& bn)
|
||||||
|
{
|
||||||
|
const BnNodeSet& nodes = bn.getBayesNodes();
|
||||||
|
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||||
|
FgVarNode* varNode = new FgVarNode (nodes[i]);
|
||||||
|
varNode->setIndex (i);
|
||||||
|
addVariable (varNode);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||||
|
const BnNodeSet& parents = nodes[i]->getParents();
|
||||||
|
if (!(nodes[i]->hasEvidence() && parents.size() == 0)) {
|
||||||
|
FgVarSet factorVars = { varNodes_[nodes[i]->getIndex()] };
|
||||||
|
for (unsigned j = 0; j < parents.size(); j++) {
|
||||||
|
factorVars.push_back (varNodes_[parents[j]->getIndex()]);
|
||||||
|
}
|
||||||
|
Factor* f = new Factor (factorVars, nodes[i]->getDistribution());
|
||||||
|
factors_.push_back (f);
|
||||||
|
for (unsigned j = 0; j < factorVars.size(); j++) {
|
||||||
|
factorVars[j]->addFactor (f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
FactorGraph::~FactorGraph (void)
|
FactorGraph::~FactorGraph (void)
|
||||||
{
|
{
|
||||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||||
@ -99,18 +129,67 @@ FactorGraph::~FactorGraph (void)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
FgVarSet
|
void
|
||||||
FactorGraph::getFgVarNodes (void) const
|
FactorGraph::addVariable (FgVarNode* varNode)
|
||||||
{
|
{
|
||||||
return varNodes_;
|
varNodes_.push_back (varNode);
|
||||||
|
varNode->setIndex (varNodes_.size() - 1);
|
||||||
|
indexMap_.insert (make_pair (varNode->getVarId(), varNodes_.size() - 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
vector<Factor*>
|
void
|
||||||
FactorGraph::getFactors (void) const
|
FactorGraph::removeVariable (const FgVarNode* var)
|
||||||
{
|
{
|
||||||
return factors_;
|
if (varNodes_[varNodes_.size() - 1] == var) {
|
||||||
|
varNodes_.pop_back();
|
||||||
|
} else {
|
||||||
|
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||||
|
if (varNodes_[i] == var) {
|
||||||
|
varNodes_.erase (varNodes_.begin() + i);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert (false);
|
||||||
|
}
|
||||||
|
indexMap_.erase (indexMap_.find (var->getVarId()));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
FactorGraph::addFactor (Factor* f)
|
||||||
|
{
|
||||||
|
factors_.push_back (f);
|
||||||
|
const FgVarSet& factorVars = f->getFgVarNodes();
|
||||||
|
for (unsigned i = 0; i < factorVars.size(); i++) {
|
||||||
|
factorVars[i]->addFactor (f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
FactorGraph::removeFactor (const Factor* f)
|
||||||
|
{
|
||||||
|
const FgVarSet& factorVars = f->getFgVarNodes();
|
||||||
|
for (unsigned i = 0; i < factorVars.size(); i++) {
|
||||||
|
if (factorVars[i]) {
|
||||||
|
factorVars[i]->removeFactor (f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (factors_[factors_.size() - 1] == f) {
|
||||||
|
factors_.pop_back();
|
||||||
|
} else {
|
||||||
|
for (unsigned i = 0; i < factors_.size(); i++) {
|
||||||
|
if (factors_[i] == f) {
|
||||||
|
factors_.erase (factors_.begin() + i);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert (false);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -127,47 +206,142 @@ FactorGraph::getVariables (void) const
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
FgVarNode*
|
Variable*
|
||||||
FactorGraph::getVariableById (unsigned id) const
|
FactorGraph::getVariable (Vid vid) const
|
||||||
{
|
{
|
||||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
return getFgVarNode (vid);
|
||||||
if (varNodes_[i]->getVarId() == id) {
|
|
||||||
return varNodes_[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
FgVarNode*
|
|
||||||
FactorGraph::getVariableByLabel (string label) const
|
|
||||||
{
|
|
||||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
|
||||||
stringstream ss;
|
|
||||||
ss << "v" << varNodes_[i]->getVarId();
|
|
||||||
if (ss.str() == label) {
|
|
||||||
return varNodes_[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return 0;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
FactorGraph::printFactorGraph (void) const
|
FactorGraph::setIndexes (void)
|
||||||
|
{
|
||||||
|
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||||
|
varNodes_[i]->setIndex (i);
|
||||||
|
}
|
||||||
|
for (unsigned i = 0; i < factors_.size(); i++) {
|
||||||
|
factors_[i]->setIndex (i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
FactorGraph::freeDistributions (void)
|
||||||
|
{
|
||||||
|
set<Distribution*> dists;
|
||||||
|
for (unsigned i = 0; i < factors_.size(); i++) {
|
||||||
|
dists.insert (factors_[i]->getDistribution());
|
||||||
|
}
|
||||||
|
for (set<Distribution*>::iterator it = dists.begin();
|
||||||
|
it != dists.end(); it++) {
|
||||||
|
delete *it;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
FactorGraph::printGraphicalModel (void) const
|
||||||
{
|
{
|
||||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||||
cout << "variable number " << varNodes_[i]->getIndex() << endl;
|
cout << "variable number " << varNodes_[i]->getIndex() << endl;
|
||||||
cout << "Id = " << varNodes_[i]->getVarId() << endl;
|
cout << "Id = " << varNodes_[i]->getVarId() << endl;
|
||||||
|
cout << "Label = " << varNodes_[i]->getLabel() << endl;
|
||||||
cout << "Domain size = " << varNodes_[i]->getDomainSize() << endl;
|
cout << "Domain size = " << varNodes_[i]->getDomainSize() << endl;
|
||||||
cout << "Evidence = " << varNodes_[i]->getEvidence() << endl;
|
cout << "Evidence = " << varNodes_[i]->getEvidence() << endl;
|
||||||
cout << endl;
|
cout << "Factors = " ;
|
||||||
|
for (unsigned j = 0; j < varNodes_[i]->getFactors().size(); j++) {
|
||||||
|
cout << varNodes_[i]->getFactors()[j]->getLabel() << " " ;
|
||||||
|
}
|
||||||
|
cout << endl << endl;
|
||||||
}
|
}
|
||||||
cout << endl;
|
|
||||||
for (unsigned i = 0; i < factors_.size(); i++) {
|
for (unsigned i = 0; i < factors_.size(); i++) {
|
||||||
cout << factors_[i]->toString() << endl;
|
factors_[i]->printFactor();
|
||||||
|
cout << endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
FactorGraph::exportToDotFormat (const char* fileName) const
|
||||||
|
{
|
||||||
|
ofstream out (fileName);
|
||||||
|
if (!out.is_open()) {
|
||||||
|
cerr << "error: cannot open file to write at " ;
|
||||||
|
cerr << "FactorGraph::exportToDotFile()" << endl;
|
||||||
|
abort();
|
||||||
|
}
|
||||||
|
|
||||||
|
out << "graph \"" << fileName << "\" {" << endl;
|
||||||
|
|
||||||
|
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||||
|
if (varNodes_[i]->hasEvidence()) {
|
||||||
|
out << '"' << varNodes_[i]->getLabel() << '"' ;
|
||||||
|
out << " [style=filled, fillcolor=yellow]" << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (unsigned i = 0; i < factors_.size(); i++) {
|
||||||
|
out << '"' << factors_[i]->getLabel() << '"' ;
|
||||||
|
out << " [label=\"" << factors_[i]->getLabel() << "\\n(";
|
||||||
|
out << factors_[i]->getDistribution()->id << ")" << "\"" ;
|
||||||
|
out << ", shape=box]" << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (unsigned i = 0; i < factors_.size(); i++) {
|
||||||
|
CFgVarSet myVars = factors_[i]->getFgVarNodes();
|
||||||
|
for (unsigned j = 0; j < myVars.size(); j++) {
|
||||||
|
out << '"' << factors_[i]->getLabel() << '"' ;
|
||||||
|
out << " -- " ;
|
||||||
|
out << '"' << myVars[j]->getLabel() << '"' << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
out << "}" << endl;
|
||||||
|
out.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
FactorGraph::exportToUaiFormat (const char* fileName) const
|
||||||
|
{
|
||||||
|
ofstream out (fileName);
|
||||||
|
if (!out.is_open()) {
|
||||||
|
cerr << "error: cannot open file to write at " ;
|
||||||
|
cerr << "FactorGraph::exportToUaiFormat()" << endl;
|
||||||
|
abort();
|
||||||
|
}
|
||||||
|
|
||||||
|
out << "MARKOV" << endl;
|
||||||
|
out << varNodes_.size() << endl;
|
||||||
|
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||||
|
out << varNodes_[i]->getDomainSize() << " " ;
|
||||||
|
}
|
||||||
|
out << endl;
|
||||||
|
|
||||||
|
out << factors_.size() << endl;
|
||||||
|
for (unsigned i = 0; i < factors_.size(); i++) {
|
||||||
|
CFgVarSet factorVars = factors_[i]->getFgVarNodes();
|
||||||
|
out << factorVars.size();
|
||||||
|
for (unsigned j = 0; j < factorVars.size(); j++) {
|
||||||
|
out << " " << factorVars[j]->getIndex();
|
||||||
|
}
|
||||||
|
out << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (unsigned i = 0; i < factors_.size(); i++) {
|
||||||
|
CParamSet params = factors_[i]->getParameters();
|
||||||
|
out << endl << params.size() << endl << " " ;
|
||||||
|
for (unsigned j = 0; j < params.size(); j++) {
|
||||||
|
out << params[j] << " " ;
|
||||||
|
}
|
||||||
|
out << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
out.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
@ -1,8 +1,7 @@
|
|||||||
#ifndef BP_FACTORGRAPH_H
|
#ifndef BP_FACTOR_GRAPH_H
|
||||||
#define BP_FACTORGRAPH_H
|
#define BP_FACTOR_GRAPH_H
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
|
||||||
|
|
||||||
#include "GraphicalModel.h"
|
#include "GraphicalModel.h"
|
||||||
#include "Shared.h"
|
#include "Shared.h"
|
||||||
@ -11,25 +10,48 @@ using namespace std;
|
|||||||
|
|
||||||
class FgVarNode;
|
class FgVarNode;
|
||||||
class Factor;
|
class Factor;
|
||||||
|
class BayesNet;
|
||||||
|
|
||||||
class FactorGraph : public GraphicalModel
|
class FactorGraph : public GraphicalModel
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
FactorGraph (const char* fileName);
|
FactorGraph (void) {};
|
||||||
|
FactorGraph (const char*);
|
||||||
|
FactorGraph (const BayesNet&);
|
||||||
~FactorGraph (void);
|
~FactorGraph (void);
|
||||||
|
|
||||||
FgVarSet getFgVarNodes (void) const;
|
void addVariable (FgVarNode*);
|
||||||
vector<Factor*> getFactors (void) const;
|
void removeVariable (const FgVarNode*);
|
||||||
|
void addFactor (Factor*);
|
||||||
|
void removeFactor (const Factor*);
|
||||||
VarSet getVariables (void) const;
|
VarSet getVariables (void) const;
|
||||||
FgVarNode* getVariableById (unsigned) const;
|
Variable* getVariable (unsigned) const;
|
||||||
FgVarNode* getVariableByLabel (string) const;
|
void setIndexes (void);
|
||||||
void printFactorGraph (void) const;
|
void freeDistributions (void);
|
||||||
|
void printGraphicalModel (void) const;
|
||||||
|
void exportToDotFormat (const char*) const;
|
||||||
|
void exportToUaiFormat (const char*) const;
|
||||||
|
|
||||||
|
const FgVarSet& getFgVarNodes (void) const { return varNodes_; }
|
||||||
|
const FactorSet& getFactors (void) const { return factors_; }
|
||||||
|
|
||||||
|
FgVarNode* getFgVarNode (Vid vid) const
|
||||||
|
{
|
||||||
|
IndexMap::const_iterator it = indexMap_.find (vid);
|
||||||
|
if (it == indexMap_.end()) {
|
||||||
|
return 0;
|
||||||
|
} else {
|
||||||
|
return varNodes_[it->second];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
DISALLOW_COPY_AND_ASSIGN (FactorGraph);
|
DISALLOW_COPY_AND_ASSIGN (FactorGraph);
|
||||||
|
|
||||||
FgVarSet varNodes_;
|
FgVarSet varNodes_;
|
||||||
vector<Factor*> factors_;
|
FactorSet factors_;
|
||||||
|
IndexMap indexMap_;
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif
|
#endif // BP_FACTOR_GRAPH_H
|
||||||
|
|
||||||
|
@ -1,8 +1,7 @@
|
|||||||
#ifndef BP_VARIABLE_H
|
#ifndef BP_FG_VAR_NODE_H
|
||||||
#define BP_VARIABLE_H
|
#define BP_FG_VAR_NODE_H
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
|
||||||
|
|
||||||
#include "Variable.h"
|
#include "Variable.h"
|
||||||
#include "Shared.h"
|
#include "Shared.h"
|
||||||
@ -14,15 +13,31 @@ class Factor;
|
|||||||
class FgVarNode : public Variable
|
class FgVarNode : public Variable
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
FgVarNode (int varId, int dsize) : Variable (varId, dsize) { }
|
FgVarNode (unsigned vid, unsigned dsize) : Variable (vid, dsize) { }
|
||||||
|
FgVarNode (const Variable* v) : Variable (v) { }
|
||||||
|
|
||||||
void addFactor (Factor* f) { factors_.push_back (f); }
|
void addFactor (Factor* f) { factors_.push_back (f); }
|
||||||
vector<Factor*> getFactors (void) const { return factors_; }
|
CFactorSet getFactors (void) const { return factors_; }
|
||||||
|
|
||||||
|
void removeFactor (const Factor* f)
|
||||||
|
{
|
||||||
|
if (factors_[factors_.size() -1] == f) {
|
||||||
|
factors_.pop_back();
|
||||||
|
} else {
|
||||||
|
for (unsigned i = 0; i < factors_.size(); i++) {
|
||||||
|
if (factors_[i] == f) {
|
||||||
|
factors_.erase (factors_.begin() + i);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert (false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
DISALLOW_COPY_AND_ASSIGN (FgVarNode);
|
DISALLOW_COPY_AND_ASSIGN (FgVarNode);
|
||||||
// members
|
// members
|
||||||
vector<Factor*> factors_;
|
FactorSet factors_;
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // BP_VARIABLE_H
|
#endif // BP_FG_VAR_NODE_H
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
#ifndef BP_GRAPHICALMODEL_H
|
#ifndef BP_GRAPHICAL_MODEL_H
|
||||||
#define BP_GRAPHICALMODEL_H
|
#define BP_GRAPHICAL_MODEL_H
|
||||||
|
|
||||||
#include "Variable.h"
|
#include "Variable.h"
|
||||||
#include "Shared.h"
|
#include "Shared.h"
|
||||||
@ -9,9 +9,10 @@ using namespace std;
|
|||||||
class GraphicalModel
|
class GraphicalModel
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
virtual VarSet getVariables (void) const = 0;
|
virtual ~GraphicalModel (void) {};
|
||||||
|
virtual Variable* getVariable (Vid) const = 0;
|
||||||
private:
|
virtual VarSet getVariables (void) const = 0;
|
||||||
|
virtual void printGraphicalModel (void) const = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif
|
#endif // BP_GRAPHICAL_MODEL_H
|
||||||
|
@ -1,17 +1,19 @@
|
|||||||
#include <iostream>
|
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
#include "BayesNet.h"
|
#include "BayesNet.h"
|
||||||
#include "BPSolver.h"
|
|
||||||
|
|
||||||
#include "FactorGraph.h"
|
#include "FactorGraph.h"
|
||||||
#include "SPSolver.h"
|
#include "SPSolver.h"
|
||||||
|
#include "BPSolver.h"
|
||||||
|
#include "CountingBP.h"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
void BayesianNetwork (int, const char* []);
|
void BayesianNetwork (int, const char* []);
|
||||||
void markovNetwork (int, const char* []);
|
void markovNetwork (int, const char* []);
|
||||||
|
void runSolver (Solver*, const VarSet&);
|
||||||
|
|
||||||
const string USAGE = "usage: \
|
const string USAGE = "usage: \
|
||||||
./hcli FILE [VARIABLE | OBSERVED_VARIABLE=EVIDENCE]..." ;
|
./hcli FILE [VARIABLE | OBSERVED_VARIABLE=EVIDENCE]..." ;
|
||||||
@ -19,14 +21,40 @@ const string USAGE = "usage: \
|
|||||||
|
|
||||||
int
|
int
|
||||||
main (int argc, const char* argv[])
|
main (int argc, const char* argv[])
|
||||||
{
|
{
|
||||||
|
/*
|
||||||
|
FactorGraph fg;
|
||||||
|
FgVarNode* varNode1 = new FgVarNode (0, 2);
|
||||||
|
FgVarNode* varNode2 = new FgVarNode (1, 2);
|
||||||
|
FgVarNode* varNode3 = new FgVarNode (2, 2);
|
||||||
|
fg.addVariable (varNode1);
|
||||||
|
fg.addVariable (varNode2);
|
||||||
|
fg.addVariable (varNode3);
|
||||||
|
Distribution* dist = new Distribution (ParamSet() = {1.2, 1.4, 2.0, 0.4});
|
||||||
|
fg.addFactor (new Factor (FgVarSet() = {varNode1, varNode2}, dist));
|
||||||
|
fg.addFactor (new Factor (FgVarSet() = {varNode3, varNode2}, dist));
|
||||||
|
//fg.printGraphicalModel();
|
||||||
|
//SPSolver sp (fg);
|
||||||
|
//sp.runSolver();
|
||||||
|
//sp.printAllPosterioris();
|
||||||
|
//ParamSet p = sp.getJointDistributionOf (VidSet() = {0, 1, 2});
|
||||||
|
//cout << Util::parametersToString (p) << endl;
|
||||||
|
CountingBP cbp (fg);
|
||||||
|
//cbp.runSolver();
|
||||||
|
//cbp.printAllPosterioris();
|
||||||
|
ParamSet p2 = cbp.getJointDistributionOf (VidSet() = {0, 1, 2});
|
||||||
|
cout << Util::parametersToString (p2) << endl;
|
||||||
|
fg.freeDistributions();
|
||||||
|
Statistics::printCompressingStats ("compressing.stats");
|
||||||
|
return 0;
|
||||||
|
*/
|
||||||
if (!argv[1]) {
|
if (!argv[1]) {
|
||||||
cerr << "error: no graphical model specified" << endl;
|
cerr << "error: no graphical model specified" << endl;
|
||||||
cerr << USAGE << endl;
|
cerr << USAGE << endl;
|
||||||
exit (0);
|
exit (0);
|
||||||
}
|
}
|
||||||
string fileName = argv[1];
|
const string& fileName = argv[1];
|
||||||
string extension = fileName.substr (fileName.find_last_of ('.') + 1);
|
const string& extension = fileName.substr (fileName.find_last_of ('.') + 1);
|
||||||
if (extension == "xml") {
|
if (extension == "xml") {
|
||||||
BayesianNetwork (argc, argv);
|
BayesianNetwork (argc, argv);
|
||||||
} else if (extension == "uai") {
|
} else if (extension == "uai") {
|
||||||
@ -45,13 +73,13 @@ void
|
|||||||
BayesianNetwork (int argc, const char* argv[])
|
BayesianNetwork (int argc, const char* argv[])
|
||||||
{
|
{
|
||||||
BayesNet bn (argv[1]);
|
BayesNet bn (argv[1]);
|
||||||
//bn.printNetwork();
|
//bn.printGraphicalModel();
|
||||||
|
|
||||||
NodeSet queryVars;
|
VarSet queryVars;
|
||||||
for (int i = 2; i < argc; i++) {
|
for (int i = 2; i < argc; i++) {
|
||||||
string arg = argv[i];
|
const string& arg = argv[i];
|
||||||
if (arg.find ('=') == std::string::npos) {
|
if (arg.find ('=') == std::string::npos) {
|
||||||
BayesNode* queryVar = bn.getNode (arg);
|
BayesNode* queryVar = bn.getBayesNode (arg);
|
||||||
if (queryVar) {
|
if (queryVar) {
|
||||||
queryVars.push_back (queryVar);
|
queryVars.push_back (queryVar);
|
||||||
} else {
|
} else {
|
||||||
@ -61,9 +89,9 @@ BayesianNetwork (int argc, const char* argv[])
|
|||||||
exit (0);
|
exit (0);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
size_t pos = arg.find ('=');
|
size_t pos = arg.find ('=');
|
||||||
string label = arg.substr (0, pos);
|
const string& label = arg.substr (0, pos);
|
||||||
string state = arg.substr (pos + 1);
|
const string& state = arg.substr (pos + 1);
|
||||||
if (label.empty()) {
|
if (label.empty()) {
|
||||||
cerr << "error: missing left argument" << endl;
|
cerr << "error: missing left argument" << endl;
|
||||||
cerr << USAGE << endl;
|
cerr << USAGE << endl;
|
||||||
@ -74,7 +102,7 @@ BayesianNetwork (int argc, const char* argv[])
|
|||||||
cerr << USAGE << endl;
|
cerr << USAGE << endl;
|
||||||
exit (0);
|
exit (0);
|
||||||
}
|
}
|
||||||
BayesNode* node = bn.getNode (label);
|
BayesNode* node = bn.getBayesNode (label);
|
||||||
if (node) {
|
if (node) {
|
||||||
if (node->isValidState (state)) {
|
if (node->isValidState (state)) {
|
||||||
node->setEvidence (state);
|
node->setEvidence (state);
|
||||||
@ -94,19 +122,16 @@ BayesianNetwork (int argc, const char* argv[])
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
BPSolver solver (bn);
|
Solver* solver;
|
||||||
if (queryVars.size() == 0) {
|
if (SolverOptions::convertBn2Fg) {
|
||||||
solver.runSolver();
|
FactorGraph* fg = new FactorGraph (bn);
|
||||||
solver.printAllPosterioris();
|
fg->printGraphicalModel();
|
||||||
} else if (queryVars.size() == 1) {
|
solver = new SPSolver (*fg);
|
||||||
solver.runSolver();
|
runSolver (solver, queryVars);
|
||||||
solver.printPosterioriOf (queryVars[0]);
|
delete fg;
|
||||||
} else {
|
} else {
|
||||||
Domain domain = BayesNet::getInstantiations(queryVars);
|
solver = new BPSolver (bn);
|
||||||
ParamSet params = solver.getJointDistribution (queryVars);
|
runSolver (solver, queryVars);
|
||||||
for (unsigned i = 0; i < params.size(); i++) {
|
|
||||||
cout << domain[i] << "\t" << params[i] << endl;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
bn.freeDistributions();
|
bn.freeDistributions();
|
||||||
}
|
}
|
||||||
@ -117,11 +142,11 @@ void
|
|||||||
markovNetwork (int argc, const char* argv[])
|
markovNetwork (int argc, const char* argv[])
|
||||||
{
|
{
|
||||||
FactorGraph fg (argv[1]);
|
FactorGraph fg (argv[1]);
|
||||||
//fg.printFactorGraph();
|
//fg.printGraphicalModel();
|
||||||
|
|
||||||
VarSet queryVars;
|
VarSet queryVars;
|
||||||
for (int i = 2; i < argc; i++) {
|
for (int i = 2; i < argc; i++) {
|
||||||
string arg = argv[i];
|
const string& arg = argv[i];
|
||||||
if (arg.find ('=') == std::string::npos) {
|
if (arg.find ('=') == std::string::npos) {
|
||||||
if (!Util::isInteger (arg)) {
|
if (!Util::isInteger (arg)) {
|
||||||
cerr << "error: `" << arg << "' " ;
|
cerr << "error: `" << arg << "' " ;
|
||||||
@ -129,16 +154,16 @@ markovNetwork (int argc, const char* argv[])
|
|||||||
cerr << endl;
|
cerr << endl;
|
||||||
exit (0);
|
exit (0);
|
||||||
}
|
}
|
||||||
unsigned varId;
|
Vid vid;
|
||||||
stringstream ss;
|
stringstream ss;
|
||||||
ss << arg;
|
ss << arg;
|
||||||
ss >> varId;
|
ss >> vid;
|
||||||
Variable* queryVar = fg.getVariableById (varId);
|
Variable* queryVar = fg.getFgVarNode (vid);
|
||||||
if (queryVar) {
|
if (queryVar) {
|
||||||
queryVars.push_back (queryVar);
|
queryVars.push_back (queryVar);
|
||||||
} else {
|
} else {
|
||||||
cerr << "error: there isn't a variable with " ;
|
cerr << "error: there isn't a variable with " ;
|
||||||
cerr << "`" << varId << "' as id" ;
|
cerr << "`" << vid << "' as id" ;
|
||||||
cerr << endl;
|
cerr << endl;
|
||||||
exit (0);
|
exit (0);
|
||||||
}
|
}
|
||||||
@ -160,11 +185,11 @@ markovNetwork (int argc, const char* argv[])
|
|||||||
cerr << endl;
|
cerr << endl;
|
||||||
exit (0);
|
exit (0);
|
||||||
}
|
}
|
||||||
unsigned varId;
|
Vid vid;
|
||||||
stringstream ss;
|
stringstream ss;
|
||||||
ss << arg.substr (0, pos);
|
ss << arg.substr (0, pos);
|
||||||
ss >> varId;
|
ss >> vid;
|
||||||
Variable* var = fg.getVariableById (varId);
|
Variable* var = fg.getFgVarNode (vid);
|
||||||
if (var) {
|
if (var) {
|
||||||
if (!Util::isInteger (arg.substr (pos + 1))) {
|
if (!Util::isInteger (arg.substr (pos + 1))) {
|
||||||
cerr << "error: `" << arg.substr (pos + 1) << "' " ;
|
cerr << "error: `" << arg.substr (pos + 1) << "' " ;
|
||||||
@ -176,7 +201,6 @@ markovNetwork (int argc, const char* argv[])
|
|||||||
stringstream ss;
|
stringstream ss;
|
||||||
ss << arg.substr (pos + 1);
|
ss << arg.substr (pos + 1);
|
||||||
ss >> stateIndex;
|
ss >> stateIndex;
|
||||||
cout << "si: " << stateIndex << endl;
|
|
||||||
if (var->isValidStateIndex (stateIndex)) {
|
if (var->isValidStateIndex (stateIndex)) {
|
||||||
var->setEvidence (stateIndex);
|
var->setEvidence (stateIndex);
|
||||||
} else {
|
} else {
|
||||||
@ -188,27 +212,35 @@ markovNetwork (int argc, const char* argv[])
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
cerr << "error: there isn't a variable with " ;
|
cerr << "error: there isn't a variable with " ;
|
||||||
cerr << "`" << varId << "' as id" ;
|
cerr << "`" << vid << "' as id" ;
|
||||||
cerr << endl;
|
cerr << endl;
|
||||||
exit (0);
|
exit (0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Solver* solver = new SPSolver (fg);
|
||||||
SPSolver solver (fg);
|
runSolver (solver, queryVars);
|
||||||
if (queryVars.size() == 0) {
|
fg.freeDistributions();
|
||||||
solver.runSolver();
|
}
|
||||||
solver.printAllPosterioris();
|
|
||||||
} else if (queryVars.size() == 1) {
|
|
||||||
solver.runSolver();
|
|
||||||
solver.printPosterioriOf (queryVars[0]);
|
void
|
||||||
} else {
|
runSolver (Solver* solver, const VarSet& queryVars)
|
||||||
assert (false); //FIXME
|
{
|
||||||
//Domain domain = BayesNet::getInstantiations(queryVars);
|
VidSet vids;
|
||||||
//ParamSet params = solver.getJointDistribution (queryVars);
|
for (unsigned i = 0; i < queryVars.size(); i++) {
|
||||||
//for (unsigned i = 0; i < params.size(); i++) {
|
vids.push_back (queryVars[i]->getVarId());
|
||||||
// cout << domain[i] << "\t" << params[i] << endl;
|
}
|
||||||
//}
|
if (queryVars.size() == 0) {
|
||||||
}
|
solver->runSolver();
|
||||||
|
solver->printAllPosterioris();
|
||||||
|
} else if (queryVars.size() == 1) {
|
||||||
|
solver->runSolver();
|
||||||
|
solver->printPosterioriOf (vids[0]);
|
||||||
|
} else {
|
||||||
|
solver->printJointDistributionOf (vids);
|
||||||
|
}
|
||||||
|
delete solver;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,41 +1,39 @@
|
|||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <vector>
|
|
||||||
#include <string>
|
|
||||||
|
|
||||||
#include <YapInterface.h>
|
#include <YapInterface.h>
|
||||||
|
|
||||||
#include "callgrind.h"
|
|
||||||
|
|
||||||
#include "BayesNet.h"
|
#include "BayesNet.h"
|
||||||
#include "BayesNode.h"
|
#include "FactorGraph.h"
|
||||||
#include "BPSolver.h"
|
#include "BPSolver.h"
|
||||||
|
#include "SPSolver.h"
|
||||||
|
#include "CountingBP.h"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
|
|
||||||
int
|
int
|
||||||
createNetwork (void)
|
createNetwork (void)
|
||||||
{
|
{
|
||||||
Statistics::numCreatedNets ++;
|
//Statistics::numCreatedNets ++;
|
||||||
cout << "creating network number " << Statistics::numCreatedNets << endl;
|
//cout << "creating network number " << Statistics::numCreatedNets << endl;
|
||||||
if (Statistics::numCreatedNets == 1) {
|
|
||||||
//CALLGRIND_START_INSTRUMENTATION;
|
|
||||||
}
|
|
||||||
BayesNet* bn = new BayesNet();
|
|
||||||
|
|
||||||
|
BayesNet* bn = new BayesNet();
|
||||||
YAP_Term varList = YAP_ARG1;
|
YAP_Term varList = YAP_ARG1;
|
||||||
while (varList != YAP_TermNil()) {
|
while (varList != YAP_TermNil()) {
|
||||||
YAP_Term var = YAP_HeadOfTerm (varList);
|
YAP_Term var = YAP_HeadOfTerm (varList);
|
||||||
unsigned varId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, var));
|
Vid vid = (Vid) YAP_IntOfTerm (YAP_ArgOfTerm (1, var));
|
||||||
unsigned dsize = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (2, var));
|
unsigned dsize = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (2, var));
|
||||||
int evidence = (int) YAP_IntOfTerm (YAP_ArgOfTerm (3, var));
|
int evidence = (int) YAP_IntOfTerm (YAP_ArgOfTerm (3, var));
|
||||||
YAP_Term parentL = YAP_ArgOfTerm (4, var);
|
YAP_Term parentL = YAP_ArgOfTerm (4, var);
|
||||||
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (5, var));
|
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (5, var));
|
||||||
NodeSet parents;
|
BnNodeSet parents;
|
||||||
while (parentL != YAP_TermNil()) {
|
while (parentL != YAP_TermNil()) {
|
||||||
unsigned parentId = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (parentL));
|
unsigned parentId = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (parentL));
|
||||||
BayesNode* parent = bn->getNode (parentId);
|
BayesNode* parent = bn->getBayesNode (parentId);
|
||||||
if (!parent) {
|
if (!parent) {
|
||||||
parent = bn->addNode (parentId);
|
parent = bn->addNode (parentId);
|
||||||
}
|
}
|
||||||
@ -47,23 +45,20 @@ createNetwork (void)
|
|||||||
dist = new Distribution (distId);
|
dist = new Distribution (distId);
|
||||||
bn->addDistribution (dist);
|
bn->addDistribution (dist);
|
||||||
}
|
}
|
||||||
BayesNode* node = bn->getNode (varId);
|
BayesNode* node = bn->getBayesNode (vid);
|
||||||
if (node) {
|
if (node) {
|
||||||
node->setData (dsize, evidence, parents, dist);
|
node->setData (dsize, evidence, parents, dist);
|
||||||
} else {
|
} else {
|
||||||
bn->addNode (varId, dsize, evidence, parents, dist);
|
bn->addNode (vid, dsize, evidence, parents, dist);
|
||||||
}
|
}
|
||||||
varList = YAP_TailOfTerm (varList);
|
varList = YAP_TailOfTerm (varList);
|
||||||
}
|
}
|
||||||
bn->setIndexes();
|
bn->setIndexes();
|
||||||
|
|
||||||
if (Statistics::numCreatedNets == 1688) {
|
// if (Statistics::numCreatedNets == 1688) {
|
||||||
Statistics::writeStats();
|
// Statistics::writeStats();
|
||||||
//Statistics::writeStats();
|
// exit (0);
|
||||||
//CALLGRIND_STOP_INSTRUMENTATION;
|
// }
|
||||||
//CALLGRIND_DUMP_STATS;
|
|
||||||
//exit (0);
|
|
||||||
}
|
|
||||||
YAP_Int p = (YAP_Int) (bn);
|
YAP_Int p = (YAP_Int) (bn);
|
||||||
return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG2);
|
return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG2);
|
||||||
}
|
}
|
||||||
@ -73,20 +68,20 @@ createNetwork (void)
|
|||||||
int
|
int
|
||||||
setExtraVarsInfo (void)
|
setExtraVarsInfo (void)
|
||||||
{
|
{
|
||||||
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
|
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
|
||||||
YAP_Term varsInfoL = YAP_ARG2;
|
YAP_Term varsInfoL = YAP_ARG2;
|
||||||
while (varsInfoL != YAP_TermNil()) {
|
while (varsInfoL != YAP_TermNil()) {
|
||||||
YAP_Term head = YAP_HeadOfTerm (varsInfoL);
|
YAP_Term head = YAP_HeadOfTerm (varsInfoL);
|
||||||
unsigned varId = YAP_IntOfTerm (YAP_ArgOfTerm (1, head));
|
Vid vid = YAP_IntOfTerm (YAP_ArgOfTerm (1, head));
|
||||||
YAP_Atom label = YAP_AtomOfTerm (YAP_ArgOfTerm (2, head));
|
YAP_Atom label = YAP_AtomOfTerm (YAP_ArgOfTerm (2, head));
|
||||||
YAP_Term domainL = YAP_ArgOfTerm (3, head);
|
YAP_Term domainL = YAP_ArgOfTerm (3, head);
|
||||||
Domain domain;
|
Domain domain;
|
||||||
while (domainL != YAP_TermNil()) {
|
while (domainL != YAP_TermNil()) {
|
||||||
YAP_Atom atom = YAP_AtomOfTerm (YAP_HeadOfTerm (domainL));
|
YAP_Atom atom = YAP_AtomOfTerm (YAP_HeadOfTerm (domainL));
|
||||||
domain.push_back ((char*) YAP_AtomName (atom));
|
domain.push_back ((char*) YAP_AtomName (atom));
|
||||||
domainL = YAP_TailOfTerm (domainL);
|
domainL = YAP_TailOfTerm (domainL);
|
||||||
}
|
}
|
||||||
BayesNode* node = bn->getNode (varId);
|
BayesNode* node = bn->getBayesNode (vid);
|
||||||
assert (node);
|
assert (node);
|
||||||
node->setLabel ((char*) YAP_AtomName (label));
|
node->setLabel ((char*) YAP_AtomName (label));
|
||||||
node->setDomain (domain);
|
node->setDomain (domain);
|
||||||
@ -100,8 +95,8 @@ setExtraVarsInfo (void)
|
|||||||
int
|
int
|
||||||
setParameters (void)
|
setParameters (void)
|
||||||
{
|
{
|
||||||
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
|
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
|
||||||
YAP_Term distList = YAP_ARG2;
|
YAP_Term distList = YAP_ARG2;
|
||||||
while (distList != YAP_TermNil()) {
|
while (distList != YAP_TermNil()) {
|
||||||
YAP_Term dist = YAP_HeadOfTerm (distList);
|
YAP_Term dist = YAP_HeadOfTerm (distList);
|
||||||
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, dist));
|
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, dist));
|
||||||
@ -112,6 +107,11 @@ setParameters (void)
|
|||||||
paramL = YAP_TailOfTerm (paramL);
|
paramL = YAP_TailOfTerm (paramL);
|
||||||
}
|
}
|
||||||
bn->getDistribution(distId)->updateParameters(params);
|
bn->getDistribution(distId)->updateParameters(params);
|
||||||
|
if (Statistics::numCreatedNets == 4) {
|
||||||
|
cout << "dist " << distId << " parameters:" ;
|
||||||
|
cout << Util::parametersToString (params);
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
distList = YAP_TailOfTerm (distList);
|
distList = YAP_TailOfTerm (distList);
|
||||||
}
|
}
|
||||||
return TRUE;
|
return TRUE;
|
||||||
@ -122,84 +122,126 @@ setParameters (void)
|
|||||||
int
|
int
|
||||||
runSolver (void)
|
runSolver (void)
|
||||||
{
|
{
|
||||||
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
|
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
|
||||||
YAP_Term taskList = YAP_ARG2;
|
YAP_Term taskList = YAP_ARG2;
|
||||||
|
vector<VidSet> tasks;
|
||||||
vector<NodeSet> tasks;
|
VidSet marginalVids;
|
||||||
NodeSet marginalVars;
|
|
||||||
|
|
||||||
while (taskList != YAP_TermNil()) {
|
while (taskList != YAP_TermNil()) {
|
||||||
if (YAP_IsPairTerm (YAP_HeadOfTerm (taskList))) {
|
if (YAP_IsPairTerm (YAP_HeadOfTerm (taskList))) {
|
||||||
NodeSet jointVars;
|
VidSet jointVids;
|
||||||
YAP_Term jointList = YAP_HeadOfTerm (taskList);
|
YAP_Term jointList = YAP_HeadOfTerm (taskList);
|
||||||
while (jointList != YAP_TermNil()) {
|
while (jointList != YAP_TermNil()) {
|
||||||
unsigned varId = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (jointList));
|
Vid vid = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (jointList));
|
||||||
assert (bn->getNode (varId));
|
assert (bn->getBayesNode (vid));
|
||||||
jointVars.push_back (bn->getNode (varId));
|
jointVids.push_back (vid);
|
||||||
jointList = YAP_TailOfTerm (jointList);
|
jointList = YAP_TailOfTerm (jointList);
|
||||||
}
|
}
|
||||||
tasks.push_back (jointVars);
|
tasks.push_back (jointVids);
|
||||||
} else {
|
} else {
|
||||||
unsigned varId = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (taskList));
|
Vid vid = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (taskList));
|
||||||
BayesNode* node = bn->getNode (varId);
|
assert (bn->getBayesNode (vid));
|
||||||
assert (node);
|
tasks.push_back (VidSet() = {vid});
|
||||||
tasks.push_back (NodeSet() = {node});
|
marginalVids.push_back (vid);
|
||||||
marginalVars.push_back (node);
|
|
||||||
}
|
}
|
||||||
taskList = YAP_TailOfTerm (taskList);
|
taskList = YAP_TailOfTerm (taskList);
|
||||||
}
|
}
|
||||||
/*
|
|
||||||
cout << "tasks to resolve:" << endl;
|
|
||||||
for (unsigned i = 0; i < tasks.size(); i++) {
|
|
||||||
cout << "i" << ": " ;
|
|
||||||
if (tasks[i].size() == 1) {
|
|
||||||
cout << tasks[i][0]->getVarId() << endl;
|
|
||||||
} else {
|
|
||||||
for (unsigned j = 0; j < tasks[i].size(); j++) {
|
|
||||||
cout << tasks[i][j]->getVarId() << " " ;
|
|
||||||
}
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
cerr << "prunning now..." << endl;
|
// cout << "inference tasks:" << endl;
|
||||||
BayesNet* prunedNet = bn->pruneNetwork (marginalVars);
|
// for (unsigned i = 0; i < tasks.size(); i++) {
|
||||||
bn->printNetworkToFile ("net.txt");
|
// cout << "i" << ": " ;
|
||||||
BPSolver solver (*prunedNet);
|
// if (tasks[i].size() == 1) {
|
||||||
cerr << "solving marginals now..." << endl;
|
// cout << tasks[i][0] << endl;
|
||||||
solver.runSolver();
|
// } else {
|
||||||
cerr << "calculating joints now ..." << endl;
|
// for (unsigned j = 0; j < tasks[i].size(); j++) {
|
||||||
|
// cout << tasks[i][j] << " " ;
|
||||||
|
// }
|
||||||
|
// cout << endl;
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
Solver* solver = 0;
|
||||||
|
GraphicalModel* gm = 0;
|
||||||
|
VidSet vids;
|
||||||
|
const BnNodeSet& nodes = bn->getBayesNodes();
|
||||||
|
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||||
|
vids.push_back (nodes[i]->getVarId());
|
||||||
|
}
|
||||||
|
if (marginalVids.size() != 0) {
|
||||||
|
bn->exportToDotFormat ("bn unbayes.dot");
|
||||||
|
BayesNet* mrn = bn->getMinimalRequesiteNetwork (marginalVids);
|
||||||
|
mrn->exportToDotFormat ("bn bayes.dot");
|
||||||
|
//BayesNet* mrn = bn->getMinimalRequesiteNetwork (vids);
|
||||||
|
if (SolverOptions::convertBn2Fg) {
|
||||||
|
gm = new FactorGraph (*mrn);
|
||||||
|
if (SolverOptions::compressFactorGraph) {
|
||||||
|
solver = new CountingBP (*static_cast<FactorGraph*> (gm));
|
||||||
|
} else {
|
||||||
|
solver = new SPSolver (*static_cast<FactorGraph*> (gm));
|
||||||
|
}
|
||||||
|
if (SolverOptions::runBayesBall) {
|
||||||
|
delete mrn;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
gm = mrn;
|
||||||
|
solver = new BPSolver (*static_cast<BayesNet*> (gm));
|
||||||
|
}
|
||||||
|
solver->runSolver();
|
||||||
|
}
|
||||||
|
|
||||||
vector<ParamSet> results;
|
vector<ParamSet> results;
|
||||||
results.reserve (tasks.size());
|
results.reserve (tasks.size());
|
||||||
for (unsigned i = 0; i < tasks.size(); i++) {
|
for (unsigned i = 0; i < tasks.size(); i++) {
|
||||||
if (tasks[i].size() == 1) {
|
if (tasks[i].size() == 1) {
|
||||||
BayesNode* node = prunedNet->getNode (tasks[i][0]->getVarId());
|
results.push_back (solver->getPosterioriOf (tasks[i][0]));
|
||||||
results.push_back (solver.getPosterioriOf (node));
|
|
||||||
} else {
|
} else {
|
||||||
BPSolver solver2 (*bn);
|
static int count = 0;
|
||||||
cout << "calculating an join dist on: " ;
|
cout << "calculating joint... " << count ++ << endl;
|
||||||
for (unsigned j = 0; j < tasks[i].size(); j++) {
|
//if (count == 5225) {
|
||||||
cout << tasks[i][j]->getVarId() << " " ;
|
// Statistics::printCompressingStats ("compressing.stats");
|
||||||
|
//}
|
||||||
|
Solver* solver2 = 0;
|
||||||
|
GraphicalModel* gm2 = 0;
|
||||||
|
bn->exportToDotFormat ("joint.dot");
|
||||||
|
BayesNet* mrn2;
|
||||||
|
if (SolverOptions::runBayesBall) {
|
||||||
|
mrn2 = bn->getMinimalRequesiteNetwork (tasks[i]);
|
||||||
|
} else {
|
||||||
|
mrn2 = bn;
|
||||||
}
|
}
|
||||||
cout << "..." << endl;
|
if (SolverOptions::convertBn2Fg) {
|
||||||
results.push_back (solver2.getJointDistribution (tasks[i]));
|
gm2 = new FactorGraph (*mrn2);
|
||||||
|
if (SolverOptions::compressFactorGraph) {
|
||||||
|
solver2 = new CountingBP (*static_cast<FactorGraph*> (gm2));
|
||||||
|
} else {
|
||||||
|
solver2 = new SPSolver (*static_cast<FactorGraph*> (gm2));
|
||||||
|
}
|
||||||
|
if (SolverOptions::runBayesBall) {
|
||||||
|
delete mrn2;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
gm2 = mrn2;
|
||||||
|
solver2 = new BPSolver (*static_cast<BayesNet*> (gm2));
|
||||||
|
}
|
||||||
|
results.push_back (solver2->getJointDistributionOf (tasks[i]));
|
||||||
|
delete solver2;
|
||||||
|
delete gm2;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
delete prunedNet;
|
delete solver;
|
||||||
|
delete gm;
|
||||||
|
|
||||||
YAP_Term list = YAP_TermNil();
|
YAP_Term list = YAP_TermNil();
|
||||||
for (int i = results.size() - 1; i >= 0; i--) {
|
for (int i = results.size() - 1; i >= 0; i--) {
|
||||||
const ParamSet& beliefs = results[i];
|
const ParamSet& beliefs = results[i];
|
||||||
YAP_Term queryBeliefsL = YAP_TermNil();
|
YAP_Term queryBeliefsL = YAP_TermNil();
|
||||||
for (int j = beliefs.size() - 1; j >= 0; j--) {
|
for (int j = beliefs.size() - 1; j >= 0; j--) {
|
||||||
YAP_Int sl1 = YAP_InitSlot(list);
|
YAP_Int sl1 = YAP_InitSlot (list);
|
||||||
YAP_Term belief = YAP_MkFloatTerm (beliefs[j]);
|
YAP_Term belief = YAP_MkFloatTerm (beliefs[j]);
|
||||||
queryBeliefsL = YAP_MkPairTerm (belief, queryBeliefsL);
|
queryBeliefsL = YAP_MkPairTerm (belief, queryBeliefsL);
|
||||||
list = YAP_GetFromSlot(sl1);
|
list = YAP_GetFromSlot (sl1);
|
||||||
YAP_RecoverSlots(1);
|
YAP_RecoverSlots (1);
|
||||||
}
|
}
|
||||||
list = YAP_MkPairTerm (queryBeliefsL, list);
|
list = YAP_MkPairTerm (queryBeliefsL, list);
|
||||||
}
|
}
|
||||||
@ -210,8 +252,9 @@ runSolver (void)
|
|||||||
|
|
||||||
|
|
||||||
int
|
int
|
||||||
deleteBayesNet (void)
|
freeBayesNetwork (void)
|
||||||
{
|
{
|
||||||
|
//Statistics::printCompressingStats ("../../compressing.stats");
|
||||||
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
|
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
|
||||||
bn->freeDistributions();
|
bn->freeDistributions();
|
||||||
delete bn;
|
delete bn;
|
||||||
@ -223,10 +266,10 @@ deleteBayesNet (void)
|
|||||||
extern "C" void
|
extern "C" void
|
||||||
init_predicates (void)
|
init_predicates (void)
|
||||||
{
|
{
|
||||||
YAP_UserCPredicate ("create_network", createNetwork, 2);
|
YAP_UserCPredicate ("create_network", createNetwork, 2);
|
||||||
YAP_UserCPredicate ("set_extra_vars_info", setExtraVarsInfo, 2);
|
YAP_UserCPredicate ("set_extra_vars_info", setExtraVarsInfo, 2);
|
||||||
YAP_UserCPredicate ("set_parameters", setParameters, 2);
|
YAP_UserCPredicate ("set_parameters", setParameters, 2);
|
||||||
YAP_UserCPredicate ("run_solver", runSolver, 3);
|
YAP_UserCPredicate ("run_solver", runSolver, 3);
|
||||||
YAP_UserCPredicate ("delete_bayes_net", deleteBayesNet, 1);
|
YAP_UserCPredicate ("free_bayesian_network", freeBayesNetwork, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
278
packages/CLPBN/clpbn/bp/LiftedFG.cpp
Normal file
278
packages/CLPBN/clpbn/bp/LiftedFG.cpp
Normal file
@ -0,0 +1,278 @@
|
|||||||
|
|
||||||
|
#include "LiftedFG.h"
|
||||||
|
#include "FgVarNode.h"
|
||||||
|
#include "Factor.h"
|
||||||
|
#include "Distribution.h"
|
||||||
|
|
||||||
|
LiftedFG::LiftedFG (const FactorGraph& fg)
|
||||||
|
{
|
||||||
|
groundFg_ = &fg;
|
||||||
|
freeColor_ = 0;
|
||||||
|
|
||||||
|
const FgVarSet& varNodes = fg.getFgVarNodes();
|
||||||
|
const FactorSet& factors = fg.getFactors();
|
||||||
|
varColors_.resize (varNodes.size());
|
||||||
|
factorColors_.resize (factors.size());
|
||||||
|
for (unsigned i = 0; i < factors.size(); i++) {
|
||||||
|
factors[i]->setIndex (i);
|
||||||
|
}
|
||||||
|
|
||||||
|
// create the initial variable colors
|
||||||
|
VarColorMap colorMap;
|
||||||
|
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||||
|
unsigned dsize = varNodes[i]->getDomainSize();
|
||||||
|
VarColorMap::iterator it = colorMap.find (dsize);
|
||||||
|
if (it == colorMap.end()) {
|
||||||
|
it = colorMap.insert (make_pair (
|
||||||
|
dsize, vector<Color> (dsize + 1,-1))).first;
|
||||||
|
}
|
||||||
|
unsigned idx;
|
||||||
|
if (varNodes[i]->hasEvidence()) {
|
||||||
|
idx = varNodes[i]->getEvidence();
|
||||||
|
} else {
|
||||||
|
idx = dsize;
|
||||||
|
}
|
||||||
|
vector<Color>& stateColors = it->second;
|
||||||
|
if (stateColors[idx] == -1) {
|
||||||
|
stateColors[idx] = getFreeColor();
|
||||||
|
}
|
||||||
|
setColor (varNodes[i], stateColors[idx]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// create the initial factor colors
|
||||||
|
DistColorMap distColors;
|
||||||
|
for (unsigned i = 0; i < factors.size(); i++) {
|
||||||
|
Distribution* dist = factors[i]->getDistribution();
|
||||||
|
DistColorMap::iterator it = distColors.find (dist);
|
||||||
|
if (it == distColors.end()) {
|
||||||
|
it = distColors.insert (make_pair (dist, getFreeColor())).first;
|
||||||
|
}
|
||||||
|
setColor (factors[i], it->second);
|
||||||
|
}
|
||||||
|
|
||||||
|
VarSignMap varGroups;
|
||||||
|
FactorSignMap factorGroups;
|
||||||
|
bool groupsHaveChanged = true;
|
||||||
|
unsigned nIter = 0;
|
||||||
|
while (groupsHaveChanged || nIter == 1) {
|
||||||
|
nIter ++;
|
||||||
|
if (Statistics::numCreatedNets == 4) {
|
||||||
|
cout << "--------------------------------------------" << endl;
|
||||||
|
cout << "Iteration " << nIter << endl;
|
||||||
|
cout << "--------------------------------------------" << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned prevFactorGroupsSize = factorGroups.size();
|
||||||
|
factorGroups.clear();
|
||||||
|
// set a new color to the factors with the same signature
|
||||||
|
for (unsigned i = 0; i < factors.size(); i++) {
|
||||||
|
const string& signatureId = getSignatureId (factors[i]);
|
||||||
|
// cout << factors[i]->getLabel() << " signature: " ;
|
||||||
|
// cout<< signatureId << endl;
|
||||||
|
FactorSignMap::iterator it = factorGroups.find (signatureId);
|
||||||
|
if (it == factorGroups.end()) {
|
||||||
|
it = factorGroups.insert (make_pair (signatureId, FactorSet())).first;
|
||||||
|
}
|
||||||
|
it->second.push_back (factors[i]);
|
||||||
|
}
|
||||||
|
if (nIter > 0)
|
||||||
|
for (FactorSignMap::iterator it = factorGroups.begin();
|
||||||
|
it != factorGroups.end(); it++) {
|
||||||
|
Color newColor = getFreeColor();
|
||||||
|
FactorSet& groupMembers = it->second;
|
||||||
|
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
||||||
|
setColor (groupMembers[i], newColor);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// set a new color to the variables with the same signature
|
||||||
|
unsigned prevVarGroupsSize = varGroups.size();
|
||||||
|
varGroups.clear();
|
||||||
|
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||||
|
const string& signatureId = getSignatureId (varNodes[i]);
|
||||||
|
VarSignMap::iterator it = varGroups.find (signatureId);
|
||||||
|
// cout << varNodes[i]->getLabel() << " signature: " ;
|
||||||
|
// cout << signatureId << endl;
|
||||||
|
if (it == varGroups.end()) {
|
||||||
|
it = varGroups.insert (make_pair (signatureId, FgVarSet())).first;
|
||||||
|
}
|
||||||
|
it->second.push_back (varNodes[i]);
|
||||||
|
}
|
||||||
|
if (nIter > 0)
|
||||||
|
for (VarSignMap::iterator it = varGroups.begin();
|
||||||
|
it != varGroups.end(); it++) {
|
||||||
|
Color newColor = getFreeColor();
|
||||||
|
FgVarSet& groupMembers = it->second;
|
||||||
|
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
||||||
|
setColor (groupMembers[i], newColor);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//if (nIter >= 3) cout << "bigger than three: " << nIter << endl;
|
||||||
|
groupsHaveChanged = prevVarGroupsSize != varGroups.size()
|
||||||
|
|| prevFactorGroupsSize != factorGroups.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
printGroups (varGroups, factorGroups);
|
||||||
|
for (VarSignMap::iterator it = varGroups.begin();
|
||||||
|
it != varGroups.end(); it++) {
|
||||||
|
CFgVarSet vars = it->second;
|
||||||
|
VarCluster* vc = new VarCluster (vars);
|
||||||
|
for (unsigned i = 0; i < vars.size(); i++) {
|
||||||
|
vid2VarCluster_.insert (make_pair (vars[i]->getVarId(), vc));
|
||||||
|
}
|
||||||
|
varClusters_.push_back (vc);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (FactorSignMap::iterator it = factorGroups.begin();
|
||||||
|
it != factorGroups.end(); it++) {
|
||||||
|
VarClusterSet varClusters;
|
||||||
|
Factor* groundFactor = it->second[0];
|
||||||
|
FgVarSet groundVars = groundFactor->getFgVarNodes();
|
||||||
|
for (unsigned i = 0; i < groundVars.size(); i++) {
|
||||||
|
Vid vid = groundVars[i]->getVarId();
|
||||||
|
varClusters.push_back (vid2VarCluster_.find (vid)->second);
|
||||||
|
}
|
||||||
|
factorClusters_.push_back (new FactorCluster (it->second, varClusters));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
LiftedFG::~LiftedFG (void)
|
||||||
|
{
|
||||||
|
for (unsigned i = 0; i < varClusters_.size(); i++) {
|
||||||
|
delete varClusters_[i];
|
||||||
|
}
|
||||||
|
for (unsigned i = 0; i < factorClusters_.size(); i++) {
|
||||||
|
delete factorClusters_[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
string
|
||||||
|
LiftedFG::getSignatureId (FgVarNode* var) const
|
||||||
|
{
|
||||||
|
stringstream ss;
|
||||||
|
CFactorSet myFactors = var->getFactors();
|
||||||
|
ss << myFactors.size();
|
||||||
|
for (unsigned i = 0; i < myFactors.size(); i++) {
|
||||||
|
ss << "." << getColor (myFactors[i]);
|
||||||
|
ss << "." << myFactors[i]->getIndexOf(var);
|
||||||
|
}
|
||||||
|
ss << "." << getColor (var);
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
string
|
||||||
|
LiftedFG::getSignatureId (Factor* factor) const
|
||||||
|
{
|
||||||
|
stringstream ss;
|
||||||
|
CFgVarSet myVars = factor->getFgVarNodes();
|
||||||
|
ss << myVars.size();
|
||||||
|
for (unsigned i = 0; i < myVars.size(); i++) {
|
||||||
|
ss << "." << getColor (myVars[i]);
|
||||||
|
}
|
||||||
|
ss << "." << getColor (factor);
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
FactorGraph*
|
||||||
|
LiftedFG::getCompressedFactorGraph (void)
|
||||||
|
{
|
||||||
|
FactorGraph* fg = new FactorGraph();
|
||||||
|
for (unsigned i = 0; i < varClusters_.size(); i++) {
|
||||||
|
FgVarNode* var = varClusters_[i]->getGroundFgVarNodes()[0];
|
||||||
|
FgVarNode* newVar = new FgVarNode (var);
|
||||||
|
newVar->setIndex (i);
|
||||||
|
varClusters_[i]->setRepresentativeVariable (newVar);
|
||||||
|
fg->addVariable (newVar);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (unsigned i = 0; i < factorClusters_.size(); i++) {
|
||||||
|
FgVarSet myGroundVars;
|
||||||
|
const VarClusterSet& myVarClusters = factorClusters_[i]->getVarClusters();
|
||||||
|
for (unsigned j = 0; j < myVarClusters.size(); j++) {
|
||||||
|
myGroundVars.push_back (myVarClusters[j]->getRepresentativeVariable());
|
||||||
|
}
|
||||||
|
Factor* newFactor = new Factor (myGroundVars,
|
||||||
|
factorClusters_[i]->getGroundFactors()[0]->getDistribution());
|
||||||
|
factorClusters_[i]->setRepresentativeFactor (newFactor);
|
||||||
|
fg->addFactor (newFactor);
|
||||||
|
}
|
||||||
|
return fg;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
unsigned
|
||||||
|
LiftedFG::getGroundEdgeCount (FactorCluster* fc, VarCluster* vc) const
|
||||||
|
{
|
||||||
|
CFactorSet clusterGroundFactors = fc->getGroundFactors();
|
||||||
|
FgVarNode* var = vc->getGroundFgVarNodes()[0];
|
||||||
|
unsigned count = 0;
|
||||||
|
for (unsigned i = 0; i < clusterGroundFactors.size(); i++) {
|
||||||
|
if (clusterGroundFactors[i]->getIndexOf (var) != -1) {
|
||||||
|
count ++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/*
|
||||||
|
CFgVarSet vars = vc->getGroundFgVarNodes();
|
||||||
|
for (unsigned i = 1; i < vars.size(); i++) {
|
||||||
|
FgVarNode* var = vc->getGroundFgVarNodes()[i];
|
||||||
|
unsigned count2 = 0;
|
||||||
|
for (unsigned i = 0; i < clusterGroundFactors.size(); i++) {
|
||||||
|
if (clusterGroundFactors[i]->getIndexOf (var) != -1) {
|
||||||
|
count2 ++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (count != count2) { cout << "oops!" << endl; abort(); }
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
return count;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
LiftedFG::printGroups (const VarSignMap& varGroups,
|
||||||
|
const FactorSignMap& factorGroups) const
|
||||||
|
{
|
||||||
|
cout << "variable groups:" << endl;
|
||||||
|
unsigned count = 0;
|
||||||
|
for (VarSignMap::const_iterator it = varGroups.begin();
|
||||||
|
it != varGroups.end(); it++) {
|
||||||
|
const FgVarSet& groupMembers = it->second;
|
||||||
|
if (groupMembers.size() > 0) {
|
||||||
|
cout << ++count << ": " ;
|
||||||
|
//if (groupMembers.size() > 1) {
|
||||||
|
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
||||||
|
cout << groupMembers[i]->getLabel() << " " ;
|
||||||
|
}
|
||||||
|
//}
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cout << endl;
|
||||||
|
cout << "factor groups:" << endl;
|
||||||
|
count = 0;
|
||||||
|
for (FactorSignMap::const_iterator it = factorGroups.begin();
|
||||||
|
it != factorGroups.end(); it++) {
|
||||||
|
const FactorSet& groupMembers = it->second;
|
||||||
|
if (groupMembers.size() > 0) {
|
||||||
|
cout << ++count << ": " ;
|
||||||
|
//if (groupMembers.size() > 1) {
|
||||||
|
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
||||||
|
cout << groupMembers[i]->getLabel() << " " ;
|
||||||
|
}
|
||||||
|
//}
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
152
packages/CLPBN/clpbn/bp/LiftedFG.h
Normal file
152
packages/CLPBN/clpbn/bp/LiftedFG.h
Normal file
@ -0,0 +1,152 @@
|
|||||||
|
#ifndef BP_LIFTED_FG_H
|
||||||
|
#define BP_LIFTED_FG_H
|
||||||
|
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
|
#include "FactorGraph.h"
|
||||||
|
#include "FgVarNode.h"
|
||||||
|
#include "Factor.h"
|
||||||
|
#include "Shared.h"
|
||||||
|
|
||||||
|
class VarCluster;
|
||||||
|
class FactorCluster;
|
||||||
|
class Distribution;
|
||||||
|
|
||||||
|
typedef long Color;
|
||||||
|
typedef vector<Color> Signature;
|
||||||
|
typedef vector<VarCluster*> VarClusterSet;
|
||||||
|
typedef vector<FactorCluster*> FactorClusterSet;
|
||||||
|
|
||||||
|
typedef map<string, FgVarSet> VarSignMap;
|
||||||
|
typedef map<string, FactorSet> FactorSignMap;
|
||||||
|
|
||||||
|
typedef map<unsigned, vector<Color> > VarColorMap;
|
||||||
|
typedef map<Distribution*, Color> DistColorMap;
|
||||||
|
|
||||||
|
typedef map<Vid, VarCluster*> Vid2VarCluster;
|
||||||
|
|
||||||
|
|
||||||
|
class VarCluster
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
VarCluster (CFgVarSet vs)
|
||||||
|
{
|
||||||
|
for (unsigned i = 0; i < vs.size(); i++) {
|
||||||
|
groundVars_.push_back (vs[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void addFactorCluster (FactorCluster* fc)
|
||||||
|
{
|
||||||
|
factorClusters_.push_back (fc);
|
||||||
|
}
|
||||||
|
|
||||||
|
const FactorClusterSet& getFactorClusters (void) const
|
||||||
|
{
|
||||||
|
return factorClusters_;
|
||||||
|
}
|
||||||
|
|
||||||
|
FgVarNode* getRepresentativeVariable (void) const { return representVar_; }
|
||||||
|
void setRepresentativeVariable (FgVarNode* v) { representVar_ = v; }
|
||||||
|
CFgVarSet getGroundFgVarNodes (void) const { return groundVars_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
FgVarSet groundVars_;
|
||||||
|
FactorClusterSet factorClusters_;
|
||||||
|
FgVarNode* representVar_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
class FactorCluster
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
FactorCluster (CFactorSet groundFactors, const VarClusterSet& vcs)
|
||||||
|
{
|
||||||
|
groundFactors_ = groundFactors;
|
||||||
|
varClusters_ = vcs;
|
||||||
|
for (unsigned i = 0; i < varClusters_.size(); i++) {
|
||||||
|
varClusters_[i]->addFactorCluster (this);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const VarClusterSet& getVarClusters (void) const
|
||||||
|
{
|
||||||
|
return varClusters_;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool containsGround (const Factor* f)
|
||||||
|
{
|
||||||
|
for (unsigned i = 0; i < groundFactors_.size(); i++) {
|
||||||
|
if (groundFactors_[i] == f) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
Factor* getRepresentativeFactor (void) const { return representFactor_; }
|
||||||
|
void setRepresentativeFactor (Factor* f) { representFactor_ = f; }
|
||||||
|
CFactorSet getGroundFactors (void) const { return groundFactors_; }
|
||||||
|
|
||||||
|
|
||||||
|
private:
|
||||||
|
FactorSet groundFactors_;
|
||||||
|
VarClusterSet varClusters_;
|
||||||
|
Factor* representFactor_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
class LiftedFG
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
LiftedFG (const FactorGraph&);
|
||||||
|
~LiftedFG (void);
|
||||||
|
|
||||||
|
FactorGraph* getCompressedFactorGraph (void);
|
||||||
|
unsigned getGroundEdgeCount (FactorCluster*, VarCluster*) const;
|
||||||
|
void printGroups (const VarSignMap& varGroups,
|
||||||
|
const FactorSignMap& factorGroups) const;
|
||||||
|
|
||||||
|
FgVarNode* getEquivalentVariable (Vid vid)
|
||||||
|
{
|
||||||
|
VarCluster* vc = vid2VarCluster_.find (vid)->second;
|
||||||
|
return vc->getRepresentativeVariable();
|
||||||
|
}
|
||||||
|
|
||||||
|
const VarClusterSet& getVariableClusters (void) { return varClusters_; }
|
||||||
|
const FactorClusterSet& getFactorClusters (void) { return factorClusters_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
string getSignatureId (FgVarNode*) const;
|
||||||
|
string getSignatureId (Factor*) const;
|
||||||
|
|
||||||
|
Color getFreeColor (void) { return ++freeColor_ -1; }
|
||||||
|
Color getColor (FgVarNode* v) const { return varColors_[v->getIndex()]; }
|
||||||
|
Color getColor (Factor* f) const { return factorColors_[f->getIndex()]; }
|
||||||
|
|
||||||
|
void setColor (FgVarNode* v, Color c)
|
||||||
|
{
|
||||||
|
varColors_[v->getIndex()] = c;
|
||||||
|
}
|
||||||
|
|
||||||
|
void setColor (Factor* f, Color c)
|
||||||
|
{
|
||||||
|
factorColors_[f->getIndex()] = c;
|
||||||
|
}
|
||||||
|
|
||||||
|
VarCluster* getVariableCluster (Vid vid) const
|
||||||
|
{
|
||||||
|
return vid2VarCluster_.find (vid)->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
Color freeColor_;
|
||||||
|
vector<Color> varColors_;
|
||||||
|
vector<Color> factorColors_;
|
||||||
|
VarClusterSet varClusters_;
|
||||||
|
FactorClusterSet factorClusters_;
|
||||||
|
Vid2VarCluster vid2VarCluster_;
|
||||||
|
const FactorGraph* groundFg_;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // BP_LIFTED_FG_H
|
||||||
|
|
@ -50,28 +50,33 @@ CWD=$(PWD)
|
|||||||
HEADERS = \
|
HEADERS = \
|
||||||
$(srcdir)/GraphicalModel.h \
|
$(srcdir)/GraphicalModel.h \
|
||||||
$(srcdir)/Variable.h \
|
$(srcdir)/Variable.h \
|
||||||
|
$(srcdir)/Distribution.h \
|
||||||
$(srcdir)/BayesNet.h \
|
$(srcdir)/BayesNet.h \
|
||||||
$(srcdir)/BayesNode.h \
|
$(srcdir)/BayesNode.h \
|
||||||
$(srcdir)/Distribution.h \
|
$(srcdir)/LiftedFG.h \
|
||||||
$(srcdir)/CptEntry.h \
|
$(srcdir)/CptEntry.h \
|
||||||
$(srcdir)/FactorGraph.h \
|
$(srcdir)/FactorGraph.h \
|
||||||
$(srcdir)/FgVarNode.h \
|
$(srcdir)/FgVarNode.h \
|
||||||
$(srcdir)/Factor.h \
|
$(srcdir)/Factor.h \
|
||||||
$(srcdir)/Solver.h \
|
$(srcdir)/Solver.h \
|
||||||
$(srcdir)/BPSolver.h \
|
$(srcdir)/BPSolver.h \
|
||||||
$(srcdir)/BpNode.h \
|
$(srcdir)/BPNodeInfo.h \
|
||||||
$(srcdir)/SPSolver.h \
|
$(srcdir)/SPSolver.h \
|
||||||
|
$(srcdir)/CountingBP.h \
|
||||||
$(srcdir)/Shared.h \
|
$(srcdir)/Shared.h \
|
||||||
$(srcdir)/xmlParser/xmlParser.h
|
$(srcdir)/xmlParser/xmlParser.h
|
||||||
|
|
||||||
CPP_SOURCES = \
|
CPP_SOURCES = \
|
||||||
$(srcdir)/BayesNet.cpp \
|
$(srcdir)/BayesNet.cpp \
|
||||||
$(srcdir)/BayesNode.cpp \
|
$(srcdir)/BayesNode.cpp \
|
||||||
$(srcdir)/FactorGraph.cpp \
|
$(srcdir)/FactorGraph.cpp \
|
||||||
$(srcdir)/Factor.cpp \
|
$(srcdir)/Factor.cpp \
|
||||||
|
$(srcdir)/LiftedFG.cpp \
|
||||||
$(srcdir)/BPSolver.cpp \
|
$(srcdir)/BPSolver.cpp \
|
||||||
$(srcdir)/BpNode.cpp \
|
$(srcdir)/BPNodeInfo.cpp \
|
||||||
$(srcdir)/SPSolver.cpp \
|
$(srcdir)/SPSolver.cpp \
|
||||||
|
$(srcdir)/CountingBP.cpp \
|
||||||
|
$(srcdir)/Util.cpp \
|
||||||
$(srcdir)/HorusYap.cpp \
|
$(srcdir)/HorusYap.cpp \
|
||||||
$(srcdir)/HorusCli.cpp \
|
$(srcdir)/HorusCli.cpp \
|
||||||
$(srcdir)/xmlParser/xmlParser.cpp
|
$(srcdir)/xmlParser/xmlParser.cpp
|
||||||
@ -82,22 +87,38 @@ OBJS = \
|
|||||||
FactorGraph.o \
|
FactorGraph.o \
|
||||||
Factor.o \
|
Factor.o \
|
||||||
BPSolver.o \
|
BPSolver.o \
|
||||||
BpNode.o \
|
BPNodeInfo.o \
|
||||||
SPSolver.o \
|
SPSolver.o \
|
||||||
HorusYap.o \
|
Util.o \
|
||||||
xmlParser.o
|
LiftedFG.o \
|
||||||
|
CountingBP.o \
|
||||||
|
HorusYap.o
|
||||||
|
|
||||||
|
HCLI_OBJS = \
|
||||||
|
BayesNet.o \
|
||||||
|
BayesNode.o \
|
||||||
|
FactorGraph.o \
|
||||||
|
Factor.o \
|
||||||
|
BPSolver.o \
|
||||||
|
BPNodeInfo.o \
|
||||||
|
SPSolver.o \
|
||||||
|
Util.o \
|
||||||
|
LiftedFG.o \
|
||||||
|
CountingBP.o \
|
||||||
|
HorusCli.o \
|
||||||
|
xmlParser/xmlParser.o
|
||||||
|
|
||||||
SOBJS=horus.@SO@
|
SOBJS=horus.@SO@
|
||||||
|
|
||||||
|
|
||||||
all: $(SOBJS)
|
all: $(SOBJS) hcli
|
||||||
|
|
||||||
# default rule
|
# default rule
|
||||||
%.o : $(srcdir)/%.cpp
|
%.o : $(srcdir)/%.cpp
|
||||||
$(CXX) -c $(CXXFLAGS) $< -o $@
|
$(CXX) -c $(CXXFLAGS) $< -o $@
|
||||||
|
|
||||||
|
|
||||||
xmlParser.o : $(srcdir)/xmlParser/xmlParser.cpp
|
xmlParser/xmlParser.o : $(srcdir)/xmlParser/xmlParser.cpp
|
||||||
$(CXX) -c $(CXXFLAGS) $< -o $@
|
$(CXX) -c $(CXXFLAGS) $< -o $@
|
||||||
|
|
||||||
|
|
||||||
@ -105,7 +126,7 @@ xmlParser.o : $(srcdir)/xmlParser/xmlParser.cpp
|
|||||||
@DO_SECOND_LD@ @SHLIB_CXX_LD@ -o horus.@SO@ $(OBJS) @EXTRA_LIBS_FOR_SWIDLLS@
|
@DO_SECOND_LD@ @SHLIB_CXX_LD@ -o horus.@SO@ $(OBJS) @EXTRA_LIBS_FOR_SWIDLLS@
|
||||||
|
|
||||||
|
|
||||||
hcli: $(OBJS)
|
hcli: $(HCLI_OBJS)
|
||||||
$(CXX) -o hcli $(HCLI_OBJS)
|
$(CXX) -o hcli $(HCLI_OBJS)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,38 +1,77 @@
|
|||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <algorithm>
|
#include <limits>
|
||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
||||||
#include "SPSolver.h"
|
#include "SPSolver.h"
|
||||||
#include "FactorGraph.h"
|
#include "FactorGraph.h"
|
||||||
#include "FgVarNode.h"
|
#include "FgVarNode.h"
|
||||||
#include "Factor.h"
|
#include "Factor.h"
|
||||||
|
#include "Shared.h"
|
||||||
SPSolver* Link::klass = 0;
|
|
||||||
|
|
||||||
|
|
||||||
SPSolver::SPSolver (const FactorGraph& fg) : Solver (&fg)
|
SPSolver::SPSolver (FactorGraph& fg) : Solver (&fg)
|
||||||
{
|
{
|
||||||
fg_ = &fg;
|
fg_ = &fg;
|
||||||
accuracy_ = 0.0001;
|
|
||||||
maxIter_ = 10000;
|
|
||||||
//schedule_ = S_SEQ_FIXED;
|
|
||||||
//schedule_ = S_SEQ_RANDOM;
|
|
||||||
//schedule_ = S_SEQ_PARALLEL;
|
|
||||||
schedule_ = S_MAX_RESIDUAL;
|
|
||||||
Link::klass = this;
|
|
||||||
FgVarSet vars = fg_->getFgVarNodes();
|
|
||||||
for (unsigned i = 0; i < vars.size(); i++) {
|
|
||||||
msgs_.push_back (new MessageBanket (vars[i]));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
SPSolver::~SPSolver (void)
|
SPSolver::~SPSolver (void)
|
||||||
{
|
{
|
||||||
for (unsigned i = 0; i < msgs_.size(); i++) {
|
for (unsigned i = 0; i < varsI_.size(); i++) {
|
||||||
delete msgs_[i];
|
delete varsI_[i];
|
||||||
}
|
}
|
||||||
|
for (unsigned i = 0; i < factorsI_.size(); i++) {
|
||||||
|
delete factorsI_[i];
|
||||||
|
}
|
||||||
|
for (unsigned i = 0; i < links_.size(); i++) {
|
||||||
|
delete links_[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
SPSolver::runTreeSolver (void)
|
||||||
|
{
|
||||||
|
CFactorSet factors = fg_->getFactors();
|
||||||
|
bool finish = false;
|
||||||
|
while (!finish) {
|
||||||
|
finish = true;
|
||||||
|
for (unsigned i = 0; i < factors.size(); i++) {
|
||||||
|
CLinkSet links = factorsI_[factors[i]->getIndex()]->getLinks();
|
||||||
|
for (unsigned j = 0; j < links.size(); j++) {
|
||||||
|
if (!links[j]->messageWasSended()) {
|
||||||
|
if (readyToSendMessage(links[j])) {
|
||||||
|
links[j]->setNextMessage (getFactor2VarMsg (links[j]));
|
||||||
|
links[j]->updateMessage();
|
||||||
|
}
|
||||||
|
finish = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
SPSolver::readyToSendMessage (const Link* link) const
|
||||||
|
{
|
||||||
|
CFgVarSet factorVars = link->getFactor()->getFgVarNodes();
|
||||||
|
for (unsigned i = 0; i < factorVars.size(); i++) {
|
||||||
|
if (factorVars[i] != link->getVariable()) {
|
||||||
|
CLinkSet links = varsI_[factorVars[i]->getIndex()]->getLinks();
|
||||||
|
for (unsigned j = 0; j < links.size(); j++) {
|
||||||
|
if (links[j]->getFactor() != link->getFactor() &&
|
||||||
|
!links[j]->messageWasSended()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -40,62 +79,54 @@ SPSolver::~SPSolver (void)
|
|||||||
void
|
void
|
||||||
SPSolver::runSolver (void)
|
SPSolver::runSolver (void)
|
||||||
{
|
{
|
||||||
|
initializeSolver();
|
||||||
|
runTreeSolver();
|
||||||
|
return;
|
||||||
nIter_ = 0;
|
nIter_ = 0;
|
||||||
vector<Factor*> factors = fg_->getFactors();
|
while (!converged() && nIter_ < SolverOptions::maxIter) {
|
||||||
for (unsigned i = 0; i < factors.size(); i++) {
|
|
||||||
FgVarSet neighbors = factors[i]->getFgVarNodes();
|
nIter_ ++;
|
||||||
for (unsigned j = 0; j < neighbors.size(); j++) {
|
if (DL >= 2) {
|
||||||
updateOrder_.push_back (Link (factors[i], neighbors[j]));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
while (!converged() && nIter_ < maxIter_) {
|
|
||||||
if (DL >= 1) {
|
|
||||||
cout << endl;
|
cout << endl;
|
||||||
cout << "****************************************" ;
|
cout << "****************************************" ;
|
||||||
cout << "****************************************" ;
|
cout << "****************************************" ;
|
||||||
cout << endl;
|
cout << endl;
|
||||||
cout << " Iteration " << nIter_ + 1 << endl;
|
cout << " Iteration " << nIter_ << endl;
|
||||||
cout << "****************************************" ;
|
cout << "****************************************" ;
|
||||||
cout << "****************************************" ;
|
cout << "****************************************" ;
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
switch (schedule_) {
|
switch (SolverOptions::schedule) {
|
||||||
|
case SolverOptions::S_SEQ_RANDOM:
|
||||||
case S_SEQ_RANDOM:
|
random_shuffle (links_.begin(), links_.end());
|
||||||
random_shuffle (updateOrder_.begin(), updateOrder_.end());
|
|
||||||
// no break
|
// no break
|
||||||
|
|
||||||
case S_SEQ_FIXED:
|
case SolverOptions::S_SEQ_FIXED:
|
||||||
for (unsigned c = 0; c < updateOrder_.size(); c++) {
|
for (unsigned i = 0; i < links_.size(); i++) {
|
||||||
Link& link = updateOrder_[c];
|
links_[i]->setNextMessage (getFactor2VarMsg (links_[i]));
|
||||||
calculateNextMessage (link.source, link.destination);
|
links_[i]->updateMessage();
|
||||||
updateMessage (updateOrder_[c]);
|
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case S_PARALLEL:
|
case SolverOptions::S_PARALLEL:
|
||||||
for (unsigned c = 0; c < updateOrder_.size(); c++) {
|
for (unsigned i = 0; i < links_.size(); i++) {
|
||||||
Link link = updateOrder_[c];
|
links_[i]->setNextMessage (getFactor2VarMsg (links_[i]));
|
||||||
calculateNextMessage (link.source, link.destination);
|
|
||||||
}
|
}
|
||||||
for (unsigned c = 0; c < updateOrder_.size(); c++) {
|
for (unsigned i = 0; i < links_.size(); i++) {
|
||||||
Link link = updateOrder_[c];
|
links_[i]->updateMessage();
|
||||||
updateMessage (updateOrder_[c]);
|
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case S_MAX_RESIDUAL:
|
case SolverOptions::S_MAX_RESIDUAL:
|
||||||
maxResidualSchedule();
|
maxResidualSchedule();
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
nIter_++;
|
|
||||||
}
|
}
|
||||||
cout << endl;
|
|
||||||
if (DL >= 1) {
|
if (DL >= 2) {
|
||||||
if (nIter_ < maxIter_) {
|
cout << endl;
|
||||||
|
if (nIter_ < SolverOptions::maxIter) {
|
||||||
cout << "Loopy Sum-Product converged in " ;
|
cout << "Loopy Sum-Product converged in " ;
|
||||||
cout << nIter_ << " iterations" << endl;
|
cout << nIter_ << " iterations" << endl;
|
||||||
} else {
|
} else {
|
||||||
@ -108,58 +139,168 @@ SPSolver::runSolver (void)
|
|||||||
|
|
||||||
|
|
||||||
ParamSet
|
ParamSet
|
||||||
SPSolver::getPosterioriOf (const Variable* var) const
|
SPSolver::getPosterioriOf (Vid vid) const
|
||||||
{
|
{
|
||||||
assert (var);
|
assert (fg_->getFgVarNode (vid));
|
||||||
assert (var == fg_->getVariableById (var->getVarId()));
|
FgVarNode* var = fg_->getFgVarNode (vid);
|
||||||
assert (var->getIndex() < msgs_.size());
|
ParamSet probs;
|
||||||
|
|
||||||
ParamSet probs (var->getDomainSize(), 1);
|
|
||||||
if (var->hasEvidence()) {
|
if (var->hasEvidence()) {
|
||||||
for (unsigned i = 0; i < probs.size(); i++) {
|
probs.resize (var->getDomainSize(), 0.0);
|
||||||
if ((int)i != var->getEvidence()) {
|
probs[var->getEvidence()] = 1.0;
|
||||||
probs[i] = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
|
probs.resize (var->getDomainSize(), 1.0);
|
||||||
MessageBanket* mb = msgs_[var->getIndex()];
|
CLinkSet links = varsI_[var->getIndex()]->getLinks();
|
||||||
const FgVarNode* varNode = fg_->getFgVarNodes()[var->getIndex()];
|
for (unsigned i = 0; i < links.size(); i++) {
|
||||||
vector<Factor*> neighbors = varNode->getFactors();
|
CParamSet msg = links[i]->getMessage();
|
||||||
for (unsigned i = 0; i < neighbors.size(); i++) {
|
|
||||||
const Message& msg = mb->getMessage (neighbors[i]);
|
|
||||||
for (unsigned j = 0; j < msg.size(); j++) {
|
for (unsigned j = 0; j < msg.size(); j++) {
|
||||||
probs[j] *= msg[j];
|
probs[j] *= msg[j];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Util::normalize (probs);
|
Util::normalize (probs);
|
||||||
}
|
}
|
||||||
|
|
||||||
return probs;
|
return probs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
ParamSet
|
||||||
|
SPSolver::getJointDistributionOf (const VidSet& jointVids)
|
||||||
|
{
|
||||||
|
FgVarSet jointVars;
|
||||||
|
unsigned dsize = 1;
|
||||||
|
for (unsigned i = 0; i < jointVids.size(); i++) {
|
||||||
|
FgVarNode* varNode = fg_->getFgVarNode (jointVids[i]);
|
||||||
|
dsize *= varNode->getDomainSize();
|
||||||
|
jointVars.push_back (varNode);
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned maxVid = std::numeric_limits<unsigned>::max();
|
||||||
|
FgVarNode* junctionVar = new FgVarNode (maxVid, dsize);
|
||||||
|
FgVarSet factorVars = { junctionVar };
|
||||||
|
for (unsigned i = 0; i < jointVars.size(); i++) {
|
||||||
|
factorVars.push_back (jointVars[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned nParams = dsize * dsize;
|
||||||
|
ParamSet params (nParams);
|
||||||
|
for (unsigned i = 0; i < nParams; i++) {
|
||||||
|
unsigned row = i / dsize;
|
||||||
|
unsigned col = i % dsize;
|
||||||
|
if (row == col) {
|
||||||
|
params[i] = 1;
|
||||||
|
} else {
|
||||||
|
params[i] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Distribution* dist = new Distribution (params, maxVid);
|
||||||
|
Factor* newFactor = new Factor (factorVars, dist);
|
||||||
|
fg_->addVariable (junctionVar);
|
||||||
|
fg_->addFactor (newFactor);
|
||||||
|
|
||||||
|
runSolver();
|
||||||
|
ParamSet results = getPosterioriOf (maxVid);
|
||||||
|
deleteJunction (newFactor, junctionVar);
|
||||||
|
|
||||||
|
return results;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
SPSolver::initializeSolver (void)
|
||||||
|
{
|
||||||
|
fg_->setIndexes();
|
||||||
|
|
||||||
|
CFgVarSet vars = fg_->getFgVarNodes();
|
||||||
|
for (unsigned i = 0; i < varsI_.size(); i++) {
|
||||||
|
delete varsI_[i];
|
||||||
|
}
|
||||||
|
varsI_.reserve (vars.size());
|
||||||
|
for (unsigned i = 0; i < vars.size(); i++) {
|
||||||
|
varsI_.push_back (new SPNodeInfo());
|
||||||
|
}
|
||||||
|
|
||||||
|
CFactorSet factors = fg_->getFactors();
|
||||||
|
for (unsigned i = 0; i < factorsI_.size(); i++) {
|
||||||
|
delete factorsI_[i];
|
||||||
|
}
|
||||||
|
factorsI_.reserve (factors.size());
|
||||||
|
for (unsigned i = 0; i < factors.size(); i++) {
|
||||||
|
factorsI_.push_back (new SPNodeInfo());
|
||||||
|
}
|
||||||
|
|
||||||
|
for (unsigned i = 0; i < links_.size(); i++) {
|
||||||
|
delete links_[i];
|
||||||
|
}
|
||||||
|
createLinks();
|
||||||
|
|
||||||
|
for (unsigned i = 0; i < links_.size(); i++) {
|
||||||
|
Factor* source = links_[i]->getFactor();
|
||||||
|
FgVarNode* dest = links_[i]->getVariable();
|
||||||
|
varsI_[dest->getIndex()]->addLink (links_[i]);
|
||||||
|
factorsI_[source->getIndex()]->addLink (links_[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
SPSolver::createLinks (void)
|
||||||
|
{
|
||||||
|
CFactorSet factors = fg_->getFactors();
|
||||||
|
for (unsigned i = 0; i < factors.size(); i++) {
|
||||||
|
CFgVarSet neighbors = factors[i]->getFgVarNodes();
|
||||||
|
for (unsigned j = 0; j < neighbors.size(); j++) {
|
||||||
|
links_.push_back (new Link (factors[i], neighbors[j]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
SPSolver::deleteJunction (Factor* f, FgVarNode* v)
|
||||||
|
{
|
||||||
|
fg_->removeFactor (f);
|
||||||
|
f->freeDistribution();
|
||||||
|
delete f;
|
||||||
|
fg_->removeVariable (v);
|
||||||
|
delete v;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
bool
|
||||||
SPSolver::converged (void)
|
SPSolver::converged (void)
|
||||||
{
|
{
|
||||||
|
// this can happen if the graph is fully disconnected
|
||||||
|
if (links_.size() == 0) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
if (nIter_ == 0 || nIter_ == 1) {
|
if (nIter_ == 0 || nIter_ == 1) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
bool converged = true;
|
bool converged = true;
|
||||||
for (unsigned i = 0; i < updateOrder_.size(); i++) {
|
if (SolverOptions::schedule == SolverOptions::S_MAX_RESIDUAL) {
|
||||||
double residual = getResidual (updateOrder_[i]);
|
Param maxResidual = (*(sortedOrder_.begin()))->getResidual();
|
||||||
if (DL >= 1) {
|
if (maxResidual < SolverOptions::accuracy) {
|
||||||
cout << updateOrder_[i].toString();
|
converged = true;
|
||||||
cout << " residual = " << residual << endl;
|
} else {
|
||||||
}
|
|
||||||
if (residual > accuracy_) {
|
|
||||||
converged = false;
|
converged = false;
|
||||||
if (DL == 0) {
|
}
|
||||||
break;
|
} else {
|
||||||
|
for (unsigned i = 0; i < links_.size(); i++) {
|
||||||
|
double residual = links_[i]->getResidual();
|
||||||
|
if (DL >= 2) {
|
||||||
|
cout << links_[i]->toString() + " residual = " << residual << endl;
|
||||||
}
|
}
|
||||||
}
|
if (residual > SolverOptions::accuracy) {
|
||||||
|
converged = false;
|
||||||
|
if (DL == 0) break;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return converged;
|
return converged;
|
||||||
}
|
}
|
||||||
@ -169,127 +310,161 @@ SPSolver::converged (void)
|
|||||||
void
|
void
|
||||||
SPSolver::maxResidualSchedule (void)
|
SPSolver::maxResidualSchedule (void)
|
||||||
{
|
{
|
||||||
if (nIter_ == 0) {
|
if (nIter_ == 1) {
|
||||||
for (unsigned c = 0; c < updateOrder_.size(); c++) {
|
for (unsigned i = 0; i < links_.size(); i++) {
|
||||||
Link& l = updateOrder_[c];
|
links_[i]->setNextMessage (getFactor2VarMsg (links_[i]));
|
||||||
calculateNextMessage (l.source, l.destination);
|
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
|
||||||
if (DL >= 1) {
|
linkMap_.insert (make_pair (links_[i], it));
|
||||||
cout << updateOrder_[c].toString() << " residual = " ;
|
if (DL >= 2 && DL < 5) {
|
||||||
cout << getResidual (updateOrder_[c]) << endl;
|
cout << "calculating " << links_[i]->toString() << endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
sort (updateOrder_.begin(), updateOrder_.end(), compareResidual);
|
return;
|
||||||
} else {
|
}
|
||||||
|
|
||||||
for (unsigned c = 0; c < updateOrder_.size(); c++) {
|
for (unsigned c = 0; c < links_.size(); c++) {
|
||||||
Link& link = updateOrder_.front();
|
if (DL >= 2) {
|
||||||
updateMessage (link);
|
cout << endl << "current residuals:" << endl;
|
||||||
resetResidual (link);
|
for (SortedOrder::iterator it = sortedOrder_.begin();
|
||||||
|
it != sortedOrder_.end(); it ++) {
|
||||||
|
cout << " " << setw (30) << left << (*it)->toString();
|
||||||
|
cout << "residual = " << (*it)->getResidual() << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// update the messages that depend on message source --> destination
|
SortedOrder::iterator it = sortedOrder_.begin();
|
||||||
vector<Factor*> fstLevelNeighbors = link.destination->getFactors();
|
Link* link = *it;
|
||||||
for (unsigned i = 0; i < fstLevelNeighbors.size(); i++) {
|
if (DL >= 2) {
|
||||||
if (fstLevelNeighbors[i] != link.source) {
|
cout << "updating " << (*sortedOrder_.begin())->toString() << endl;
|
||||||
FgVarSet sndLevelNeighbors;
|
}
|
||||||
sndLevelNeighbors = fstLevelNeighbors[i]->getFgVarNodes();
|
if (link->getResidual() < SolverOptions::accuracy) {
|
||||||
for (unsigned j = 0; j < sndLevelNeighbors.size(); j++) {
|
return;
|
||||||
if (sndLevelNeighbors[j] != link.destination) {
|
}
|
||||||
calculateNextMessage (fstLevelNeighbors[i], sndLevelNeighbors[j]);
|
link->updateMessage();
|
||||||
|
link->clearResidual();
|
||||||
|
sortedOrder_.erase (it);
|
||||||
|
linkMap_.find (link)->second = sortedOrder_.insert (link);
|
||||||
|
|
||||||
|
// update the messages that depend on message source --> destin
|
||||||
|
CFactorSet factorNeighbors = link->getVariable()->getFactors();
|
||||||
|
for (unsigned i = 0; i < factorNeighbors.size(); i++) {
|
||||||
|
if (factorNeighbors[i] != link->getFactor()) {
|
||||||
|
CLinkSet links = factorsI_[factorNeighbors[i]->getIndex()]->getLinks();
|
||||||
|
for (unsigned j = 0; j < links.size(); j++) {
|
||||||
|
if (links[j]->getVariable() != link->getVariable()) {
|
||||||
|
if (DL >= 2 && DL < 5) {
|
||||||
|
cout << " calculating " << links[j]->toString() << endl;
|
||||||
}
|
}
|
||||||
|
links[j]->setNextMessage (getFactor2VarMsg (links[j]));
|
||||||
|
LinkMap::iterator iter = linkMap_.find (links[j]);
|
||||||
|
sortedOrder_.erase (iter->second);
|
||||||
|
iter->second = sortedOrder_.insert (links[j]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
sort (updateOrder_.begin(), updateOrder_.end(), compareResidual);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
ParamSet
|
||||||
SPSolver::updateMessage (const Link& link)
|
SPSolver::getFactor2VarMsg (const Link* link) const
|
||||||
{
|
{
|
||||||
updateMessage (link.source, link.destination);
|
const Factor* src = link->getFactor();
|
||||||
}
|
const FgVarNode* dest = link->getVariable();
|
||||||
|
CFgVarSet neighbors = src->getFgVarNodes();
|
||||||
|
CLinkSet links = factorsI_[src->getIndex()]->getLinks();
|
||||||
|
// calculate the product of messages that were sent
|
||||||
void
|
|
||||||
SPSolver::updateMessage (const Factor* src, const FgVarNode* dest)
|
|
||||||
{
|
|
||||||
msgs_[dest->getIndex()]->updateMessage (src);
|
|
||||||
/* cout << src->getLabel() << " --> " << dest->getLabel() << endl;
|
|
||||||
cout << " m: " ;
|
|
||||||
Message msg = msgs_[dest->getIndex()]->getMessage (src);
|
|
||||||
for (unsigned i = 0; i < msg.size(); i++) {
|
|
||||||
if (i != 0) cout << ", " ;
|
|
||||||
cout << msg[i];
|
|
||||||
}
|
|
||||||
cout << endl;
|
|
||||||
*/
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
SPSolver::calculateNextMessage (const Link& link)
|
|
||||||
{
|
|
||||||
calculateNextMessage (link.source, link.destination);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
SPSolver::calculateNextMessage (const Factor* src, const FgVarNode* dest)
|
|
||||||
{
|
|
||||||
FgVarSet neighbors = src->getFgVarNodes();
|
|
||||||
// calculate the product of MessageBankets sended
|
|
||||||
// to factor `src', except from var `dest'
|
// to factor `src', except from var `dest'
|
||||||
Factor result = *src;
|
Factor result (*src);
|
||||||
for (unsigned i = 0; i < neighbors.size(); i++) {
|
Factor temp;
|
||||||
if (neighbors[i] != dest) {
|
if (DL >= 5) {
|
||||||
Message msg (neighbors[i]->getDomainSize(), 1);
|
cout << "calculating " ;
|
||||||
calculateVarFactorMessage (neighbors[i], src, msg);
|
cout << src->getLabel() << " --> " << dest->getLabel();
|
||||||
result *= Factor (neighbors[i], msg);
|
cout << endl;
|
||||||
}
|
|
||||||
}
|
}
|
||||||
// marginalize all vars except `dest'
|
|
||||||
for (unsigned i = 0; i < neighbors.size(); i++) {
|
for (unsigned i = 0; i < neighbors.size(); i++) {
|
||||||
if (neighbors[i] != dest) {
|
if (links[i]->getVariable() != dest) {
|
||||||
result.marginalizeVariable (neighbors[i]);
|
if (DL >= 5) {
|
||||||
}
|
cout << " message from " << links[i]->getVariable()->getLabel();
|
||||||
}
|
cout << ": " ;
|
||||||
msgs_[dest->getIndex()]->setNextMessage (src, result.getParameters());
|
ParamSet p = getVar2FactorMsg (links[i]);
|
||||||
}
|
cout << endl;
|
||||||
|
Factor temp2 (links[i]->getVariable(), p);
|
||||||
|
temp.multiplyByFactor (temp2);
|
||||||
|
temp2.freeDistribution();
|
||||||
void
|
|
||||||
SPSolver::calculateVarFactorMessage (const FgVarNode* src,
|
|
||||||
const Factor* dest,
|
|
||||||
Message& placeholder) const
|
|
||||||
{
|
|
||||||
assert (src->getDomainSize() == (int)placeholder.size());
|
|
||||||
if (src->hasEvidence()) {
|
|
||||||
for (unsigned i = 0; i < placeholder.size(); i++) {
|
|
||||||
if ((int)i != src->getEvidence()) {
|
|
||||||
placeholder[i] = 0.0;
|
|
||||||
} else {
|
} else {
|
||||||
placeholder[i] = 1.0;
|
Factor temp2 (links[i]->getVariable(), getVar2FactorMsg (links[i]));
|
||||||
}
|
temp.multiplyByFactor (temp2);
|
||||||
}
|
temp2.freeDistribution();
|
||||||
|
|
||||||
} else {
|
|
||||||
|
|
||||||
MessageBanket* mb = msgs_[src->getIndex()];
|
|
||||||
vector<Factor*> neighbors = src->getFactors();
|
|
||||||
for (unsigned i = 0; i < neighbors.size(); i++) {
|
|
||||||
if (neighbors[i] != dest) {
|
|
||||||
const Message& fromFactor = mb->getMessage (neighbors[i]);
|
|
||||||
for (unsigned j = 0; j < fromFactor.size(); j++) {
|
|
||||||
placeholder[j] *= fromFactor[j];
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (links.size() >= 2) {
|
||||||
|
result.multiplyByFactor (temp, &(src->getCptEntries()));
|
||||||
|
if (DL >= 5) {
|
||||||
|
cout << " message product: " ;
|
||||||
|
cout << Util::parametersToString (temp.getParameters()) << endl;
|
||||||
|
cout << " factor product: " ;
|
||||||
|
cout << Util::parametersToString (src->getParameters());
|
||||||
|
cout << " x " ;
|
||||||
|
cout << Util::parametersToString (temp.getParameters());
|
||||||
|
cout << " = " ;
|
||||||
|
cout << Util::parametersToString (result.getParameters()) << endl;
|
||||||
|
}
|
||||||
|
temp.freeDistribution();
|
||||||
|
}
|
||||||
|
|
||||||
|
for (unsigned i = 0; i < links.size(); i++) {
|
||||||
|
if (links[i]->getVariable() != dest) {
|
||||||
|
result.removeVariable (links[i]->getVariable());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (DL >= 5) {
|
||||||
|
cout << " final message: " ;
|
||||||
|
cout << Util::parametersToString (result.getParameters()) << endl << endl;
|
||||||
|
}
|
||||||
|
ParamSet msg = result.getParameters();
|
||||||
|
result.freeDistribution();
|
||||||
|
return msg;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
ParamSet
|
||||||
|
SPSolver::getVar2FactorMsg (const Link* link) const
|
||||||
|
{
|
||||||
|
const FgVarNode* src = link->getVariable();
|
||||||
|
const Factor* dest = link->getFactor();
|
||||||
|
ParamSet msg;
|
||||||
|
if (src->hasEvidence()) {
|
||||||
|
msg.resize (src->getDomainSize(), 0.0);
|
||||||
|
msg[src->getEvidence()] = 1.0;
|
||||||
|
if (DL >= 5) {
|
||||||
|
cout << Util::parametersToString (msg);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
msg.resize (src->getDomainSize(), 1.0);
|
||||||
|
}
|
||||||
|
if (DL >= 5) {
|
||||||
|
cout << Util::parametersToString (msg);
|
||||||
|
}
|
||||||
|
CLinkSet links = varsI_[src->getIndex()]->getLinks();
|
||||||
|
for (unsigned i = 0; i < links.size(); i++) {
|
||||||
|
if (links[i]->getFactor() != dest) {
|
||||||
|
CParamSet msgFromFactor = links[i]->getMessage();
|
||||||
|
for (unsigned j = 0; j < msgFromFactor.size(); j++) {
|
||||||
|
msg[j] *= msgFromFactor[j];
|
||||||
|
}
|
||||||
|
if (DL >= 5) {
|
||||||
|
cout << " x " << Util::parametersToString (msgFromFactor);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (DL >= 5) {
|
||||||
|
cout << " = " << Util::parametersToString (msg);
|
||||||
|
}
|
||||||
|
return msg;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,10 +1,8 @@
|
|||||||
#ifndef BP_SPSOLVER_H
|
#ifndef BP_SP_SOLVER_H
|
||||||
#define BP_SPSOLVER_H
|
#define BP_SP_SOLVER_H
|
||||||
|
|
||||||
#include <cmath>
|
|
||||||
#include <map>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
#include <set>
|
||||||
|
|
||||||
#include "Solver.h"
|
#include "Solver.h"
|
||||||
#include "FgVarNode.h"
|
#include "FgVarNode.h"
|
||||||
@ -15,157 +13,118 @@ using namespace std;
|
|||||||
class FactorGraph;
|
class FactorGraph;
|
||||||
class SPSolver;
|
class SPSolver;
|
||||||
|
|
||||||
struct Link
|
|
||||||
{
|
|
||||||
Link (Factor* s, FgVarNode* d)
|
|
||||||
{
|
|
||||||
source = s;
|
|
||||||
destination = d;
|
|
||||||
}
|
|
||||||
string toString (void) const
|
|
||||||
{
|
|
||||||
stringstream ss;
|
|
||||||
ss << source->getLabel() << " --> " ;
|
|
||||||
ss << destination->getLabel();
|
|
||||||
return ss.str();
|
|
||||||
}
|
|
||||||
Factor* source;
|
|
||||||
FgVarNode* destination;
|
|
||||||
static SPSolver* klass;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
class Link
|
||||||
|
|
||||||
class MessageBanket
|
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
MessageBanket (const FgVarNode* var)
|
Link (Factor* f, FgVarNode* v)
|
||||||
|
{
|
||||||
|
factor_ = f;
|
||||||
|
var_ = v;
|
||||||
|
currMsg_.resize (v->getDomainSize(), 1);
|
||||||
|
nextMsg_.resize (v->getDomainSize(), 1);
|
||||||
|
msgSended_ = false;
|
||||||
|
residual_ = 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void setMessage (ParamSet msg)
|
||||||
{
|
{
|
||||||
vector<Factor*> sources = var->getFactors();
|
Util::normalize (msg);
|
||||||
for (unsigned i = 0; i < sources.size(); i++) {
|
residual_ = Util::getMaxNorm (currMsg_, msg);
|
||||||
indexMap_.insert (make_pair (sources[i], i));
|
currMsg_ = msg;
|
||||||
currMsgs_.push_back (Message(var->getDomainSize(), 1));
|
|
||||||
nextMsgs_.push_back (Message(var->getDomainSize(), -10));
|
|
||||||
residuals_.push_back (0.0);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void updateMessage (const Factor* source)
|
void setNextMessage (CParamSet msg)
|
||||||
{
|
{
|
||||||
unsigned idx = getIndex(source);
|
nextMsg_ = msg;
|
||||||
currMsgs_[idx] = nextMsgs_[idx];
|
Util::normalize (nextMsg_);
|
||||||
|
residual_ = Util::getMaxNorm (currMsg_, nextMsg_);
|
||||||
}
|
}
|
||||||
|
|
||||||
void setNextMessage (const Factor* source, const Message& msg)
|
void updateMessage (void)
|
||||||
{
|
{
|
||||||
unsigned idx = getIndex(source);
|
currMsg_ = nextMsg_;
|
||||||
nextMsgs_[idx] = msg;
|
msgSended_ = true;
|
||||||
residuals_[idx] = computeResidual (source);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const Message& getMessage (const Factor* source) const
|
string toString (void) const
|
||||||
{
|
{
|
||||||
return currMsgs_[getIndex(source)];
|
stringstream ss;
|
||||||
}
|
ss << factor_->getLabel();
|
||||||
|
ss << " -- " ;
|
||||||
double getResidual (const Factor* source) const
|
ss << var_->getLabel();
|
||||||
{
|
return ss.str();
|
||||||
return residuals_[getIndex(source)];
|
|
||||||
}
|
|
||||||
|
|
||||||
void resetResidual (const Factor* source)
|
|
||||||
{
|
|
||||||
residuals_[getIndex(source)] = 0.0;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Factor* getFactor (void) const { return factor_; }
|
||||||
|
FgVarNode* getVariable (void) const { return var_; }
|
||||||
|
CParamSet getMessage (void) const { return currMsg_; }
|
||||||
|
bool messageWasSended (void) const { return msgSended_; }
|
||||||
|
double getResidual (void) const { return residual_; }
|
||||||
|
void clearResidual (void) { residual_ = 0.0; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
double computeResidual (const Factor* source)
|
Factor* factor_;
|
||||||
{
|
FgVarNode* var_;
|
||||||
double change = 0.0;
|
ParamSet currMsg_;
|
||||||
unsigned idx = getIndex (source);
|
ParamSet nextMsg_;
|
||||||
const Message& currMessage = currMsgs_[idx];
|
bool msgSended_;
|
||||||
const Message& nextMessage = nextMsgs_[idx];
|
double residual_;
|
||||||
for (unsigned i = 0; i < currMessage.size(); i++) {
|
|
||||||
change += abs (currMessage[i] - nextMessage[i]);
|
|
||||||
}
|
|
||||||
return change;
|
|
||||||
}
|
|
||||||
|
|
||||||
unsigned getIndex (const Factor* factor) const
|
|
||||||
{
|
|
||||||
assert (factor);
|
|
||||||
assert (indexMap_.find(factor) != indexMap_.end());
|
|
||||||
return indexMap_.find(factor)->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
typedef map<const Factor*, unsigned> IndexMap;
|
|
||||||
|
|
||||||
IndexMap indexMap_;
|
|
||||||
vector<Message> currMsgs_;
|
|
||||||
vector<Message> nextMsgs_;
|
|
||||||
vector<double> residuals_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
class SPNodeInfo
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
void addLink (Link* link) { links_.push_back (link); }
|
||||||
|
CLinkSet getLinks (void) { return links_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
LinkSet links_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
class SPSolver : public Solver
|
class SPSolver : public Solver
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
SPSolver (const FactorGraph&);
|
SPSolver (FactorGraph&);
|
||||||
~SPSolver (void);
|
virtual ~SPSolver (void);
|
||||||
|
|
||||||
void runSolver (void);
|
void runSolver (void);
|
||||||
ParamSet getPosterioriOf (const Variable* var) const;
|
virtual ParamSet getPosterioriOf (Vid) const;
|
||||||
|
ParamSet getJointDistributionOf (CVidSet);
|
||||||
|
|
||||||
|
protected:
|
||||||
|
virtual void initializeSolver (void);
|
||||||
|
void runTreeSolver (void);
|
||||||
|
bool readyToSendMessage (const Link*) const;
|
||||||
|
virtual void createLinks (void);
|
||||||
|
virtual void deleteJunction (Factor*, FgVarNode*);
|
||||||
|
bool converged (void);
|
||||||
|
virtual void maxResidualSchedule (void);
|
||||||
|
virtual ParamSet getFactor2VarMsg (const Link*) const;
|
||||||
|
virtual ParamSet getVar2FactorMsg (const Link*) const;
|
||||||
|
|
||||||
private:
|
struct CompareResidual {
|
||||||
bool converged (void);
|
inline bool operator() (const Link* link1, const Link* link2)
|
||||||
void maxResidualSchedule (void);
|
{
|
||||||
void updateMessage (const Link&);
|
return link1->getResidual() > link2->getResidual();
|
||||||
void updateMessage (const Factor*, const FgVarNode*);
|
}
|
||||||
void calculateNextMessage (const Link&);
|
};
|
||||||
void calculateNextMessage (const Factor*, const FgVarNode*);
|
|
||||||
void calculateVarFactorMessage (
|
FactorGraph* fg_;
|
||||||
const FgVarNode*, const Factor*, Message&) const;
|
LinkSet links_;
|
||||||
double getResidual (const Link&) const;
|
vector<SPNodeInfo*> varsI_;
|
||||||
void resetResidual (const Link&) const;
|
vector<SPNodeInfo*> factorsI_;
|
||||||
friend bool compareResidual (const Link&, const Link&);
|
unsigned nIter_;
|
||||||
|
|
||||||
|
typedef multiset<Link*, CompareResidual> SortedOrder;
|
||||||
|
SortedOrder sortedOrder_;
|
||||||
|
|
||||||
|
typedef map<Link*, SortedOrder::iterator> LinkMap;
|
||||||
|
LinkMap linkMap_;
|
||||||
|
|
||||||
const FactorGraph* fg_;
|
|
||||||
vector<MessageBanket*> msgs_;
|
|
||||||
Schedule schedule_;
|
|
||||||
int nIter_;
|
|
||||||
double accuracy_;
|
|
||||||
int maxIter_;
|
|
||||||
vector<Link> updateOrder_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#endif // BP_SP_SOLVER_H
|
||||||
|
|
||||||
inline double
|
|
||||||
SPSolver::getResidual (const Link& link) const
|
|
||||||
{
|
|
||||||
MessageBanket* mb = Link::klass->msgs_[link.destination->getIndex()];
|
|
||||||
return mb->getResidual (link.source);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline void
|
|
||||||
SPSolver::resetResidual (const Link& link) const
|
|
||||||
{
|
|
||||||
MessageBanket* mb = Link::klass->msgs_[link.destination->getIndex()];
|
|
||||||
mb->resetResidual (link.source);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline bool
|
|
||||||
compareResidual (const Link& link1, const Link& link2)
|
|
||||||
{
|
|
||||||
MessageBanket* mb1 = Link::klass->msgs_[link1.destination->getIndex()];
|
|
||||||
MessageBanket* mb2 = Link::klass->msgs_[link2.destination->getIndex()];
|
|
||||||
return mb1->getResidual(link1.source) > mb2->getResidual(link2.source);
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
|
@ -2,14 +2,15 @@
|
|||||||
#define BP_SHARED_H
|
#define BP_SHARED_H
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <iostream>
|
|
||||||
#include <fstream>
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
||||||
// Macro to disallow the copy constructor and operator= functions
|
#include <iostream>
|
||||||
|
#include <fstream>
|
||||||
|
#include <iomanip>
|
||||||
|
|
||||||
#define DISALLOW_COPY_AND_ASSIGN(TypeName) \
|
#define DISALLOW_COPY_AND_ASSIGN(TypeName) \
|
||||||
TypeName(const TypeName&); \
|
TypeName(const TypeName&); \
|
||||||
void operator=(const TypeName&)
|
void operator=(const TypeName&)
|
||||||
@ -19,61 +20,162 @@ using namespace std;
|
|||||||
class Variable;
|
class Variable;
|
||||||
class BayesNode;
|
class BayesNode;
|
||||||
class FgVarNode;
|
class FgVarNode;
|
||||||
|
class Factor;
|
||||||
|
class Link;
|
||||||
|
class Edge;
|
||||||
|
|
||||||
typedef double Param;
|
typedef double Param;
|
||||||
typedef vector<Param> ParamSet;
|
typedef vector<Param> ParamSet;
|
||||||
typedef vector<Param> Message;
|
typedef const ParamSet& CParamSet;
|
||||||
|
typedef unsigned Vid;
|
||||||
|
typedef vector<Vid> VidSet;
|
||||||
|
typedef const VidSet& CVidSet;
|
||||||
typedef vector<Variable*> VarSet;
|
typedef vector<Variable*> VarSet;
|
||||||
typedef vector<BayesNode*> NodeSet;
|
typedef vector<BayesNode*> BnNodeSet;
|
||||||
|
typedef const BnNodeSet& CBnNodeSet;
|
||||||
typedef vector<FgVarNode*> FgVarSet;
|
typedef vector<FgVarNode*> FgVarSet;
|
||||||
|
typedef const FgVarSet& CFgVarSet;
|
||||||
|
typedef vector<Factor*> FactorSet;
|
||||||
|
typedef const FactorSet& CFactorSet;
|
||||||
|
typedef vector<Link*> LinkSet;
|
||||||
|
typedef const LinkSet& CLinkSet;
|
||||||
|
typedef vector<Edge*> EdgeSet;
|
||||||
|
typedef const EdgeSet& CEdgeSet;
|
||||||
typedef vector<string> Domain;
|
typedef vector<string> Domain;
|
||||||
typedef vector<unsigned> DomainConf;
|
typedef vector<unsigned> DConf;
|
||||||
typedef pair<unsigned, unsigned> DomainConstr;
|
typedef pair<unsigned, unsigned> DConstraint;
|
||||||
typedef unordered_map<unsigned, unsigned> IndexMap;
|
typedef map<unsigned, unsigned> IndexMap;
|
||||||
|
|
||||||
|
// level of debug information
|
||||||
//extern unsigned DL;
|
|
||||||
static const unsigned DL = 0;
|
static const unsigned DL = 0;
|
||||||
|
|
||||||
// number of digits to show when printing a parameter
|
static const int NO_EVIDENCE = -1;
|
||||||
static const unsigned PRECISION = 10;
|
|
||||||
|
|
||||||
// shared by bp and sp solver
|
// number of digits to show when printing a parameter
|
||||||
enum Schedule
|
static const unsigned PRECISION = 5;
|
||||||
|
|
||||||
|
static const bool EXPORT_TO_DOT = false;
|
||||||
|
static const unsigned EXPORT_MIN_SIZE = 30;
|
||||||
|
|
||||||
|
|
||||||
|
namespace SolverOptions
|
||||||
{
|
{
|
||||||
S_SEQ_FIXED,
|
enum Schedule
|
||||||
S_SEQ_RANDOM,
|
{
|
||||||
S_PARALLEL,
|
S_SEQ_FIXED,
|
||||||
S_MAX_RESIDUAL
|
S_SEQ_RANDOM,
|
||||||
|
S_PARALLEL,
|
||||||
|
S_MAX_RESIDUAL
|
||||||
|
};
|
||||||
|
extern bool runBayesBall;
|
||||||
|
extern bool convertBn2Fg;
|
||||||
|
extern bool compressFactorGraph;
|
||||||
|
extern Schedule schedule;
|
||||||
|
extern double accuracy;
|
||||||
|
extern unsigned maxIter;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
namespace Util
|
||||||
|
{
|
||||||
|
void normalize (ParamSet&);
|
||||||
|
void pow (ParamSet&, unsigned);
|
||||||
|
double getL1dist (CParamSet, CParamSet);
|
||||||
|
double getMaxNorm (CParamSet, CParamSet);
|
||||||
|
bool isInteger (const string&);
|
||||||
|
string parametersToString (CParamSet);
|
||||||
|
vector<DConf> getDomainConfigurations (const VarSet&);
|
||||||
|
vector<string> getInstantiations (const VarSet&);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
struct NetInfo
|
struct NetInfo
|
||||||
{
|
{
|
||||||
NetInfo (unsigned c, double t)
|
NetInfo (void)
|
||||||
{
|
{
|
||||||
counting = c;
|
counting = 0;
|
||||||
solvingTime = t;
|
nIters = 0;
|
||||||
|
solvingTime = 0.0;
|
||||||
}
|
}
|
||||||
unsigned counting;
|
unsigned counting;
|
||||||
double solvingTime;
|
double solvingTime;
|
||||||
|
unsigned nIters;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
struct CompressInfo
|
||||||
|
{
|
||||||
|
CompressInfo (unsigned a, unsigned b, unsigned c,
|
||||||
|
unsigned d, unsigned e) {
|
||||||
|
nUncVars = a;
|
||||||
|
nUncFactors = b;
|
||||||
|
nCompVars = c;
|
||||||
|
nCompFactors = d;
|
||||||
|
nNeighborlessVars = e;
|
||||||
|
}
|
||||||
|
unsigned nUncVars;
|
||||||
|
unsigned nUncFactors;
|
||||||
|
unsigned nCompVars;
|
||||||
|
unsigned nCompFactors;
|
||||||
|
unsigned nNeighborlessVars;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
typedef map<unsigned, NetInfo> StatisticMap;
|
typedef map<unsigned, NetInfo> StatisticMap;
|
||||||
|
|
||||||
|
|
||||||
class Statistics
|
class Statistics
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
|
|
||||||
static void updateStats (unsigned size, double time)
|
static void updateStats (unsigned size, unsigned nIters, double time)
|
||||||
{
|
{
|
||||||
StatisticMap::iterator it = stats_.find(size);
|
StatisticMap::iterator it = stats_.find (size);
|
||||||
if (it == stats_.end()) {
|
if (it == stats_.end()) {
|
||||||
stats_.insert (make_pair (size, NetInfo (1, 0.0)));
|
it = (stats_.insert (make_pair (size, NetInfo()))).first;
|
||||||
} else {
|
} else {
|
||||||
it->second.counting ++;
|
it->second.counting ++;
|
||||||
|
it->second.nIters += nIters;
|
||||||
it->second.solvingTime += time;
|
it->second.solvingTime += time;
|
||||||
|
totalOfIterations += nIters;
|
||||||
|
if (nIters > maxIterations) {
|
||||||
|
maxIterations = nIters;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void updateCompressingStats (unsigned nUncVars,
|
||||||
|
unsigned nUncFactors,
|
||||||
|
unsigned nCompVars,
|
||||||
|
unsigned nCompFactors,
|
||||||
|
unsigned nNeighborlessVars) {
|
||||||
|
compressInfo_.push_back (CompressInfo (
|
||||||
|
nUncVars, nUncFactors, nCompVars, nCompFactors, nNeighborlessVars));
|
||||||
|
}
|
||||||
|
|
||||||
|
static void printCompressingStats (const char* fileName)
|
||||||
|
{
|
||||||
|
ofstream out (fileName);
|
||||||
|
if (!out.is_open()) {
|
||||||
|
cerr << "error: cannot open file to write at " ;
|
||||||
|
cerr << "BayesNet::printCompressingStats()" << endl;
|
||||||
|
abort();
|
||||||
|
}
|
||||||
|
out << "--------------------------------------" ;
|
||||||
|
out << "--------------------------------------" << endl;
|
||||||
|
out << " Compression Stats" << endl;
|
||||||
|
out << "--------------------------------------" ;
|
||||||
|
out << "--------------------------------------" << endl;
|
||||||
|
out << left;
|
||||||
|
out << "Uncompress Compressed Uncompress Compressed Neighborless";
|
||||||
|
out << endl;
|
||||||
|
out << "Vars Vars Factors Factors Vars" ;
|
||||||
|
out << endl;
|
||||||
|
for (unsigned i = 0; i < compressInfo_.size(); i++) {
|
||||||
|
out << setw (13) << compressInfo_[i].nUncVars;
|
||||||
|
out << setw (13) << compressInfo_[i].nCompVars;
|
||||||
|
out << setw (13) << compressInfo_[i].nUncFactors;
|
||||||
|
out << setw (13) << compressInfo_[i].nCompFactors;
|
||||||
|
out << setw (13) << compressInfo_[i].nNeighborlessVars;
|
||||||
|
out << endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -84,20 +186,12 @@ class Statistics
|
|||||||
return it->second.counting;
|
return it->second.counting;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void updateIterations (unsigned nIters)
|
|
||||||
{
|
|
||||||
totalOfIterations += nIters;
|
|
||||||
if (nIters > maxIterations) {
|
|
||||||
maxIterations = nIters;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static void writeStats (void)
|
static void writeStats (void)
|
||||||
{
|
{
|
||||||
ofstream out ("../../stats.txt");
|
ofstream out ("../../stats.txt");
|
||||||
if (!out.is_open()) {
|
if (!out.is_open()) {
|
||||||
cerr << "error: cannot open file to write at " ;
|
cerr << "error: cannot open file to write at " ;
|
||||||
cerr << "Statistics:::updateStats()" << endl;
|
cerr << "Statistics::updateStats()" << endl;
|
||||||
abort();
|
abort();
|
||||||
}
|
}
|
||||||
unsigned avgIterations = 0;
|
unsigned avgIterations = 0;
|
||||||
@ -117,17 +211,24 @@ class Statistics
|
|||||||
out << " average iterations: " << avgIterations << endl;
|
out << " average iterations: " << avgIterations << endl;
|
||||||
out << "total solving time " << totalSolvingTime << endl;
|
out << "total solving time " << totalSolvingTime << endl;
|
||||||
out << endl;
|
out << endl;
|
||||||
out << "Network Size\tCounting\tSolving Time\tAverage Time" << endl;
|
out << left << endl;
|
||||||
|
out << setw (15) << "Network Size" ;
|
||||||
|
out << setw (15) << "Counting" ;
|
||||||
|
out << setw (15) << "Solving Time" ;
|
||||||
|
out << setw (15) << "Average Time" ;
|
||||||
|
out << setw (15) << "#Iterations" ;
|
||||||
|
out << endl;
|
||||||
for (StatisticMap::iterator it = stats_.begin();
|
for (StatisticMap::iterator it = stats_.begin();
|
||||||
it != stats_.end(); it++) {
|
it != stats_.end(); it++) {
|
||||||
out << it->first;
|
out << setw (15) << it->first;
|
||||||
out << "\t\t" << it->second.counting;
|
out << setw (15) << it->second.counting;
|
||||||
out << "\t\t" << it->second.solvingTime;
|
out << setw (15) << it->second.solvingTime;
|
||||||
if (it->second.counting > 0) {
|
if (it->second.counting > 0) {
|
||||||
out << "\t\t" << it->second.solvingTime / it->second.counting;
|
out << setw (15) << it->second.solvingTime / it->second.counting;
|
||||||
} else {
|
} else {
|
||||||
out << "\t\t0.0" ;
|
out << setw (15) << "0.0" ;
|
||||||
}
|
}
|
||||||
|
out << setw (15) << it->second.nIters;
|
||||||
out << endl;
|
out << endl;
|
||||||
}
|
}
|
||||||
out.close();
|
out.close();
|
||||||
@ -142,62 +243,8 @@ class Statistics
|
|||||||
static StatisticMap stats_;
|
static StatisticMap stats_;
|
||||||
static unsigned maxIterations;
|
static unsigned maxIterations;
|
||||||
static unsigned totalOfIterations;
|
static unsigned totalOfIterations;
|
||||||
|
static vector<CompressInfo> compressInfo_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#endif //BP_SHARED_H
|
||||||
|
|
||||||
class Util
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
static void normalize (ParamSet& v)
|
|
||||||
{
|
|
||||||
double sum = 0.0;
|
|
||||||
for (unsigned i = 0; i < v.size(); i++) {
|
|
||||||
sum += v[i];
|
|
||||||
}
|
|
||||||
assert (sum != 0.0);
|
|
||||||
for (unsigned i = 0; i < v.size(); i++) {
|
|
||||||
v[i] /= sum;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static double getL1dist (const ParamSet& v1, const ParamSet& v2)
|
|
||||||
{
|
|
||||||
assert (v1.size() == v2.size());
|
|
||||||
double dist = 0.0;
|
|
||||||
for (unsigned i = 0; i < v1.size(); i++) {
|
|
||||||
dist += abs (v1[i] - v2[i]);
|
|
||||||
}
|
|
||||||
return dist;
|
|
||||||
}
|
|
||||||
|
|
||||||
static double getMaxNorm (const ParamSet& v1, const ParamSet& v2)
|
|
||||||
{
|
|
||||||
assert (v1.size() == v2.size());
|
|
||||||
double max = 0.0;
|
|
||||||
for (unsigned i = 0; i < v1.size(); i++) {
|
|
||||||
double diff = abs (v1[i] - v2[i]);
|
|
||||||
if (diff > max) {
|
|
||||||
max = diff;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return max;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool isInteger (const string& s)
|
|
||||||
{
|
|
||||||
stringstream ss1 (s);
|
|
||||||
stringstream ss2;
|
|
||||||
int integer;
|
|
||||||
ss1 >> integer;
|
|
||||||
ss2 << integer;
|
|
||||||
return (ss1.str() == ss2.str());
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
//unsigned Statistics::totalOfIterations = 0;
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
|
@ -15,19 +15,30 @@ class Solver
|
|||||||
{
|
{
|
||||||
gm_ = gm;
|
gm_ = gm;
|
||||||
}
|
}
|
||||||
|
virtual ~Solver() {} // to call subclass destructor
|
||||||
virtual void runSolver (void) = 0;
|
virtual void runSolver (void) = 0;
|
||||||
virtual ParamSet getPosterioriOf (const Variable*) const = 0;
|
virtual ParamSet getPosterioriOf (Vid) const = 0;
|
||||||
|
virtual ParamSet getJointDistributionOf (const VidSet&) = 0;
|
||||||
|
|
||||||
void printPosterioriOf (const Variable* var) const
|
void printAllPosterioris (void) const
|
||||||
{
|
{
|
||||||
|
VarSet vars = gm_->getVariables();
|
||||||
|
for (unsigned i = 0; i < vars.size(); i++) {
|
||||||
|
printPosterioriOf (vars[i]->getVarId());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void printPosterioriOf (Vid vid) const
|
||||||
|
{
|
||||||
|
Variable* var = gm_->getVariable (vid);
|
||||||
cout << endl;
|
cout << endl;
|
||||||
cout << setw (20) << left << var->getLabel() << "posteriori" ;
|
cout << setw (20) << left << var->getLabel() << "posteriori" ;
|
||||||
cout << endl;
|
cout << endl;
|
||||||
cout << "------------------------------" ;
|
cout << "------------------------------" ;
|
||||||
cout << endl;
|
cout << endl;
|
||||||
const Domain& domain = var->getDomain();
|
const Domain& domain = var->getDomain();
|
||||||
ParamSet results = getPosterioriOf (var);
|
ParamSet results = getPosterioriOf (vid);
|
||||||
for (int xi = 0; xi < var->getDomainSize(); xi++) {
|
for (unsigned xi = 0; xi < var->getDomainSize(); xi++) {
|
||||||
cout << setw (20) << domain[xi];
|
cout << setw (20) << domain[xi];
|
||||||
cout << setprecision (PRECISION) << results[xi];
|
cout << setprecision (PRECISION) << results[xi];
|
||||||
cout << endl;
|
cout << endl;
|
||||||
@ -35,16 +46,35 @@ class Solver
|
|||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
void printAllPosterioris (void) const
|
void printJointDistributionOf (const VidSet& vids)
|
||||||
{
|
{
|
||||||
VarSet vars = gm_->getVariables();
|
const ParamSet& jointDist = getJointDistributionOf (vids);
|
||||||
for (unsigned i = 0; i < vars.size(); i++) {
|
cout << endl;
|
||||||
printPosterioriOf (vars[i]);
|
cout << "joint distribution of " ;
|
||||||
|
VarSet vars;
|
||||||
|
for (unsigned i = 0; i < vids.size() - 1; i++) {
|
||||||
|
Variable* var = gm_->getVariable (vids[i]);
|
||||||
|
cout << var->getLabel() << ", " ;
|
||||||
|
vars.push_back (var);
|
||||||
}
|
}
|
||||||
|
Variable* var = gm_->getVariable (vids[vids.size() - 1]);
|
||||||
|
cout << var->getLabel() ;
|
||||||
|
vars.push_back (var);
|
||||||
|
cout << endl;
|
||||||
|
cout << "------------------------------" ;
|
||||||
|
cout << endl;
|
||||||
|
const vector<string>& domainConfs = Util::getInstantiations (vars);
|
||||||
|
for (unsigned i = 0; i < jointDist.size(); i++) {
|
||||||
|
cout << left << setw (20) << domainConfs[i];
|
||||||
|
cout << setprecision (PRECISION) << jointDist[i];
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
cout << endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const GraphicalModel* gm_;
|
const GraphicalModel* gm_;
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif
|
#endif //BP_SOLVER_H
|
||||||
|
|
||||||
|
191
packages/CLPBN/clpbn/bp/Util.cpp
Normal file
191
packages/CLPBN/clpbn/bp/Util.cpp
Normal file
@ -0,0 +1,191 @@
|
|||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "Variable.h"
|
||||||
|
#include "Shared.h"
|
||||||
|
|
||||||
|
namespace SolverOptions {
|
||||||
|
|
||||||
|
bool runBayesBall = false;
|
||||||
|
bool convertBn2Fg = true;
|
||||||
|
bool compressFactorGraph = true;
|
||||||
|
Schedule schedule = S_SEQ_FIXED;
|
||||||
|
//Schedule schedule = S_SEQ_RANDOM;
|
||||||
|
//Schedule schedule = S_PARALLEL;
|
||||||
|
//Schedule schedule = S_MAX_RESIDUAL;
|
||||||
|
double accuracy = 0.0001;
|
||||||
|
unsigned maxIter = 1000; //FIXME
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
unsigned Statistics::numCreatedNets = 0;
|
||||||
|
unsigned Statistics::numSolvedPolyTrees = 0;
|
||||||
|
unsigned Statistics::numSolvedLoopyNets = 0;
|
||||||
|
unsigned Statistics::numUnconvergedRuns = 0;
|
||||||
|
unsigned Statistics::maxIterations = 0;
|
||||||
|
unsigned Statistics::totalOfIterations = 0;
|
||||||
|
vector<CompressInfo> Statistics::compressInfo_;
|
||||||
|
StatisticMap Statistics::stats_;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
namespace Util {
|
||||||
|
|
||||||
|
void
|
||||||
|
normalize (ParamSet& v)
|
||||||
|
{
|
||||||
|
double sum = 0.0;
|
||||||
|
for (unsigned i = 0; i < v.size(); i++) {
|
||||||
|
sum += v[i];
|
||||||
|
}
|
||||||
|
assert (sum != 0.0);
|
||||||
|
for (unsigned i = 0; i < v.size(); i++) {
|
||||||
|
v[i] /= sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
pow (ParamSet& v, unsigned expoent)
|
||||||
|
{
|
||||||
|
for (unsigned i = 0; i < v.size(); i++) {
|
||||||
|
double value = 1;
|
||||||
|
for (unsigned j = 0; j < expoent; j++) {
|
||||||
|
value *= v[i];
|
||||||
|
}
|
||||||
|
v[i] = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
double
|
||||||
|
getL1dist (const ParamSet& v1, const ParamSet& v2)
|
||||||
|
{
|
||||||
|
assert (v1.size() == v2.size());
|
||||||
|
double dist = 0.0;
|
||||||
|
for (unsigned i = 0; i < v1.size(); i++) {
|
||||||
|
dist += abs (v1[i] - v2[i]);
|
||||||
|
}
|
||||||
|
return dist;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
double
|
||||||
|
getMaxNorm (const ParamSet& v1, const ParamSet& v2)
|
||||||
|
{
|
||||||
|
assert (v1.size() == v2.size());
|
||||||
|
double max = 0.0;
|
||||||
|
for (unsigned i = 0; i < v1.size(); i++) {
|
||||||
|
double diff = abs (v1[i] - v2[i]);
|
||||||
|
if (diff > max) {
|
||||||
|
max = diff;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return max;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
isInteger (const string& s)
|
||||||
|
{
|
||||||
|
stringstream ss1 (s);
|
||||||
|
stringstream ss2;
|
||||||
|
int integer;
|
||||||
|
ss1 >> integer;
|
||||||
|
ss2 << integer;
|
||||||
|
return (ss1.str() == ss2.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
string
|
||||||
|
parametersToString (CParamSet v)
|
||||||
|
{
|
||||||
|
stringstream ss;
|
||||||
|
ss << "[" ;
|
||||||
|
for (unsigned i = 0; i < v.size() - 1; i++) {
|
||||||
|
ss << v[i] << ", " ;
|
||||||
|
}
|
||||||
|
if (v.size() != 0) {
|
||||||
|
ss << v[v.size() - 1];
|
||||||
|
}
|
||||||
|
ss << "]" ;
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
vector<DConf>
|
||||||
|
getDomainConfigurations (const VarSet& vars)
|
||||||
|
{
|
||||||
|
unsigned nConfs = 1;
|
||||||
|
for (unsigned i = 0; i < vars.size(); i++) {
|
||||||
|
nConfs *= vars[i]->getDomainSize();
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<DConf> confs (nConfs);
|
||||||
|
for (unsigned i = 0; i < nConfs; i++) {
|
||||||
|
confs[i].resize (vars.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned nReps = 1;
|
||||||
|
for (int i = vars.size() - 1; i >= 0; i--) {
|
||||||
|
unsigned index = 0;
|
||||||
|
while (index < nConfs) {
|
||||||
|
for (unsigned j = 0; j < vars[i]->getDomainSize(); j++) {
|
||||||
|
for (unsigned r = 0; r < nReps; r++) {
|
||||||
|
confs[index][i] = j;
|
||||||
|
index++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nReps *= vars[i]->getDomainSize();
|
||||||
|
}
|
||||||
|
return confs;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
vector<string>
|
||||||
|
getInstantiations (const VarSet& vars)
|
||||||
|
{
|
||||||
|
//FIXME handle variables without domain
|
||||||
|
/*
|
||||||
|
char c = 'a' ;
|
||||||
|
const DConf& conf = entries[i].getDomainConfiguration();
|
||||||
|
for (unsigned j = 0; j < conf.size(); j++) {
|
||||||
|
if (j != 0) ss << "," ;
|
||||||
|
ss << c << conf[j] + 1;
|
||||||
|
c ++;
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
unsigned rowSize = 1;
|
||||||
|
for (unsigned i = 0; i < vars.size(); i++) {
|
||||||
|
rowSize *= vars[i]->getDomainSize();
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<string> headers (rowSize);
|
||||||
|
|
||||||
|
unsigned nReps = 1;
|
||||||
|
for (int i = vars.size() - 1; i >= 0; i--) {
|
||||||
|
Domain domain = vars[i]->getDomain();
|
||||||
|
unsigned index = 0;
|
||||||
|
while (index < rowSize) {
|
||||||
|
for (unsigned j = 0; j < vars[i]->getDomainSize(); j++) {
|
||||||
|
for (unsigned r = 0; r < nReps; r++) {
|
||||||
|
if (headers[index] != "") {
|
||||||
|
headers[index] = domain[j] + ", " + headers[index];
|
||||||
|
} else {
|
||||||
|
headers[index] = domain[j];
|
||||||
|
}
|
||||||
|
index++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nReps *= vars[i]->getDomainSize();
|
||||||
|
}
|
||||||
|
return headers;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
@ -1,9 +1,10 @@
|
|||||||
#ifndef BP_GENERIC_VARIABLE_H
|
#ifndef BP_VARIABLE_H
|
||||||
#define BP_GENERIC_VARIABLE_H
|
#define BP_VARIABLE_H
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
#include "Shared.h"
|
#include "Shared.h"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
@ -12,33 +13,61 @@ class Variable
|
|||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
|
|
||||||
Variable (unsigned varId)
|
Variable (const Variable* v)
|
||||||
{
|
{
|
||||||
this->varId_ = varId;
|
vid_ = v->getVarId();
|
||||||
this->dsize_ = 0;
|
dsize_ = v->getDomainSize();
|
||||||
this->evidence_ = -1;
|
if (v->hasDomain()) {
|
||||||
this->label_ = 0;
|
domain_ = v->getDomain();
|
||||||
|
dsize_ = domain_.size();
|
||||||
|
} else {
|
||||||
|
dsize_ = v->getDomainSize();
|
||||||
|
}
|
||||||
|
evidence_ = v->getEvidence();
|
||||||
|
if (v->hasLabel()) {
|
||||||
|
label_ = new string (v->getLabel());
|
||||||
|
} else {
|
||||||
|
label_ = 0;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Variable (unsigned varId, unsigned dsize, int evidence = -1)
|
Variable (Vid vid)
|
||||||
|
{
|
||||||
|
this->vid_ = vid;
|
||||||
|
this->dsize_ = 0;
|
||||||
|
this->evidence_ = NO_EVIDENCE;
|
||||||
|
this->label_ = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
Variable (Vid vid, unsigned dsize, int evidence = NO_EVIDENCE,
|
||||||
|
const string& lbl = string())
|
||||||
{
|
{
|
||||||
assert (dsize != 0);
|
assert (dsize != 0);
|
||||||
assert (evidence < (int)dsize);
|
assert (evidence < (int)dsize);
|
||||||
this->varId_ = varId;
|
this->vid_ = vid;
|
||||||
this->dsize_ = dsize;
|
this->dsize_ = dsize;
|
||||||
this->evidence_ = evidence;
|
this->evidence_ = evidence;
|
||||||
this->label_ = 0;
|
if (!lbl.empty()) {
|
||||||
|
this->label_ = new string (lbl);
|
||||||
|
} else {
|
||||||
|
this->label_ = 0;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Variable (unsigned varId, const Domain& domain, int evidence = -1)
|
Variable (Vid vid, const Domain& domain, int evidence = NO_EVIDENCE,
|
||||||
|
const string& lbl = string())
|
||||||
{
|
{
|
||||||
assert (!domain.empty());
|
assert (!domain.empty());
|
||||||
assert (evidence < (int)domain.size());
|
assert (evidence < (int)domain.size());
|
||||||
this->varId_ = varId;
|
this->vid_ = vid;
|
||||||
this->dsize_ = domain.size();
|
this->dsize_ = domain.size();
|
||||||
this->domain_ = domain;
|
this->domain_ = domain;
|
||||||
this->evidence_ = evidence;
|
this->evidence_ = evidence;
|
||||||
this->label_ = 0;
|
if (!lbl.empty()) {
|
||||||
|
this->label_ = new string (lbl);
|
||||||
|
} else {
|
||||||
|
this->label_ = 0;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
~Variable (void)
|
~Variable (void)
|
||||||
@ -46,19 +75,19 @@ class Variable
|
|||||||
delete label_;
|
delete label_;
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned getVarId (void) const { return varId_; }
|
unsigned getVarId (void) const { return vid_; }
|
||||||
unsigned getIndex (void) const { return index_; }
|
unsigned getIndex (void) const { return index_; }
|
||||||
void setIndex (unsigned idx) { index_ = idx; }
|
void setIndex (unsigned idx) { index_ = idx; }
|
||||||
int getDomainSize (void) const { return dsize_; }
|
unsigned getDomainSize (void) const { return dsize_; }
|
||||||
bool hasEvidence (void) const { return evidence_ != -1; }
|
bool hasEvidence (void) const { return evidence_ != NO_EVIDENCE; }
|
||||||
int getEvidence (void) const { return evidence_; }
|
int getEvidence (void) const { return evidence_; }
|
||||||
bool hasDomain (void) { return !domain_.empty(); }
|
bool hasDomain (void) const { return !domain_.empty(); }
|
||||||
bool hasLabel (void) { return label_ != 0; }
|
bool hasLabel (void) const { return label_ != 0; }
|
||||||
|
|
||||||
bool isValidStateIndex (int index)
|
bool isValidStateIndex (int index)
|
||||||
{
|
{
|
||||||
return index >= 0 && index < dsize_;
|
return index >= 0 && index < (int)dsize_;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool isValidState (const string& state)
|
bool isValidState (const string& state)
|
||||||
{
|
{
|
||||||
@ -70,7 +99,7 @@ class Variable
|
|||||||
assert (dsize_ != 0);
|
assert (dsize_ != 0);
|
||||||
if (domain_.size() == 0) {
|
if (domain_.size() == 0) {
|
||||||
Domain d;
|
Domain d;
|
||||||
for (int i = 0; i < dsize_; i++) {
|
for (unsigned i = 0; i < dsize_; i++) {
|
||||||
stringstream ss;
|
stringstream ss;
|
||||||
ss << "x" << i ;
|
ss << "x" << i ;
|
||||||
d.push_back (ss.str());
|
d.push_back (ss.str());
|
||||||
@ -110,7 +139,7 @@ class Variable
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void setLabel (string label)
|
void setLabel (const string& label)
|
||||||
{
|
{
|
||||||
label_ = new string (label);
|
label_ = new string (label);
|
||||||
}
|
}
|
||||||
@ -119,25 +148,25 @@ class Variable
|
|||||||
{
|
{
|
||||||
if (label_ == 0) {
|
if (label_ == 0) {
|
||||||
stringstream ss;
|
stringstream ss;
|
||||||
ss << "v" << varId_;
|
ss << "v" << vid_;
|
||||||
return ss.str();
|
return ss.str();
|
||||||
} else {
|
} else {
|
||||||
return *label_;
|
return *label_;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
|
||||||
unsigned varId_;
|
|
||||||
string* label_;
|
|
||||||
unsigned index_;
|
|
||||||
int evidence_;
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
DISALLOW_COPY_AND_ASSIGN (Variable);
|
DISALLOW_COPY_AND_ASSIGN (Variable);
|
||||||
Domain domain_;
|
|
||||||
int dsize_;
|
Vid vid_;
|
||||||
|
unsigned dsize_;
|
||||||
|
int evidence_;
|
||||||
|
Domain domain_;
|
||||||
|
string* label_;
|
||||||
|
unsigned index_;
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // BP_GENERIC_VARIABLE_H
|
#endif // BP_VARIABLE_H
|
||||||
|
|
||||||
|
34
packages/CLPBN/clpbn/bp/examples/1parentNchilds.yap
Executable file
34
packages/CLPBN/clpbn/bp/examples/1parentNchilds.yap
Executable file
@ -0,0 +1,34 @@
|
|||||||
|
|
||||||
|
:- use_module(library(clpbn)).
|
||||||
|
|
||||||
|
:- set_clpbn_flag(solver, bp).
|
||||||
|
|
||||||
|
%
|
||||||
|
% R
|
||||||
|
% / | \
|
||||||
|
% / | \
|
||||||
|
% A B C
|
||||||
|
%
|
||||||
|
|
||||||
|
|
||||||
|
r(R) :-
|
||||||
|
{ R = r with p([t, f], [0.35, 0.65]) }.
|
||||||
|
|
||||||
|
a(A) :-
|
||||||
|
r(R),
|
||||||
|
child_dist(R,Dist),
|
||||||
|
{ A = a with Dist }.
|
||||||
|
|
||||||
|
b(B) :-
|
||||||
|
r(R),
|
||||||
|
child_dist(R,Dist),
|
||||||
|
{ B = b with Dist }.
|
||||||
|
|
||||||
|
c(C) :-
|
||||||
|
r(R),
|
||||||
|
child_dist(R,Dist),
|
||||||
|
{ C = c with Dist }.
|
||||||
|
|
||||||
|
|
||||||
|
child_dist(R, p([t, f], [0.3, 0.4, 0.25, 0.05], [R])).
|
||||||
|
|
53
packages/CLPBN/clpbn/bp/examples/bp-example.xml
Executable file
53
packages/CLPBN/clpbn/bp/examples/bp-example.xml
Executable file
@ -0,0 +1,53 @@
|
|||||||
|
<?xml version="1.0" encoding="US-ASCII"?>
|
||||||
|
|
||||||
|
<!--
|
||||||
|
|
||||||
|
A B
|
||||||
|
\ /
|
||||||
|
\ /
|
||||||
|
C
|
||||||
|
|
||||||
|
-->
|
||||||
|
|
||||||
|
<BIF VERSION="0.3">
|
||||||
|
<NETWORK>
|
||||||
|
<NAME>Neapolitan</NAME>
|
||||||
|
|
||||||
|
<VARIABLE TYPE="nature">
|
||||||
|
<NAME>A</NAME>
|
||||||
|
<OUTCOME>a1</OUTCOME>
|
||||||
|
<OUTCOME>a2</OUTCOME>
|
||||||
|
</VARIABLE>
|
||||||
|
|
||||||
|
<VARIABLE TYPE="nature">
|
||||||
|
<NAME>B</NAME>
|
||||||
|
<OUTCOME>b1</OUTCOME>
|
||||||
|
<OUTCOME>b2</OUTCOME>
|
||||||
|
</VARIABLE>
|
||||||
|
|
||||||
|
<VARIABLE TYPE="nature">
|
||||||
|
<NAME>C</NAME>
|
||||||
|
<OUTCOME>c1</OUTCOME>
|
||||||
|
<OUTCOME>c2</OUTCOME>
|
||||||
|
</VARIABLE>
|
||||||
|
|
||||||
|
<DEFINITION>
|
||||||
|
<FOR>A</FOR>
|
||||||
|
<TABLE> .695 .305 </TABLE>
|
||||||
|
</DEFINITION>
|
||||||
|
|
||||||
|
<DEFINITION>
|
||||||
|
<FOR>B</FOR>
|
||||||
|
<TABLE> .25 .75 </TABLE>
|
||||||
|
</DEFINITION>
|
||||||
|
|
||||||
|
<DEFINITION>
|
||||||
|
<FOR>C</FOR>
|
||||||
|
<GIVEN>A</GIVEN>
|
||||||
|
<GIVEN>B</GIVEN>
|
||||||
|
<TABLE> .2 .8 .45 .55 .32 .68 .7 .3 </TABLE>
|
||||||
|
</DEFINITION>
|
||||||
|
|
||||||
|
</NETWORK>
|
||||||
|
</BIF>
|
||||||
|
|
@ -9,10 +9,10 @@ MARKOV
|
|||||||
2 4 2
|
2 4 2
|
||||||
|
|
||||||
2
|
2
|
||||||
.001 .009
|
.001 .999
|
||||||
|
|
||||||
2
|
2
|
||||||
.002 .008
|
.002 .998
|
||||||
|
|
||||||
8
|
8
|
||||||
.95 .94 .29 .001
|
.95 .94 .29 .001
|
||||||
|
@ -49,12 +49,12 @@
|
|||||||
|
|
||||||
<DEFINITION>
|
<DEFINITION>
|
||||||
<FOR>B</FOR>
|
<FOR>B</FOR>
|
||||||
<TABLE> .001 .009 </TABLE>
|
<TABLE> .001 .999 </TABLE>
|
||||||
</DEFINITION>
|
</DEFINITION>
|
||||||
|
|
||||||
<DEFINITION>
|
<DEFINITION>
|
||||||
<FOR>E</FOR>
|
<FOR>E</FOR>
|
||||||
<TABLE> .002 .008 </TABLE>
|
<TABLE> .002 .998 </TABLE>
|
||||||
</DEFINITION>
|
</DEFINITION>
|
||||||
|
|
||||||
<DEFINITION>
|
<DEFINITION>
|
||||||
|
@ -1,54 +1,29 @@
|
|||||||
|
|
||||||
:- use_module(library(clpbn)).
|
:- use_module(library(clpbn)).
|
||||||
|
|
||||||
:- set_clpbn_flag(solver, vel).
|
:- set_clpbn_flag(solver, bp).
|
||||||
|
|
||||||
%
|
r(R) :- r_cpt(RCpt),
|
||||||
% B E
|
{ R = r with p([r1, r2], RCpt) }.
|
||||||
% \ /
|
|
||||||
% \ /
|
|
||||||
% A
|
|
||||||
% / \
|
|
||||||
% / \
|
|
||||||
% J M
|
|
||||||
%
|
|
||||||
|
|
||||||
|
t(T) :- t_cpt(TCpt),
|
||||||
|
{ T = t with p([t1, t2], TCpt) }.
|
||||||
|
|
||||||
b(B) :-
|
a(A) :- r(R), t(T), a_cpt(ACpt),
|
||||||
b_table(BDist),
|
{ A = a with p([a1, a2], ACpt, [R, T]) }.
|
||||||
{ B = b with p([b1, b2], BDist) }.
|
|
||||||
|
|
||||||
e(E) :-
|
j(J) :- a(A), j_cpt(JCpt),
|
||||||
e_table(EDist),
|
{ J = j with p([j1, j2], JCpt, [A]) }.
|
||||||
{ E = e with p([e1, e2], EDist) }.
|
|
||||||
|
|
||||||
a(A) :-
|
m(M) :- a(A), m_cpt(MCpt),
|
||||||
b(B),
|
{ M = m with p([m1, m2], MCpt, [A]) }.
|
||||||
e(E),
|
|
||||||
a_table(ADist),
|
|
||||||
{ A = a with p([a1, a2], ADist, [B, E]) }.
|
|
||||||
|
|
||||||
j(J):-
|
|
||||||
a(A),
|
|
||||||
j_table(JDist),
|
|
||||||
{ J = j with p([j1, j2], JDist, [A]) }.
|
|
||||||
|
|
||||||
m(M):-
|
|
||||||
a(A),
|
|
||||||
m_table(MDist),
|
|
||||||
{ M = m with p([m1, m2], MDist, [A]) }.
|
|
||||||
|
|
||||||
|
r_cpt([0.001, 0.999]).
|
||||||
b_table([0.001, 0.009]).
|
t_cpt([0.002, 0.998]).
|
||||||
|
a_cpt([0.95, 0.94, 0.29, 0.001,
|
||||||
e_table([0.002, 0.008]).
|
0.05, 0.06, 0.71, 0.999]).
|
||||||
|
j_cpt([0.9, 0.05,
|
||||||
a_table([0.95, 0.94, 0.29, 0.001,
|
0.1, 0.95]).
|
||||||
0.05, 0.06, 0.71, 0.999]).
|
m_cpt([0.7, 0.01,
|
||||||
|
0.3, 0.99]).
|
||||||
j_table([0.9, 0.05,
|
|
||||||
0.1, 0.95]).
|
|
||||||
|
|
||||||
m_table([0.7, 0.01,
|
|
||||||
0.3, 0.99]).
|
|
||||||
|
|
||||||
|
@ -16,34 +16,37 @@
|
|||||||
|
|
||||||
<VARIABLE TYPE="nature">
|
<VARIABLE TYPE="nature">
|
||||||
<NAME>A</NAME>
|
<NAME>A</NAME>
|
||||||
<OUTCOME></OUTCOME>
|
<OUTCOME>a1</OUTCOME>
|
||||||
|
<OUTCOME>a2</OUTCOME>
|
||||||
</VARIABLE>
|
</VARIABLE>
|
||||||
|
|
||||||
<VARIABLE TYPE="nature">
|
<VARIABLE TYPE="nature">
|
||||||
<NAME>B</NAME>
|
<NAME>B</NAME>
|
||||||
<OUTCOME></OUTCOME>
|
<OUTCOME>b1</OUTCOME>
|
||||||
|
<OUTCOME>b2</OUTCOME>
|
||||||
</VARIABLE>
|
</VARIABLE>
|
||||||
|
|
||||||
<VARIABLE TYPE="nature">
|
<VARIABLE TYPE="nature">
|
||||||
<NAME>C</NAME>
|
<NAME>C</NAME>
|
||||||
<OUTCOME></OUTCOME>
|
<OUTCOME>c1</OUTCOME>
|
||||||
|
<OUTCOME>c2</OUTCOME>
|
||||||
</VARIABLE>
|
</VARIABLE>
|
||||||
|
|
||||||
<DEFINITION>
|
<DEFINITION>
|
||||||
<FOR>A</FOR>
|
<FOR>A</FOR>
|
||||||
<TABLE>1</TABLE>
|
<TABLE>.695 .305</TABLE>
|
||||||
</DEFINITION>
|
</DEFINITION>
|
||||||
|
|
||||||
<DEFINITION>
|
<DEFINITION>
|
||||||
<FOR>B</FOR>
|
<FOR>B</FOR>
|
||||||
<TABLE>1</TABLE>
|
<TABLE>0.25 0.75</TABLE>
|
||||||
</DEFINITION>
|
</DEFINITION>
|
||||||
|
|
||||||
<DEFINITION>
|
<DEFINITION>
|
||||||
<FOR>C</FOR>
|
<FOR>C</FOR>
|
||||||
<GIVEN>A</GIVEN>
|
<GIVEN>A</GIVEN>
|
||||||
<GIVEN>B</GIVEN>
|
<GIVEN>B</GIVEN>
|
||||||
<TABLE>1</TABLE>
|
<TABLE>0.2 0.8 0.45 0.55 0.32 0.68 0.7 0.3</TABLE>
|
||||||
</DEFINITION>
|
</DEFINITION>
|
||||||
|
|
||||||
</NETWORK>
|
</NETWORK>
|
||||||
|
67
packages/CLPBN/clpbn/bp/examples/lambda fail.xml
Executable file
67
packages/CLPBN/clpbn/bp/examples/lambda fail.xml
Executable file
@ -0,0 +1,67 @@
|
|||||||
|
<?xml version="1.0" encoding="US-ASCII"?>
|
||||||
|
|
||||||
|
<!--
|
||||||
|
|
||||||
|
P1 P2 P3
|
||||||
|
\ | /
|
||||||
|
\ | /
|
||||||
|
-
|
||||||
|
C
|
||||||
|
|
||||||
|
-->
|
||||||
|
|
||||||
|
<BIF VERSION="0.3">
|
||||||
|
<NETWORK>
|
||||||
|
|
||||||
|
<NAME>Simple Convergence</NAME>
|
||||||
|
<VARIABLE TYPE="nature">
|
||||||
|
<NAME>P1</NAME>
|
||||||
|
<OUTCOME>p1</OUTCOME>
|
||||||
|
<OUTCOME>p2</OUTCOME>
|
||||||
|
</VARIABLE>
|
||||||
|
|
||||||
|
<VARIABLE TYPE="nature">
|
||||||
|
<NAME>P2</NAME>
|
||||||
|
<OUTCOME>p1</OUTCOME>
|
||||||
|
<OUTCOME>p2</OUTCOME>
|
||||||
|
<OUTCOME>p3</OUTCOME>
|
||||||
|
</VARIABLE>
|
||||||
|
|
||||||
|
<VARIABLE TYPE="nature">
|
||||||
|
<NAME>P3</NAME>
|
||||||
|
<OUTCOME>p1</OUTCOME>
|
||||||
|
<OUTCOME>p2</OUTCOME>
|
||||||
|
</VARIABLE>
|
||||||
|
|
||||||
|
<VARIABLE TYPE="nature">
|
||||||
|
<NAME>C</NAME>
|
||||||
|
<OUTCOME>c1</OUTCOME>
|
||||||
|
<OUTCOME>c2</OUTCOME>
|
||||||
|
</VARIABLE>
|
||||||
|
|
||||||
|
<DEFINITION>
|
||||||
|
<FOR>P1</FOR>
|
||||||
|
<TABLE>.695 .305</TABLE>
|
||||||
|
</DEFINITION>
|
||||||
|
|
||||||
|
<DEFINITION>
|
||||||
|
<FOR>P2</FOR>
|
||||||
|
<TABLE>0.2 0.3 0.5</TABLE>
|
||||||
|
</DEFINITION>
|
||||||
|
|
||||||
|
<DEFINITION>
|
||||||
|
<FOR>P3</FOR>
|
||||||
|
<TABLE>0.25 0.75</TABLE>
|
||||||
|
</DEFINITION>
|
||||||
|
|
||||||
|
<DEFINITION>
|
||||||
|
<FOR>C</FOR>
|
||||||
|
<GIVEN>P1</GIVEN>
|
||||||
|
<GIVEN>P2</GIVEN>
|
||||||
|
<GIVEN>P3</GIVEN>
|
||||||
|
<TABLE>0.2 0.8 0.45 0.55 0.32 0.68 0.7 0.3 0.3 0.7 0.55 0.45 0.22 0.78 0.25 0.75 0.11 0.89 0.34 0.66 0.1 0.9 0.6 0.4</TABLE>
|
||||||
|
</DEFINITION>
|
||||||
|
|
||||||
|
</NETWORK>
|
||||||
|
</BIF>
|
||||||
|
|
@ -2,6 +2,7 @@
|
|||||||
:- use_module(library(clpbn)).
|
:- use_module(library(clpbn)).
|
||||||
|
|
||||||
:- set_clpbn_flag(solver, bp).
|
:- set_clpbn_flag(solver, bp).
|
||||||
|
%:- set_clpbn_flag(solver, jt).
|
||||||
|
|
||||||
%
|
%
|
||||||
% B F
|
% B F
|
||||||
|
17
packages/CLPBN/clpbn/bp/examples/sp-example.uai
Executable file
17
packages/CLPBN/clpbn/bp/examples/sp-example.uai
Executable file
@ -0,0 +1,17 @@
|
|||||||
|
MARKOV
|
||||||
|
3
|
||||||
|
2 2 2
|
||||||
|
3
|
||||||
|
1 0
|
||||||
|
1 1
|
||||||
|
3 2 0 1
|
||||||
|
|
||||||
|
2
|
||||||
|
.695 .305
|
||||||
|
|
||||||
|
2
|
||||||
|
.25 .75
|
||||||
|
|
||||||
|
8
|
||||||
|
0.2 0.45 0.32 0.7
|
||||||
|
0.8 0.55 0.68 0.3
|
128
packages/CLPBN/clpbn/bp/examples/test_bn.xml
Executable file
128
packages/CLPBN/clpbn/bp/examples/test_bn.xml
Executable file
@ -0,0 +1,128 @@
|
|||||||
|
<?xml version="1.0" encoding="US-ASCII"?>
|
||||||
|
|
||||||
|
<!--
|
||||||
|
|
||||||
|
A B C
|
||||||
|
\ | /
|
||||||
|
\ | /
|
||||||
|
D
|
||||||
|
/ | \
|
||||||
|
/ | \
|
||||||
|
E F G
|
||||||
|
|
||||||
|
-->
|
||||||
|
|
||||||
|
<BIF VERSION="0.3">
|
||||||
|
<NETWORK>
|
||||||
|
<NAME>Node with several parents and childs</NAME>
|
||||||
|
|
||||||
|
<VARIABLE TYPE="nature">
|
||||||
|
<NAME>A</NAME>
|
||||||
|
<OUTCOME>a1</OUTCOME>
|
||||||
|
<OUTCOME>a2</OUTCOME>
|
||||||
|
</VARIABLE>
|
||||||
|
|
||||||
|
<VARIABLE TYPE="nature">
|
||||||
|
<NAME>B</NAME>
|
||||||
|
<OUTCOME>b1</OUTCOME>
|
||||||
|
<OUTCOME>b2</OUTCOME>
|
||||||
|
<OUTCOME>b3</OUTCOME>
|
||||||
|
<OUTCOME>b4</OUTCOME>
|
||||||
|
</VARIABLE>
|
||||||
|
|
||||||
|
<VARIABLE TYPE="nature">
|
||||||
|
<NAME>C</NAME>
|
||||||
|
<OUTCOME>c1</OUTCOME>
|
||||||
|
<OUTCOME>c2</OUTCOME>
|
||||||
|
<OUTCOME>c3</OUTCOME>
|
||||||
|
</VARIABLE>
|
||||||
|
|
||||||
|
<VARIABLE TYPE="nature">
|
||||||
|
<NAME>D</NAME>
|
||||||
|
<OUTCOME>d1</OUTCOME>
|
||||||
|
<OUTCOME>d2</OUTCOME>
|
||||||
|
<OUTCOME>d3</OUTCOME>
|
||||||
|
</VARIABLE>
|
||||||
|
|
||||||
|
<VARIABLE TYPE="nature">
|
||||||
|
<NAME>E</NAME>
|
||||||
|
<OUTCOME>e1</OUTCOME>
|
||||||
|
<OUTCOME>e2</OUTCOME>
|
||||||
|
<OUTCOME>e3</OUTCOME>
|
||||||
|
<OUTCOME>e4</OUTCOME>
|
||||||
|
</VARIABLE>
|
||||||
|
|
||||||
|
<VARIABLE TYPE="nature">
|
||||||
|
<NAME>F</NAME>
|
||||||
|
<OUTCOME>f1</OUTCOME>
|
||||||
|
<OUTCOME>f2</OUTCOME>
|
||||||
|
<OUTCOME>f3</OUTCOME>
|
||||||
|
</VARIABLE>
|
||||||
|
|
||||||
|
<VARIABLE TYPE="nature">
|
||||||
|
<NAME>G</NAME>
|
||||||
|
<OUTCOME>g1</OUTCOME>
|
||||||
|
<OUTCOME>g2</OUTCOME>
|
||||||
|
</VARIABLE>
|
||||||
|
|
||||||
|
|
||||||
|
<DEFINITION>
|
||||||
|
<FOR>A</FOR>
|
||||||
|
<TABLE> .1 .2 </TABLE>
|
||||||
|
</DEFINITION>
|
||||||
|
|
||||||
|
<DEFINITION>
|
||||||
|
<FOR>B</FOR>
|
||||||
|
<TABLE> .01 .02 .03 .04 </TABLE>
|
||||||
|
</DEFINITION>
|
||||||
|
|
||||||
|
<DEFINITION>
|
||||||
|
<FOR>C</FOR>
|
||||||
|
<TABLE> .11 .22 .33 </TABLE>
|
||||||
|
</DEFINITION>
|
||||||
|
|
||||||
|
<DEFINITION>
|
||||||
|
<FOR>D</FOR>
|
||||||
|
<GIVEN>A</GIVEN>
|
||||||
|
<GIVEN>B</GIVEN>
|
||||||
|
<GIVEN>C</GIVEN>
|
||||||
|
<TABLE>
|
||||||
|
.522 .008 .99 .01 .2 .8 .003 .457 .423 .007 .92 .04 .5 .232 .033 .227 .112 .048 .91 .21 .24 .18 .005 .227
|
||||||
|
.212 .04 .59 .21 .6 .1 .023 .215 .913 .017 .96 .01 .55 .422 .013 .417 .272 .068 .61 .11 .26 .28 .205 .322
|
||||||
|
.142 .028 .19 .11 .5 .67 .013 .437 .163 .067 .12 .06 .1 .262 .063 .167 .512 .028 .11 .41 .14 .68 .015 .92
|
||||||
|
</TABLE>
|
||||||
|
</DEFINITION>
|
||||||
|
|
||||||
|
<DEFINITION>
|
||||||
|
<FOR>E</FOR>
|
||||||
|
<GIVEN>D</GIVEN>
|
||||||
|
<TABLE>
|
||||||
|
.111 .11 .1
|
||||||
|
.222 .22 .2
|
||||||
|
.333 .33 .3
|
||||||
|
.444 .44 .4
|
||||||
|
</TABLE>
|
||||||
|
</DEFINITION>
|
||||||
|
|
||||||
|
<DEFINITION>
|
||||||
|
<FOR>F</FOR>
|
||||||
|
<GIVEN>D</GIVEN>
|
||||||
|
<TABLE>
|
||||||
|
.112 .111 .110
|
||||||
|
.223 .222 .221
|
||||||
|
.334 .333 .332
|
||||||
|
</TABLE>
|
||||||
|
</DEFINITION>
|
||||||
|
|
||||||
|
<DEFINITION>
|
||||||
|
<FOR>G</FOR>
|
||||||
|
<GIVEN>D</GIVEN>
|
||||||
|
<TABLE>
|
||||||
|
.101 .102 .103
|
||||||
|
.201 .202 .203
|
||||||
|
</TABLE>
|
||||||
|
</DEFINITION>
|
||||||
|
|
||||||
|
</NETWORK>
|
||||||
|
</BIF>
|
||||||
|
|
36
packages/CLPBN/clpbn/bp/examples/test_mk.uai
Executable file
36
packages/CLPBN/clpbn/bp/examples/test_mk.uai
Executable file
@ -0,0 +1,36 @@
|
|||||||
|
MARKOV
|
||||||
|
5
|
||||||
|
4 2 3 2 3
|
||||||
|
7
|
||||||
|
1 0
|
||||||
|
1 1
|
||||||
|
1 2
|
||||||
|
1 3
|
||||||
|
1 4
|
||||||
|
2 0 1
|
||||||
|
4 1 2 3 4
|
||||||
|
|
||||||
|
4
|
||||||
|
0.1 0.7 0.43 0.22
|
||||||
|
|
||||||
|
2
|
||||||
|
0.2 0.6
|
||||||
|
|
||||||
|
3
|
||||||
|
0.3 0.5 0.2
|
||||||
|
|
||||||
|
2
|
||||||
|
0.15 0.75
|
||||||
|
|
||||||
|
3
|
||||||
|
0.25 0.45 0.15
|
||||||
|
|
||||||
|
8
|
||||||
|
0.210 0.333 0.457 0.4
|
||||||
|
0.811 0.000 0.189 0.89
|
||||||
|
|
||||||
|
36
|
||||||
|
0.1 0.15 0.2 0.25 0.3 0.45 0.5 0.55 0.65 0.7 0.75 0.9
|
||||||
|
0.11 0.22 0.33 0.44 0.55 0.66 0.77 0.88 0.91 0.93 0.95 0.97
|
||||||
|
0.42 0.22 0.33 0.44 0.15 0.36 0.27 0.28 0.21 0.13 0.25 0.17
|
||||||
|
|
69
packages/CLPBN/clpbn/bp/examples/ve_example.xml
Executable file
69
packages/CLPBN/clpbn/bp/examples/ve_example.xml
Executable file
@ -0,0 +1,69 @@
|
|||||||
|
<?xml version="1.0" encoding="US-ASCII"?>
|
||||||
|
|
||||||
|
<!--
|
||||||
|
|
||||||
|
A B
|
||||||
|
\ /
|
||||||
|
\ /
|
||||||
|
C
|
||||||
|
|
|
||||||
|
|
|
||||||
|
D
|
||||||
|
|
||||||
|
-->
|
||||||
|
|
||||||
|
|
||||||
|
<BIF VERSION="0.3">
|
||||||
|
<NETWORK>
|
||||||
|
<NAME>Simple Loop</NAME>
|
||||||
|
|
||||||
|
<VARIABLE TYPE="nature">
|
||||||
|
<NAME>A</NAME>
|
||||||
|
<OUTCOME>a1</OUTCOME>
|
||||||
|
<OUTCOME>a2</OUTCOME>
|
||||||
|
</VARIABLE>
|
||||||
|
|
||||||
|
<VARIABLE TYPE="nature">
|
||||||
|
<NAME>B</NAME>
|
||||||
|
<OUTCOME>b1</OUTCOME>
|
||||||
|
<OUTCOME>b2</OUTCOME>
|
||||||
|
</VARIABLE>
|
||||||
|
|
||||||
|
<VARIABLE TYPE="nature">
|
||||||
|
<NAME>C</NAME>
|
||||||
|
<OUTCOME>c1</OUTCOME>
|
||||||
|
<OUTCOME>c2</OUTCOME>
|
||||||
|
</VARIABLE>
|
||||||
|
|
||||||
|
<VARIABLE TYPE="nature">
|
||||||
|
<NAME>D</NAME>
|
||||||
|
<OUTCOME>d1</OUTCOME>
|
||||||
|
<OUTCOME>d2</OUTCOME>
|
||||||
|
</VARIABLE>
|
||||||
|
|
||||||
|
<DEFINITION>
|
||||||
|
<FOR>A</FOR>
|
||||||
|
<TABLE> .001 .009 </TABLE>
|
||||||
|
</DEFINITION>
|
||||||
|
|
||||||
|
<DEFINITION>
|
||||||
|
<FOR>B</FOR>
|
||||||
|
<TABLE> .002 .008 </TABLE>
|
||||||
|
</DEFINITION>
|
||||||
|
|
||||||
|
<DEFINITION>
|
||||||
|
<FOR>C</FOR>
|
||||||
|
<GIVEN>A</GIVEN>
|
||||||
|
<GIVEN>B</GIVEN>
|
||||||
|
<TABLE> .95 .05 .94 .06 .29 .71 .001 .999 </TABLE>
|
||||||
|
</DEFINITION>
|
||||||
|
|
||||||
|
<DEFINITION>
|
||||||
|
<FOR>D</FOR>
|
||||||
|
<GIVEN>C</GIVEN>
|
||||||
|
<TABLE> .9 .1 .05 .95 </TABLE>
|
||||||
|
</DEFINITION>
|
||||||
|
|
||||||
|
</NETWORK>
|
||||||
|
</BIF>
|
||||||
|
|
Reference in New Issue
Block a user