Conflicts:
	packages/CLPBN/clpbn/bp.yap
	packages/CLPBN/clpbn/fove.yap
	packages/CLPBN/clpbn/horus.yap
This commit is contained in:
Vítor Santos Costa 2012-04-03 15:07:22 +01:00
commit 6ccd458ea5
61 changed files with 81208 additions and 2903 deletions

View File

@ -1,9 +1,9 @@
/************************************************
/*******************************************************
Belief Propagation in CLP(BN)
Belief Propagation and Variable Elimination Interface
**************************************************/
********************************************************/
:- module(clpbn_bp,
[bp/3,
@ -23,18 +23,16 @@
:- use_module(library('clpbn/display'),
[clpbn_bind_vals/3]).
[clpbn_bind_vals/3]).
:- use_module(library('clpbn/aggregates'),
[check_for_agg_vars/2]).
[check_for_agg_vars/2]).
:- use_module(library(clpbn/horus)).
:- use_module(library(atts)).
:- use_module(library(lists)).
:- use_module(library(charsio)).
:- attribute id/1.
@ -51,42 +49,50 @@
:- set_horus_flag(accuracy, 0.0001).
:- set_horus_flag(max_iter, 1000).
:- use_module(library(charsio),
[term_to_atom/2]).
:- set_horus_flag(use_logarithms, false).
%:- set_horus_flag(use_logarithms, true).
:- set_horus_flag(order_factor_variables, false).
%:- set_horus_flag(order_factor_variables, true).
:- use_module(horus,
[create_ground_network/2,
set_bayes_net_params/2,
run_ground_solver/3,
set_extra_vars_info/2,
free_bayesian_network/1
]).
:- attribute id/1.
bp([[]],_,_) :- !.
bp([QueryVars], AllVars, Output) :-
init_bp_solver(_, AllVars, _, Network),
run_bp_solver([QueryVars], LPs, Network),
finalize_bp_solver(Network),
clpbn_bind_vals([QueryVars], LPs, Output).
init_bp_solver(_, AllVars, _, Network),
run_bp_solver([QueryVars], LPs, Network),
finalize_bp_solver(Network),
clpbn_bind_vals([QueryVars], LPs, Output).
init_bp_solver(_, AllVars0, _, bp(BayesNet, DistIds)) :-
check_for_agg_vars(AllVars0, AllVars),
writeln('clpbn_vars:'),
print_clpbn_vars(AllVars),
assign_ids(AllVars, 0),
get_vars_info(AllVars, VarsInfo, DistIds0),
sort(DistIds0, DistIds),
create_ground_network(VarsInfo, BayesNet).
%get_extra_vars_info(AllVars, ExtraVarsInfo),
%set_extra_vars_info(BayesNet, ExtraVarsInfo).
%writeln('init_bp_solver'),
check_for_agg_vars(AllVars0, AllVars),
%writeln('clpbn_vars:'), print_clpbn_vars(AllVars),
assign_ids(AllVars, 0),
get_vars_info(AllVars, VarsInfo, DistIds0),
sort(DistIds0, DistIds),
create_ground_network(VarsInfo, BayesNet),
%get_extra_vars_info(AllVars, ExtraVarsInfo),
%set_extra_vars_info(BayesNet, ExtraVarsInfo),
%writeln(extravarsinfo:ExtraVarsInfo),
true.
run_bp_solver(QueryVars, Solutions, bp(Network, DistIds)) :-
get_dists_parameters(DistIds, DistsParams),
%writeln('-> run_bp_solver'),
get_dists_parameters(DistIds, DistsParams),
set_bayes_net_params(Network, DistsParams),
flatten_1_element_sublists(QueryVars, QueryVars1),
vars_to_ids(QueryVars1, QueryVarsIds),
run_other_solvers(Network, QueryVarsIds, Solutions).
vars_to_ids(QueryVars, QueryVarsIds),
run_ground_solver(Network, QueryVarsIds, Solutions).
finalize_bp_solver(bp(Network, _)) :-
@ -95,82 +101,75 @@ finalize_bp_solver(bp(Network, _)) :-
assign_ids([], _).
assign_ids([V|Vs], Count) :-
put_atts(V, [id(Count)]),
Count1 is Count + 1,
assign_ids(Vs, Count1).
put_atts(V, [id(Count)]),
Count1 is Count + 1,
assign_ids(Vs, Count1).
get_vars_info([], [], []).
get_vars_info(V.Vs,
var(VarId,DS,Ev,PIds,DistId).VarsInfo,
DistId.DistIds) :-
clpbn:get_atts(V, [dist(DistId, Parents)]), !,
get_atts(V, [id(VarId)]),
get_dist_domain_size(DistId, DS),
get_evidence(V, Ev),
vars_to_ids(Parents, PIds),
get_vars_info(Vs, VarsInfo, DistIds).
var(VarId,DS,Ev,PIds,DistId).VarsInfo,
DistId.DistIds) :-
clpbn:get_atts(V, [dist(DistId, Parents)]), !,
get_atts(V, [id(VarId)]),
get_dist_domain_size(DistId, DS),
get_evidence(V, Ev),
vars_to_ids(Parents, PIds),
get_vars_info(Vs, VarsInfo, DistIds).
get_evidence(V, Ev) :-
clpbn:get_atts(V, [evidence(Ev)]), !.
clpbn:get_atts(V, [evidence(Ev)]), !.
get_evidence(_V, -1). % no evidence !!!
vars_to_ids([], []).
vars_to_ids([L|Vars], [LIds|Ids]) :-
is_list(L), !,
vars_to_ids(L, LIds),
vars_to_ids(Vars, Ids).
is_list(L), !,
vars_to_ids(L, LIds),
vars_to_ids(Vars, Ids).
vars_to_ids([V|Vars], [VarId|Ids]) :-
get_atts(V, [id(VarId)]),
vars_to_ids(Vars, Ids).
get_atts(V, [id(VarId)]),
vars_to_ids(Vars, Ids).
get_extra_vars_info([], []).
get_extra_vars_info([V|Vs], [v(VarId, Label, Domain)|VarsInfo]) :-
get_atts(V, [id(VarId)]), !,
clpbn:get_atts(V, [key(Key),dist(DistId, _)]),
term_to_atom(Key, Label),
get_dist_domain(DistId, Domain0),
numbers_to_atoms(Domain0, Domain),
get_extra_vars_info(Vs, VarsInfo).
get_atts(V, [id(VarId)]), !,
clpbn:get_atts(V, [key(Key), dist(DistId, _)]),
term_to_atom(Key, Label),
get_dist_domain(DistId, Domain0),
numbers_to_atoms(Domain0, Domain),
get_extra_vars_info(Vs, VarsInfo).
get_extra_vars_info([_|Vs], VarsInfo) :-
get_extra_vars_info(Vs, VarsInfo).
get_extra_vars_info(Vs, VarsInfo).
get_dists_parameters([],[]).
get_dists_parameters([Id|Ids], [dist(Id, Params)|DistsInfo]) :-
get_dist_params(Id, Params),
get_dists_parameters(Ids, DistsInfo).
get_dist_params(Id, Params),
get_dists_parameters(Ids, DistsInfo).
numbers_to_atoms([], []).
numbers_to_atoms([Atom|L0], [Atom|L]) :-
atom(Atom), !,
numbers_to_atoms(L0, L).
atom(Atom), !,
numbers_to_atoms(L0, L).
numbers_to_atoms([Number|L0], [Atom|L]) :-
number_atom(Number, Atom),
numbers_to_atoms(L0, L).
flatten_1_element_sublists([],[]).
flatten_1_element_sublists([[H|[]]|T],[H|R]) :- !,
flatten_1_element_sublists(T,R).
flatten_1_element_sublists([H|T],[H|R]) :-
flatten_1_element_sublists(T,R).
number_atom(Number, Atom),
numbers_to_atoms(L0, L).
print_clpbn_vars(Var.AllVars) :-
clpbn:get_atts(Var, [key(Key),dist(DistId,Parents)]),
parents_to_keys(Parents, ParentKeys),
writeln(Var:Key:ParentKeys:DistId),
print_clpbn_vars(AllVars).
clpbn:get_atts(Var, [key(Key),dist(DistId,Parents)]),
parents_to_keys(Parents, ParentKeys),
writeln(Var:Key:ParentKeys:DistId),
print_clpbn_vars(AllVars).
print_clpbn_vars([]).
parents_to_keys([], []).
parents_to_keys(Var.Parents, Key.Keys) :-
clpbn:get_atts(Var, [key(Key)]),
parents_to_keys(Parents, Keys).
clpbn:get_atts(Var, [key(Key)]),
parents_to_keys(Parents, Keys).

View File

@ -88,15 +88,22 @@ BayesNet::readFromBifFormat (const char* fileName)
abort();
}
params = reorderParameters (params, node->nrStates());
Distribution* dist = new Distribution (params);
node->setDistribution (dist);
addDistribution (dist);
if (Globals::logDomain) {
Util::toLog (params);
}
node->setParams (params);
}
setIndexes();
if (Globals::logDomain) {
distributionsToLogs();
}
}
BayesNode*
BayesNet::addNode (BayesNode* n)
{
varMap_.insert (make_pair (n->varId(), nodes_.size()));
nodes_.push_back (n);
return nodes_.back();
}
@ -114,15 +121,6 @@ BayesNet::addNode (string label, const States& states)
BayesNode*
BayesNet::addNode (VarId vid, unsigned dsize, int evidence, Distribution* dist)
{
varMap_.insert (make_pair (vid, nodes_.size()));
nodes_.push_back (new BayesNode (vid, dsize, evidence, dist));
return nodes_.back();
}
BayesNode*
BayesNet::getBayesNode (VarId vid) const
@ -176,29 +174,6 @@ BayesNet::getVariableNodes (void) const
void
BayesNet::addDistribution (Distribution* dist)
{
dists_.push_back (dist);
}
Distribution*
BayesNet::getDistribution (unsigned distId) const
{
Distribution* dist = 0;
for (unsigned i = 0; i < dists_.size(); i++) {
if (dists_[i]->id == (int) distId) {
dist = dists_[i];
break;
}
}
return dist;
}
const BnNodeSet&
BayesNet::getBayesNodes (void) const
{
@ -299,7 +274,7 @@ BayesNet::getMinimalRequesiteNetwork (const VarIds& queryVarIds) const
/*
cout << "\t\ttop\tbottom" << endl;
cout << "variable\t\tmarked\tmarked\tvisited\tobserved" << endl;
cout << "----------------------------------------------------------" ;
Util::printDashedLine();
cout << endl;
for (unsigned i = 0; i < states.size(); i++) {
cout << nodes_[i]->label() << ":\t\t" ;
@ -350,10 +325,8 @@ BayesNet::constructGraph (BayesNet* bn,
}
}
assert (bn->getBayesNode (nodes_[i]->varId()) == 0);
BayesNode* mrnNode = bn->addNode (nodes_[i]->varId(),
nodes_[i]->nrStates(),
nodes_[i]->getEvidence(),
nodes_[i]->getDistribution());
BayesNode* mrnNode = new BayesNode (nodes_[i]);
bn->addNode (mrnNode);
mrnNodes.push_back (mrnNode);
}
}
@ -388,26 +361,6 @@ BayesNet::setIndexes (void)
void
BayesNet::distributionsToLogs (void)
{
for (unsigned i = 0; i < dists_.size(); i++) {
Util::toLog (dists_[i]->params);
}
}
void
BayesNet::freeDistributions (void)
{
for (unsigned i = 0; i < dists_.size(); i++) {
delete dists_[i];
}
}
void
BayesNet::printGraphicalModel (void) const
{
@ -504,8 +457,8 @@ BayesNet::exportToBifFormat (const char* fileName) const
out << "\t<GIVEN>" << parents[j]->label();
out << "</GIVEN>" << endl;
}
Params params = revertParameterReorder (nodes_[i]->getParameters(),
nodes_[i]->nrStates());
Params params = revertParameterReorder (
nodes_[i]->params(), nodes_[i]->nrStates());
out << "\t<TABLE>" ;
for (unsigned j = 0; j < params.size(); j++) {
out << " " << params[j];

View File

@ -13,16 +13,11 @@
using namespace std;
class Distribution;
struct ScheduleInfo
{
ScheduleInfo (BayesNode* n, bool vfp, bool vfc)
{
node = n;
visitedFromParent = vfp;
visitedFromChild = vfc;
}
ScheduleInfo (BayesNode* n, bool vfp, bool vfc) :
node(n), visitedFromParent(vfp), visitedFromChild(vfc) { }
BayesNode* node;
bool visitedFromParent;
bool visitedFromChild;
@ -31,70 +26,84 @@ struct ScheduleInfo
struct StateInfo
{
StateInfo (void)
{
visited = true;
markedOnTop = false;
markedOnBottom = false;
}
StateInfo (void) : visited(false), markedOnTop(false),
markedOnBottom(false) { }
bool visited;
bool markedOnTop;
bool markedOnBottom;
};
typedef vector<Distribution*> DistSet;
typedef queue<ScheduleInfo, list<ScheduleInfo> > Scheduling;
class BayesNet : public GraphicalModel
{
public:
BayesNet (void) {};
BayesNet (void) { };
~BayesNet (void);
void readFromBifFormat (const char*);
BayesNode* addNode (string, const States&);
// BayesNode* addNode (VarId, unsigned, int, BnNodeSet&, Distribution*);
BayesNode* addNode (VarId, unsigned, int, Distribution*);
BayesNode* getBayesNode (VarId) const;
BayesNode* getBayesNode (string) const;
VarNode* getVariableNode (VarId) const;
VarNodes getVariableNodes (void) const;
void addDistribution (Distribution*);
Distribution* getDistribution (unsigned) const;
const BnNodeSet& getBayesNodes (void) const;
unsigned nrNodes (void) const;
BnNodeSet getRootNodes (void) const;
BnNodeSet getLeafNodes (void) const;
BayesNet* getMinimalRequesiteNetwork (VarId) const;
BayesNet* getMinimalRequesiteNetwork (const VarIds&) const;
void constructGraph (
BayesNet*, const vector<StateInfo*>&) const;
bool isPolyTree (void) const;
void setIndexes (void);
void distributionsToLogs (void);
void freeDistributions (void);
void printGraphicalModel (void) const;
void exportToGraphViz (const char*, bool = true,
const VarIds& = VarIds()) const;
void exportToBifFormat (const char*) const;
void readFromBifFormat (const char*);
BayesNode* addNode (BayesNode*);
BayesNode* addNode (string, const States&);
BayesNode* getBayesNode (VarId) const;
BayesNode* getBayesNode (string) const;
VarNode* getVariableNode (VarId) const;
VarNodes getVariableNodes (void) const;
const BnNodeSet& getBayesNodes (void) const;
unsigned nrNodes (void) const;
BnNodeSet getRootNodes (void) const;
BnNodeSet getLeafNodes (void) const;
BayesNet* getMinimalRequesiteNetwork (VarId) const;
BayesNet* getMinimalRequesiteNetwork (const VarIds&) const;
void constructGraph (BayesNet*, const vector<StateInfo*>&) const;
bool isPolyTree (void) const;
void setIndexes (void);
void printGraphicalModel (void) const;
void exportToGraphViz (const char*, bool = true,
const VarIds& = VarIds()) const;
void exportToBifFormat (const char*) const;
private:
DISALLOW_COPY_AND_ASSIGN (BayesNet);
bool containsUndirectedCycle (void) const;
bool containsUndirectedCycle (int, int, vector<bool>&)const;
vector<int> getAdjacentNodes (int) const;
Params reorderParameters (const Params&, unsigned) const;
Params revertParameterReorder (const Params&, unsigned) const;
void scheduleParents (const BayesNode*, Scheduling&) const;
void scheduleChilds (const BayesNode*, Scheduling&) const;
bool containsUndirectedCycle (void) const;
BnNodeSet nodes_;
DistSet dists_;
bool containsUndirectedCycle (int, int, vector<bool>&)const;
vector<int> getAdjacentNodes (int) const;
Params reorderParameters (const Params&, unsigned) const;
Params revertParameterReorder (const Params&, unsigned) const;
void scheduleParents (const BayesNode*, Scheduling&) const;
void scheduleChilds (const BayesNode*, Scheduling&) const;
BnNodeSet nodes_;
typedef unordered_map<unsigned, unsigned> IndexMap;
IndexMap varMap_;
IndexMap varMap_;
};

View File

@ -8,29 +8,10 @@
#include "BayesNode.h"
BayesNode::BayesNode (VarId vid,
unsigned dsize,
int evidence,
Distribution* dist)
: VarNode (vid, dsize, evidence)
void
BayesNode::setParams (const Params& params)
{
dist_ = dist;
}
BayesNode::BayesNode (VarId vid,
unsigned dsize,
int evidence,
const BnNodeSet& parents,
Distribution* dist)
: VarNode (vid, dsize, evidence)
{
parents_ = parents;
dist_ = dist;
for (unsigned int i = 0; i < parents.size(); i++) {
parents[i]->addChild (this);
}
params_ = params;
}
@ -54,31 +35,6 @@ BayesNode::addChild (BayesNode* node)
void
BayesNode::setDistribution (Distribution* dist)
{
assert (dist);
dist_ = dist;
}
Distribution*
BayesNode::getDistribution (void)
{
return dist_;
}
const Params&
BayesNode::getParameters (void)
{
return dist_->params;
}
Params
BayesNode::getRow (int rowIndex) const
{
@ -86,7 +42,7 @@ BayesNode::getRow (int rowIndex) const
int offset = rowSize * rowIndex;
Params row (rowSize);
for (int i = 0; i < rowSize; i++) {
row[i] = dist_->params[offset + i] ;
row[i] = params_[offset + i] ;
}
return row;
}
@ -119,13 +75,13 @@ BayesNode::hasNeighbors (void) const
int
BayesNode::getCptSize (void)
{
return dist_->params.size();
return params_.size();
}
int
BayesNode::getIndexOfParent (const BayesNode* parent) const
BayesNode::indexOfParent (const BayesNode* parent) const
{
for (unsigned int i = 0; i < parents_.size(); i++) {
if (parents_[i] == parent) {

View File

@ -4,7 +4,6 @@
#include <vector>
#include "VarNode.h"
#include "Distribution.h"
#include "Horus.h"
using namespace std;
@ -13,49 +12,70 @@ using namespace std;
class BayesNode : public VarNode
{
public:
BayesNode (const VarNode& v) : VarNode (v) {}
BayesNode (VarId, unsigned, int, Distribution*);
BayesNode (VarId, unsigned, int, const BnNodeSet&, Distribution*);
void setParents (const BnNodeSet&);
void addChild (BayesNode*);
void setDistribution (Distribution*);
Distribution* getDistribution (void);
const Params& getParameters (void);
Params getRow (int) const;
bool isRoot (void);
bool isLeaf (void);
bool hasNeighbors (void) const;
int getCptSize (void);
int getIndexOfParent (const BayesNode*) const;
string cptEntryToString (int, const vector<unsigned>&) const;
BayesNode (const VarNode& v) : VarNode (v) { }
const BnNodeSet& getParents (void) const { return parents_; }
const BnNodeSet& getChilds (void) const { return childs_; }
BayesNode (const BayesNode* n) :
VarNode (n->varId(), n->nrStates(), n->getEvidence()),
params_(n->params()), distId_(n->distId()) { }
BayesNode (VarId vid, unsigned nrStates, int ev,
const Params& ps, unsigned id)
: VarNode (vid, nrStates, ev) , params_(ps), distId_(id) { }
const BnNodeSet& getParents (void) const { return parents_; }
const BnNodeSet& getChilds (void) const { return childs_; }
const Params& params (void) const { return params_; }
unsigned distId (void) const { return distId_; }
unsigned getRowSize (void) const
{
return dist_->params.size() / nrStates();
return params_.size() / nrStates();
}
double getProbability (int row, unsigned col)
{
int idx = (row * getRowSize()) + col;
return dist_->params[idx];
return params_[idx];
}
void setParams (const Params& params);
void setParents (const BnNodeSet&);
void addChild (BayesNode*);
const Params& getParameters (void);
Params getRow (int) const;
bool isRoot (void);
bool isLeaf (void);
bool hasNeighbors (void) const;
int getCptSize (void);
int indexOfParent (const BayesNode*) const;
string cptEntryToString (int, const vector<unsigned>&) const;
friend ostream& operator << (ostream&, const BayesNode&);
private:
DISALLOW_COPY_AND_ASSIGN (BayesNode);
States getDomainHeaders (void) const;
friend ostream& operator << (ostream&, const BayesNode&);
States getDomainHeaders (void) const;
BnNodeSet parents_;
BnNodeSet childs_;
Distribution* dist_;
BnNodeSet parents_;
BnNodeSet childs_;
Params params_;
unsigned distId_;
};
ostream& operator << (ostream&, const BayesNode&);
#endif // HORUS_BAYESNODE_H

View File

@ -34,12 +34,12 @@ void
BnBpSolver::runSolver (void)
{
clock_t start;
if (COLLECT_STATISTICS) {
if (Constants::COLLECT_STATS) {
start = clock();
}
initializeSolver();
runLoopySolver();
if (DL >= 2) {
if (Constants::DEBUG >= 2) {
cout << endl;
if (nIters_ < BpOptions::maxIter) {
cout << "Belief propagation converged in " ;
@ -51,18 +51,13 @@ BnBpSolver::runSolver (void)
}
unsigned size = bayesNet_->nrNodes();
if (COLLECT_STATISTICS) {
if (Constants::COLLECT_STATS) {
unsigned nIters = 0;
bool loopy = bayesNet_->isPolyTree() == false;
if (loopy) nIters = nIters_;
double time = (double (clock() - start)) / CLOCKS_PER_SEC;
Statistics::updateStatistics (size, loopy, nIters, time);
}
if (EXPORT_TO_GRAPHVIZ && size > EXPORT_MINIMAL_SIZE) {
stringstream ss;
ss << Statistics::getSolvedNetworksCounting() << "." << size << ".dot" ;
bayesNet_->exportToGraphViz (ss.str().c_str());
}
}
@ -80,7 +75,7 @@ BnBpSolver::getPosterioriOf (VarId vid)
Params
BnBpSolver::getJointDistributionOf (const VarIds& jointVarIds)
{
if (DL >= 2) {
if (Constants::DEBUG >= 2) {
cout << "calculating joint distribution on: " ;
for (unsigned i = 0; i < jointVarIds.size(); i++) {
VarNode* var = bayesNet_->getBayesNode (jointVarIds[i]);
@ -112,7 +107,7 @@ BnBpSolver::initializeSolver (void)
BnNodeSet roots = bayesNet_->getRootNodes();
for (unsigned i = 0; i < roots.size(); i++) {
const Params& params = roots[i]->getParameters();
const Params& params = roots[i]->params();
Params& piVals = ninf(roots[i])->getPiValues();
for (unsigned ri = 0; ri < roots[i]->nrStates(); ri++) {
piVals[ri] = params[ri];
@ -143,11 +138,11 @@ BnBpSolver::initializeSolver (void)
Params& piVals = ninf(nodes[i])->getPiValues();
Params& ldVals = ninf(nodes[i])->getLambdaValues();
for (unsigned xi = 0; xi < nodes[i]->nrStates(); xi++) {
piVals[xi] = Util::noEvidence();
ldVals[xi] = Util::noEvidence();
piVals[xi] = LogAware::noEvidence();
ldVals[xi] = LogAware::noEvidence();
}
piVals[nodes[i]->getEvidence()] = Util::withEvidence();
ldVals[nodes[i]->getEvidence()] = Util::withEvidence();
piVals[nodes[i]->getEvidence()] = LogAware::withEvidence();
ldVals[nodes[i]->getEvidence()] = LogAware::withEvidence();
}
}
}
@ -161,13 +156,8 @@ BnBpSolver::runLoopySolver()
while (!converged() && nIters_ < BpOptions::maxIter) {
nIters_++;
if (DL >= 2) {
cout << "****************************************" ;
cout << "****************************************" ;
cout << endl;
cout << " Iteration " << nIters_ << endl;
cout << "****************************************" ;
cout << "****************************************" ;
if (Constants::DEBUG >= 2) {
Util::printHeader ("Iteration " + nIters_);
cout << endl;
}
@ -199,7 +189,7 @@ BnBpSolver::runLoopySolver()
break;
}
if (DL >= 2) {
if (Constants::DEBUG >= 2) {
cout << endl;
}
}
@ -228,7 +218,7 @@ BnBpSolver::converged (void) const
} else {
for (unsigned i = 0; i < links_.size(); i++) {
double residual = links_[i]->getResidual();
if (DL >= 2) {
if (Constants::DEBUG >= 2) {
cout << links_[i]->toString() + " residual change = " ;
cout << residual << endl;
}
@ -256,7 +246,7 @@ BnBpSolver::maxResidualSchedule (void)
}
for (unsigned c = 0; c < sortedOrder_.size(); c++) {
if (DL >= 2) {
if (Constants::DEBUG >= 2) {
cout << "current residuals:" << endl;
for (SortedOrder::iterator it = sortedOrder_.begin();
it != sortedOrder_.end(); it ++) {
@ -300,9 +290,8 @@ BnBpSolver::maxResidualSchedule (void)
}
}
if (DL >= 2) {
cout << "----------------------------------------" ;
cout << "----------------------------------------" << endl;
if (Constants::DEBUG >= 2) {
Util::printDashedLine();
}
}
}
@ -313,7 +302,7 @@ void
BnBpSolver::updatePiValues (BayesNode* x)
{
// π(Xi)
if (DL >= 3) {
if (Constants::DEBUG >= 3) {
cout << "updating " << PI_SYMBOL << " values for " << x->label() << endl;
}
Params& piValues = ninf(x)->getPiValues();
@ -329,11 +318,11 @@ BnBpSolver::updatePiValues (BayesNode* x)
Params messageProducts (indexer.size());
for (unsigned k = 0; k < indexer.size(); k++) {
if (DL >= 5) {
if (Constants::DEBUG >= 5) {
calcs1 = new stringstream;
calcs2 = new stringstream;
}
double messageProduct = Util::multIdenty();
double messageProduct = LogAware::multIdenty();
if (Globals::logDomain) {
for (unsigned i = 0; i < parentLinks.size(); i++) {
messageProduct += parentLinks[i]->getMessage()[indexer[i]];
@ -341,7 +330,7 @@ BnBpSolver::updatePiValues (BayesNode* x)
} else {
for (unsigned i = 0; i < parentLinks.size(); i++) {
messageProduct *= parentLinks[i]->getMessage()[indexer[i]];
if (DL >= 5) {
if (Constants::DEBUG >= 5) {
if (i != 0) *calcs1 << " + " ;
if (i != 0) *calcs2 << " + " ;
*calcs1 << parentLinks[i]->toString (indexer[i]);
@ -350,7 +339,7 @@ BnBpSolver::updatePiValues (BayesNode* x)
}
}
messageProducts[k] = messageProduct;
if (DL >= 5) {
if (Constants::DEBUG >= 5) {
cout << " mp" << k;
cout << " = " << (*calcs1).str();
if (parentLinks.size() == 1) {
@ -366,27 +355,27 @@ BnBpSolver::updatePiValues (BayesNode* x)
}
for (unsigned xi = 0; xi < x->nrStates(); xi++) {
double sum = Util::addIdenty();
if (DL >= 5) {
double sum = LogAware::addIdenty();
if (Constants::DEBUG >= 5) {
calcs1 = new stringstream;
calcs2 = new stringstream;
}
indexer.reset();
if (Globals::logDomain) {
for (unsigned k = 0; k < indexer.size(); k++) {
Util::logSum (sum,
x->getProbability(xi, indexer.linearIndex()) + messageProducts[k]);
sum = Util::logSum (sum,
x->getProbability(xi, indexer) + messageProducts[k]);
++ indexer;
}
} else {
for (unsigned k = 0; k < indexer.size(); k++) {
sum += x->getProbability (xi, indexer.linearIndex()) * messageProducts[k];
if (DL >= 5) {
sum += x->getProbability (xi, indexer) * messageProducts[k];
if (Constants::DEBUG >= 5) {
if (k != 0) *calcs1 << " + " ;
if (k != 0) *calcs2 << " + " ;
*calcs1 << x->cptEntryToString (xi, indexer.indices());
*calcs1 << ".mp" << k;
*calcs2 << Util::fl (x->getProbability (xi, indexer.linearIndex()));
*calcs2 << LogAware::fl (x->getProbability (xi, indexer));
*calcs2 << "*" << messageProducts[k];
}
++ indexer;
@ -394,7 +383,7 @@ BnBpSolver::updatePiValues (BayesNode* x)
}
piValues[xi] = sum;
if (DL >= 5) {
if (Constants::DEBUG >= 5) {
cout << " " << PI_SYMBOL << "(" << x->label() << ")" ;
cout << "[" << x->states()[xi] << "]" ;
cout << " = " << (*calcs1).str();
@ -412,7 +401,7 @@ void
BnBpSolver::updateLambdaValues (BayesNode* x)
{
// λ(Xi)
if (DL >= 3) {
if (Constants::DEBUG >= 3) {
cout << "updating " << LD_SYMBOL << " values for " << x->label() << endl;
}
Params& lambdaValues = ninf(x)->getLambdaValues();
@ -421,11 +410,11 @@ BnBpSolver::updateLambdaValues (BayesNode* x)
stringstream* calcs2 = 0;
for (unsigned xi = 0; xi < x->nrStates(); xi++) {
if (DL >= 5) {
if (Constants::DEBUG >= 5) {
calcs1 = new stringstream;
calcs2 = new stringstream;
}
double product = Util::multIdenty();
double product = LogAware::multIdenty();
if (Globals::logDomain) {
for (unsigned i = 0; i < childLinks.size(); i++) {
product += childLinks[i]->getMessage()[xi];
@ -433,7 +422,7 @@ BnBpSolver::updateLambdaValues (BayesNode* x)
} else {
for (unsigned i = 0; i < childLinks.size(); i++) {
product *= childLinks[i]->getMessage()[xi];
if (DL >= 5) {
if (Constants::DEBUG >= 5) {
if (i != 0) *calcs1 << "." ;
if (i != 0) *calcs2 << "*" ;
*calcs1 << childLinks[i]->toString (xi);
@ -442,7 +431,7 @@ BnBpSolver::updateLambdaValues (BayesNode* x)
}
}
lambdaValues[xi] = product;
if (DL >= 5) {
if (Constants::DEBUG >= 5) {
cout << " " << LD_SYMBOL << "(" << x->label() << ")" ;
cout << "[" << x->states()[xi] << "]" ;
cout << " = " << (*calcs1).str();
@ -474,7 +463,7 @@ BnBpSolver::calculatePiMessage (BpLink* link)
const Params& zPiValues = ninf(z)->getPiValues();
for (unsigned zi = 0; zi < z->nrStates(); zi++) {
double product = zPiValues[zi];
if (DL >= 5) {
if (Constants::DEBUG >= 5) {
calcs1 = new stringstream;
calcs2 = new stringstream;
*calcs1 << PI_SYMBOL << "(" << z->label() << ")";
@ -491,7 +480,7 @@ BnBpSolver::calculatePiMessage (BpLink* link)
for (unsigned i = 0; i < zChildLinks.size(); i++) {
if (zChildLinks[i]->getSource() != x) {
product *= zChildLinks[i]->getMessage()[zi];
if (DL >= 5) {
if (Constants::DEBUG >= 5) {
*calcs1 << "." << zChildLinks[i]->toString (zi);
*calcs2 << " * " << zChildLinks[i]->getMessage()[zi];
}
@ -499,7 +488,7 @@ BnBpSolver::calculatePiMessage (BpLink* link)
}
}
zxPiNextMessage[zi] = product;
if (DL >= 5) {
if (Constants::DEBUG >= 5) {
cout << " " << link->toString();
cout << "[" << z->states()[zi] << "]" ;
cout << " = " << (*calcs1).str();
@ -513,7 +502,7 @@ BnBpSolver::calculatePiMessage (BpLink* link)
delete calcs2;
}
}
Util::normalize (zxPiNextMessage);
LogAware::normalize (zxPiNextMessage);
}
@ -527,10 +516,10 @@ BnBpSolver::calculateLambdaMessage (BpLink* link)
if (x->hasEvidence()) {
return;
}
Params& yxLambdaNextMessage = link->getNextMessage();
const BpLinkSet& yParentLinks = ninf(y)->getIncomingParentLinks();
const Params& yLambdaValues = ninf(y)->getLambdaValues();
int parentIndex = y->getIndexOfParent (x);
Params& yxLambdaNextMessage = link->getNextMessage();
const BpLinkSet& yParentLinks = ninf(y)->getIncomingParentLinks();
const Params& yLambdaValues = ninf(y)->getLambdaValues();
int parentIndex = y->indexOfParent (x);
stringstream* calcs1 = 0;
stringstream* calcs2 = 0;
@ -548,11 +537,11 @@ BnBpSolver::calculateLambdaMessage (BpLink* link)
while (indexer[parentIndex] != 0) {
++ indexer;
}
if (DL >= 5) {
if (Constants::DEBUG >= 5) {
calcs1 = new stringstream;
calcs2 = new stringstream;
}
double messageProduct = Util::multIdenty();
double messageProduct = LogAware::multIdenty();
if (Globals::logDomain) {
for (unsigned i = 0; i < yParentLinks.size(); i++) {
if (yParentLinks[i]->getSource() != x) {
@ -562,9 +551,9 @@ BnBpSolver::calculateLambdaMessage (BpLink* link)
} else {
for (unsigned i = 0; i < yParentLinks.size(); i++) {
if (yParentLinks[i]->getSource() != x) {
if (DL >= 5) {
if (messageProduct != Util::multIdenty()) *calcs1 << "*" ;
if (messageProduct != Util::multIdenty()) *calcs2 << "*" ;
if (Constants::DEBUG >= 5) {
if (messageProduct != LogAware::multIdenty()) *calcs1 << "*" ;
if (messageProduct != LogAware::multIdenty()) *calcs2 << "*" ;
*calcs1 << yParentLinks[i]->toString (indexer[i]);
*calcs2 << yParentLinks[i]->getMessage()[indexer[i]];
}
@ -574,7 +563,7 @@ BnBpSolver::calculateLambdaMessage (BpLink* link)
}
messageProducts[k] = messageProduct;
++ indexer;
if (DL >= 5) {
if (Constants::DEBUG >= 5) {
cout << " mp" << k;
cout << " = " << (*calcs1).str();
if (yParentLinks.size() == 1) {
@ -591,55 +580,54 @@ BnBpSolver::calculateLambdaMessage (BpLink* link)
}
for (unsigned xi = 0; xi < x->nrStates(); xi++) {
if (DL >= 5) {
if (Constants::DEBUG >= 5) {
calcs1 = new stringstream;
calcs2 = new stringstream;
}
double outerSum = Util::addIdenty();
double outerSum = LogAware::addIdenty();
for (unsigned yi = 0; yi < y->nrStates(); yi++) {
if (DL >= 5) {
if (Constants::DEBUG >= 5) {
(yi != 0) ? *calcs1 << " + {" : *calcs1 << "{" ;
(yi != 0) ? *calcs2 << " + {" : *calcs2 << "{" ;
}
double innerSum = Util::addIdenty();
double innerSum = LogAware::addIdenty();
indexer.reset();
if (Globals::logDomain) {
for (unsigned k = 0; k < N; k++) {
while (indexer[parentIndex] != xi) {
++ indexer;
}
Util::logSum (innerSum, y->getProbability (
yi, indexer.linearIndex()) + messageProducts[k]);
innerSum = Util::logSum (innerSum,
y->getProbability (yi, indexer) + messageProducts[k]);
++ indexer;
}
Util::logSum (outerSum, innerSum + yLambdaValues[yi]);
outerSum = Util::logSum (outerSum, innerSum + yLambdaValues[yi]);
} else {
for (unsigned k = 0; k < N; k++) {
while (indexer[parentIndex] != xi) {
++ indexer;
}
if (DL >= 5) {
if (Constants::DEBUG >= 5) {
if (k != 0) *calcs1 << " + " ;
if (k != 0) *calcs2 << " + " ;
*calcs1 << y->cptEntryToString (yi, indexer.indices());
*calcs1 << ".mp" << k;
*calcs2 << y->getProbability (yi, indexer.linearIndex());
*calcs2 << y->getProbability (yi, indexer);
*calcs2 << "*" << messageProducts[k];
}
innerSum += y->getProbability (
yi, indexer.linearIndex()) * messageProducts[k];
innerSum += y->getProbability (yi, indexer) * messageProducts[k];
++ indexer;
}
outerSum += innerSum * yLambdaValues[yi];
}
if (DL >= 5) {
if (Constants::DEBUG >= 5) {
*calcs1 << "}." << LD_SYMBOL << "(" << y->label() << ")" ;
*calcs1 << "[" << y->states()[yi] << "]";
*calcs2 << "}*" << yLambdaValues[yi];
}
}
yxLambdaNextMessage[xi] = outerSum;
if (DL >= 5) {
if (Constants::DEBUG >= 5) {
cout << " " << link->toString();
cout << "[" << x->states()[xi] << "]" ;
cout << " = " << (*calcs1).str();
@ -649,7 +637,7 @@ BnBpSolver::calculateLambdaMessage (BpLink* link)
delete calcs2;
}
}
Util::normalize (yxLambdaNextMessage);
LogAware::normalize (yxLambdaNextMessage);
}
@ -674,7 +662,7 @@ BnBpSolver::getJointByConditioning (const VarIds& jointVarIds) const
for (unsigned i = 1; i < jointVarIds.size(); i++) {
assert (jointVars[i]->hasEvidence() == false);
VarIds reqVars = {jointVarIds[i]};
reqVars.insert (reqVars.end(), observedVids.begin(), observedVids.end());
Util::addToVector (reqVars, observedVids);
mrn = bayesNet_->getMinimalRequesiteNetwork (reqVars);
Params newBeliefs;
VarNodes observedVars;
@ -720,8 +708,7 @@ BnBpSolver::printPiLambdaValues (const BayesNode* var) const
cout << setw (20) << LD_SYMBOL << "(" + var->label() + ")" ;
cout << setw (16) << "belief" ;
cout << endl;
cout << "--------------------------------" ;
cout << "--------------------------------" ;
Util::printDashedLine();
cout << endl;
const States& states = var->states();
const Params& piVals = ninf(var)->getPiValues();
@ -731,7 +718,7 @@ BnBpSolver::printPiLambdaValues (const BayesNode* var) const
cout << setw (10) << states[xi];
cout << setw (19) << piVals[xi];
cout << setw (19) << ldVals[xi];
cout.precision (PRECISION);
cout.precision (Constants::PRECISION);
cout << setw (16) << beliefs[xi];
cout << endl;
}
@ -754,8 +741,8 @@ BnBpSolver::printAllMessageStatus (void) const
BpNodeInfo::BpNodeInfo (BayesNode* node)
{
node_ = node;
piVals_.resize (node->nrStates(), Util::one());
ldVals_.resize (node->nrStates(), Util::one());
piVals_.resize (node->nrStates(), LogAware::one());
ldVals_.resize (node->nrStates(), LogAware::one());
}

View File

@ -27,11 +27,11 @@ class BpLink
destin_ = d;
orientation_ = o;
if (orientation_ == LinkOrientation::DOWN) {
v1_.resize (s->nrStates(), Util::tl (1.0 / s->nrStates()));
v2_.resize (s->nrStates(), Util::tl (1.0 / s->nrStates()));
v1_.resize (s->nrStates(), LogAware::tl (1.0 / s->nrStates()));
v2_.resize (s->nrStates(), LogAware::tl (1.0 / s->nrStates()));
} else {
v1_.resize (d->nrStates(), Util::tl (1.0 / d->nrStates()));
v2_.resize (d->nrStates(), Util::tl (1.0 / d->nrStates()));
v1_.resize (d->nrStates(), LogAware::tl (1.0 / d->nrStates()));
v2_.resize (d->nrStates(), LogAware::tl (1.0 / d->nrStates()));
}
currMsg_ = &v1_;
nextMsg_ = &v2_;
@ -39,6 +39,22 @@ class BpLink
msgSended_ = false;
}
BayesNode* getSource (void) const { return source_; }
BayesNode* getDestination (void) const { return destin_; }
LinkOrientation getOrientation (void) const { return orientation_; }
const Params& getMessage (void) const { return *currMsg_; }
Params& getNextMessage (void) { return *nextMsg_;}
bool messageWasSended (void) const { return msgSended_; }
double getResidual (void) const { return residual_; }
void clearResidual (void) { residual_ = 0;}
void updateMessage (void)
{
swap (currMsg_, nextMsg_);
@ -47,7 +63,7 @@ class BpLink
void updateResidual (void)
{
residual_ = Util::getMaxNorm (v1_, v2_);
residual_ = LogAware::getMaxNorm (v1_, v2_);
}
string toString (void) const
@ -74,29 +90,19 @@ class BpLink
}
return ss.str();
}
BayesNode* getSource (void) const { return source_; }
BayesNode* getDestination (void) const { return destin_; }
LinkOrientation getOrientation (void) const { return orientation_; }
const Params& getMessage (void) const { return *currMsg_; }
Params& getNextMessage (void) { return *nextMsg_; }
bool messageWasSended (void) const { return msgSended_; }
double getResidual (void) const { return residual_; }
void clearResidual (void) { residual_ = 0;}
private:
BayesNode* source_;
BayesNode* destin_;
LinkOrientation orientation_;
Params v1_;
Params v2_;
Params* currMsg_;
Params* nextMsg_;
Params v1_;
Params v2_;
Params* currMsg_;
Params* nextMsg_;
bool msgSended_;
double residual_;
};
typedef vector<BpLink*> BpLinkSet;
@ -105,32 +111,41 @@ class BpNodeInfo
public:
BpNodeInfo (BayesNode*);
Params getBeliefs (void) const;
bool receivedBottomInfluence (void) const;
Params& getPiValues (void) { return piVals_; }
Params& getPiValues (void) { return piVals_; }
Params& getLambdaValues (void) { return ldVals_; }
Params& getLambdaValues (void) { return ldVals_; }
const BpLinkSet& getIncomingParentLinks (void) { return inParentLinks_; }
const BpLinkSet& getIncomingChildLinks (void) { return inChildLinks_; }
const BpLinkSet& getIncomingParentLinks (void) { return inParentLinks_; }
const BpLinkSet& getIncomingChildLinks (void) { return inChildLinks_; }
const BpLinkSet& getOutcomingParentLinks (void) { return outParentLinks_; }
const BpLinkSet& getOutcomingChildLinks (void) { return outChildLinks_; }
const BpLinkSet& getOutcomingChildLinks (void) { return outChildLinks_; }
void addIncomingParentLink (BpLink* l) { inParentLinks_.push_back (l); }
void addIncomingChildLink (BpLink* l) { inChildLinks_.push_back (l); }
void addIncomingParentLink (BpLink* l) { inParentLinks_.push_back (l); }
void addIncomingChildLink (BpLink* l) { inChildLinks_.push_back (l); }
void addOutcomingParentLink (BpLink* l) { outParentLinks_.push_back (l); }
void addOutcomingChildLink (BpLink* l) { outChildLinks_.push_back (l); }
void addOutcomingChildLink (BpLink* l) { outChildLinks_.push_back (l); }
Params getBeliefs (void) const;
bool receivedBottomInfluence (void) const;
private:
DISALLOW_COPY_AND_ASSIGN (BpNodeInfo);
const BayesNode* node_;
Params piVals_; // pi values
Params ldVals_; // lambda values
BpLinkSet inParentLinks_;
BpLinkSet inChildLinks_;
BpLinkSet outParentLinks_;
BpLinkSet outChildLinks_;
const BayesNode* node_;
Params piVals_;
Params ldVals_;
BpLinkSet inParentLinks_;
BpLinkSet inChildLinks_;
BpLinkSet outParentLinks_;
BpLinkSet outChildLinks_;
};
@ -139,32 +154,43 @@ class BnBpSolver : public Solver
{
public:
BnBpSolver (const BayesNet&);
~BnBpSolver (void);
void runSolver (void);
Params getPosterioriOf (VarId);
Params getJointDistributionOf (const VarIds&);
void runSolver (void);
Params getPosterioriOf (VarId);
Params getJointDistributionOf (const VarIds&);
private:
DISALLOW_COPY_AND_ASSIGN (BnBpSolver);
void initializeSolver (void);
void runLoopySolver (void);
void maxResidualSchedule (void);
bool converged (void) const;
void updatePiValues (BayesNode*);
void updateLambdaValues (BayesNode*);
void calculateLambdaMessage (BpLink*);
void calculatePiMessage (BpLink*);
Params getJointByJunctionNode (const VarIds&);
Params getJointByConditioning (const VarIds&) const;
void printPiLambdaValues (const BayesNode*) const;
void printAllMessageStatus (void) const;
void initializeSolver (void);
void runLoopySolver (void);
void maxResidualSchedule (void);
bool converged (void) const;
void updatePiValues (BayesNode*);
void updateLambdaValues (BayesNode*);
void calculateLambdaMessage (BpLink*);
void calculatePiMessage (BpLink*);
Params getJointByJunctionNode (const VarIds&);
Params getJointByConditioning (const VarIds&) const;
void printPiLambdaValues (const BayesNode*) const;
void printAllMessageStatus (void) const;
void calculateAndUpdateMessage (BpLink* link, bool calcResidual = true)
{
if (DL >= 3) {
if (Constants::DEBUG >= 3) {
cout << "calculating & updating " << link->toString() << endl;
}
if (link->getOrientation() == LinkOrientation::DOWN) {
@ -180,7 +206,7 @@ class BnBpSolver : public Solver
void calculateMessage (BpLink* link, bool calcResidual = true)
{
if (DL >= 3) {
if (Constants::DEBUG >= 3) {
cout << "calculating " << link->toString() << endl;
}
if (link->getOrientation() == LinkOrientation::DOWN) {
@ -195,7 +221,7 @@ class BnBpSolver : public Solver
void updateMessage (BpLink* link)
{
if (DL >= 3) {
if (Constants::DEBUG >= 3) {
cout << "updating " << link->toString() << endl;
}
link->updateMessage();

View File

@ -1,7 +1,6 @@
#include "CFactorGraph.h"
#include "Factor.h"
#include "Distribution.h"
bool CFactorGraph::checkForIdenticalFactors = true;
@ -73,27 +72,34 @@ CFactorGraph::setInitialColors (void)
const FgFacSet& facNodes = groundFg_->getFactorNodes();
if (checkForIdenticalFactors) {
for (unsigned i = 0, s = facNodes.size(); i < s; i++) {
Distribution* dist1 = facNodes[i]->getDistribution();
for (unsigned j = 0; j < i; j++) {
Distribution* dist2 = facNodes[j]->getDistribution();
if (dist1 != dist2 && dist1->params == dist2->params) {
if (facNodes[i]->factor()->getRanges() ==
facNodes[j]->factor()->getRanges()) {
facNodes[i]->factor()->setDistribution (dist2);
}
unsigned groupCount = 1;
for (unsigned i = 0; i < facNodes.size(); i++) {
Factor* f1 = facNodes[i]->factor();
if (f1->distId() != Util::maxUnsigned()) {
continue;
}
f1->setDistId (groupCount);
for (unsigned j = i + 1; j < facNodes.size(); j++) {
Factor* f2 = facNodes[j]->factor();
if (f2->distId() != Util::maxUnsigned()) {
continue;
}
if (f1->size() == f2->size() &&
f1->ranges() == f2->ranges() &&
f1->params() == f2->params()) {
f2->setDistId (groupCount);
}
}
groupCount ++;
}
}
// create the initial factor colors
DistColorMap distColors;
for (unsigned i = 0; i < facNodes.size(); i++) {
const Distribution* dist = facNodes[i]->getDistribution();
DistColorMap::iterator it = distColors.find (dist);
unsigned distId = facNodes[i]->factor()->distId();
DistColorMap::iterator it = distColors.find (distId);
if (it == distColors.end()) {
it = distColors.insert (make_pair (dist, getFreeColor())).first;
it = distColors.insert (make_pair (distId, getFreeColor())).first;
}
setColor (facNodes[i], it->second);
}
@ -104,11 +110,11 @@ CFactorGraph::setInitialColors (void)
void
CFactorGraph::createGroups (void)
{
VarSignMap varGroups;
VarSignMap varGroups;
FacSignMap factorGroups;
unsigned nIters = 0;
bool groupsHaveChanged = true;
const FgVarSet& varNodes = groundFg_->getVarNodes();
const FgVarSet& varNodes = groundFg_->getVarNodes();
const FgFacSet& facNodes = groundFg_->getFactorNodes();
while (groupsHaveChanged || nIters == 1) {
@ -164,8 +170,9 @@ CFactorGraph::createGroups (void)
void
CFactorGraph::createClusters (const VarSignMap& varGroups,
const FacSignMap& factorGroups)
CFactorGraph::createClusters (
const VarSignMap& varGroups,
const FacSignMap& factorGroups)
{
varClusters_.reserve (varGroups.size());
for (VarSignMap::const_iterator it = varGroups.begin();
@ -249,7 +256,7 @@ CFactorGraph::getCompressedFactorGraph (void)
myGroundVars.push_back (v);
}
Factor* newFactor = new Factor (myGroundVars,
facClusters_[i]->getGroundFactors()[0]->getDistribution());
facClusters_[i]->getGroundFactors()[0]->params());
FgFacNode* fn = new FgFacNode (newFactor);
facClusters_[i]->setRepresentativeFactor (fn);
fg->addFactor (fn);
@ -293,8 +300,9 @@ CFactorGraph::getGroundEdgeCount (
void
CFactorGraph::printGroups (const VarSignMap& varGroups,
const FacSignMap& factorGroups) const
CFactorGraph::printGroups (
const VarSignMap& varGroups,
const FacSignMap& factorGroups) const
{
unsigned count = 1;
cout << "variable groups:" << endl;

View File

@ -15,23 +15,25 @@ class Signature;
class SignatureHash;
typedef long Color;
typedef unordered_map<unsigned, vector<Color> > VarColorMap;
typedef unordered_map<const Distribution*, Color> DistColorMap;
typedef unordered_map<VarId, VarCluster*> VarId2VarCluster;
typedef vector<VarCluster*> VarClusterSet;
typedef vector<FacCluster*> FacClusterSet;
typedef unordered_map<Signature, FgVarSet, SignatureHash> VarSignMap;
typedef unordered_map<Signature, FgFacSet, SignatureHash> FacSignMap;
typedef long Color;
typedef unordered_map<unsigned, vector<Color>> VarColorMap;
typedef unordered_map<unsigned, Color> DistColorMap;
typedef unordered_map<VarId, VarCluster*> VarId2VarCluster;
typedef vector<VarCluster*> VarClusterSet;
typedef vector<FacCluster*> FacClusterSet;
typedef unordered_map<Signature, FgVarSet, SignatureHash> VarSignMap;
typedef unordered_map<Signature, FgFacSet, SignatureHash> FacSignMap;
struct Signature
{
Signature (unsigned size)
{
colors.resize (size);
}
Signature (unsigned size) : colors(size) { }
bool operator< (const Signature& sig) const
{
if (colors.size() < sig.colors.size()) {
@ -49,6 +51,7 @@ struct Signature
}
return false;
}
bool operator== (const Signature& sig) const
{
if (colors.size() != sig.colors.size()) {
@ -61,12 +64,14 @@ struct Signature
}
return true;
}
vector<Color> colors;
};
struct SignatureHash {
struct SignatureHash
{
size_t operator() (const Signature &sig) const
{
size_t val = hash<size_t>()(sig.colors.size());
@ -141,10 +146,12 @@ class FacCluster
{
return representFactor_;
}
void setRepresentativeFactor (FgFacNode* fn)
{
representFactor_ = fn;
}
const FgFacSet& getGroundFactors (void) const
{
return groundFactors_;
@ -162,31 +169,28 @@ class CFactorGraph
{
public:
CFactorGraph (const FactorGraph&);
~CFactorGraph (void);
FactorGraph* getCompressedFactorGraph (void);
unsigned getGroundEdgeCount (const FacCluster*, const VarCluster*) const;
const VarClusterSet& getVarClusters (void) { return varClusters_; }
const FacClusterSet& getFacClusters (void) { return facClusters_; }
FgVarNode* getEquivalentVariable (VarId vid)
{
VarCluster* vc = vid2VarCluster_.find (vid)->second;
return vc->getRepresentativeVariable();
}
const VarClusterSet& getVarClusters (void) { return varClusters_; }
const FacClusterSet& getFacClusters (void) { return facClusters_; }
FactorGraph* getCompressedFactorGraph (void);
unsigned getGroundEdgeCount (const FacCluster*, const VarCluster*) const;
static bool checkForIdenticalFactors;
private:
void setInitialColors (void);
void createGroups (void);
void createClusters (const VarSignMap&, const FacSignMap&);
const Signature& getSignature (const FgVarNode*);
const Signature& getSignature (const FgFacNode*);
void printGroups (const VarSignMap&, const FacSignMap&) const;
Color getFreeColor (void) {
Color getFreeColor (void)
{
++ freeColor_;
return freeColor_ - 1;
}
@ -214,14 +218,26 @@ class CFactorGraph
return vid2VarCluster_.find (vid)->second;
}
void setInitialColors (void);
void createGroups (void);
void createClusters (const VarSignMap&, const FacSignMap&);
const Signature& getSignature (const FgVarNode*);
const Signature& getSignature (const FgFacNode*);
void printGroups (const VarSignMap&, const FacSignMap&) const;
Color freeColor_;
vector<Color> varColors_;
vector<Color> factorColors_;
vector<Signature> varSignatures_;
vector<Signature> factorSignatures_;
VarClusterSet varClusters_;
FacClusterSet facClusters_;
VarId2VarCluster vid2VarCluster_;
FacClusterSet facClusters_;
VarId2VarCluster vid2VarCluster_;
const FactorGraph* groundFg_;
};

View File

@ -20,24 +20,24 @@ CbpSolver::getPosterioriOf (VarId vid)
FgVarNode* var = lfg_->getEquivalentVariable (vid);
Params probs;
if (var->hasEvidence()) {
probs.resize (var->nrStates(), Util::noEvidence());
probs[var->getEvidence()] = Util::withEvidence();
probs.resize (var->nrStates(), LogAware::noEvidence());
probs[var->getEvidence()] = LogAware::withEvidence();
} else {
probs.resize (var->nrStates(), Util::multIdenty());
probs.resize (var->nrStates(), LogAware::multIdenty());
const SpLinkSet& links = ninf(var)->getLinks();
if (Globals::logDomain) {
for (unsigned i = 0; i < links.size(); i++) {
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
Util::add (probs, l->getPoweredMessage());
}
Util::normalize (probs);
LogAware::normalize (probs);
Util::fromLog (probs);
} else {
for (unsigned i = 0; i < links.size(); i++) {
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
Util::multiply (probs, l->getPoweredMessage());
}
Util::normalize (probs);
LogAware::normalize (probs);
}
}
return probs;
@ -62,7 +62,7 @@ void
CbpSolver::initializeSolver (void)
{
unsigned nGroundVars, nGroundFacs, nWithoutNeighs;
if (COLLECT_STATISTICS) {
if (Constants::COLLECT_STATS) {
nGroundVars = factorGraph_->getVarNodes().size();
nGroundFacs = factorGraph_->getFactorNodes().size();
const FgVarSet& vars = factorGraph_->getVarNodes();
@ -82,7 +82,7 @@ CbpSolver::initializeSolver (void)
// factorGraph_->exportToGraphViz ("uncompressed_fg.dot");
factorGraph_ = lfg_->getCompressedFactorGraph();
if (COLLECT_STATISTICS) {
if (Constants::COLLECT_STATS) {
unsigned nClusterVars = factorGraph_->getVarNodes().size();
unsigned nClusterFacs = factorGraph_->getFactorNodes().size();
Statistics::updateCompressingStatistics (nGroundVars, nGroundFacs,
@ -123,7 +123,7 @@ CbpSolver::maxResidualSchedule (void)
calculateMessage (links_[i]);
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
linkMap_.insert (make_pair (links_[i], it));
if (DL >= 2 && DL < 5) {
if (Constants::DEBUG >= 2 && Constants::DEBUG < 5) {
cout << "calculating " << links_[i]->toString() << endl;
}
}
@ -131,7 +131,7 @@ CbpSolver::maxResidualSchedule (void)
}
for (unsigned c = 0; c < links_.size(); c++) {
if (DL >= 2) {
if (Constants::DEBUG >= 2) {
cout << endl << "current residuals:" << endl;
for (SortedOrder::iterator it = sortedOrder_.begin();
it != sortedOrder_.end(); it ++) {
@ -142,7 +142,7 @@ CbpSolver::maxResidualSchedule (void)
SortedOrder::iterator it = sortedOrder_.begin();
SpLink* link = *it;
if (DL >= 2) {
if (Constants::DEBUG >= 2) {
cout << "updating " << (*sortedOrder_.begin())->toString() << endl;
}
if (link->getResidual() < BpOptions::accuracy) {
@ -159,7 +159,7 @@ CbpSolver::maxResidualSchedule (void)
const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks();
for (unsigned j = 0; j < links.size(); j++) {
if (links[j]->getVariable() != link->getVariable()) {
if (DL >= 2 && DL < 5) {
if (Constants::DEBUG >= 2 && Constants::DEBUG < 5) {
cout << " calculating " << links[j]->toString() << endl;
}
calculateMessage (links[j]);
@ -174,7 +174,7 @@ CbpSolver::maxResidualSchedule (void)
const SpLinkSet& links = ninf(link->getFactor())->getLinks();
for (unsigned i = 0; i < links.size(); i++) {
if (links[i]->getVariable() != link->getVariable()) {
if (DL >= 2 && DL < 5) {
if (Constants::DEBUG >= 2 && Constants::DEBUG < 5) {
cout << " calculating " << links[i]->toString() << endl;
}
calculateMessage (links[i]);
@ -196,15 +196,15 @@ CbpSolver::getVar2FactorMsg (const SpLink* link) const
const FgFacNode* dst = link->getFactor();
const CbpSolverLink* l = static_cast<const CbpSolverLink*> (link);
if (src->hasEvidence()) {
msg.resize (src->nrStates(), Util::noEvidence());
msg.resize (src->nrStates(), LogAware::noEvidence());
double value = link->getMessage()[src->getEvidence()];
msg[src->getEvidence()] = Util::pow (value, l->getNumberOfEdges() - 1);
msg[src->getEvidence()] = LogAware::pow (value, l->getNumberOfEdges() - 1);
} else {
msg = link->getMessage();
Util::pow (msg, l->getNumberOfEdges() - 1);
LogAware::pow (msg, l->getNumberOfEdges() - 1);
}
if (DL >= 5) {
cout << " " << "init: " << Util::parametersToString (msg) << endl;
if (Constants::DEBUG >= 5) {
cout << " " << "init: " << msg << endl;
}
const SpLinkSet& links = ninf(src)->getLinks();
if (Globals::logDomain) {
@ -219,16 +219,16 @@ CbpSolver::getVar2FactorMsg (const SpLink* link) const
if (links[i]->getFactor() != dst) {
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
Util::multiply (msg, l->getPoweredMessage());
if (DL >= 5) {
if (Constants::DEBUG >= 5) {
cout << " msg from " << l->getFactor()->getLabel() << ": " ;
cout << Util::parametersToString (l->getPoweredMessage()) << endl;
cout << l->getPoweredMessage() << endl;
}
}
}
}
if (DL >= 5) {
cout << " result = " << Util::parametersToString (msg) << endl;
if (Constants::DEBUG >= 5) {
cout << " result = " << msg << endl;
}
return msg;
}
@ -241,12 +241,9 @@ CbpSolver::printLinkInformation (void) const
for (unsigned i = 0; i < links_.size(); i++) {
CbpSolverLink* l = static_cast<CbpSolverLink*> (links_[i]);
cout << l->toString() << ":" << endl;
cout << " curr msg = " ;
cout << Util::parametersToString (l->getMessage()) << endl;
cout << " next msg = " ;
cout << Util::parametersToString (l->getNextMessage()) << endl;
cout << " powered = " ;
cout << Util::parametersToString (l->getPoweredMessage()) << endl;
cout << " curr msg = " << l->getMessage() << endl;
cout << " next msg = " << l->getNextMessage() << endl;
cout << " powered = " << l->getPoweredMessage() << endl;
cout << " residual = " << l->getResidual() << endl;
}
}

View File

@ -12,23 +12,24 @@ class CbpSolverLink : public SpLink
CbpSolverLink (FgFacNode* fn, FgVarNode* vn, unsigned c) : SpLink (fn, vn)
{
edgeCount_ = c;
poweredMsg_.resize (vn->nrStates(), Util::one());
poweredMsg_.resize (vn->nrStates(), LogAware::one());
}
unsigned getNumberOfEdges (void) const { return edgeCount_; }
const Params& getPoweredMessage (void) const { return poweredMsg_; }
void updateMessage (void)
{
poweredMsg_ = *nextMsg_;
swap (currMsg_, nextMsg_);
msgSended_ = true;
Util::pow (poweredMsg_, edgeCount_);
LogAware::pow (poweredMsg_, edgeCount_);
}
unsigned getNumberOfEdges (void) const { return edgeCount_; }
const Params& getPoweredMessage (void) const { return poweredMsg_; }
private:
Params poweredMsg_;
unsigned edgeCount_;
Params poweredMsg_;
unsigned edgeCount_;
};
@ -37,21 +38,22 @@ class CbpSolver : public FgBpSolver
{
public:
CbpSolver (FactorGraph& fg) : FgBpSolver (fg) { }
~CbpSolver (void);
Params getPosterioriOf (VarId);
Params getJointDistributionOf (const VarIds&);
Params getPosterioriOf (VarId);
Params getJointDistributionOf (const VarIds&);
private:
void initializeSolver (void);
void createLinks (void);
void initializeSolver (void);
void createLinks (void);
void maxResidualSchedule (void);
Params getVar2FactorMsg (const SpLink*) const;
void printLinkInformation (void) const;
void maxResidualSchedule (void);
Params getVar2FactorMsg (const SpLink*) const;
void printLinkInformation (void) const;
CFactorGraph* lfg_;
CFactorGraph* lfg_;
};
#endif // HORUS_CBP_H

View File

@ -1,10 +1,11 @@
#include <queue>
#include <fstream>
#include "ConstraintTree.h"
#include "Util.h"
void
CTNode::addChild (CTNode* child, bool updateLevels)
{
@ -42,6 +43,26 @@ CTNode::removeChild (CTNode* child)
void
CTNode::removeAndDeleteChild (CTNode* child)
{
removeChild (child);
CTNode::deleteSubtree (child);
}
void
CTNode::removeAndDeleteAllChilds (void)
{
for (unsigned i = 0; i < childs_.size(); i++) {
deleteSubtree (childs_[i]);
}
childs_.clear();
}
SymbolSet
CTNode::childSymbols (void) const
{
@ -66,6 +87,32 @@ CTNode::updateChildLevels (CTNode* n, unsigned level)
CTNode*
CTNode::copySubtree (const CTNode* n)
{
CTNode* newNode = new CTNode (*n);
const CTNodes& childs = n->childs();
for (unsigned i = 0; i < childs.size(); i++) {
newNode->addChild (copySubtree (childs[i]));
}
return newNode;
}
void
CTNode::deleteSubtree (CTNode* n)
{
assert (n);
const CTNodes& childs = n->childs();
for (unsigned i = 0; i < childs.size(); i++) {
deleteSubtree (childs[i]);
}
delete n;
}
ostream& operator<< (ostream &out, const CTNode& n)
{
// out << "(" << n.level() << ") " ;
@ -75,6 +122,17 @@ ostream& operator<< (ostream &out, const CTNode& n)
ConstraintTree::ConstraintTree (unsigned nrLvs)
{
for (unsigned i = 0; i < nrLvs; i++) {
logVars_.push_back (LogVar (i));
}
root_ = new CTNode (0, 0);
logVarSet_ = LogVarSet (logVars_);
}
ConstraintTree::ConstraintTree (const LogVars& logVars)
{
root_ = new CTNode (0, 0);
@ -99,7 +157,7 @@ ConstraintTree::ConstraintTree (const LogVars& logVars,
ConstraintTree::ConstraintTree (const ConstraintTree& ct)
{
root_ = copySubtree (ct.root_);
root_ = CTNode::copySubtree (ct.root_);
logVars_ = ct.logVars_;
logVarSet_ = ct.logVarSet_;
}
@ -108,7 +166,7 @@ ConstraintTree::ConstraintTree (const ConstraintTree& ct)
ConstraintTree::~ConstraintTree (void)
{
deleteSubtree (root_);
CTNode::deleteSubtree (root_);
}
@ -200,21 +258,28 @@ ConstraintTree::moveToBottom (const LogVars& lvs)
void
ConstraintTree::join (ConstraintTree* ct, bool assertWhenNotFound)
{
{
if (logVarSet_.empty()) {
delete root_;
root_ = CTNode::copySubtree (ct->root());
logVars_ = ct->logVars();
logVarSet_ = ct->logVarSet();
return;
}
LogVarSet intersect = logVarSet_ & ct->logVarSet_;
if (intersect.empty()) {
const CTNodes& childs = ct->root()->childs();
CTNodes leafs = getNodesAtLevel (getLevel (logVars_.back()));
for (unsigned i = 0; i < leafs.size(); i++) {
for (unsigned j = 0; j < childs.size(); j++) {
leafs[i]->addChild (copySubtree (childs[j]));
leafs[i]->addChild (CTNode::copySubtree (childs[j]));
}
}
logVars_.insert (logVars_.end(), ct->logVars_.begin(), ct->logVars_.end());
Util::addToVector (logVars_, ct->logVars_);
logVarSet_ |= ct->logVarSet_;
} else {
moveToBottom (intersect.elements());
ct->moveToTop (intersect.elements());
@ -222,25 +287,27 @@ ConstraintTree::join (ConstraintTree* ct, bool assertWhenNotFound)
CTNodes nodes = getNodesAtLevel (level);
Tuples tuples;
CTNodes continuationNodes;
CTNodes continNodes;
getTuples (ct->root(),
Tuples(),
intersect.size(),
tuples,
continuationNodes);
continNodes);
for (unsigned i = 0; i < tuples.size(); i++) {
bool tupleFounded = false;
for (unsigned j = 0; j < nodes.size(); j++) {
tupleFounded |= join (nodes[j], tuples[i], 0, continuationNodes[i]);
tupleFounded |= join (nodes[j], tuples[i], 0, continNodes[i]);
}
if (assertWhenNotFound) {
assert (tupleFounded);
}
}
LogVarSet newLvs = ct->logVarSet_ - intersect;
logVars_.insert (logVars_.end(), newLvs.begin(), newLvs.end());
logVarSet_ |= newLvs;
LogVars newLvs (ct->logVars().begin() + intersect.size(),
ct->logVars().end());
Util::addToVector (logVars_, newLvs);
logVarSet_ |= LogVarSet (newLvs);
}
}
@ -280,6 +347,10 @@ ConstraintTree::rename (LogVar X_old, LogVar X_new)
void
ConstraintTree::applySubstitution (const Substitution& theta)
{
LogVars discardedLvs = theta.getDiscardedLogVars();
for (unsigned i = 0; i < discardedLvs.size(); i++) {
remove(discardedLvs[i]);
}
for (unsigned i = 0; i < logVars_.size(); i++) {
logVars_[i] = theta.newNameFor (logVars_[i]);
}
@ -308,11 +379,7 @@ ConstraintTree::remove (const LogVarSet& X)
unsigned level = getLevel (X.front()) - 1;
CTNodes nodes = getNodesAtLevel (level);
for (unsigned i = 0; i < nodes.size(); i++) {
CTNodes childs = nodes[i]->childs();
for (unsigned j = 0; j < childs.size(); j++) {
nodes[i]->removeChild (childs[j]);
deleteSubtree (childs[j]);
}
nodes[i]->removeAndDeleteAllChilds();
}
logVars_.resize (logVars_.size() - X.size());
logVarSet_ -= X;
@ -545,16 +612,16 @@ ConstraintTree::split (
for (unsigned i = 0; i < commNodes.size(); i++) {
commCt->root()->addChild (commNodes[i]);
}
//cout << commCt->tupleSet() << " + " ;
//cout << exclCt->tupleSet() << " = " ;
//cout << tupleSet() << endl << endl;
// cout << commCt->tupleSet() << " + " ;
// cout << exclCt->tupleSet() << " = " ;
// cout << tupleSet() << endl << endl;
// if (((commCt->tupleSet() | exclCt->tupleSet()) == tupleSet()) == false) {
// exportToGraphViz ("_fail.dot", true);
// commCt->exportToGraphViz ("_fail_comm.dot", true);
// exclCt->exportToGraphViz ("_fail_excl.dot", true);
// }
assert ((commCt->tupleSet() | exclCt->tupleSet()) == tupleSet());
assert ((exclCt->tupleSet (stopLevel) & ct->tupleSet (stopLevel)).empty());
// assert ((commCt->tupleSet() | exclCt->tupleSet()) == tupleSet());
// assert ((exclCt->tupleSet (stopLevel) & ct->tupleSet (stopLevel)).empty());
return {commCt, exclCt};
}
@ -601,36 +668,32 @@ ConstraintTree::jointCountNormalize (
LogVar X_new1,
LogVar X_new2)
{
exportToGraphViz ("C.dot", true);
commCt->exportToGraphViz ("C_comm.dot", true);
exclCt->exportToGraphViz ("C_exlc.dot", true);
unsigned N = getConditionalCount (X);
cout << "My tuples: " << tupleSet() << endl;
cout << "CommCt tuples: " << commCt->tupleSet() << endl;
cout << "ExclCt tuples: " << exclCt->tupleSet() << endl;
cout << "Counted Lv: " << X << endl;
cout << "Original N: " << N << endl;
cout << endl;
// cout << "My tuples: " << tupleSet() << endl;
// cout << "CommCt tuples: " << commCt->tupleSet() << endl;
// cout << "ExclCt tuples: " << exclCt->tupleSet() << endl;
// cout << "Counted Lv: " << X << endl;
// cout << "X_new1: " << X_new1 << endl;
// cout << "X_new2: " << X_new2 << endl;
// cout << "Original N: " << N << endl;
// cout << endl;
ConstraintTrees normCts1 = commCt->countNormalize (X);
vector<unsigned> counts1 (normCts1.size());
for (unsigned i = 0; i < normCts1.size(); i++) {
counts1[i] = normCts1[i]->getConditionalCount (X);
cout << "normCts1[" << i << "] #" << counts1[i] ;
cout << " " << normCts1[i]->tupleSet() << endl;
// cout << "normCts1[" << i << "] #" << counts1[i] ;
// cout << " " << normCts1[i]->tupleSet() << endl;
}
ConstraintTrees normCts2 = exclCt->countNormalize (X);
vector<unsigned> counts2 (normCts2.size());
for (unsigned i = 0; i < normCts2.size(); i++) {
counts2[i] = normCts2[i]->getConditionalCount (X);
cout << "normCts2[" << i << "] #" << counts2[i] ;
cout << " " << normCts2[i]->tupleSet() << endl;
// cout << "normCts2[" << i << "] #" << counts2[i] ;
// cout << " " << normCts2[i]->tupleSet() << endl;
}
cout << endl;
cout << "1###### " << normCts1.size() << endl;
cout << "2###### " << normCts2.size() << endl;
// cout << endl;
ConstraintTree* excl1 = 0;
for (unsigned i = 0; i < normCts1.size(); i++) {
@ -638,7 +701,7 @@ ConstraintTree::jointCountNormalize (
excl1 = normCts1[i];
normCts1.erase (normCts1.begin() + i);
counts1.erase (counts1.begin() + i);
cout << ">joint-count(" << N << ",0)" << endl;
// cout << "joint-count(" << N << ",0)" << endl;
break;
}
}
@ -649,22 +712,21 @@ ConstraintTree::jointCountNormalize (
excl2 = normCts2[i];
normCts2.erase (normCts2.begin() + i);
counts2.erase (counts2.begin() + i);
cout << ">>joint-count(0," << N << ")" << endl;
// cout << "joint-count(0," << N << ")" << endl;
break;
}
}
cout << "3###### " << normCts1.size() << endl;
cout << "4###### " << normCts2.size() << endl;
for (unsigned i = 0; i < normCts1.size(); i++) {
unsigned j;
for (j = 0; counts1[i] + counts2[j] != N; j++) ;
cout << "joint-count(" << counts1[i] << "," << counts2[j] << ")" << endl;
// cout << "joint-count(" << counts1[i] ;
// cout << "," << counts2[j] << ")" << endl;
const CTNodes& childs = normCts2[j]->root_->childs();
for (unsigned k = 0; k < childs.size(); k++) {
normCts1[i]->root_->addChild (childs[k]);
normCts1[i]->root_->addChild (CTNode::copySubtree (childs[k]));
}
delete normCts2[j];
}
ConstraintTrees cts = normCts1;
@ -683,11 +745,6 @@ ConstraintTree::jointCountNormalize (
cts.push_back (excl2);
}
for (unsigned i = 0; i < cts.size(); i++) {
stringstream ss;
ss << "aaacts_" << i + 1 << ".dot" ;
cts[i]->exportToGraphViz (ss.str().c_str(), true);
}
return cts;
}
@ -735,11 +792,11 @@ ConstraintTree::expand (LogVar X)
unsigned nrSymbols = getConditionalCount (X);
for (unsigned i = 0; i < nodes.size(); i++) {
Symbols symbols;
CTNodes childs = nodes[i]->childs();
const CTNodes& childs = nodes[i]->childs();
for (unsigned j = 0; j < childs.size(); j++) {
symbols.push_back (childs[j]->symbol());
nodes[i]->removeChild (childs[j]);
}
nodes[i]->removeAndDeleteAllChilds();
CTNode* prev = nodes[i];
assert (symbols.size() == nrSymbols);
for (unsigned j = 0; j < nrSymbols; j++) {
@ -768,7 +825,7 @@ ConstraintTree::ground (LogVar X)
ConstraintTrees cts;
const CTNodes& nodes = root_->childs();
for (unsigned i = 0; i < nodes.size(); i++) {
CTNode* copy = copySubtree (nodes[i]);
CTNode* copy = CTNode::copySubtree (nodes[i]);
copy->setSymbol (nodes[i]->symbol());
ConstraintTree* newCt = new ConstraintTree (logVars_);
newCt->root()->addChild (copy);
@ -884,7 +941,7 @@ ConstraintTree::join (
if (currIdx == tuple.size() - 1) {
const CTNodes& childs = appendNode->childs();
for (unsigned i = 0; i < childs.size(); i++) {
n->addChild (copySubtree (childs[i]));
n->addChild (CTNode::copySubtree (childs[i]));
}
return true;
}
@ -985,7 +1042,7 @@ ConstraintTree::countNormalize (
{
if (n->level() == stopLevel) {
return vector<pair<CTNode*, unsigned>>() = {
make_pair (copySubtree (n), countTuples (n))
make_pair (CTNode::copySubtree (n), countTuples (n))
};
}
@ -1004,65 +1061,6 @@ ConstraintTree::countNormalize (
}
/*
void
ConstraintTree::split (
CTNode* n1,
CTNode* n2,
CTNodes& nodes,
unsigned stopLevel)
{
CTNodes& childs1 = n1->childs();
CTNodes& childs2 = n2->childs();
// cout << string (n1->level() * 8, '-') << "Level = " << n1->level() + 1;
// cout << ", #I = " << childs1.size();
// cout << ", #J = " << childs2.size() << endl;
for (unsigned i = 0; i < childs1.size(); i++) {
for (unsigned j = 0; j < childs2.size(); j++) {
if (childs1[i]->symbol() != childs2[j]->symbol()) {
continue;
}
if (childs1[i]->level() == stopLevel) {
CTNode* newNode = copySubtree (childs1[i]);
newNode->setSymbol (childs1[i]->symbol());
nodes.push_back (newNode);
childs1[i]->setSymbol (Symbol::invalid());
break;
} else {
CTNodes lowerNodes;
split (childs1[i], childs2[j], lowerNodes, stopLevel);
if (lowerNodes.empty() == false) {
CTNode* me = new CTNode (childs1[i]->symbol(), childs1[i]->level());
for (unsigned k = 0; k < lowerNodes.size(); k++) {
me->addChild (lowerNodes[k]);
}
nodes.push_back (me);
}
if (childs1[i]->isLeaf()) {
break;
}
}
}
}
for (int i = 0; i < (int)childs1.size(); i++) {
// cout << string (n1->level() * 8, '-') << childs1[i];
if (childs1[i]->symbol() == Symbol::invalid()) {
// cout << " empty, removing..." ;
n1->removeChild (childs1[i]);
i --;
} else if (childs1[i]->isLeaf() &&
childs1[i]->level() != stopLevel) {
// cout << " leaf, removing..." ;
n1->removeChild (childs1[i]);
i --;
}
// cout << endl;
}
}
*/
void
ConstraintTree::split (
@ -1085,7 +1083,7 @@ ConstraintTree::split (
continue;
}
if (childs1[i]->level() == stopLevel) {
CTNode* newNode = copySubtree (childs1[i]);
CTNode* newNode = CTNode::copySubtree (childs1[i]);
nodes.push_back (newNode);
childs1[i]->setSymbol (Symbol::invalid());
} else {
@ -1103,11 +1101,11 @@ ConstraintTree::split (
for (int i = 0; i < (int)childs1.size(); i++) {
if (childs1[i]->symbol() == Symbol::invalid()) {
n1->removeChild (childs1[i]);
n1->removeAndDeleteChild (childs1[i]);
i --;
} else if (childs1[i]->isLeaf() &&
childs1[i]->level() != stopLevel) {
n1->removeChild (childs1[i]);
n1->removeAndDeleteChild (childs1[i]);
i --;
}
}
@ -1141,29 +1139,3 @@ ConstraintTree::overlap (
return false;
}
CTNode*
ConstraintTree::copySubtree (const CTNode* n)
{
CTNode* newNode = new CTNode (*n);
const CTNodes& childs = n->childs();
for (unsigned i = 0; i < childs.size(); i++) {
newNode->addChild (copySubtree (childs[i]));
}
return newNode;
}
void
ConstraintTree::deleteSubtree (CTNode* n)
{
assert (n);
const CTNodes& childs = n->childs();
for (unsigned i = 0; i < childs.size(); i++) {
deleteSubtree (childs[i]);
}
delete n;
}

View File

@ -21,7 +21,6 @@ typedef vector<ConstraintTree*> ConstraintTrees;
class CTNode
{
public:
@ -47,29 +46,42 @@ class CTNode
bool isLeaf (void) const { return childs_.empty(); }
void addChild (CTNode*, bool = true);
void removeChild (CTNode*);
SymbolSet childSymbols (void) const;
void addChild (CTNode*, bool = true);
void removeChild (CTNode*);
void removeAndDeleteChild (CTNode*);
void removeAndDeleteAllChilds (void);
SymbolSet childSymbols (void) const;
static CTNode* copySubtree (const CTNode*);
static void deleteSubtree (CTNode*);
private:
void updateChildLevels (CTNode*, unsigned);
void updateChildLevels (CTNode*, unsigned);
Symbol symbol_;
CTNodes childs_;
unsigned level_;
Symbol symbol_;
CTNodes childs_;
unsigned level_;
};
ostream& operator<< (ostream &out, const CTNode&);
class ConstraintTree
{
public:
ConstraintTree (unsigned);
ConstraintTree (const LogVars&);
ConstraintTree (const LogVars&, const Tuples&);
ConstraintTree (const ConstraintTree&);
~ConstraintTree (void);
CTNode* root (void) const { return root_; }
@ -94,94 +106,95 @@ class ConstraintTree
assert (LogVarSet (logVars_) == logVarSet_);
}
void addTuple (const Tuple&);
bool containsTuple (const Tuple&);
void moveToTop (const LogVars&);
void moveToBottom (const LogVars&);
void join (ConstraintTree*, bool = false);
unsigned getLevel (LogVar) const;
void rename (LogVar, LogVar);
void applySubstitution (const Substitution&);
void project (const LogVarSet&);
void remove (const LogVarSet&);
bool isSingleton (LogVar);
LogVarSet singletons (void);
TupleSet tupleSet (unsigned = 0) const;
TupleSet tupleSet (const LogVars&);
unsigned size (void) const;
unsigned nrSymbols (LogVar);
void exportToGraphViz (const char*, bool = false) const;
bool isCountNormalized (const LogVarSet&);
unsigned getConditionalCount (const LogVarSet&);
TinySet<unsigned> getConditionalCounts (const LogVarSet&);
bool isCarteesianProduct (const LogVarSet&) const;
void addTuple (const Tuple&);
bool containsTuple (const Tuple&);
void moveToTop (const LogVars&);
void moveToBottom (const LogVars&);
void join (ConstraintTree*, bool = false);
unsigned getLevel (LogVar) const;
void rename (LogVar, LogVar);
void applySubstitution (const Substitution&);
void project (const LogVarSet&);
void remove (const LogVarSet&);
bool isSingleton (LogVar);
LogVarSet singletons (void);
TupleSet tupleSet (unsigned = 0) const;
TupleSet tupleSet (const LogVars&);
unsigned size (void) const;
unsigned nrSymbols (LogVar);
void exportToGraphViz (const char*, bool = false) const;
bool isCountNormalized (const LogVarSet&);
unsigned getConditionalCount (const LogVarSet&);
TinySet<unsigned> getConditionalCounts (const LogVarSet&);
bool isCarteesianProduct (const LogVarSet&) const;
std::pair<ConstraintTree*, ConstraintTree*> split (
const Tuple&,
unsigned);
const Tuple&, unsigned);
std::pair<ConstraintTree*, ConstraintTree*> split (
const ConstraintTree*,
unsigned) const;
const ConstraintTree*, unsigned) const;
ConstraintTrees countNormalize (const LogVarSet&);
ConstraintTrees countNormalize (const LogVarSet&);
ConstraintTrees jointCountNormalize (
ConstraintTree*,
ConstraintTree*,
LogVar,
LogVar,
LogVar);
ConstraintTree*, ConstraintTree*, LogVar, LogVar, LogVar);
static bool identical (
const ConstraintTree*,
const ConstraintTree*,
unsigned);
static bool identical (
const ConstraintTree*, const ConstraintTree*, unsigned);
static bool overlap (
const ConstraintTree*,
const ConstraintTree*,
unsigned);
static bool overlap (
const ConstraintTree*, const ConstraintTree*, unsigned);
LogVars expand (LogVar);
ConstraintTrees ground (LogVar);
LogVars expand (LogVar);
ConstraintTrees ground (LogVar);
private:
unsigned countTuples (const CTNode*) const;
CTNodes getNodesBelow (CTNode*) const;
CTNodes getNodesAtLevel (unsigned) const;
void swapLogVar (LogVar);
bool join (CTNode*, const Tuple&, unsigned, CTNode*);
unsigned countTuples (const CTNode*) const;
bool indenticalSubtrees (
const CTNode*,
const CTNode*,
bool) const;
CTNodes getNodesBelow (CTNode*) const;
void getTuples (
CTNode*,
Tuples,
unsigned,
Tuples&,
CTNodes&) const;
CTNodes getNodesAtLevel (unsigned) const;
void swapLogVar (LogVar);
bool join (CTNode*, const Tuple&, unsigned, CTNode*);
bool indenticalSubtrees (
const CTNode*, const CTNode*, bool) const;
void getTuples (CTNode*, Tuples, unsigned, Tuples&, CTNodes&) const;
vector<std::pair<CTNode*, unsigned>> countNormalize (
const CTNode*,
unsigned);
const CTNode*, unsigned);
static void split (
CTNode*,
CTNode*,
CTNodes&,
unsigned);
static void split (
CTNode*, CTNode*, CTNodes&, unsigned);
static bool overlap (const CTNode*, const CTNode*, unsigned);
static CTNode* copySubtree (const CTNode*);
static void deleteSubtree (CTNode*);
static bool overlap (const CTNode*, const CTNode*, unsigned);
CTNode* root_;
LogVars logVars_;
LogVarSet logVarSet_;
CTNode* root_;
LogVars logVars_;
LogVarSet logVarSet_;
};

View File

@ -1,45 +0,0 @@
#ifndef HORUS_DISTRIBUTION_H
#define HORUS_DISTRIBUTION_H
#include <vector>
#include "Horus.h"
//TODO die die die die die
using namespace std;
struct Distribution
{
public:
Distribution (int id)
{
this->id = id;
}
Distribution (const Params& params, int id = -1)
{
this->id = id;
this->params = params;
}
void updateParameters (const Params& params)
{
this->params = params;
}
bool shared (void)
{
return id != -1;
}
int id;
Params params;
private:
DISALLOW_COPY_AND_ASSIGN (Distribution);
};
#endif // HORUS_DISTRIBUTION_H

View File

@ -1,5 +1,7 @@
#include <limits>
#include <fstream>
#include "ElimGraph.h"
#include "BayesNet.h"

View File

@ -17,15 +17,15 @@ enum ElimHeuristic
};
class EgNode : public VarNode {
class EgNode : public VarNode
{
public:
EgNode (VarNode* var) : VarNode (var) { }
void addNeighbor (EgNode* n)
{
neighs_.push_back (n);
}
void addNeighbor (EgNode* n) { neighs_.push_back (n); }
const vector<EgNode*>& neighbors (void) const { return neighs_; }
private:
vector<EgNode*> neighs_;
};
@ -35,6 +35,7 @@ class ElimGraph
{
public:
ElimGraph (const BayesNet&);
~ElimGraph (void);
void addEdge (EgNode* n1, EgNode* n2)
@ -43,13 +44,19 @@ class ElimGraph
n1->addNeighbor (n2);
n2->addNeighbor (n1);
}
void addNode (EgNode*);
EgNode* getEgNode (VarId) const;
VarIds getEliminatingOrder (const VarIds&);
void printGraphicalModel (void) const;
void exportToGraphViz (const char*, bool = true,
const VarIds& = VarIds()) const;
void setIndexes();
void addNode (EgNode*);
EgNode* getEgNode (VarId) const;
VarIds getEliminatingOrder (const VarIds&);
void printGraphicalModel (void) const;
void exportToGraphViz (const char*, bool = true,
const VarIds& = VarIds()) const;
void setIndexes();
static void setEliminationHeuristic (ElimHeuristic h)
{
@ -57,14 +64,19 @@ class ElimGraph
}
private:
EgNode* getLowestCostNode (void) const;
unsigned getNeighborsCost (const EgNode*) const;
unsigned getWeightCost (const EgNode*) const;
unsigned getFillCost (const EgNode*) const;
unsigned getWeightedFillCost (const EgNode*) const;
void connectAllNeighbors (const EgNode*);
bool neighbors (const EgNode*, const EgNode*) const;
EgNode* getLowestCostNode (void) const;
unsigned getNeighborsCost (const EgNode*) const;
unsigned getWeightCost (const EgNode*) const;
unsigned getFillCost (const EgNode*) const;
unsigned getWeightedFillCost (const EgNode*) const;
void connectAllNeighbors (const EgNode*);
bool neighbors (const EgNode*, const EgNode*) const;
vector<EgNode*> nodes_;
vector<bool> marked_;

View File

@ -8,7 +8,7 @@
#include "Factor.h"
#include "Indexer.h"
#include "Util.h"
Factor::Factor (const Factor& g)
@ -18,206 +18,73 @@ Factor::Factor (const Factor& g)
Factor::Factor (VarId vid, unsigned nStates)
Factor::Factor (VarId vid, unsigned nrStates)
{
varids_.push_back (vid);
ranges_.push_back (nStates);
dist_ = new Distribution (Params (nStates, 1.0));
args_.push_back (vid);
ranges_.push_back (nrStates);
params_.resize (nrStates, 1.0);
distId_ = Util::maxUnsigned();
assert (params_.size() == Util::expectedSize (ranges_));
}
Factor::Factor (const VarNodes& vars)
{
int nParams = 1;
int nrParams = 1;
for (unsigned i = 0; i < vars.size(); i++) {
varids_.push_back (vars[i]->varId());
args_.push_back (vars[i]->varId());
ranges_.push_back (vars[i]->nrStates());
nParams *= vars[i]->nrStates();
nrParams *= vars[i]->nrStates();
}
// create a uniform distribution
double val = 1.0 / nParams;
dist_ = new Distribution (Params (nParams, val));
double val = 1.0 / nrParams;
params_.resize (nrParams, val);
distId_ = Util::maxUnsigned();
assert (params_.size() == Util::expectedSize (ranges_));
}
Factor::Factor (VarId vid, unsigned nStates, const Params& params)
Factor::Factor (
VarId vid,
unsigned nrStates,
const Params& params)
{
varids_.push_back (vid);
ranges_.push_back (nStates);
dist_ = new Distribution (params);
}
Factor::Factor (VarNodes& vars, Distribution* dist)
{
for (unsigned i = 0; i < vars.size(); i++) {
varids_.push_back (vars[i]->varId());
ranges_.push_back (vars[i]->nrStates());
}
dist_ = dist;
}
Factor::Factor (const VarNodes& vars, const Params& params)
{
for (unsigned i = 0; i < vars.size(); i++) {
varids_.push_back (vars[i]->varId());
ranges_.push_back (vars[i]->nrStates());
}
dist_ = new Distribution (params);
}
Factor::Factor (const VarIds& vids,
const Ranges& ranges,
const Params& params)
{
varids_ = vids;
ranges_ = ranges;
dist_ = new Distribution (params);
}
Factor::~Factor (void)
{
if (dist_->shared() == false) {
delete dist_;
}
}
void
Factor::setParameters (const Params& params)
{
assert (dist_->params.size() == params.size());
dist_->params = params;
}
void
Factor::copyFromFactor (const Factor& g)
{
varids_ = g.getVarIds();
ranges_ = g.getRanges();
dist_ = new Distribution (g.getParameters());
}
void
Factor::multiply (const Factor& g)
{
if (varids_.size() == 0) {
copyFromFactor (g);
return;
}
const VarIds& g_varids = g.getVarIds();
const Ranges& g_ranges = g.getRanges();
const Params& g_params = g.getParameters();
if (varids_ == g_varids) {
// optimization: if the factors contain the same set of variables,
// we can do a 1 to 1 operation on the parameters
if (Globals::logDomain) {
Util::add (dist_->params, g_params);
} else {
Util::multiply (dist_->params, g_params);
}
} else {
bool sharedVars = false;
vector<unsigned> gvarpos;
for (unsigned i = 0; i < g_varids.size(); i++) {
int idx = indexOf (g_varids[i]);
if (idx == -1) {
insertVariable (g_varids[i], g_ranges[i]);
gvarpos.push_back (varids_.size() - 1);
} else {
sharedVars = true;
gvarpos.push_back (idx);
}
}
if (sharedVars == false) {
// optimization: if the original factors doesn't have common variables,
// we don't need to marry the states of the common variables
unsigned count = 0;
for (unsigned i = 0; i < dist_->params.size(); i++) {
if (Globals::logDomain) {
dist_->params[i] += g_params[count];
} else {
dist_->params[i] *= g_params[count];
}
count ++;
if (count >= g_params.size()) {
count = 0;
}
}
} else {
StatesIndexer indexer (ranges_, false);
while (indexer.valid()) {
unsigned g_li = 0;
unsigned prod = 1;
for (int j = gvarpos.size() - 1; j >= 0; j--) {
g_li += indexer[gvarpos[j]] * prod;
prod *= g_ranges[j];
}
if (Globals::logDomain) {
dist_->params[indexer] += g_params[g_li];
} else {
dist_->params[indexer] *= g_params[g_li];
}
++ indexer;
}
}
}
}
void
Factor::insertVariable (VarId varId, unsigned nrStates)
{
assert (indexOf (varId) == -1);
Params oldParams = dist_->params;
dist_->params.clear();
dist_->params.reserve (oldParams.size() * nrStates);
for (unsigned i = 0; i < oldParams.size(); i++) {
for (unsigned reps = 0; reps < nrStates; reps++) {
dist_->params.push_back (oldParams[i]);
}
}
varids_.push_back (varId);
args_.push_back (vid);
ranges_.push_back (nrStates);
params_ = params;
distId_ = Util::maxUnsigned();
assert (params_.size() == Util::expectedSize (ranges_));
}
void
Factor::insertVariables (const VarIds& varIds, const Ranges& ranges)
Factor::Factor (
const VarNodes& vars,
const Params& params,
unsigned distId)
{
Params oldParams = dist_->params;
unsigned nrStates = 1;
for (unsigned i = 0; i < varIds.size(); i++) {
assert (indexOf (varIds[i]) == -1);
varids_.push_back (varIds[i]);
ranges_.push_back (ranges[i]);
nrStates *= ranges[i];
}
dist_->params.clear();
dist_->params.reserve (oldParams.size() * nrStates);
for (unsigned i = 0; i < oldParams.size(); i++) {
for (unsigned reps = 0; reps < nrStates; reps++) {
dist_->params.push_back (oldParams[i]);
}
for (unsigned i = 0; i < vars.size(); i++) {
args_.push_back (vars[i]->varId());
ranges_.push_back (vars[i]->nrStates());
}
params_ = params;
distId_ = distId;
assert (params_.size() == Util::expectedSize (ranges_));
}
Factor::Factor (
const VarIds& vids,
const Ranges& ranges,
const Params& params)
{
args_ = vids;
ranges_ = ranges;
params_ = params;
distId_ = Util::maxUnsigned();
assert (params_.size() == Util::expectedSize (ranges_));
}
@ -226,10 +93,10 @@ void
Factor::sumOutAllExcept (VarId vid)
{
assert (indexOf (vid) != -1);
while (varids_.back() != vid) {
while (args_.back() != vid) {
sumOutLastVariable();
}
while (varids_.front() != vid) {
while (args_.front() != vid) {
sumOutFirstVariable();
}
}
@ -239,9 +106,10 @@ Factor::sumOutAllExcept (VarId vid)
void
Factor::sumOutAllExcept (const VarIds& vids)
{
for (unsigned i = 0; i < varids_.size(); i++) {
if (std::find (vids.begin(), vids.end(), varids_[i]) == vids.end()) {
sumOut (varids_[i]);
for (int i = 0; i < (int)args_.size(); i++) {
if (Util::contains (vids, args_[i]) == false) {
sumOut (args_[i]);
i --;
}
}
}
@ -254,11 +122,11 @@ Factor::sumOut (VarId vid)
int idx = indexOf (vid);
assert (idx != -1);
if (vid == varids_.back()) {
if (vid == args_.back()) {
sumOutLastVariable(); // optimization
return;
}
if (vid == varids_.front()) {
if (vid == args_.front()) {
sumOutFirstVariable(); // optimization
return;
}
@ -271,7 +139,7 @@ Factor::sumOut (VarId vid)
// on the left of `var', with the states of the remaining vars fixed
unsigned leftVarOffset = 1;
for (int i = varids_.size() - 1; i > idx; i--) {
for (int i = args_.size() - 1; i > idx; i--) {
varOffset *= ranges_[i];
leftVarOffset *= ranges_[i];
}
@ -280,25 +148,24 @@ Factor::sumOut (VarId vid)
unsigned offset = 0;
unsigned count1 = 0;
unsigned count2 = 0;
unsigned newpsSize = dist_->params.size() / ranges_[idx];
unsigned newpsSize = params_.size() / ranges_[idx];
Params newps;
newps.reserve (newpsSize);
Params& params = dist_->params;
while (newps.size() < newpsSize) {
double sum = Util::addIdenty();
double sum = LogAware::addIdenty();
for (unsigned i = 0; i < ranges_[idx]; i++) {
if (Globals::logDomain) {
Util::logSum (sum, params[offset]);
sum = Util::logSum (sum, params_[offset]);
} else {
sum += params[offset];
sum += params_[offset];
}
offset += varOffset;
}
newps.push_back (sum);
count1 ++;
if (idx == (int)varids_.size() - 1) {
if (idx == (int)args_.size() - 1) {
offset = count1 * ranges_[idx];
} else {
if (((offset - varOffset + 1) % leftVarOffset) == 0) {
@ -308,9 +175,9 @@ Factor::sumOut (VarId vid)
offset = (leftVarOffset * count2) + count1;
}
}
varids_.erase (varids_.begin() + idx);
args_.erase (args_.begin() + idx);
ranges_.erase (ranges_.begin() + idx);
dist_->params = newps;
params_ = newps;
}
@ -318,20 +185,19 @@ Factor::sumOut (VarId vid)
void
Factor::sumOutFirstVariable (void)
{
Params& params = dist_->params;
unsigned nStates = ranges_.front();
unsigned sep = params.size() / nStates;
unsigned sep = params_.size() / nStates;
if (Globals::logDomain) {
for (unsigned i = sep; i < params.size(); i++) {
Util::logSum (params[i % sep], params[i]);
for (unsigned i = sep; i < params_.size(); i++) {
params_[i % sep] = Util::logSum (params_[i % sep], params_[i]);
}
} else {
for (unsigned i = sep; i < params.size(); i++) {
params[i % sep] += params[i];
for (unsigned i = sep; i < params_.size(); i++) {
params_[i % sep] += params_[i];
}
}
params.resize (sep);
varids_.erase (varids_.begin());
params_.resize (sep);
args_.erase (args_.begin());
ranges_.erase (ranges_.begin());
}
@ -340,143 +206,55 @@ Factor::sumOutFirstVariable (void)
void
Factor::sumOutLastVariable (void)
{
Params& params = dist_->params;
unsigned nStates = ranges_.back();
unsigned idx1 = 0;
unsigned idx2 = 0;
if (Globals::logDomain) {
while (idx1 < params.size()) {
params[idx2] = params[idx1];
while (idx1 < params_.size()) {
params_[idx2] = params_[idx1];
idx1 ++;
for (unsigned j = 1; j < nStates; j++) {
Util::logSum (params[idx2], params[idx1]);
params_[idx2] = Util::logSum (params_[idx2], params_[idx1]);
idx1 ++;
}
idx2 ++;
}
} else {
while (idx1 < params.size()) {
params[idx2] = params[idx1];
while (idx1 < params_.size()) {
params_[idx2] = params_[idx1];
idx1 ++;
for (unsigned j = 1; j < nStates; j++) {
params[idx2] += params[idx1];
params_[idx2] += params_[idx1];
idx1 ++;
}
idx2 ++;
}
}
params.resize (idx2);
varids_.pop_back();
params_.resize (idx2);
args_.pop_back();
ranges_.pop_back();
}
void
Factor::orderVariables (void)
Factor::multiply (Factor& g)
{
VarIds sortedVarIds = varids_;
sort (sortedVarIds.begin(), sortedVarIds.end());
reorderVariables (sortedVarIds);
}
void
Factor::reorderVariables (const VarIds& newVarIds)
{
assert (newVarIds.size() == varids_.size());
if (newVarIds == varids_) {
if (args_.size() == 0) {
copyFromFactor (g);
return;
}
Ranges newRanges;
vector<unsigned> positions;
for (unsigned i = 0; i < newVarIds.size(); i++) {
unsigned idx = indexOf (newVarIds[i]);
newRanges.push_back (ranges_[idx]);
positions.push_back (idx);
}
unsigned N = ranges_.size();
Params newParams (dist_->params.size());
for (unsigned i = 0; i < dist_->params.size(); i++) {
unsigned li = i;
// calculate vector index corresponding to linear index
vector<unsigned> vi (N);
for (int k = N-1; k >= 0; k--) {
vi[k] = li % ranges_[k];
li /= ranges_[k];
}
// convert permuted vector index to corresponding linear index
unsigned prod = 1;
unsigned new_li = 0;
for (int k = N-1; k >= 0; k--) {
new_li += vi[positions[k]] * prod;
prod *= ranges_[positions[k]];
}
newParams[new_li] = dist_->params[i];
}
varids_ = newVarIds;
ranges_ = newRanges;
dist_->params = newParams;
TFactor<VarId>::multiply (g);
}
void
Factor::absorveEvidence (VarId vid, unsigned evidence)
Factor::reorderAccordingVarIds (void)
{
int idx = indexOf (vid);
assert (idx != -1);
Params oldParams = dist_->params;
dist_->params.clear();
dist_->params.reserve (oldParams.size() / ranges_[idx]);
StatesIndexer indexer (ranges_);
for (unsigned i = 0; i < evidence; i++) {
indexer.increment (idx);
}
while (indexer.valid()) {
dist_->params.push_back (oldParams[indexer]);
indexer.incrementExcluding (idx);
}
varids_.erase (varids_.begin() + idx);
ranges_.erase (ranges_.begin() + idx);
}
void
Factor::normalize (void)
{
Util::normalize (dist_->params);
}
bool
Factor::contains (const VarIds& vars) const
{
for (unsigned i = 0; i < vars.size(); i++) {
if (indexOf (vars[i]) == -1) {
return false;
}
}
return true;
}
int
Factor::indexOf (VarId vid) const
{
for (unsigned i = 0; i < varids_.size(); i++) {
if (varids_[i] == vid) {
return i;
}
}
return -1;
VarIds sortedVarIds = args_;
sort (sortedVarIds.begin(), sortedVarIds.end());
reorderArguments (sortedVarIds);
}
@ -486,9 +264,9 @@ Factor::getLabel (void) const
{
stringstream ss;
ss << "f(" ;
for (unsigned i = 0; i < varids_.size(); i++) {
for (unsigned i = 0; i < args_.size(); i++) {
if (i != 0) ss << "," ;
ss << VarNode (varids_[i], ranges_[i]).label();
ss << VarNode (args_[i], ranges_[i]).label();
}
ss << ")" ;
return ss.str();
@ -500,13 +278,13 @@ void
Factor::print (void) const
{
VarNodes vars;
for (unsigned i = 0; i < varids_.size(); i++) {
vars.push_back (new VarNode (varids_[i], ranges_[i]));
for (unsigned i = 0; i < args_.size(); i++) {
vars.push_back (new VarNode (args_[i], ranges_[i]));
}
vector<string> jointStrings = Util::getJointStateStrings (vars);
for (unsigned i = 0; i < dist_->params.size(); i++) {
for (unsigned i = 0; i < params_.size(); i++) {
cout << "f(" << jointStrings[i] << ")" ;
cout << " = " << dist_->params[i] << endl;
cout << " = " << params_[i] << endl;
}
cout << endl;
for (unsigned i = 0; i < vars.size(); i++) {
@ -515,3 +293,13 @@ Factor::print (void) const
}
void
Factor::copyFromFactor (const Factor& g)
{
args_ = g.arguments();
ranges_ = g.ranges();
params_ = g.params();
distId_ = g.distId();
}

View File

@ -3,64 +3,293 @@
#include <vector>
#include "Distribution.h"
#include "VarNode.h"
#include "Indexer.h"
#include "Util.h"
using namespace std;
class Distribution;
template <typename T>
class TFactor
{
public:
const vector<T>& arguments (void) const { return args_; }
vector<T>& arguments (void) { return args_; }
const Ranges& ranges (void) const { return ranges_; }
const Params& params (void) const { return params_; }
Params& params (void) { return params_; }
unsigned nrArguments (void) const { return args_.size(); }
unsigned size (void) const { return params_.size(); }
unsigned distId (void) const { return distId_; }
void setDistId (unsigned id) { distId_ = id; }
void setParams (const Params& newParams)
{
params_ = newParams;
assert (params_.size() == Util::expectedSize (ranges_));
}
void normalize (void)
{
LogAware::normalize (params_);
}
int indexOf (const T& t) const
{
int idx = -1;
for (unsigned i = 0; i < args_.size(); i++) {
if (args_[i] == t) {
idx = i;
break;
}
}
return idx;
}
const T& argument (unsigned idx) const
{
assert (idx < args_.size());
return args_[idx];
}
T& argument (unsigned idx)
{
assert (idx < args_.size());
return args_[idx];
}
unsigned range (unsigned idx) const
{
assert (idx < ranges_.size());
return ranges_[idx];
}
void multiply (TFactor<T>& g)
{
const vector<T>& g_args = g.arguments();
const Ranges& g_ranges = g.ranges();
const Params& g_params = g.params();
if (args_ == g_args) {
// optimization: if the factors contain the same set of args,
// we can do a 1 to 1 operation on the parameters
if (Globals::logDomain) {
Util::add (params_, g_params);
} else {
Util::multiply (params_, g_params);
}
} else {
bool sharedArgs = false;
vector<unsigned> gvarpos;
for (unsigned i = 0; i < g_args.size(); i++) {
int idx = indexOf (g_args[i]);
if (idx == -1) {
insertArgument (g_args[i], g_ranges[i]);
gvarpos.push_back (args_.size() - 1);
} else {
sharedArgs = true;
gvarpos.push_back (idx);
}
}
if (sharedArgs == false) {
// optimization: if the original factors doesn't have common args,
// we don't need to marry the states of the common args
unsigned count = 0;
for (unsigned i = 0; i < params_.size(); i++) {
if (Globals::logDomain) {
params_[i] += g_params[count];
} else {
params_[i] *= g_params[count];
}
count ++;
if (count >= g_params.size()) {
count = 0;
}
}
} else {
StatesIndexer indexer (ranges_, false);
while (indexer.valid()) {
unsigned g_li = 0;
unsigned prod = 1;
for (int j = gvarpos.size() - 1; j >= 0; j--) {
g_li += indexer[gvarpos[j]] * prod;
prod *= g_ranges[j];
}
if (Globals::logDomain) {
params_[indexer] += g_params[g_li];
} else {
params_[indexer] *= g_params[g_li];
}
++ indexer;
}
}
}
}
void absorveEvidence (const T& arg, unsigned evidence)
{
int idx = indexOf (arg);
assert (idx != -1);
assert (evidence < ranges_[idx]);
Params copy = params_;
params_.clear();
params_.reserve (copy.size() / ranges_[idx]);
StatesIndexer indexer (ranges_);
for (unsigned i = 0; i < evidence; i++) {
indexer.increment (idx);
}
while (indexer.valid()) {
params_.push_back (copy[indexer]);
indexer.incrementExcluding (idx);
}
args_.erase (args_.begin() + idx);
ranges_.erase (ranges_.begin() + idx);
}
void reorderArguments (const vector<T> newArgs)
{
assert (newArgs.size() == args_.size());
if (newArgs == args_) {
return; // already in the wanted order
}
Ranges newRanges;
vector<unsigned> positions;
for (unsigned i = 0; i < newArgs.size(); i++) {
unsigned idx = indexOf (newArgs[i]);
newRanges.push_back (ranges_[idx]);
positions.push_back (idx);
}
unsigned N = ranges_.size();
Params newParams (params_.size());
for (unsigned i = 0; i < params_.size(); i++) {
unsigned li = i;
// calculate vector index corresponding to linear index
vector<unsigned> vi (N);
for (int k = N-1; k >= 0; k--) {
vi[k] = li % ranges_[k];
li /= ranges_[k];
}
// convert permuted vector index to corresponding linear index
unsigned prod = 1;
unsigned new_li = 0;
for (int k = N - 1; k >= 0; k--) {
new_li += vi[positions[k]] * prod;
prod *= ranges_[positions[k]];
}
newParams[new_li] = params_[i];
}
args_ = newArgs;
ranges_ = newRanges;
params_ = newParams;
}
bool contains (const T& arg) const
{
return Util::contains (args_, arg);
}
bool contains (const vector<T>& args) const
{
for (unsigned i = 0; i < args_.size(); i++) {
if (contains (args[i]) == false) {
return false;
}
}
return true;
}
protected:
vector<T> args_;
Ranges ranges_;
Params params_;
unsigned distId_;
private:
void insertArgument (const T& arg, unsigned range)
{
assert (indexOf (arg) == -1);
Params copy = params_;
params_.clear();
params_.reserve (copy.size() * range);
for (unsigned i = 0; i < copy.size(); i++) {
for (unsigned reps = 0; reps < range; reps++) {
params_.push_back (copy[i]);
}
}
args_.push_back (arg);
ranges_.push_back (range);
}
void insertArguments (const vector<T>& args, const Ranges& ranges)
{
Params copy = params_;
unsigned nrStates = 1;
for (unsigned i = 0; i < args.size(); i++) {
assert (indexOf (args[i]) == -1);
args_.push_back (args[i]);
ranges_.push_back (ranges[i]);
nrStates *= ranges[i];
}
params_.clear();
params_.reserve (copy.size() * nrStates);
for (unsigned i = 0; i < copy.size(); i++) {
for (unsigned reps = 0; reps < nrStates; reps++) {
params_.push_back (copy[i]);
}
}
}
};
class Factor
class Factor : public TFactor<VarId>
{
public:
Factor (void) { }
Factor (const Factor&);
Factor (VarId, unsigned);
Factor (const VarNodes&);
Factor (VarId, unsigned, const Params&);
Factor (VarNodes&, Distribution*);
Factor (const VarNodes&, const Params&);
Factor (const VarNodes&, const Params&,
unsigned = Util::maxUnsigned());
Factor (const VarIds&, const Ranges&, const Params&);
~Factor (void);
void setParameters (const Params&);
void copyFromFactor (const Factor& f);
void multiply (const Factor&);
void insertVariable (VarId, unsigned);
void insertVariables (const VarIds&, const Ranges&);
void sumOutAllExcept (VarId);
void sumOutAllExcept (const VarIds&);
void sumOut (VarId);
void sumOutFirstVariable (void);
void sumOutLastVariable (void);
void orderVariables (void);
void reorderVariables (const VarIds&);
void absorveEvidence (VarId, unsigned);
void normalize (void);
bool contains (const VarIds&) const;
int indexOf (VarId) const;
string getLabel (void) const;
void print (void) const;
void sumOutAllExcept (VarId);
const VarIds& getVarIds (void) const { return varids_; }
const Ranges& getRanges (void) const { return ranges_; }
const Params& getParameters (void) const { return dist_->params; }
Distribution* getDistribution (void) const { return dist_; }
unsigned nrVariables (void) const { return varids_.size(); }
unsigned nrParameters() const { return dist_->params.size(); }
void sumOutAllExcept (const VarIds&);
void setDistribution (Distribution* dist)
{
dist_ = dist;
}
void sumOut (VarId);
void sumOutFirstVariable (void);
void sumOutLastVariable (void);
void multiply (Factor&);
void reorderAccordingVarIds (void);
string getLabel (void) const;
void print (void) const;
private:
void copyFromFactor (const Factor& f);
VarIds varids_;
Ranges ranges_;
Distribution* dist_;
};
#endif // HORUS_FACTOR_H

View File

@ -53,10 +53,10 @@ FactorGraph::FactorGraph (const BayesNet& bn)
neighs.push_back (varNodes_[parents[j]->getIndex()]);
}
FgFacNode* fn = new FgFacNode (
new Factor (neighs, nodes[i]->getDistribution()));
new Factor (neighs, nodes[i]->params(), nodes[i]->distId()));
if (orderFactorVariables) {
sort (neighs.begin(), neighs.end(), CompVarId());
fn->factor()->orderVariables();
fn->factor()->reorderAccordingVarIds();
}
addFactor (fn);
for (unsigned j = 0; j < neighs.size(); j++) {
@ -131,10 +131,10 @@ FactorGraph::readFromUaiFormat (const char* fileName)
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
unsigned nParams;
is >> nParams;
if (facNodes_[i]->getParameters().size() != nParams) {
if (facNodes_[i]->params().size() != nParams) {
cerr << "error: invalid number of parameters for factor " ;
cerr << facNodes_[i]->getLabel() ;
cerr << ", expected: " << facNodes_[i]->getParameters().size();
cerr << ", expected: " << facNodes_[i]->params().size();
cerr << ", given: " << nParams << endl;
abort();
}
@ -147,7 +147,7 @@ FactorGraph::readFromUaiFormat (const char* fileName)
if (Globals::logDomain) {
Util::toLog (params);
}
facNodes_[i]->factor()->setParameters (params);
facNodes_[i]->factor()->setParams (params);
}
is.close();
setIndexes();
@ -335,21 +335,6 @@ FactorGraph::setIndexes (void)
void
FactorGraph::freeDistributions (void)
{
set<Distribution*> dists;
for (unsigned i = 0; i < facNodes_.size(); i++) {
dists.insert (facNodes_[i]->factor()->getDistribution());
}
for (set<Distribution*>::iterator it = dists.begin();
it != dists.end(); it++) {
delete *it;
}
}
void
FactorGraph::printGraphicalModel (void) const
{
@ -440,7 +425,7 @@ FactorGraph::exportToUaiFormat (const char* fileName) const
}
for (unsigned i = 0; i < facNodes_.size(); i++) {
Params params = facNodes_[i]->getParameters();
Params params = facNodes_[i]->params();
if (Globals::logDomain) {
Util::fromLog (params);
}
@ -477,7 +462,7 @@ FactorGraph::exportToLibDaiFormat (const char* fileName) const
out << factorVars[j]->nrStates() << " " ;
}
out << endl;
Params params = facNodes_[i]->factor()->getParameters();
Params params = facNodes_[i]->factor()->params();
if (Globals::logDomain) {
Util::fromLog (params);
}

View File

@ -4,7 +4,6 @@
#include <vector>
#include "GraphicalModel.h"
#include "Distribution.h"
#include "Factor.h"
#include "Horus.h"
@ -13,18 +12,21 @@ using namespace std;
class BayesNet;
class FgFacNode;
class FgVarNode : public VarNode
{
public:
FgVarNode (VarId varId, unsigned nrStates) : VarNode (varId, nrStates) { }
FgVarNode (const VarNode* v) : VarNode (v) { }
void addNeighbor (FgFacNode* fn) { neighs_.push_back (fn); }
const FgFacSet& neighbors (void) const { return neighs_; }
void addNeighbor (FgFacNode* fn) { neighs_.push_back (fn); }
const FgFacSet& neighbors (void) const { return neighs_; }
private:
DISALLOW_COPY_AND_ASSIGN (FgVarNode);
// members
FgFacSet neighs_;
};
@ -32,13 +34,18 @@ class FgVarNode : public VarNode
class FgFacNode
{
public:
FgFacNode (const FgFacNode* fn) {
FgFacNode (const FgFacNode* fn)
{
factor_ = new Factor (*fn->factor());
index_ = -1;
}
FgFacNode (Factor* f) : factor_(new Factor(*f)), index_(-1) { }
Factor* factor() const { return factor_; }
void addNeighbor (FgVarNode* vn) { neighs_.push_back (vn); }
Factor* factor() const { return factor_; }
void addNeighbor (FgVarNode* vn) { neighs_.push_back (vn); }
const FgVarSet& neighbors (void) const { return neighs_; }
int getIndex (void) const
@ -46,28 +53,28 @@ class FgFacNode
assert (index_ != -1);
return index_;
}
void setIndex (int index)
{
index_ = index;
}
Distribution* getDistribution (void)
const Params& params (void) const
{
return factor_->getDistribution();
}
const Params& getParameters (void) const
{
return factor_->getParameters();
return factor_->params();
}
string getLabel (void)
{
return factor_->getLabel();
}
private:
DISALLOW_COPY_AND_ASSIGN (FgFacNode);
Factor* factor_;
int index_;
FgVarSet neighs_;
Factor* factor_;
FgVarSet neighs_;
int index_;
};
@ -83,29 +90,17 @@ struct CompVarId
class FactorGraph : public GraphicalModel
{
public:
FactorGraph (void) {};
FactorGraph (void) { };
FactorGraph (const FactorGraph&);
FactorGraph (const BayesNet&);
~FactorGraph (void);
void readFromUaiFormat (const char*);
void readFromLibDaiFormat (const char*);
void addVariable (FgVarNode*);
void addFactor (FgFacNode*);
void addEdge (FgVarNode*, FgFacNode*);
void addEdge (FgFacNode*, FgVarNode*);
VarNode* getVariableNode (unsigned) const;
VarNodes getVariableNodes (void) const;
bool isTree (void) const;
void setIndexes (void);
void freeDistributions (void);
void printGraphicalModel (void) const;
void exportToGraphViz (const char*) const;
void exportToUaiFormat (const char*) const;
void exportToLibDaiFormat (const char*) const;
const FgVarSet& getVarNodes (void) const { return varNodes_; }
const FgFacSet& getFactorNodes (void) const { return facNodes_; }
const FgVarSet& getVarNodes (void) const { return varNodes_; }
const FgFacSet& getFactorNodes (void) const { return facNodes_; }
FgVarNode* getFgVarNode (VarId vid) const
{
@ -117,21 +112,52 @@ class FactorGraph : public GraphicalModel
}
}
void readFromUaiFormat (const char*);
void readFromLibDaiFormat (const char*);
void addVariable (FgVarNode*);
void addFactor (FgFacNode*);
void addEdge (FgVarNode*, FgFacNode*);
void addEdge (FgFacNode*, FgVarNode*);
VarNode* getVariableNode (unsigned) const;
VarNodes getVariableNodes (void) const;
bool isTree (void) const;
void setIndexes (void);
void printGraphicalModel (void) const;
void exportToGraphViz (const char*) const;
void exportToUaiFormat (const char*) const;
void exportToLibDaiFormat (const char*) const;
static bool orderFactorVariables;
private:
//DISALLOW_COPY_AND_ASSIGN (FactorGraph);
bool containsCycle (void) const;
bool containsCycle (const FgVarNode*, const FgFacNode*,
vector<bool>&, vector<bool>&) const;
bool containsCycle (const FgFacNode*, const FgVarNode*,
vector<bool>&, vector<bool>&) const;
// DISALLOW_COPY_AND_ASSIGN (FactorGraph);
FgVarSet varNodes_;
FgFacSet facNodes_;
bool containsCycle (void) const;
bool containsCycle (const FgVarNode*, const FgFacNode*,
vector<bool>&, vector<bool>&) const;
bool containsCycle (const FgFacNode*, const FgVarNode*,
vector<bool>&, vector<bool>&) const;
FgVarSet varNodes_;
FgFacSet facNodes_;
typedef unordered_map<unsigned, unsigned> IndexMap;
IndexMap varMap_;
IndexMap varMap_;
};
#endif // HORUS_FACTORGRAPH_H

View File

@ -38,11 +38,11 @@ void
FgBpSolver::runSolver (void)
{
clock_t start;
if (COLLECT_STATISTICS) {
if (Constants::COLLECT_STATS) {
start = clock();
}
runLoopySolver();
if (DL >= 2) {
if (Constants::DEBUG >= 2) {
cout << endl;
if (nIters_ < BpOptions::maxIter) {
cout << "Sum-Product converged in " ;
@ -53,18 +53,13 @@ FgBpSolver::runSolver (void)
}
}
unsigned size = factorGraph_->getVarNodes().size();
if (COLLECT_STATISTICS) {
if (Constants::COLLECT_STATS) {
unsigned nIters = 0;
bool loopy = factorGraph_->isTree() == false;
if (loopy) nIters = nIters_;
double time = (double (clock() - start)) / CLOCKS_PER_SEC;
Statistics::updateStatistics (size, loopy, nIters, time);
}
if (EXPORT_TO_GRAPHVIZ && size > EXPORT_MINIMAL_SIZE) {
stringstream ss;
ss << Statistics::getSolvedNetworksCounting() << "." << size << ".dot" ;
factorGraph_->exportToGraphViz (ss.str().c_str());
}
}
@ -76,22 +71,22 @@ FgBpSolver::getPosterioriOf (VarId vid)
FgVarNode* var = factorGraph_->getFgVarNode (vid);
Params probs;
if (var->hasEvidence()) {
probs.resize (var->nrStates(), Util::noEvidence());
probs[var->getEvidence()] = Util::withEvidence();
probs.resize (var->nrStates(), LogAware::noEvidence());
probs[var->getEvidence()] = LogAware::withEvidence();
} else {
probs.resize (var->nrStates(), Util::multIdenty());
probs.resize (var->nrStates(), LogAware::multIdenty());
const SpLinkSet& links = ninf(var)->getLinks();
if (Globals::logDomain) {
for (unsigned i = 0; i < links.size(); i++) {
Util::add (probs, links[i]->getMessage());
}
Util::normalize (probs);
LogAware::normalize (probs);
Util::fromLog (probs);
} else {
for (unsigned i = 0; i < links.size(); i++) {
Util::multiply (probs, links[i]->getMessage());
}
Util::normalize (probs);
LogAware::normalize (probs);
}
}
return probs;
@ -102,9 +97,9 @@ FgBpSolver::getPosterioriOf (VarId vid)
Params
FgBpSolver::getJointDistributionOf (const VarIds& jointVarIds)
{
int idx = -1;
FgVarNode* vn = factorGraph_->getFgVarNode (jointVarIds[0]);
const FgFacSet& factorNodes = vn->neighbors();
int idx = -1;
for (unsigned i = 0; i < factorNodes.size(); i++) {
if (factorNodes[i]->factor()->contains (jointVarIds)) {
idx = i;
@ -114,18 +109,18 @@ FgBpSolver::getJointDistributionOf (const VarIds& jointVarIds)
if (idx == -1) {
return getJointByConditioning (jointVarIds);
} else {
Factor r (*factorNodes[idx]->factor());
Factor res (*factorNodes[idx]->factor());
const SpLinkSet& links = ninf(factorNodes[idx])->getLinks();
for (unsigned i = 0; i < links.size(); i++) {
Factor msg (links[i]->getVariable()->varId(),
links[i]->getVariable()->nrStates(),
getVar2FactorMsg (links[i]));
r.multiply (msg);
res.multiply (msg);
}
r.sumOutAllExcept (jointVarIds);
r.reorderVariables (jointVarIds);
r.normalize();
Params jointDist = r.getParameters();
res.sumOutAllExcept (jointVarIds);
res.reorderArguments (jointVarIds);
res.normalize();
Params jointDist = res.params();
if (Globals::logDomain) {
Util::fromLog (jointDist);
}
@ -144,13 +139,8 @@ FgBpSolver::runLoopySolver (void)
while (!converged() && nIters_ < BpOptions::maxIter) {
nIters_ ++;
if (DL >= 2) {
cout << "****************************************" ;
cout << "****************************************" ;
cout << endl;
cout << " Iteration " << nIters_ << endl;
cout << "****************************************" ;
cout << "****************************************" ;
if (Constants::DEBUG >= 2) {
Util::printHeader (" Iteration " + nIters_);
cout << endl;
}
@ -178,7 +168,7 @@ FgBpSolver::runLoopySolver (void)
maxResidualSchedule();
break;
}
if (DL >= 2) {
if (Constants::DEBUG >= 2) {
cout << endl;
}
}
@ -256,12 +246,12 @@ FgBpSolver::converged (void)
} else {
for (unsigned i = 0; i < links_.size(); i++) {
double residual = links_[i]->getResidual();
if (DL >= 2) {
if (Constants::DEBUG >= 2) {
cout << links_[i]->toString() + " residual = " << residual << endl;
}
if (residual > BpOptions::accuracy) {
converged = false;
if (DL == 0) break;
if (Constants::DEBUG == 0) break;
}
}
}
@ -283,7 +273,7 @@ FgBpSolver::maxResidualSchedule (void)
}
for (unsigned c = 0; c < links_.size(); c++) {
if (DL >= 2) {
if (Constants::DEBUG >= 2) {
cout << "current residuals:" << endl;
for (SortedOrder::iterator it = sortedOrder_.begin();
it != sortedOrder_.end(); it ++) {
@ -317,9 +307,8 @@ FgBpSolver::maxResidualSchedule (void)
}
}
}
if (DL >= 2) {
cout << "----------------------------------------" ;
cout << "----------------------------------------" << endl;
if (Constants::DEBUG >= 2) {
Util::printDashedLine();
}
}
}
@ -339,7 +328,7 @@ FgBpSolver::calculateFactor2VariableMsg (SpLink* link) const
msgSize *= links[i]->getVariable()->nrStates();
}
unsigned repetitions = 1;
Params msgProduct (msgSize, Util::multIdenty());
Params msgProduct (msgSize, LogAware::multIdenty());
if (Globals::logDomain) {
for (int i = links.size() - 1; i >= 0; i--) {
if (links[i]->getVariable() != dst) {
@ -354,7 +343,7 @@ FgBpSolver::calculateFactor2VariableMsg (SpLink* link) const
} else {
for (int i = links.size() - 1; i >= 0; i--) {
if (links[i]->getVariable() != dst) {
if (DL >= 5) {
if (Constants::DEBUG >= 5) {
cout << " message from " << links[i]->getVariable()->label();
cout << ": " << endl;
}
@ -368,34 +357,29 @@ FgBpSolver::calculateFactor2VariableMsg (SpLink* link) const
}
}
Factor result (src->factor()->getVarIds(),
src->factor()->getRanges(),
Factor result (src->factor()->arguments(),
src->factor()->ranges(),
msgProduct);
result.multiply (*(src->factor()));
if (DL >= 5) {
cout << " message product: " ;
cout << Util::parametersToString (msgProduct) << endl;
cout << " original factor: " ;
cout << Util::parametersToString (src->getParameters()) << endl;
cout << " factor product: " ;
cout << Util::parametersToString (result.getParameters()) << endl;
if (Constants::DEBUG >= 5) {
cout << " message product: " << msgProduct << endl;
cout << " original factor: " << src->params() << endl;
cout << " factor product: " << result.params() << endl;
}
result.sumOutAllExcept (dst->varId());
if (DL >= 5) {
if (Constants::DEBUG >= 5) {
cout << " marginalized: " ;
cout << Util::parametersToString (result.getParameters()) << endl;
cout << result.params() << endl;
}
const Params& resultParams = result.getParameters();
const Params& resultParams = result.params();
Params& message = link->getNextMessage();
for (unsigned i = 0; i < resultParams.size(); i++) {
message[i] = resultParams[i];
}
Util::normalize (message);
if (DL >= 5) {
cout << " curr msg: " ;
cout << Util::parametersToString (link->getMessage()) << endl;
cout << " next msg: " ;
cout << Util::parametersToString (message) << endl;
LogAware::normalize (message);
if (Constants::DEBUG >= 5) {
cout << " curr msg: " << link->getMessage() << endl;
cout << " next msg: " << message << endl;
}
}
@ -408,16 +392,16 @@ FgBpSolver::getVar2FactorMsg (const SpLink* link) const
const FgFacNode* dst = link->getFactor();
Params msg;
if (src->hasEvidence()) {
msg.resize (src->nrStates(), Util::noEvidence());
msg[src->getEvidence()] = Util::withEvidence();
if (DL >= 5) {
cout << Util::parametersToString (msg);
msg.resize (src->nrStates(), LogAware::noEvidence());
msg[src->getEvidence()] = LogAware::withEvidence();
if (Constants::DEBUG >= 5) {
cout << msg;
}
} else {
msg.resize (src->nrStates(), Util::one());
msg.resize (src->nrStates(), LogAware::one());
}
if (DL >= 5) {
cout << Util::parametersToString (msg);
if (Constants::DEBUG >= 5) {
cout << msg;
}
const SpLinkSet& links = ninf (src)->getLinks();
if (Globals::logDomain) {
@ -430,14 +414,14 @@ FgBpSolver::getVar2FactorMsg (const SpLink* link) const
for (unsigned i = 0; i < links.size(); i++) {
if (links[i]->getFactor() != dst) {
Util::multiply (msg, links[i]->getMessage());
if (DL >= 5) {
cout << " x " << Util::parametersToString (links[i]->getMessage());
if (Constants::DEBUG >= 5) {
cout << " x " << links[i]->getMessage();
}
}
}
}
if (DL >= 5) {
cout << " = " << Util::parametersToString (msg);
if (Constants::DEBUG >= 5) {
cout << " = " << msg;
}
return msg;
}
@ -503,9 +487,9 @@ FgBpSolver::printLinkInformation (void) const
SpLink* l = links_[i];
cout << l->toString() << ":" << endl;
cout << " curr msg = " ;
cout << Util::parametersToString (l->getMessage()) << endl;
cout << l->getMessage() << endl;
cout << " next msg = " ;
cout << Util::parametersToString (l->getNextMessage()) << endl;
cout << l->getNextMessage() << endl;
cout << " residual = " << l->getResidual() << endl;
}
}

View File

@ -13,7 +13,6 @@
using namespace std;
class SpLink
{
public:
@ -21,15 +20,34 @@ class SpLink
{
fac_ = fn;
var_ = vn;
v1_.resize (vn->nrStates(), Util::tl (1.0 / vn->nrStates()));
v2_.resize (vn->nrStates(), Util::tl (1.0 / vn->nrStates()));
v1_.resize (vn->nrStates(), LogAware::tl (1.0 / vn->nrStates()));
v2_.resize (vn->nrStates(), LogAware::tl (1.0 / vn->nrStates()));
currMsg_ = &v1_;
nextMsg_ = &v2_;
msgSended_ = false;
residual_ = 0.0;
}
virtual ~SpLink (void) {};
virtual ~SpLink (void) { };
FgFacNode* getFactor (void) const { return fac_; }
FgVarNode* getVariable (void) const { return var_; }
const Params& getMessage (void) const { return *currMsg_; }
Params& getNextMessage (void) { return *nextMsg_; }
bool messageWasSended (void) const { return msgSended_; }
double getResidual (void) const { return residual_; }
void clearResidual (void) { residual_ = 0.0; }
void updateResidual (void)
{
residual_ = LogAware::getMaxNorm (v1_,v2_);
}
virtual void updateMessage (void)
{
@ -37,11 +55,6 @@ class SpLink
msgSended_ = true;
}
void updateResidual (void)
{
residual_ = Util::getMaxNorm (v1_, v2_);
}
string toString (void) const
{
stringstream ss;
@ -50,38 +63,28 @@ class SpLink
ss << var_->label();
return ss.str();
}
FgFacNode* getFactor (void) const { return fac_; }
FgVarNode* getVariable (void) const { return var_; }
const Params& getMessage (void) const { return *currMsg_; }
Params& getNextMessage (void) { return *nextMsg_; }
bool messageWasSended (void) const { return msgSended_; }
double getResidual (void) const { return residual_; }
void clearResidual (void) { residual_ = 0.0; }
protected:
FgFacNode* fac_;
FgVarNode* var_;
Params v1_;
Params v2_;
Params* currMsg_;
Params* nextMsg_;
bool msgSended_;
double residual_;
FgFacNode* fac_;
FgVarNode* var_;
Params v1_;
Params v2_;
Params* currMsg_;
Params* nextMsg_;
bool msgSended_;
double residual_;
};
typedef vector<SpLink*> SpLinkSet;
class SPNodeInfo
{
public:
void addSpLink (SpLink* link) { links_.push_back (link); }
const SpLinkSet& getLinks (void) { return links_; }
void addSpLink (SpLink* link) { links_.push_back (link); }
const SpLinkSet& getLinks (void) { return links_; }
private:
SpLinkSet links_;
SpLinkSet links_;
};
@ -89,51 +92,29 @@ class FgBpSolver : public Solver
{
public:
FgBpSolver (const FactorGraph&);
virtual ~FgBpSolver (void);
void runSolver (void);
virtual Params getPosterioriOf (VarId);
virtual Params getJointDistributionOf (const VarIds&);
void runSolver (void);
virtual Params getPosterioriOf (VarId);
virtual Params getJointDistributionOf (const VarIds&);
protected:
virtual void initializeSolver (void);
virtual void createLinks (void);
virtual void maxResidualSchedule (void);
virtual void calculateFactor2VariableMsg (SpLink*) const;
virtual Params getVar2FactorMsg (const SpLink*) const;
virtual Params getJointByConditioning (const VarIds&) const;
virtual void printLinkInformation (void) const;
virtual void initializeSolver (void);
void calculateAndUpdateMessage (SpLink* link, bool calcResidual = true)
{
if (DL >= 3) {
cout << "calculating & updating " << link->toString() << endl;
}
calculateFactor2VariableMsg (link);
if (calcResidual) {
link->updateResidual();
}
link->updateMessage();
}
virtual void createLinks (void);
void calculateMessage (SpLink* link, bool calcResidual = true)
{
if (DL >= 3) {
cout << "calculating " << link->toString() << endl;
}
calculateFactor2VariableMsg (link);
if (calcResidual) {
link->updateResidual();
}
}
virtual void maxResidualSchedule (void);
void updateMessage (SpLink* link)
{
link->updateMessage();
if (DL >= 3) {
cout << "updating " << link->toString() << endl;
}
}
virtual void calculateFactor2VariableMsg (SpLink*) const;
virtual Params getVar2FactorMsg (const SpLink*) const;
virtual Params getJointByConditioning (const VarIds&) const;
virtual void printLinkInformation (void) const;
SPNodeInfo* ninf (const FgVarNode* var) const
{
@ -145,7 +126,39 @@ class FgBpSolver : public Solver
return facsI_[fac->getIndex()];
}
struct CompareResidual {
void calculateAndUpdateMessage (SpLink* link, bool calcResidual = true)
{
if (Constants::DEBUG >= 3) {
cout << "calculating & updating " << link->toString() << endl;
}
calculateFactor2VariableMsg (link);
if (calcResidual) {
link->updateResidual();
}
link->updateMessage();
}
void calculateMessage (SpLink* link, bool calcResidual = true)
{
if (Constants::DEBUG >= 3) {
cout << "calculating " << link->toString() << endl;
}
calculateFactor2VariableMsg (link);
if (calcResidual) {
link->updateResidual();
}
}
void updateMessage (SpLink* link)
{
link->updateMessage();
if (Constants::DEBUG >= 3) {
cout << "updating " << link->toString() << endl;
}
}
struct CompareResidual
{
inline bool operator() (const SpLink* link1, const SpLink* link2)
{
return link1->getResidual() > link2->getResidual();
@ -165,10 +178,8 @@ class FgBpSolver : public Solver
SpLinkMap linkMap_;
private:
void runLoopySolver (void);
bool converged (void);
void runLoopySolver (void);
bool converged (void);
};
#endif // HORUS_FGBPSOLVER_H

View File

@ -8,7 +8,9 @@
vector<LiftedOperator*>
LiftedOperator::getValidOps (ParfactorList& pfList, const Grounds& query)
LiftedOperator::getValidOps (
ParfactorList& pfList,
const Grounds& query)
{
vector<LiftedOperator*> validOps;
vector<SumOutOperator*> sumOutOps;
@ -28,12 +30,15 @@ LiftedOperator::getValidOps (ParfactorList& pfList, const Grounds& query)
void
LiftedOperator::printValidOps (ParfactorList& pfList, const Grounds& query)
LiftedOperator::printValidOps (
ParfactorList& pfList,
const Grounds& query)
{
vector<LiftedOperator*> validOps;
validOps = LiftedOperator::getValidOps (pfList, query);
for (unsigned i = 0; i < validOps.size(); i++) {
cout << "-> " << validOps[i]->toString() << endl;
delete validOps[i];
}
}
@ -56,14 +61,14 @@ SumOutOperator::getCost (void)
pfIter = pfList_.begin();
while (pfIter != pfList_.end()) {
if ((*pfIter)->containsGroup (groupSet[i])) {
int idx = (*pfIter)->indexOfFormulaWithGroup (groupSet[i]);
int idx = (*pfIter)->indexOfGroup (groupSet[i]);
cost *= (*pfIter)->range (idx);
break;
}
++ pfIter;
}
}
return cost;
return cost;
}
@ -77,14 +82,13 @@ SumOutOperator::apply (void)
pfList_.remove (iters[0]);
for (unsigned i = 1; i < iters.size(); i++) {
product->multiply (**(iters[i]));
delete *(iters[i]);
pfList_.remove (iters[i]);
pfList_.removeAndDelete (iters[i]);
}
if (product->nrFormulas() == 1) {
if (product->nrArguments() == 1) {
delete product;
return;
}
int fIdx = product->indexOfFormulaWithGroup (group_);
int fIdx = product->indexOfGroup (group_);
LogVarSet excl = product->exclusiveLogVars (fIdx);
if (product->constr()->isCountNormalized (excl)) {
product->sumOut (fIdx);
@ -96,21 +100,21 @@ SumOutOperator::apply (void)
pfList_.add (pfs[i]);
}
delete product;
pfList_.shatter();
}
}
vector<SumOutOperator*>
SumOutOperator::getValidOps (ParfactorList& pfList, const Grounds& query)
SumOutOperator::getValidOps (
ParfactorList& pfList,
const Grounds& query)
{
vector<SumOutOperator*> validOps;
set<unsigned> allGroups;
ParfactorList::const_iterator it = pfList.begin();
while (it != pfList.end()) {
assert (*it);
const ProbFormulas& formulas = (*it)->formulas();
const ProbFormulas& formulas = (*it)->arguments();
for (unsigned i = 0; i < formulas.size(); i++) {
allGroups.insert (formulas[i].group());
}
@ -134,8 +138,8 @@ SumOutOperator::toString (void)
stringstream ss;
vector<ParfactorList::iterator> pfIters;
pfIters = parfactorsWithGroup (pfList_, group_);
int idx = (*pfIters[0])->indexOfFormulaWithGroup (group_);
ProbFormula f = (*pfIters[0])->formula (idx);
int idx = (*pfIters[0])->indexOfGroup (group_);
ProbFormula f = (*pfIters[0])->argument (idx);
TupleSet tupleSet = (*pfIters[0])->constr()->tupleSet (f.logVars());
ss << "sum out " << f.functor() << "/" << f.arity();
ss << "|" << tupleSet << " (group " << group_ << ")";
@ -158,9 +162,9 @@ SumOutOperator::validOp (
}
unordered_map<unsigned, unsigned> groupToRange;
for (unsigned i = 0; i < pfIters.size(); i++) {
int fIdx = (*pfIters[i])->indexOfFormulaWithGroup (group);
if ((*pfIters[i])->formulas()[fIdx].contains (
(*pfIters[i])->elimLogVars()) == false) {
int fIdx = (*pfIters[i])->indexOfGroup (group);
if ((*pfIters[i])->argument (fIdx).contains (
(*pfIters[i])->elimLogVars()) == false) {
return false;
}
vector<unsigned> ranges = (*pfIters[i])->ranges();
@ -206,8 +210,8 @@ SumOutOperator::isToEliminate (
unsigned group,
const Grounds& query)
{
int fIdx = g->indexOfFormulaWithGroup (group);
const ProbFormula& formula = g->formula (fIdx);
int fIdx = g->indexOfGroup (group);
const ProbFormula& formula = g->argument (fIdx);
bool toElim = true;
for (unsigned i = 0; i < query.size(); i++) {
if (formula.functor() == query[i].functor() &&
@ -228,7 +232,7 @@ unsigned
CountingOperator::getCost (void)
{
unsigned cost = 0;
int fIdx = (*pfIter_)->indexOfFormulaWithLogVar (X_);
int fIdx = (*pfIter_)->indexOfLogVar (X_);
unsigned range = (*pfIter_)->range (fIdx);
unsigned size = (*pfIter_)->size() / range;
TinySet<unsigned> counts;
@ -247,18 +251,19 @@ CountingOperator::apply (void)
if ((*pfIter_)->constr()->isCountNormalized (X_)) {
(*pfIter_)->countConvert (X_);
} else {
Parfactors pfs = FoveSolver::countNormalize (*pfIter_, X_);
Parfactor* pf = *pfIter_;
pfList_.remove (pfIter_);
Parfactors pfs = FoveSolver::countNormalize (pf, X_);
for (unsigned i = 0; i < pfs.size(); i++) {
unsigned condCount = pfs[i]->constr()->getConditionalCount (X_);
bool cartProduct = pfs[i]->constr()->isCarteesianProduct (
(*pfIter_)->countedLogVars() | X_);
pfs[i]->countedLogVars() | X_);
if (condCount > 1 && cartProduct) {
pfs[i]->countConvert (X_);
}
pfList_.add (pfs[i]);
}
pfList_.deleteAndRemove (pfIter_);
pfList_.shatter();
delete pf;
}
}
@ -289,14 +294,17 @@ CountingOperator::toString (void)
{
stringstream ss;
ss << "count convert " << X_ << " in " ;
ss << (*pfIter_)->getHeaderString();
ss << (*pfIter_)->getLabel();
ss << " [cost=" << getCost() << "]" << endl;
Parfactors pfs = FoveSolver::countNormalize (*pfIter_, X_);
if ((*pfIter_)->constr()->isCountNormalized (X_) == false) {
for (unsigned i = 0; i < pfs.size(); i++) {
ss << " º " << pfs[i]->getHeaderString() << endl;
ss << " º " << pfs[i]->getLabel() << endl;
}
}
for (unsigned i = 0; i < pfs.size(); i++) {
delete pfs[i];
}
return ss.str();
}
@ -308,8 +316,8 @@ CountingOperator::validOp (Parfactor* g, LogVar X)
if (g->nrFormulas (X) != 1) {
return false;
}
int fIdx = g->indexOfFormulaWithLogVar (X);
if (g->formulas()[fIdx].isCounting()) {
int fIdx = g->indexOfLogVar (X);
if (g->argument (fIdx).isCounting()) {
return false;
}
bool countNormalized = g->constr()->isCountNormalized (X);
@ -332,10 +340,10 @@ GroundOperator::getCost (void)
unsigned cost = 0;
bool isCountingLv = (*pfIter_)->countedLogVars().contains (X_);
if (isCountingLv) {
int fIdx = (*pfIter_)->indexOfFormulaWithLogVar (X_);
int fIdx = (*pfIter_)->indexOfLogVar (X_);
unsigned currSize = (*pfIter_)->size();
unsigned nrHists = (*pfIter_)->range (fIdx);
unsigned range = (*pfIter_)->formula(fIdx).range();
unsigned range = (*pfIter_)->argument (fIdx).range();
unsigned nrSymbols = (*pfIter_)->constr()->getConditionalCount (X_);
cost = (currSize / nrHists) * (std::pow (range, nrSymbols));
} else {
@ -350,18 +358,17 @@ void
GroundOperator::apply (void)
{
bool countedLv = (*pfIter_)->countedLogVars().contains (X_);
Parfactor* pf = *pfIter_;
pfList_.remove (pfIter_);
if (countedLv) {
(*pfIter_)->fullExpand (X_);
(*pfIter_)->setNewGroups();
pfList_.shatter();
pf->fullExpand (X_);
pfList_.add (pf);
} else {
ConstraintTrees cts = (*pfIter_)->constr()->ground (X_);
ConstraintTrees cts = pf->constr()->ground (X_);
for (unsigned i = 0; i < cts.size(); i++) {
Parfactor* newPf = new Parfactor (*pfIter_, cts[i]);
pfList_.add (newPf);
pfList_.add (new Parfactor (pf, cts[i]));
}
pfList_.deleteAndRemove (pfIter_);
pfList_.shatter();
delete pf;
}
}
@ -393,24 +400,13 @@ GroundOperator::toString (void)
((*pfIter_)->countedLogVars().contains (X_))
? ss << "full expanding "
: ss << "grounding " ;
ss << X_ << " in " << (*pfIter_)->getHeaderString();
ss << X_ << " in " << (*pfIter_)->getLabel();
ss << " [cost=" << getCost() << "]" << endl;
return ss.str();
}
FoveSolver::FoveSolver (const ParfactorList* pfList)
{
for (ParfactorList::const_iterator it = pfList->begin();
it != pfList->end();
it ++) {
pfList_.addShattered (new Parfactor (**it));
}
}
Params
FoveSolver::getPosterioriOf (const Ground& query)
{
@ -422,14 +418,12 @@ FoveSolver::getPosterioriOf (const Ground& query)
Params
FoveSolver::getJointDistributionOf (const Grounds& query)
{
shatterAgainstQuery (query);
runSolver (query);
(*pfList_.begin())->normalize();
Params params = (*pfList_.begin())->params();
if (Globals::logDomain) {
Util::fromLog (params);
}
delete *pfList_.begin();
return params;
}
@ -438,32 +432,38 @@ FoveSolver::getJointDistributionOf (const Grounds& query)
void
FoveSolver::absorveEvidence (
ParfactorList& pfList,
const ObservedFormulas& obsFormulas)
ObservedFormulas& obsFormulas)
{
ParfactorList::iterator it = pfList.begin();
while (it != pfList.end()) {
bool increment = true;
for (unsigned i = 0; i < obsFormulas.size(); i++) {
if (absorved (pfList, it, obsFormulas[i])) {
it = pfList.deleteAndRemove (it);
increment = false;
break;
}
}
if (increment) {
++ it;
for (unsigned i = 0; i < obsFormulas.size(); i++) {
Parfactors newPfs;
ParfactorList::iterator it = pfList.begin();
while (it != pfList.end()) {
Parfactor* pf = *it;
it = pfList.remove (it);
Parfactors absorvedPfs = absorve (obsFormulas[i], pf);
if (absorvedPfs.empty() == false) {
if (absorvedPfs.size() == 1 && absorvedPfs[0] == 0) {
// just remove pf;
} else {
Util::addToVector (newPfs, absorvedPfs);
}
delete pf;
} else {
it = pfList.insertShattered (it, pf);
++ it;
}
}
pfList.add (newPfs);
}
pfList.shatter();
if (obsFormulas.empty() == false) {
cout << "*******************************************************" << endl;
if (Constants::DEBUG > 1 && obsFormulas.empty() == false) {
Util::printAsteriskLine();
cout << "AFTER EVIDENCE ABSORVED" << endl;
for (unsigned i = 0; i < obsFormulas.size(); i++) {
cout << " -> " << *obsFormulas[i] << endl;
cout << " -> " << obsFormulas[i] << endl;
}
cout << "*******************************************************" << endl;
Util::printAsteriskLine();
pfList.print();
}
pfList.print();
}
@ -473,14 +473,14 @@ FoveSolver::countNormalize (
Parfactor* g,
const LogVarSet& set)
{
if (set.empty()) {
assert (false); // TODO
return {};
}
Parfactors normPfs;
ConstraintTrees normCts = g->constr()->countNormalize (set);
for (unsigned i = 0; i < normCts.size(); i++) {
normPfs.push_back (new Parfactor (g, normCts[i]));
if (set.empty()) {
normPfs.push_back (new Parfactor (*g));
} else {
ConstraintTrees normCts = g->constr()->countNormalize (set);
for (unsigned i = 0; i < normCts.size(); i++) {
normPfs.push_back (new Parfactor (g, normCts[i]));
}
}
return normPfs;
}
@ -490,17 +490,25 @@ FoveSolver::countNormalize (
void
FoveSolver::runSolver (const Grounds& query)
{
shatterAgainstQuery (query);
runWeakBayesBall (query);
while (true) {
cout << "---------------------------------------------------" << endl;
pfList_.print();
LiftedOperator::printValidOps (pfList_, query);
if (Constants::DEBUG > 1) {
Util::printDashedLine();
pfList_.print();
LiftedOperator::printValidOps (pfList_, query);
}
LiftedOperator* op = getBestOperation (query);
if (op == 0) {
break;
}
cout << "best operation: " << op->toString() << endl;
if (Constants::DEBUG > 1) {
cout << "best operation: " << op->toString() << endl;
}
op->apply();
delete op;
}
assert (pfList_.size() > 0);
if (pfList_.size() > 1) {
ParfactorList::iterator pfIter = pfList_.begin();
pfIter ++;
@ -514,26 +522,6 @@ FoveSolver::runSolver (const Grounds& query)
bool
FoveSolver::allEliminated (const Grounds&)
{
ParfactorList::iterator pfIter = pfList_.begin();
while (pfIter != pfList_.end()) {
const ProbFormulas formulas = (*pfIter)->formulas();
for (unsigned i = 0; i < formulas.size(); i++) {
//bool toElim = false;
//for (unsigned j = 0; j < queries.size(); j++) {
// if ((*pfIter)->containsGround (queries[i]) == false) {
// return
// }
}
++ pfIter;
}
return false;
}
LiftedOperator*
FoveSolver::getBestOperation (const Grounds& query)
{
@ -548,156 +536,170 @@ FoveSolver::getBestOperation (const Grounds& query)
bestCost = cost;
}
}
for (unsigned i = 0; i < validOps.size(); i++) {
if (validOps[i] != bestOp) {
delete validOps[i];
}
}
return bestOp;
}
void
FoveSolver::runWeakBayesBall (const Grounds& query)
{
queue<unsigned> todo; // groups to process
set<unsigned> done; // processed or in queue
for (unsigned i = 0; i < query.size(); i++) {
ParfactorList::iterator it = pfList_.begin();
while (it != pfList_.end()) {
int group = (*it)->findGroup (query[i]);
if (group != -1) {
todo.push (group);
done.insert (group);
break;
}
++ it;
}
}
set<Parfactor*> requiredPfs;
while (todo.empty() == false) {
unsigned group = todo.front();
ParfactorList::iterator it = pfList_.begin();
while (it != pfList_.end()) {
if (Util::contains (requiredPfs, *it) == false &&
(*it)->containsGroup (group)) {
vector<unsigned> groups = (*it)->getAllGroups();
for (unsigned i = 0; i < groups.size(); i++) {
if (Util::contains (done, groups[i]) == false) {
todo.push (groups[i]);
done.insert (groups[i]);
}
}
requiredPfs.insert (*it);
}
++ it;
}
todo.pop();
}
ParfactorList::iterator it = pfList_.begin();
while (it != pfList_.end()) {
if (Util::contains (requiredPfs, *it) == false) {
it = pfList_.removeAndDelete (it);
} else {
++ it;
}
}
if (Constants::DEBUG > 1) {
Util::printHeader ("REQUIRED PARFACTORS");
pfList_.print();
}
}
void
FoveSolver::shatterAgainstQuery (const Grounds& query)
{
// return;
return ;
for (unsigned i = 0; i < query.size(); i++) {
if (query[i].isAtom()) {
continue;
}
ParfactorList pfListCopy = pfList_;
pfList_.clear();
for (ParfactorList::iterator it = pfListCopy.begin();
it != pfListCopy.end(); ++ it) {
Parfactor* pf = *it;
if (pf->containsGround (query[i])) {
Parfactors newPfs;
ParfactorList::iterator it = pfList_.begin();
while (it != pfList_.end()) {
if ((*it)->containsGround (query[i])) {
std::pair<ConstraintTree*, ConstraintTree*> split =
pf->constr()->split (query[i].args(), query[i].arity());
(*it)->constr()->split (query[i].args(), query[i].arity());
ConstraintTree* commCt = split.first;
ConstraintTree* exclCt = split.second;
pfList_.add (new Parfactor (pf, commCt));
newPfs.push_back (new Parfactor (*it, commCt));
if (exclCt->empty() == false) {
pfList_.add (new Parfactor (pf, exclCt));
newPfs.push_back (new Parfactor (*it, exclCt));
} else {
delete exclCt;
}
delete pf;
it = pfList_.removeAndDelete (it);
} else {
pfList_.add (pf);
++ it;
}
}
pfList_.shatter();
pfList_.add (newPfs);
}
cout << endl;
cout << "*******************************************************" << endl;
cout << "SHATTERED AGAINST THE QUERY" << endl;
for (unsigned i = 0; i < query.size(); i++) {
cout << " -> " << query[i] << endl;
if (Constants::DEBUG > 1) {
cout << endl;
Util::printAsteriskLine();
cout << "SHATTERED AGAINST THE QUERY" << endl;
for (unsigned i = 0; i < query.size(); i++) {
cout << " -> " << query[i] << endl;
}
Util::printAsteriskLine();
pfList_.print();
}
cout << "*******************************************************" << endl;
pfList_.print();
}
bool
FoveSolver::absorved (
ParfactorList& pfList,
ParfactorList::iterator pfIter,
const ObservedFormula* obsFormula)
Parfactors
FoveSolver::absorve (
ObservedFormula& obsFormula,
Parfactor* g)
{
Parfactors absorvedPfs;
Parfactor* g = *pfIter;
const ProbFormulas& formulas = g->formulas();
const ProbFormulas& formulas = g->arguments();
for (unsigned i = 0; i < formulas.size(); i++) {
if (obsFormula->functor() == formulas[i].functor() &&
obsFormula->arity() == formulas[i].arity()) {
if (obsFormula.functor() == formulas[i].functor() &&
obsFormula.arity() == formulas[i].arity()) {
if (obsFormula->isAtom()) {
if (obsFormula.isAtom()) {
if (formulas.size() > 1) {
g->absorveEvidence (i, obsFormula->evidence());
g->absorveEvidence (formulas[i], obsFormula.evidence());
} else {
return true;
// hack to erase parfactor g
absorvedPfs.push_back (0);
}
break;
}
g->constr()->moveToTop (formulas[i].logVars());
std::pair<ConstraintTree*, ConstraintTree*> res
= g->constr()->split (obsFormula->constr(), formulas[i].arity());
= g->constr()->split (&(obsFormula.constr()), formulas[i].arity());
ConstraintTree* commCt = res.first;
ConstraintTree* exclCt = res.second;
if (commCt->empty()) {
delete commCt;
delete exclCt;
continue;
}
if (exclCt->empty() == false) {
pfList.add (new Parfactor (g, exclCt));
} else {
delete exclCt;
}
if (formulas.size() > 1) {
LogVarSet excl = g->exclusiveLogVars (i);
Parfactors countNormPfs = countNormalize (g, excl);
for (unsigned j = 0; j < countNormPfs.size(); j++) {
countNormPfs[j]->absorveEvidence (i, obsFormula->evidence());
absorvedPfs.push_back (countNormPfs[j]);
if (commCt->empty() == false) {
if (formulas.size() > 1) {
LogVarSet excl = g->exclusiveLogVars (i);
Parfactors countNormPfs = countNormalize (g, excl);
for (unsigned j = 0; j < countNormPfs.size(); j++) {
countNormPfs[j]->absorveEvidence (
formulas[i], obsFormula.evidence());
absorvedPfs.push_back (countNormPfs[j]);
}
} else {
delete commCt;
}
if (exclCt->empty() == false) {
absorvedPfs.push_back (new Parfactor (g, exclCt));
} else {
delete exclCt;
}
if (absorvedPfs.empty()) {
// hack to erase parfactor g
absorvedPfs.push_back (0);
}
break;
} else {
delete commCt;
delete exclCt;
}
return true;
}
}
return false;
}
bool
FoveSolver::proper (
const ProbFormula& f1,
ConstraintTree* c1,
const ProbFormula& f2,
ConstraintTree* c2)
{
return disjoint (f1, c1, f2, c2)
|| identical (f1, c1, f2, c2);
}
bool
FoveSolver::identical (
const ProbFormula& f1,
ConstraintTree* c1,
const ProbFormula& f2,
ConstraintTree* c2)
{
if (f1.sameSkeletonAs (f2) == false) {
return false;
}
c1->moveToTop (f1.logVars());
c2->moveToTop (f2.logVars());
return ConstraintTree::identical (
c1, c2, f1.logVars().size());
}
bool
FoveSolver::disjoint (
const ProbFormula& f1,
ConstraintTree* c1,
const ProbFormula& f2,
ConstraintTree* c2)
{
if (f1.sameSkeletonAs (f2) == false) {
return true;
}
c1->moveToTop (f1.logVars());
c2->moveToTop (f2.logVars());
return ConstraintTree::overlap (
c1, c2, f1.arity()) == false;
return absorvedPfs;
}

View File

@ -9,10 +9,14 @@ class LiftedOperator
{
public:
virtual unsigned getCost (void) = 0;
virtual void apply (void) = 0;
virtual string toString (void) = 0;
static vector<LiftedOperator*> getValidOps (
ParfactorList&, const Grounds&);
static void printValidOps (ParfactorList&, const Grounds&);
};
@ -23,18 +27,26 @@ class SumOutOperator : public LiftedOperator
public:
SumOutOperator (unsigned group, ParfactorList& pfList)
: group_(group), pfList_(pfList) { }
unsigned getCost (void);
void apply (void);
static vector<SumOutOperator*> getValidOps (
ParfactorList&, const Grounds&);
string toString (void);
private:
static bool validOp (unsigned, ParfactorList&, const Grounds&);
static vector<ParfactorList::iterator> parfactorsWithGroup (
ParfactorList& pfList, unsigned group);
static bool isToEliminate (Parfactor*, unsigned, const Grounds&);
unsigned group_;
ParfactorList& pfList_;
unsigned group_;
ParfactorList& pfList_;
};
@ -47,15 +59,21 @@ class CountingOperator : public LiftedOperator
LogVar X,
ParfactorList& pfList)
: pfIter_(pfIter), X_(X), pfList_(pfList) { }
unsigned getCost (void);
void apply (void);
static vector<CountingOperator*> getValidOps (ParfactorList&);
string toString (void);
private:
static bool validOp (Parfactor*, LogVar);
ParfactorList::iterator pfIter_;
LogVar X_;
ParfactorList& pfList_;
ParfactorList::iterator pfIter_;
LogVar X_;
ParfactorList& pfList_;
};
@ -68,14 +86,19 @@ class GroundOperator : public LiftedOperator
LogVar X,
ParfactorList& pfList)
: pfIter_(pfIter), X_(X), pfList_(pfList) { }
unsigned getCost (void);
void apply (void);
static vector<GroundOperator*> getValidOps (ParfactorList&);
string toString (void);
private:
ParfactorList::iterator pfIter_;
LogVar X_;
ParfactorList& pfList_;
ParfactorList::iterator pfIter_;
LogVar X_;
ParfactorList& pfList_;
};
@ -83,49 +106,29 @@ class GroundOperator : public LiftedOperator
class FoveSolver
{
public:
FoveSolver (const ParfactorList*);
FoveSolver (const ParfactorList& pfList) : pfList_(pfList) { }
Params getPosterioriOf (const Ground&);
Params getJointDistributionOf (const Grounds&);
Params getPosterioriOf (const Ground&);
static void absorveEvidence (
ParfactorList& pfList,
const ObservedFormulas& obsFormulas);
Params getJointDistributionOf (const Grounds&);
static Parfactors countNormalize (Parfactor*, const LogVarSet&);
static void absorveEvidence (
ParfactorList& pfList, ObservedFormulas& obsFormulas);
static Parfactors countNormalize (Parfactor*, const LogVarSet&);
private:
void runSolver (const Grounds&);
bool allEliminated (const Grounds&);
LiftedOperator* getBestOperation (const Grounds&);
void shatterAgainstQuery (const Grounds&);
void runSolver (const Grounds&);
static bool absorved (
ParfactorList& pfList,
ParfactorList::iterator pfIter,
const ObservedFormula*);
LiftedOperator* getBestOperation (const Grounds&);
public:
void runWeakBayesBall (const Grounds&);
static bool proper (
const ProbFormula&,
ConstraintTree*,
const ProbFormula&,
ConstraintTree*);
void shatterAgainstQuery (const Grounds&);
static bool identical (
const ProbFormula&,
ConstraintTree*,
const ProbFormula&,
ConstraintTree*);
static Parfactors absorve (ObservedFormula&, Parfactor*);
static bool disjoint (
const ProbFormula&,
ConstraintTree*,
const ProbFormula&,
ConstraintTree*);
ParfactorList pfList_;
ParfactorList pfList_;
};
#endif // HORUS_FOVESOLVER_H

View File

@ -1,66 +1,63 @@
#ifndef HORUS_GRAPHICALMODEL_H
#define HORUS_GRAPHICALMODEL_H
#include <cassert>
#include <unordered_map>
#include <sstream>
#include "VarNode.h"
#include "Distribution.h"
#include "Util.h"
#include "Horus.h"
using namespace std;
struct VariableInfo
struct VarInfo
{
VariableInfo (string l, const States& sts)
{
label = l;
states = sts;
}
string label;
States states;
VarInfo (string l, const States& sts) : label(l), states(sts) { }
string label;
States states;
};
class GraphicalModel
{
public:
virtual ~GraphicalModel (void) {};
virtual VarNode* getVariableNode (VarId) const = 0;
virtual VarNodes getVariableNodes (void) const = 0;
virtual void printGraphicalModel (void) const = 0;
virtual ~GraphicalModel (void) { };
virtual VarNode* getVariableNode (VarId) const = 0;
virtual VarNodes getVariableNodes (void) const = 0;
virtual void printGraphicalModel (void) const = 0;
static void addVariableInformation (VarId vid, string label,
const States& states)
static void addVariableInformation (
VarId vid, string label, const States& states)
{
assert (varsInfo_.find (vid) == varsInfo_.end());
varsInfo_.insert (make_pair (vid, VariableInfo (label, states)));
assert (Util::contains (varsInfo_, vid) == false);
varsInfo_.insert (make_pair (vid, VarInfo (label, states)));
}
static VariableInfo getVariableInformation (VarId vid)
static VarInfo getVarInformation (VarId vid)
{
assert (varsInfo_.find (vid) != varsInfo_.end());
assert (Util::contains (varsInfo_, vid));
return varsInfo_.find (vid)->second;
}
static bool variablesHaveInformation (void)
{
return varsInfo_.size() != 0;
}
static void clearVariablesInformation (void)
{
varsInfo_.clear();
}
static void addDistribution (unsigned id, Distribution* dist)
{
distsInfo_[id] = dist;
}
static void updateDistribution (unsigned id, const Params& params)
{
distsInfo_[id]->updateParameters (params);
}
private:
static unordered_map<VarId,VariableInfo> varsInfo_;
static unordered_map<unsigned,Distribution*> distsInfo_;
static unordered_map<VarId,VarInfo> varsInfo_;
};
#endif // HORUS_GRAPHICALMODEL_H

View File

@ -84,16 +84,34 @@ HistogramSet::nrHistograms (unsigned N, unsigned R)
unsigned
HistogramSet::findIndex (
const Histogram& hist,
const vector<Histogram>& histograms)
const Histogram& h,
const vector<Histogram>& hists)
{
vector<Histogram>::const_iterator it = std::lower_bound (
histograms.begin(),
histograms.end(),
hist,
std::greater<Histogram>());
assert (it != histograms.end() && *it == hist);
return std::distance (histograms.begin(), it);
hists.begin(), hists.end(), h, std::greater<Histogram>());
assert (it != hists.end() && *it == h);
return std::distance (hists.begin(), it);
}
vector<double>
HistogramSet::getNumAssigns (unsigned N, unsigned R)
{
HistogramSet hs (N, R);
unsigned N_factorial = Util::factorial (N);
unsigned H = hs.nrHistograms();
vector<double> numAssigns;
numAssigns.reserve (H);
for (unsigned h = 0; h < H; h++) {
unsigned prod = 1;
for (unsigned r = 0; r < R; r++) {
prod *= Util::factorial (hs[r]);
}
numAssigns.push_back (LogAware::tl (N_factorial / prod));
hs.nextHistogram();
}
return numAssigns;
}

View File

@ -26,8 +26,9 @@ class HistogramSet
static unsigned nrHistograms (unsigned, unsigned);
static unsigned findIndex (
const Histogram&,
const vector<Histogram>&);
const Histogram&, const vector<Histogram>&);
static vector<double> getNumAssigns (unsigned, unsigned);
friend std::ostream& operator<< (ostream &os, const HistogramSet& hs);

View File

@ -1,17 +1,9 @@
#ifndef HORUS_HORUS_H
#define HORUS_HORUS_H
#include <cmath>
#include <cassert>
#include <limits>
#include <algorithm>
#include <vector>
#include <unordered_map>
#include <iostream>
#include <fstream>
#include <sstream>
#define DISALLOW_COPY_AND_ASSIGN(TypeName) \
TypeName(const TypeName&); \
@ -25,49 +17,48 @@ class FgVarNode;
class FgFacNode;
class Factor;
typedef vector<double> Params;
typedef unsigned VarId;
typedef vector<VarId> VarIds;
typedef vector<VarNode*> VarNodes;
typedef vector<BayesNode*> BnNodeSet;
typedef vector<FgVarNode*> FgVarSet;
typedef vector<FgFacNode*> FgFacSet;
typedef vector<Factor*> FactorSet;
typedef vector<string> States;
typedef vector<unsigned> Ranges;
typedef vector<double> Params;
typedef unsigned VarId;
typedef vector<VarId> VarIds;
typedef vector<VarNode*> VarNodes;
typedef vector<BayesNode*> BnNodeSet;
typedef vector<FgVarNode*> FgVarSet;
typedef vector<FgFacNode*> FgFacSet;
typedef vector<Factor*> FactorSet;
typedef vector<string> States;
typedef vector<unsigned> Ranges;
namespace Globals {
extern bool logDomain;
enum InfAlgorithms
{
VE, // variable elimination
BN_BP, // bayesian network belief propagation
FG_BP, // factor graph belief propagation
CBP // counting bp solver
};
// level of debug information
static const unsigned DL = 1;
namespace Globals {
static const int NO_EVIDENCE = -1;
extern bool logDomain;
extern InfAlgorithms infAlgorithm;
};
namespace Constants {
// level of debug information
const unsigned DEBUG = 2;
const int NO_EVIDENCE = -1;
// number of digits to show when printing a parameter
static const unsigned PRECISION = 5;
const unsigned PRECISION = 5;
static const bool COLLECT_STATISTICS = false;
const bool COLLECT_STATS = false;
static const bool EXPORT_TO_GRAPHVIZ = false;
static const unsigned EXPORT_MINIMAL_SIZE = 100;
static const double INF = -numeric_limits<double>::infinity();
namespace InfAlgorithms {
enum InfAlgs
{
VE, // variable elimination
BN_BP, // bayesian network belief propagation
FG_BP, // factor graph belief propagation
CBP // counting bp solver
};
extern InfAlgs infAlgorithm;
};

View File

@ -10,10 +10,6 @@
#include "FgBpSolver.h"
#include "CbpSolver.h"
//#include "TinySet.h"
#include "LiftedUtils.h"
using namespace std;
void processArguments (BayesNet&, int, const char* []);
@ -24,38 +20,9 @@ const string USAGE = "usage: \
./hcli FILE [VARIABLE | OBSERVED_VARIABLE=EVIDENCE]..." ;
class Cenas
{
public:
Cenas (int cc)
{
c = cc;
}
//operator int (void) const
//{
// cout << "return int" << endl;
// return c;
//}
operator double (void) const
{
cout << "return double" << endl;
return 0.0;
}
private:
int c;
};
int
main (int argc, const char* argv[])
{
LogVar X = 3;
LogVarSet Xs = X;
cout << "set: " << X << endl;
Cenas c1 (1);
Cenas c2 (3);
cout << (c1 < c2) << endl;
return 0;
if (!argv[1]) {
cerr << "error: no graphical model specified" << endl;
cerr << USAGE << endl;
@ -99,7 +66,6 @@ processArguments (BayesNet& bn, int argc, const char* argv[])
cerr << "error: there isn't a variable labeled of " ;
cerr << "`" << arg << "'" ;
cerr << endl;
bn.freeDistributions();
exit (0);
}
} else {
@ -109,13 +75,11 @@ processArguments (BayesNet& bn, int argc, const char* argv[])
if (label.empty()) {
cerr << "error: missing left argument" << endl;
cerr << USAGE << endl;
bn.freeDistributions();
exit (0);
}
if (state.empty()) {
cerr << "error: missing right argument" << endl;
cerr << USAGE << endl;
bn.freeDistributions();
exit (0);
}
BayesNode* node = bn.getBayesNode (label);
@ -127,14 +91,12 @@ processArguments (BayesNet& bn, int argc, const char* argv[])
cerr << "is not a valid state for " ;
cerr << "`" << node->label() << "'" ;
cerr << endl;
bn.freeDistributions();
exit (0);
}
} else {
cerr << "error: there isn't a variable labeled of " ;
cerr << "`" << label << "'" ;
cerr << endl;
bn.freeDistributions();
exit (0);
}
}
@ -142,7 +104,7 @@ processArguments (BayesNet& bn, int argc, const char* argv[])
Solver* solver = 0;
FactorGraph* fg = 0;
switch (InfAlgorithms::infAlgorithm) {
switch (Globals::infAlgorithm) {
case InfAlgorithms::VE:
fg = new FactorGraph (bn);
solver = new VarElimSolver (*fg);
@ -163,7 +125,6 @@ processArguments (BayesNet& bn, int argc, const char* argv[])
}
runSolver (solver, queryVars);
delete fg;
bn.freeDistributions();
}
@ -179,7 +140,6 @@ processArguments (FactorGraph& fg, int argc, const char* argv[])
cerr << "error: `" << arg << "' " ;
cerr << "is not a valid variable id" ;
cerr << endl;
fg.freeDistributions();
exit (0);
}
VarId vid;
@ -193,7 +153,6 @@ processArguments (FactorGraph& fg, int argc, const char* argv[])
cerr << "error: there isn't a variable with " ;
cerr << "`" << vid << "' as id" ;
cerr << endl;
fg.freeDistributions();
exit (0);
}
} else {
@ -201,20 +160,17 @@ processArguments (FactorGraph& fg, int argc, const char* argv[])
if (arg.substr (0, pos).empty()) {
cerr << "error: missing left argument" << endl;
cerr << USAGE << endl;
fg.freeDistributions();
exit (0);
}
if (arg.substr (pos + 1).empty()) {
cerr << "error: missing right argument" << endl;
cerr << USAGE << endl;
fg.freeDistributions();
exit (0);
}
if (!Util::isInteger (arg.substr (0, pos))) {
cerr << "error: `" << arg.substr (0, pos) << "' " ;
cerr << "is not a variable id" ;
cerr << endl;
fg.freeDistributions();
exit (0);
}
VarId vid;
@ -227,7 +183,6 @@ processArguments (FactorGraph& fg, int argc, const char* argv[])
cerr << "error: `" << arg.substr (pos + 1) << "' " ;
cerr << "is not a state index" ;
cerr << endl;
fg.freeDistributions();
exit (0);
}
int stateIndex;
@ -241,28 +196,23 @@ processArguments (FactorGraph& fg, int argc, const char* argv[])
cerr << "is not a valid state index for variable " ;
cerr << "`" << var->varId() << "'" ;
cerr << endl;
fg.freeDistributions();
exit (0);
}
} else {
cerr << "error: there isn't a variable with " ;
cerr << "`" << vid << "' as id" ;
cerr << endl;
fg.freeDistributions();
exit (0);
}
}
}
Solver* solver = 0;
switch (InfAlgorithms::infAlgorithm) {
switch (Globals::infAlgorithm) {
case InfAlgorithms::VE:
solver = new VarElimSolver (fg);
break;
case InfAlgorithms::BN_BP:
case InfAlgorithms::FG_BP:
//cout << "here!" << endl;
//fg.printGraphicalModel();
//fg.exportToLibDaiFormat ("net.fg");
solver = new FgBpSolver (fg);
break;
case InfAlgorithms::CBP:
@ -272,7 +222,6 @@ processArguments (FactorGraph& fg, int argc, const char* argv[])
assert (false);
}
runSolver (solver, queryVars);
fg.freeDistributions();
}

View File

@ -7,22 +7,26 @@
#include <YapInterface.h>
#include "ParfactorList.h"
#include "BayesNet.h"
#include "FactorGraph.h"
#include "FoveSolver.h"
#include "VarElimSolver.h"
#include "BnBpSolver.h"
#include "FgBpSolver.h"
#include "CbpSolver.h"
#include "ElimGraph.h"
#include "FoveSolver.h"
#include "ParfactorList.h"
using namespace std;
typedef std::pair<ParfactorList*, ObservedFormulas*> LiftedNetwork;
Params readParams (YAP_Term);
void readLiftedEvidence (YAP_Term, ObservedFormulas&);
Parfactor* readParfactor (YAP_Term);
int createLiftedNetwork (void)
@ -30,107 +34,124 @@ int createLiftedNetwork (void)
Parfactors parfactors;
YAP_Term parfactorList = YAP_ARG1;
while (parfactorList != YAP_TermNil()) {
YAP_Term parfactor = YAP_HeadOfTerm (parfactorList);
// read dist id
unsigned distId = YAP_IntOfTerm (YAP_ArgOfTerm (1, parfactor));
// read the ranges
Ranges ranges;
YAP_Term rangeList = YAP_ArgOfTerm (3, parfactor);
while (rangeList != YAP_TermNil()) {
unsigned range = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (rangeList));
ranges.push_back (range);
rangeList = YAP_TailOfTerm (rangeList);
}
// read parametric random vars
ProbFormulas formulas;
unsigned count = 0;
unordered_map<YAP_Term, LogVar> lvMap;
YAP_Term pvList = YAP_ArgOfTerm (2, parfactor);
while (pvList != YAP_TermNil()) {
YAP_Term formulaTerm = YAP_HeadOfTerm (pvList);
if (YAP_IsAtomTerm (formulaTerm)) {
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (formulaTerm)));
Symbol functor = LiftedUtils::getSymbol (name);
formulas.push_back (ProbFormula (functor, ranges[count]));
} else {
LogVars logVars;
YAP_Functor yapFunctor = YAP_FunctorOfTerm (formulaTerm);
string name ((char*) YAP_AtomName (YAP_NameOfFunctor (yapFunctor)));
Symbol functor = LiftedUtils::getSymbol (name);
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
for (unsigned i = 1; i <= arity; i++) {
YAP_Term ti = YAP_ArgOfTerm (i, formulaTerm);
unordered_map<YAP_Term, LogVar>::iterator it = lvMap.find (ti);
if (it != lvMap.end()) {
logVars.push_back (it->second);
} else {
unsigned newLv = lvMap.size();
lvMap[ti] = newLv;
logVars.push_back (newLv);
}
}
formulas.push_back (ProbFormula (functor, logVars, ranges[count]));
}
count ++;
pvList = YAP_TailOfTerm (pvList);
}
// read the parameters
const Params& params = readParams (YAP_ArgOfTerm (4, parfactor));
// read the constraint
Tuples tuples;
if (lvMap.size() >= 1) {
YAP_Term tupleList = YAP_ArgOfTerm (5, parfactor);
while (tupleList != YAP_TermNil()) {
YAP_Term term = YAP_HeadOfTerm (tupleList);
assert (YAP_IsApplTerm (term));
YAP_Functor yapFunctor = YAP_FunctorOfTerm (term);
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
assert (lvMap.size() == arity);
Tuple tuple (arity);
for (unsigned i = 1; i <= arity; i++) {
YAP_Term ti = YAP_ArgOfTerm (i, term);
if (YAP_IsAtomTerm (ti) == false) {
cerr << "error: bad formed constraint" << endl;
abort();
}
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ti)));
tuple[i - 1] = LiftedUtils::getSymbol (name);
}
tuples.push_back (tuple);
tupleList = YAP_TailOfTerm (tupleList);
}
}
parfactors.push_back (new Parfactor (formulas, params, tuples, distId));
YAP_Term pfTerm = YAP_HeadOfTerm (parfactorList);
parfactors.push_back (readParfactor (pfTerm));
parfactorList = YAP_TailOfTerm (parfactorList);
}
// LiftedUtils::printSymbolDictionary();
cout << "*******************************************************" << endl;
cout << "INITIAL PARFACTORS" << endl;
cout << "*******************************************************" << endl;
for (unsigned i = 0; i < parfactors.size(); i++) {
parfactors[i]->print();
cout << endl;
if (Constants::DEBUG > 1) {
// Util::printHeader ("INITIAL PARFACTORS");
// for (unsigned i = 0; i < parfactors.size(); i++) {
// parfactors[i]->print();
// cout << endl;
// }
// parfactors[0]->countConvert (LogVar (0));
//parfactors[1]->fullExpand (LogVar (1));
Util::printHeader ("SHATTERED PARFACTORS");
}
ParfactorList* pfList = new ParfactorList();
for (unsigned i = 0; i < parfactors.size(); i++) {
pfList->add (parfactors[i]);
}
cout << endl;
cout << "*******************************************************" << endl;
cout << "SHATTERED PARFACTORS" << endl;
cout << "*******************************************************" << endl;
pfList->shatter();
pfList->print();
// insert the evidence
ObservedFormulas obsFormulas;
YAP_Term observedList = YAP_ARG2;
ParfactorList* pfList = new ParfactorList (parfactors);
if (Constants::DEBUG > 1) {
pfList->print();
}
// read evidence
ObservedFormulas* obsFormulas = new ObservedFormulas();
readLiftedEvidence (YAP_ARG2, *(obsFormulas));
LiftedNetwork* net = new LiftedNetwork (pfList, obsFormulas);
YAP_Int p = (YAP_Int) (net);
return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG3);
}
Parfactor* readParfactor (YAP_Term pfTerm)
{
// read dist id
unsigned distId = YAP_IntOfTerm (YAP_ArgOfTerm (1, pfTerm));
// read the ranges
Ranges ranges;
YAP_Term rangeList = YAP_ArgOfTerm (3, pfTerm);
while (rangeList != YAP_TermNil()) {
unsigned range = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (rangeList));
ranges.push_back (range);
rangeList = YAP_TailOfTerm (rangeList);
}
// read parametric random vars
ProbFormulas formulas;
unsigned count = 0;
unordered_map<YAP_Term, LogVar> lvMap;
YAP_Term pvList = YAP_ArgOfTerm (2, pfTerm);
while (pvList != YAP_TermNil()) {
YAP_Term formulaTerm = YAP_HeadOfTerm (pvList);
if (YAP_IsAtomTerm (formulaTerm)) {
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (formulaTerm)));
Symbol functor = LiftedUtils::getSymbol (name);
formulas.push_back (ProbFormula (functor, ranges[count]));
} else {
LogVars logVars;
YAP_Functor yapFunctor = YAP_FunctorOfTerm (formulaTerm);
string name ((char*) YAP_AtomName (YAP_NameOfFunctor (yapFunctor)));
Symbol functor = LiftedUtils::getSymbol (name);
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
for (unsigned i = 1; i <= arity; i++) {
YAP_Term ti = YAP_ArgOfTerm (i, formulaTerm);
unordered_map<YAP_Term, LogVar>::iterator it = lvMap.find (ti);
if (it != lvMap.end()) {
logVars.push_back (it->second);
} else {
unsigned newLv = lvMap.size();
lvMap[ti] = newLv;
logVars.push_back (newLv);
}
}
formulas.push_back (ProbFormula (functor, logVars, ranges[count]));
}
count ++;
pvList = YAP_TailOfTerm (pvList);
}
// read the parameters
const Params& params = readParams (YAP_ArgOfTerm (4, pfTerm));
// read the constraint
Tuples tuples;
if (lvMap.size() >= 1) {
YAP_Term tupleList = YAP_ArgOfTerm (5, pfTerm);
while (tupleList != YAP_TermNil()) {
YAP_Term term = YAP_HeadOfTerm (tupleList);
assert (YAP_IsApplTerm (term));
YAP_Functor yapFunctor = YAP_FunctorOfTerm (term);
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
assert (lvMap.size() == arity);
Tuple tuple (arity);
for (unsigned i = 1; i <= arity; i++) {
YAP_Term ti = YAP_ArgOfTerm (i, term);
if (YAP_IsAtomTerm (ti) == false) {
cerr << "error: constraint has free variables" << endl;
abort();
}
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ti)));
tuple[i - 1] = LiftedUtils::getSymbol (name);
}
tuples.push_back (tuple);
tupleList = YAP_TailOfTerm (tupleList);
}
}
return new Parfactor (formulas, params, tuples, distId);
}
void readLiftedEvidence (
YAP_Term observedList,
ObservedFormulas& obsFormulas)
{
while (observedList != YAP_TermNil()) {
YAP_Term pair = YAP_HeadOfTerm (observedList);
YAP_Term ground = YAP_ArgOfTerm (1, pair);
@ -155,22 +176,18 @@ int createLiftedNetwork (void)
unsigned evidence = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (2, pair));
bool found = false;
for (unsigned i = 0; i < obsFormulas.size(); i++) {
if (obsFormulas[i]->functor() == functor &&
obsFormulas[i]->arity() == args.size() &&
obsFormulas[i]->evidence() == evidence) {
obsFormulas[i]->addTuple (args);
if (obsFormulas[i].functor() == functor &&
obsFormulas[i].arity() == args.size() &&
obsFormulas[i].evidence() == evidence) {
obsFormulas[i].addTuple (args);
found = true;
}
}
if (found == false) {
obsFormulas.push_back (new ObservedFormula (functor, evidence, args));
obsFormulas.push_back (ObservedFormula (functor, evidence, args));
}
observedList = YAP_TailOfTerm (observedList);
}
FoveSolver::absorveEvidence (*pfList, obsFormulas);
YAP_Int p = (YAP_Int) (pfList);
return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG3);
}
}
@ -186,7 +203,6 @@ createGroundNetwork (void)
// }
BayesNet* bn = new BayesNet();
YAP_Term varList = YAP_ARG1;
BnNodeSet nodes;
vector<VarIds> parents;
while (varList != YAP_TermNil()) {
YAP_Term var = YAP_HeadOfTerm (varList);
@ -201,15 +217,13 @@ createGroundNetwork (void)
parents.back().push_back (parentId);
parentL = YAP_TailOfTerm (parentL);
}
Distribution* dist = bn->getDistribution (distId);
if (!dist) {
dist = new Distribution (distId);
bn->addDistribution (dist);
}
assert (bn->getBayesNode (vid) == 0);
nodes.push_back (bn->addNode (vid, dsize, evidence, dist));
BayesNode* newNode = new BayesNode (
vid, dsize, evidence, Params(), distId);
bn->addNode (newNode);
varList = YAP_TailOfTerm (varList);
}
const BnNodeSet& nodes = bn->getBayesNodes();
for (unsigned i = 0; i < nodes.size(); i++) {
BnNodeSet ps;
for (unsigned j = 0; j < parents[i].size(); j++) {
@ -225,41 +239,6 @@ createGroundNetwork (void)
int
setBayesNetParams (void)
{
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
YAP_Term distList = YAP_ARG2;
while (distList != YAP_TermNil()) {
YAP_Term dist = YAP_HeadOfTerm (distList);
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, dist));
const Params params = readParams (YAP_ArgOfTerm (2, dist));
bn->getDistribution(distId)->updateParameters (params);
distList = YAP_TailOfTerm (distList);
}
return TRUE;
}
int
setParfactorGraphParams (void)
{
// FIXME
// ParfactorGraph* pfg = (ParfactorGraph*) YAP_IntOfTerm (YAP_ARG1);
YAP_Term distList = YAP_ARG2;
while (distList != YAP_TermNil()) {
// YAP_Term dist = YAP_HeadOfTerm (distList);
// unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, dist));
// const Params params = readParams (YAP_ArgOfTerm (2, dist));
// pfg->getDistribution(distId)->setData (params);
distList = YAP_TailOfTerm (distList);
}
return TRUE;
}
Params
readParams (YAP_Term paramL)
{
@ -279,15 +258,14 @@ readParams (YAP_Term paramL)
int
runLiftedSolver (void)
{
ParfactorList* pfList = (ParfactorList*) YAP_IntOfTerm (YAP_ARG1);
LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1);
YAP_Term taskList = YAP_ARG2;
vector<Params> results;
ParfactorList pfListCopy (*network->first);
FoveSolver::absorveEvidence (pfListCopy, *network->second);
while (taskList != YAP_TermNil()) {
YAP_Term jointList = YAP_HeadOfTerm (taskList);
Grounds queryVars;
assert (YAP_IsPairTerm (taskList));
assert (YAP_IsPairTerm (jointList));
YAP_Term jointList = YAP_HeadOfTerm (taskList);
while (jointList != YAP_TermNil()) {
YAP_Term ground = YAP_HeadOfTerm (jointList);
if (YAP_IsAtomTerm (ground)) {
@ -310,11 +288,11 @@ runLiftedSolver (void)
}
jointList = YAP_TailOfTerm (jointList);
}
FoveSolver solver (pfList);
FoveSolver solver (pfListCopy);
if (queryVars.size() == 1) {
results.push_back (solver.getPosterioriOf (queryVars[0]));
} else {
assert (false); // TODO joint dist
results.push_back (solver.getJointDistributionOf (queryVars));
}
taskList = YAP_TailOfTerm (taskList);
}
@ -339,46 +317,40 @@ runLiftedSolver (void)
int
runOtherSolvers (void)
runGroundSolver (void)
{
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
YAP_Term taskList = YAP_ARG2;
vector<VarIds> tasks;
std::set<VarId> vids;
while (taskList != YAP_TermNil()) {
if (YAP_IsPairTerm (YAP_HeadOfTerm (taskList))) {
tasks.push_back (VarIds());
YAP_Term jointList = YAP_HeadOfTerm (taskList);
while (jointList != YAP_TermNil()) {
VarId vid = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (jointList));
assert (bn->getBayesNode (vid));
tasks.back().push_back (vid);
vids.insert (vid);
jointList = YAP_TailOfTerm (jointList);
}
} else {
VarId vid = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (taskList));
VarIds queryVars;
YAP_Term jointList = YAP_HeadOfTerm (taskList);
while (jointList != YAP_TermNil()) {
VarId vid = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (jointList));
assert (bn->getBayesNode (vid));
tasks.push_back (VarIds() = {vid});
queryVars.push_back (vid);
vids.insert (vid);
jointList = YAP_TailOfTerm (jointList);
}
tasks.push_back (queryVars);
taskList = YAP_TailOfTerm (taskList);
}
Solver* bpSolver = 0;
GraphicalModel* graphicalModel = 0;
CFactorGraph::checkForIdenticalFactors = false;
if (InfAlgorithms::infAlgorithm != InfAlgorithms::VE) {
if (Globals::infAlgorithm != InfAlgorithms::VE) {
BayesNet* mrn = bn->getMinimalRequesiteNetwork (
VarIds (vids.begin(), vids.end()));
if (InfAlgorithms::infAlgorithm == InfAlgorithms::BN_BP) {
if (Globals::infAlgorithm == InfAlgorithms::BN_BP) {
graphicalModel = mrn;
bpSolver = new BnBpSolver (*static_cast<BayesNet*> (graphicalModel));
} else if (InfAlgorithms::infAlgorithm == InfAlgorithms::FG_BP) {
} else if (Globals::infAlgorithm == InfAlgorithms::FG_BP) {
graphicalModel = new FactorGraph (*mrn);
bpSolver = new FgBpSolver (*static_cast<FactorGraph*> (graphicalModel));
delete mrn;
} else if (InfAlgorithms::infAlgorithm == InfAlgorithms::CBP) {
} else if (Globals::infAlgorithm == InfAlgorithms::CBP) {
graphicalModel = new FactorGraph (*mrn);
bpSolver = new CbpSolver (*static_cast<FactorGraph*> (graphicalModel));
delete mrn;
@ -389,8 +361,7 @@ runOtherSolvers (void)
vector<Params> results;
results.reserve (tasks.size());
for (unsigned i = 0; i < tasks.size(); i++) {
//if (i == 1) exit (0);
if (InfAlgorithms::infAlgorithm == InfAlgorithms::VE) {
if (Globals::infAlgorithm == InfAlgorithms::VE) {
BayesNet* mrn = bn->getMinimalRequesiteNetwork (tasks[i]);
VarElimSolver* veSolver = new VarElimSolver (*mrn);
if (tasks[i].size() == 1) {
@ -430,10 +401,57 @@ runOtherSolvers (void)
int
setParfactorsParams (void)
{
LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1);
ParfactorList* pfList = network->first;
YAP_Term distList = YAP_ARG2;
unordered_map<unsigned, Params> paramsMap;
while (distList != YAP_TermNil()) {
YAP_Term dist = YAP_HeadOfTerm (distList);
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, dist));
assert (Util::contains (paramsMap, distId) == false);
paramsMap[distId] = readParams (YAP_ArgOfTerm (2, dist));
distList = YAP_TailOfTerm (distList);
}
ParfactorList::iterator it = pfList->begin();
while (it != pfList->end()) {
assert (Util::contains (paramsMap, (*it)->distId()));
// (*it)->setParams (paramsMap[(*it)->distId()]);
++ it;
}
return TRUE;
}
int
setBayesNetParams (void)
{
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
YAP_Term distList = YAP_ARG2;
unordered_map<unsigned, Params> paramsMap;
while (distList != YAP_TermNil()) {
YAP_Term dist = YAP_HeadOfTerm (distList);
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, dist));
assert (Util::contains (paramsMap, distId) == false);
paramsMap[distId] = readParams (YAP_ArgOfTerm (2, dist));
distList = YAP_TailOfTerm (distList);
}
const BnNodeSet& nodes = bn->getBayesNodes();
for (unsigned i = 0; i < nodes.size(); i++) {
assert (Util::contains (paramsMap, nodes[i]->distId()));
nodes[i]->setParams (paramsMap[nodes[i]->distId()]);
}
return TRUE;
}
int
setExtraVarsInfo (void)
{
// BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
GraphicalModel::clearVariablesInformation();
YAP_Term varsInfoL = YAP_ARG2;
while (varsInfoL != YAP_TermNil()) {
@ -463,13 +481,13 @@ setHorusFlag (void)
if (key == "inf_alg") {
string value ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG2)));
if ( value == "ve") {
InfAlgorithms::infAlgorithm = InfAlgorithms::VE;
Globals::infAlgorithm = InfAlgorithms::VE;
} else if (value == "bn_bp") {
InfAlgorithms::infAlgorithm = InfAlgorithms::BN_BP;
Globals::infAlgorithm = InfAlgorithms::BN_BP;
} else if (value == "fg_bp") {
InfAlgorithms::infAlgorithm = InfAlgorithms::FG_BP;
Globals::infAlgorithm = InfAlgorithms::FG_BP;
} else if (value == "cbp") {
InfAlgorithms::infAlgorithm = InfAlgorithms::CBP;
Globals::infAlgorithm = InfAlgorithms::CBP;
} else {
cerr << "warning: invalid value `" << value << "' " ;
cerr << "for `" << key << "'" << endl;
@ -543,19 +561,19 @@ setHorusFlag (void)
int
freeBayesNetwork (void)
{
//Statistics::writeStatisticsToFile ("stats.txt");
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
bn->freeDistributions();
delete bn;
delete (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
return TRUE;
}
int
freeParfactorGraph (void)
freeParfactors (void)
{
delete (ParfactorList*) YAP_IntOfTerm (YAP_ARG1);
LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1);
delete network->first;
delete network->second;
delete network;
return TRUE;
}
@ -564,15 +582,15 @@ freeParfactorGraph (void)
extern "C" void
init_predicates (void)
{
YAP_UserCPredicate ("create_lifted_network", createLiftedNetwork, 3);
YAP_UserCPredicate ("create_ground_network", createGroundNetwork, 2);
YAP_UserCPredicate ("set_parfactor_graph_params", setParfactorGraphParams, 2);
YAP_UserCPredicate ("set_bayes_net_params", setBayesNetParams, 2);
YAP_UserCPredicate ("run_lifted_solver", runLiftedSolver, 3);
YAP_UserCPredicate ("run_other_solvers", runOtherSolvers, 3);
YAP_UserCPredicate ("set_extra_vars_info", setExtraVarsInfo, 2);
YAP_UserCPredicate ("set_horus_flag", setHorusFlag, 2);
YAP_UserCPredicate ("free_bayesian_network", freeBayesNetwork, 1);
YAP_UserCPredicate ("free_parfactor_graph", freeParfactorGraph, 1);
YAP_UserCPredicate ("create_lifted_network", createLiftedNetwork, 3);
YAP_UserCPredicate ("create_ground_network", createGroundNetwork, 2);
YAP_UserCPredicate ("run_lifted_solver", runLiftedSolver, 3);
YAP_UserCPredicate ("run_ground_solver", runGroundSolver, 3);
YAP_UserCPredicate ("set_parfactors_params", setParfactorsParams, 2);
YAP_UserCPredicate ("set_bayes_net_params", setBayesNetParams, 2);
YAP_UserCPredicate ("set_extra_vars_info", setExtraVarsInfo, 2);
YAP_UserCPredicate ("set_horus_flag", setHorusFlag, 2);
YAP_UserCPredicate ("free_parfactors", freeParfactors, 1);
YAP_UserCPredicate ("free_bayesian_network", freeBayesNetwork, 1);
}

View File

@ -12,7 +12,9 @@
#include "Util.h"
class StatesIndexer {
class StatesIndexer
{
public:
StatesIndexer (const Ranges& ranges, bool calcOffsets = true)
@ -134,11 +136,11 @@ class StatesIndexer {
return size_ ;
}
friend ostream& operator<< (ostream &out, const StatesIndexer& idx)
friend ostream& operator<< (ostream &os, const StatesIndexer& idx)
{
out << "(" << std::setw (2) << std::setfill('0') << idx.li_ << ") " ;
out << idx.indices_;
return out;
os << "(" << std::setw (2) << std::setfill('0') << idx.li_ << ") " ;
os << idx.indices_;
return os;
}
private:
@ -274,21 +276,14 @@ class MapIndexer
index_ = 0;
}
friend ostream& operator<< (ostream &out, const MapIndexer& idx)
friend ostream& operator<< (ostream &os, const MapIndexer& idx)
{
out << "(" << std::setw (2) << std::setfill('0') << idx.index_ << ") " ;
out << idx.indices_;
return out;
os << "(" << std::setw (2) << std::setfill('0') << idx.index_ << ") " ;
os << idx.indices_;
return os;
}
private:
MapIndexer (const Ranges& ranges) :
ranges_(ranges),
indices_(ranges.size(), 0),
offsets_(ranges.size())
{
index_ = 0;
}
unsigned index_;
bool valid_;
vector<unsigned> ranges_;

View File

@ -95,26 +95,37 @@ ostream& operator<< (ostream &os, const Ground& gr)
void
ObservedFormula::addTuple (const Tuple& t)
LogVars
Substitution::getDiscardedLogVars (void) const
{
if (constr_ == 0) {
LogVars lvs (arity_);
for (unsigned i = 0; i < arity_; i++) {
lvs[i] = i;
LogVars discardedLvs;
set<LogVar> doneLvs;
unordered_map<LogVar, LogVar>::const_iterator it;
it = subs_.begin();
while (it != subs_.end()) {
if (Util::contains (doneLvs, it->second)) {
discardedLvs.push_back (it->first);
} else {
doneLvs.insert (it->second);
}
constr_ = new ConstraintTree (lvs);
it ++;
}
constr_->addTuple (t);
return discardedLvs;
}
ostream& operator<< (ostream &os, const ObservedFormula of)
ostream& operator<< (ostream &os, const Substitution& theta)
{
os << of.functor_ << "/" << of.arity_;
os << "|" << of.constr_->tupleSet();
os << " [evidence=" << of.evidence_ << "]";
unordered_map<LogVar, LogVar>::const_iterator it;
os << "[" ;
it = theta.subs_.begin();
while (it != theta.subs_.end()) {
if (it != theta.subs_.begin()) os << ", " ;
os << it->first << "->" << it->second ;
++ it;
}
os << "]" ;
return os;
}

View File

@ -18,11 +18,17 @@ class Symbol
{
public:
Symbol (void) : id_(numeric_limits<unsigned>::max()) { }
Symbol (unsigned id) : id_(id) { }
operator unsigned (void) const { return id_; }
bool valid (void) const { return id_ != numeric_limits<unsigned>::max(); }
static Symbol invalid (void) { return Symbol(); }
friend ostream& operator<< (ostream &os, const Symbol& s);
private:
unsigned id_;
};
@ -32,7 +38,9 @@ class LogVar
{
public:
LogVar (void) : id_(numeric_limits<unsigned>::max()) { }
LogVar (unsigned id) : id_(id) { }
operator unsigned (void) const { return id_; }
LogVar& operator++ (void)
@ -48,6 +56,7 @@ class LogVar
}
friend ostream& operator<< (ostream &os, const LogVar& X);
private:
unsigned id_;
};
@ -79,8 +88,8 @@ ostream& operator<< (ostream &os, const Tuple& t);
namespace LiftedUtils {
Symbol getSymbol (const string&);
void printSymbolDictionary (void);
Symbol getSymbol (const string&);
void printSymbolDictionary (void);
}
@ -89,71 +98,56 @@ class Ground
{
public:
Ground (Symbol f) : functor_(f) { }
Ground (Symbol f, const Symbols& args) : functor_(f), args_(args) { }
Symbol functor (void) const { return functor_; }
Symbols args (void) const { return args_; }
unsigned arity (void) const { return args_.size(); }
bool isAtom (void) const { return args_.size() == 0; }
Symbol functor (void) const { return functor_; }
Symbols args (void) const { return args_; }
unsigned arity (void) const { return args_.size(); }
bool isAtom (void) const { return args_.size() == 0; }
friend ostream& operator<< (ostream &os, const Ground& gr);
private:
Symbol functor_;
Symbols args_;
Symbol functor_;
Symbols args_;
};
typedef vector<Ground> Grounds;
class ConstraintTree;
class ObservedFormula
{
public:
ObservedFormula (Symbol f, unsigned a, unsigned ev)
: functor_(f), arity_(a), evidence_(ev), constr_(0) { }
ObservedFormula (Symbol f, unsigned ev, const Tuple& tuple)
: functor_(f), arity_(tuple.size()), evidence_(ev), constr_(0)
{
addTuple (tuple);
}
Symbol functor (void) const { return functor_; }
unsigned arity (void) const { return arity_; }
unsigned evidence (void) const { return evidence_; }
ConstraintTree* constr (void) const { return constr_; }
bool isAtom (void) const { return arity_ == 0; }
void addTuple (const Tuple& t);
friend ostream& operator<< (ostream &os, const ObservedFormula opv);
private:
Symbol functor_;
unsigned arity_;
unsigned evidence_;
ConstraintTree* constr_;
};
typedef vector<ObservedFormula*> ObservedFormulas;
class Substitution
{
public:
void add (LogVar X_old, LogVar X_new)
{
assert (Util::contains (subs_, X_old) == false);
subs_.insert (make_pair (X_old, X_new));
}
void rename (LogVar X_old, LogVar X_new)
{
assert (subs_.find (X_old) != subs_.end());
assert (Util::contains (subs_, X_old));
subs_.find (X_old)->second = X_new;
}
LogVar newNameFor (LogVar X) const
{
assert (subs_.find (X) != subs_.end());
assert (Util::contains (subs_, X));
return subs_.find (X)->second;
}
LogVars getDiscardedLogVars (void) const;
friend ostream& operator<< (ostream &os, const Substitution& theta);
private:
unordered_map<LogVar, LogVar> subs_;
};

View File

@ -60,7 +60,6 @@ HEADERS = \
$(srcdir)/CbpSolver.h \
$(srcdir)/FoveSolver.h \
$(srcdir)/VarNode.h \
$(srcdir)/Distribution.h \
$(srcdir)/Indexer.h \
$(srcdir)/Parfactor.h \
$(srcdir)/ProbFormula.h \

View File

@ -2,6 +2,7 @@
#include "Parfactor.h"
#include "Histogram.h"
#include "Indexer.h"
#include "Util.h"
#include "Horus.h"
@ -11,55 +12,58 @@ Parfactor::Parfactor (
const Tuples& tuples,
unsigned distId)
{
formulas_ = formulas;
params_ = params;
distId_ = distId;
args_ = formulas;
params_ = params;
distId_ = distId;
LogVars logVars;
for (unsigned i = 0; i < formulas_.size(); i++) {
ranges_.push_back (formulas_[i].range());
const LogVars& lvs = formulas_[i].logVars();
for (unsigned i = 0; i < args_.size(); i++) {
ranges_.push_back (args_[i].range());
const LogVars& lvs = args_[i].logVars();
for (unsigned j = 0; j < lvs.size(); j++) {
if (std::find (logVars.begin(), logVars.end(), lvs[j]) ==
logVars.end()) {
if (Util::contains (logVars, lvs[j]) == false) {
logVars.push_back (lvs[j]);
}
}
}
constr_ = new ConstraintTree (logVars, tuples);
assert (params_.size() == Util::expectedSize (ranges_));
}
Parfactor::Parfactor (const Parfactor* g, const Tuple& tuple)
{
formulas_ = g->formulas();
params_ = g->params();
ranges_ = g->ranges();
distId_ = g->distId();
constr_ = new ConstraintTree (g->logVars(), {tuple});
args_ = g->arguments();
params_ = g->params();
ranges_ = g->ranges();
distId_ = g->distId();
constr_ = new ConstraintTree (g->logVars(), {tuple});
assert (params_.size() == Util::expectedSize (ranges_));
}
Parfactor::Parfactor (const Parfactor* g, ConstraintTree* constr)
{
formulas_ = g->formulas();
params_ = g->params();
ranges_ = g->ranges();
distId_ = g->distId();
constr_ = constr;
args_ = g->arguments();
params_ = g->params();
ranges_ = g->ranges();
distId_ = g->distId();
constr_ = constr;
assert (params_.size() == Util::expectedSize (ranges_));
}
Parfactor::Parfactor (const Parfactor& g)
{
formulas_ = g.formulas();
params_ = g.params();
ranges_ = g.ranges();
distId_ = g.distId();
constr_ = new ConstraintTree (*g.constr());
args_ = g.arguments();
params_ = g.params();
ranges_ = g.ranges();
distId_ = g.distId();
constr_ = new ConstraintTree (*g.constr());
assert (params_.size() == Util::expectedSize (ranges_));
}
@ -75,9 +79,9 @@ LogVarSet
Parfactor::countedLogVars (void) const
{
LogVarSet set;
for (unsigned i = 0; i < formulas_.size(); i++) {
if (formulas_[i].isCounting()) {
set.insert (formulas_[i].countedLogVar());
for (unsigned i = 0; i < args_.size(); i++) {
if (args_[i].isCounting()) {
set.insert (args_[i].countedLogVar());
}
}
return set;
@ -107,14 +111,14 @@ Parfactor::elimLogVars (void) const
LogVarSet
Parfactor::exclusiveLogVars (unsigned fIdx) const
{
assert (fIdx < formulas_.size());
assert (fIdx < args_.size());
LogVarSet remaining;
for (unsigned i = 0; i < formulas_.size(); i++) {
for (unsigned i = 0; i < args_.size(); i++) {
if (i != fIdx) {
remaining |= formulas_[i].logVarSet();
remaining |= args_[i].logVarSet();
}
}
return formulas_[fIdx].logVarSet() - remaining;
return args_[fIdx].logVarSet() - remaining;
}
@ -131,44 +135,51 @@ Parfactor::setConstraintTree (ConstraintTree* newTree)
void
Parfactor::sumOut (unsigned fIdx)
{
assert (fIdx < formulas_.size());
assert (formulas_[fIdx].contains (elimLogVars()));
assert (fIdx < args_.size());
assert (args_[fIdx].contains (elimLogVars()));
LogVarSet excl = exclusiveLogVars (fIdx);
unsigned condCount = constr_->getConditionalCount (excl);
Util::pow (params_, condCount);
if (args_[fIdx].isCounting()) {
LogAware::pow (params_, constr_->getConditionalCount (
excl - args_[fIdx].countedLogVar()));
} else {
LogAware::pow (params_, constr_->getConditionalCount (excl));
}
vector<unsigned> numAssigns (ranges_[fIdx], 1);
if (formulas_[fIdx].isCounting()) {
if (args_[fIdx].isCounting()) {
unsigned N = constr_->getConditionalCount (
formulas_[fIdx].countedLogVar());
unsigned R = formulas_[fIdx].range();
unsigned H = ranges_[fIdx];
HistogramSet hs (N, R);
unsigned N_factorial = Util::factorial (N);
for (unsigned h = 0; h < H; h++) {
unsigned prod = 1;
for (unsigned r = 0; r < R; r++) {
prod *= Util::factorial (hs[r]);
args_[fIdx].countedLogVar());
unsigned R = args_[fIdx].range();
vector<double> numAssigns = HistogramSet::getNumAssigns (N, R);
StatesIndexer sindexer (ranges_, fIdx);
while (sindexer.valid()) {
unsigned h = sindexer[fIdx];
if (Globals::logDomain) {
params_[sindexer] += numAssigns[h];
} else {
params_[sindexer] *= numAssigns[h];
}
numAssigns[h] = N_factorial / prod;
hs.nextHistogram();
++ sindexer;
}
cout << endl;
}
Params copy = params_;
params_.clear();
params_.resize (copy.size() / ranges_[fIdx], 0.0);
params_.resize (copy.size() / ranges_[fIdx], LogAware::addIdenty());
MapIndexer indexer (ranges_, fIdx);
for (unsigned i = 0; i < copy.size(); i++) {
unsigned h = indexer[fIdx];
// TODO NOT LOG DOMAIN AWARE :(
params_[indexer] += numAssigns[h] * copy[i];
++ indexer;
if (Globals::logDomain) {
for (unsigned i = 0; i < copy.size(); i++) {
params_[indexer] = Util::logSum (params_[indexer], copy[i]);
++ indexer;
}
} else {
for (unsigned i = 0; i < copy.size(); i++) {
params_[indexer] += copy[i];
++ indexer;
}
}
formulas_.erase (formulas_.begin() + fIdx);
args_.erase (args_.begin() + fIdx);
ranges_.erase (ranges_.begin() + fIdx);
constr_->remove (excl);
}
@ -179,55 +190,7 @@ void
Parfactor::multiply (Parfactor& g)
{
alignAndExponentiate (this, &g);
bool sharedVars = false;
vector<unsigned> g_varpos;
const ProbFormulas& g_formulas = g.formulas();
const Params& g_params = g.params();
const Ranges& g_ranges = g.ranges();
for (unsigned i = 0; i < g_formulas.size(); i++) {
int group = g_formulas[i].group();
if (indexOfFormulaWithGroup (group) == -1) {
insertDimension (g.ranges()[i]);
formulas_.push_back (g_formulas[i]);
g_varpos.push_back (formulas_.size() - 1);
} else {
sharedVars = true;
g_varpos.push_back (indexOfFormulaWithGroup (group));
}
}
if (sharedVars == false) {
unsigned count = 0;
for (unsigned i = 0; i < params_.size(); i++) {
if (Globals::logDomain) {
params_[i] += g_params[count];
} else {
params_[i] *= g_params[count];
}
count ++;
if (count >= g_params.size()) {
count = 0;
}
}
} else {
StatesIndexer indexer (ranges_, false);
while (indexer.valid()) {
unsigned g_li = 0;
unsigned prod = 1;
for (int j = g_varpos.size() - 1; j >= 0; j--) {
g_li += indexer[g_varpos[j]] * prod;
prod *= g_ranges[j];
}
if (Globals::logDomain) {
params_[indexer] += g_params[g_li];
} else {
params_[indexer] *= g_params[g_li];
}
++ indexer;
}
}
TFactor<ProbFormula>::multiply (g);
constr_->join (g.constr(), true);
}
@ -236,7 +199,7 @@ Parfactor::multiply (Parfactor& g)
void
Parfactor::countConvert (LogVar X)
{
int fIdx = indexOfFormulaWithLogVar (X);
int fIdx = indexOfLogVar (X);
assert (fIdx != -1);
assert (constr_->isCountNormalized (X));
assert (constr_->getConditionalCount (X) > 1);
@ -248,12 +211,12 @@ Parfactor::countConvert (LogVar X)
vector<Histogram> histograms = HistogramSet::getHistograms (N, R);
StatesIndexer indexer (ranges_);
vector<Params> summout (params_.size() / R);
vector<Params> sumout (params_.size() / R);
unsigned count = 0;
while (indexer.valid()) {
summout[count].reserve (R);
sumout[count].reserve (R);
for (unsigned r = 0; r < R; r++) {
summout[count].push_back (params_[indexer]);
sumout[count].push_back (params_[indexer]);
indexer.increment (fIdx);
}
count ++;
@ -262,45 +225,42 @@ Parfactor::countConvert (LogVar X)
}
params_.clear();
params_.reserve (summout.size() * H);
params_.reserve (sumout.size() * H);
vector<bool> mapDims (ranges_.size(), true);
ranges_[fIdx] = H;
mapDims[fIdx] = false;
MapIndexer mapIndexer (ranges_, mapDims);
MapIndexer mapIndexer (ranges_, fIdx);
while (mapIndexer.valid()) {
double prod = 1.0;
double prod = LogAware::multIdenty();
unsigned i = mapIndexer.mappedIndex();
unsigned h = mapIndexer[fIdx];
for (unsigned r = 0; r < R; r++) {
// TODO not log domain aware
prod *= Util::pow (summout[i][r], histograms[h][r]);
if (Globals::logDomain) {
prod += LogAware::pow (sumout[i][r], histograms[h][r]);
} else {
prod *= LogAware::pow (sumout[i][r], histograms[h][r]);
}
}
params_.push_back (prod);
++ mapIndexer;
}
formulas_[fIdx].setCountedLogVar (X);
args_[fIdx].setCountedLogVar (X);
}
void
Parfactor::expandPotential (
LogVar X,
LogVar X_new1,
LogVar X_new2)
Parfactor::expand (LogVar X, LogVar X_new1, LogVar X_new2)
{
int fIdx = indexOfFormulaWithLogVar (X);
int fIdx = indexOfLogVar (X);
assert (fIdx != -1);
assert (formulas_[fIdx].isCounting());
assert (args_[fIdx].isCounting());
unsigned N1 = constr_->getConditionalCount (X_new1);
unsigned N2 = constr_->getConditionalCount (X_new2);
unsigned N = N1 + N2;
unsigned R = formulas_[fIdx].range();
unsigned R = args_[fIdx].range();
unsigned H1 = HistogramSet::nrHistograms (N1, R);
unsigned H2 = HistogramSet::nrHistograms (N2, R);
unsigned H = ranges_[fIdx];
vector<Histogram> histograms = HistogramSet::getHistograms (N, R);
vector<Histogram> histograms1 = HistogramSet::getHistograms (N1, R);
@ -320,48 +280,11 @@ Parfactor::expandPotential (
}
}
unsigned size = (params_.size() / H) * H1 * H2;
Params copy = params_;
params_.clear();
params_.reserve (size);
expandPotential (fIdx, H1 * H2, sumIndexes);
unsigned prod = 1;
vector<unsigned> offsets_ (ranges_.size());
for (int i = ranges_.size() - 1; i >= 0; i--) {
offsets_[i] = prod;
prod *= ranges_[i];
}
unsigned index = 0;
ranges_[fIdx] = H1 * H2;
vector<unsigned> indices (ranges_.size(), 0);
for (unsigned k = 0; k < size; k++) {
params_.push_back (copy[index]);
for (int i = ranges_.size() - 1; i >= 0; i--) {
indices[i] ++;
if (i == fIdx) {
int diff = sumIndexes[indices[i]] - sumIndexes[indices[i] - 1];
index += diff * offsets_[i];
} else {
index += offsets_[i];
}
if (indices[i] != ranges_[i]) {
break;
} else {
if (i == fIdx) {
int diff = sumIndexes[0] - sumIndexes[indices[i]];
index += diff * offsets_[i];
} else {
index -= offsets_[i] * ranges_[i];
}
indices[i] = 0;
}
}
}
formulas_.insert (formulas_.begin() + fIdx + 1, formulas_[fIdx]);
formulas_[fIdx].rename (X, X_new1);
formulas_[fIdx + 1].rename (X, X_new2);
args_.insert (args_.begin() + fIdx + 1, args_[fIdx]);
args_[fIdx].rename (X, X_new1);
args_[fIdx + 1].rename (X, X_new2);
ranges_.insert (ranges_.begin() + fIdx + 1, H2);
ranges_[fIdx] = H1;
}
@ -371,13 +294,12 @@ Parfactor::expandPotential (
void
Parfactor::fullExpand (LogVar X)
{
int fIdx = indexOfFormulaWithLogVar (X);
int fIdx = indexOfLogVar (X);
assert (fIdx != -1);
assert (formulas_[fIdx].isCounting());
assert (args_[fIdx].isCounting());
unsigned N = constr_->getConditionalCount (X);
unsigned R = formulas_[fIdx].range();
unsigned H = ranges_[fIdx];
unsigned R = args_[fIdx].range();
vector<Histogram> originHists = HistogramSet::getHistograms (N, R);
vector<Histogram> expandHists = HistogramSet::getHistograms (1, R);
@ -400,54 +322,17 @@ Parfactor::fullExpand (LogVar X)
++ indexer;
}
unsigned size = (params_.size() / H) * std::pow (R, N);
Params copy = params_;
params_.clear();
params_.reserve (size);
expandPotential (fIdx, std::pow (R, N), sumIndexes);
unsigned prod = 1;
vector<unsigned> offsets_ (ranges_.size());
for (int i = ranges_.size() - 1; i >= 0; i--) {
offsets_[i] = prod;
prod *= ranges_[i];
}
unsigned index = 0;
ranges_[fIdx] = std::pow (R, N);
vector<unsigned> indices (ranges_.size(), 0);
for (unsigned k = 0; k < size; k++) {
params_.push_back (copy[index]);
for (int i = ranges_.size() - 1; i >= 0; i--) {
indices[i] ++;
if (i == fIdx) {
int diff = sumIndexes[indices[i]] - sumIndexes[indices[i] - 1];
index += diff * offsets_[i];
} else {
index += offsets_[i];
}
if (indices[i] != ranges_[i]) {
break;
} else {
if (i == fIdx) {
int diff = sumIndexes[0] - sumIndexes[indices[i]];
index += diff * offsets_[i];
} else {
index -= offsets_[i] * ranges_[i];
}
indices[i] = 0;
}
}
}
ProbFormula f = formulas_[fIdx];
formulas_.erase (formulas_.begin() + fIdx);
ProbFormula f = args_[fIdx];
args_.erase (args_.begin() + fIdx);
ranges_.erase (ranges_.begin() + fIdx);
LogVars newLvs = constr_->expand (X);
assert (newLvs.size() == N);
for (unsigned i = 0 ; i < N; i++) {
ProbFormula newFormula (f.functor(), f.logVars(), f.range());
newFormula.rename (X, newLvs[i]);
formulas_.insert (formulas_.begin() + fIdx + i, newFormula);
args_.insert (args_.begin() + fIdx + i, newFormula);
ranges_.insert (ranges_.begin() + fIdx + i, R);
}
}
@ -459,117 +344,43 @@ Parfactor::reorderAccordingGrounds (const Grounds& grounds)
{
ProbFormulas newFormulas;
for (unsigned i = 0; i < grounds.size(); i++) {
for (unsigned j = 0; j < formulas_.size(); j++) {
if (grounds[i].functor() == formulas_[j].functor() &&
grounds[i].arity() == formulas_[j].arity()) {
constr_->moveToTop (formulas_[j].logVars());
for (unsigned j = 0; j < args_.size(); j++) {
if (grounds[i].functor() == args_[j].functor() &&
grounds[i].arity() == args_[j].arity()) {
constr_->moveToTop (args_[j].logVars());
if (constr_->containsTuple (grounds[i].args())) {
newFormulas.push_back (formulas_[j]);
newFormulas.push_back (args_[j]);
break;
}
}
}
assert (newFormulas.size() == i + 1);
}
reorderFormulas (newFormulas);
reorderArguments (newFormulas);
}
void
Parfactor::reorderFormulas (const ProbFormulas& newFormulas)
{
assert (newFormulas.size() == formulas_.size());
if (newFormulas == formulas_) {
return;
}
Ranges newRanges;
vector<unsigned> positions;
for (unsigned i = 0; i < newFormulas.size(); i++) {
unsigned idx = indexOf (newFormulas[i]);
newRanges.push_back (ranges_[idx]);
positions.push_back (idx);
}
unsigned N = ranges_.size();
Params newParams (params_.size());
for (unsigned i = 0; i < params_.size(); i++) {
unsigned li = i;
// calculate vector index corresponding to linear index
vector<unsigned> vi (N);
for (int k = N-1; k >= 0; k--) {
vi[k] = li % ranges_[k];
li /= ranges_[k];
}
// convert permuted vector index to corresponding linear index
unsigned prod = 1;
unsigned new_li = 0;
for (int k = N - 1; k >= 0; k--) {
new_li += vi[positions[k]] * prod;
prod *= ranges_[positions[k]];
}
newParams[new_li] = params_[i];
}
formulas_ = newFormulas;
ranges_ = newRanges;
params_ = newParams;
}
void
Parfactor::absorveEvidence (unsigned fIdx, unsigned evidence)
Parfactor::absorveEvidence (const ProbFormula& formula, unsigned evidence)
{
int fIdx = indexOf (formula);
assert (fIdx != -1);
LogVarSet excl = exclusiveLogVars (fIdx);
assert (fIdx < formulas_.size());
assert (evidence < formulas_[fIdx].range());
assert (formulas_[fIdx].isCounting() == false);
assert (args_[fIdx].isCounting() == false);
assert (constr_->isCountNormalized (excl));
Util::pow (params_, constr_->getConditionalCount (excl));
Params copy = params_;
params_.clear();
params_.reserve (copy.size() / formulas_[fIdx].range());
StatesIndexer indexer (ranges_);
for (unsigned i = 0; i < evidence; i++) {
indexer.increment (fIdx);
}
while (indexer.valid()) {
params_.push_back (copy[indexer]);
indexer.incrementExcluding (fIdx);
}
formulas_.erase (formulas_.begin() + fIdx);
ranges_.erase (ranges_.begin() + fIdx);
LogAware::pow (params_, constr_->getConditionalCount (excl));
TFactor<ProbFormula>::absorveEvidence (formula, evidence);
constr_->remove (excl);
}
void
Parfactor::normalize (void)
{
Util::normalize (params_);
}
void
Parfactor::setFormulaGroup (const ProbFormula& f, int group)
{
assert (indexOf (f) != -1);
formulas_[indexOf (f)].setGroup (group);
}
void
Parfactor::setNewGroups (void)
{
for (unsigned i = 0; i < formulas_.size(); i++) {
formulas_[i].setGroup (ProbFormula::getNewGroup());
for (unsigned i = 0; i < args_.size(); i++) {
args_[i].setGroup (ProbFormula::getNewGroup());
}
}
@ -578,14 +389,14 @@ Parfactor::setNewGroups (void)
void
Parfactor::applySubstitution (const Substitution& theta)
{
for (unsigned i = 0; i < formulas_.size(); i++) {
LogVars& lvs = formulas_[i].logVars();
for (unsigned i = 0; i < args_.size(); i++) {
LogVars& lvs = args_[i].logVars();
for (unsigned j = 0; j < lvs.size(); j++) {
lvs[j] = theta.newNameFor (lvs[j]);
}
if (formulas_[i].isCounting()) {
LogVar clv = formulas_[i].countedLogVar();
formulas_[i].setCountedLogVar (theta.newNameFor (clv));
if (args_[i].isCounting()) {
LogVar clv = args_[i].countedLogVar();
args_[i].setCountedLogVar (theta.newNameFor (clv));
}
}
constr_->applySubstitution (theta);
@ -593,19 +404,29 @@ Parfactor::applySubstitution (const Substitution& theta)
bool
Parfactor::containsGround (const Ground& ground) const
int
Parfactor::findGroup (const Ground& ground) const
{
for (unsigned i = 0; i < formulas_.size(); i++) {
if (formulas_[i].functor() == ground.functor() &&
formulas_[i].arity() == ground.arity()) {
constr_->moveToTop (formulas_[i].logVars());
int group = -1;
for (unsigned i = 0; i < args_.size(); i++) {
if (args_[i].functor() == ground.functor() &&
args_[i].arity() == ground.arity()) {
constr_->moveToTop (args_[i].logVars());
if (constr_->containsTuple (ground.args())) {
return true;
group = args_[i].group();
break;
}
}
}
return false;
return group;
}
bool
Parfactor::containsGround (const Ground& ground) const
{
return findGroup (ground) != -1;
}
@ -613,8 +434,8 @@ Parfactor::containsGround (const Ground& ground) const
bool
Parfactor::containsGroup (unsigned group) const
{
for (unsigned i = 0; i < formulas_.size(); i++) {
if (formulas_[i].group() == group) {
for (unsigned i = 0; i < args_.size(); i++) {
if (args_[i].group() == group) {
return true;
}
}
@ -623,30 +444,12 @@ Parfactor::containsGroup (unsigned group) const
const ProbFormula&
Parfactor::formula (unsigned fIdx) const
{
assert (fIdx < formulas_.size());
return formulas_[fIdx];
}
unsigned
Parfactor::range (unsigned fIdx) const
{
assert (fIdx < ranges_.size());
return ranges_[fIdx];
}
unsigned
Parfactor::nrFormulas (LogVar X) const
{
unsigned count = 0;
for (unsigned i = 0; i < formulas_.size(); i++) {
if (formulas_[i].contains (X)) {
for (unsigned i = 0; i < args_.size(); i++) {
if (args_[i].contains (X)) {
count ++;
}
}
@ -656,27 +459,12 @@ Parfactor::nrFormulas (LogVar X) const
int
Parfactor::indexOf (const ProbFormula& f) const
{
int idx = -1;
for (unsigned i = 0; i < formulas_.size(); i++) {
if (f == formulas_[i]) {
idx = i;
break;
}
}
return idx;
}
int
Parfactor::indexOfFormulaWithLogVar (LogVar X) const
Parfactor::indexOfLogVar (LogVar X) const
{
int idx = -1;
assert (nrFormulas (X) == 1);
for (unsigned i = 0; i < formulas_.size(); i++) {
if (formulas_[i].contains (X)) {
for (unsigned i = 0; i < args_.size(); i++) {
if (args_[i].contains (X)) {
idx = i;
break;
}
@ -687,11 +475,11 @@ Parfactor::indexOfFormulaWithLogVar (LogVar X) const
int
Parfactor::indexOfFormulaWithGroup (unsigned group) const
Parfactor::indexOfGroup (unsigned group) const
{
int pos = -1;
for (unsigned i = 0; i < formulas_.size(); i++) {
if (formulas_[i].group() == group) {
for (unsigned i = 0; i < args_.size(); i++) {
if (args_[i].group() == group) {
pos = i;
break;
}
@ -704,9 +492,9 @@ Parfactor::indexOfFormulaWithGroup (unsigned group) const
vector<unsigned>
Parfactor::getAllGroups (void) const
{
vector<unsigned> groups (formulas_.size());
for (unsigned i = 0; i < formulas_.size(); i++) {
groups[i] = formulas_[i].group();
vector<unsigned> groups (args_.size());
for (unsigned i = 0; i < args_.size(); i++) {
groups[i] = args_[i].group();
}
return groups;
}
@ -714,13 +502,13 @@ Parfactor::getAllGroups (void) const
string
Parfactor::getHeaderString (void) const
Parfactor::getLabel (void) const
{
stringstream ss;
ss << "phi(" ;
for (unsigned i = 0; i < formulas_.size(); i++) {
for (unsigned i = 0; i < args_.size(); i++) {
if (i != 0) ss << "," ;
ss << formulas_[i];
ss << args_[i];
}
ss << ")" ;
ConstraintTree copy (*constr_);
@ -735,32 +523,35 @@ void
Parfactor::print (bool printParams) const
{
cout << "Formulas: " ;
for (unsigned i = 0; i < formulas_.size(); i++) {
for (unsigned i = 0; i < args_.size(); i++) {
if (i != 0) cout << ", " ;
cout << formulas_[i];
cout << args_[i];
}
cout << endl;
vector<string> groups;
for (unsigned i = 0; i < formulas_.size(); i++) {
groups.push_back (string ("g") + Util::toString (formulas_[i].group()));
for (unsigned i = 0; i < args_.size(); i++) {
groups.push_back (string ("g") + Util::toString (args_[i].group()));
}
cout << "Groups: " << groups << endl;
cout << "LogVars: " << constr_->logVars() << endl;
cout << "LogVars: " << constr_->logVarSet() << endl;
cout << "Ranges: " << ranges_ << endl;
if (printParams == false) {
cout << "Params: " << params_ << endl;
}
cout << "Tuples: " << constr_->tupleSet() << endl;
ConstraintTree copy (*constr_);
copy.moveToTop (copy.logVarSet().elements());
cout << "Tuples: " << copy.tupleSet() << endl;
if (printParams) {
vector<string> jointStrings;
StatesIndexer indexer (ranges_);
while (indexer.valid()) {
stringstream ss;
for (unsigned i = 0; i < formulas_.size(); i++) {
for (unsigned i = 0; i < args_.size(); i++) {
if (i != 0) ss << ", " ;
if (formulas_[i].isCounting()) {
unsigned N = constr_->getConditionalCount (formulas_[i].countedLogVar());
HistogramSet hs (N, formulas_[i].range());
if (args_[i].isCounting()) {
unsigned N = constr_->getConditionalCount (
args_[i].countedLogVar());
HistogramSet hs (N, args_[i].range());
unsigned c = 0;
while (c < indexer[i]) {
hs.nextHistogram();
@ -784,17 +575,50 @@ Parfactor::print (bool printParams) const
void
Parfactor::insertDimension (unsigned range)
Parfactor::expandPotential (
int fIdx,
unsigned newRange,
const vector<unsigned>& sumIndexes)
{
unsigned size = (params_.size() / ranges_[fIdx]) * newRange;
Params copy = params_;
params_.clear();
params_.reserve (copy.size() * range);
for (unsigned i = 0; i < copy.size(); i++) {
for (unsigned reps = 0; reps < range; reps++) {
params_.push_back (copy[i]);
params_.reserve (size);
unsigned prod = 1;
vector<unsigned> offsets_ (ranges_.size());
for (int i = ranges_.size() - 1; i >= 0; i--) {
offsets_[i] = prod;
prod *= ranges_[i];
}
unsigned index = 0;
ranges_[fIdx] = newRange;
vector<unsigned> indices (ranges_.size(), 0);
for (unsigned k = 0; k < size; k++) {
params_.push_back (copy[index]);
for (int i = ranges_.size() - 1; i >= 0; i--) {
indices[i] ++;
if (i == fIdx) {
assert (indices[i] - 1 < sumIndexes.size());
int diff = sumIndexes[indices[i]] - sumIndexes[indices[i] - 1];
index += diff * offsets_[i];
} else {
index += offsets_[i];
}
if (indices[i] != ranges_[i]) {
break;
} else {
if (i == fIdx) {
int diff = sumIndexes[0] - sumIndexes[indices[i]];
index += diff * offsets_[i];
} else {
index -= offsets_[i] * ranges_[i];
}
indices[i] = 0;
}
}
}
ranges_.push_back (range);
}
@ -803,29 +627,27 @@ void
Parfactor::alignAndExponentiate (Parfactor* g1, Parfactor* g2)
{
LogVars X_1, X_2;
const ProbFormulas& formulas1 = g1->formulas();
const ProbFormulas& formulas2 = g2->formulas();
const ProbFormulas& formulas1 = g1->arguments();
const ProbFormulas& formulas2 = g2->arguments();
for (unsigned i = 0; i < formulas1.size(); i++) {
for (unsigned j = 0; j < formulas2.size(); j++) {
if (formulas1[i].group() == formulas2[j].group()) {
X_1.insert (X_1.end(),
formulas1[i].logVars().begin(),
formulas1[i].logVars().end());
X_2.insert (X_2.end(),
formulas2[j].logVars().begin(),
formulas2[j].logVars().end());
Util::addToVector (X_1, formulas1[i].logVars());
Util::addToVector (X_2, formulas2[j].logVars());
}
}
}
align (g1, X_1, g2, X_2);
LogVarSet Y_1 = g1->logVarSet() - LogVarSet (X_1);
LogVarSet Y_2 = g2->logVarSet() - LogVarSet (X_2);
assert (g1->constr()->isCountNormalized (Y_1));
assert (g2->constr()->isCountNormalized (Y_2));
unsigned condCount1 = g1->constr()->getConditionalCount (Y_1);
unsigned condCount2 = g2->constr()->getConditionalCount (Y_2);
Util::pow (g1->params(), 1.0 / condCount2);
Util::pow (g2->params(), 1.0 / condCount1);
LogAware::pow (g1->params(), 1.0 / condCount2);
LogAware::pow (g2->params(), 1.0 / condCount1);
// this must be done in the end or else X_1 and X_2
// will refer the old log var names in the code above
align (g1, X_1, g2, X_2);
}
@ -838,7 +660,6 @@ Parfactor::align (
LogVar freeLogVar = 0;
Substitution theta1;
Substitution theta2;
const LogVarSet& allLvs1 = g1->logVarSet();
for (unsigned i = 0; i < allLvs1.size(); i++) {
theta1.add (allLvs1[i], freeLogVar);
@ -850,7 +671,7 @@ Parfactor::align (
theta2.add (allLvs2[i], freeLogVar);
++ freeLogVar;
}
assert (alignLvs1.size() == alignLvs2.size());
for (unsigned i = 0; i < alignLvs1.size(); i++) {
theta1.rename (alignLvs1[i], theta2.newNameFor (alignLvs2[i]));

View File

@ -9,8 +9,9 @@
#include "LiftedUtils.h"
#include "Horus.h"
#include "Factor.h"
class Parfactor
class Parfactor : public TFactor<ProbFormula>
{
public:
Parfactor (
@ -18,27 +19,15 @@ class Parfactor
const Params&,
const Tuples&,
unsigned);
Parfactor (const Parfactor*, const Tuple&);
Parfactor (const Parfactor*, ConstraintTree*);
Parfactor (const Parfactor&);
~Parfactor (void);
ProbFormulas& formulas (void) { return formulas_; }
const ProbFormulas& formulas (void) const { return formulas_; }
unsigned nrFormulas (void) const { return formulas_.size(); }
Params& params (void) { return params_; }
const Params& params (void) const { return params_; }
unsigned size (void) const { return params_.size(); }
const Ranges& ranges (void) const { return ranges_; }
unsigned distId (void) const { return distId_; }
ConstraintTree* constr (void) { return constr_; }
const ConstraintTree* constr (void) const { return constr_; }
@ -57,64 +46,52 @@ class Parfactor
void setConstraintTree (ConstraintTree*);
void sumOut (unsigned);
void sumOut (unsigned fIdx);
void multiply (Parfactor&);
void countConvert (LogVar);
void expandPotential (LogVar, LogVar, LogVar);
void expand (LogVar, LogVar, LogVar);
void fullExpand (LogVar);
void reorderAccordingGrounds (const Grounds&);
void reorderFormulas (const ProbFormulas&);
void absorveEvidence (unsigned, unsigned);
void normalize (void);
void setFormulaGroup (const ProbFormula&, int);
void absorveEvidence (const ProbFormula&, unsigned);
void setNewGroups (void);
void applySubstitution (const Substitution&);
int findGroup (const Ground&) const;
bool containsGround (const Ground&) const;
bool containsGroup (unsigned) const;
const ProbFormula& formula (unsigned) const;
unsigned range (unsigned) const;
unsigned nrFormulas (LogVar) const;
int indexOf (const ProbFormula&) const;
int indexOfLogVar (LogVar) const;
int indexOfFormulaWithLogVar (LogVar) const;
int indexOfFormulaWithGroup (unsigned) const;
int indexOfGroup (unsigned) const;
vector<unsigned> getAllGroups (void) const;
void print (bool = false) const;
string getHeaderString (void) const;
string getLabel (void) const;
private:
void expandPotential (int fIdx, unsigned newRange,
const vector<unsigned>& sumIndexes);
static void alignAndExponentiate (Parfactor*, Parfactor*);
static void align (
Parfactor*, const LogVars&, Parfactor*, const LogVars&);
void insertDimension (unsigned);
ProbFormulas formulas_;
Ranges ranges_;
Params params_;
unsigned distId_;
ConstraintTree* constr_;
ConstraintTree* constr_;
};

View File

@ -3,9 +3,32 @@
#include "ParfactorList.h"
ParfactorList::ParfactorList (Parfactors& pfs)
ParfactorList::ParfactorList (const ParfactorList& pfList)
{
pfList_.insert (pfList_.end(), pfs.begin(), pfs.end());
ParfactorList::const_iterator it = pfList.begin();
while (it != pfList.end()) {
addShattered (new Parfactor (**it));
++ it;
}
}
ParfactorList::ParfactorList (const Parfactors& pfs)
{
add (pfs);
}
ParfactorList::~ParfactorList (void)
{
ParfactorList::const_iterator it = pfList_.begin();
while (it != pfList_.end()) {
delete *it;
++ it;
}
}
@ -14,17 +37,17 @@ void
ParfactorList::add (Parfactor* pf)
{
pf->setNewGroups();
pfList_.push_back (pf);
addToShatteredList (pf);
}
void
ParfactorList::add (Parfactors& pfs)
ParfactorList::add (const Parfactors& pfs)
{
for (unsigned i = 0; i < pfs.size(); i++) {
pfs[i]->setNewGroups();
pfList_.push_back (pfs[i]);
addToShatteredList (pfs[i]);
}
}
@ -33,7 +56,20 @@ ParfactorList::add (Parfactors& pfs)
void
ParfactorList::addShattered (Parfactor* pf)
{
assert (isAllShattered());
pfList_.push_back (pf);
assert (isAllShattered());
}
list<Parfactor*>::iterator
ParfactorList::insertShattered (
list<Parfactor*>::iterator it,
Parfactor* pf)
{
return pfList_.insert (it, pf);
assert (isAllShattered());
}
@ -47,7 +83,7 @@ ParfactorList::remove (list<Parfactor*>::iterator it)
list<Parfactor*>::iterator
ParfactorList::deleteAndRemove (list<Parfactor*>::iterator it)
ParfactorList::removeAndDelete (list<Parfactor*>::iterator it)
{
delete *it;
return pfList_.erase (it);
@ -55,58 +91,21 @@ ParfactorList::deleteAndRemove (list<Parfactor*>::iterator it)
void
ParfactorList::shatter (void)
bool
ParfactorList::isAllShattered (void) const
{
list<Parfactor*> tempList;
Parfactors newPfs;
newPfs.insert (newPfs.end(), pfList_.begin(), pfList_.end());
while (newPfs.empty() == false) {
tempList.insert (tempList.end(), newPfs.begin(), newPfs.end());
newPfs.clear();
list<Parfactor*>::iterator iter1 = tempList.begin();
while (tempList.size() > 1 && iter1 != -- tempList.end()) {
list<Parfactor*>::iterator iter2 = iter1;
++ iter2;
bool incIter1 = true;
while (iter2 != tempList.end()) {
assert (iter1 != iter2);
std::pair<Parfactors, Parfactors> res = shatter (
(*iter1)->formulas(), *iter1, (*iter2)->formulas(), *iter2);
bool incIter2 = true;
if (res.second.empty() == false) {
// cout << "second unshattered" << endl;
delete *iter2;
iter2 = tempList.erase (iter2);
incIter2 = false;
newPfs.insert (
newPfs.begin(), res.second.begin(), res.second.end());
}
if (res.first.empty() == false) {
// cout << "first unshattered" << endl;
delete *iter1;
iter1 = tempList.erase (iter1);
newPfs.insert (
newPfs.begin(), res.first.begin(), res.first.end());
incIter1 = false;
break;
}
if (incIter2) {
++ iter2;
}
}
if (incIter1) {
++ iter1;
if (pfList_.size() <= 1) {
return true;
}
vector<Parfactor*> pfs (pfList_.begin(), pfList_.end());
for (unsigned i = 0; i < pfs.size() - 1; i++) {
for (unsigned j = i + 1; j < pfs.size(); j++) {
if (isShattered (pfs[i], pfs[j]) == false) {
return false;
}
}
// cout << "|||||||||||||||||||||||||||||||||||||||||||||||||" << endl;
// cout << "||||||||||||| SHATTERING ITERATION ||||||||||||||" << endl;
// cout << "|||||||||||||||||||||||||||||||||||||||||||||||||" << endl;
// printParfactors (newPfs);
// cout << "|||||||||||||||||||||||||||||||||||||||||||||||||" << endl;
}
pfList_.clear();
pfList_.insert (pfList_.end(), tempList.begin(), tempList.end());
return true;
}
@ -123,19 +122,83 @@ ParfactorList::print (void) const
std::pair<Parfactors, Parfactors>
ParfactorList::shatter (
ProbFormulas& formulas1,
Parfactor* g1,
ProbFormulas& formulas2,
Parfactor* g2)
bool
ParfactorList::isShattered (
const Parfactor* g1,
const Parfactor* g2) const
{
assert (g1 != g2);
const ProbFormulas& fms1 = g1->arguments();
const ProbFormulas& fms2 = g2->arguments();
for (unsigned i = 0; i < fms1.size(); i++) {
for (unsigned j = 0; j < fms2.size(); j++) {
if (fms1[i].group() == fms2[j].group()) {
if (identical (
fms1[i], *(g1->constr()),
fms2[j], *(g2->constr())) == false) {
return false;
}
} else {
if (disjoint (
fms1[i], *(g1->constr()),
fms2[j], *(g2->constr())) == false) {
return false;
}
}
}
}
return true;
}
void
ParfactorList::addToShatteredList (Parfactor* g)
{
queue<Parfactor*> residuals;
residuals.push (g);
while (residuals.empty() == false) {
Parfactor* pf = residuals.front();
bool pfSplitted = false;
list<Parfactor*>::iterator pfIter;
pfIter = pfList_.begin();
while (pfIter != pfList_.end()) {
std::pair<Parfactors, Parfactors> shattRes;
shattRes = shatter (*pfIter, pf);
if (shattRes.first.empty() == false) {
pfIter = removeAndDelete (pfIter);
Util::addToQueue (residuals, shattRes.first);
} else {
++ pfIter;
}
if (shattRes.second.empty() == false) {
delete pf;
Util::addToQueue (residuals, shattRes.second);
pfSplitted = true;
break;
}
}
residuals.pop();
if (pfSplitted == false) {
addShattered (pf);
}
}
assert (isAllShattered());
}
std::pair<Parfactors, Parfactors>
ParfactorList::shatter (Parfactor* g1, Parfactor* g2)
{
ProbFormulas& formulas1 = g1->arguments();
ProbFormulas& formulas2 = g2->arguments();
assert (g1 != 0 && g2 != 0 && g1 != g2);
for (unsigned i = 0; i < formulas1.size(); i++) {
for (unsigned j = 0; j < formulas2.size(); j++) {
if (formulas1[i].sameSkeletonAs (formulas2[j])) {
std::pair<Parfactors, Parfactors> res
= shatter (formulas1[i], g1, formulas2[j], g2);
std::pair<Parfactors, Parfactors> res;
res = shatter (i, g1, j, g2);
if (res.first.empty() == false ||
res.second.empty() == false) {
return res;
@ -150,21 +213,22 @@ ParfactorList::shatter (
std::pair<Parfactors, Parfactors>
ParfactorList::shatter (
ProbFormula& f1,
Parfactor* g1,
ProbFormula& f2,
Parfactor* g2)
unsigned fIdx1, Parfactor* g1,
unsigned fIdx2, Parfactor* g2)
{
ProbFormula& f1 = g1->argument (fIdx1);
ProbFormula& f2 = g2->argument (fIdx2);
// cout << endl;
// cout << "-------------------------------------------------" << endl;
// Util::printDashLine();
// cout << "-> SHATTERING (#" << g1 << ", #" << g2 << ")" << endl;
// g1->print();
// cout << "-> WITH" << endl;
// g2->print();
// cout << "-> ON: " << f1.toString (g1->constr()) << endl;
// cout << "-> ON: " << f2.toString (g2->constr()) << endl;
// cout << "-------------------------------------------------" << endl;
// cout << "-> ON: " << f1 << "|" ;
// cout << g1->constr()->tupleSet (f1.logVars()) << endl;
// cout << "-> ON: " << f2 << "|" ;
// cout << g2->constr()->tupleSet (f2.logVars()) << endl;
// Util::printDashLine();
if (f1.isAtom()) {
unsigned group = (f1.group() < f2.group()) ? f1.group() : f2.group();
f1.setGroup (group);
@ -174,7 +238,7 @@ ParfactorList::shatter (
assert (g1->constr()->empty() == false);
assert (g2->constr()->empty() == false);
if (f1.group() == f2.group()) {
// assert (identical (f1, g1->constr(), f2, g2->constr()));
assert (identical (f1, *(g1->constr()), f2, *(g2->constr())));
return { };
}
@ -215,7 +279,9 @@ ParfactorList::shatter (
// exclCt2->exportToGraphViz (ss6.str().c_str(), true);
if (exclCt1->empty() && exclCt2->empty()) {
unsigned group = (f1.group() < f2.group()) ? f1.group() : f2.group();
unsigned group = (f1.group() < f2.group())
? f1.group()
: f2.group();
// identical
f1.setGroup (group);
f2.setGroup (group);
@ -235,8 +301,8 @@ ParfactorList::shatter (
} else {
group = ProbFormula::getNewGroup();
}
Parfactors res1 = shatter (g1, f1, commCt1, exclCt1, group);
Parfactors res2 = shatter (g2, f2, commCt2, exclCt2, group);
Parfactors res1 = shatter (g1, fIdx1, commCt1, exclCt1, group);
Parfactors res2 = shatter (g2, fIdx2, commCt2, exclCt2, group);
return make_pair (res1, res2);
}
@ -245,11 +311,19 @@ ParfactorList::shatter (
Parfactors
ParfactorList::shatter (
Parfactor* g,
const ProbFormula& f,
unsigned fIdx,
ConstraintTree* commCt,
ConstraintTree* exclCt,
unsigned commGroup)
{
ProbFormula& f = g->argument (fIdx);
if (exclCt->empty()) {
delete commCt;
delete exclCt;
f.setGroup (commGroup);
return { };
}
Parfactors result;
if (f.isCounting()) {
LogVar X_new1 = g->constr()->logVarSet().back() + 1;
@ -259,7 +333,7 @@ ParfactorList::shatter (
for (unsigned i = 0; i < cts.size(); i++) {
Parfactor* newPf = new Parfactor (g, cts[i]);
if (cts[i]->nrLogVars() == g->constr()->nrLogVars() + 1) {
newPf->expandPotential (f.countedLogVar(), X_new1, X_new2);
newPf->expand (f.countedLogVar(), X_new1, X_new2);
assert (g->constr()->getConditionalCount (f.countedLogVar()) ==
cts[i]->getConditionalCount (X_new1) +
cts[i]->getConditionalCount (X_new2));
@ -270,20 +344,16 @@ ParfactorList::shatter (
newPf->setNewGroups();
result.push_back (newPf);
}
delete commCt;
delete exclCt;
} else {
if (exclCt->empty()) {
delete commCt;
delete exclCt;
g->setFormulaGroup (f, commGroup);
} else {
Parfactor* newPf = new Parfactor (g, commCt);
newPf->setNewGroups();
newPf->setFormulaGroup (f, commGroup);
result.push_back (newPf);
newPf = new Parfactor (g, exclCt);
newPf->setNewGroups();
result.push_back (newPf);
}
Parfactor* newPf = new Parfactor (g, commCt);
newPf->setNewGroups();
newPf->argument (fIdx).setGroup (commGroup);
result.push_back (newPf);
newPf = new Parfactor (g, exclCt);
newPf->setNewGroups();
result.push_back (newPf);
}
return result;
}
@ -296,7 +366,7 @@ ParfactorList::unifyGroups (unsigned group1, unsigned group2)
unsigned newGroup = ProbFormula::getNewGroup();
for (ParfactorList::iterator it = pfList_.begin();
it != pfList_.end(); it++) {
ProbFormulas& formulas = (*it)->formulas();
ProbFormulas& formulas = (*it)->arguments();
for (unsigned i = 0; i < formulas.size(); i++) {
if (formulas[i].group() == group1 ||
formulas[i].group() == group2) {
@ -306,3 +376,52 @@ ParfactorList::unifyGroups (unsigned group1, unsigned group2)
}
}
bool
ParfactorList::proper (
const ProbFormula& f1, ConstraintTree c1,
const ProbFormula& f2, ConstraintTree c2) const
{
return disjoint (f1, c1, f2, c2)
|| identical (f1, c1, f2, c2);
}
bool
ParfactorList::identical (
const ProbFormula& f1, ConstraintTree c1,
const ProbFormula& f2, ConstraintTree c2) const
{
if (f1.sameSkeletonAs (f2) == false) {
return false;
}
if (f1.isAtom()) {
return true;
}
c1.moveToTop (f1.logVars());
c2.moveToTop (f2.logVars());
return ConstraintTree::identical (
&c1, &c2, f1.logVars().size());
}
bool
ParfactorList::disjoint (
const ProbFormula& f1, ConstraintTree c1,
const ProbFormula& f2, ConstraintTree c2) const
{
if (f1.sameSkeletonAs (f2) == false) {
return true;
}
if (f1.isAtom()) {
return true;
}
c1.moveToTop (f1.logVars());
c2.moveToTop (f2.logVars());
return ConstraintTree::overlap (
&c1, &c2, f1.arity()) == false;
}

View File

@ -2,6 +2,7 @@
#define HORUS_PARFACTORLIST_H
#include <list>
#include <queue>
#include "Parfactor.h"
#include "ProbFormula.h"
@ -14,56 +15,82 @@ class ParfactorList
{
public:
ParfactorList (void) { }
ParfactorList (Parfactors&);
list<Parfactor*>& getParfactors (void) { return pfList_; }
const list<Parfactor*>& getParfactors (void) const { return pfList_; }
void add (Parfactor* pf);
void add (Parfactors& pfs);
void addShattered (Parfactor* pf);
list<Parfactor*>::iterator remove (list<Parfactor*>::iterator);
list<Parfactor*>::iterator deleteAndRemove (list<Parfactor*>::iterator);
ParfactorList (const ParfactorList&);
void clear (void) { pfList_.clear(); }
unsigned size (void) const { return pfList_.size(); }
ParfactorList (const Parfactors&);
void shatter (void);
~ParfactorList (void);
const list<Parfactor*>& parfactors (void) const { return pfList_; }
void clear (void) { pfList_.clear(); }
unsigned size (void) const { return pfList_.size(); }
typedef std::list<Parfactor*>::iterator iterator;
iterator begin (void) { return pfList_.begin(); }
iterator end (void) { return pfList_.end(); }
iterator end (void) { return pfList_.end(); }
typedef std::list<Parfactor*>::const_iterator const_iterator;
const_iterator begin (void) const { return pfList_.begin(); }
const_iterator end (void) const { return pfList_.end(); }
const_iterator end (void) const { return pfList_.end(); }
void add (Parfactor* pf);
void add (const Parfactors& pfs);
void addShattered (Parfactor* pf);
list<Parfactor*>::iterator insertShattered (
list<Parfactor*>::iterator, Parfactor*);
list<Parfactor*>::iterator remove (list<Parfactor*>::iterator);
list<Parfactor*>::iterator removeAndDelete (list<Parfactor*>::iterator);
bool isAllShattered (void) const;
void print (void) const;
private:
bool isShattered (const Parfactor*, const Parfactor*) const;
static std::pair<Parfactors, Parfactors> shatter (
ProbFormulas&,
Parfactor*,
ProbFormulas&,
Parfactor*);
void addToShatteredList (Parfactor*);
std::pair<Parfactors, Parfactors> shatter (
Parfactor*, Parfactor*);
static std::pair<Parfactors, Parfactors> shatter (
ProbFormula&,
Parfactor*,
ProbFormula&,
Parfactor*);
std::pair<Parfactors, Parfactors> shatter (
unsigned, Parfactor*, unsigned, Parfactor*);
static Parfactors shatter (
Parfactor*,
const ProbFormula&,
ConstraintTree*,
ConstraintTree*,
unsigned);
Parfactors shatter (
Parfactor*,
unsigned,
ConstraintTree*,
ConstraintTree*,
unsigned);
void unifyGroups (unsigned group1, unsigned group2);
void unifyGroups (unsigned group1, unsigned group2);
list<Parfactor*> pfList_;
bool proper (
const ProbFormula&, ConstraintTree,
const ProbFormula&, ConstraintTree) const;
bool identical (
const ProbFormula&, ConstraintTree,
const ProbFormula&, ConstraintTree) const;
bool disjoint (
const ProbFormula&, ConstraintTree,
const ProbFormula&, ConstraintTree) const;
list<Parfactor*> pfList_;
};
#endif // HORUS_PARFACTORLIST_H

View File

@ -16,8 +16,7 @@ ProbFormula::sameSkeletonAs (const ProbFormula& f) const
bool
ProbFormula::contains (LogVar lv) const
{
return std::find (logVars_.begin(), logVars_.end(), lv) !=
logVars_.end();
return Util::contains (logVars_, lv);
}
@ -77,16 +76,15 @@ ProbFormula::rename (LogVar oldName, LogVar newName)
}
bool
ProbFormula::operator== (const ProbFormula& f) const
bool operator== (const ProbFormula& f1, const ProbFormula& f2)
{
return functor_ == f.functor_ && logVars_ == f.logVars_ ;
return f1.group_ == f2.group_;
//return functor_ == f.functor_ && logVars_ == f.logVars_ ;
}
ostream& operator<< (ostream &os, const ProbFormula& f)
std::ostream& operator<< (ostream &os, const ProbFormula& f)
{
os << f.functor_;
if (f.isAtom() == false) {
@ -113,3 +111,13 @@ ProbFormula::getNewGroup (void)
return freeGroup_;
}
ostream& operator<< (ostream &os, const ObservedFormula& of)
{
os << of.functor_ << "/" << of.arity_;
os << "|" << of.constr_.tupleSet();
os << " [evidence=" << of.evidence_ << "]";
return os;
}

View File

@ -8,14 +8,16 @@
#include "Horus.h"
class ProbFormula
{
public:
ProbFormula (Symbol f, const LogVars& lvs, unsigned range)
: functor_(f), logVars_(lvs), range_(range),
countedLogVar_() { }
countedLogVar_(), group_(Util::maxUnsigned()) { }
ProbFormula (Symbol f, unsigned r) : functor_(f), range_(r) { }
ProbFormula (Symbol f, unsigned r)
: functor_(f), range_(r), group_(Util::maxUnsigned()) { }
Symbol functor (void) const { return functor_; }
@ -29,9 +31,9 @@ class ProbFormula
LogVarSet logVarSet (void) const { return LogVarSet (logVars_); }
unsigned group (void) const { return groupId_; }
unsigned group (void) const { return group_; }
void setGroup (unsigned g) { groupId_ = g; }
void setGroup (unsigned g) { group_ = g; }
bool sameSkeletonAs (const ProbFormula&) const;
@ -49,23 +51,58 @@ class ProbFormula
void rename (LogVar, LogVar);
bool operator== (const ProbFormula& f) const;
friend ostream& operator<< (ostream &out, const ProbFormula& f);
static unsigned getNewGroup (void);
friend std::ostream& operator<< (ostream &os, const ProbFormula& f);
friend bool operator== (const ProbFormula& f1, const ProbFormula& f2);
private:
Symbol functor_;
LogVars logVars_;
unsigned range_;
LogVar countedLogVar_;
unsigned groupId_;
static int freeGroup_;
Symbol functor_;
LogVars logVars_;
unsigned range_;
LogVar countedLogVar_;
unsigned group_;
static int freeGroup_;
};
typedef vector<ProbFormula> ProbFormulas;
class ObservedFormula
{
public:
ObservedFormula (Symbol f, unsigned a, unsigned ev)
: functor_(f), arity_(a), evidence_(ev), constr_(a) { }
ObservedFormula (Symbol f, unsigned ev, const Tuple& tuple)
: functor_(f), arity_(tuple.size()), evidence_(ev), constr_(arity_)
{
constr_.addTuple (tuple);
}
Symbol functor (void) const { return functor_; }
unsigned arity (void) const { return arity_; }
unsigned evidence (void) const { return evidence_; }
ConstraintTree& constr (void) { return constr_; }
bool isAtom (void) const { return arity_ == 0; }
void addTuple (const Tuple& tuple) { constr_.addTuple (tuple); }
friend ostream& operator<< (ostream &os, const ObservedFormula& of);
private:
Symbol functor_;
unsigned arity_;
unsigned evidence_;
ConstraintTree constr_;
};
typedef vector<ObservedFormula> ObservedFormulas;
#endif // HORUS_PROBFORMULA_H

View File

@ -21,7 +21,7 @@ Solver::printPosterioriOf (VarId vid)
const States& states = var->states();
for (unsigned i = 0; i < states.size(); i++) {
cout << "P(" << var->label() << "=" << states[i] << ") = " ;
cout << setprecision (PRECISION) << posterioriDist[i];
cout << setprecision (Constants::PRECISION) << posterioriDist[i];
cout << endl;
}
cout << endl;
@ -45,7 +45,7 @@ Solver::printJointDistributionOf (const VarIds& vids)
vector<string> jointStrings = Util::getJointStateStrings (vars);
for (unsigned i = 0; i < jointDist.size(); i++) {
cout << "P(" << jointStrings[i] << ") = " ;
cout << setprecision (PRECISION) << jointDist[i];
cout << setprecision (Constants::PRECISION) << jointDist[i];
cout << endl;
}
cout << endl;

View File

@ -11,17 +11,20 @@ using namespace std;
class Solver
{
public:
Solver (const GraphicalModel* gm)
{
gm_ = gm;
}
virtual ~Solver() {} // to ensure that subclass destructor is called
virtual void runSolver (void) = 0;
virtual Params getPosterioriOf (VarId) = 0;
virtual Params getJointDistributionOf (const VarIds&) = 0;
Solver (const GraphicalModel* gm) : gm_(gm) { }
virtual ~Solver() { } // ensure that subclass destructor is called
virtual void runSolver (void) = 0;
virtual Params getPosterioriOf (VarId) = 0;
virtual Params getJointDistributionOf (const VarIds&) = 0;
void printAllPosterioris (void);
void printPosterioriOf (VarId vid);
void printJointDistributionOf (const VarIds& vids);
private:

View File

@ -0,0 +1,5 @@
TODO
- add way to calculate combinations and factorials with large numbers
- refactor sumOut in parfactor -> is really ugly code
- Indexer: start receiving ranges as constant reference

View File

@ -1,4 +1,7 @@
#include <limits>
#include <sstream>
#include <fstream>
#include "Util.h"
#include "Indexer.h"
@ -6,16 +9,15 @@
namespace Globals {
bool logDomain = false;
bool logDomain = false;
//InfAlgs infAlgorithm = InfAlgorithms::VE;
//InfAlgs infAlgorithm = InfAlgorithms::BN_BP;
//InfAlgs infAlgorithm = InfAlgorithms::FG_BP;
InfAlgorithms infAlgorithm = InfAlgorithms::CBP;
};
namespace InfAlgorithms {
//InfAlgs infAlgorithm = InfAlgorithms::VE;
//InfAlgs infAlgorithm = InfAlgorithms::BN_BP;
InfAlgs infAlgorithm = InfAlgorithms::FG_BP;
//InfAlgs infAlgorithm = InfAlgorithms::CBP;
}
namespace BpOptions {
@ -28,8 +30,7 @@ unsigned maxIter = 1000;
}
unordered_map<VarId,VariableInfo> GraphicalModel::varsInfo_;
unordered_map<unsigned,Distribution*> GraphicalModel::distsInfo_;
unordered_map<VarId, VarInfo> GraphicalModel::varsInfo_;
vector<NetInfo> Statistics::netInfo_;
vector<CompressInfo> Statistics::compressInfo_;
@ -58,76 +59,6 @@ fromLog (Params& v)
void
normalize (Params& v)
{
double sum;
if (Globals::logDomain) {
sum = addIdenty();
for (unsigned i = 0; i < v.size(); i++) {
logSum (sum, v[i]);
}
assert (sum != -numeric_limits<double>::infinity());
for (unsigned i = 0; i < v.size(); i++) {
v[i] -= sum;
}
} else {
sum = 0.0;
for (unsigned i = 0; i < v.size(); i++) {
sum += v[i];
}
assert (sum != 0.0);
for (unsigned i = 0; i < v.size(); i++) {
v[i] /= sum;
}
}
}
void
pow (Params& v, double expoent)
{
if (Globals::logDomain) {
for (unsigned i = 0; i < v.size(); i++) {
v[i] *= expoent;
}
} else {
for (unsigned i = 0; i < v.size(); i++) {
v[i] = std::pow (v[i], expoent);
}
}
}
void
pow (Params& v, unsigned expoent)
{
if (expoent == 1) {
return;
}
if (Globals::logDomain) {
for (unsigned i = 0; i < v.size(); i++) {
v[i] *= expoent;
}
} else {
for (unsigned i = 0; i < v.size(); i++) {
v[i] = std::pow (v[i], expoent);
}
}
}
double
pow (double p, unsigned expoent)
{
return Globals::logDomain ? p * expoent : std::pow (p, expoent);
}
double
factorial (double num)
{
@ -153,52 +84,21 @@ nrCombinations (unsigned n, unsigned r)
double
getL1Distance (const Params& v1, const Params& v2)
unsigned
expectedSize (const Ranges& ranges)
{
assert (v1.size() == v2.size());
double dist = 0.0;
if (Globals::logDomain) {
for (unsigned i = 0; i < v1.size(); i++) {
dist += abs (exp(v1[i]) - exp(v2[i]));
}
} else {
for (unsigned i = 0; i < v1.size(); i++) {
dist += abs (v1[i] - v2[i]);
}
unsigned prod = 1;
for (unsigned i = 0; i < ranges.size(); i++) {
prod *= ranges[i];
}
return dist;
}
double
getMaxNorm (const Params& v1, const Params& v2)
{
assert (v1.size() == v2.size());
double max = 0.0;
if (Globals::logDomain) {
for (unsigned i = 0; i < v1.size(); i++) {
double diff = abs (exp(v1[i]) - exp(v2[i]));
if (diff > max) {
max = diff;
}
}
} else {
for (unsigned i = 0; i < v1.size(); i++) {
double diff = abs (v1[i] - v2[i]);
if (diff > max) {
max = diff;
}
}
}
return max;
return prod;
}
unsigned
getNumberOfDigits (int number) {
getNumberOfDigits (int number)
{
unsigned count = 1;
while (number >= 10) {
number /= 10;
@ -257,6 +157,168 @@ getJointStateStrings (const VarNodes& vars)
void printHeader (string header, std::ostream& os)
{
printAsteriskLine (os);
os << header << endl;
printAsteriskLine (os);
}
void printSubHeader (string header, std::ostream& os)
{
printDashedLine (os);
os << header << endl;
printDashedLine (os);
}
void printAsteriskLine (std::ostream& os)
{
os << "********************************" ;
os << "********************************" ;
os << endl;
}
void printDashedLine (std::ostream& os)
{
os << "--------------------------------" ;
os << "--------------------------------" ;
os << endl;
}
}
namespace LogAware {
void
normalize (Params& v)
{
double sum;
if (Globals::logDomain) {
sum = LogAware::addIdenty();
for (unsigned i = 0; i < v.size(); i++) {
sum = Util::logSum (sum, v[i]);
}
assert (sum != -numeric_limits<double>::infinity());
for (unsigned i = 0; i < v.size(); i++) {
v[i] -= sum;
}
} else {
sum = 0.0;
for (unsigned i = 0; i < v.size(); i++) {
sum += v[i];
}
assert (sum != 0.0);
for (unsigned i = 0; i < v.size(); i++) {
v[i] /= sum;
}
}
}
double
getL1Distance (const Params& v1, const Params& v2)
{
assert (v1.size() == v2.size());
double dist = 0.0;
if (Globals::logDomain) {
for (unsigned i = 0; i < v1.size(); i++) {
dist += abs (exp(v1[i]) - exp(v2[i]));
}
} else {
for (unsigned i = 0; i < v1.size(); i++) {
dist += abs (v1[i] - v2[i]);
}
}
return dist;
}
double
getMaxNorm (const Params& v1, const Params& v2)
{
assert (v1.size() == v2.size());
double max = 0.0;
if (Globals::logDomain) {
for (unsigned i = 0; i < v1.size(); i++) {
double diff = abs (exp(v1[i]) - exp(v2[i]));
if (diff > max) {
max = diff;
}
}
} else {
for (unsigned i = 0; i < v1.size(); i++) {
double diff = abs (v1[i] - v2[i]);
if (diff > max) {
max = diff;
}
}
}
return max;
}
double
pow (double p, unsigned expoent)
{
return Globals::logDomain ? p * expoent : std::pow (p, expoent);
}
double
pow (double p, double expoent)
{
// assumes that `expoent' is never in log domain
return Globals::logDomain ? p * expoent : std::pow (p, expoent);
}
void
pow (Params& v, unsigned expoent)
{
if (expoent == 1) {
return;
}
if (Globals::logDomain) {
for (unsigned i = 0; i < v.size(); i++) {
v[i] *= expoent;
}
} else {
for (unsigned i = 0; i < v.size(); i++) {
v[i] = std::pow (v[i], expoent);
}
}
}
void
pow (Params& v, double expoent)
{
// assumes that `expoent' is never in log domain
if (Globals::logDomain) {
for (unsigned i = 0; i < v.size(); i++) {
v[i] *= expoent;
}
} else {
for (unsigned i = 0; i < v.size(); i++) {
v[i] = std::pow (v[i], expoent);
}
}
}
}
@ -286,8 +348,11 @@ Statistics::getPrimaryNetworksCounting (void)
void
Statistics::updateStatistics (unsigned size, bool loopy,
unsigned nIters, double time)
Statistics::updateStatistics (
unsigned size,
bool loopy,
unsigned nIters,
double time)
{
netInfo_.push_back (NetInfo (size, loopy, nIters, time));
}
@ -318,11 +383,12 @@ Statistics::writeStatisticsToFile (const char* fileName)
void
Statistics::updateCompressingStatistics (unsigned nGroundVars,
unsigned nGroundFactors,
unsigned nClusterVars,
unsigned nClusterFactors,
unsigned nWithoutNeighs) {
Statistics::updateCompressingStatistics (
unsigned nGroundVars,
unsigned nGroundFactors,
unsigned nClusterVars,
unsigned nClusterFactors,
unsigned nWithoutNeighs) {
compressInfo_.push_back (CompressInfo (nGroundVars, nGroundFactors,
nClusterVars, nClusterFactors, nWithoutNeighs));
}
@ -334,7 +400,7 @@ Statistics::getStatisticString (void)
{
stringstream ss2, ss3, ss4, ss1;
ss1 << "running mode: " ;
switch (InfAlgorithms::infAlgorithm) {
switch (Globals::infAlgorithm) {
case InfAlgorithms::VE: ss1 << "ve" << endl; break;
case InfAlgorithms::BN_BP: ss1 << "bn_bp" << endl; break;
case InfAlgorithms::FG_BP: ss1 << "fg_bp" << endl; break;
@ -342,18 +408,23 @@ Statistics::getStatisticString (void)
}
ss1 << "message schedule: " ;
switch (BpOptions::schedule) {
case BpOptions::Schedule::SEQ_FIXED: ss1 << "sequential fixed" << endl; break;
case BpOptions::Schedule::SEQ_RANDOM: ss1 << "sequential random" << endl; break;
case BpOptions::Schedule::PARALLEL: ss1 << "parallel" << endl; break;
case BpOptions::Schedule::MAX_RESIDUAL: ss1 << "max residual" << endl; break;
case BpOptions::Schedule::SEQ_FIXED:
ss1 << "sequential fixed" << endl;
break;
case BpOptions::Schedule::SEQ_RANDOM:
ss1 << "sequential random" << endl;
break;
case BpOptions::Schedule::PARALLEL:
ss1 << "parallel" << endl;
break;
case BpOptions::Schedule::MAX_RESIDUAL:
ss1 << "max residual" << endl;
break;
}
ss1 << "max iterations: " << BpOptions::maxIter << endl;
ss1 << "accuracy " << BpOptions::accuracy << endl;
ss1 << endl << endl;
ss2 << "---------------------------------------------------" << endl;
ss2 << " Network information" << endl;
ss2 << "---------------------------------------------------" << endl;
Util::printSubHeader ("Network information", ss2);
ss2 << left;
ss2 << setw (15) << "Network Size" ;
ss2 << setw (9) << "Loopy" ;
@ -387,9 +458,7 @@ Statistics::getStatisticString (void)
unsigned c1 = 0, c2 = 0, c3 = 0, c4 = 0;
if (compressInfo_.size() > 0) {
ss3 << "---------------------------------------------------" << endl;
ss3 << " Compression information" << endl;
ss3 << "---------------------------------------------------" << endl;
Util::printSubHeader ("Compress information", ss3);
ss3 << left;
ss3 << "Ground Cluster Ground Cluster Neighborless" << endl;
ss3 << "Vars Vars Factors Factors Vars" << endl;

View File

@ -1,53 +1,131 @@
#ifndef HORUS_UTIL_H
#define HORUS_UTIL_H
#include <cmath>
#include <cassert>
#include <limits>
#include <vector>
#include <set>
#include <queue>
#include <unordered_map>
#include <sstream>
#include <iostream>
#include "Horus.h"
using namespace std;
namespace Util {
void toLog (Params&);
void fromLog (Params&);
void normalize (Params&);
void logSum (double&, double);
void multiply (Params&, const Params&);
void multiply (Params&, const Params&, unsigned);
void add (Params&, const Params&);
void add (Params&, const Params&, unsigned);
void pow (Params&, double);
void pow (Params&, unsigned);
double pow (double, unsigned);
double factorial (double);
unsigned nrCombinations (unsigned, unsigned);
double getL1Distance (const Params&, const Params&);
double getMaxNorm (const Params&, const Params&);
unsigned getNumberOfDigits (int);
bool isInteger (const string&);
string parametersToString (const Params&, unsigned = PRECISION);
vector<string> getJointStateStrings (const VarNodes&);
double tl (double);
double fl (double);
double multIdenty();
double addIdenty();
double withEvidence();
double noEvidence();
double one();
double zero();
template <typename T> void addToVector (vector<T>&, const vector<T>&);
template <typename T> void addToQueue (queue<T>&, const vector<T>&);
template <typename T> bool contains (const vector<T>&, const T&);
template <typename T> bool contains (const set<T>&, const T&);
template <typename K, typename V> bool contains (
const unordered_map<K, V>&, const K&);
template <typename T> std::string toString (const T&);
void toLog (Params&);
void fromLog (Params&);
double logSum (double, double);
void multiply (Params&, const Params&);
void multiply (Params&, const Params&, unsigned);
void add (Params&, const Params&);
void add (Params&, const Params&, unsigned);
double factorial (double);
unsigned nrCombinations (unsigned, unsigned);
unsigned expectedSize (const Ranges&);
unsigned getNumberOfDigits (int);
bool isInteger (const string&);
string parametersToString (const Params&, unsigned = Constants::PRECISION);
vector<string> getJointStateStrings (const VarNodes&);
void printHeader (string, std::ostream& os = std::cout);
void printSubHeader (string, std::ostream& os = std::cout);
void printAsteriskLine (std::ostream& os = std::cout);
void printDashedLine (std::ostream& os = std::cout);
unsigned maxUnsigned (void);
};
template <class T>
std::string toString (const T& t)
template <typename T> void
Util::addToVector (vector<T>& v, const vector<T>& elements)
{
v.insert (v.end(), elements.begin(), elements.end());
}
template <typename T> void
Util::addToQueue (queue<T>& q, const vector<T>& elements)
{
for (unsigned i = 0; i < elements.size(); i++) {
q.push (elements[i]);
}
}
template <typename T> bool
Util::contains (const vector<T>& v, const T& e)
{
return std::find (v.begin(), v.end(), e) != v.end();
}
template <typename T> bool
Util::contains (const set<T>& s, const T& e)
{
return s.find (e) != s.end();
}
template <typename K, typename V> bool
Util::contains (
const unordered_map<K, V>& m, const K& k)
{
return m.find (k) != m.end();
}
template <typename T> std::string
Util::toString (const T& t)
{
std::stringstream ss;
ss << t;
return ss.str();
}
};
template <typename T>
@ -62,28 +140,31 @@ std::ostream& operator << (std::ostream& os, const vector<T>& v)
}
namespace {
const double INF = -numeric_limits<double>::infinity();
};
inline void
Util::logSum (double& x, double y)
inline double
Util::logSum (double x, double y)
{
x = log (exp (x) + exp (y)); return;
return log (exp (x) + exp (y));
assert (isfinite (x) && isfinite (y));
// If one value is much smaller than the other, keep the larger value.
if (x < (y - log (1e200))) {
x = y;
return;
return y;
}
if (y < (x - log (1e200))) {
return;
return x;
}
double diff = x - y;
assert (isfinite (diff) && isfinite (x) && isfinite (y));
if (!isfinite (exp (diff))) { // difference is too large
x = x > y ? x : y;
} else { // otherwise return the sum.
x = y + log (static_cast<double>(1.0) + exp (diff));
if (!isfinite (exp (diff))) {
// difference is too large
return x > y ? x : y;
}
// otherwise return the sum.
return y + log (static_cast<double>(1.0) + exp (diff));
}
@ -140,52 +221,87 @@ Util::add (Params& v1, const Params& v2, unsigned repetitions)
inline double
Util::tl (double v)
inline unsigned
Util::maxUnsigned (void)
{
return Globals::logDomain ? log(v) : v;
return numeric_limits<unsigned>::max();
}
inline double
Util::fl (double v)
{
return Globals::logDomain ? exp(v) : v;
}
namespace LogAware {
inline double
Util::multIdenty() {
return Globals::logDomain ? 0.0 : 1.0;
}
inline double
Util::addIdenty()
{
return Globals::logDomain ? INF : 0.0;
}
inline double
Util::withEvidence()
one()
{
return Globals::logDomain ? 0.0 : 1.0;
}
inline double
Util::noEvidence() {
return Globals::logDomain ? INF : 0.0;
}
inline double
Util::one()
{
return Globals::logDomain ? 0.0 : 1.0;
}
inline double
Util::zero() {
zero() {
return Globals::logDomain ? INF : 0.0 ;
}
inline double
addIdenty()
{
return Globals::logDomain ? INF : 0.0;
}
inline double
multIdenty()
{
return Globals::logDomain ? 0.0 : 1.0;
}
inline double
withEvidence()
{
return Globals::logDomain ? 0.0 : 1.0;
}
inline double
noEvidence() {
return Globals::logDomain ? INF : 0.0;
}
inline double
tl (double v)
{
return Globals::logDomain ? log (v) : v;
}
inline double
fl (double v)
{
return Globals::logDomain ? exp (v) : v;
}
void normalize (Params&);
double getL1Distance (const Params&, const Params&);
double getMaxNorm (const Params&, const Params&);
double pow (double, unsigned);
double pow (double, double);
void pow (Params&, unsigned);
void pow (Params&, double);
};
struct NetInfo
{
NetInfo (unsigned size, bool loopy, unsigned nIters, double time)
@ -224,11 +340,17 @@ class Statistics
{
public:
static unsigned getSolvedNetworksCounting (void);
static void incrementPrimaryNetworksCounting (void);
static unsigned getPrimaryNetworksCounting (void);
static void updateStatistics (unsigned, bool, unsigned, double);
static void printStatistics (void);
static void writeStatisticsToFile (const char*);
static void updateCompressingStatistics (
unsigned, unsigned, unsigned, unsigned, unsigned);

View File

@ -56,7 +56,7 @@ VarElimSolver::getJointDistributionOf (const VarIds& vids)
introduceEvidence();
chooseEliminationOrder (vids);
processFactorList (vids);
Params params = factorList_.back()->getParameters();
Params params = factorList_.back()->params();
if (Globals::logDomain) {
Util::fromLog (params);
}
@ -98,7 +98,7 @@ VarElimSolver::introduceEvidence (void)
varFactors_.find (varNodes[i]->varId())->second;
for (unsigned j = 0; j < idxs.size(); j++) {
Factor* factor = factorList_[idxs[j]];
if (factor->nrVariables() == 1) {
if (factor->nrArguments() == 1) {
factorList_[idxs[j]] = 0;
} else {
factorList_[idxs[j]]->absorveEvidence (
@ -121,8 +121,8 @@ VarElimSolver::chooseEliminationOrder (const VarIds& vids)
const FgVarSet& varNodes = factorGraph_->getVarNodes();
for (unsigned i = 0; i < varNodes.size(); i++) {
VarId vid = varNodes[i]->varId();
if (std::find (vids.begin(), vids.end(), vid) == vids.end()
&& !varNodes[i]->hasEvidence()) {
if (Util::contains (vids, vid) == false &&
varNodes[i]->hasEvidence() == false) {
elimOrder_.push_back (vid);
}
}
@ -154,7 +154,7 @@ VarElimSolver::processFactorList (const VarIds& vids)
}
}
finalFactor->reorderVariables (unobservedVids);
finalFactor->reorderArguments (unobservedVids);
finalFactor->normalize();
factorList_.push_back (finalFactor);
}
@ -179,10 +179,10 @@ VarElimSolver::eliminate (VarId elimVar)
factorList_[idx] = 0;
}
}
if (result != 0 && result->nrVariables() != 1) {
if (result != 0 && result->nrArguments() != 1) {
result->sumOut (vn->varId());
factorList_.push_back (result);
const VarIds& resultVarIds = result->getVarIds();
const VarIds& resultVarIds = result->arguments();
for (unsigned i = 0; i < resultVarIds.size(); i++) {
vector<unsigned>& idxs =
varFactors_.find (resultVarIds[i])->second;

View File

@ -16,18 +16,28 @@ class VarElimSolver : public Solver
{
public:
VarElimSolver (const BayesNet&);
VarElimSolver (const FactorGraph&);
~VarElimSolver (void);
void runSolver (void) { }
Params getPosterioriOf (VarId);
Params getJointDistributionOf (const VarIds&);
void runSolver (void) { }
Params getPosterioriOf (VarId);
Params getJointDistributionOf (const VarIds&);
private:
void createFactorList (void);
void introduceEvidence (void);
void chooseEliminationOrder (const VarIds&);
void processFactorList (const VarIds&);
void eliminate (VarId);
void printActiveFactors (void);
const BayesNet* bayesNet_;

View File

@ -40,8 +40,8 @@ VarNode::isValidState (int stateIndex)
bool
VarNode::isValidState (const string& stateName)
{
States states = GraphicalModel::getVariableInformation (varId_).states;
return find (states.begin(), states.end(), stateName) != states.end();
States states = GraphicalModel::getVarInformation (varId_).states;
return Util::contains (states, stateName);
}
@ -58,7 +58,7 @@ VarNode::setEvidence (int ev)
void
VarNode::setEvidence (const string& ev)
{
States states = GraphicalModel::getVariableInformation (varId_).states;
States states = GraphicalModel::getVarInformation (varId_).states;
for (unsigned i = 0; i < states.size(); i++) {
if (states[i] == ev) {
evidence_ = i;
@ -74,7 +74,7 @@ string
VarNode::label (void) const
{
if (GraphicalModel::variablesHaveInformation()) {
return GraphicalModel::getVariableInformation (varId_).label;
return GraphicalModel::getVarInformation (varId_).label;
}
stringstream ss;
ss << "x" << varId_;
@ -87,7 +87,7 @@ States
VarNode::states (void) const
{
if (GraphicalModel::variablesHaveInformation()) {
return GraphicalModel::getVariableInformation (varId_).states;
return GraphicalModel::getVarInformation (varId_).states;
}
States states;
for (unsigned i = 0; i < nrStates_; i++) {

View File

@ -1,6 +1,10 @@
#ifndef HORUS_VARNODE_H
#define HORUS_VARNODE_H
#include <cassert>
#include <iostream>
#include "Horus.h"
using namespace std;
@ -9,25 +13,28 @@ class VarNode
{
public:
VarNode (const VarNode*);
VarNode (VarId, unsigned, int = NO_EVIDENCE);
virtual ~VarNode (void) {};
bool isValidState (int);
bool isValidState (const string&);
void setEvidence (int);
void setEvidence (const string&);
string label (void) const;
States states (void) const;
VarNode (VarId, unsigned, int = Constants::NO_EVIDENCE);
unsigned varId (void) const { return varId_; }
unsigned nrStates (void) const { return nrStates_; }
bool hasEvidence (void) const { return evidence_ != NO_EVIDENCE; }
int getEvidence (void) const { return evidence_; }
unsigned getIndex (void) const { return index_; }
void setIndex (unsigned idx) { index_ = idx; }
virtual ~VarNode (void) { };
unsigned varId (void) const { return varId_; }
unsigned nrStates (void) const { return nrStates_; }
int getEvidence (void) const { return evidence_; }
unsigned getIndex (void) const { return index_; }
void setIndex (unsigned idx) { index_ = idx; }
operator unsigned () const { return index_; }
bool hasEvidence (void) const
{
return evidence_ != Constants::NO_EVIDENCE;
}
bool operator== (const VarNode& var) const
{
cout << "equal operator called" << endl;
@ -42,11 +49,23 @@ class VarNode
return varId_ != var.varId();
}
bool isValidState (int);
bool isValidState (const string&);
void setEvidence (int);
void setEvidence (const string&);
string label (void) const;
States states (void) const;
private:
VarId varId_;
unsigned nrStates_;
int evidence_;
unsigned index_;
VarId varId_;
unsigned nrStates_;
int evidence_;
unsigned index_;
};

View File

@ -13,8 +13,8 @@ function run_solver
{
if [ $2 = bp ]
then
extra_flag1=clpbn_bp:set_horus_flag\(inf_alg,$4\)
extra_flag2=clpbn_bp:set_horus_flag\(schedule,$5\)
extra_flag1=clpbn_horus:set_horus_flag\(inf_alg,$4\)
extra_flag2=clpbn_horus:set_horus_flag\(schedule,$5\)
else
extra_flag1=true
extra_flag2=true
@ -22,7 +22,7 @@ fi
/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME
[$1].
clpbn:set_clpbn_flag(solver,$2),
clpbn_bp:set_horus_flag(use_logarithms, true),
clpbn_horus:set_horus_flag(use_logarithms, true),
$extra_flag1, $extra_flag2,
run_query(_R),
open("$OUT_FILE_NAME", 'append',S),

View File

@ -13,8 +13,8 @@ function run_solver
{
if [ $2 = bp ]
then
extra_flag1=clpbn_bp:set_horus_flag\(inf_alg,$4\)
extra_flag2=clpbn_bp:set_horus_flag\(schedule,$5\)
extra_flag1=clpbn_horus:set_horus_flag\(inf_alg,$4\)
extra_flag2=clpbn_horus:set_horus_flag\(schedule,$5\)
else
extra_flag1=true
extra_flag2=true
@ -22,7 +22,7 @@ fi
/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME
[$1].
clpbn:set_clpbn_flag(solver,$2),
clpbn_bp:set_horus_flag(use_logarithms, true),
clpbn_horus:set_horus_flag(use_logarithms, true),
$extra_flag1, $extra_flag2,
run_query(_R),
open("$OUT_FILE_NAME", 'append',S),
@ -37,6 +37,8 @@ function run_all_graphs
echo "*******************************************************************" >> "$OUT_FILE_NAME"
echo "results for solver $2" >> $OUT_FILE_NAME
echo "*******************************************************************" >> "$OUT_FILE_NAME"
run_solver town_3 $1 town_3 $3 $4 $5
return
run_solver town_1000 $1 town_1000 $3 $4 $5
run_solver town_5000 $1 town_5000 $3 $4 $5
run_solver town_10000 $1 town_10000 $3 $4 $5

View File

@ -13,8 +13,8 @@ function run_solver
{
if [ $2 = bp ]
then
extra_flag1=clpbn_bp:set_horus_flag\(inf_alg,$4\)
extra_flag2=clpbn_bp:set_horus_flag\(schedule,$5\)
extra_flag1=clpbn_horus:set_horus_flag\(inf_alg,$4\)
extra_flag2=clpbn_horus:set_horus_flag\(schedule,$5\)
else
extra_flag1=true
extra_flag2=true
@ -22,7 +22,7 @@ fi
/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME
[$1].
clpbn:set_clpbn_flag(solver,$2),
clpbn_bp:set_horus_flag(use_logarithms, true),
clpbn_horus:set_horus_flag(use_logarithms, true),
$extra_flag1, $extra_flag2,
run_query(_R),
open("$OUT_FILE_NAME", 'append',S),

View File

@ -0,0 +1,29 @@
:- source.
:- style_check(all).
:- yap_flag(unknown,error).
:- yap_flag(write_strings,on).
:- use_module(library(clpbn)).
:- set_clpbn_flag(solver, bp).
:- [-schema].
lives(_joe, nyc).
run_query(Guilty) :-
guilty(joe, Guilty),
witness(nyc, t),
runall(X, ev(X)).
runall(G, Wrapper) :-
findall(G, Wrapper, L),
execute_all(L).
execute_all([]).
execute_all(G.L) :-
call(G),
execute_all(L).
ev(descn(p2, t)).
ev(descn(p3, t)).

View File

@ -9,7 +9,7 @@ OUT_FILE_NAME=results.log
rm -f $OUT_FILE_NAME
rm -f ignore.$OUT_FILE_NAME
# yap -g "['../../../../examples/School/school_32'], [missing5], use_module(library(clpbn/learning/em)), graph(L), clpbn:set_clpbn_flag(em_solver,bp), clpbn_bp:set_horus_flag(inf_alg,ve), statistics(runtime, _), em(L,0.01,10,_,Lik), statistics(runtime, [T,_])."
# yap -g "['../../../../examples/School/sch32'], [missing5], use_module(library(clpbn/learning/em)), graph(L), clpbn:set_clpbn_flag(em_solver,bp), clpbn_horus:set_horus_flag(inf_alg,fg_bp), statistics(runtime, _), em(L,0.01,10,_,Lik), statistics(runtime, [T,_])."
function run_solver
{
@ -17,11 +17,11 @@ if [ $2 = bp ]
then
if [ $4 = ve ]
then
extra_flag1=clpbn_bp:set_horus_flag\(inf_alg,$4\)
extra_flag2=clpbn_bp:set_horus_flag\(elim_heuristic,$5\)
extra_flag1=clpbn_horus:set_horus_flag\(inf_alg,$4\)
extra_flag2=clpbn_horus:set_horus_flag\(elim_heuristic,$5\)
else
extra_flag1=clpbn_bp:set_horus_flag\(inf_alg,$4\)
extra_flag2=clpbn_bp:set_horus_flag\(schedule,$5\)
extra_flag1=clpbn_horus:set_horus_flag\(inf_alg,$4\)
extra_flag2=clpbn_horus:set_horus_flag\(schedule,$5\)
fi
else
extra_flag1=true
@ -29,7 +29,7 @@ else
fi
/usr/bin/time -o "$OUT_FILE_NAME" -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF &>> "ignore.$OUT_FILE_NAME"
:- [pos:train].
:- ['../../../../examples/School/school_32'].
:- ['../../../../examples/School/sch32'].
:- use_module(library(clpbn/learning/em)).
:- use_module(library(clpbn/bp)).
[$1].
@ -57,12 +57,11 @@ function run_all_graphs
#run_solver missing50 $1 missing50 $3 $4 $5
}
run_solver missing5 ve missing5 $3 $4 $5
exit
run_all_graphs bp "hve(min_neighbors) " ve min_neighbors
#run_all_graphs bp "hve(min_neighbors) " ve min_neighbors
#run_all_graphs bp "bn_bp(seq_fixed) " bn_bp seq_fixed
#run_all_graphs bp "fg_bp(seq_fixed) " fg_bp seq_fixed
#run_all_graphs bp "cbp(seq_fixed) " cbp seq_fixed
run_all_graphs bp "fg_bp(seq_fixed) " fg_bp seq_fixed
#run_all_graphs bp "cbp(seq_fixed) " cbp seq_fixed
exit

View File

@ -0,0 +1,18 @@
:- use_module(library(pfl)).
:- set_clpbn_flag(solver,fove).
c(x1,y1,z1).
c(x1,y1,z2).
c(x2,y2,z1).
c(x3,y2,z1).
bayes p(X)::[t,f] ; [0.2, 0.4] ; [c(X,_,_)].
bayes q(Y)::[t,f] ; [0.5, 0.6] ; [c(_,Y,_)].
bayes s(Z)::[t,f] , p(X) , q(Y) ; [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8] ; [c(X,Y,Z)].
% bayes series::[t,f] , attends(X) ; [0.5, 0.6, 0.7, 0.8] ; [c(X,_)].

File diff suppressed because it is too large Load Diff

View File

@ -25,6 +25,8 @@ markov attends(P)::[t,f] , hot(W)::[t,f] ; [0.1, 0.2, 0.3, 0.4] ; [c(P,W)].
markov attends(P)::[t,f], series::[t,f] ; [0.5, 0.6, 0.7, 0.8] ; [c(P,_)].
:- clpbn_horus:set_horus_flag(use_logarithms,true).
?- series(X).

View File

@ -0,0 +1,20 @@
:- use_module(library(pfl)).
:- set_clpbn_flag(solver,ve).
%:- set_clpbn_flag(solver,fove).
t(ann).
t(dave).
% p(ann,t).
bayes p(X)::[t,f] ; [0.1, 0.3] ; [t(X)].
% use standard Prolog queries: provide evidence first.
?- p(ann,t), p(ann,X).
% ?- p(ann,X).

View File

@ -0,0 +1,32 @@
:- use_module(library(pfl)).
%:- set_clpbn_flag(solver,ve).
%:- set_clpbn_flag(solver,bp), clpbn_bp:set_horus_flag(inf_alg,ve).
:- set_clpbn_flag(solver,fove).
c(p1).
c(p2).
c(p3).
c(p4).
c(p5).
markov attends(P)::[t,f] , attr1::[t,f] ; [0.1, 0.2, 0.3, 0.4] ; [c(P)].
markov attends(P)::[t,f] , attr2::[t,f] ; [0.1, 0.2, 0.3, 0.4] ; [c(P)].
markov attends(P)::[t,f] , attr3::[t,f] ; [0.1, 0.2, 0.3, 0.4] ; [c(P)].
markov attends(P)::[t,f] , attr4::[t,f] ; [0.1, 0.2, 0.3, 0.4] ; [c(P)].
markov attends(P)::[t,f] , attr5::[t,f] ; [0.1, 0.2, 0.3, 0.4] ; [c(P)].
markov attends(P)::[t,f] , attr6::[t,f] ; [0.1, 0.2, 0.3, 0.4] ; [c(P)].
markov attends(P)::[t,f], series::[t,f] ; [0.5, 0.6, 0.7, 0.8] ; [c(P)].
%:- clpbn_horus:set_horus_flag(use_logarithms,true).
?- series(X).

View File

@ -1,9 +1,9 @@
/************************************************
/*******************************************************
(GC) First Order Variable Elimination Interface
First Order Variable Elimination Interface
**************************************************/
********************************************************/
:- module(clpbn_fove,
[fove/3,
@ -23,75 +23,88 @@
:- use_module(library(pfl),
[factor/6,
skolem/2
[factor/5,
skolem/2,
get_pfl_parameters/2
]).
:- use_module(horus).
:- use_module(horus,
[create_lifted_network/3,
set_parfactors_params/2,
run_lifted_solver/3,
free_parfactors/1
]).
:- set_horus_flag(use_logarithms, false).
%:- set_horus_flag(use_logarithms, true).
fove([[]], _, _) :- !.
fove([QueryVars], AllVars, Output) :-
writeln(queryVars:QueryVars),
writeln(allVars:AllVars),
init_fove_solver(_, AllVars, _, ParfactorGraph),
run_fove_solver([QueryVars], LPs, ParfactorGraph),
finalize_fove_solver(ParfactorGraph),
clpbn_bind_vals([QueryVars], LPs, Output).
init_fove_solver(_, AllVars, _, ParfactorList),
run_fove_solver([QueryVars], LPs, ParfactorList),
finalize_fove_solver(ParfactorList),
clpbn_bind_vals([QueryVars], LPs, Output).
init_fove_solver(_, AllAttVars, _, fove(ParfactorGraph, DistIds)) :-
writeln(allattvars:AllAttVars), writeln(''),
get_parfactors(Parfactors),
get_dist_ids(Parfactors, DistIds0),
sort(DistIds0, DistIds),
get_observed_vars(AllAttVars, ObservedVars),
init_fove_solver(_, AllAttVars, _, fove(ParfactorList, DistIds)) :-
get_parfactors(Parfactors),
get_dist_ids(Parfactors, DistIds0),
sort(DistIds0, DistIds),
get_observed_vars(AllAttVars, ObservedVars),
writeln(factors:Parfactors:'\n'),
writeln(evidence:ObservedVars:'\n'),
create_lifted_network(Parfactors,ObservedVars,ParfactorGraph).
writeln(evidence:ObservedVars:'\n'),
create_lifted_network(Parfactors,ObservedVars,ParfactorList).
:- table get_parfactors/1.
%
% enumerate all parfactors and enumerate their domain as tuples.
%
% output is list of pf(
% Id: an unique number
% Ks: a list of keys, also known as the pf formula [a(X),b(Y),c(X,Y)]
% Vs: the list of free variables [X,Y]
% Phi: the table following usual CLP(BN) convention
% Tuples: tuples with all ground bindings for variables in Vs, of the form [fv(x,y)]
% Id: an unique number
% Ks: a list of keys, also known as the pf formula [a(X),b(Y),c(X,Y)]
% Vs: the list of free variables [X,Y]
% Phi: the table following usual CLP(BN) convention
% Tuples: ground bindings for variables in Vs, of the form [fv(x,y)]
%
get_parfactors(Factors) :-
findall(F, is_factor(F), Factors).
findall(F, is_factor(F), Factors).
is_factor(pf(Id, Ks, Rs, Phi, Tuples)) :-
<<<<<<< HEAD
factor(_Type, Id, Ks, Vs, Table, Constraints),
get_ranges(Ks,Rs),
Table \= avg,
gen_table(Table, Phi),
all_tuples(Constraints, Vs, Tuples).
=======
factor(Id, Ks, Vs, Table, Constraints),
get_ranges(Ks,Rs),
Table \= avg,
gen_table(Table, Phi),
all_tuples(Constraints, Vs, Tuples).
>>>>>>> 911b241ad663a911af52babcf5d702c5239b4350
get_ranges([],[]).
get_ranges(K.Ks, Range.Rs) :- !,
skolem(K,Domain),
length(Domain,Range),
get_ranges(Ks, Rs).
skolem(K,Domain),
length(Domain,Range),
get_ranges(Ks, Rs).
gen_table(Table, Phi) :-
( is_list(Table)
->
Phi = Table
;
call(user:Table, Phi)
).
( is_list(Table)
->
Phi = Table
;
call(user:Table, Phi)
).
all_tuples(Constraints, Tuple, Tuples) :-
setof(Tuple, Constraints^run(Constraints), Tuples).
setof(Tuple, Constraints^run(Constraints), Tuples).
run([]).
@ -107,45 +120,41 @@ get_dist_ids(pf(Id, _, _, _, _).Parfactors, Id.DistIds) :-
get_observed_vars([], []).
get_observed_vars(V.AllAttVars, [K:E|ObservedVars]) :-
writeln('checking ev for':V),
clpbn:get_atts(V,[key(K)]),
( clpbn:get_atts(V,[evidence(E)]) ; pfl:evidence(K,E) ), !,
writeln('evidence!!!':K:E),
get_observed_vars(AllAttVars, ObservedVars).
clpbn:get_atts(V,[key(K)]),
( clpbn:get_atts(V,[evidence(E)]) ; pfl:evidence(K,E) ), !,
get_observed_vars(AllAttVars, ObservedVars).
get_observed_vars(V.AllAttVars, ObservedVars) :-
clpbn:get_atts(V,[key(K)]), !,
writeln('no evidence for':V:K),
get_observed_vars(AllAttVars, ObservedVars).
clpbn:get_atts(V,[key(K)]), !,
get_observed_vars(AllAttVars, ObservedVars).
get_query_vars([], []).
get_query_vars(E1.L1, E2.L2) :-
get_query_vars_2(E1,E2),
get_query_vars(L1, L2).
get_query_vars(L1, L2).
get_query_vars_2([], []).
get_query_vars_2(V.AttVars, [RV|RVs]) :-
clpbn:get_atts(V,[key(RV)]), !,
get_query_vars_2(AttVars, RVs).
clpbn:get_atts(V,[key(RV)]), !,
get_query_vars_2(AttVars, RVs).
get_dists_parameters([], []).
get_dists_parameters([Id|Ids], [dist(Id, Params)|DistsInfo]) :-
get_pfl_parameters(Id, Params),
get_dists_parameters(Ids, DistsInfo).
get_pfl_parameters(Id, Params),
get_dists_parameters(Ids, DistsInfo).
run_fove_solver(QueryVarsAtts, Solutions, fove(ParfactorGraph, DistIds)) :-
% TODO set_parfactor_graph_params
writeln(distIds:DistIds),
%get_dists_parameters(DistIds, DistParams),
%writeln(distParams:DistParams),
run_fove_solver(QueryVarsAtts, Solutions, fove(ParfactorList, DistIds)) :-
get_dists_parameters(DistIds, DistsParams),
writeln(distParams:DistsParams),
set_parfactors_params(ParfactorList, DistsParams),
get_query_vars(QueryVarsAtts, QueryVars),
writeln(queryVars:QueryVars),
run_lifted_solver(ParfactorGraph, QueryVars, Solutions).
writeln(queryVars:QueryVars), writeln(''),
run_lifted_solver(ParfactorList, QueryVars, Solutions).
finalize_fove_solver(fove(ParfactorGraph, _)) :-
free_parfactor_graph(ParfactorGraph).
finalize_fove_solver(fove(ParfactorList, _)) :-
free_parfactors(ParfactorList).

View File

@ -1,17 +1,23 @@
/*******************************************************
Interface with C++
********************************************************/
:- module(clpbn_horus,
[
create_lifted_network/3,
create_ground_network/2,
set_parfactor_graph_params/2,
set_bayes_net_params/2,
run_lifted_solver/3,
run_other_solvers/3,
set_extra_vars_info/2,
set_horus_flag/2,
free_bayesian_network/1,
free_parfactor_graph/1
]).
[create_lifted_network/3,
create_ground_network/2,
set_parfactors_params/2,
set_bayes_net_params/2,
run_lifted_solver/3,
run_ground_solver/3,
set_extra_vars_info/2,
set_horus_flag/2,
free_parfactors/1,
free_bayesian_network/1
]).
patch_things_up :-
assert_static(clpbn_horus:set_horus_flag(_,_)).
@ -23,3 +29,24 @@ warning :-
%:- set_horus_flag(inf_alg, ve).
:- set_horus_flag(inf_alg, bn_bp).
%:- set_horus_flag(inf_alg, fg_bp).
%: -set_horus_flag(inf_alg, cbp).
:- set_horus_flag(schedule, seq_fixed).
%:- set_horus_flag(schedule, seq_random).
%:- set_horus_flag(schedule, parallel).
%:- set_horus_flag(schedule, max_residual).
:- set_horus_flag(accuracy, 0.0001).
:- set_horus_flag(max_iter, 1000).
:- set_horus_flag(order_factor_variables, false).
%:- set_horus_flag(order_factor_variables, true).
:- set_horus_flag(use_logarithms, false).
% :- set_horus_flag(use_logarithms, true).