This commit is contained in:
Vítor Santos Costa 2012-04-12 17:07:19 +01:00
commit 28ce2da3dc
91 changed files with 2208 additions and 1351364 deletions

View File

@ -31,54 +31,62 @@
[check_for_agg_vars/2]).
:- use_module(library(clpbn/horus)).
:- use_module(library(charsio),
[term_to_atom/2]).
:- use_module(library(pfl),
[skolem/2,
get_pfl_parameters/2
]).
:- use_module(library(lists)).
:- use_module(library(atts)).
:- attribute id/1.
%:- set_horus_flag(inf_alg, ve).
:- set_horus_flag(inf_alg, bn_bp).
%:- set_horus_flag(inf_alg, fg_bp).
%: -set_horus_flag(inf_alg, cbp).
:- set_horus_flag(schedule, seq_fixed).
%:- set_horus_flag(schedule, seq_random).
%:- set_horus_flag(schedule, parallel).
%:- set_horus_flag(schedule, max_residual).
:- set_horus_flag(accuracy, 0.0001).
:- use_module(library(charsio),
[term_to_atom/2]).
:- use_module(library(bhash)).
:- use_module(horus,
[create_ground_network/2,
set_bayes_net_params/2,
[create_ground_network/4,
set_factors_params/2,
run_ground_solver/3,
set_extra_vars_info/2,
free_bayesian_network/1
set_vars_information/2,
free_ground_network/1
]).
:- attribute id/1.
call_bp_ground(QueryKeys, AllKeys, Factors, Evidence, Solutions) :-
call_bp_ground(QueryKeys, AllKeys, Factors, Evidence, Output) :-
b_hash_new(Hash0),
keys_to_ids(AllKeys, 0, Hash0, Hash),
InvMap =.. [view|AllKeys],
list_of_keys_to_ids(QueryKeys, Hash, QueryVarsIds),
evidence_to_ids(Evidence, Hash, EvIds, EvIdNames),
get_factors_type(Factors, Type),
evidence_to_ids(Evidence, Hash, EvidenceIds),
factors_to_ids(Factors, Hash, FactorIds),
init_graphical_model(FactorIds, Network, InvMap, EvIdNames),
run_ground_solver(Network, QueryVarsIds, EvIds, Solutions),
free_graphical_model(Network).
writeln(type:Type), writeln(''),
writeln(allKeys:AllKeys), writeln(''),
writeln(factors:Factors), writeln(''),
writeln(factorIds:FactorIds), writeln(''),
writeln(evidence:Evidence), writeln(''),
writeln(evidenceIds:EvidenceIds), writeln(''),
create_ground_network(Type, FactorIds, EvidenceIds, Network),
%get_vars_information(AllKeys, StatesNames),
%set_vars_information(AllKeys, StatesNames),
run_solver(ground(Network,Hash), QueryKeys, Solutions),
writeln(answer:Solutions),
%clpbn_bind_vals([QueryKeys], Solutions, Output).
free_ground_network(Network).
run_solver(ground(Network,Hash), QueryKeys, Solutions) :-
%get_dists_parameters(DistIds, DistsParams),
%set_factors_params(Network, DistsParams),
list_of_keys_to_ids(QueryKeys, Hash, QueryIds),
writeln(queryKeys:QueryKeys), writeln(''),
writeln(queryIds:QueryIds), writeln(''),
list_of_keys_to_ids(QueryKeys, Hash, QueryIds),
run_ground_solver(Network, [QueryIds], Solutions).
keys_to_ids([], _, Hash, Hash).
keys_to_ids([Key|AllKeys], I0, Hash0, Hash) :-
@ -86,27 +94,48 @@ keys_to_ids([Key|AllKeys], I0, Hash0, Hash) :-
I is I0+1,
keys_to_ids(AllKeys, I, HashI, Hash).
get_factors_type([f(bayes, _, _)|_], bayes) :- ! .
get_factors_type([f(markov, _, _)|_], markov) :- ! .
list_of_keys_to_ids([], _, []).
list_of_keys_to_ids([Key|QueryKeys], Hash, [Id|QueryIds]) :-
b_hash_lookup(Key, Id, Hash),
list_of_keys_to_ids(QueryKeys, Hash, QueryIds).
evidence_to_ids([], _, [], []).
evidence_to_ids([Key=V|QueryKeys], Hash, [Id=V|QueryIds], [Id=Name|QueryNames]) :-
b_hash_lookup(Key, Id, Hash),
pfl:skolem(Key,Dom),
nth0(V, Dom, Name),
evidence_to_ids(QueryKeys, Hash, QueryIds, QueryNames).
factors_to_ids([], _, []).
factors_to_ids([f(markov, Keys, CPT)|Fs], Hash, [markov(Ids, CPT)|NFs]) :-
list_of_keys_to_ids(Keys, Hash, Ids),
factors_to_ids(Fs, Hash, NFs).
factors_to_ids([f(bayes, Keys, CPT)|Fs], Hash, [bayes(Ids, CPT)|NFs]) :-
factors_to_ids([f(_, Keys, CPT)|Fs], Hash, [f(Ids, Ranges, CPT, DistId)|NFs]) :-
list_of_keys_to_ids(Keys, Hash, Ids),
DistId = 0,
get_ranges(Keys, Ranges),
factors_to_ids(Fs, Hash, NFs).
get_ranges([],[]).
get_ranges(K.Ks, Range.Rs) :- !,
skolem(K,Domain),
length(Domain,Range),
get_ranges(Ks, Rs).
evidence_to_ids([], _, []).
evidence_to_ids([Key=Ev|QueryKeys], Hash, [Id=Ev|QueryIds]) :-
b_hash_lookup(Key, Id, Hash),
evidence_to_ids(QueryKeys, Hash, QueryIds).
get_vars_information([], []).
get_vars_information(Key.QueryKeys, Domain.StatesNames) :-
pfl:skolem(Key, Domain),
get_vars_information(QueryKeys, StatesNames).
finalize_bp_solver(bp(Network, _)) :-
free_ground_network(Network).
bp([[]],_,_) :- !.
bp([QueryVars], AllVars, Output) :-
init_bp_solver(_, AllVars, _, Network),
@ -116,102 +145,22 @@ bp([QueryVars], AllVars, Output) :-
init_bp_solver(_, AllVars0, _, bp(BayesNet, DistIds)) :-
%writeln('init_bp_solver'),
check_for_agg_vars(AllVars0, AllVars),
%writeln('clpbn_vars:'), print_clpbn_vars(AllVars),
assign_ids(AllVars, 0),
%check_for_agg_vars(AllVars0, AllVars),
get_vars_info(AllVars, VarsInfo, DistIds0),
sort(DistIds0, DistIds),
create_ground_network(VarsInfo, BayesNet),
%get_extra_vars_info(AllVars, ExtraVarsInfo),
%set_extra_vars_info(BayesNet, ExtraVarsInfo),
%writeln(extravarsinfo:ExtraVarsInfo),
true.
run_bp_solver(QueryVars, Solutions, bp(Network, DistIds)) :-
%writeln('-> run_bp_solver'),
get_dists_parameters(DistIds, DistsParams),
set_bayes_net_params(Network, DistsParams),
set_factors_params(Network, DistsParams),
vars_to_ids(QueryVars, QueryVarsIds),
run_ground_solver(Network, QueryVarsIds, Solutions).
finalize_bp_solver(bp(Network, _)) :-
free_bayesian_network(Network).
assign_ids([], _).
assign_ids([V|Vs], Count) :-
put_atts(V, [id(Count)]),
Count1 is Count + 1,
assign_ids(Vs, Count1).
get_vars_info([], [], []).
get_vars_info(V.Vs,
var(VarId,DS,Ev,PIds,DistId).VarsInfo,
DistId.DistIds) :-
clpbn:get_atts(V, [dist(DistId, Parents)]), !,
get_atts(V, [id(VarId)]),
get_dist_domain_size(DistId, DS),
get_evidence(V, Ev),
vars_to_ids(Parents, PIds),
get_vars_info(Vs, VarsInfo, DistIds).
get_evidence(V, Ev) :-
clpbn:get_atts(V, [evidence(Ev)]), !.
get_evidence(_V, -1). % no evidence !!!
vars_to_ids([], []).
vars_to_ids([L|Vars], [LIds|Ids]) :-
is_list(L), !,
vars_to_ids(L, LIds),
vars_to_ids(Vars, Ids).
vars_to_ids([V|Vars], [VarId|Ids]) :-
get_atts(V, [id(VarId)]),
vars_to_ids(Vars, Ids).
get_extra_vars_info([], []).
get_extra_vars_info([V|Vs], [v(VarId, Label, Domain)|VarsInfo]) :-
get_atts(V, [id(VarId)]), !,
clpbn:get_atts(V, [key(Key), dist(DistId, _)]),
term_to_atom(Key, Label),
get_dist_domain(DistId, Domain0),
numbers_to_atoms(Domain0, Domain),
get_extra_vars_info(Vs, VarsInfo).
get_extra_vars_info([_|Vs], VarsInfo) :-
get_extra_vars_info(Vs, VarsInfo).
get_dists_parameters([],[]).
get_dists_parameters([Id|Ids], [dist(Id, Params)|DistsInfo]) :-
get_dist_params(Id, Params),
get_dists_parameters(Ids, DistsInfo).
numbers_to_atoms([], []).
numbers_to_atoms([Atom|L0], [Atom|L]) :-
atom(Atom), !,
numbers_to_atoms(L0, L).
numbers_to_atoms([Number|L0], [Atom|L]) :-
number_atom(Number, Atom),
numbers_to_atoms(L0, L).
print_clpbn_vars(Var.AllVars) :-
clpbn:get_atts(Var, [key(Key),dist(DistId,Parents)]),
parents_to_keys(Parents, ParentKeys),
writeln(Var:Key:ParentKeys:DistId),
print_clpbn_vars(AllVars).
print_clpbn_vars([]).
parents_to_keys([], []).
parents_to_keys(Var.Parents, Key.Keys) :-
clpbn:get_atts(Var, [key(Key)]),
parents_to_keys(Parents, Keys).

View File

@ -0,0 +1,77 @@
#include <cstdlib>
#include <cassert>
#include <iostream>
#include <fstream>
#include <sstream>
#include "BayesBall.h"
#include "Util.h"
FactorGraph*
BayesBall::getMinimalFactorGraph (const VarIds& queryIds)
{
assert (fg_.isFromBayesNetwork());
Scheduling scheduling;
for (unsigned i = 0; i < queryIds.size(); i++) {
assert (dag_.getNode (queryIds[i]));
DAGraphNode* n = dag_.getNode (queryIds[i]);
scheduling.push (ScheduleInfo (n, false, true));
}
while (!scheduling.empty()) {
ScheduleInfo& sch = scheduling.front();
DAGraphNode* n = sch.node;
n->setAsVisited();
if (n->hasEvidence() == false && sch.visitedFromChild) {
if (n->isMarkedOnTop() == false) {
n->markOnTop();
scheduleParents (n, scheduling);
}
if (n->isMarkedOnBottom() == false) {
n->markOnBottom();
scheduleChilds (n, scheduling);
}
}
if (sch.visitedFromParent) {
if (n->hasEvidence() && n->isMarkedOnTop() == false) {
n->markOnTop();
scheduleParents (n, scheduling);
}
if (n->hasEvidence() == false && n->isMarkedOnBottom() == false) {
n->markOnBottom();
scheduleChilds (n, scheduling);
}
}
scheduling.pop();
}
FactorGraph* fg = new FactorGraph();
constructGraph (fg);
return fg;
}
void
BayesBall::constructGraph (FactorGraph* fg) const
{
const FacNodes& facNodes = fg_.facNodes();
for (unsigned i = 0; i < facNodes.size(); i++) {
const DAGraphNode* n = dag_.getNode (
facNodes[i]->factor().argument (0));
if (n->isMarkedOnTop()) {
fg->addFactor (Factor (facNodes[i]->factor()));
} else if (n->hasEvidence() && n->isVisited()) {
VarIds varIds = { facNodes[i]->factor().argument (0) };
Ranges ranges = { facNodes[i]->factor().range (0) };
Params params (ranges[0], LogAware::noEvidence());
params[n->getEvidence()] = LogAware::withEvidence();
fg->addFactor (Factor (varIds, ranges, params));
}
}
}

View File

@ -0,0 +1,85 @@
#ifndef HORUS_BAYESBALL_H
#define HORUS_BAYESBALL_H
#include <vector>
#include <queue>
#include <list>
#include <map>
#include "FactorGraph.h"
#include "BayesNet.h"
#include "Horus.h"
using namespace std;
struct ScheduleInfo
{
ScheduleInfo (DAGraphNode* n, bool vfp, bool vfc) :
node(n), visitedFromParent(vfp), visitedFromChild(vfc) { }
DAGraphNode* node;
bool visitedFromParent;
bool visitedFromChild;
};
typedef queue<ScheduleInfo, list<ScheduleInfo>> Scheduling;
class BayesBall
{
public:
BayesBall (FactorGraph& fg)
: fg_(fg) , dag_(fg.getStructure())
{
dag_.clear();
}
FactorGraph* getMinimalFactorGraph (const VarIds&);
static FactorGraph* getMinimalFactorGraph (FactorGraph& fg, VarIds vids)
{
BayesBall bb (fg);
return bb.getMinimalFactorGraph (vids);
}
private:
void constructGraph (FactorGraph* fg) const;
void scheduleParents (const DAGraphNode* n, Scheduling& sch) const;
void scheduleChilds (const DAGraphNode* n, Scheduling& sch) const;
FactorGraph& fg_;
DAGraph& dag_;
};
inline void
BayesBall::scheduleParents (const DAGraphNode* n, Scheduling& sch) const
{
const vector<DAGraphNode*>& ps = n->parents();
for (vector<DAGraphNode*>::const_iterator it = ps.begin();
it != ps.end(); it++) {
sch.push (ScheduleInfo (*it, false, true));
}
}
inline void
BayesBall::scheduleChilds (const DAGraphNode* n, Scheduling& sch) const
{
const vector<DAGraphNode*>& cs = n->childs();
for (vector<DAGraphNode*>::const_iterator it = cs.begin();
it != cs.end(); it++) {
sch.push (ScheduleInfo (*it, true, false));
}
}
#endif // HORUS_BAYESBALL_H

View File

@ -5,354 +5,57 @@
#include <fstream>
#include <sstream>
#include "xmlParser/xmlParser.h"
#include "BayesNet.h"
#include "Util.h"
BayesNet::~BayesNet (void)
{
for (unsigned i = 0; i < nodes_.size(); i++) {
delete nodes_[i];
}
}
void
BayesNet::readFromBifFormat (const char* fileName)
DAGraph::addNode (DAGraphNode* n)
{
XMLNode xMainNode = XMLNode::openFileHelper (fileName, "BIF");
// only the first network is parsed, others are ignored
XMLNode xNode = xMainNode.getChildNode ("NETWORK");
unsigned nVars = xNode.nChildNode ("VARIABLE");
for (unsigned i = 0; i < nVars; i++) {
XMLNode var = xNode.getChildNode ("VARIABLE", i);
if (string (var.getAttribute ("TYPE")) != "nature") {
cerr << "error: only \"nature\" variables are supported" << endl;
abort();
}
States states;
string label = var.getChildNode("NAME").getText();
unsigned nrStates = var.nChildNode ("OUTCOME");
for (unsigned j = 0; j < nrStates; j++) {
if (var.getChildNode("OUTCOME", j).getText() == 0) {
stringstream ss;
ss << j + 1;
states.push_back (ss.str());
} else {
states.push_back (var.getChildNode("OUTCOME", j).getText());
}
}
addNode (label, states);
}
unsigned nDefs = xNode.nChildNode ("DEFINITION");
if (nVars != nDefs) {
cerr << "error: different number of variables and definitions" << endl;
abort();
}
for (unsigned i = 0; i < nDefs; i++) {
XMLNode def = xNode.getChildNode ("DEFINITION", i);
string label = def.getChildNode("FOR").getText();
BayesNode* node = getBayesNode (label);
if (!node) {
cerr << "error: unknow variable `" << label << "'" << endl;
abort();
}
BnNodeSet parents;
unsigned nParams = node->nrStates();
for (int j = 0; j < def.nChildNode ("GIVEN"); j++) {
string parentLabel = def.getChildNode("GIVEN", j).getText();
BayesNode* parentNode = getBayesNode (parentLabel);
if (!parentNode) {
cerr << "error: unknow variable `" << parentLabel << "'" << endl;
abort();
}
nParams *= parentNode->nrStates();
parents.push_back (parentNode);
}
node->setParents (parents);
unsigned count = 0;
Params params (nParams);
stringstream s (def.getChildNode("TABLE").getText());
while (!s.eof() && count < nParams) {
s >> params[count];
count ++;
}
if (count != nParams) {
cerr << "error: invalid number of parameters " ;
cerr << "for variable `" << label << "'" << endl;
abort();
}
params = reorderParameters (params, node->nrStates());
if (Globals::logDomain) {
Util::toLog (params);
}
node->setParams (params);
}
setIndexes();
}
BayesNode*
BayesNet::addNode (BayesNode* n)
{
varMap_.insert (make_pair (n->varId(), nodes_.size()));
assert (Util::contains (varMap_, n->varId()) == false);
nodes_.push_back (n);
return nodes_.back();
}
BayesNode*
BayesNet::addNode (string label, const States& states)
{
VarId vid = nodes_.size();
varMap_.insert (make_pair (vid, nodes_.size()));
GraphicalModel::addVariableInformation (vid, label, states);
BayesNode* node = new BayesNode (VarNode (vid, states.size()));
nodes_.push_back (node);
return node;
}
BayesNode*
BayesNet::getBayesNode (VarId vid) const
{
IndexMap::const_iterator it = varMap_.find (vid);
if (it == varMap_.end()) {
return 0;
} else {
return nodes_[it->second];
}
}
BayesNode*
BayesNet::getBayesNode (string label) const
{
BayesNode* node = 0;
for (unsigned i = 0; i < nodes_.size(); i++) {
if (nodes_[i]->label() == label) {
node = nodes_[i];
break;
}
}
return node;
}
VarNode*
BayesNet::getVariableNode (VarId vid) const
{
BayesNode* node = getBayesNode (vid);
assert (node);
return node;
}
VarNodes
BayesNet::getVariableNodes (void) const
{
VarNodes vars;
for (unsigned i = 0; i < nodes_.size(); i++) {
vars.push_back (nodes_[i]);
}
return vars;
}
const BnNodeSet&
BayesNet::getBayesNodes (void) const
{
return nodes_;
}
unsigned
BayesNet::nrNodes (void) const
{
return nodes_.size();
}
BnNodeSet
BayesNet::getRootNodes (void) const
{
BnNodeSet roots;
for (unsigned i = 0; i < nodes_.size(); i++) {
if (nodes_[i]->isRoot()) {
roots.push_back (nodes_[i]);
}
}
return roots;
}
BnNodeSet
BayesNet::getLeafNodes (void) const
{
BnNodeSet leafs;
for (unsigned i = 0; i < nodes_.size(); i++) {
if (nodes_[i]->isLeaf()) {
leafs.push_back (nodes_[i]);
}
}
return leafs;
}
BayesNet*
BayesNet::getMinimalRequesiteNetwork (VarId vid) const
{
return getMinimalRequesiteNetwork (VarIds() = {vid});
}
BayesNet*
BayesNet::getMinimalRequesiteNetwork (const VarIds& queryVarIds) const
{
BnNodeSet queryVars;
Scheduling scheduling;
for (unsigned i = 0; i < queryVarIds.size(); i++) {
BayesNode* n = getBayesNode (queryVarIds[i]);
assert (n);
queryVars.push_back (n);
scheduling.push (ScheduleInfo (n, false, true));
}
vector<StateInfo*> states (nodes_.size(), 0);
while (!scheduling.empty()) {
ScheduleInfo& sch = scheduling.front();
StateInfo* state = states[sch.node->getIndex()];
if (!state) {
state = new StateInfo();
states[sch.node->getIndex()] = state;
} else {
state->visited = true;
}
if (!sch.node->hasEvidence() && sch.visitedFromChild) {
if (!state->markedOnTop) {
state->markedOnTop = true;
scheduleParents (sch.node, scheduling);
}
if (!state->markedOnBottom) {
state->markedOnBottom = true;
scheduleChilds (sch.node, scheduling);
}
}
if (sch.visitedFromParent) {
if (sch.node->hasEvidence() && !state->markedOnTop) {
state->markedOnTop = true;
scheduleParents (sch.node, scheduling);
}
if (!sch.node->hasEvidence() && !state->markedOnBottom) {
state->markedOnBottom = true;
scheduleChilds (sch.node, scheduling);
}
}
scheduling.pop();
}
/*
cout << "\t\ttop\tbottom" << endl;
cout << "variable\t\tmarked\tmarked\tvisited\tobserved" << endl;
Util::printDashedLine();
cout << endl;
for (unsigned i = 0; i < states.size(); i++) {
cout << nodes_[i]->label() << ":\t\t" ;
if (states[i]) {
states[i]->markedOnTop ? cout << "yes\t" : cout << "no\t" ;
states[i]->markedOnBottom ? cout << "yes\t" : cout << "no\t" ;
states[i]->visited ? cout << "yes\t" : cout << "no\t" ;
nodes_[i]->hasEvidence() ? cout << "yes" : cout << "no" ;
cout << endl;
} else {
cout << "no\tno\tno\t" ;
nodes_[i]->hasEvidence() ? cout << "yes" : cout << "no" ;
cout << endl;
}
}
cout << endl;
*/
BayesNet* bn = new BayesNet();
constructGraph (bn, states);
for (unsigned i = 0; i < nodes_.size(); i++) {
delete states[i];
}
return bn;
varMap_[n->varId()] = n;
}
void
BayesNet::constructGraph (BayesNet* bn,
const vector<StateInfo*>& states) const
DAGraph::addEdge (VarId vid1, VarId vid2)
{
BnNodeSet mrnNodes;
vector<VarIds> parents;
for (unsigned i = 0; i < nodes_.size(); i++) {
bool isRequired = false;
if (states[i]) {
isRequired = (nodes_[i]->hasEvidence() && states[i]->visited)
||
states[i]->markedOnTop;
}
if (isRequired) {
parents.push_back (VarIds());
if (states[i]->markedOnTop) {
const BnNodeSet& ps = nodes_[i]->getParents();
for (unsigned j = 0; j < ps.size(); j++) {
parents.back().push_back (ps[j]->varId());
}
}
assert (bn->getBayesNode (nodes_[i]->varId()) == 0);
BayesNode* mrnNode = new BayesNode (nodes_[i]);
bn->addNode (mrnNode);
mrnNodes.push_back (mrnNode);
}
}
for (unsigned i = 0; i < mrnNodes.size(); i++) {
BnNodeSet ps;
for (unsigned j = 0; j < parents[i].size(); j++) {
assert (bn->getBayesNode (parents[i][j]) != 0);
ps.push_back (bn->getBayesNode (parents[i][j]));
}
mrnNodes[i]->setParents (ps);
}
bn->setIndexes();
unordered_map<VarId, DAGraphNode*>::iterator it1;
unordered_map<VarId, DAGraphNode*>::iterator it2;
it1 = varMap_.find (vid1);
it2 = varMap_.find (vid2);
assert (it1 != varMap_.end());
assert (it2 != varMap_.end());
it1->second->addChild (it2->second);
it2->second->addParent (it1->second);
}
bool
BayesNet::isPolyTree (void) const
const DAGraphNode*
DAGraph::getNode (VarId vid) const
{
return !containsUndirectedCycle();
unordered_map<VarId, DAGraphNode*>::const_iterator it;
it = varMap_.find (vid);
return it != varMap_.end() ? it->second : 0;
}
DAGraphNode*
DAGraph::getNode (VarId vid)
{
unordered_map<VarId, DAGraphNode*>::const_iterator it;
it = varMap_.find (vid);
return it != varMap_.end() ? it->second : 0;
}
void
BayesNet::setIndexes (void)
DAGraph::setIndexes (void)
{
for (unsigned i = 0; i < nodes_.size(); i++) {
nodes_[i]->setIndex (i);
@ -362,213 +65,43 @@ BayesNet::setIndexes (void)
void
BayesNet::printGraphicalModel (void) const
DAGraph::clear (void)
{
for (unsigned i = 0; i < nodes_.size(); i++) {
cout << *nodes_[i];
nodes_[i]->clear();
}
}
void
BayesNet::exportToGraphViz (const char* fileName,
bool showNeighborless,
const VarIds& highlightVarIds) const
DAGraph::exportToGraphViz (const char* fileName)
{
ofstream out (fileName);
if (!out.is_open()) {
cerr << "error: cannot open file to write at " ;
cerr << "BayesNet::exportToDotFile()" << endl;
cerr << "DAGraph::exportToDotFile()" << endl;
abort();
}
out << "digraph {" << endl;
out << "ranksep=1" << endl;
for (unsigned i = 0; i < nodes_.size(); i++) {
if (showNeighborless || nodes_[i]->hasNeighbors()) {
out << nodes_[i]->varId() ;
if (nodes_[i]->hasEvidence()) {
out << " [" ;
out << "label=\"" << nodes_[i]->label() << "\"," ;
out << "style=filled, fillcolor=yellow" ;
out << "]" ;
} else {
out << " [" ;
out << "label=\"" << nodes_[i]->label() << "\"" ;
out << "]" ;
if (nodes_[i]->hasEvidence()) {
out << ",style=filled, fillcolor=yellow" ;
}
out << endl;
out << "]" << endl;
}
}
for (unsigned i = 0; i < highlightVarIds.size(); i++) {
BayesNode* node = getBayesNode (highlightVarIds[i]);
if (node) {
out << node->varId() ;
out << " [shape=box3d]" << endl;
} else {
cout << "error: invalid variable id: " << highlightVarIds[i] << endl;
abort();
}
}
for (unsigned i = 0; i < nodes_.size(); i++) {
const BnNodeSet& childs = nodes_[i]->getChilds();
const vector<DAGraphNode*>& childs = nodes_[i]->childs();
for (unsigned j = 0; j < childs.size(); j++) {
out << nodes_[i]->varId() << " -> " << childs[j]->varId() << " [style=bold]" << endl ;
out << nodes_[i]->varId() << " -> " << childs[j]->varId();
out << " [style=bold]" << endl ;
}
}
out << "}" << endl;
out.close();
}
void
BayesNet::exportToBifFormat (const char* fileName) const
{
ofstream out (fileName);
if(!out.is_open()) {
cerr << "error: cannot open file to write at " ;
cerr << "BayesNet::exportToBifFile()" << endl;
abort();
}
out << "<?xml version=\"1.0\" encoding=\"US-ASCII\"?>" << endl;
out << "<BIF VERSION=\"0.3\">" << endl;
out << "<NETWORK>" << endl;
out << "<NAME>" << fileName << "</NAME>" << endl << endl;
for (unsigned i = 0; i < nodes_.size(); i++) {
out << "<VARIABLE TYPE=\"nature\">" << endl;
out << "\t<NAME>" << nodes_[i]->label() << "</NAME>" << endl;
const States& states = nodes_[i]->states();
for (unsigned j = 0; j < states.size(); j++) {
out << "\t<OUTCOME>" << states[j] << "</OUTCOME>" << endl;
}
out << "</VARIABLE>" << endl << endl;
}
for (unsigned i = 0; i < nodes_.size(); i++) {
out << "<DEFINITION>" << endl;
out << "\t<FOR>" << nodes_[i]->label() << "</FOR>" << endl;
const BnNodeSet& parents = nodes_[i]->getParents();
for (unsigned j = 0; j < parents.size(); j++) {
out << "\t<GIVEN>" << parents[j]->label();
out << "</GIVEN>" << endl;
}
Params params = revertParameterReorder (
nodes_[i]->params(), nodes_[i]->nrStates());
out << "\t<TABLE>" ;
for (unsigned j = 0; j < params.size(); j++) {
out << " " << params[j];
}
out << " </TABLE>" << endl;
out << "</DEFINITION>" << endl << endl;
}
out << "</NETWORK>" << endl;
out << "</BIF>" << endl << endl;
out.close();
}
bool
BayesNet::containsUndirectedCycle (void) const
{
vector<bool> visited (nodes_.size(), false);
for (unsigned i = 0; i < nodes_.size(); i++) {
int v = nodes_[i]->getIndex();
if (!visited[v]) {
if (containsUndirectedCycle (v, -1, visited)) {
return true;
}
}
}
return false;
}
bool
BayesNet::containsUndirectedCycle (int v, int p, vector<bool>& visited) const
{
visited[v] = true;
vector<int> adjacencies = getAdjacentNodes (v);
for (unsigned i = 0; i < adjacencies.size(); i++) {
int w = adjacencies[i];
if (!visited[w]) {
if (containsUndirectedCycle (w, v, visited)) {
return true;
}
}
else if (visited[w] && w != p) {
return true;
}
}
return false; // no cycle detected in this component
}
vector<int>
BayesNet::getAdjacentNodes (int v) const
{
vector<int> adjacencies;
const BnNodeSet& parents = nodes_[v]->getParents();
const BnNodeSet& childs = nodes_[v]->getChilds();
for (unsigned i = 0; i < parents.size(); i++) {
adjacencies.push_back (parents[i]->getIndex());
}
for (unsigned i = 0; i < childs.size(); i++) {
adjacencies.push_back (childs[i]->getIndex());
}
return adjacencies;
}
Params
BayesNet::reorderParameters (const Params& params, unsigned dsize) const
{
// the interchange format for bayesian networks keeps the probabilities
// in the following order:
// p(a1|b1,c1) p(a2|b1,c1) p(a1|b1,c2) p(a2|b1,c2) p(a1|b2,c1) p(a2|b2,c1)
// p(a1|b2,c2) p(a2|b2,c2).
//
// however, in clpbn we keep the probabilities in this order:
// p(a1|b1,c1) p(a1|b1,c2) p(a1|b2,c1) p(a1|b2,c2) p(a2|b1,c1) p(a2|b1,c2)
// p(a2|b2,c1) p(a2|b2,c2).
unsigned count = 0;
unsigned rowSize = params.size() / dsize;
Params reordered;
while (reordered.size() < params.size()) {
unsigned idx = count;
for (unsigned i = 0; i < rowSize; i++) {
reordered.push_back (params[idx]);
idx += dsize ;
}
count++;
}
return reordered;
}
Params
BayesNet::revertParameterReorder (const Params& params, unsigned dsize) const
{
unsigned count = 0;
unsigned rowSize = params.size() / dsize;
Params reordered;
while (reordered.size() < params.size()) {
unsigned idx = count;
for (unsigned i = 0; i < dsize; i++) {
reordered.push_back (params[idx]);
idx += rowSize;
}
count ++;
}
return reordered;
}

View File

@ -6,127 +6,83 @@
#include <list>
#include <map>
#include "GraphicalModel.h"
#include "BayesNode.h"
#include "Var.h"
#include "Horus.h"
using namespace std;
struct ScheduleInfo
{
ScheduleInfo (BayesNode* n, bool vfp, bool vfc) :
node(n), visitedFromParent(vfp), visitedFromChild(vfc) { }
BayesNode* node;
bool visitedFromParent;
bool visitedFromChild;
};
class Var;
struct StateInfo
{
StateInfo (void) : visited(false), markedOnTop(false),
markedOnBottom(false) { }
bool visited;
bool markedOnTop;
bool markedOnBottom;
};
typedef queue<ScheduleInfo, list<ScheduleInfo> > Scheduling;
class BayesNet : public GraphicalModel
class DAGraphNode : public Var
{
public:
BayesNet (void) { };
DAGraphNode (Var* v) : Var (v) , visited_(false),
markedOnTop_(false), markedOnBottom_(false) { }
~BayesNet (void);
const vector<DAGraphNode*>& childs (void) const { return childs_; }
void readFromBifFormat (const char*);
vector<DAGraphNode*>& childs (void) { return childs_; }
BayesNode* addNode (BayesNode*);
const vector<DAGraphNode*>& parents (void) const { return parents_; }
BayesNode* addNode (string, const States&);
vector<DAGraphNode*>& parents (void) { return parents_; }
BayesNode* getBayesNode (VarId) const;
void addParent (DAGraphNode* p) { parents_.push_back (p); }
BayesNode* getBayesNode (string) const;
void addChild (DAGraphNode* c) { childs_.push_back (c); }
VarNode* getVariableNode (VarId) const;
bool isVisited (void) const { return visited_; }
VarNodes getVariableNodes (void) const;
void setAsVisited (void) { visited_ = true; }
const BnNodeSet& getBayesNodes (void) const;
bool isMarkedOnTop (void) const { return markedOnTop_; }
unsigned nrNodes (void) const;
void markOnTop (void) { markedOnTop_ = true; }
BnNodeSet getRootNodes (void) const;
bool isMarkedOnBottom (void) const { return markedOnBottom_; }
BnNodeSet getLeafNodes (void) const;
void markOnBottom (void) { markedOnBottom_ = true; }
BayesNet* getMinimalRequesiteNetwork (VarId) const;
void clear (void) { visited_ = markedOnTop_ = markedOnBottom_ = false; }
BayesNet* getMinimalRequesiteNetwork (const VarIds&) const;
private:
bool visited_;
bool markedOnTop_;
bool markedOnBottom_;
void constructGraph (BayesNet*, const vector<StateInfo*>&) const;
vector<DAGraphNode*> childs_;
vector<DAGraphNode*> parents_;
};
bool isPolyTree (void) const;
class DAGraph
{
public:
DAGraph (void) { }
void addNode (DAGraphNode* n);
void addEdge (VarId vid1, VarId vid2);
const DAGraphNode* getNode (VarId vid) const;
DAGraphNode* getNode (VarId vid);
bool empty (void) const { return nodes_.empty(); }
void setIndexes (void);
void printGraphicalModel (void) const;
void clear (void);
void exportToGraphViz (const char*, bool = true,
const VarIds& = VarIds()) const;
void exportToBifFormat (const char*) const;
void exportToGraphViz (const char*);
private:
DISALLOW_COPY_AND_ASSIGN (BayesNet);
vector<DAGraphNode*> nodes_;
bool containsUndirectedCycle (void) const;
bool containsUndirectedCycle (int, int, vector<bool>&)const;
vector<int> getAdjacentNodes (int) const;
Params reorderParameters (const Params&, unsigned) const;
Params revertParameterReorder (const Params&, unsigned) const;
void scheduleParents (const BayesNode*, Scheduling&) const;
void scheduleChilds (const BayesNode*, Scheduling&) const;
BnNodeSet nodes_;
typedef unordered_map<unsigned, unsigned> IndexMap;
IndexMap varMap_;
unordered_map<VarId, DAGraphNode*> varMap_;
};
inline void
BayesNet::scheduleParents (const BayesNode* n, Scheduling& sch) const
{
const BnNodeSet& ps = n->getParents();
for (BnNodeSet::const_iterator it = ps.begin(); it != ps.end(); it++) {
sch.push (ScheduleInfo (*it, false, true));
}
}
inline void
BayesNet::scheduleChilds (const BayesNode* n, Scheduling& sch) const
{
const BnNodeSet& cs = n->getChilds();
for (BnNodeSet::const_iterator it = cs.begin(); it != cs.end(); it++) {
sch.push (ScheduleInfo (*it, true, false));
}
}
#endif // HORUS_BAYESNET_H

View File

@ -1,247 +0,0 @@
#include <cstdlib>
#include <cassert>
#include <iomanip>
#include <iostream>
#include <sstream>
#include "BayesNode.h"
void
BayesNode::setParams (const Params& params)
{
params_ = params;
}
void
BayesNode::setParents (const BnNodeSet& parents)
{
parents_ = parents;
for (unsigned int i = 0; i < parents.size(); i++) {
parents[i]->addChild (this);
}
}
void
BayesNode::addChild (BayesNode* node)
{
childs_.push_back (node);
}
Params
BayesNode::getRow (int rowIndex) const
{
int rowSize = getRowSize();
int offset = rowSize * rowIndex;
Params row (rowSize);
for (int i = 0; i < rowSize; i++) {
row[i] = params_[offset + i] ;
}
return row;
}
bool
BayesNode::isRoot (void)
{
return getParents().empty();
}
bool
BayesNode::isLeaf (void)
{
return getChilds().empty();
}
bool
BayesNode::hasNeighbors (void) const
{
return childs_.size() != 0 || parents_.size() != 0;
}
int
BayesNode::getCptSize (void)
{
return params_.size();
}
int
BayesNode::indexOfParent (const BayesNode* parent) const
{
for (unsigned int i = 0; i < parents_.size(); i++) {
if (parents_[i] == parent) {
return i;
}
}
return -1;
}
string
BayesNode::cptEntryToString (
int row,
const vector<unsigned>& stateConf) const
{
stringstream ss;
ss << "p(" ;
ss << states()[row];
if (parents_.size() > 0) {
ss << "|" ;
for (unsigned int i = 0; i < stateConf.size(); i++) {
if (i != 0) {
ss << ",";
}
ss << parents_[i]->states()[stateConf[i]];
}
}
ss << ")" ;
return ss.str();
}
vector<string>
BayesNode::getDomainHeaders (void) const
{
unsigned nParents = parents_.size();
unsigned rowSize = getRowSize();
unsigned nReps = 1;
vector<string> headers (rowSize);
for (int i = nParents - 1; i >= 0; i--) {
States states = parents_[i]->states();
unsigned index = 0;
while (index < rowSize) {
for (unsigned j = 0; j < parents_[i]->nrStates(); j++) {
for (unsigned r = 0; r < nReps; r++) {
if (headers[index] != "") {
headers[index] = states[j] + "," + headers[index];
} else {
headers[index] = states[j];
}
index++;
}
}
}
nReps *= parents_[i]->nrStates();
}
return headers;
}
ostream&
operator << (ostream& o, const BayesNode& node)
{
o << "variable " << node.getIndex() << endl;
o << "Var Id: " << node.varId() << endl;
o << "Label: " << node.label() << endl;
o << "Evidence: " ;
if (node.hasEvidence()) {
o << node.getEvidence();
}
else {
o << "no" ;
}
o << endl;
o << "Parents: " ;
const BnNodeSet& parents = node.getParents();
if (parents.size() != 0) {
for (unsigned int i = 0; i < parents.size() - 1; i++) {
o << parents[i]->label() << ", " ;
}
o << parents[parents.size() - 1]->label();
}
o << endl;
o << "Childs: " ;
const BnNodeSet& childs = node.getChilds();
if (childs.size() != 0) {
for (unsigned int i = 0; i < childs.size() - 1; i++) {
o << childs[i]->label() << ", " ;
}
o << childs[childs.size() - 1]->label();
}
o << endl;
o << "Domain: " ;
States states = node.states();
for (unsigned int i = 0; i < states.size() - 1; i++) {
o << states[i] << ", " ;
}
if (states.size() != 0) {
o << states[states.size() - 1];
}
o << endl;
// min width of first column
const unsigned int MIN_DOMAIN_WIDTH = 4;
// min width of following columns
const unsigned int MIN_COMBO_WIDTH = 12;
unsigned int domainWidth = states[0].length();
for (unsigned int i = 1; i < states.size(); i++) {
if (states[i].length() > domainWidth) {
domainWidth = states[i].length();
}
}
domainWidth = (domainWidth < MIN_DOMAIN_WIDTH)
? MIN_DOMAIN_WIDTH
: domainWidth;
o << left << setw (domainWidth) << "cpt" << right;
vector<int> widths;
int lineWidth = domainWidth;
vector<string> headers = node.getDomainHeaders();
if (!headers.empty()) {
for (unsigned int i = 0; i < headers.size(); i++) {
unsigned int len = headers[i].length();
int w = (len < MIN_COMBO_WIDTH) ? MIN_COMBO_WIDTH : len;
widths.push_back (w);
o << setw (w) << headers[i];
lineWidth += w;
}
o << endl;
} else {
cout << endl;
widths.push_back (domainWidth);
lineWidth += MIN_COMBO_WIDTH;
}
for (int i = 0; i < lineWidth; i++) {
o << "-" ;
}
o << endl;
for (unsigned int i = 0; i < states.size(); i++) {
Params row = node.getRow (i);
o << left << setw (domainWidth) << states[i] << right;
for (unsigned j = 0; j < node.getRowSize(); j++) {
o << setw (widths[j]) << row[j];
}
o << endl;
}
o << endl;
return o;
}

View File

@ -1,81 +0,0 @@
#ifndef HORUS_BAYESNODE_H
#define HORUS_BAYESNODE_H
#include <vector>
#include "VarNode.h"
#include "Horus.h"
using namespace std;
class BayesNode : public VarNode
{
public:
BayesNode (const VarNode& v) : VarNode (v) { }
BayesNode (const BayesNode* n) :
VarNode (n->varId(), n->nrStates(), n->getEvidence()),
params_(n->params()), distId_(n->distId()) { }
BayesNode (VarId vid, unsigned nrStates, int ev,
const Params& ps, unsigned id)
: VarNode (vid, nrStates, ev) , params_(ps), distId_(id) { }
const BnNodeSet& getParents (void) const { return parents_; }
const BnNodeSet& getChilds (void) const { return childs_; }
const Params& params (void) const { return params_; }
unsigned distId (void) const { return distId_; }
unsigned getRowSize (void) const
{
return params_.size() / nrStates();
}
double getProbability (int row, unsigned col)
{
int idx = (row * getRowSize()) + col;
return params_[idx];
}
void setParams (const Params& params);
void setParents (const BnNodeSet&);
void addChild (BayesNode*);
const Params& getParameters (void);
Params getRow (int) const;
bool isRoot (void);
bool isLeaf (void);
bool hasNeighbors (void) const;
int getCptSize (void);
int indexOfParent (const BayesNode*) const;
string cptEntryToString (int, const vector<unsigned>&) const;
friend ostream& operator << (ostream&, const BayesNode&);
private:
DISALLOW_COPY_AND_ASSIGN (BayesNode);
States getDomainHeaders (void) const;
BnNodeSet parents_;
BnNodeSet childs_;
Params params_;
unsigned distId_;
};
#endif // HORUS_BAYESNODE_H

View File

@ -1,790 +0,0 @@
#include <cstdlib>
#include <limits>
#include <time.h>
#include <algorithm>
#include <iostream>
#include <sstream>
#include <iomanip>
#include "BnBpSolver.h"
#include "Indexer.h"
BnBpSolver::BnBpSolver (const BayesNet& bn) : Solver (&bn)
{
bayesNet_ = &bn;
}
BnBpSolver::~BnBpSolver (void)
{
for (unsigned i = 0; i < nodesI_.size(); i++) {
delete nodesI_[i];
}
for (unsigned i = 0; i < links_.size(); i++) {
delete links_[i];
}
}
void
BnBpSolver::runSolver (void)
{
clock_t start;
if (Constants::COLLECT_STATS) {
start = clock();
}
initializeSolver();
runLoopySolver();
if (Constants::DEBUG >= 2) {
cout << endl;
if (nIters_ < BpOptions::maxIter) {
cout << "Belief propagation converged in " ;
cout << nIters_ << " iterations" << endl;
} else {
cout << "The maximum number of iterations was hit, terminating..." ;
cout << endl;
}
}
unsigned size = bayesNet_->nrNodes();
if (Constants::COLLECT_STATS) {
unsigned nIters = 0;
bool loopy = bayesNet_->isPolyTree() == false;
if (loopy) nIters = nIters_;
double time = (double (clock() - start)) / CLOCKS_PER_SEC;
Statistics::updateStatistics (size, loopy, nIters, time);
}
}
Params
BnBpSolver::getPosterioriOf (VarId vid)
{
BayesNode* node = bayesNet_->getBayesNode (vid);
assert (node);
return nodesI_[node->getIndex()]->getBeliefs();
}
Params
BnBpSolver::getJointDistributionOf (const VarIds& jointVarIds)
{
if (Constants::DEBUG >= 2) {
cout << "calculating joint distribution on: " ;
for (unsigned i = 0; i < jointVarIds.size(); i++) {
VarNode* var = bayesNet_->getBayesNode (jointVarIds[i]);
cout << var->label() << " " ;
}
cout << endl;
}
return getJointByConditioning (jointVarIds);
}
void
BnBpSolver::initializeSolver (void)
{
const BnNodeSet& nodes = bayesNet_->getBayesNodes();
for (unsigned i = 0; i < nodesI_.size(); i++) {
delete nodesI_[i];
}
nodesI_.clear();
nodesI_.reserve (nodes.size());
links_.clear();
sortedOrder_.clear();
linkMap_.clear();
for (unsigned i = 0; i < nodes.size(); i++) {
nodesI_.push_back (new BpNodeInfo (nodes[i]));
}
BnNodeSet roots = bayesNet_->getRootNodes();
for (unsigned i = 0; i < roots.size(); i++) {
const Params& params = roots[i]->params();
Params& piVals = ninf(roots[i])->getPiValues();
for (unsigned ri = 0; ri < roots[i]->nrStates(); ri++) {
piVals[ri] = params[ri];
}
}
for (unsigned i = 0; i < nodes.size(); i++) {
const BnNodeSet& parents = nodes[i]->getParents();
for (unsigned j = 0; j < parents.size(); j++) {
BpLink* newLink = new BpLink (
parents[j], nodes[i], LinkOrientation::DOWN);
links_.push_back (newLink);
ninf(nodes[i])->addIncomingParentLink (newLink);
ninf(parents[j])->addOutcomingChildLink (newLink);
}
const BnNodeSet& childs = nodes[i]->getChilds();
for (unsigned j = 0; j < childs.size(); j++) {
BpLink* newLink = new BpLink (
childs[j], nodes[i], LinkOrientation::UP);
links_.push_back (newLink);
ninf(nodes[i])->addIncomingChildLink (newLink);
ninf(childs[j])->addOutcomingParentLink (newLink);
}
}
for (unsigned i = 0; i < nodes.size(); i++) {
if (nodes[i]->hasEvidence()) {
Params& piVals = ninf(nodes[i])->getPiValues();
Params& ldVals = ninf(nodes[i])->getLambdaValues();
for (unsigned xi = 0; xi < nodes[i]->nrStates(); xi++) {
piVals[xi] = LogAware::noEvidence();
ldVals[xi] = LogAware::noEvidence();
}
piVals[nodes[i]->getEvidence()] = LogAware::withEvidence();
ldVals[nodes[i]->getEvidence()] = LogAware::withEvidence();
}
}
}
void
BnBpSolver::runLoopySolver()
{
nIters_ = 0;
while (!converged() && nIters_ < BpOptions::maxIter) {
nIters_++;
if (Constants::DEBUG >= 2) {
Util::printHeader ("Iteration " + nIters_);
cout << endl;
}
switch (BpOptions::schedule) {
case BpOptions::Schedule::SEQ_RANDOM:
random_shuffle (links_.begin(), links_.end());
// no break
case BpOptions::Schedule::SEQ_FIXED:
for (unsigned i = 0; i < links_.size(); i++) {
calculateAndUpdateMessage (links_[i]);
updateValues (links_[i]);
}
break;
case BpOptions::Schedule::PARALLEL:
for (unsigned i = 0; i < links_.size(); i++) {
calculateMessage (links_[i]);
}
for (unsigned i = 0; i < links_.size(); i++) {
updateMessage (links_[i]);
updateValues (links_[i]);
}
break;
case BpOptions::Schedule::MAX_RESIDUAL:
maxResidualSchedule();
break;
}
if (Constants::DEBUG >= 2) {
cout << endl;
}
}
}
bool
BnBpSolver::converged (void) const
{
// this can happen if the graph is fully disconnected
if (links_.size() == 0) {
return true;
}
if (nIters_ == 0 || nIters_ == 1) {
return false;
}
bool converged = true;
if (BpOptions::schedule == BpOptions::Schedule::MAX_RESIDUAL) {
double maxResidual = (*(sortedOrder_.begin()))->getResidual();
if (maxResidual < BpOptions::accuracy) {
converged = true;
} else {
converged = false;
}
} else {
for (unsigned i = 0; i < links_.size(); i++) {
double residual = links_[i]->getResidual();
if (Constants::DEBUG >= 2) {
cout << links_[i]->toString() + " residual change = " ;
cout << residual << endl;
}
if (residual > BpOptions::accuracy) {
converged = false;
break;
}
}
}
return converged;
}
void
BnBpSolver::maxResidualSchedule (void)
{
if (nIters_ == 1) {
for (unsigned i = 0; i < links_.size(); i++) {
calculateMessage (links_[i]);
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
linkMap_.insert (make_pair (links_[i], it));
}
return;
}
for (unsigned c = 0; c < sortedOrder_.size(); c++) {
if (Constants::DEBUG >= 2) {
cout << "current residuals:" << endl;
for (SortedOrder::iterator it = sortedOrder_.begin();
it != sortedOrder_.end(); it ++) {
cout << " " << setw (30) << left << (*it)->toString();
cout << "residual = " << (*it)->getResidual() << endl;
}
}
SortedOrder::iterator it = sortedOrder_.begin();
BpLink* link = *it;
if (link->getResidual() < BpOptions::accuracy) {
sortedOrder_.erase (it);
it = sortedOrder_.begin();
return;
}
updateMessage (link);
updateValues (link);
link->clearResidual();
sortedOrder_.erase (it);
linkMap_.find (link)->second = sortedOrder_.insert (link);
const BpLinkSet& outParentLinks =
ninf(link->getDestination())->getOutcomingParentLinks();
for (unsigned i = 0; i < outParentLinks.size(); i++) {
if (outParentLinks[i]->getDestination() != link->getSource()
&& outParentLinks[i]->getDestination()->hasEvidence() == false) {
calculateMessage (outParentLinks[i]);
BpLinkMap::iterator iter = linkMap_.find (outParentLinks[i]);
sortedOrder_.erase (iter->second);
iter->second = sortedOrder_.insert (outParentLinks[i]);
}
}
const BpLinkSet& outChildLinks =
ninf(link->getDestination())->getOutcomingChildLinks();
for (unsigned i = 0; i < outChildLinks.size(); i++) {
if (outChildLinks[i]->getDestination() != link->getSource()) {
calculateMessage (outChildLinks[i]);
BpLinkMap::iterator iter = linkMap_.find (outChildLinks[i]);
sortedOrder_.erase (iter->second);
iter->second = sortedOrder_.insert (outChildLinks[i]);
}
}
if (Constants::DEBUG >= 2) {
Util::printDashedLine();
}
}
}
void
BnBpSolver::updatePiValues (BayesNode* x)
{
// π(Xi)
if (Constants::DEBUG >= 3) {
cout << "updating " << PI_SYMBOL << " values for " << x->label() << endl;
}
Params& piValues = ninf(x)->getPiValues();
const BpLinkSet& parentLinks = ninf(x)->getIncomingParentLinks();
const BnNodeSet& ps = x->getParents();
Ranges ranges;
for (unsigned i = 0; i < ps.size(); i++) {
ranges.push_back (ps[i]->nrStates());
}
StatesIndexer indexer (ranges, false);
stringstream* calcs1 = 0;
stringstream* calcs2 = 0;
Params messageProducts (indexer.size());
for (unsigned k = 0; k < indexer.size(); k++) {
if (Constants::DEBUG >= 5) {
calcs1 = new stringstream;
calcs2 = new stringstream;
}
double messageProduct = LogAware::multIdenty();
if (Globals::logDomain) {
for (unsigned i = 0; i < parentLinks.size(); i++) {
messageProduct += parentLinks[i]->getMessage()[indexer[i]];
}
} else {
for (unsigned i = 0; i < parentLinks.size(); i++) {
messageProduct *= parentLinks[i]->getMessage()[indexer[i]];
if (Constants::DEBUG >= 5) {
if (i != 0) *calcs1 << " + " ;
if (i != 0) *calcs2 << " + " ;
*calcs1 << parentLinks[i]->toString (indexer[i]);
*calcs2 << parentLinks[i]->getMessage()[indexer[i]];
}
}
}
messageProducts[k] = messageProduct;
if (Constants::DEBUG >= 5) {
cout << " mp" << k;
cout << " = " << (*calcs1).str();
if (parentLinks.size() == 1) {
cout << " = " << messageProduct << endl;
} else {
cout << " = " << (*calcs2).str();
cout << " = " << messageProduct << endl;
}
delete calcs1;
delete calcs2;
}
++ indexer;
}
for (unsigned xi = 0; xi < x->nrStates(); xi++) {
double sum = LogAware::addIdenty();
if (Constants::DEBUG >= 5) {
calcs1 = new stringstream;
calcs2 = new stringstream;
}
indexer.reset();
if (Globals::logDomain) {
for (unsigned k = 0; k < indexer.size(); k++) {
sum = Util::logSum (sum,
x->getProbability(xi, indexer) + messageProducts[k]);
++ indexer;
}
} else {
for (unsigned k = 0; k < indexer.size(); k++) {
sum += x->getProbability (xi, indexer) * messageProducts[k];
if (Constants::DEBUG >= 5) {
if (k != 0) *calcs1 << " + " ;
if (k != 0) *calcs2 << " + " ;
*calcs1 << x->cptEntryToString (xi, indexer.indices());
*calcs1 << ".mp" << k;
*calcs2 << LogAware::fl (x->getProbability (xi, indexer));
*calcs2 << "*" << messageProducts[k];
}
++ indexer;
}
}
piValues[xi] = sum;
if (Constants::DEBUG >= 5) {
cout << " " << PI_SYMBOL << "(" << x->label() << ")" ;
cout << "[" << x->states()[xi] << "]" ;
cout << " = " << (*calcs1).str();
cout << " = " << (*calcs2).str();
cout << " = " << piValues[xi] << endl;
delete calcs1;
delete calcs2;
}
}
}
void
BnBpSolver::updateLambdaValues (BayesNode* x)
{
// λ(Xi)
if (Constants::DEBUG >= 3) {
cout << "updating " << LD_SYMBOL << " values for " << x->label() << endl;
}
Params& lambdaValues = ninf(x)->getLambdaValues();
const BpLinkSet& childLinks = ninf(x)->getIncomingChildLinks();
stringstream* calcs1 = 0;
stringstream* calcs2 = 0;
for (unsigned xi = 0; xi < x->nrStates(); xi++) {
if (Constants::DEBUG >= 5) {
calcs1 = new stringstream;
calcs2 = new stringstream;
}
double product = LogAware::multIdenty();
if (Globals::logDomain) {
for (unsigned i = 0; i < childLinks.size(); i++) {
product += childLinks[i]->getMessage()[xi];
}
} else {
for (unsigned i = 0; i < childLinks.size(); i++) {
product *= childLinks[i]->getMessage()[xi];
if (Constants::DEBUG >= 5) {
if (i != 0) *calcs1 << "." ;
if (i != 0) *calcs2 << "*" ;
*calcs1 << childLinks[i]->toString (xi);
*calcs2 << childLinks[i]->getMessage()[xi];
}
}
}
lambdaValues[xi] = product;
if (Constants::DEBUG >= 5) {
cout << " " << LD_SYMBOL << "(" << x->label() << ")" ;
cout << "[" << x->states()[xi] << "]" ;
cout << " = " << (*calcs1).str();
if (childLinks.size() == 1) {
cout << " = " << product << endl;
} else {
cout << " = " << (*calcs2).str();
cout << " = " << lambdaValues[xi] << endl;
}
delete calcs1;
delete calcs2;
}
}
}
void
BnBpSolver::calculatePiMessage (BpLink* link)
{
// πX(Zi)
BayesNode* z = link->getSource();
BayesNode* x = link->getDestination();
Params& zxPiNextMessage = link->getNextMessage();
const BpLinkSet& zChildLinks = ninf(z)->getIncomingChildLinks();
stringstream* calcs1 = 0;
stringstream* calcs2 = 0;
const Params& zPiValues = ninf(z)->getPiValues();
for (unsigned zi = 0; zi < z->nrStates(); zi++) {
double product = zPiValues[zi];
if (Constants::DEBUG >= 5) {
calcs1 = new stringstream;
calcs2 = new stringstream;
*calcs1 << PI_SYMBOL << "(" << z->label() << ")";
*calcs1 << "[" << z->states()[zi] << "]" ;
*calcs2 << product;
}
if (Globals::logDomain) {
for (unsigned i = 0; i < zChildLinks.size(); i++) {
if (zChildLinks[i]->getSource() != x) {
product += zChildLinks[i]->getMessage()[zi];
}
}
} else {
for (unsigned i = 0; i < zChildLinks.size(); i++) {
if (zChildLinks[i]->getSource() != x) {
product *= zChildLinks[i]->getMessage()[zi];
if (Constants::DEBUG >= 5) {
*calcs1 << "." << zChildLinks[i]->toString (zi);
*calcs2 << " * " << zChildLinks[i]->getMessage()[zi];
}
}
}
}
zxPiNextMessage[zi] = product;
if (Constants::DEBUG >= 5) {
cout << " " << link->toString();
cout << "[" << z->states()[zi] << "]" ;
cout << " = " << (*calcs1).str();
if (zChildLinks.size() == 1) {
cout << " = " << product << endl;
} else {
cout << " = " << (*calcs2).str();
cout << " = " << product << endl;
}
delete calcs1;
delete calcs2;
}
}
LogAware::normalize (zxPiNextMessage);
}
void
BnBpSolver::calculateLambdaMessage (BpLink* link)
{
// λY(Xi)
BayesNode* y = link->getSource();
BayesNode* x = link->getDestination();
if (x->hasEvidence()) {
return;
}
Params& yxLambdaNextMessage = link->getNextMessage();
const BpLinkSet& yParentLinks = ninf(y)->getIncomingParentLinks();
const Params& yLambdaValues = ninf(y)->getLambdaValues();
int parentIndex = y->indexOfParent (x);
stringstream* calcs1 = 0;
stringstream* calcs2 = 0;
const BnNodeSet& ps = y->getParents();
Ranges ranges;
for (unsigned i = 0; i < ps.size(); i++) {
ranges.push_back (ps[i]->nrStates());
}
StatesIndexer indexer (ranges, false);
unsigned N = indexer.size() / x->nrStates();
Params messageProducts (N);
for (unsigned k = 0; k < N; k++) {
while (indexer[parentIndex] != 0) {
++ indexer;
}
if (Constants::DEBUG >= 5) {
calcs1 = new stringstream;
calcs2 = new stringstream;
}
double messageProduct = LogAware::multIdenty();
if (Globals::logDomain) {
for (unsigned i = 0; i < yParentLinks.size(); i++) {
if (yParentLinks[i]->getSource() != x) {
messageProduct += yParentLinks[i]->getMessage()[indexer[i]];
}
}
} else {
for (unsigned i = 0; i < yParentLinks.size(); i++) {
if (yParentLinks[i]->getSource() != x) {
if (Constants::DEBUG >= 5) {
if (messageProduct != LogAware::multIdenty()) *calcs1 << "*" ;
if (messageProduct != LogAware::multIdenty()) *calcs2 << "*" ;
*calcs1 << yParentLinks[i]->toString (indexer[i]);
*calcs2 << yParentLinks[i]->getMessage()[indexer[i]];
}
messageProduct *= yParentLinks[i]->getMessage()[indexer[i]];
}
}
}
messageProducts[k] = messageProduct;
++ indexer;
if (Constants::DEBUG >= 5) {
cout << " mp" << k;
cout << " = " << (*calcs1).str();
if (yParentLinks.size() == 1) {
cout << 1 << endl;
} else if (yParentLinks.size() == 2) {
cout << " = " << messageProduct << endl;
} else {
cout << " = " << (*calcs2).str();
cout << " = " << messageProduct << endl;
}
delete calcs1;
delete calcs2;
}
}
for (unsigned xi = 0; xi < x->nrStates(); xi++) {
if (Constants::DEBUG >= 5) {
calcs1 = new stringstream;
calcs2 = new stringstream;
}
double outerSum = LogAware::addIdenty();
for (unsigned yi = 0; yi < y->nrStates(); yi++) {
if (Constants::DEBUG >= 5) {
(yi != 0) ? *calcs1 << " + {" : *calcs1 << "{" ;
(yi != 0) ? *calcs2 << " + {" : *calcs2 << "{" ;
}
double innerSum = LogAware::addIdenty();
indexer.reset();
if (Globals::logDomain) {
for (unsigned k = 0; k < N; k++) {
while (indexer[parentIndex] != xi) {
++ indexer;
}
innerSum = Util::logSum (innerSum,
y->getProbability (yi, indexer) + messageProducts[k]);
++ indexer;
}
outerSum = Util::logSum (outerSum, innerSum + yLambdaValues[yi]);
} else {
for (unsigned k = 0; k < N; k++) {
while (indexer[parentIndex] != xi) {
++ indexer;
}
if (Constants::DEBUG >= 5) {
if (k != 0) *calcs1 << " + " ;
if (k != 0) *calcs2 << " + " ;
*calcs1 << y->cptEntryToString (yi, indexer.indices());
*calcs1 << ".mp" << k;
*calcs2 << y->getProbability (yi, indexer);
*calcs2 << "*" << messageProducts[k];
}
innerSum += y->getProbability (yi, indexer) * messageProducts[k];
++ indexer;
}
outerSum += innerSum * yLambdaValues[yi];
}
if (Constants::DEBUG >= 5) {
*calcs1 << "}." << LD_SYMBOL << "(" << y->label() << ")" ;
*calcs1 << "[" << y->states()[yi] << "]";
*calcs2 << "}*" << yLambdaValues[yi];
}
}
yxLambdaNextMessage[xi] = outerSum;
if (Constants::DEBUG >= 5) {
cout << " " << link->toString();
cout << "[" << x->states()[xi] << "]" ;
cout << " = " << (*calcs1).str();
cout << " = " << (*calcs2).str();
cout << " = " << yxLambdaNextMessage[xi] << endl;
delete calcs1;
delete calcs2;
}
}
LogAware::normalize (yxLambdaNextMessage);
}
Params
BnBpSolver::getJointByConditioning (const VarIds& jointVarIds) const
{
BnNodeSet jointVars;
for (unsigned i = 0; i < jointVarIds.size(); i++) {
assert (bayesNet_->getBayesNode (jointVarIds[i]));
jointVars.push_back (bayesNet_->getBayesNode (jointVarIds[i]));
}
BayesNet* mrn = bayesNet_->getMinimalRequesiteNetwork (jointVarIds[0]);
BnBpSolver solver (*mrn);
solver.runSolver();
Params prevBeliefs = solver.getPosterioriOf (jointVarIds[0]);
delete mrn;
VarIds observedVids = {jointVars[0]->varId()};
for (unsigned i = 1; i < jointVarIds.size(); i++) {
assert (jointVars[i]->hasEvidence() == false);
VarIds reqVars = {jointVarIds[i]};
Util::addToVector (reqVars, observedVids);
mrn = bayesNet_->getMinimalRequesiteNetwork (reqVars);
Params newBeliefs;
VarNodes observedVars;
for (unsigned j = 0; j < observedVids.size(); j++) {
observedVars.push_back (mrn->getBayesNode (observedVids[j]));
}
StatesIndexer idx (observedVars, false);
while (idx.valid()) {
for (unsigned j = 0; j < observedVars.size(); j++) {
observedVars[j]->setEvidence (idx[j]);
}
BnBpSolver solver (*mrn);
solver.runSolver();
Params beliefs = solver.getPosterioriOf (jointVarIds[i]);
for (unsigned k = 0; k < beliefs.size(); k++) {
newBeliefs.push_back (beliefs[k]);
}
++ idx;
}
int count = -1;
for (unsigned j = 0; j < newBeliefs.size(); j++) {
if (j % jointVars[i]->nrStates() == 0) {
count ++;
}
newBeliefs[j] *= prevBeliefs[count];
}
prevBeliefs = newBeliefs;
observedVids.push_back (jointVars[i]->varId());
delete mrn;
}
return prevBeliefs;
}
void
BnBpSolver::printPiLambdaValues (const BayesNode* var) const
{
cout << left;
cout << setw (10) << "states" ;
cout << setw (20) << PI_SYMBOL << "(" + var->label() + ")" ;
cout << setw (20) << LD_SYMBOL << "(" + var->label() + ")" ;
cout << setw (16) << "belief" ;
cout << endl;
Util::printDashedLine();
cout << endl;
const States& states = var->states();
const Params& piVals = ninf(var)->getPiValues();
const Params& ldVals = ninf(var)->getLambdaValues();
const Params& beliefs = ninf(var)->getBeliefs();
for (unsigned xi = 0; xi < var->nrStates(); xi++) {
cout << setw (10) << states[xi];
cout << setw (19) << piVals[xi];
cout << setw (19) << ldVals[xi];
cout.precision (Constants::PRECISION);
cout << setw (16) << beliefs[xi];
cout << endl;
}
cout << endl;
}
void
BnBpSolver::printAllMessageStatus (void) const
{
const BnNodeSet& nodes = bayesNet_->getBayesNodes();
for (unsigned i = 0; i < nodes.size(); i++) {
printPiLambdaValues (nodes[i]);
}
}
BpNodeInfo::BpNodeInfo (BayesNode* node)
{
node_ = node;
piVals_.resize (node->nrStates(), LogAware::one());
ldVals_.resize (node->nrStates(), LogAware::one());
}
Params
BpNodeInfo::getBeliefs (void) const
{
double sum = 0.0;
Params beliefs (node_->nrStates());
if (Globals::logDomain) {
for (unsigned xi = 0; xi < node_->nrStates(); xi++) {
beliefs[xi] = exp (piVals_[xi] + ldVals_[xi]);
sum += beliefs[xi];
}
} else {
for (unsigned xi = 0; xi < node_->nrStates(); xi++) {
beliefs[xi] = piVals_[xi] * ldVals_[xi];
sum += beliefs[xi];
}
}
assert (sum);
for (unsigned xi = 0; xi < node_->nrStates(); xi++) {
beliefs[xi] /= sum;
}
return beliefs;
}
bool
BpNodeInfo::receivedBottomInfluence (void) const
{
// if all lambda values are equal, then neither
// this node neither its descendents have evidence,
// we can use this to don't send lambda messages his parents
bool childInfluenced = false;
for (unsigned xi = 1; xi < node_->nrStates(); xi++) {
if (ldVals_[xi] != ldVals_[0]) {
childInfluenced = true;
break;
}
}
return childInfluenced;
}

View File

@ -1,271 +0,0 @@
#ifndef HORUS_BNBPSOLVER_H
#define HORUS_BNBPSOLVER_H
#include <vector>
#include <set>
#include "Solver.h"
#include "BayesNet.h"
#include "Horus.h"
#include "Util.h"
using namespace std;
class BpNodeInfo;
static const string PI_SYMBOL = "pi" ;
static const string LD_SYMBOL = "ld" ;
enum LinkOrientation {UP, DOWN};
class BpLink
{
public:
BpLink (BayesNode* s, BayesNode* d, LinkOrientation o)
{
source_ = s;
destin_ = d;
orientation_ = o;
if (orientation_ == LinkOrientation::DOWN) {
v1_.resize (s->nrStates(), LogAware::tl (1.0 / s->nrStates()));
v2_.resize (s->nrStates(), LogAware::tl (1.0 / s->nrStates()));
} else {
v1_.resize (d->nrStates(), LogAware::tl (1.0 / d->nrStates()));
v2_.resize (d->nrStates(), LogAware::tl (1.0 / d->nrStates()));
}
currMsg_ = &v1_;
nextMsg_ = &v2_;
residual_ = 0;
msgSended_ = false;
}
BayesNode* getSource (void) const { return source_; }
BayesNode* getDestination (void) const { return destin_; }
LinkOrientation getOrientation (void) const { return orientation_; }
const Params& getMessage (void) const { return *currMsg_; }
Params& getNextMessage (void) { return *nextMsg_;}
bool messageWasSended (void) const { return msgSended_; }
double getResidual (void) const { return residual_; }
void clearResidual (void) { residual_ = 0;}
void updateMessage (void)
{
swap (currMsg_, nextMsg_);
msgSended_ = true;
}
void updateResidual (void)
{
residual_ = LogAware::getMaxNorm (v1_, v2_);
}
string toString (void) const
{
stringstream ss;
if (orientation_ == LinkOrientation::DOWN) {
ss << PI_SYMBOL;
} else {
ss << LD_SYMBOL;
}
ss << "(" << source_->label();
ss << " --> " << destin_->label() << ")" ;
return ss.str();
}
string toString (unsigned stateIndex) const
{
stringstream ss;
ss << toString() << "[" ;
if (orientation_ == LinkOrientation::DOWN) {
ss << source_->states()[stateIndex] << "]" ;
} else {
ss << destin_->states()[stateIndex] << "]" ;
}
return ss.str();
}
private:
BayesNode* source_;
BayesNode* destin_;
LinkOrientation orientation_;
Params v1_;
Params v2_;
Params* currMsg_;
Params* nextMsg_;
bool msgSended_;
double residual_;
};
typedef vector<BpLink*> BpLinkSet;
class BpNodeInfo
{
public:
BpNodeInfo (BayesNode*);
Params& getPiValues (void) { return piVals_; }
Params& getLambdaValues (void) { return ldVals_; }
const BpLinkSet& getIncomingParentLinks (void) { return inParentLinks_; }
const BpLinkSet& getIncomingChildLinks (void) { return inChildLinks_; }
const BpLinkSet& getOutcomingParentLinks (void) { return outParentLinks_; }
const BpLinkSet& getOutcomingChildLinks (void) { return outChildLinks_; }
void addIncomingParentLink (BpLink* l) { inParentLinks_.push_back (l); }
void addIncomingChildLink (BpLink* l) { inChildLinks_.push_back (l); }
void addOutcomingParentLink (BpLink* l) { outParentLinks_.push_back (l); }
void addOutcomingChildLink (BpLink* l) { outChildLinks_.push_back (l); }
Params getBeliefs (void) const;
bool receivedBottomInfluence (void) const;
private:
DISALLOW_COPY_AND_ASSIGN (BpNodeInfo);
const BayesNode* node_;
Params piVals_;
Params ldVals_;
BpLinkSet inParentLinks_;
BpLinkSet inChildLinks_;
BpLinkSet outParentLinks_;
BpLinkSet outChildLinks_;
};
class BnBpSolver : public Solver
{
public:
BnBpSolver (const BayesNet&);
~BnBpSolver (void);
void runSolver (void);
Params getPosterioriOf (VarId);
Params getJointDistributionOf (const VarIds&);
private:
DISALLOW_COPY_AND_ASSIGN (BnBpSolver);
void initializeSolver (void);
void runLoopySolver (void);
void maxResidualSchedule (void);
bool converged (void) const;
void updatePiValues (BayesNode*);
void updateLambdaValues (BayesNode*);
void calculateLambdaMessage (BpLink*);
void calculatePiMessage (BpLink*);
Params getJointByJunctionNode (const VarIds&);
Params getJointByConditioning (const VarIds&) const;
void printPiLambdaValues (const BayesNode*) const;
void printAllMessageStatus (void) const;
void calculateAndUpdateMessage (BpLink* link, bool calcResidual = true)
{
if (Constants::DEBUG >= 3) {
cout << "calculating & updating " << link->toString() << endl;
}
if (link->getOrientation() == LinkOrientation::DOWN) {
calculatePiMessage (link);
} else if (link->getOrientation() == LinkOrientation::UP) {
calculateLambdaMessage (link);
}
if (calcResidual) {
link->updateResidual();
}
link->updateMessage();
}
void calculateMessage (BpLink* link, bool calcResidual = true)
{
if (Constants::DEBUG >= 3) {
cout << "calculating " << link->toString() << endl;
}
if (link->getOrientation() == LinkOrientation::DOWN) {
calculatePiMessage (link);
} else if (link->getOrientation() == LinkOrientation::UP) {
calculateLambdaMessage (link);
}
if (calcResidual) {
link->updateResidual();
}
}
void updateMessage (BpLink* link)
{
if (Constants::DEBUG >= 3) {
cout << "updating " << link->toString() << endl;
}
link->updateMessage();
}
void updateValues (BpLink* link)
{
if (!link->getDestination()->hasEvidence()) {
if (link->getOrientation() == LinkOrientation::DOWN) {
updatePiValues (link->getDestination());
} else if (link->getOrientation() == LinkOrientation::UP) {
updateLambdaValues (link->getDestination());
}
}
}
BpNodeInfo* ninf (const BayesNode* node) const
{
assert (node);
assert (node == bayesNet_->getBayesNode (node->varId()));
assert (node->getIndex() < nodesI_.size());
return nodesI_[node->getIndex()];
}
const BayesNet* bayesNet_;
vector<BpLink*> links_;
vector<BpNodeInfo*> nodesI_;
unsigned nIters_;
struct compare
{
inline bool operator() (const BpLink* e1, const BpLink* e2)
{
return e1->getResidual() > e2->getResidual();
}
};
typedef multiset<BpLink*, compare> SortedOrder;
SortedOrder sortedOrder_;
typedef unordered_map<BpLink*, SortedOrder::iterator> BpLinkMap;
BpLinkMap linkMap_;
};
#endif // HORUS_BNBPSOLVER_H

View File

@ -5,21 +5,22 @@
#include <iostream>
#include "FgBpSolver.h"
#include "BpSolver.h"
#include "FactorGraph.h"
#include "Factor.h"
#include "Indexer.h"
#include "Horus.h"
FgBpSolver::FgBpSolver (const FactorGraph& fg) : Solver (&fg)
BpSolver::BpSolver (const FactorGraph& fg) : Solver (fg)
{
factorGraph_ = &fg;
fg_ = &fg;
runned_ = false;
}
FgBpSolver::~FgBpSolver (void)
BpSolver::~BpSolver (void)
{
for (unsigned i = 0; i < varsI_.size(); i++) {
delete varsI_[i];
@ -34,47 +35,33 @@ FgBpSolver::~FgBpSolver (void)
void
FgBpSolver::runSolver (void)
Params
BpSolver::solveQuery (VarIds queryVids)
{
clock_t start;
if (Constants::COLLECT_STATS) {
start = clock();
}
runLoopySolver();
if (Constants::DEBUG >= 2) {
cout << endl;
if (nIters_ < BpOptions::maxIter) {
cout << "Sum-Product converged in " ;
cout << nIters_ << " iterations" << endl;
assert (queryVids.empty() == false);
if (queryVids.size() == 1) {
return getPosterioriOf (queryVids[0]);
} else {
cout << "The maximum number of iterations was hit, terminating..." ;
cout << endl;
}
}
unsigned size = factorGraph_->getVarNodes().size();
if (Constants::COLLECT_STATS) {
unsigned nIters = 0;
bool loopy = factorGraph_->isTree() == false;
if (loopy) nIters = nIters_;
double time = (double (clock() - start)) / CLOCKS_PER_SEC;
Statistics::updateStatistics (size, loopy, nIters, time);
return getJointDistributionOf (queryVids);
}
}
Params
FgBpSolver::getPosterioriOf (VarId vid)
BpSolver::getPosterioriOf (VarId vid)
{
assert (factorGraph_->getFgVarNode (vid));
FgVarNode* var = factorGraph_->getFgVarNode (vid);
if (runned_ == false) {
runSolver();
}
assert (fg_->getVarNode (vid));
VarNode* var = fg_->getVarNode (vid);
Params probs;
if (var->hasEvidence()) {
probs.resize (var->nrStates(), LogAware::noEvidence());
probs.resize (var->range(), LogAware::noEvidence());
probs[var->getEvidence()] = LogAware::withEvidence();
} else {
probs.resize (var->nrStates(), LogAware::multIdenty());
probs.resize (var->range(), LogAware::multIdenty());
const SpLinkSet& links = ninf(var)->getLinks();
if (Globals::logDomain) {
for (unsigned i = 0; i < links.size(); i++) {
@ -95,13 +82,16 @@ FgBpSolver::getPosterioriOf (VarId vid)
Params
FgBpSolver::getJointDistributionOf (const VarIds& jointVarIds)
BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
{
if (runned_ == false) {
runSolver();
}
int idx = -1;
FgVarNode* vn = factorGraph_->getFgVarNode (jointVarIds[0]);
const FgFacSet& factorNodes = vn->neighbors();
for (unsigned i = 0; i < factorNodes.size(); i++) {
if (factorNodes[i]->factor()->contains (jointVarIds)) {
VarNode* vn = fg_->getVarNode (jointVarIds[0]);
const FacNodes& facNodes = vn->neighbors();
for (unsigned i = 0; i < facNodes.size(); i++) {
if (facNodes[i]->factor().contains (jointVarIds)) {
idx = i;
break;
}
@ -109,11 +99,11 @@ FgBpSolver::getJointDistributionOf (const VarIds& jointVarIds)
if (idx == -1) {
return getJointByConditioning (jointVarIds);
} else {
Factor res (*factorNodes[idx]->factor());
const SpLinkSet& links = ninf(factorNodes[idx])->getLinks();
Factor res (facNodes[idx]->factor());
const SpLinkSet& links = ninf(facNodes[idx])->getLinks();
for (unsigned i = 0; i < links.size(); i++) {
Factor msg (links[i]->getVariable()->varId(),
links[i]->getVariable()->nrStates(),
Factor msg ({links[i]->getVariable()->varId()},
{links[i]->getVariable()->range()},
getVar2FactorMsg (links[i]));
res.multiply (msg);
}
@ -131,30 +121,29 @@ FgBpSolver::getJointDistributionOf (const VarIds& jointVarIds)
void
FgBpSolver::runLoopySolver (void)
BpSolver::runSolver (void)
{
clock_t start;
if (Constants::COLLECT_STATS) {
start = clock();
}
initializeSolver();
nIters_ = 0;
while (!converged() && nIters_ < BpOptions::maxIter) {
nIters_ ++;
if (Constants::DEBUG >= 2) {
Util::printHeader (" Iteration " + nIters_);
cout << endl;
Util::printHeader (string ("Iteration ") + Util::toString (nIters_));
// cout << endl;
}
switch (BpOptions::schedule) {
case BpOptions::Schedule::SEQ_RANDOM:
random_shuffle (links_.begin(), links_.end());
// no break
case BpOptions::Schedule::SEQ_FIXED:
for (unsigned i = 0; i < links_.size(); i++) {
calculateAndUpdateMessage (links_[i]);
}
break;
case BpOptions::Schedule::PARALLEL:
for (unsigned i = 0; i < links_.size(); i++) {
calculateMessage (links_[i]);
@ -163,7 +152,6 @@ FgBpSolver::runLoopySolver (void)
updateMessage(links_[i]);
}
break;
case BpOptions::Schedule::MAX_RESIDUAL:
maxResidualSchedule();
break;
@ -172,52 +160,35 @@ FgBpSolver::runLoopySolver (void)
cout << endl;
}
}
if (Constants::DEBUG >= 2) {
cout << endl;
if (nIters_ < BpOptions::maxIter) {
cout << "Sum-Product converged in " ;
cout << nIters_ << " iterations" << endl;
} else {
cout << "The maximum number of iterations was hit, terminating..." ;
cout << endl;
}
}
unsigned size = fg_->varNodes().size();
if (Constants::COLLECT_STATS) {
unsigned nIters = 0;
bool loopy = fg_->isTree() == false;
if (loopy) nIters = nIters_;
double time = (double (clock() - start)) / CLOCKS_PER_SEC;
Statistics::updateStatistics (size, loopy, nIters, time);
}
runned_ = true;
}
void
FgBpSolver::initializeSolver (void)
BpSolver::createLinks (void)
{
const FgVarSet& varNodes = factorGraph_->getVarNodes();
for (unsigned i = 0; i < varsI_.size(); i++) {
delete varsI_[i];
}
varsI_.reserve (varNodes.size());
for (unsigned i = 0; i < varNodes.size(); i++) {
varsI_.push_back (new SPNodeInfo());
}
const FgFacSet& facNodes = factorGraph_->getFactorNodes();
for (unsigned i = 0; i < facsI_.size(); i++) {
delete facsI_[i];
}
facsI_.reserve (facNodes.size());
const FacNodes& facNodes = fg_->facNodes();
for (unsigned i = 0; i < facNodes.size(); i++) {
facsI_.push_back (new SPNodeInfo());
}
for (unsigned i = 0; i < links_.size(); i++) {
delete links_[i];
}
createLinks();
for (unsigned i = 0; i < links_.size(); i++) {
FgFacNode* src = links_[i]->getFactor();
FgVarNode* dst = links_[i]->getVariable();
ninf (dst)->addSpLink (links_[i]);
ninf (src)->addSpLink (links_[i]);
}
}
void
FgBpSolver::createLinks (void)
{
const FgFacSet& facNodes = factorGraph_->getFactorNodes();
for (unsigned i = 0; i < facNodes.size(); i++) {
const FgVarSet& neighbors = facNodes[i]->neighbors();
const VarNodes& neighbors = facNodes[i]->neighbors();
for (unsigned j = 0; j < neighbors.size(); j++) {
links_.push_back (new SpLink (facNodes[i], neighbors[j]));
}
@ -226,42 +197,8 @@ FgBpSolver::createLinks (void)
bool
FgBpSolver::converged (void)
{
if (links_.size() == 0) {
return true;
}
if (nIters_ == 0 || nIters_ == 1) {
return false;
}
bool converged = true;
if (BpOptions::schedule == BpOptions::Schedule::MAX_RESIDUAL) {
double maxResidual = (*(sortedOrder_.begin()))->getResidual();
if (maxResidual > BpOptions::accuracy) {
converged = false;
} else {
converged = true;
}
} else {
for (unsigned i = 0; i < links_.size(); i++) {
double residual = links_[i]->getResidual();
if (Constants::DEBUG >= 2) {
cout << links_[i]->toString() + " residual = " << residual << endl;
}
if (residual > BpOptions::accuracy) {
converged = false;
if (Constants::DEBUG == 0) break;
}
}
}
return converged;
}
void
FgBpSolver::maxResidualSchedule (void)
BpSolver::maxResidualSchedule (void)
{
if (nIters_ == 1) {
for (unsigned i = 0; i < links_.size(); i++) {
@ -293,7 +230,7 @@ FgBpSolver::maxResidualSchedule (void)
linkMap_.find (link)->second = sortedOrder_.insert (link);
// update the messages that depend on message source --> destin
const FgFacSet& factorNeighbors = link->getVariable()->neighbors();
const FacNodes& factorNeighbors = link->getVariable()->neighbors();
for (unsigned i = 0; i < factorNeighbors.size(); i++) {
if (factorNeighbors[i] != link->getFactor()) {
const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks();
@ -316,16 +253,16 @@ FgBpSolver::maxResidualSchedule (void)
void
FgBpSolver::calculateFactor2VariableMsg (SpLink* link) const
BpSolver::calculateFactor2VariableMsg (SpLink* link)
{
const FgFacNode* src = link->getFactor();
const FgVarNode* dst = link->getVariable();
FacNode* src = link->getFactor();
const VarNode* dst = link->getVariable();
const SpLinkSet& links = ninf(src)->getLinks();
// calculate the product of messages that were sent
// to factor `src', except from var `dst'
unsigned msgSize = 1;
for (unsigned i = 0; i < links.size(); i++) {
msgSize *= links[i]->getVariable()->nrStates();
msgSize *= links[i]->getVariable()->range();
}
unsigned repetitions = 1;
Params msgProduct (msgSize, LogAware::multIdenty());
@ -333,9 +270,9 @@ FgBpSolver::calculateFactor2VariableMsg (SpLink* link) const
for (int i = links.size() - 1; i >= 0; i--) {
if (links[i]->getVariable() != dst) {
Util::add (msgProduct, getVar2FactorMsg (links[i]), repetitions);
repetitions *= links[i]->getVariable()->nrStates();
repetitions *= links[i]->getVariable()->range();
} else {
unsigned ds = links[i]->getVariable()->nrStates();
unsigned ds = links[i]->getVariable()->range();
Util::add (msgProduct, Params (ds, 1.0), repetitions);
repetitions *= ds;
}
@ -348,22 +285,21 @@ FgBpSolver::calculateFactor2VariableMsg (SpLink* link) const
cout << ": " << endl;
}
Util::multiply (msgProduct, getVar2FactorMsg (links[i]), repetitions);
repetitions *= links[i]->getVariable()->nrStates();
repetitions *= links[i]->getVariable()->range();
} else {
unsigned ds = links[i]->getVariable()->nrStates();
unsigned ds = links[i]->getVariable()->range();
Util::multiply (msgProduct, Params (ds, 1.0), repetitions);
repetitions *= ds;
}
}
}
Factor result (src->factor()->arguments(),
src->factor()->ranges(),
msgProduct);
result.multiply (*(src->factor()));
Factor result (src->factor().arguments(),
src->factor().ranges(), msgProduct);
result.multiply (src->factor());
if (Constants::DEBUG >= 5) {
cout << " message product: " << msgProduct << endl;
cout << " original factor: " << src->params() << endl;
cout << " original factor: " << src->factor().params() << endl;
cout << " factor product: " << result.params() << endl;
}
result.sumOutAllExcept (dst->varId());
@ -386,19 +322,19 @@ FgBpSolver::calculateFactor2VariableMsg (SpLink* link) const
Params
FgBpSolver::getVar2FactorMsg (const SpLink* link) const
BpSolver::getVar2FactorMsg (const SpLink* link) const
{
const FgVarNode* src = link->getVariable();
const FgFacNode* dst = link->getFactor();
const VarNode* src = link->getVariable();
const FacNode* dst = link->getFactor();
Params msg;
if (src->hasEvidence()) {
msg.resize (src->nrStates(), LogAware::noEvidence());
msg.resize (src->range(), LogAware::noEvidence());
msg[src->getEvidence()] = LogAware::withEvidence();
if (Constants::DEBUG >= 5) {
cout << msg;
}
} else {
msg.resize (src->nrStates(), LogAware::one());
msg.resize (src->range(), LogAware::one());
}
if (Constants::DEBUG >= 5) {
cout << msg;
@ -429,16 +365,16 @@ FgBpSolver::getVar2FactorMsg (const SpLink* link) const
Params
FgBpSolver::getJointByConditioning (const VarIds& jointVarIds) const
BpSolver::getJointByConditioning (const VarIds& jointVarIds) const
{
FgVarSet jointVars;
VarNodes jointVars;
for (unsigned i = 0; i < jointVarIds.size(); i++) {
assert (factorGraph_->getFgVarNode (jointVarIds[i]));
jointVars.push_back (factorGraph_->getFgVarNode (jointVarIds[i]));
assert (fg_->getVarNode (jointVarIds[i]));
jointVars.push_back (fg_->getVarNode (jointVarIds[i]));
}
FactorGraph* fg = new FactorGraph (*factorGraph_);
FgBpSolver solver (*fg);
FactorGraph* fg = new FactorGraph (*fg_);
BpSolver solver (*fg);
solver.runSolver();
Params prevBeliefs = solver.getPosterioriOf (jointVarIds[0]);
@ -447,9 +383,9 @@ FgBpSolver::getJointByConditioning (const VarIds& jointVarIds) const
for (unsigned i = 1; i < jointVarIds.size(); i++) {
assert (jointVars[i]->hasEvidence() == false);
Params newBeliefs;
VarNodes observedVars;
Vars observedVars;
for (unsigned j = 0; j < observedVids.size(); j++) {
observedVars.push_back (fg->getFgVarNode (observedVids[j]));
observedVars.push_back (fg->getVarNode (observedVids[j]));
}
StatesIndexer idx (observedVars, false);
while (idx.valid()) {
@ -457,7 +393,7 @@ FgBpSolver::getJointByConditioning (const VarIds& jointVarIds) const
observedVars[j]->setEvidence (idx[j]);
}
++ idx;
FgBpSolver solver (*fg);
BpSolver solver (*fg);
solver.runSolver();
Params beliefs = solver.getPosterioriOf (jointVarIds[i]);
for (unsigned k = 0; k < beliefs.size(); k++) {
@ -467,7 +403,7 @@ FgBpSolver::getJointByConditioning (const VarIds& jointVarIds) const
int count = -1;
for (unsigned j = 0; j < newBeliefs.size(); j++) {
if (j % jointVars[i]->nrStates() == 0) {
if (j % jointVars[i]->range() == 0) {
count ++;
}
newBeliefs[j] *= prevBeliefs[count];
@ -481,7 +417,68 @@ FgBpSolver::getJointByConditioning (const VarIds& jointVarIds) const
void
FgBpSolver::printLinkInformation (void) const
BpSolver::initializeSolver (void)
{
const VarNodes& varNodes = fg_->varNodes();
varsI_.reserve (varNodes.size());
for (unsigned i = 0; i < varNodes.size(); i++) {
varsI_.push_back (new SPNodeInfo());
}
const FacNodes& facNodes = fg_->facNodes();
facsI_.reserve (facNodes.size());
for (unsigned i = 0; i < facNodes.size(); i++) {
facsI_.push_back (new SPNodeInfo());
}
createLinks();
for (unsigned i = 0; i < links_.size(); i++) {
FacNode* src = links_[i]->getFactor();
VarNode* dst = links_[i]->getVariable();
ninf (dst)->addSpLink (links_[i]);
ninf (src)->addSpLink (links_[i]);
}
}
bool
BpSolver::converged (void)
{
if (links_.size() == 0) {
return true;
}
if (nIters_ <= 1) {
return false;
}
bool converged = true;
if (BpOptions::schedule == BpOptions::Schedule::MAX_RESIDUAL) {
double maxResidual = (*(sortedOrder_.begin()))->getResidual();
if (maxResidual > BpOptions::accuracy) {
converged = false;
} else {
converged = true;
}
} else {
for (unsigned i = 0; i < links_.size(); i++) {
double residual = links_[i]->getResidual();
if (Constants::DEBUG >= 2) {
cout << links_[i]->toString() + " residual = " << residual << endl;
}
if (residual > BpOptions::accuracy) {
converged = false;
if (Constants::DEBUG == 0) break;
}
}
if (Constants::DEBUG >= 2) {
cout << endl;
}
}
return converged;
}
void
BpSolver::printLinkInformation (void) const
{
for (unsigned i = 0; i < links_.size(); i++) {
SpLink* l = links_[i];

View File

@ -1,5 +1,5 @@
#ifndef HORUS_FGBPSOLVER_H
#define HORUS_FGBPSOLVER_H
#ifndef HORUS_BPSOLVER_H
#define HORUS_BPSOLVER_H
#include <set>
#include <vector>
@ -16,12 +16,12 @@ using namespace std;
class SpLink
{
public:
SpLink (FgFacNode* fn, FgVarNode* vn)
SpLink (FacNode* fn, VarNode* vn)
{
fac_ = fn;
var_ = vn;
v1_.resize (vn->nrStates(), LogAware::tl (1.0 / vn->nrStates()));
v2_.resize (vn->nrStates(), LogAware::tl (1.0 / vn->nrStates()));
v1_.resize (vn->range(), LogAware::tl (1.0 / vn->range()));
v2_.resize (vn->range(), LogAware::tl (1.0 / vn->range()));
currMsg_ = &v1_;
nextMsg_ = &v2_;
msgSended_ = false;
@ -30,9 +30,9 @@ class SpLink
virtual ~SpLink (void) { };
FgFacNode* getFactor (void) const { return fac_; }
FacNode* getFactor (void) const { return fac_; }
FgVarNode* getVariable (void) const { return var_; }
VarNode* getVariable (void) const { return var_; }
const Params& getMessage (void) const { return *currMsg_; }
@ -65,8 +65,8 @@ class SpLink
}
protected:
FgFacNode* fac_;
FgVarNode* var_;
FacNode* fac_;
VarNode* var_;
Params v1_;
Params v2_;
Params* currMsg_;
@ -88,40 +88,38 @@ class SPNodeInfo
};
class FgBpSolver : public Solver
class BpSolver : public Solver
{
public:
FgBpSolver (const FactorGraph&);
BpSolver (const FactorGraph&);
virtual ~FgBpSolver (void);
virtual ~BpSolver (void);
void runSolver (void);
Params solveQuery (VarIds);
virtual Params getPosterioriOf (VarId);
virtual Params getJointDistributionOf (const VarIds&);
protected:
virtual void initializeSolver (void);
void runSolver (void);
virtual void createLinks (void);
virtual void maxResidualSchedule (void);
virtual void calculateFactor2VariableMsg (SpLink*) const;
virtual void calculateFactor2VariableMsg (SpLink*);
virtual Params getVar2FactorMsg (const SpLink*) const;
virtual Params getJointByConditioning (const VarIds&) const;
virtual void printLinkInformation (void) const;
SPNodeInfo* ninf (const FgVarNode* var) const
SPNodeInfo* ninf (const VarNode* var) const
{
return varsI_[var->getIndex()];
}
SPNodeInfo* ninf (const FgFacNode* fac) const
SPNodeInfo* ninf (const FacNode* fac) const
{
return facsI_[fac->getIndex()];
}
@ -169,7 +167,8 @@ class FgBpSolver : public Solver
unsigned nIters_;
vector<SPNodeInfo*> varsI_;
vector<SPNodeInfo*> facsI_;
const FactorGraph* factorGraph_;
bool runned_;
const FactorGraph* fg_;
typedef multiset<SpLink*, CompareResidual> SortedOrder;
SortedOrder sortedOrder_;
@ -178,9 +177,12 @@ class FgBpSolver : public Solver
SpLinkMap linkMap_;
private:
void runLoopySolver (void);
void initializeSolver (void);
bool converged (void);
void printLinkInformation (void) const;
};
#endif // HORUS_FGBPSOLVER_H
#endif // HORUS_BPSOLVER_H

View File

@ -10,22 +10,22 @@ CFactorGraph::CFactorGraph (const FactorGraph& fg)
groundFg_ = &fg;
freeColor_ = 0;
const FgVarSet& varNodes = fg.getVarNodes();
const VarNodes& varNodes = fg.varNodes();
varSignatures_.reserve (varNodes.size());
for (unsigned i = 0; i < varNodes.size(); i++) {
unsigned c = (varNodes[i]->neighbors().size() * 2) + 1;
varSignatures_.push_back (Signature (c));
}
const FgFacSet& facNodes = fg.getFactorNodes();
factorSignatures_.reserve (facNodes.size());
const FacNodes& facNodes = fg.facNodes();
facSignatures_.reserve (facNodes.size());
for (unsigned i = 0; i < facNodes.size(); i++) {
unsigned c = facNodes[i]->neighbors().size() + 1;
factorSignatures_.push_back (Signature (c));
facSignatures_.push_back (Signature (c));
}
varColors_.resize (varNodes.size());
factorColors_.resize (facNodes.size());
facColors_.resize (facNodes.size());
setInitialColors();
createGroups();
}
@ -49,9 +49,9 @@ CFactorGraph::setInitialColors (void)
{
// create the initial variable colors
VarColorMap colorMap;
const FgVarSet& varNodes = groundFg_->getVarNodes();
const VarNodes& varNodes = groundFg_->varNodes();
for (unsigned i = 0; i < varNodes.size(); i++) {
unsigned dsize = varNodes[i]->nrStates();
unsigned dsize = varNodes[i]->range();
VarColorMap::iterator it = colorMap.find (dsize);
if (it == colorMap.end()) {
it = colorMap.insert (make_pair (
@ -70,24 +70,28 @@ CFactorGraph::setInitialColors (void)
setColor (varNodes[i], stateColors[idx]);
}
const FgFacSet& facNodes = groundFg_->getFactorNodes();
if (checkForIdenticalFactors) {
const FacNodes& facNodes = groundFg_->facNodes();
for (unsigned i = 0; i < facNodes.size(); i++) {
facNodes[i]->factor().setDistId (Util::maxUnsigned());
}
// FIXME FIXME FIXME : pfl should give correct dist ids.
if (checkForIdenticalFactors || true) {
unsigned groupCount = 1;
for (unsigned i = 0; i < facNodes.size(); i++) {
Factor* f1 = facNodes[i]->factor();
if (f1->distId() != Util::maxUnsigned()) {
Factor& f1 = facNodes[i]->factor();
if (f1.distId() != Util::maxUnsigned()) {
continue;
}
f1->setDistId (groupCount);
f1.setDistId (groupCount);
for (unsigned j = i + 1; j < facNodes.size(); j++) {
Factor* f2 = facNodes[j]->factor();
if (f2->distId() != Util::maxUnsigned()) {
Factor& f2 = facNodes[j]->factor();
if (f2.distId() != Util::maxUnsigned()) {
continue;
}
if (f1->size() == f2->size() &&
f1->ranges() == f2->ranges() &&
f1->params() == f2->params()) {
f2->setDistId (groupCount);
if (f1.size() == f2.size() &&
f1.ranges() == f2.ranges() &&
f1.params() == f2.params()) {
f2.setDistId (groupCount);
}
}
groupCount ++;
@ -96,7 +100,7 @@ CFactorGraph::setInitialColors (void)
// create the initial factor colors
DistColorMap distColors;
for (unsigned i = 0; i < facNodes.size(); i++) {
unsigned distId = facNodes[i]->factor()->distId();
unsigned distId = facNodes[i]->factor().distId();
DistColorMap::iterator it = distColors.find (distId);
if (it == distColors.end()) {
it = distColors.insert (make_pair (distId, getFreeColor())).first;
@ -111,30 +115,30 @@ void
CFactorGraph::createGroups (void)
{
VarSignMap varGroups;
FacSignMap factorGroups;
FacSignMap facGroups;
unsigned nIters = 0;
bool groupsHaveChanged = true;
const FgVarSet& varNodes = groundFg_->getVarNodes();
const FgFacSet& facNodes = groundFg_->getFactorNodes();
const VarNodes& varNodes = groundFg_->varNodes();
const FacNodes& facNodes = groundFg_->facNodes();
while (groupsHaveChanged || nIters == 1) {
nIters ++;
unsigned prevFactorGroupsSize = factorGroups.size();
factorGroups.clear();
unsigned prevFactorGroupsSize = facGroups.size();
facGroups.clear();
// set a new color to the factors with the same signature
for (unsigned i = 0; i < facNodes.size(); i++) {
const Signature& signature = getSignature (facNodes[i]);
FacSignMap::iterator it = factorGroups.find (signature);
if (it == factorGroups.end()) {
it = factorGroups.insert (make_pair (signature, FgFacSet())).first;
FacSignMap::iterator it = facGroups.find (signature);
if (it == facGroups.end()) {
it = facGroups.insert (make_pair (signature, FacNodes())).first;
}
it->second.push_back (facNodes[i]);
}
for (FacSignMap::iterator it = factorGroups.begin();
it != factorGroups.end(); it++) {
for (FacSignMap::iterator it = facGroups.begin();
it != facGroups.end(); it++) {
Color newColor = getFreeColor();
FgFacSet& groupMembers = it->second;
FacNodes& groupMembers = it->second;
for (unsigned i = 0; i < groupMembers.size(); i++) {
setColor (groupMembers[i], newColor);
}
@ -147,24 +151,24 @@ CFactorGraph::createGroups (void)
const Signature& signature = getSignature (varNodes[i]);
VarSignMap::iterator it = varGroups.find (signature);
if (it == varGroups.end()) {
it = varGroups.insert (make_pair (signature, FgVarSet())).first;
it = varGroups.insert (make_pair (signature, VarNodes())).first;
}
it->second.push_back (varNodes[i]);
}
for (VarSignMap::iterator it = varGroups.begin();
it != varGroups.end(); it++) {
Color newColor = getFreeColor();
FgVarSet& groupMembers = it->second;
VarNodes& groupMembers = it->second;
for (unsigned i = 0; i < groupMembers.size(); i++) {
setColor (groupMembers[i], newColor);
}
}
groupsHaveChanged = prevVarGroupsSize != varGroups.size()
|| prevFactorGroupsSize != factorGroups.size();
|| prevFactorGroupsSize != facGroups.size();
}
//printGroups (varGroups, factorGroups);
createClusters (varGroups, factorGroups);
printGroups (varGroups, facGroups);
createClusters (varGroups, facGroups);
}
@ -172,12 +176,12 @@ CFactorGraph::createGroups (void)
void
CFactorGraph::createClusters (
const VarSignMap& varGroups,
const FacSignMap& factorGroups)
const FacSignMap& facGroups)
{
varClusters_.reserve (varGroups.size());
for (VarSignMap::const_iterator it = varGroups.begin();
it != varGroups.end(); it++) {
const FgVarSet& groupVars = it->second;
const VarNodes& groupVars = it->second;
VarCluster* vc = new VarCluster (groupVars);
for (unsigned i = 0; i < groupVars.size(); i++) {
vid2VarCluster_.insert (make_pair (groupVars[i]->varId(), vc));
@ -185,12 +189,12 @@ CFactorGraph::createClusters (
varClusters_.push_back (vc);
}
facClusters_.reserve (factorGroups.size());
for (FacSignMap::const_iterator it = factorGroups.begin();
it != factorGroups.end(); it++) {
FgFacNode* groupFactor = it->second[0];
const FgVarSet& neighs = groupFactor->neighbors();
VarClusterSet varClusters;
facClusters_.reserve (facGroups.size());
for (FacSignMap::const_iterator it = facGroups.begin();
it != facGroups.end(); it++) {
FacNode* groupFactor = it->second[0];
const VarNodes& neighs = groupFactor->neighbors();
VarClusters varClusters;
varClusters.reserve (neighs.size());
for (unsigned i = 0; i < neighs.size(); i++) {
VarId vid = neighs[i]->varId();
@ -203,15 +207,15 @@ CFactorGraph::createClusters (
const Signature&
CFactorGraph::getSignature (const FgVarNode* varNode)
CFactorGraph::getSignature (const VarNode* varNode)
{
Signature& sign = varSignatures_[varNode->getIndex()];
vector<Color>::iterator it = sign.colors.begin();
const FgFacSet& neighs = varNode->neighbors();
const FacNodes& neighs = varNode->neighbors();
for (unsigned i = 0; i < neighs.size(); i++) {
*it = getColor (neighs[i]);
it ++;
*it = neighs[i]->factor()->indexOf (varNode->varId());
*it = neighs[i]->factor().indexOf (varNode->varId());
it ++;
}
*it = getColor (varNode);
@ -221,11 +225,11 @@ CFactorGraph::getSignature (const FgVarNode* varNode)
const Signature&
CFactorGraph::getSignature (const FgFacNode* facNode)
CFactorGraph::getSignature (const FacNode* facNode)
{
Signature& sign = factorSignatures_[facNode->getIndex()];
Signature& sign = facSignatures_[facNode->getIndex()];
vector<Color>::iterator it = sign.colors.begin();
const FgVarSet& neighs = facNode->neighbors();
const VarNodes& neighs = facNode->neighbors();
for (unsigned i = 0; i < neighs.size(); i++) {
*it = getColor (neighs[i]);
it ++;
@ -237,55 +241,53 @@ CFactorGraph::getSignature (const FgFacNode* facNode)
FactorGraph*
CFactorGraph::getCompressedFactorGraph (void)
CFactorGraph::getGroundFactorGraph (void) const
{
FactorGraph* fg = new FactorGraph();
for (unsigned i = 0; i < varClusters_.size(); i++) {
FgVarNode* var = varClusters_[i]->getGroundFgVarNodes()[0];
FgVarNode* newVar = new FgVarNode (var);
VarNode* var = varClusters_[i]->getGroundVarNodes()[0];
VarNode* newVar = new VarNode (var);
varClusters_[i]->setRepresentativeVariable (newVar);
fg->addVariable (newVar);
fg->addVarNode (newVar);
}
for (unsigned i = 0; i < facClusters_.size(); i++) {
const VarClusterSet& myVarClusters = facClusters_[i]->getVarClusters();
VarNodes myGroundVars;
const VarClusters& myVarClusters = facClusters_[i]->getVarClusters();
Vars myGroundVars;
myGroundVars.reserve (myVarClusters.size());
for (unsigned j = 0; j < myVarClusters.size(); j++) {
FgVarNode* v = myVarClusters[j]->getRepresentativeVariable();
VarNode* v = myVarClusters[j]->getRepresentativeVariable();
myGroundVars.push_back (v);
}
Factor* newFactor = new Factor (myGroundVars,
facClusters_[i]->getGroundFactors()[0]->params());
FgFacNode* fn = new FgFacNode (newFactor);
FacNode* fn = new FacNode (Factor (myGroundVars,
facClusters_[i]->getGroundFactors()[0]->factor().params()));
facClusters_[i]->setRepresentativeFactor (fn);
fg->addFactor (fn);
fg->addFacNode (fn);
for (unsigned j = 0; j < myGroundVars.size(); j++) {
fg->addEdge (fn, static_cast<FgVarNode*> (myGroundVars[j]));
fg->addEdge (static_cast<VarNode*> (myGroundVars[j]), fn);
}
}
fg->setIndexes();
return fg;
}
unsigned
CFactorGraph::getGroundEdgeCount (
CFactorGraph::getEdgeCount (
const FacCluster* fc,
const VarCluster* vc) const
{
const FgFacSet& clusterGroundFactors = fc->getGroundFactors();
FgVarNode* varNode = vc->getGroundFgVarNodes()[0];
unsigned count = 0;
VarId vid = vc->getGroundVarNodes().front()->varId();
const FacNodes& clusterGroundFactors = fc->getGroundFactors();
for (unsigned i = 0; i < clusterGroundFactors.size(); i++) {
if (clusterGroundFactors[i]->factor()->indexOf (varNode->varId()) != -1) {
if (clusterGroundFactors[i]->factor().contains (vid)) {
count ++;
}
}
// CFgVarSet vars = vc->getGroundFgVarNodes();
// CVarNodes vars = vc->getGroundVarNodes();
// for (unsigned i = 1; i < vars.size(); i++) {
// FgVarNode* var = vc->getGroundFgVarNodes()[i];
// VarNode* var = vc->getGroundVarNodes()[i];
// unsigned count2 = 0;
// for (unsigned i = 0; i < clusterGroundFactors.size(); i++) {
// if (clusterGroundFactors[i]->getPosition (var) != -1) {
@ -302,13 +304,13 @@ CFactorGraph::getGroundEdgeCount (
void
CFactorGraph::printGroups (
const VarSignMap& varGroups,
const FacSignMap& factorGroups) const
const FacSignMap& facGroups) const
{
unsigned count = 1;
cout << "variable groups:" << endl;
for (VarSignMap::const_iterator it = varGroups.begin();
it != varGroups.end(); it++) {
const FgVarSet& groupMembers = it->second;
const VarNodes& groupMembers = it->second;
if (groupMembers.size() > 0) {
cout << count << ": " ;
for (unsigned i = 0; i < groupMembers.size(); i++) {
@ -321,9 +323,9 @@ CFactorGraph::printGroups (
count = 1;
cout << endl << "factor groups:" << endl;
for (FacSignMap::const_iterator it = factorGroups.begin();
it != factorGroups.end(); it++) {
const FgFacSet& groupMembers = it->second;
for (FacSignMap::const_iterator it = facGroups.begin();
it != facGroups.end(); it++) {
const FacNodes& groupMembers = it->second;
if (groupMembers.size() > 0) {
cout << ++count << ": " ;
for (unsigned i = 0; i < groupMembers.size(); i++) {

View File

@ -22,11 +22,11 @@ typedef unordered_map<unsigned, vector<Color>> VarColorMap;
typedef unordered_map<unsigned, Color> DistColorMap;
typedef unordered_map<VarId, VarCluster*> VarId2VarCluster;
typedef vector<VarCluster*> VarClusterSet;
typedef vector<FacCluster*> FacClusterSet;
typedef vector<VarCluster*> VarClusters;
typedef vector<FacCluster*> FacClusters;
typedef unordered_map<Signature, FgVarSet, SignatureHash> VarSignMap;
typedef unordered_map<Signature, FgFacSet, SignatureHash> FacSignMap;
typedef unordered_map<Signature, VarNodes, SignatureHash> VarSignMap;
typedef unordered_map<Signature, FacNodes, SignatureHash> FacSignMap;
@ -87,7 +87,7 @@ struct SignatureHash
class VarCluster
{
public:
VarCluster (const FgVarSet& vs)
VarCluster (const VarNodes& vs)
{
for (unsigned i = 0; i < vs.size(); i++) {
groundVars_.push_back (vs[i]);
@ -99,26 +99,28 @@ class VarCluster
facClusters_.push_back (fc);
}
const FacClusterSet& getFacClusters (void) const
const FacClusters& getFacClusters (void) const
{
return facClusters_;
}
FgVarNode* getRepresentativeVariable (void) const { return representVar_; }
void setRepresentativeVariable (FgVarNode* v) { representVar_ = v; }
const FgVarSet& getGroundFgVarNodes (void) const { return groundVars_; }
VarNode* getRepresentativeVariable (void) const { return representVar_; }
void setRepresentativeVariable (VarNode* v) { representVar_ = v; }
const VarNodes& getGroundVarNodes (void) const { return groundVars_; }
private:
FgVarSet groundVars_;
FacClusterSet facClusters_;
FgVarNode* representVar_;
VarNodes groundVars_;
FacClusters facClusters_;
VarNode* representVar_;
};
class FacCluster
{
public:
FacCluster (const FgFacSet& groundFactors, const VarClusterSet& vcs)
FacCluster (const FacNodes& groundFactors, const VarClusters& vcs)
{
groundFactors_ = groundFactors;
varClusters_ = vcs;
@ -127,12 +129,12 @@ class FacCluster
}
}
const VarClusterSet& getVarClusters (void) const
const VarClusters& getVarClusters (void) const
{
return varClusters_;
}
bool containsGround (const FgFacNode* fn)
bool containsGround (const FacNode* fn)
{
for (unsigned i = 0; i < groundFactors_.size(); i++) {
if (groundFactors_[i] == fn) {
@ -142,26 +144,26 @@ class FacCluster
return false;
}
FgFacNode* getRepresentativeFactor (void) const
FacNode* getRepresentativeFactor (void) const
{
return representFactor_;
}
void setRepresentativeFactor (FgFacNode* fn)
void setRepresentativeFactor (FacNode* fn)
{
representFactor_ = fn;
}
const FgFacSet& getGroundFactors (void) const
const FacNodes& getGroundFactors (void) const
{
return groundFactors_;
}
private:
FgFacSet groundFactors_;
VarClusterSet varClusters_;
FgFacNode* representFactor_;
FacNodes groundFactors_;
VarClusters varClusters_;
FacNode* representFactor_;
};
@ -172,19 +174,19 @@ class CFactorGraph
~CFactorGraph (void);
const VarClusterSet& getVarClusters (void) { return varClusters_; }
const VarClusters& getVarClusters (void) { return varClusters_; }
const FacClusterSet& getFacClusters (void) { return facClusters_; }
const FacClusters& getFacClusters (void) { return facClusters_; }
FgVarNode* getEquivalentVariable (VarId vid)
VarNode* getEquivalentVariable (VarId vid)
{
VarCluster* vc = vid2VarCluster_.find (vid)->second;
return vc->getRepresentativeVariable();
}
FactorGraph* getCompressedFactorGraph (void);
FactorGraph* getGroundFactorGraph (void) const;
unsigned getGroundEdgeCount (const FacCluster*, const VarCluster*) const;
unsigned getEdgeCount (const FacCluster*, const VarCluster*) const;
static bool checkForIdenticalFactors;
@ -195,22 +197,22 @@ class CFactorGraph
return freeColor_ - 1;
}
Color getColor (const FgVarNode* vn) const
Color getColor (const VarNode* vn) const
{
return varColors_[vn->getIndex()];
}
Color getColor (const FgFacNode* fn) const {
return factorColors_[fn->getIndex()];
Color getColor (const FacNode* fn) const {
return facColors_[fn->getIndex()];
}
void setColor (const FgVarNode* vn, Color c)
void setColor (const VarNode* vn, Color c)
{
varColors_[vn->getIndex()] = c;
}
void setColor (const FgFacNode* fn, Color c)
void setColor (const FacNode* fn, Color c)
{
factorColors_[fn->getIndex()] = c;
facColors_[fn->getIndex()] = c;
}
VarCluster* getVariableCluster (VarId vid) const
@ -224,19 +226,19 @@ class CFactorGraph
void createClusters (const VarSignMap&, const FacSignMap&);
const Signature& getSignature (const FgVarNode*);
const Signature& getSignature (const VarNode*);
const Signature& getSignature (const FgFacNode*);
const Signature& getSignature (const FacNode*);
void printGroups (const VarSignMap&, const FacSignMap&) const;
Color freeColor_;
vector<Color> varColors_;
vector<Color> factorColors_;
vector<Color> facColors_;
vector<Signature> varSignatures_;
vector<Signature> factorSignatures_;
VarClusterSet varClusters_;
FacClusterSet facClusters_;
vector<Signature> facSignatures_;
VarClusters varClusters_;
FacClusters facClusters_;
VarId2VarCluster vid2VarCluster_;
const FactorGraph* groundFg_;
};

View File

@ -1,10 +1,41 @@
#include "CbpSolver.h"
CbpSolver::CbpSolver (const FactorGraph& fg) : BpSolver (fg)
{
unsigned nGroundVars, nGroundFacs, nWithoutNeighs;
if (Constants::COLLECT_STATS) {
nGroundVars = fg_->varNodes().size();
nGroundFacs = fg_->facNodes().size();
const VarNodes& vars = fg_->varNodes();
nWithoutNeighs = 0;
for (unsigned i = 0; i < vars.size(); i++) {
const FacNodes& factors = vars[i]->neighbors();
if (factors.size() == 1 && factors[0]->neighbors().size() == 1) {
nWithoutNeighs ++;
}
}
}
cfg_ = new CFactorGraph (fg);
fg_ = cfg_->getGroundFactorGraph();
if (Constants::COLLECT_STATS) {
unsigned nClusterVars = fg_->varNodes().size();
unsigned nClusterFacs = fg_->facNodes().size();
Statistics::updateCompressingStatistics (nGroundVars,
nGroundFacs, nClusterVars, nClusterFacs, nWithoutNeighs);
}
Util::printHeader ("Uncompressed Factor Graph");
fg.print();
Util::printHeader ("Compressed Factor Graph");
fg_->print();
}
CbpSolver::~CbpSolver (void)
{
delete lfg_;
delete factorGraph_;
delete cfg_;
delete fg_;
for (unsigned i = 0; i < links_.size(); i++) {
delete links_[i];
}
@ -16,26 +47,29 @@ CbpSolver::~CbpSolver (void)
Params
CbpSolver::getPosterioriOf (VarId vid)
{
assert (lfg_->getEquivalentVariable (vid));
FgVarNode* var = lfg_->getEquivalentVariable (vid);
if (runned_ == false) {
runSolver();
}
assert (cfg_->getEquivalentVariable (vid));
VarNode* var = cfg_->getEquivalentVariable (vid);
Params probs;
if (var->hasEvidence()) {
probs.resize (var->nrStates(), LogAware::noEvidence());
probs.resize (var->range(), LogAware::noEvidence());
probs[var->getEvidence()] = LogAware::withEvidence();
} else {
probs.resize (var->nrStates(), LogAware::multIdenty());
probs.resize (var->range(), LogAware::multIdenty());
const SpLinkSet& links = ninf(var)->getLinks();
if (Globals::logDomain) {
for (unsigned i = 0; i < links.size(); i++) {
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
Util::add (probs, l->getPoweredMessage());
Util::add (probs, l->poweredMessage());
}
LogAware::normalize (probs);
Util::fromLog (probs);
} else {
for (unsigned i = 0; i < links.size(); i++) {
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
Util::multiply (probs, l->getPoweredMessage());
Util::multiply (probs, l->poweredMessage());
}
LogAware::normalize (probs);
}
@ -46,55 +80,14 @@ CbpSolver::getPosterioriOf (VarId vid)
Params
CbpSolver::getJointDistributionOf (const VarIds& jointVarIds)
CbpSolver::getJointDistributionOf (const VarIds& jointVids)
{
VarIds eqVarIds;
for (unsigned i = 0; i < jointVarIds.size(); i++) {
eqVarIds.push_back (lfg_->getEquivalentVariable (jointVarIds[i])->varId());
for (unsigned i = 0; i < jointVids.size(); i++) {
VarNode* vn = cfg_->getEquivalentVariable (jointVids[i]);
eqVarIds.push_back (vn->varId());
}
return FgBpSolver::getJointDistributionOf (eqVarIds);
}
void
CbpSolver::initializeSolver (void)
{
unsigned nGroundVars, nGroundFacs, nWithoutNeighs;
if (Constants::COLLECT_STATS) {
nGroundVars = factorGraph_->getVarNodes().size();
nGroundFacs = factorGraph_->getFactorNodes().size();
const FgVarSet& vars = factorGraph_->getVarNodes();
nWithoutNeighs = 0;
for (unsigned i = 0; i < vars.size(); i++) {
const FgFacSet& factors = vars[i]->neighbors();
if (factors.size() == 1 && factors[0]->neighbors().size() == 1) {
nWithoutNeighs ++;
}
}
}
lfg_ = new CFactorGraph (*factorGraph_);
// cout << "Uncompressed Factor Graph" << endl;
// factorGraph_->printGraphicalModel();
// factorGraph_->exportToGraphViz ("uncompressed_fg.dot");
factorGraph_ = lfg_->getCompressedFactorGraph();
if (Constants::COLLECT_STATS) {
unsigned nClusterVars = factorGraph_->getVarNodes().size();
unsigned nClusterFacs = factorGraph_->getFactorNodes().size();
Statistics::updateCompressingStatistics (nGroundVars, nGroundFacs,
nClusterVars, nClusterFacs,
nWithoutNeighs);
}
// cout << "Compressed Factor Graph" << endl;
// factorGraph_->printGraphicalModel();
// factorGraph_->exportToGraphViz ("compressed_fg.dot");
// abort();
FgBpSolver::initializeSolver();
return BpSolver::getJointDistributionOf (eqVarIds);
}
@ -102,12 +95,13 @@ CbpSolver::initializeSolver (void)
void
CbpSolver::createLinks (void)
{
const FacClusterSet fcs = lfg_->getFacClusters();
const FacClusters& fcs = cfg_->getFacClusters();
for (unsigned i = 0; i < fcs.size(); i++) {
const VarClusterSet vcs = fcs[i]->getVarClusters();
const VarClusters& vcs = fcs[i]->getVarClusters();
for (unsigned j = 0; j < vcs.size(); j++) {
unsigned c = lfg_->getGroundEdgeCount (fcs[i], vcs[j]);
links_.push_back (new CbpSolverLink (fcs[i]->getRepresentativeFactor(),
unsigned c = cfg_->getEdgeCount (fcs[i], vcs[j]);
links_.push_back (new CbpSolverLink (
fcs[i]->getRepresentativeFactor(),
vcs[j]->getRepresentativeVariable(), c));
}
}
@ -154,7 +148,7 @@ CbpSolver::maxResidualSchedule (void)
linkMap_.find (link)->second = sortedOrder_.insert (link);
// update the messages that depend on message source --> destin
const FgFacSet& factorNeighbors = link->getVariable()->neighbors();
const FacNodes& factorNeighbors = link->getVariable()->neighbors();
for (unsigned i = 0; i < factorNeighbors.size(); i++) {
const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks();
for (unsigned j = 0; j < links.size(); j++) {
@ -192,16 +186,16 @@ Params
CbpSolver::getVar2FactorMsg (const SpLink* link) const
{
Params msg;
const FgVarNode* src = link->getVariable();
const FgFacNode* dst = link->getFactor();
const VarNode* src = link->getVariable();
const FacNode* dst = link->getFactor();
const CbpSolverLink* l = static_cast<const CbpSolverLink*> (link);
if (src->hasEvidence()) {
msg.resize (src->nrStates(), LogAware::noEvidence());
msg.resize (src->range(), LogAware::noEvidence());
double value = link->getMessage()[src->getEvidence()];
msg[src->getEvidence()] = LogAware::pow (value, l->getNumberOfEdges() - 1);
msg[src->getEvidence()] = LogAware::pow (value, l->nrEdges() - 1);
} else {
msg = link->getMessage();
LogAware::pow (msg, l->getNumberOfEdges() - 1);
LogAware::pow (msg, l->nrEdges() - 1);
}
if (Constants::DEBUG >= 5) {
cout << " " << "init: " << msg << endl;
@ -211,17 +205,17 @@ CbpSolver::getVar2FactorMsg (const SpLink* link) const
for (unsigned i = 0; i < links.size(); i++) {
if (links[i]->getFactor() != dst) {
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
Util::add (msg, l->getPoweredMessage());
Util::add (msg, l->poweredMessage());
}
}
} else {
for (unsigned i = 0; i < links.size(); i++) {
if (links[i]->getFactor() != dst) {
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
Util::multiply (msg, l->getPoweredMessage());
Util::multiply (msg, l->poweredMessage());
if (Constants::DEBUG >= 5) {
cout << " msg from " << l->getFactor()->getLabel() << ": " ;
cout << l->getPoweredMessage() << endl;
cout << l->poweredMessage() << endl;
}
}
}
@ -243,7 +237,7 @@ CbpSolver::printLinkInformation (void) const
cout << l->toString() << ":" << endl;
cout << " curr msg = " << l->getMessage() << endl;
cout << " next msg = " << l->getNextMessage() << endl;
cout << " powered = " << l->getPoweredMessage() << endl;
cout << " powered = " << l->poweredMessage() << endl;
cout << " residual = " << l->getResidual() << endl;
}
}

View File

@ -1,7 +1,7 @@
#ifndef HORUS_CBP_H
#define HORUS_CBP_H
#include "FgBpSolver.h"
#include "BpSolver.h"
#include "CFactorGraph.h"
class Factor;
@ -9,35 +9,33 @@ class Factor;
class CbpSolverLink : public SpLink
{
public:
CbpSolverLink (FgFacNode* fn, FgVarNode* vn, unsigned c) : SpLink (fn, vn)
{
edgeCount_ = c;
poweredMsg_.resize (vn->nrStates(), LogAware::one());
}
CbpSolverLink (FacNode* fn, VarNode* vn, unsigned c)
: SpLink (fn, vn), nrEdges_(c),
pwdMsg_(vn->range(), LogAware::one()) { }
unsigned getNumberOfEdges (void) const { return edgeCount_; }
unsigned nrEdges (void) const { return nrEdges_; }
const Params& getPoweredMessage (void) const { return poweredMsg_; }
const Params& poweredMessage (void) const { return pwdMsg_; }
void updateMessage (void)
{
poweredMsg_ = *nextMsg_;
pwdMsg_ = *nextMsg_;
swap (currMsg_, nextMsg_);
msgSended_ = true;
LogAware::pow (poweredMsg_, edgeCount_);
LogAware::pow (pwdMsg_, nrEdges_);
}
private:
Params poweredMsg_;
unsigned edgeCount_;
unsigned nrEdges_;
Params pwdMsg_;
};
class CbpSolver : public FgBpSolver
class CbpSolver : public BpSolver
{
public:
CbpSolver (FactorGraph& fg) : FgBpSolver (fg) { }
CbpSolver (const FactorGraph& fg);
~CbpSolver (void);
@ -46,14 +44,16 @@ class CbpSolver : public FgBpSolver
Params getJointDistributionOf (const VarIds&);
private:
void initializeSolver (void);
void createLinks (void);
void maxResidualSchedule (void);
Params getVar2FactorMsg (const SpLink*) const;
void printLinkInformation (void) const;
CFactorGraph* lfg_;
CFactorGraph* cfg_;
};
#endif // HORUS_CBP_H

View File

@ -43,6 +43,14 @@ CTNode::removeChild (CTNode* child)
void
CTNode::removeChilds (void)
{
childs_.clear();
}
void
CTNode::removeAndDeleteChild (CTNode* child)
{
@ -897,19 +905,19 @@ ConstraintTree::getNodesAtLevel (unsigned level) const
void
ConstraintTree::swapLogVar (LogVar X)
{
TupleSet before = tupleSet();
LogVars::iterator it =
std::find (logVars_.begin(),logVars_.end(), X);
assert (it != logVars_.end());
unsigned pos = std::distance (logVars_.begin(), it);
const CTNodes& nodes = getNodesAtLevel (pos);
for (unsigned i = 0; i < nodes.size(); i++) {
const CTNodes childs = nodes[i]->childs();
for (unsigned j = 0; j < childs.size(); j++) {
nodes[i]->removeChild (childs[j]);
const CTNodes grandsons = childs[j]->childs();
CTNodes childsCopy = nodes[i]->childs();
nodes[i]->removeChilds();
for (unsigned j = 0; j < childsCopy.size(); j++) {
const CTNodes grandsons = childsCopy[j]->childs();
for (unsigned k = 0; k < grandsons.size(); k++) {
CTNode* childCopy = new CTNode (*childs[j]);
CTNode* childCopy = new CTNode (*childsCopy[j]);
const CTNodes greatGrandsons = grandsons[k]->childs();
for (unsigned t = 0; t < greatGrandsons.size(); t++) {
grandsons[k]->removeChild (greatGrandsons[t]);
@ -920,10 +928,9 @@ ConstraintTree::swapLogVar (LogVar X)
grandsons[k]->setLevel (grandsons[k]->level() - 1);
nodes[i]->addChild (grandsons[k], false);
}
delete childs[j];
delete childsCopy[j];
}
}
std::swap (logVars_[pos], logVars_[pos + 1]);
}

View File

@ -50,6 +50,8 @@ class CTNode
void removeChild (CTNode*);
void removeChilds (void);
void removeAndDeleteChild (CTNode*);
void removeAndDeleteAllChilds (void);

View File

@ -3,54 +3,38 @@
#include <fstream>
#include "ElimGraph.h"
#include "BayesNet.h"
ElimHeuristic ElimGraph::elimHeuristic_ = MIN_NEIGHBORS;
ElimGraph::ElimGraph (const BayesNet& bayesNet)
ElimGraph::ElimGraph (const vector<Factor*>& factors)
{
const BnNodeSet& bnNodes = bayesNet.getBayesNodes();
for (unsigned i = 0; i < bnNodes.size(); i++) {
if (bnNodes[i]->hasEvidence() == false) {
addNode (new EgNode (bnNodes[i]));
for (unsigned i = 0; i < factors.size(); i++) {
const VarIds& vids = factors[i]->arguments();
for (unsigned j = 0; j < vids.size() - 1; j++) {
EgNode* n1 = getEgNode (vids[j]);
if (n1 == 0) {
n1 = new EgNode (vids[j], factors[i]->range (j));
addNode (n1);
}
for (unsigned k = j + 1; k < vids.size(); k++) {
EgNode* n2 = getEgNode (vids[k]);
if (n2 == 0) {
n2 = new EgNode (vids[k], factors[i]->range (k));
addNode (n2);
}
for (unsigned i = 0; i < bnNodes.size(); i++) {
if (bnNodes[i]->hasEvidence() == false) {
EgNode* n = getEgNode (bnNodes[i]->varId());
const BnNodeSet& childs = bnNodes[i]->getChilds();
for (unsigned j = 0; j < childs.size(); j++) {
if (childs[j]->hasEvidence() == false) {
addEdge (n, getEgNode (childs[j]->varId()));
if (neighbors (n1, n2) == false) {
addEdge (n1, n2);
}
}
}
}
for (unsigned i = 0; i < bnNodes.size(); i++) {
vector<EgNode*> neighs;
const vector<BayesNode*>& parents = bnNodes[i]->getParents();
for (unsigned i = 0; i < parents.size(); i++) {
if (parents[i]->hasEvidence() == false) {
neighs.push_back (getEgNode (parents[i]->varId()));
}
}
if (neighs.size() > 0) {
for (unsigned i = 0; i < neighs.size() - 1; i++) {
for (unsigned j = i+1; j < neighs.size(); j++) {
if (!neighbors (neighs[i], neighs[j])) {
addEdge (neighs[i], neighs[j]);
if (vids.size() == 1) {
if (getEgNode (vids[0]) == 0) {
addNode (new EgNode (vids[0], factors[i]->range (0)));
}
}
}
}
}
setIndexes();
}
@ -63,40 +47,16 @@ ElimGraph::~ElimGraph (void)
void
ElimGraph::addNode (EgNode* n)
{
nodes_.push_back (n);
varMap_.insert (make_pair (n->varId(), n));
}
EgNode*
ElimGraph::getEgNode (VarId vid) const
{
unordered_map<VarId,EgNode*>::const_iterator it =varMap_.find (vid);
if (it ==varMap_.end()) {
return 0;
} else {
return it->second;
}
}
VarIds
ElimGraph::getEliminatingOrder (const VarIds& exclude)
{
VarIds elimOrder;
marked_.resize (nodes_.size(), false);
for (unsigned i = 0; i < exclude.size(); i++) {
assert (getEgNode (exclude[i]));
EgNode* node = getEgNode (exclude[i]);
assert (node);
marked_[*node] = true;
}
unsigned nVarsToEliminate = nodes_.size() - exclude.size();
for (unsigned i = 0; i < nVarsToEliminate; i++) {
EgNode* node = getLowestCostNode();
@ -109,6 +69,100 @@ ElimGraph::getEliminatingOrder (const VarIds& exclude)
void
ElimGraph::print (void) const
{
for (unsigned i = 0; i < nodes_.size(); i++) {
cout << "node " << nodes_[i]->label() << " neighs:" ;
vector<EgNode*> neighs = nodes_[i]->neighbors();
for (unsigned j = 0; j < neighs.size(); j++) {
cout << " " << neighs[j]->label();
}
cout << endl;
}
}
void
ElimGraph::exportToGraphViz (
const char* fileName,
bool showNeighborless,
const VarIds& highlightVarIds) const
{
ofstream out (fileName);
if (!out.is_open()) {
cerr << "error: cannot open file to write at " ;
cerr << "Markov::exportToDotFile()" << endl;
abort();
}
out << "strict graph {" << endl;
for (unsigned i = 0; i < nodes_.size(); i++) {
if (showNeighborless || nodes_[i]->neighbors().size() != 0) {
out << '"' << nodes_[i]->label() << '"' << endl;
}
}
for (unsigned i = 0; i < highlightVarIds.size(); i++) {
EgNode* node =getEgNode (highlightVarIds[i]);
if (node) {
out << '"' << node->label() << '"' ;
out << " [shape=box3d]" << endl;
} else {
cout << "error: invalid variable id: " << highlightVarIds[i] << endl;
abort();
}
}
for (unsigned i = 0; i < nodes_.size(); i++) {
vector<EgNode*> neighs = nodes_[i]->neighbors();
for (unsigned j = 0; j < neighs.size(); j++) {
out << '"' << nodes_[i]->label() << '"' << " -- " ;
out << '"' << neighs[j]->label() << '"' << endl;
}
}
out << "}" << endl;
out.close();
}
VarIds
ElimGraph::getEliminationOrder (
const vector<Factor*> factors,
VarIds excludedVids)
{
ElimGraph graph (factors);
// graph.print();
// graph.exportToGraphViz ("_egg.dot");
return graph.getEliminatingOrder (excludedVids);
}
void
ElimGraph::addNode (EgNode* n)
{
nodes_.push_back (n);
n->setIndex (nodes_.size() - 1);
varMap_.insert (make_pair (n->varId(), n));
}
EgNode*
ElimGraph::getEgNode (VarId vid) const
{
unordered_map<VarId, EgNode*>::const_iterator it;
it = varMap_.find (vid);
return (it != varMap_.end()) ? it->second : 0;
}
EgNode*
ElimGraph::getLowestCostNode (void) const
{
@ -166,7 +220,7 @@ ElimGraph::getWeightCost (const EgNode* n) const
const vector<EgNode*>& neighs = n->neighbors();
for (unsigned i = 0; i < neighs.size(); i++) {
if (marked_[*neighs[i]] == false) {
cost *= neighs[i]->nrStates();
cost *= neighs[i]->range();
}
}
return cost;
@ -206,7 +260,7 @@ ElimGraph::getWeightedFillCost (const EgNode* n) const
for (unsigned j = i+1; j < neighs.size(); j++) {
if (marked_[*neighs[j]] == true) continue;
if (!neighbors (neighs[i], neighs[j])) {
cost += neighs[i]->nrStates() * neighs[j]->nrStates();
cost += neighs[i]->range() * neighs[j]->range();
}
}
}
@ -247,78 +301,3 @@ ElimGraph::neighbors (const EgNode* n1, const EgNode* n2) const
return false;
}
void
ElimGraph::setIndexes (void)
{
for (unsigned i = 0; i < nodes_.size(); i++) {
nodes_[i]->setIndex (i);
}
}
void
ElimGraph::printGraphicalModel (void) const
{
for (unsigned i = 0; i < nodes_.size(); i++) {
cout << "node " << nodes_[i]->label() << " neighs:" ;
vector<EgNode*> neighs = nodes_[i]->neighbors();
for (unsigned j = 0; j < neighs.size(); j++) {
cout << " " << neighs[j]->label();
}
cout << endl;
}
}
void
ElimGraph::exportToGraphViz (const char* fileName,
bool showNeighborless,
const VarIds& highlightVarIds) const
{
ofstream out (fileName);
if (!out.is_open()) {
cerr << "error: cannot open file to write at " ;
cerr << "Markov::exportToDotFile()" << endl;
abort();
}
out << "strict graph {" << endl;
for (unsigned i = 0; i < nodes_.size(); i++) {
if (showNeighborless || nodes_[i]->neighbors().size() != 0) {
out << '"' << nodes_[i]->label() << '"' ;
if (nodes_[i]->hasEvidence()) {
out << " [style=filled, fillcolor=yellow]" << endl;
} else {
out << endl;
}
}
}
for (unsigned i = 0; i < highlightVarIds.size(); i++) {
EgNode* node =getEgNode (highlightVarIds[i]);
if (node) {
out << '"' << node->label() << '"' ;
out << " [shape=box3d]" << endl;
} else {
cout << "error: invalid variable id: " << highlightVarIds[i] << endl;
abort();
}
}
for (unsigned i = 0; i < nodes_.size(); i++) {
vector<EgNode*> neighs = nodes_[i]->neighbors();
for (unsigned j = 0; j < neighs.size(); j++) {
out << '"' << nodes_[i]->label() << '"' << " -- " ;
out << '"' << neighs[j]->label() << '"' << endl;
}
}
out << "}" << endl;
out.close();
}

View File

@ -17,10 +17,10 @@ enum ElimHeuristic
};
class EgNode : public VarNode
class EgNode : public Var
{
public:
EgNode (VarNode* var) : VarNode (var) { }
EgNode (VarId vid, unsigned range) : Var (vid, range) { }
void addNeighbor (EgNode* n) { neighs_.push_back (n); }
@ -34,10 +34,26 @@ class EgNode : public VarNode
class ElimGraph
{
public:
ElimGraph (const BayesNet&);
ElimGraph (const vector<Factor*>&); // TODO
~ElimGraph (void);
VarIds getEliminatingOrder (const VarIds&);
void print (void) const;
void exportToGraphViz (const char*, bool = true,
const VarIds& = VarIds()) const;
static VarIds getEliminationOrder (const vector<Factor*>, VarIds);
static void setEliminationHeuristic (ElimHeuristic h)
{
elimHeuristic_ = h;
}
private:
void addEdge (EgNode* n1, EgNode* n2)
{
assert (n1 != n2);
@ -48,22 +64,6 @@ class ElimGraph
void addNode (EgNode*);
EgNode* getEgNode (VarId) const;
VarIds getEliminatingOrder (const VarIds&);
void printGraphicalModel (void) const;
void exportToGraphViz (const char*, bool = true,
const VarIds& = VarIds()) const;
void setIndexes();
static void setEliminationHeuristic (ElimHeuristic h)
{
elimHeuristic_ = h;
}
private:
EgNode* getLowestCostNode (void) const;
unsigned getNeighborsCost (const EgNode*) const;

View File

@ -18,56 +18,14 @@ Factor::Factor (const Factor& g)
Factor::Factor (VarId vid, unsigned nrStates)
{
args_.push_back (vid);
ranges_.push_back (nrStates);
params_.resize (nrStates, 1.0);
distId_ = Util::maxUnsigned();
assert (params_.size() == Util::expectedSize (ranges_));
}
Factor::Factor (const VarNodes& vars)
{
int nrParams = 1;
for (unsigned i = 0; i < vars.size(); i++) {
args_.push_back (vars[i]->varId());
ranges_.push_back (vars[i]->nrStates());
nrParams *= vars[i]->nrStates();
}
double val = 1.0 / nrParams;
params_.resize (nrParams, val);
distId_ = Util::maxUnsigned();
assert (params_.size() == Util::expectedSize (ranges_));
}
Factor::Factor (
VarId vid,
unsigned nrStates,
const Params& params)
{
args_.push_back (vid);
ranges_.push_back (nrStates);
params_ = params;
distId_ = Util::maxUnsigned();
assert (params_.size() == Util::expectedSize (ranges_));
}
Factor::Factor (
const VarNodes& vars,
const VarIds& vids,
const Ranges& ranges,
const Params& params,
unsigned distId)
{
for (unsigned i = 0; i < vars.size(); i++) {
args_.push_back (vars[i]->varId());
ranges_.push_back (vars[i]->nrStates());
}
args_ = vids;
ranges_ = ranges;
params_ = params;
distId_ = distId;
assert (params_.size() == Util::expectedSize (ranges_));
@ -76,14 +34,16 @@ Factor::Factor (
Factor::Factor (
const VarIds& vids,
const Ranges& ranges,
const Params& params)
const Vars& vars,
const Params& params,
unsigned distId)
{
args_ = vids;
ranges_ = ranges;
for (unsigned i = 0; i < vars.size(); i++) {
args_.push_back (vars[i]->varId());
ranges_.push_back (vars[i]->range());
}
params_ = params;
distId_ = Util::maxUnsigned();
distId_ = distId;
assert (params_.size() == Util::expectedSize (ranges_));
}
@ -185,8 +145,8 @@ Factor::sumOut (VarId vid)
void
Factor::sumOutFirstVariable (void)
{
unsigned nStates = ranges_.front();
unsigned sep = params_.size() / nStates;
unsigned range = ranges_.front();
unsigned sep = params_.size() / range;
if (Globals::logDomain) {
for (unsigned i = sep; i < params_.size(); i++) {
params_[i % sep] = Util::logSum (params_[i % sep], params_[i]);
@ -206,14 +166,14 @@ Factor::sumOutFirstVariable (void)
void
Factor::sumOutLastVariable (void)
{
unsigned nStates = ranges_.back();
unsigned range = ranges_.back();
unsigned idx1 = 0;
unsigned idx2 = 0;
if (Globals::logDomain) {
while (idx1 < params_.size()) {
params_[idx2] = params_[idx1];
idx1 ++;
for (unsigned j = 1; j < nStates; j++) {
for (unsigned j = 1; j < range; j++) {
params_[idx2] = Util::logSum (params_[idx2], params_[idx1]);
idx1 ++;
}
@ -223,7 +183,7 @@ Factor::sumOutLastVariable (void)
while (idx1 < params_.size()) {
params_[idx2] = params_[idx1];
idx1 ++;
for (unsigned j = 1; j < nStates; j++) {
for (unsigned j = 1; j < range; j++) {
params_[idx2] += params_[idx1];
idx1 ++;
}
@ -266,7 +226,7 @@ Factor::getLabel (void) const
ss << "f(" ;
for (unsigned i = 0; i < args_.size(); i++) {
if (i != 0) ss << "," ;
ss << VarNode (args_[i], ranges_[i]).label();
ss << Var (args_[i], ranges_[i]).label();
}
ss << ")" ;
return ss.str();
@ -277,13 +237,13 @@ Factor::getLabel (void) const
void
Factor::print (void) const
{
VarNodes vars;
Vars vars;
for (unsigned i = 0; i < args_.size(); i++) {
vars.push_back (new VarNode (args_[i], ranges_[i]));
vars.push_back (new Var (args_[i], ranges_[i]));
}
vector<string> jointStrings = Util::getJointStateStrings (vars);
vector<string> jointStrings = Util::getStateLines (vars);
for (unsigned i = 0; i < params_.size(); i++) {
cout << "f(" << jointStrings[i] << ")" ;
cout << "[" << distId_ << "] f(" << jointStrings[i] << ")" ;
cout << " = " << params_[i] << endl;
}
cout << endl;

View File

@ -3,7 +3,7 @@
#include <vector>
#include "VarNode.h"
#include "Var.h"
#include "Indexer.h"
#include "Util.h"
@ -33,17 +33,14 @@ class TFactor
void setDistId (unsigned id) { distId_ = id; }
void normalize (void) { LogAware::normalize (params_); }
void setParams (const Params& newParams)
{
params_ = newParams;
assert (params_.size() == Util::expectedSize (ranges_));
}
void normalize (void)
{
LogAware::normalize (params_);
}
int indexOf (const T& t) const
{
int idx = -1;
@ -258,16 +255,11 @@ class Factor : public TFactor<VarId>
Factor (const Factor&);
Factor (VarId, unsigned);
Factor (const VarNodes&);
Factor (VarId, unsigned, const Params&);
Factor (const VarNodes&, const Params&,
Factor (const VarIds&, const Ranges&, const Params&,
unsigned = Util::maxUnsigned());
Factor (const VarIds&, const Ranges&, const Params&);
Factor (const Vars&, const Params&,
unsigned = Util::maxUnsigned());
void sumOutAllExcept (VarId);

View File

@ -9,6 +9,7 @@
#include "FactorGraph.h"
#include "Factor.h"
#include "BayesNet.h"
#include "BayesBall.h"
#include "Util.h"
@ -17,140 +18,92 @@ bool FactorGraph::orderFactorVariables = false;
FactorGraph::FactorGraph (const FactorGraph& fg)
{
const FgVarSet& vars = fg.getVarNodes();
for (unsigned i = 0; i < vars.size(); i++) {
FgVarNode* varNode = new FgVarNode (vars[i]);
addVariable (varNode);
const VarNodes& varNodes = fg.varNodes();
for (unsigned i = 0; i < varNodes.size(); i++) {
addVarNode (new VarNode (varNodes[i]));
}
const FgFacSet& facs = fg.getFactorNodes();
for (unsigned i = 0; i < facs.size(); i++) {
FgFacNode* facNode = new FgFacNode (facs[i]);
addFactor (facNode);
const FgVarSet& neighs = facs[i]->neighbors();
const FacNodes& facNodes = fg.facNodes();
for (unsigned i = 0; i < facNodes.size(); i++) {
FacNode* facNode = new FacNode (facNodes[i]->factor());
addFacNode (facNode);
const VarNodes& neighs = facNodes[i]->neighbors();
for (unsigned j = 0; j < neighs.size(); j++) {
addEdge (facNode, varNodes_[neighs[j]->getIndex()]);
addEdge (varNodes_[neighs[j]->getIndex()], facNode);
}
}
}
FactorGraph::FactorGraph (const BayesNet& bn)
{
const BnNodeSet& nodes = bn.getBayesNodes();
for (unsigned i = 0; i < nodes.size(); i++) {
FgVarNode* varNode = new FgVarNode (nodes[i]);
addVariable (varNode);
}
for (unsigned i = 0; i < nodes.size(); i++) {
const BnNodeSet& parents = nodes[i]->getParents();
if (!(nodes[i]->hasEvidence() && parents.size() == 0)) {
VarNodes neighs;
neighs.push_back (varNodes_[nodes[i]->getIndex()]);
for (unsigned j = 0; j < parents.size(); j++) {
neighs.push_back (varNodes_[parents[j]->getIndex()]);
}
FgFacNode* fn = new FgFacNode (
new Factor (neighs, nodes[i]->params(), nodes[i]->distId()));
if (orderFactorVariables) {
sort (neighs.begin(), neighs.end(), CompVarId());
fn->factor()->reorderAccordingVarIds();
}
addFactor (fn);
for (unsigned j = 0; j < neighs.size(); j++) {
addEdge (fn, static_cast<FgVarNode*> (neighs[j]));
}
}
}
setIndexes();
}
void
FactorGraph::readFromUaiFormat (const char* fileName)
{
ifstream is (fileName);
std::ifstream is (fileName);
if (!is.is_open()) {
cerr << "error: cannot read from file " + std::string (fileName) << endl;
cerr << "error: cannot read from file " << fileName << endl;
abort();
}
ignoreLines (is);
string line;
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
getline (is, line);
if (line != "MARKOV") {
cerr << "error: the network must be a MARKOV network " << endl;
abort();
}
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
unsigned nVars;
is >> nVars;
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
vector<int> domainSizes (nVars);
for (unsigned i = 0; i < nVars; i++) {
unsigned ds;
is >> ds;
domainSizes[i] = ds;
// read the number of vars
ignoreLines (is);
unsigned nrVars;
is >> nrVars;
// read the range of each var
ignoreLines (is);
Ranges ranges (nrVars);
for (unsigned i = 0; i < nrVars; i++) {
is >> ranges[i];
}
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
for (unsigned i = 0; i < nVars; i++) {
addVariable (new FgVarNode (i, domainSizes[i]));
}
unsigned nFactors;
is >> nFactors;
for (unsigned i = 0; i < nFactors; i++) {
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
unsigned nFactorVars;
is >> nFactorVars;
VarNodes neighs;
for (unsigned j = 0; j < nFactorVars; j++) {
unsigned nrFactors;
unsigned nrArgs;
unsigned vid;
is >> nrFactors;
vector<VarIds> factorVarIds;
vector<Ranges> factorRanges;
for (unsigned i = 0; i < nrFactors; i++) {
ignoreLines (is);
is >> nrArgs;
factorVarIds.push_back ({ });
factorRanges.push_back ({ });
for (unsigned j = 0; j < nrArgs; j++) {
is >> vid;
FgVarNode* neigh = getFgVarNode (vid);
if (!neigh) {
cerr << "error: invalid variable identifier (" << vid << ")" << endl;
if (vid >= ranges.size()) {
cerr << "error: invalid variable identifier `" << vid << "'" << endl;
cerr << "identifiers must be between 0 and " << ranges.size() - 1 ;
cerr << endl;
abort();
}
neighs.push_back (neigh);
}
FgFacNode* fn = new FgFacNode (new Factor (neighs));
addFactor (fn);
for (unsigned j = 0; j < neighs.size(); j++) {
addEdge (fn, static_cast<FgVarNode*> (neighs[j]));
factorVarIds.back().push_back (vid);
factorRanges.back().push_back (ranges[vid]);
}
}
for (unsigned i = 0; i < nFactors; i++) {
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
unsigned nParams;
is >> nParams;
if (facNodes_[i]->params().size() != nParams) {
cerr << "error: invalid number of parameters for factor " ;
cerr << facNodes_[i]->getLabel() ;
cerr << ", expected: " << facNodes_[i]->params().size();
cerr << ", given: " << nParams << endl;
// read the parameters
unsigned nrParams;
for (unsigned i = 0; i < nrFactors; i++) {
ignoreLines (is);
is >> nrParams;
if (nrParams != Util::expectedSize (factorRanges[i])) {
cerr << "error: invalid number of parameters for factor nº " << i ;
cerr << ", expected: " << Util::expectedSize (factorRanges[i]);
cerr << ", given: " << nrParams << endl;
abort();
}
Params params (nParams);
for (unsigned j = 0; j < nParams; j++) {
double param;
is >> param;
params[j] = param;
Params params (nrParams);
for (unsigned j = 0; j < nrParams; j++) {
is >> params[j];
}
if (Globals::logDomain) {
Util::toLog (params);
}
facNodes_[i]->factor()->setParams (params);
addFactor (Factor (factorVarIds[i], factorRanges[i], params));
}
is.close();
setIndexes();
}
@ -158,87 +111,58 @@ FactorGraph::readFromUaiFormat (const char* fileName)
void
FactorGraph::readFromLibDaiFormat (const char* fileName)
{
ifstream is (fileName);
std::ifstream is (fileName);
if (!is.is_open()) {
cerr << "error: cannot read from file " + std::string (fileName) << endl;
cerr << "error: cannot read from file " << fileName << endl;
abort();
}
string line;
unsigned nFactors;
while ((is.peek()) == '#') getline (is, line);
is >> nFactors;
if (is.fail()) {
cerr << "error: cannot read the number of factors" << endl;
abort();
}
getline (is, line);
if (is.fail() || line.size() > 0) {
cerr << "error: cannot read the number of factors" << endl;
abort();
}
for (unsigned i = 0; i < nFactors; i++) {
unsigned nVars;
while ((is.peek()) == '#') getline (is, line);
is >> nVars;
VarIds vids;
for (unsigned j = 0; j < nVars; j++) {
ignoreLines (is);
unsigned nrFactors;
unsigned nrArgs;
VarId vid;
while ((is.peek()) == '#') getline (is, line);
is >> nrFactors;
for (unsigned i = 0; i < nrFactors; i++) {
ignoreLines (is);
// read the factor arguments
is >> nrArgs;
VarIds vids;
for (unsigned j = 0; j < nrArgs; j++) {
ignoreLines (is);
is >> vid;
vids.push_back (vid);
}
VarNodes neighs;
unsigned nParams = 1;
for (unsigned j = 0; j < nVars; j++) {
unsigned dsize;
while ((is.peek()) == '#') getline (is, line);
is >> dsize;
FgVarNode* var = getFgVarNode (vids[j]);
if (var == 0) {
var = new FgVarNode (vids[j], dsize);
addVariable (var);
} else {
if (var->nrStates() != dsize) {
// read ranges
Ranges ranges (nrArgs);
for (unsigned j = 0; j < nrArgs; j++) {
ignoreLines (is);
is >> ranges[j];
VarNode* var = getVarNode (vids[j]);
if (var != 0 && ranges[j] != var->range()) {
cerr << "error: variable `" << vids[j] << "' appears in two or " ;
cerr << "more factors with different domain sizes" << endl;
cerr << "more factors with a different range" << endl;
}
}
neighs.push_back (var);
nParams *= var->nrStates();
}
Params params (nParams, 0);
// read parameters
ignoreLines (is);
unsigned nNonzeros;
while ((is.peek()) == '#') getline (is, line);
is >> nNonzeros;
Params params (Util::expectedSize (ranges), 0);
for (unsigned j = 0; j < nNonzeros; j++) {
ignoreLines (is);
unsigned index;
double val;
while ((is.peek()) == '#') getline (is, line);
is >> index;
while ((is.peek()) == '#') getline (is, line);
ignoreLines (is);
double val;
is >> val;
params[index] = val;
}
reverse (neighs.begin(), neighs.end());
reverse (vids.begin(), vids.end());
if (Globals::logDomain) {
Util::toLog (params);
}
FgFacNode* fn = new FgFacNode (new Factor (neighs, params));
addFactor (fn);
for (unsigned j = 0; j < neighs.size(); j++) {
addEdge (fn, static_cast<FgVarNode*> (neighs[j]));
}
addFactor (Factor (vids, ranges, params));
}
is.close();
setIndexes();
}
@ -256,17 +180,41 @@ FactorGraph::~FactorGraph (void)
void
FactorGraph::addVariable (FgVarNode* vn)
FactorGraph::addFactor (const Factor& factor)
{
varNodes_.push_back (vn);
vn->setIndex (varNodes_.size() - 1);
varMap_.insert (make_pair (vn->varId(), varNodes_.size() - 1));
FacNode* fn = new FacNode (factor);
addFacNode (fn);
const VarIds& vids = factor.arguments();
for (unsigned i = 0; i < vids.size(); i++) {
bool found = false;
for (unsigned j = 0; j < varNodes_.size(); j++) {
if (varNodes_[j]->varId() == vids[i]) {
addEdge (varNodes_[j], fn);
found = true;
}
}
if (found == false) {
VarNode* vn = new VarNode (vids[i], factor.range (i));
addVarNode (vn);
addEdge (vn, fn);
}
}
}
void
FactorGraph::addFactor (FgFacNode* fn)
FactorGraph::addVarNode (VarNode* vn)
{
varNodes_.push_back (vn);
vn->setIndex (varNodes_.size() - 1);
varMap_.insert (make_pair (vn->varId(), vn));
}
void
FactorGraph::addFacNode (FacNode* fn)
{
facNodes_.push_back (fn);
fn->setIndex (facNodes_.size() - 1);
@ -275,7 +223,7 @@ FactorGraph::addFactor (FgFacNode* fn)
void
FactorGraph::addEdge (FgVarNode* vn, FgFacNode* fn)
FactorGraph::addEdge (VarNode* vn, FacNode* fn)
{
vn->addNeighbor (fn);
fn->addNeighbor (vn);
@ -283,37 +231,6 @@ FactorGraph::addEdge (FgVarNode* vn, FgFacNode* fn)
void
FactorGraph::addEdge (FgFacNode* fn, FgVarNode* vn)
{
fn->addNeighbor (vn);
vn->addNeighbor (fn);
}
VarNode*
FactorGraph::getVariableNode (VarId vid) const
{
FgVarNode* vn = getFgVarNode (vid);
assert (vn);
return vn;
}
VarNodes
FactorGraph::getVariableNodes (void) const
{
VarNodes vars;
for (unsigned i = 0; i < varNodes_.size(); i++) {
vars.push_back (varNodes_[i]);
}
return vars;
}
bool
FactorGraph::isTree (void) const
{
@ -322,36 +239,42 @@ FactorGraph::isTree (void) const
void
FactorGraph::setIndexes (void)
DAGraph&
FactorGraph::getStructure (void)
{
assert (fromBayesNet_);
if (structure_.empty()) {
for (unsigned i = 0; i < varNodes_.size(); i++) {
varNodes_[i]->setIndex (i);
structure_.addNode (new DAGraphNode (varNodes_[i]));
}
for (unsigned i = 0; i < facNodes_.size(); i++) {
facNodes_[i]->setIndex (i);
const VarIds& vids = facNodes_[i]->factor().arguments();
for (unsigned j = 1; j < vids.size(); j++) {
structure_.addEdge (vids[j], vids[0]);
}
}
}
return structure_;
}
void
FactorGraph::printGraphicalModel (void) const
FactorGraph::print (void) const
{
for (unsigned i = 0; i < varNodes_.size(); i++) {
cout << "VarId = " << varNodes_[i]->varId() << endl;
cout << "Label = " << varNodes_[i]->label() << endl;
cout << "Nr States = " << varNodes_[i]->nrStates() << endl;
cout << "Evidence = " << varNodes_[i]->getEvidence() << endl;
cout << "Factors = " ;
cout << "var id = " << varNodes_[i]->varId() << endl;
cout << "label = " << varNodes_[i]->label() << endl;
cout << "range = " << varNodes_[i]->range() << endl;
cout << "evidence = " << varNodes_[i]->getEvidence() << endl;
cout << "factors = " ;
for (unsigned j = 0; j < varNodes_[i]->neighbors().size(); j++) {
cout << varNodes_[i]->neighbors()[j]->getLabel() << " " ;
}
cout << endl << endl;
}
for (unsigned i = 0; i < facNodes_.size(); i++) {
facNodes_[i]->factor()->print();
cout << endl;
facNodes_[i]->factor().print();
}
}
@ -366,31 +289,26 @@ FactorGraph::exportToGraphViz (const char* fileName) const
cerr << "FactorGraph::exportToDotFile()" << endl;
abort();
}
out << "graph \"" << fileName << "\" {" << endl;
for (unsigned i = 0; i < varNodes_.size(); i++) {
if (varNodes_[i]->hasEvidence()) {
out << '"' << varNodes_[i]->label() << '"' ;
out << " [style=filled, fillcolor=yellow]" << endl;
}
}
for (unsigned i = 0; i < facNodes_.size(); i++) {
out << '"' << facNodes_[i]->getLabel() << '"' ;
out << " [label=\"" << facNodes_[i]->getLabel();
out << "\"" << ", shape=box]" << endl;
}
for (unsigned i = 0; i < facNodes_.size(); i++) {
const FgVarSet& myVars = facNodes_[i]->neighbors();
const VarNodes& myVars = facNodes_[i]->neighbors();
for (unsigned j = 0; j < myVars.size(); j++) {
out << '"' << facNodes_[i]->getLabel() << '"' ;
out << " -- " ;
out << '"' << myVars[j]->label() << '"' << endl;
}
}
out << "}" << endl;
out.close();
}
@ -402,30 +320,26 @@ FactorGraph::exportToUaiFormat (const char* fileName) const
{
ofstream out (fileName);
if (!out.is_open()) {
cerr << "error: cannot open file to write at " ;
cerr << "FactorGraph::exportToUaiFormat()" << endl;
cerr << "error: cannot open file " << fileName << endl;
abort();
}
out << "MARKOV" << endl;
out << varNodes_.size() << endl;
for (unsigned i = 0; i < varNodes_.size(); i++) {
out << varNodes_[i]->nrStates() << " " ;
out << varNodes_[i]->range() << " " ;
}
out << endl;
out << facNodes_.size() << endl;
for (unsigned i = 0; i < facNodes_.size(); i++) {
const FgVarSet& factorVars = facNodes_[i]->neighbors();
const VarNodes& factorVars = facNodes_[i]->neighbors();
out << factorVars.size();
for (unsigned j = 0; j < factorVars.size(); j++) {
out << " " << factorVars[j]->getIndex();
}
out << endl;
}
for (unsigned i = 0; i < facNodes_.size(); i++) {
Params params = facNodes_[i]->params();
Params params = facNodes_[i]->factor().params();
if (Globals::logDomain) {
Util::fromLog (params);
}
@ -435,7 +349,6 @@ FactorGraph::exportToUaiFormat (const char* fileName) const
}
out << endl;
}
out.close();
}
@ -446,23 +359,22 @@ FactorGraph::exportToLibDaiFormat (const char* fileName) const
{
ofstream out (fileName);
if (!out.is_open()) {
cerr << "error: cannot open file to write at " ;
cerr << "FactorGraph::exportToLibDaiFormat()" << endl;
cerr << "error: cannot open file " << fileName << endl;
abort();
}
out << facNodes_.size() << endl << endl;
for (unsigned i = 0; i < facNodes_.size(); i++) {
const FgVarSet& factorVars = facNodes_[i]->neighbors();
const VarNodes& factorVars = facNodes_[i]->neighbors();
out << factorVars.size() << endl;
for (int j = factorVars.size() - 1; j >= 0; j--) {
out << factorVars[j]->varId() << " " ;
}
out << endl;
for (unsigned j = 0; j < factorVars.size(); j++) {
out << factorVars[j]->nrStates() << " " ;
out << factorVars[j]->range() << " " ;
}
out << endl;
Params params = facNodes_[i]->factor()->params();
Params params = facNodes_[i]->factor().params();
if (Globals::logDomain) {
Util::fromLog (params);
}
@ -477,6 +389,17 @@ FactorGraph::exportToLibDaiFormat (const char* fileName) const
void
FactorGraph::ignoreLines (std::ifstream& is) const
{
string ignoreStr;
while (is.peek() == '#' || is.peek() == '\n') {
getline (is, ignoreStr);
}
}
bool
FactorGraph::containsCycle (void) const
{
@ -496,13 +419,14 @@ FactorGraph::containsCycle (void) const
bool
FactorGraph::containsCycle (const FgVarNode* v,
const FgFacNode* p,
FactorGraph::containsCycle (
const VarNode* v,
const FacNode* p,
vector<bool>& visitedVars,
vector<bool>& visitedFactors) const
{
visitedVars[v->getIndex()] = true;
const FgFacSet& adjacencies = v->neighbors();
const FacNodes& adjacencies = v->neighbors();
for (unsigned i = 0; i < adjacencies.size(); i++) {
int w = adjacencies[i]->getIndex();
if (!visitedFactors[w]) {
@ -520,13 +444,14 @@ FactorGraph::containsCycle (const FgVarNode* v,
bool
FactorGraph::containsCycle (const FgFacNode* v,
const FgVarNode* p,
FactorGraph::containsCycle (
const FacNode* v,
const VarNode* p,
vector<bool>& visitedVars,
vector<bool>& visitedFactors) const
{
visitedFactors[v->getIndex()] = true;
const FgVarSet& adjacencies = v->neighbors();
const VarNodes& adjacencies = v->neighbors();
for (unsigned i = 0; i < adjacencies.size(); i++) {
int w = adjacencies[i]->getIndex();
if (!visitedVars[w]) {

View File

@ -3,136 +3,109 @@
#include <vector>
#include "GraphicalModel.h"
#include "Factor.h"
#include "BayesNet.h"
#include "Horus.h"
using namespace std;
class BayesNet;
class FgFacNode;
class FacNode;
class FgVarNode : public VarNode
class VarNode : public Var
{
public:
FgVarNode (VarId varId, unsigned nrStates) : VarNode (varId, nrStates) { }
VarNode (VarId varId, unsigned nrStates)
: Var (varId, nrStates) { }
FgVarNode (const VarNode* v) : VarNode (v) { }
VarNode (const Var* v) : Var (v) { }
void addNeighbor (FgFacNode* fn) { neighs_.push_back (fn); }
void addNeighbor (FacNode* fn) { neighs_.push_back (fn); }
const FgFacSet& neighbors (void) const { return neighs_; }
const FacNodes& neighbors (void) const { return neighs_; }
private:
DISALLOW_COPY_AND_ASSIGN (FgVarNode);
DISALLOW_COPY_AND_ASSIGN (VarNode);
FgFacSet neighs_;
FacNodes neighs_;
};
class FgFacNode
class FacNode
{
public:
FgFacNode (const FgFacNode* fn)
{
factor_ = new Factor (*fn->factor());
index_ = -1;
}
FacNode (const Factor& f) : factor_(f), index_(-1) { }
FgFacNode (Factor* f) : factor_(new Factor(*f)), index_(-1) { }
const Factor& factor (void) const { return factor_; }
Factor* factor() const { return factor_; }
Factor& factor (void) { return factor_; }
void addNeighbor (FgVarNode* vn) { neighs_.push_back (vn); }
void addNeighbor (VarNode* vn) { neighs_.push_back (vn); }
const FgVarSet& neighbors (void) const { return neighs_; }
const VarNodes& neighbors (void) const { return neighs_; }
int getIndex (void) const
{
assert (index_ != -1);
return index_;
}
int getIndex (void) const { return index_; }
void setIndex (int index)
{
index_ = index;
}
void setIndex (int index) { index_ = index; }
const Params& params (void) const
{
return factor_->params();
}
string getLabel (void)
{
return factor_->getLabel();
}
string getLabel (void) { return factor_.getLabel(); }
private:
DISALLOW_COPY_AND_ASSIGN (FgFacNode);
DISALLOW_COPY_AND_ASSIGN (FacNode);
Factor* factor_;
FgVarSet neighs_;
VarNodes neighs_;
Factor factor_;
int index_;
};
struct CompVarId
{
bool operator() (const VarNode* vn1, const VarNode* vn2) const
bool operator() (const Var* v1, const Var* v2) const
{
return vn1->varId() < vn2->varId();
return v1->varId() < v2->varId();
}
};
class FactorGraph : public GraphicalModel
class FactorGraph
{
public:
FactorGraph (void) { };
FactorGraph (bool fbn = false) : fromBayesNet_(fbn) { }
FactorGraph (const FactorGraph&);
FactorGraph (const BayesNet&);
~FactorGraph (void);
const FgVarSet& getVarNodes (void) const { return varNodes_; }
const VarNodes& varNodes (void) const { return varNodes_; }
const FgFacSet& getFactorNodes (void) const { return facNodes_; }
const FacNodes& facNodes (void) const { return facNodes_; }
FgVarNode* getFgVarNode (VarId vid) const
bool isFromBayesNetwork (void) const { return fromBayesNet_ ; }
VarNode* getVarNode (VarId vid) const
{
IndexMap::const_iterator it = varMap_.find (vid);
if (it == varMap_.end()) {
return 0;
} else {
return varNodes_[it->second];
}
VarMap::const_iterator it = varMap_.find (vid);
return it != varMap_.end() ? it->second : 0;
}
void readFromUaiFormat (const char*);
void readFromLibDaiFormat (const char*);
void addVariable (FgVarNode*);
void addFactor (const Factor& factor);
void addFactor (FgFacNode*);
void addVarNode (VarNode*);
void addEdge (FgVarNode*, FgFacNode*);
void addFacNode (FacNode*);
void addEdge (FgFacNode*, FgVarNode*);
VarNode* getVariableNode (unsigned) const;
VarNodes getVariableNodes (void) const;
void addEdge (VarNode*, FacNode*);
bool isTree (void) const;
void setIndexes (void);
DAGraph& getStructure (void);
void printGraphicalModel (void) const;
void print (void) const;
void exportToGraphViz (const char*) const;
@ -145,19 +118,24 @@ class FactorGraph : public GraphicalModel
private:
// DISALLOW_COPY_AND_ASSIGN (FactorGraph);
void ignoreLines (std::ifstream&) const;
bool containsCycle (void) const;
bool containsCycle (const FgVarNode*, const FgFacNode*,
bool containsCycle (const VarNode*, const FacNode*,
vector<bool>&, vector<bool>&) const;
bool containsCycle (const FgFacNode*, const FgVarNode*,
bool containsCycle (const FacNode*, const VarNode*,
vector<bool>&, vector<bool>&) const;
FgVarSet varNodes_;
FgFacSet facNodes_;
VarNodes varNodes_;
FacNodes facNodes_;
typedef unordered_map<unsigned, unsigned> IndexMap;
IndexMap varMap_;
DAGraph structure_;
bool fromBayesNet_;
typedef unordered_map<unsigned, VarNode*> VarMap;
VarMap varMap_;
};
#endif // HORUS_FACTORGRAPH_H

View File

@ -455,7 +455,7 @@ FoveSolver::absorveEvidence (
}
pfList.add (newPfs);
}
if (Constants::DEBUG > 1 && obsFormulas.empty() == false) {
if (Constants::DEBUG >= 2 && obsFormulas.empty() == false) {
Util::printAsteriskLine();
cout << "AFTER EVIDENCE ABSORVED" << endl;
for (unsigned i = 0; i < obsFormulas.size(); i++) {
@ -493,7 +493,7 @@ FoveSolver::runSolver (const Grounds& query)
shatterAgainstQuery (query);
runWeakBayesBall (query);
while (true) {
if (Constants::DEBUG > 1) {
if (Constants::DEBUG >= 2) {
Util::printDashedLine();
pfList_.print();
LiftedOperator::printValidOps (pfList_, query);
@ -502,7 +502,7 @@ FoveSolver::runSolver (const Grounds& query)
if (op == 0) {
break;
}
if (Constants::DEBUG > 1) {
if (Constants::DEBUG >= 2) {
cout << "best operation: " << op->toString() << endl;
}
op->apply();
@ -594,7 +594,7 @@ FoveSolver::runWeakBayesBall (const Grounds& query)
}
}
if (Constants::DEBUG > 1) {
if (Constants::DEBUG >= 2) {
Util::printHeader ("REQUIRED PARFACTORS");
pfList_.print();
}
@ -605,15 +605,16 @@ FoveSolver::runWeakBayesBall (const Grounds& query)
void
FoveSolver::shatterAgainstQuery (const Grounds& query)
{
return ;
for (unsigned i = 0; i < query.size(); i++) {
if (query[i].isAtom()) {
continue;
}
bool found = false;
Parfactors newPfs;
ParfactorList::iterator it = pfList_.begin();
while (it != pfList_.end()) {
if ((*it)->containsGround (query[i])) {
found = true;
std::pair<ConstraintTree*, ConstraintTree*> split =
(*it)->constr()->split (query[i].args(), query[i].arity());
ConstraintTree* commCt = split.first;
@ -629,9 +630,14 @@ FoveSolver::shatterAgainstQuery (const Grounds& query)
++ it;
}
}
if (found == false) {
cerr << "error: could not find a parfactor with ground " ;
cerr << "`" << query[i] << "'" << endl;
exit (0);
}
pfList_.add (newPfs);
}
if (Constants::DEBUG > 1) {
if (Constants::DEBUG >= 2) {
cout << endl;
Util::printAsteriskLine();
cout << "SHATTERED AGAINST THE QUERY" << endl;

View File

@ -1,64 +0,0 @@
#ifndef HORUS_GRAPHICALMODEL_H
#define HORUS_GRAPHICALMODEL_H
#include <cassert>
#include <unordered_map>
#include <sstream>
#include "VarNode.h"
#include "Util.h"
#include "Horus.h"
using namespace std;
struct VarInfo
{
VarInfo (string l, const States& sts) : label(l), states(sts) { }
string label;
States states;
};
class GraphicalModel
{
public:
virtual ~GraphicalModel (void) { };
virtual VarNode* getVariableNode (VarId) const = 0;
virtual VarNodes getVariableNodes (void) const = 0;
virtual void printGraphicalModel (void) const = 0;
static void addVariableInformation (
VarId vid, string label, const States& states)
{
assert (Util::contains (varsInfo_, vid) == false);
varsInfo_.insert (make_pair (vid, VarInfo (label, states)));
}
static VarInfo getVarInformation (VarId vid)
{
assert (Util::contains (varsInfo_, vid));
return varsInfo_.find (vid)->second;
}
static bool variablesHaveInformation (void)
{
return varsInfo_.size() != 0;
}
static void clearVariablesInformation (void)
{
varsInfo_.clear();
}
private:
static unordered_map<VarId,VarInfo> varsInfo_;
};
#endif // HORUS_GRAPHICALMODEL_H

View File

@ -11,20 +11,18 @@
using namespace std;
class VarNode;
class BayesNode;
class FgVarNode;
class FgFacNode;
class Var;
class Factor;
class VarNode;
class FacNode;
typedef vector<double> Params;
typedef unsigned VarId;
typedef vector<VarId> VarIds;
typedef vector<Var*> Vars;
typedef vector<VarNode*> VarNodes;
typedef vector<BayesNode*> BnNodeSet;
typedef vector<FgVarNode*> FgVarSet;
typedef vector<FgFacNode*> FgFacSet;
typedef vector<Factor*> FactorSet;
typedef vector<FacNode*> FacNodes;
typedef vector<Factor*> Factors;
typedef vector<string> States;
typedef vector<unsigned> Ranges;
@ -32,9 +30,8 @@ typedef vector<unsigned> Ranges;
enum InfAlgorithms
{
VE, // variable elimination
BN_BP, // bayesian network belief propagation
FG_BP, // factor graph belief propagation
CBP // counting bp solver
BP, // belief propagation
CBP // counting belief propagation
};
@ -50,7 +47,7 @@ extern InfAlgorithms infAlgorithm;
namespace Constants {
// level of debug information
const unsigned DEBUG = 2;
const unsigned DEBUG = 0;
const int NO_EVIDENCE = -1;

View File

@ -3,137 +3,70 @@
#include <iostream>
#include <sstream>
#include "BayesNet.h"
#include "FactorGraph.h"
#include "VarElimSolver.h"
#include "BnBpSolver.h"
#include "FgBpSolver.h"
#include "BpSolver.h"
#include "CbpSolver.h"
using namespace std;
void processArguments (BayesNet&, int, const char* []);
void processArguments (FactorGraph&, int, const char* []);
void runSolver (Solver*, const VarNodes&);
void runSolver (const FactorGraph&, const VarIds&);
const string USAGE = "usage: \
./hcli FILE [VARIABLE | OBSERVED_VARIABLE=EVIDENCE]..." ;
./hcli ve|bp|cbp NETWORK_FILE [VARIABLE | OBSERVED_VARIABLE=EVIDENCE]..." ;
int
main (int argc, const char* argv[])
{
if (!argv[1]) {
if (argc <= 1) {
cerr << "error: no solver specified" << endl;
cerr << "error: no graphical model specified" << endl;
cerr << USAGE << endl;
exit (0);
}
const string& fileName = argv[1];
const string& extension = fileName.substr (fileName.find_last_of ('.') + 1);
if (extension == "xml") {
BayesNet bn;
bn.readFromBifFormat (argv[1]);
processArguments (bn, argc, argv);
} else if (extension == "uai") {
if (argc <= 2) {
cerr << "error: no graphical model specified" << endl;
cerr << USAGE << endl;
exit (0);
}
string solver (argv[1]);
if (solver == "ve") {
Globals::infAlgorithm = InfAlgorithms::VE;
} else if (solver == "bp") {
Globals::infAlgorithm = InfAlgorithms::BP;
} else if (solver == "cbp") {
Globals::infAlgorithm = InfAlgorithms::CBP;
} else {
cerr << "error: unknow solver `" << solver << "'" << endl ;
cerr << USAGE << endl;
exit(0);
}
string fileName (argv[2]);
string extension = fileName.substr (
fileName.find_last_of ('.') + 1);
FactorGraph fg;
fg.readFromUaiFormat (argv[1]);
processArguments (fg, argc, argv);
if (extension == "uai") {
fg.readFromUaiFormat (fileName.c_str());
} else if (extension == "fg") {
FactorGraph fg;
fg.readFromLibDaiFormat (argv[1]);
processArguments (fg, argc, argv);
fg.readFromLibDaiFormat (fileName.c_str());
} else {
cerr << "error: the graphical model must be defined either " ;
cerr << "in a xml, uai or libDAI file" << endl;
cerr << "in a UAI or libDAI file" << endl;
exit (0);
}
processArguments (fg, argc, argv);
return 0;
}
void
processArguments (BayesNet& bn, int argc, const char* argv[])
{
VarNodes queryVars;
for (int i = 2; i < argc; i++) {
const string& arg = argv[i];
if (arg.find ('=') == std::string::npos) {
BayesNode* queryVar = bn.getBayesNode (arg);
if (queryVar) {
queryVars.push_back (queryVar);
} else {
cerr << "error: there isn't a variable labeled of " ;
cerr << "`" << arg << "'" ;
cerr << endl;
exit (0);
}
} else {
size_t pos = arg.find ('=');
const string& label = arg.substr (0, pos);
const string& state = arg.substr (pos + 1);
if (label.empty()) {
cerr << "error: missing left argument" << endl;
cerr << USAGE << endl;
exit (0);
}
if (state.empty()) {
cerr << "error: missing right argument" << endl;
cerr << USAGE << endl;
exit (0);
}
BayesNode* node = bn.getBayesNode (label);
if (node) {
if (node->isValidState (state)) {
node->setEvidence (state);
} else {
cerr << "error: `" << state << "' " ;
cerr << "is not a valid state for " ;
cerr << "`" << node->label() << "'" ;
cerr << endl;
exit (0);
}
} else {
cerr << "error: there isn't a variable labeled of " ;
cerr << "`" << label << "'" ;
cerr << endl;
exit (0);
}
}
}
Solver* solver = 0;
FactorGraph* fg = 0;
switch (Globals::infAlgorithm) {
case InfAlgorithms::VE:
fg = new FactorGraph (bn);
solver = new VarElimSolver (*fg);
break;
case InfAlgorithms::BN_BP:
solver = new BnBpSolver (bn);
break;
case InfAlgorithms::FG_BP:
fg = new FactorGraph (bn);
solver = new FgBpSolver (*fg);
break;
case InfAlgorithms::CBP:
fg = new FactorGraph (bn);
solver = new CbpSolver (*fg);
break;
default:
assert (false);
}
runSolver (solver, queryVars);
delete fg;
}
void
processArguments (FactorGraph& fg, int argc, const char* argv[])
{
VarNodes queryVars;
for (int i = 2; i < argc; i++) {
VarIds queryIds;
for (int i = 3; i < argc; i++) {
const string& arg = argv[i];
if (arg.find ('=') == std::string::npos) {
if (!Util::isInteger (arg)) {
@ -146,9 +79,9 @@ processArguments (FactorGraph& fg, int argc, const char* argv[])
stringstream ss;
ss << arg;
ss >> vid;
VarNode* queryVar = fg.getFgVarNode (vid);
VarNode* queryVar = fg.getVarNode (vid);
if (queryVar) {
queryVars.push_back (queryVar);
queryIds.push_back (vid);
} else {
cerr << "error: there isn't a variable with " ;
cerr << "`" << vid << "' as id" ;
@ -177,7 +110,7 @@ processArguments (FactorGraph& fg, int argc, const char* argv[])
stringstream ss;
ss << arg.substr (0, pos);
ss >> vid;
VarNode* var = fg.getFgVarNode (vid);
VarNode* var = fg.getVarNode (vid);
if (var) {
if (!Util::isInteger (arg.substr (pos + 1))) {
cerr << "error: `" << arg.substr (pos + 1) << "' " ;
@ -206,14 +139,21 @@ processArguments (FactorGraph& fg, int argc, const char* argv[])
}
}
}
runSolver (fg, queryIds);
}
void
runSolver (const FactorGraph& fg, const VarIds& queryIds)
{
Solver* solver = 0;
switch (Globals::infAlgorithm) {
case InfAlgorithms::VE:
solver = new VarElimSolver (fg);
break;
case InfAlgorithms::BN_BP:
case InfAlgorithms::FG_BP:
solver = new FgBpSolver (fg);
case InfAlgorithms::BP:
solver = new BpSolver (fg);
break;
case InfAlgorithms::CBP:
solver = new CbpSolver (fg);
@ -221,27 +161,10 @@ processArguments (FactorGraph& fg, int argc, const char* argv[])
default:
assert (false);
}
runSolver (solver, queryVars);
}
void
runSolver (Solver* solver, const VarNodes& queryVars)
{
VarIds vids;
for (unsigned i = 0; i < queryVars.size(); i++) {
vids.push_back (queryVars[i]->varId());
}
if (queryVars.size() == 0) {
solver->runSolver();
if (queryIds.size() == 0) {
solver->printAllPosterioris();
} else if (queryVars.size() == 1) {
solver->runSolver();
solver->printPosterioriOf (vids[0]);
} else {
solver->runSolver();
solver->printJointDistributionOf (vids);
solver->printAnswer (queryIds);
}
delete solver;
}

View File

@ -8,14 +8,13 @@
#include <YapInterface.h>
#include "ParfactorList.h"
#include "BayesNet.h"
#include "FactorGraph.h"
#include "FoveSolver.h"
#include "VarElimSolver.h"
#include "BnBpSolver.h"
#include "FgBpSolver.h"
#include "BpSolver.h"
#include "CbpSolver.h"
#include "ElimGraph.h"
#include "BayesBall.h"
using namespace std;
@ -24,10 +23,35 @@ using namespace std;
typedef std::pair<ParfactorList*, ObservedFormulas*> LiftedNetwork;
Params readParams (YAP_Term);
Params readParameters (YAP_Term);
vector<unsigned> readUnsignedList (YAP_Term);
void readLiftedEvidence (YAP_Term, ObservedFormulas&);
Parfactor* readParfactor (YAP_Term);
void runVeSolver (FactorGraph* fg, const vector<VarIds>& tasks,
vector<Params>& results);
void runBpSolver (FactorGraph* fg, const vector<VarIds>& tasks,
vector<Params>& results);
vector<unsigned>
readUnsignedList (YAP_Term list)
{
vector<unsigned> vec;
while (list != YAP_TermNil()) {
vec.push_back ((unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (list)));
list = YAP_TailOfTerm (list);
}
return vec;
}
int createLiftedNetwork (void)
{
@ -40,20 +64,17 @@ int createLiftedNetwork (void)
}
// LiftedUtils::printSymbolDictionary();
if (Constants::DEBUG > 1) {
if (Constants::DEBUG > 2) {
// Util::printHeader ("INITIAL PARFACTORS");
// for (unsigned i = 0; i < parfactors.size(); i++) {
// parfactors[i]->print();
// cout << endl;
// }
// parfactors[0]->countConvert (LogVar (0));
//parfactors[1]->fullExpand (LogVar (1));
Util::printHeader ("SHATTERED PARFACTORS");
}
ParfactorList* pfList = new ParfactorList (parfactors);
if (Constants::DEBUG > 1) {
if (Constants::DEBUG >= 2) {
Util::printHeader ("SHATTERED PARFACTORS");
pfList->print();
}
@ -117,7 +138,7 @@ Parfactor* readParfactor (YAP_Term pfTerm)
}
// read the parameters
const Params& params = readParams (YAP_ArgOfTerm (4, pfTerm));
const Params& params = readParameters (YAP_ArgOfTerm (4, pfTerm));
// read the constraint
Tuples tuples;
@ -195,54 +216,45 @@ void readLiftedEvidence (
int
createGroundNetwork (void)
{
Statistics::incrementPrimaryNetworksCounting();
// cout << "creating network number " ;
// cout << Statistics::getPrimaryNetworksCounting() << endl;
// if (Statistics::getPrimaryNetworksCounting() > 98) {
// Statistics::writeStatisticsToFile ("../../compressing.stats");
// }
BayesNet* bn = new BayesNet();
YAP_Term varList = YAP_ARG1;
vector<VarIds> parents;
while (varList != YAP_TermNil()) {
YAP_Term var = YAP_HeadOfTerm (varList);
VarId vid = (VarId) YAP_IntOfTerm (YAP_ArgOfTerm (1, var));
unsigned dsize = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (2, var));
int evidence = (int) YAP_IntOfTerm (YAP_ArgOfTerm (3, var));
YAP_Term parentL = YAP_ArgOfTerm (4, var);
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (5, var));
parents.push_back (VarIds());
while (parentL != YAP_TermNil()) {
unsigned parentId = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (parentL));
parents.back().push_back (parentId);
parentL = YAP_TailOfTerm (parentL);
string factorsType ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG1)));
bool fromBayesNet = factorsType == "bayes";
FactorGraph* fg = new FactorGraph (fromBayesNet);
YAP_Term factorList = YAP_ARG2;
while (factorList != YAP_TermNil()) {
YAP_Term factor = YAP_HeadOfTerm (factorList);
// read the var ids
VarIds varIds = readUnsignedList (YAP_ArgOfTerm (1, factor));
// read the ranges
Ranges ranges = readUnsignedList (YAP_ArgOfTerm (2, factor));
// read the parameters
Params params = readParameters (YAP_ArgOfTerm (3, factor));
// read dist id
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (4, factor));
fg->addFactor (Factor (varIds, ranges, params, distId));
factorList = YAP_TailOfTerm (factorList);
}
assert (bn->getBayesNode (vid) == 0);
BayesNode* newNode = new BayesNode (
vid, dsize, evidence, Params(), distId);
bn->addNode (newNode);
varList = YAP_TailOfTerm (varList);
YAP_Term evidenceList = YAP_ARG3;
while (evidenceList != YAP_TermNil()) {
YAP_Term evTerm = YAP_HeadOfTerm (evidenceList);
unsigned vid = (unsigned) YAP_IntOfTerm ((YAP_ArgOfTerm (1, evTerm)));
unsigned ev = (unsigned) YAP_IntOfTerm ((YAP_ArgOfTerm (2, evTerm)));
assert (fg->getVarNode (vid));
fg->getVarNode (vid)->setEvidence (ev);
evidenceList = YAP_TailOfTerm (evidenceList);
}
const BnNodeSet& nodes = bn->getBayesNodes();
for (unsigned i = 0; i < nodes.size(); i++) {
BnNodeSet ps;
for (unsigned j = 0; j < parents[i].size(); j++) {
assert (bn->getBayesNode (parents[i][j]) != 0);
ps.push_back (bn->getBayesNode (parents[i][j]));
}
nodes[i]->setParents (ps);
}
bn->setIndexes();
YAP_Int p = (YAP_Int) (bn);
return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG2);
YAP_Int p = (YAP_Int) (fg);
return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG4);
}
Params
readParams (YAP_Term paramL)
readParameters (YAP_Term paramL)
{
Params params;
assert (YAP_IsPairTerm (paramL));
while (paramL != YAP_TermNil()) {
params.push_back ((double) YAP_FloatOfTerm (YAP_HeadOfTerm (paramL)));
paramL = YAP_TailOfTerm (paramL);
@ -319,68 +331,21 @@ runLiftedSolver (void)
int
runGroundSolver (void)
{
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
YAP_Term taskList = YAP_ARG2;
FactorGraph* fg = (FactorGraph*) YAP_IntOfTerm (YAP_ARG1);
vector<VarIds> tasks;
std::set<VarId> vids;
YAP_Term taskList = YAP_ARG2;
while (taskList != YAP_TermNil()) {
VarIds queryVars;
YAP_Term jointList = YAP_HeadOfTerm (taskList);
while (jointList != YAP_TermNil()) {
VarId vid = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (jointList));
assert (bn->getBayesNode (vid));
queryVars.push_back (vid);
vids.insert (vid);
jointList = YAP_TailOfTerm (jointList);
}
tasks.push_back (queryVars);
tasks.push_back (readUnsignedList (YAP_HeadOfTerm (taskList)));
taskList = YAP_TailOfTerm (taskList);
}
Solver* bpSolver = 0;
GraphicalModel* graphicalModel = 0;
CFactorGraph::checkForIdenticalFactors = false;
if (Globals::infAlgorithm != InfAlgorithms::VE) {
BayesNet* mrn = bn->getMinimalRequesiteNetwork (
VarIds (vids.begin(), vids.end()));
if (Globals::infAlgorithm == InfAlgorithms::BN_BP) {
graphicalModel = mrn;
bpSolver = new BnBpSolver (*static_cast<BayesNet*> (graphicalModel));
} else if (Globals::infAlgorithm == InfAlgorithms::FG_BP) {
graphicalModel = new FactorGraph (*mrn);
bpSolver = new FgBpSolver (*static_cast<FactorGraph*> (graphicalModel));
delete mrn;
} else if (Globals::infAlgorithm == InfAlgorithms::CBP) {
graphicalModel = new FactorGraph (*mrn);
bpSolver = new CbpSolver (*static_cast<FactorGraph*> (graphicalModel));
delete mrn;
}
bpSolver->runSolver();
}
vector<Params> results;
results.reserve (tasks.size());
for (unsigned i = 0; i < tasks.size(); i++) {
if (Globals::infAlgorithm == InfAlgorithms::VE) {
BayesNet* mrn = bn->getMinimalRequesiteNetwork (tasks[i]);
VarElimSolver* veSolver = new VarElimSolver (*mrn);
if (tasks[i].size() == 1) {
results.push_back (veSolver->getPosterioriOf (tasks[i][0]));
runVeSolver (fg, tasks, results);
} else {
results.push_back (veSolver->getJointDistributionOf (tasks[i]));
runBpSolver (fg, tasks, results);
}
delete mrn;
delete veSolver;
} else {
if (tasks[i].size() == 1) {
results.push_back (bpSolver->getPosterioriOf (tasks[i][0]));
} else {
results.push_back (bpSolver->getJointDistributionOf (tasks[i]));
}
}
}
delete bpSolver;
delete graphicalModel;
YAP_Term list = YAP_TermNil();
for (int i = results.size() - 1; i >= 0; i--) {
@ -395,12 +360,68 @@ runGroundSolver (void)
}
list = YAP_MkPairTerm (queryBeliefsL, list);
}
return YAP_Unify (list, YAP_ARG3);
}
void runVeSolver (
FactorGraph* fg,
const vector<VarIds>& tasks,
vector<Params>& results)
{
results.reserve (tasks.size());
for (unsigned i = 0; i < tasks.size(); i++) {
FactorGraph* mfg = fg;
if (fg->isFromBayesNetwork()) {
mfg = BayesBall::getMinimalFactorGraph (*fg, tasks[i]);
}
VarElimSolver solver (*mfg);
results.push_back (solver.solveQuery (tasks[i]));
if (fg->isFromBayesNetwork()) {
delete mfg;
}
}
}
void runBpSolver (
FactorGraph* fg,
const vector<VarIds>& tasks,
vector<Params>& results)
{
std::set<VarId> vids;
for (unsigned i = 0; i < tasks.size(); i++) {
Util::addToSet (vids, tasks[i]);
}
Solver* solver = 0;
FactorGraph* mfg = fg;
if (fg->isFromBayesNetwork()) {
mfg = BayesBall::getMinimalFactorGraph (
*fg, VarIds (vids.begin(),vids.end()));
}
if (Globals::infAlgorithm == InfAlgorithms::BP) {
solver = new BpSolver (*mfg);
} else if (Globals::infAlgorithm == InfAlgorithms::CBP) {
CFactorGraph::checkForIdenticalFactors = false;
solver = new CbpSolver (*mfg);
} else {
cerr << "error: unknow solver" << endl;
abort();
}
results.reserve (tasks.size());
for (unsigned i = 0; i < tasks.size(); i++) {
results.push_back (solver->solveQuery (tasks[i]));
}
if (fg->isFromBayesNetwork()) {
delete mfg;
}
delete solver;
}
int
setParfactorsParams (void)
{
@ -412,7 +433,7 @@ setParfactorsParams (void)
YAP_Term dist = YAP_HeadOfTerm (distList);
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, dist));
assert (Util::contains (paramsMap, distId) == false);
paramsMap[distId] = readParams (YAP_ArgOfTerm (2, dist));
paramsMap[distId] = readParameters (YAP_ArgOfTerm (2, dist));
distList = YAP_TailOfTerm (distList);
}
ParfactorList::iterator it = pfList->begin();
@ -427,22 +448,24 @@ setParfactorsParams (void)
int
setBayesNetParams (void)
setFactorsParams (void)
{
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
return TRUE; // TODO
FactorGraph* fg = (FactorGraph*) YAP_IntOfTerm (YAP_ARG1);
YAP_Term distList = YAP_ARG2;
unordered_map<unsigned, Params> paramsMap;
while (distList != YAP_TermNil()) {
YAP_Term dist = YAP_HeadOfTerm (distList);
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, dist));
assert (Util::contains (paramsMap, distId) == false);
paramsMap[distId] = readParams (YAP_ArgOfTerm (2, dist));
paramsMap[distId] = readParameters (YAP_ArgOfTerm (2, dist));
distList = YAP_TailOfTerm (distList);
}
const BnNodeSet& nodes = bn->getBayesNodes();
for (unsigned i = 0; i < nodes.size(); i++) {
assert (Util::contains (paramsMap, nodes[i]->distId()));
nodes[i]->setParams (paramsMap[nodes[i]->distId()]);
const FacNodes& facNodes = fg->facNodes();
for (unsigned i = 0; i < facNodes.size(); i++) {
unsigned distId = facNodes[i]->factor().distId();
assert (Util::contains (paramsMap, distId));
facNodes[i]->factor().setParams (paramsMap[distId]);
}
return TRUE;
}
@ -450,24 +473,29 @@ setBayesNetParams (void)
int
setExtraVarsInfo (void)
setVarsInformation (void)
{
GraphicalModel::clearVariablesInformation();
YAP_Term varsInfoL = YAP_ARG2;
while (varsInfoL != YAP_TermNil()) {
YAP_Term head = YAP_HeadOfTerm (varsInfoL);
VarId vid = YAP_IntOfTerm (YAP_ArgOfTerm (1, head));
YAP_Atom label = YAP_AtomOfTerm (YAP_ArgOfTerm (2, head));
YAP_Term statesL = YAP_ArgOfTerm (3, head);
States states;
while (statesL != YAP_TermNil()) {
YAP_Atom atom = YAP_AtomOfTerm (YAP_HeadOfTerm (statesL));
states.push_back ((char*) YAP_AtomName (atom));
statesL = YAP_TailOfTerm (statesL);
Var::clearVarsInfo();
YAP_Term labelsL = YAP_ARG1;
vector<string> labels;
while (labelsL != YAP_TermNil()) {
YAP_Atom atom = YAP_AtomOfTerm (YAP_HeadOfTerm (labelsL));
labels.push_back ((char*) YAP_AtomName (atom));
labelsL = YAP_TailOfTerm (labelsL);
}
GraphicalModel::addVariableInformation (vid,
(char*) YAP_AtomName (label), states);
varsInfoL = YAP_TailOfTerm (varsInfoL);
unsigned count = 0;
YAP_Term stateNamesL = YAP_ARG2;
while (stateNamesL != YAP_TermNil()) {
States states;
YAP_Term namesL = YAP_HeadOfTerm (stateNamesL);
while (namesL != YAP_TermNil()) {
YAP_Atom atom = YAP_AtomOfTerm (YAP_HeadOfTerm (namesL));
states.push_back ((char*) YAP_AtomName (atom));
namesL = YAP_TailOfTerm (namesL);
}
Var::addVarInfo (count, labels[count], states);
count ++;
stateNamesL = YAP_TailOfTerm (stateNamesL);
}
return TRUE;
}
@ -482,10 +510,8 @@ setHorusFlag (void)
string value ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG2)));
if ( value == "ve") {
Globals::infAlgorithm = InfAlgorithms::VE;
} else if (value == "bn_bp") {
Globals::infAlgorithm = InfAlgorithms::BN_BP;
} else if (value == "fg_bp") {
Globals::infAlgorithm = InfAlgorithms::FG_BP;
} else if (value == "bp") {
Globals::infAlgorithm = InfAlgorithms::BP;
} else if (value == "cbp") {
Globals::infAlgorithm = InfAlgorithms::CBP;
} else {
@ -559,9 +585,9 @@ setHorusFlag (void)
int
freeBayesNetwork (void)
freeGroundNetwork (void)
{
delete (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
delete (FactorGraph*) YAP_IntOfTerm (YAP_ARG1);
return TRUE;
}
@ -583,14 +609,14 @@ extern "C" void
init_predicates (void)
{
YAP_UserCPredicate ("create_lifted_network", createLiftedNetwork, 3);
YAP_UserCPredicate ("create_ground_network", createGroundNetwork, 2);
YAP_UserCPredicate ("create_ground_network", createGroundNetwork, 4);
YAP_UserCPredicate ("run_lifted_solver", runLiftedSolver, 3);
YAP_UserCPredicate ("run_ground_solver", runGroundSolver, 3);
YAP_UserCPredicate ("set_parfactors_params", setParfactorsParams, 2);
YAP_UserCPredicate ("set_bayes_net_params", setBayesNetParams, 2);
YAP_UserCPredicate ("set_extra_vars_info", setExtraVarsInfo, 2);
YAP_UserCPredicate ("set_factors_params", setFactorsParams, 2);
YAP_UserCPredicate ("set_vars_information", setVarsInformation, 2);
YAP_UserCPredicate ("set_horus_flag", setHorusFlag, 2);
YAP_UserCPredicate ("free_parfactors", freeParfactors, 1);
YAP_UserCPredicate ("free_bayesian_network", freeBayesNetwork, 1);
YAP_UserCPredicate ("free_ground_network", freeGroundNetwork, 1);
}

View File

@ -8,7 +8,7 @@
#include <sstream>
#include <iomanip>
#include "VarNode.h"
#include "Var.h"
#include "Util.h"
@ -31,14 +31,14 @@ class StatesIndexer
}
}
StatesIndexer (const VarNodes& vars, bool calcOffsets = true)
StatesIndexer (const Vars& vars, bool calcOffsets = true)
{
size_ = 1;
indices_.resize (vars.size(), 0);
ranges_.reserve (vars.size());
for (unsigned i = 0; i < vars.size(); i++) {
ranges_.push_back (vars[i]->nrStates());
size_ *= vars[i]->nrStates();
ranges_.push_back (vars[i]->range());
size_ *= vars[i]->range();
}
li_ = 0;
if (calcOffsets) {

View File

@ -45,9 +45,8 @@ CWD=$(PWD)
HEADERS = \
$(srcdir)/GraphicalModel.h \
$(srcdir)/BayesNet.h \
$(srcdir)/BayesNode.h \
$(srcdir)/BayesBall.h \
$(srcdir)/ElimGraph.h \
$(srcdir)/FactorGraph.h \
$(srcdir)/Factor.h \
@ -55,11 +54,10 @@ HEADERS = \
$(srcdir)/ConstraintTree.h \
$(srcdir)/Solver.h \
$(srcdir)/VarElimSolver.h \
$(srcdir)/BnBpSolver.h \
$(srcdir)/FgBpSolver.h \
$(srcdir)/BpSolver.h \
$(srcdir)/CbpSolver.h \
$(srcdir)/FoveSolver.h \
$(srcdir)/VarNode.h \
$(srcdir)/Var.h \
$(srcdir)/Indexer.h \
$(srcdir)/Parfactor.h \
$(srcdir)/ProbFormula.h \
@ -68,22 +66,20 @@ HEADERS = \
$(srcdir)/LiftedUtils.h \
$(srcdir)/TinySet.h \
$(srcdir)/Util.h \
$(srcdir)/Horus.h \
$(srcdir)/xmlParser/xmlParser.h
$(srcdir)/Horus.h
CPP_SOURCES = \
$(srcdir)/BayesNet.cpp \
$(srcdir)/BayesNode.cpp \
$(srcdir)/BayesBall.cpp \
$(srcdir)/ElimGraph.cpp \
$(srcdir)/FactorGraph.cpp \
$(srcdir)/Factor.cpp \
$(srcdir)/CFactorGraph.cpp \
$(srcdir)/ConstraintTree.cpp \
$(srcdir)/VarNode.cpp \
$(srcdir)/Var.cpp \
$(srcdir)/Solver.cpp \
$(srcdir)/VarElimSolver.cpp \
$(srcdir)/BnBpSolver.cpp \
$(srcdir)/FgBpSolver.cpp \
$(srcdir)/BpSolver.cpp \
$(srcdir)/CbpSolver.cpp \
$(srcdir)/FoveSolver.cpp \
$(srcdir)/Parfactor.cpp \
@ -93,22 +89,20 @@ CPP_SOURCES = \
$(srcdir)/LiftedUtils.cpp \
$(srcdir)/Util.cpp \
$(srcdir)/HorusYap.cpp \
$(srcdir)/HorusCli.cpp \
$(srcdir)/xmlParser/xmlParser.cpp
$(srcdir)/HorusCli.cpp
OBJS = \
BayesNet.o \
BayesNode.o \
BayesBall.o \
ElimGraph.o \
FactorGraph.o \
Factor.o \
CFactorGraph.o \
ConstraintTree.o \
VarNode.o \
Var.o \
Solver.o \
VarElimSolver.o \
BnBpSolver.o \
FgBpSolver.o \
BpSolver.o \
CbpSolver.o \
FoveSolver.o \
Parfactor.o \
@ -121,17 +115,16 @@ OBJS = \
HCLI_OBJS = \
BayesNet.o \
BayesNode.o \
BayesBall.o \
ElimGraph.o \
FactorGraph.o \
Factor.o \
CFactorGraph.o \
ConstraintTree.o \
VarNode.o \
Var.o \
Solver.o \
VarElimSolver.o \
BnBpSolver.o \
FgBpSolver.o \
BpSolver.o \
CbpSolver.o \
FoveSolver.o \
Parfactor.o \
@ -140,7 +133,6 @@ HCLI_OBJS = \
ParfactorList.o \
LiftedUtils.o \
Util.o \
xmlParser/xmlParser.o \
HorusCli.o
SOBJS=horus.@SO@
@ -153,10 +145,6 @@ all: $(SOBJS) hcli
$(CXX) -c $(CXXFLAGS) $< -o $@
xmlParser/xmlParser.o : $(srcdir)/xmlParser/xmlParser.cpp
$(CXX) -c $(CXXFLAGS) $< -o $@
@DO_SECOND_LD@horus.@SO@: $(OBJS)
@DO_SECOND_LD@ @SHLIB_CXX_LD@ -o horus.@SO@ $(OBJS) @EXTRA_LIBS_FOR_SWIDLLS@
@ -170,7 +158,7 @@ install: all
clean:
rm -f *.o *~ $(OBJS) $(SOBJS) *.BAK hcli xmlParser/*.o
rm -f *.o *~ $(OBJS) $(SOBJS) *.BAK hcli
erase_dots:

View File

@ -528,11 +528,13 @@ Parfactor::print (bool printParams) const
cout << args_[i];
}
cout << endl;
if (args_[0].group() != Util::maxUnsigned()) {
vector<string> groups;
for (unsigned i = 0; i < args_.size(); i++) {
groups.push_back (string ("g") + Util::toString (args_[i].group()));
}
cout << "Groups: " << groups << endl;
}
cout << "LogVars: " << constr_->logVarSet() << endl;
cout << "Ranges: " << ranges_ << endl;
if (printParams == false) {
@ -570,6 +572,7 @@ Parfactor::print (bool printParams) const
cout << " = " << params_[i] << endl;
}
}
cout << endl;
}

View File

@ -116,7 +116,6 @@ ParfactorList::print (void) const
list<Parfactor*>::const_iterator it;
for (it = pfList_.begin(); it != pfList_.end(); ++it) {
(*it)->print();
cout << endl;
}
}
@ -219,7 +218,7 @@ ParfactorList::shatter (
ProbFormula& f1 = g1->argument (fIdx1);
ProbFormula& f2 = g2->argument (fIdx2);
// cout << endl;
// Util::printDashLine();
// Util::printDashedLine();
// cout << "-> SHATTERING (#" << g1 << ", #" << g2 << ")" << endl;
// g1->print();
// cout << "-> WITH" << endl;
@ -228,7 +227,7 @@ ParfactorList::shatter (
// cout << g1->constr()->tupleSet (f1.logVars()) << endl;
// cout << "-> ON: " << f2 << "|" ;
// cout << g2->constr()->tupleSet (f2.logVars()) << endl;
// Util::printDashLine();
// Util::printDashedLine();
if (f1.isAtom()) {
unsigned group = (f1.group() < f2.group()) ? f1.group() : f2.group();
f1.setGroup (group);
@ -265,14 +264,15 @@ ParfactorList::shatter (
assert (commCt1->tupleSet (f1.arity()) ==
commCt2->tupleSet (f2.arity()));
// unsigned static count = 0; count ++;
// stringstream ss1; ss1 << "" << count << "_A.dot" ;
// stringstream ss2; ss2 << "" << count << "_B.dot" ;
// stringstream ss3; ss3 << "" << count << "_A_comm.dot" ;
// stringstream ss4; ss4 << "" << count << "_A_excl.dot" ;
// stringstream ss5; ss5 << "" << count << "_B_comm.dot" ;
// stringstream ss6; ss6 << "" << count << "_B_excl.dot" ;
// ct1->exportToGraphViz (ss1.str().c_str(), true);
// ct2->exportToGraphViz (ss2.str().c_str(), true);
// g1->constr()->exportToGraphViz (ss1.str().c_str(), true);
// g2->constr()->exportToGraphViz (ss2.str().c_str(), true);
// commCt1->exportToGraphViz (ss3.str().c_str(), true);
// exclCt1->exportToGraphViz (ss4.str().c_str(), true);
// commCt2->exportToGraphViz (ss5.str().c_str(), true);

View File

@ -2,52 +2,36 @@
#include "Util.h"
void
Solver::printAnswer (const VarIds& vids)
{
Vars unobservedVars;
VarIds unobservedVids;
for (unsigned i = 0; i < vids.size(); i++) {
VarNode* vn = fg.getVarNode (vids[i]);
if (vn->hasEvidence() == false) {
unobservedVars.push_back (vn);
unobservedVids.push_back (vids[i]);
}
}
Params res = solveQuery (unobservedVids);
vector<string> stateLines = Util::getStateLines (unobservedVars);
for (unsigned i = 0; i < res.size(); i++) {
cout << "P(" << stateLines[i] << ") = " ;
cout << std::setprecision (Constants::PRECISION) << res[i];
cout << endl;
}
cout << endl;
}
void
Solver::printAllPosterioris (void)
{
const VarNodes& vars = gm_->getVariableNodes();
const VarNodes& vars = fg.varNodes();
for (unsigned i = 0; i < vars.size(); i++) {
printPosterioriOf (vars[i]->varId());
printAnswer ({vars[i]->varId()});
}
}
void
Solver::printPosterioriOf (VarId vid)
{
VarNode* var = gm_->getVariableNode (vid);
const Params& posterioriDist = getPosterioriOf (vid);
const States& states = var->states();
for (unsigned i = 0; i < states.size(); i++) {
cout << "P(" << var->label() << "=" << states[i] << ") = " ;
cout << setprecision (Constants::PRECISION) << posterioriDist[i];
cout << endl;
}
cout << endl;
}
void
Solver::printJointDistributionOf (const VarIds& vids)
{
VarNodes vars;
VarIds vidsWithoutEvidence;
for (unsigned i = 0; i < vids.size(); i++) {
VarNode* var = gm_->getVariableNode (vids[i]);
if (var->hasEvidence() == false) {
vars.push_back (var);
vidsWithoutEvidence.push_back (vids[i]);
}
}
const Params& jointDist = getJointDistributionOf (vidsWithoutEvidence);
vector<string> jointStrings = Util::getJointStateStrings (vars);
for (unsigned i = 0; i < jointDist.size(); i++) {
cout << "P(" << jointStrings[i] << ") = " ;
cout << setprecision (Constants::PRECISION) << jointDist[i];
cout << endl;
}
cout << endl;
}

View File

@ -3,32 +3,27 @@
#include <iomanip>
#include "GraphicalModel.h"
#include "VarNode.h"
#include "Var.h"
#include "FactorGraph.h"
using namespace std;
class Solver
{
public:
Solver (const GraphicalModel* gm) : gm_(gm) { }
Solver (const FactorGraph& factorGraph) : fg(factorGraph) { }
virtual ~Solver() { } // ensure that subclass destructor is called
virtual void runSolver (void) = 0;
virtual Params solveQuery (VarIds queryVids) = 0;
virtual Params getPosterioriOf (VarId) = 0;
virtual Params getJointDistributionOf (const VarIds&) = 0;
void printAnswer (const VarIds& vids);
void printAllPosterioris (void);
void printPosterioriOf (VarId vid);
void printJointDistributionOf (const VarIds& vids);
private:
const GraphicalModel* gm_;
protected:
const FactorGraph& fg;
};
#endif // HORUS_SOLVER_H

View File

@ -1,5 +1,4 @@
TODO
- add way to calculate combinations and factorials with large numbers
- add a way to calculate combinations and factorials with large numbers
- refactor sumOut in parfactor -> is really ugly code
- Indexer: start receiving ranges as constant reference

View File

@ -5,7 +5,6 @@
#include "Util.h"
#include "Indexer.h"
#include "GraphicalModel.h"
namespace Globals {
@ -30,7 +29,6 @@ unsigned maxIter = 1000;
}
unordered_map<VarId, VarInfo> GraphicalModel::varsInfo_;
vector<NetInfo> Statistics::netInfo_;
vector<CompressInfo> Statistics::compressInfo_;
@ -139,7 +137,7 @@ parametersToString (const Params& v, unsigned precision)
vector<string>
getJointStateStrings (const VarNodes& vars)
getStateLines (const Vars& vars)
{
StatesIndexer idx (vars);
vector<string> jointStrings;
@ -157,7 +155,8 @@ getJointStateStrings (const VarNodes& vars)
void printHeader (string header, std::ostream& os)
void
printHeader (string header, std::ostream& os)
{
printAsteriskLine (os);
os << header << endl;
@ -166,7 +165,8 @@ void printHeader (string header, std::ostream& os)
void printSubHeader (string header, std::ostream& os)
void
printSubHeader (string header, std::ostream& os)
{
printDashedLine (os);
os << header << endl;
@ -175,7 +175,8 @@ void printSubHeader (string header, std::ostream& os)
void printAsteriskLine (std::ostream& os)
void
printAsteriskLine (std::ostream& os)
{
os << "********************************" ;
os << "********************************" ;
@ -184,7 +185,8 @@ void printAsteriskLine (std::ostream& os)
void printDashedLine (std::ostream& os)
void
printDashedLine (std::ostream& os)
{
os << "--------------------------------" ;
os << "--------------------------------" ;
@ -368,12 +370,12 @@ Statistics::printStatistics (void)
void
Statistics::writeStatisticsToFile (const char* fileName)
Statistics::writeStatistics (const char* fileName)
{
ofstream out (fileName);
if (!out.is_open()) {
cerr << "error: cannot open file to write at " ;
cerr << "Statistics::writeStatisticsToFile()" << endl;
cerr << "Statistics::writeStats()" << endl;
abort();
}
out << getStatisticString();
@ -384,13 +386,13 @@ Statistics::writeStatisticsToFile (const char* fileName)
void
Statistics::updateCompressingStatistics (
unsigned nGroundVars,
unsigned nGroundFactors,
unsigned nClusterVars,
unsigned nClusterFactors,
unsigned nWithoutNeighs) {
compressInfo_.push_back (CompressInfo (nGroundVars, nGroundFactors,
nClusterVars, nClusterFactors, nWithoutNeighs));
unsigned nrGroundVars,
unsigned nrGroundFactors,
unsigned nrClusterVars,
unsigned nrClusterFactors,
unsigned nrNeighborless) {
compressInfo_.push_back (CompressInfo (nrGroundVars, nrGroundFactors,
nrClusterVars, nrClusterFactors, nrNeighborless));
}
@ -402,8 +404,7 @@ Statistics::getStatisticString (void)
ss1 << "running mode: " ;
switch (Globals::infAlgorithm) {
case InfAlgorithms::VE: ss1 << "ve" << endl; break;
case InfAlgorithms::BN_BP: ss1 << "bn_bp" << endl; break;
case InfAlgorithms::FG_BP: ss1 << "fg_bp" << endl; break;
case InfAlgorithms::BP: ss1 << "bp" << endl; break;
case InfAlgorithms::CBP: ss1 << "cbp" << endl; break;
}
ss1 << "message schedule: " ;
@ -463,17 +464,17 @@ Statistics::getStatisticString (void)
ss3 << "Ground Cluster Ground Cluster Neighborless" << endl;
ss3 << "Vars Vars Factors Factors Vars" << endl;
for (unsigned i = 0; i < compressInfo_.size(); i++) {
ss3 << setw (9) << compressInfo_[i].nGroundVars;
ss3 << setw (10) << compressInfo_[i].nClusterVars;
ss3 << setw (10) << compressInfo_[i].nGroundFactors;
ss3 << setw (10) << compressInfo_[i].nClusterFactors;
ss3 << setw (10) << compressInfo_[i].nWithoutNeighs;
ss3 << setw (9) << compressInfo_[i].nrGroundVars;
ss3 << setw (10) << compressInfo_[i].nrClusterVars;
ss3 << setw (10) << compressInfo_[i].nrGroundFactors;
ss3 << setw (10) << compressInfo_[i].nrClusterFactors;
ss3 << setw (10) << compressInfo_[i].nrNeighborless;
ss3 << endl;
c1 += compressInfo_[i].nGroundVars - compressInfo_[i].nWithoutNeighs;
c2 += compressInfo_[i].nClusterVars;
c3 += compressInfo_[i].nGroundFactors - compressInfo_[i].nWithoutNeighs;
c4 += compressInfo_[i].nClusterFactors;
if (compressInfo_[i].nWithoutNeighs != 0) {
c1 += compressInfo_[i].nrGroundVars - compressInfo_[i].nrNeighborless;
c2 += compressInfo_[i].nrClusterVars;
c3 += compressInfo_[i].nrGroundFactors - compressInfo_[i].nrNeighborless;
c4 += compressInfo_[i].nrClusterFactors;
if (compressInfo_[i].nrNeighborless != 0) {
c2 --;
c4 --;
}

View File

@ -5,6 +5,7 @@
#include <cassert>
#include <limits>
#include <algorithm>
#include <vector>
#include <set>
#include <queue>
@ -22,6 +23,8 @@ namespace Util {
template <typename T> void addToVector (vector<T>&, const vector<T>&);
template <typename T> void addToSet (set<T>&, const vector<T>&);
template <typename T> void addToQueue (queue<T>&, const vector<T>&);
template <typename T> bool contains (const vector<T>&, const T&);
@ -59,7 +62,7 @@ bool isInteger (const string&);
string parametersToString (const Params&, unsigned = Constants::PRECISION);
vector<string> getJointStateStrings (const VarNodes&);
vector<string> getStateLines (const Vars&);
void printHeader (string, std::ostream& os = std::cout);
@ -83,6 +86,14 @@ Util::addToVector (vector<T>& v, const vector<T>& elements)
template <typename T> void
Util::addToSet (set<T>& s, const vector<T>& elements)
{
s.insert (elements.begin(), elements.end());
}
template <typename T> void
Util::addToQueue (queue<T>& q, const vector<T>& elements)
{
@ -110,8 +121,7 @@ Util::contains (const set<T>& s, const T& e)
template <typename K, typename V> bool
Util::contains (
const unordered_map<K, V>& m, const K& k)
Util::contains (const unordered_map<K, V>& m, const K& k)
{
return m.find (k) != m.end();
}
@ -322,17 +332,17 @@ struct CompressInfo
{
CompressInfo (unsigned a, unsigned b, unsigned c, unsigned d, unsigned e)
{
nGroundVars = a;
nGroundFactors = b;
nClusterVars = c;
nClusterFactors = d;
nWithoutNeighs = e;
nrGroundVars = a;
nrGroundFactors = b;
nrClusterVars = c;
nrClusterFactors = d;
nrNeighborless = e;
}
unsigned nGroundVars;
unsigned nGroundFactors;
unsigned nClusterVars;
unsigned nClusterFactors;
unsigned nWithoutNeighs;
unsigned nrGroundVars;
unsigned nrGroundFactors;
unsigned nrClusterVars;
unsigned nrClusterFactors;
unsigned nrNeighborless;
};
@ -349,7 +359,7 @@ class Statistics
static void printStatistics (void);
static void writeStatisticsToFile (const char*);
static void writeStatistics (const char*);
static void updateCompressingStatistics (
unsigned, unsigned, unsigned, unsigned, unsigned);

View File

@ -0,0 +1,102 @@
#include <algorithm>
#include <sstream>
#include "Var.h"
using namespace std;
unordered_map<VarId, VarInfo> Var::varsInfo_;
Var::Var (const Var* v)
{
varId_ = v->varId();
range_ = v->range();
evidence_ = v->getEvidence();
index_ = std::numeric_limits<unsigned>::max();
}
Var::Var (VarId varId, unsigned range, int evidence)
{
assert (range != 0);
assert (evidence < (int) range);
varId_ = varId;
range_ = range;
evidence_ = evidence;
index_ = std::numeric_limits<unsigned>::max();
}
bool
Var::isValidState (int stateIndex)
{
return stateIndex >= 0 && stateIndex < (int) range_;
}
bool
Var::isValidState (const string& stateName)
{
States states = Var::getVarInfo (varId_).states;
return Util::contains (states, stateName);
}
void
Var::setEvidence (int ev)
{
assert (ev < (int) range_);
evidence_ = ev;
}
void
Var::setEvidence (const string& ev)
{
States states = Var::getVarInfo (varId_).states;
for (unsigned i = 0; i < states.size(); i++) {
if (states[i] == ev) {
evidence_ = i;
return;
}
}
assert (false);
}
string
Var::label (void) const
{
if (Var::varsHaveInfo()) {
return Var::getVarInfo (varId_).label;
}
stringstream ss;
ss << "x" << varId_;
return ss.str();
}
States
Var::states (void) const
{
if (Var::varsHaveInfo()) {
return Var::getVarInfo (varId_).states;
}
States states;
for (unsigned i = 0; i < range_; i++) {
stringstream ss;
ss << i ;
states.push_back (ss.str());
}
return states;
}

View File

@ -0,0 +1,108 @@
#ifndef HORUS_Var_H
#define HORUS_Var_H
#include <cassert>
#include <iostream>
#include "Util.h"
#include "Horus.h"
using namespace std;
struct VarInfo
{
VarInfo (string l, const States& sts) : label(l), states(sts) { }
string label;
States states;
};
class Var
{
public:
Var (const Var*);
Var (VarId, unsigned, int = Constants::NO_EVIDENCE);
virtual ~Var (void) { };
unsigned varId (void) const { return varId_; }
unsigned range (void) const { return range_; }
int getEvidence (void) const { return evidence_; }
unsigned getIndex (void) const { return index_; }
void setIndex (unsigned idx) { index_ = idx; }
operator unsigned () const { return index_; }
bool hasEvidence (void) const
{
return evidence_ != Constants::NO_EVIDENCE;
}
bool operator== (const Var& var) const
{
assert (!(varId_ == var.varId() && range_ != var.range()));
return varId_ == var.varId();
}
bool operator!= (const Var& var) const
{
assert (!(varId_ == var.varId() && range_ != var.range()));
return varId_ != var.varId();
}
bool isValidState (int);
bool isValidState (const string&);
void setEvidence (int);
void setEvidence (const string&);
string label (void) const;
States states (void) const;
static void addVarInfo (
VarId vid, string label, const States& states)
{
assert (Util::contains (varsInfo_, vid) == false);
varsInfo_.insert (make_pair (vid, VarInfo (label, states)));
}
static VarInfo getVarInfo (VarId vid)
{
assert (Util::contains (varsInfo_, vid));
return varsInfo_.find (vid)->second;
}
static bool varsHaveInfo (void)
{
return varsInfo_.size() != 0;
}
static void clearVarsInfo (void)
{
varsInfo_.clear();
}
private:
VarId varId_;
unsigned range_;
int evidence_;
unsigned index_;
static unordered_map<VarId, VarInfo> varsInfo_;
};
#endif // BP_Var_H

View File

@ -6,61 +6,27 @@
#include "Util.h"
VarElimSolver::VarElimSolver (const BayesNet& bn) : Solver (&bn)
{
bayesNet_ = &bn;
factorGraph_ = new FactorGraph (bn);
}
VarElimSolver::VarElimSolver (const FactorGraph& fg) : Solver (&fg)
{
bayesNet_ = 0;
factorGraph_ = &fg;
}
VarElimSolver::~VarElimSolver (void)
{
if (bayesNet_) {
delete factorGraph_;
}
delete factorList_.back();
}
Params
VarElimSolver::getPosterioriOf (VarId vid)
{
assert (factorGraph_->getFgVarNode (vid));
FgVarNode* vn = factorGraph_->getFgVarNode (vid);
if (vn->hasEvidence()) {
Params params (vn->nrStates(), 0.0);
params[vn->getEvidence()] = 1.0;
return params;
}
return getJointDistributionOf (VarIds() = {vid});
}
Params
VarElimSolver::getJointDistributionOf (const VarIds& vids)
VarElimSolver::solveQuery (VarIds queryVids)
{
factorList_.clear();
varFactors_.clear();
elimOrder_.clear();
createFactorList();
introduceEvidence();
chooseEliminationOrder (vids);
processFactorList (vids);
absorveEvidence();
findEliminationOrder (queryVids);
processFactorList (queryVids);
Params params = factorList_.back()->params();
if (Globals::logDomain) {
Util::fromLog (params);
}
delete factorList_.back();
return params;
}
@ -69,11 +35,11 @@ VarElimSolver::getJointDistributionOf (const VarIds& vids)
void
VarElimSolver::createFactorList (void)
{
const FgFacSet& factorNodes = factorGraph_->getFactorNodes();
factorList_.reserve (factorNodes.size() * 2);
for (unsigned i = 0; i < factorNodes.size(); i++) {
factorList_.push_back (new Factor (*factorNodes[i]->factor()));
const FgVarSet& neighs = factorNodes[i]->neighbors();
const FacNodes& facNodes = fg.facNodes();
factorList_.reserve (facNodes.size() * 2);
for (unsigned i = 0; i < facNodes.size(); i++) {
factorList_.push_back (new Factor (facNodes[i]->factor()));
const VarNodes& neighs = facNodes[i]->neighbors();
for (unsigned j = 0; j < neighs.size(); j++) {
unordered_map<VarId,vector<unsigned> >::iterator it
= varFactors_.find (neighs[j]->varId());
@ -89,9 +55,9 @@ VarElimSolver::createFactorList (void)
void
VarElimSolver::introduceEvidence (void)
VarElimSolver::absorveEvidence (void)
{
const FgVarSet& varNodes = factorGraph_->getVarNodes();
const VarNodes& varNodes = fg.varNodes();
for (unsigned i = 0; i < varNodes.size(); i++) {
if (varNodes[i]->hasEvidence()) {
const vector<unsigned>& idxs =
@ -112,21 +78,9 @@ VarElimSolver::introduceEvidence (void)
void
VarElimSolver::chooseEliminationOrder (const VarIds& vids)
VarElimSolver::findEliminationOrder (const VarIds& vids)
{
if (bayesNet_) {
ElimGraph graph (*bayesNet_);
elimOrder_ = graph.getEliminatingOrder (vids);
} else {
const FgVarSet& varNodes = factorGraph_->getVarNodes();
for (unsigned i = 0; i < varNodes.size(); i++) {
VarId vid = varNodes[i]->varId();
if (Util::contains (vids, vid) == false &&
varNodes[i]->hasEvidence() == false) {
elimOrder_.push_back (vid);
}
}
}
elimOrder_ = ElimGraph::getEliminationOrder (factorList_, vids);
}
@ -149,7 +103,7 @@ VarElimSolver::processFactorList (const VarIds& vids)
VarIds unobservedVids;
for (unsigned i = 0; i < vids.size(); i++) {
if (factorGraph_->getFgVarNode (vids[i])->hasEvidence() == false) {
if (fg.getVarNode (vids[i])->hasEvidence() == false) {
unobservedVids.push_back (vids[i]);
}
}
@ -165,7 +119,6 @@ void
VarElimSolver::eliminate (VarId elimVar)
{
Factor* result = 0;
FgVarNode* vn = factorGraph_->getFgVarNode (elimVar);
vector<unsigned>& idxs = varFactors_.find (elimVar)->second;
for (unsigned i = 0; i < idxs.size(); i++) {
unsigned idx = idxs[i];
@ -180,7 +133,7 @@ VarElimSolver::eliminate (VarId elimVar)
}
}
if (result != 0 && result->nrArguments() != 1) {
result->sumOut (vn->varId());
result->sumOut (elimVar);
factorList_.push_back (result);
const VarIds& resultVarIds = result->arguments();
for (unsigned i = 0; i < resultVarIds.size(); i++) {
@ -199,7 +152,6 @@ VarElimSolver::printActiveFactors (void)
for (unsigned i = 0; i < factorList_.size(); i++) {
if (factorList_[i] != 0) {
factorList_[i]->print();
cout << endl;
}
}
}

View File

@ -5,7 +5,6 @@
#include "Solver.h"
#include "FactorGraph.h"
#include "BayesNet.h"
#include "Horus.h"
@ -15,24 +14,18 @@ using namespace std;
class VarElimSolver : public Solver
{
public:
VarElimSolver (const BayesNet&);
VarElimSolver (const FactorGraph&);
VarElimSolver (const FactorGraph& fg) : Solver (fg) { }
~VarElimSolver (void);
void runSolver (void) { }
Params getPosterioriOf (VarId);
Params getJointDistributionOf (const VarIds&);
Params solveQuery (VarIds);
private:
void createFactorList (void);
void introduceEvidence (void);
void absorveEvidence (void);
void chooseEliminationOrder (const VarIds&);
void findEliminationOrder (const VarIds&);
void processFactorList (const VarIds&);
@ -40,8 +33,6 @@ class VarElimSolver : public Solver
void printActiveFactors (void);
const BayesNet* bayesNet_;
const FactorGraph* factorGraph_;
vector<Factor*> factorList_;
VarIds elimOrder_;
unordered_map<VarId, vector<unsigned>> varFactors_;

View File

@ -1,100 +0,0 @@
#include <algorithm>
#include <sstream>
#include "VarNode.h"
#include "GraphicalModel.h"
using namespace std;
VarNode::VarNode (const VarNode* v)
{
varId_ = v->varId();
nrStates_ = v->nrStates();
evidence_ = v->getEvidence();
index_ = std::numeric_limits<unsigned>::max();
}
VarNode::VarNode (VarId varId, unsigned nrStates, int evidence)
{
assert (nrStates != 0);
assert (evidence < (int) nrStates);
varId_ = varId;
nrStates_ = nrStates;
evidence_ = evidence;
index_ = std::numeric_limits<unsigned>::max();
}
bool
VarNode::isValidState (int stateIndex)
{
return stateIndex >= 0 && stateIndex < (int) nrStates_;
}
bool
VarNode::isValidState (const string& stateName)
{
States states = GraphicalModel::getVarInformation (varId_).states;
return Util::contains (states, stateName);
}
void
VarNode::setEvidence (int ev)
{
assert (ev < (int) nrStates_);
evidence_ = ev;
}
void
VarNode::setEvidence (const string& ev)
{
States states = GraphicalModel::getVarInformation (varId_).states;
for (unsigned i = 0; i < states.size(); i++) {
if (states[i] == ev) {
evidence_ = i;
return;
}
}
assert (false);
}
string
VarNode::label (void) const
{
if (GraphicalModel::variablesHaveInformation()) {
return GraphicalModel::getVarInformation (varId_).label;
}
stringstream ss;
ss << "x" << varId_;
return ss.str();
}
States
VarNode::states (void) const
{
if (GraphicalModel::variablesHaveInformation()) {
return GraphicalModel::getVarInformation (varId_).states;
}
States states;
for (unsigned i = 0; i < nrStates_; i++) {
stringstream ss;
ss << i ;
states.push_back (ss.str());
}
return states;
}

View File

@ -1,73 +0,0 @@
#ifndef HORUS_VARNODE_H
#define HORUS_VARNODE_H
#include <cassert>
#include <iostream>
#include "Horus.h"
using namespace std;
class VarNode
{
public:
VarNode (const VarNode*);
VarNode (VarId, unsigned, int = Constants::NO_EVIDENCE);
virtual ~VarNode (void) { };
unsigned varId (void) const { return varId_; }
unsigned nrStates (void) const { return nrStates_; }
int getEvidence (void) const { return evidence_; }
unsigned getIndex (void) const { return index_; }
void setIndex (unsigned idx) { index_ = idx; }
operator unsigned () const { return index_; }
bool hasEvidence (void) const
{
return evidence_ != Constants::NO_EVIDENCE;
}
bool operator== (const VarNode& var) const
{
cout << "equal operator called" << endl;
assert (!(varId_ == var.varId() && nrStates_ != var.nrStates()));
return varId_ == var.varId();
}
bool operator!= (const VarNode& var) const
{
cout << "diff operator called" << endl;
assert (!(varId_ == var.varId() && nrStates_ != var.nrStates()));
return varId_ != var.varId();
}
bool isValidState (int);
bool isValidState (const string&);
void setEvidence (int);
void setEvidence (const string&);
string label (void) const;
States states (void) const;
private:
VarId varId_;
unsigned nrStates_;
int evidence_;
unsigned index_;
};
#endif // BP_VARNODE_H

View File

@ -0,0 +1,35 @@
if [ $1 ] && [ $1 == "clear" ]; then
rm *~
rm -f school/*.log school/*~
rm -f city/*.log city/*~
rm -f workshop_attrs/*.log workshop_attrs/*~
fi
function run_solver
{
constraint=$1
solver_flag=true
if [ -n "$2" ]; then
if [ $SOLVER = hve ]; then
extra_flag=clpbn_horus:set_horus_flag\(elim_heuristic,$2\)
elif [ $SOLVER = bp ]; then
extra_flag=clpbn_horus:set_horus_flag\(schedule,$2\)
elif [ $SOLVER = cbp ]; then
extra_flag=clpbn_horus:set_horus_flag\(schedule,$2\)
else
echo "unknow flag $2"
fi
fi
/usr/bin/time -o $LOG_FILE -a -f "real:%E\tuser:%U\tsys:%S" \
$YAP << EOF >> $LOG_FILE 2>> ignore.$LOG_FILE
[$NETWORK].
[$constraint].
clpbn_horus:set_solver($SOLVER).
clpbn_horus:set_horus_flag(use_logarithms, true).
$solver_flag.
$QUERY.
open("$LOG_FILE", 'append', S), format(S, '$constraint: ~15+ ', []), close(S).
EOF
}

View File

@ -1,50 +0,0 @@
#!/bin/bash
cp ~/bin/yap ~/bin/town_bnbp
YAP=~/bin/town_bnbp
#OUT_FILE_NAME=results`date "+ %H:%M:%S %d-%m-%Y"`.log
OUT_FILE_NAME=bnbp.log
rm -f $OUT_FILE_NAME
rm -f ignore.$OUT_FILE_NAME
function run_solver
{
if [ $2 = bp ]
then
extra_flag1=clpbn_horus:set_horus_flag\(inf_alg,$4\)
extra_flag2=clpbn_horus:set_horus_flag\(schedule,$5\)
else
extra_flag1=true
extra_flag2=true
fi
/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME
[$1].
clpbn:set_clpbn_flag(solver,$2),
clpbn_horus:set_horus_flag(use_logarithms, true),
$extra_flag1, $extra_flag2,
run_query(_R),
open("$OUT_FILE_NAME", 'append',S),
format(S, '$3: ~15+ ',[]),
close(S).
EOF
}
function run_all_graphs
{
echo "*******************************************************************" >> "$OUT_FILE_NAME"
echo "results for solver $2" >> $OUT_FILE_NAME
echo "*******************************************************************" >> "$OUT_FILE_NAME"
run_solver town_1000 $1 town_1000 $3 $4 $5
run_solver town_5000 $1 town_5000 $3 $4 $5
run_solver town_10000 $1 town_10000 $3 $4 $5
run_solver town_50000 $1 town_50000 $3 $4 $5
run_solver town_100000 $1 town_100000 $3 $4 $5
run_solver town_500000 $1 town_500000 $3 $4 $5
run_solver town_1000000 $1 town_1000000 $3 $4 $5
}
run_all_graphs bp "bn_bp(seq_fixed) " bn_bp seq_fixed

View File

@ -0,0 +1,17 @@
#!/bin/bash
source city.sh
source ../benchs.sh
SOLVER="bp"
YAP=~/bin/$SHORTNAME-$SOLVER
LOG_FILE=$SOLVER.log
#LOG_FILE=results`date "+ %H:%M:%S %d-%m-%Y"`.
rm -f $LOG_FILE
rm -f ignore.$LOG_FILE
run_all_graphs "bp(shedule=seq_fixed) " seq_fixed

View File

@ -1,56 +1,17 @@
#!/bin/bash
cp ~/bin/yap ~/bin/town_cbp
YAP=~/bin/town_cbp
source city.sh
source ../benchs.sh
#OUT_FILE_NAME=results`date "+ %H:%M:%S %d-%m-%Y"`.log
OUT_FILE_NAME=cbp.log
rm -f $OUT_FILE_NAME
rm -f ignore.$OUT_FILE_NAME
SOLVER="cbp"
YAP=~/bin/$SHORTNAME-$SOLVER
function run_solver
{
if [ $2 = bp ]
then
extra_flag1=clpbn_horus:set_horus_flag\(inf_alg,$4\)
extra_flag2=clpbn_horus:set_horus_flag\(schedule,$5\)
else
extra_flag1=true
extra_flag2=true
fi
/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME
[$1].
clpbn:set_clpbn_flag(solver,$2),
clpbn_horus:set_horus_flag(use_logarithms, true),
$extra_flag1, $extra_flag2,
run_query(_R),
open("$OUT_FILE_NAME", 'append',S),
format(S, '$3: ~15+ ',[]),
close(S).
EOF
}
LOG_FILE=$SOLVER.log
#LOG_FILE=results`date "+ %H:%M:%S %d-%m-%Y"`.
rm -f $LOG_FILE
rm -f ignore.$LOG_FILE
function run_all_graphs
{
echo "*******************************************************************" >> "$OUT_FILE_NAME"
echo "results for solver $2" >> $OUT_FILE_NAME
echo "*******************************************************************" >> "$OUT_FILE_NAME"
run_solver town_3 $1 town_3 $3 $4 $5
return
run_solver town_1000 $1 town_1000 $3 $4 $5
run_solver town_5000 $1 town_5000 $3 $4 $5
run_solver town_10000 $1 town_10000 $3 $4 $5
run_solver town_50000 $1 town_50000 $3 $4 $5
run_solver town_100000 $1 town_100000 $3 $4 $5
run_solver town_500000 $1 town_500000 $3 $4 $5
run_solver town_1000000 $1 town_1000000 $3 $4 $5
run_solver town_2500000 $1 town_2500000 $3 $4 $5
run_solver town_5000000 $1 town_5000000 $3 $4 $5
run_solver town_7500000 $1 town_7500000 $3 $4 $5
run_solver town_10000000 $1 town_10000000 $3 $4 $5
}
run_all_graphs bp "cbp(seq_fixed) " cbp seq_fixed
run_all_graphs "cbp(shedule=seq_fixed) " seq_fixed

View File

@ -0,0 +1,25 @@
#!/bin/bash
NETWORK="'../../examples/city'"
SHORTNAME="city"
QUERY="is_joe_guilty(X)"
function run_all_graphs
{
cp ~/bin/yap $YAP
echo -n "**********************************" >> $LOG_FILE
echo "**********************************" >> $LOG_FILE
echo "results for solver $1" >> $LOG_FILE
echo -n "**********************************" >> $LOG_FILE
echo "**********************************" >> $LOG_FILE
run_solver city_5 $2
#run_solver city_1000 $2
#run_solver city_5000 $2
#run_solver city_10000 $2
#run_solver city_50000 $2
#run_solver city_100000 $2
#run_solver city_500000 $2
#run_solver city_1000000 $2
}

View File

@ -0,0 +1,37 @@
#!/home/tiago/bin/yap -L --
:- initialization(main).
main :-
unix(argv([H])),
generate_town(H).
generate_town(N) :-
atomic_concat(['city_', N, '.yap'], FileName),
open(FileName, 'write', S),
atom_number(N, N2),
generate_people(S, N2, 4),
write(S, '\n'),
generate_query(S, N2, 4),
write(S, '\n'),
close(S).
generate_people(S, N, Counting) :-
Counting > N, !.
generate_people(S, N, Counting) :-
format(S, 'people(p~w, nyc).~n', [Counting]),
Counting1 is Counting + 1,
generate_people(S, N, Counting1).
generate_query(S, N, Counting) :-
Counting > N, !.
generate_query(S, N, Counting) :- !,
format(S, 'ev(descn(p~w, t)).~n', [Counting]),
Counting1 is Counting + 1,
generate_query(S, N, Counting1).

View File

@ -1,50 +0,0 @@
#!/bin/bash
cp ~/bin/yap ~/bin/town_fgbp
YAP=~/bin/town_fgbp
#OUT_FILE_NAME=results`date "+ %H:%M:%S %d-%m-%Y"`.log
OUT_FILE_NAME=fb_bp.log
rm -f $OUT_FILE_NAME
rm -f ignore.$OUT_FILE_NAME
function run_solver
{
if [ $2 = bp ]
then
extra_flag1=clpbn_horus:set_horus_flag\(inf_alg,$4\)
extra_flag2=clpbn_horus:set_horus_flag\(schedule,$5\)
else
extra_flag1=true
extra_flag2=true
fi
/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME
[$1].
clpbn:set_clpbn_flag(solver,$2),
clpbn_horus:set_horus_flag(use_logarithms, true),
$extra_flag1, $extra_flag2,
run_query(_R),
open("$OUT_FILE_NAME", 'append',S),
format(S, '$3: ~15+ ',[]),
close(S).
EOF
}
function run_all_graphs
{
echo "*******************************************************************" >> "$OUT_FILE_NAME"
echo "results for solver $2" >> $OUT_FILE_NAME
echo "*******************************************************************" >> "$OUT_FILE_NAME"
run_solver town_1000 $1 town_1000 $3 $4 $5
#run_solver town_5000 $1 town_5000 $3 $4 $5
#run_solver town_10000 $1 town_10000 $3 $4 $5
#run_solver town_50000 $1 town_50000 $3 $4 $5
#run_solver town_100000 $1 town_100000 $3 $4 $5
#run_solver town_500000 $1 town_500000 $3 $4 $5
#run_solver town_1000000 $1 town_1000000 $3 $4 $5
}
run_all_graphs bp "fg_bp(seq_fixed) " fg_bp seq_fixed

View File

@ -0,0 +1,17 @@
#!/bin/bash
source city.sh
source ../benchs.sh
SOLVER="fove"
YAP=~/bin/$SHORTNAME-$SOLVER
LOG_FILE=$SOLVER.log
#LOG_FILE=results`date "+ %H:%M:%S %d-%m-%Y"`.
rm -f $LOG_FILE
rm -f ignore.$LOG_FILEE
run_all_graphs "fove "

View File

@ -1,50 +0,0 @@
#!/bin/bash
cp ~/bin/yap ~/bin/town_gibbs
YAP=~/bin/town_gibbs
#OUT_FILE_NAME=results`date "+ %H:%M:%S %d-%m-%Y"`.log
OUT_FILE_NAME=gibbs.log
rm -f $OUT_FILE_NAME
rm -f ignore.$OUT_FILE_NAME
function run_solver
{
if [ $2 = bp ]
then
extra_flag1=clpbn_bp:set_horus_flag\(inf_alg,$4\)
extra_flag2=clpbn_bp:set_horus_flag\(schedule,$5\)
else
extra_flag1=true
extra_flag2=true
fi
/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME
[$1].
clpbn:set_clpbn_flag(solver,$2),
clpbn_bp:set_horus_flag(use_logarithms, true),
$extra_flag1, $extra_flag2,
run_query(_R),
open("$OUT_FILE_NAME", 'append',S),
format(S, '$3: ~15+ ',[]),
close(S).
EOF
}
function run_all_graphs
{
echo "*******************************************************************" >> "$OUT_FILE_NAME"
echo "results for solver $2" >> $OUT_FILE_NAME
echo "*******************************************************************" >> "$OUT_FILE_NAME"
run_solver town_1000 $1 town_1000 $3 $4 $5
run_solver town_5000 $1 town_5000 $3 $4 $5
run_solver town_10000 $1 town_10000 $3 $4 $5
run_solver town_50000 $1 town_50000 $3 $4 $5
run_solver town_100000 $1 town_100000 $3 $4 $5
run_solver town_500000 $1 town_500000 $3 $4 $5
run_solver town_1000000 $1 town_1000000 $3 $4 $5
}
run_all_graphs gibbs "gibbs "

View File

@ -0,0 +1,17 @@
#!/bin/bash
source city.sh
source ../benchs.sh
SOLVER="hve"
YAP=~/bin/$SHORTNAME-$SOLVER
LOG_FILE=$SOLVER.log
#LOG_FILE=results`date "+ %H:%M:%S %d-%m-%Y"`.
rm -f $LOG_FILE
rm -f ignore.$LOG_FILE
run_all_graphs "hve(elim_heuristic=min_neighbors) " min_neighbors

View File

@ -1,50 +0,0 @@
#!/bin/bash
cp ~/bin/yap ~/bin/town_jt
YAP=~/bin/town_jt
#OUT_FILE_NAME=results`date "+ %H:%M:%S %d-%m-%Y"`.log
OUT_FILE_NAME=jt.log
rm -f $OUT_FILE_NAME
rm -f ignore.$OUT_FILE_NAME
function run_solver
{
if [ $2 = bp ]
then
extra_flag1=clpbn_bp:set_horus_flag\(inf_alg,$4\)
extra_flag2=clpbn_bp:set_horus_flag\(schedule,$5\)
else
extra_flag1=true
extra_flag2=true
fi
/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME
[$1].
clpbn:set_clpbn_flag(solver,$2),
clpbn_bp:set_horus_flag(use_logarithms, true),
$extra_flag1, $extra_flag2,
run_query(_R),
open("$OUT_FILE_NAME", 'append',S),
format(S, '$3: ~15+ ',[]),
close(S).
EOF
}
function run_all_graphs
{
echo "*******************************************************************" >> "$OUT_FILE_NAME"
echo "results for solver $2" >> $OUT_FILE_NAME
echo "*******************************************************************" >> "$OUT_FILE_NAME"
run_solver town_1000 $1 town_1000 $3 $4 $5
run_solver town_5000 $1 town_5000 $3 $4 $5
run_solver town_10000 $1 town_10000 $3 $4 $5
run_solver town_50000 $1 town_50000 $3 $4 $5
run_solver town_100000 $1 town_100000 $3 $4 $5
run_solver town_500000 $1 town_500000 $3 $4 $5
run_solver town_1000000 $1 town_1000000 $3 $4 $5
}
run_all_graphs jt "jt "

View File

@ -1,65 +0,0 @@
conservative_city(City, Cons) :-
cons_table(City, ConsDist),
{ Cons = conservative_city(City) with p([y,n], ConsDist) }.
gender(X, Gender) :-
gender_table(X, GenderDist),
{ Gender = gender(X) with p([m,f], GenderDist) }.
hair_color(X, Color) :-
lives(X, City),
conservative_city(City, Cons),
hair_color_table(X,ColorTable),
{ Color = hair_color(X) with
p([t,f], ColorTable,[Cons]) }.
car_color(X, Color) :-
hair_color(X, HColor),
car_color_table(X,CColorTable),
{ Color = car_color(X) with
p([t,f], CColorTable,[HColor]) }.
height(X, Height) :-
gender(X, Gender),
height_table(X,HeightTable),
{ Height = height(X) with
p([t,f], HeightTable,[Gender]) }.
shoe_size(X, Shoesize) :-
height(X, Height),
shoe_size_table(X,ShoesizeTable),
{ Shoesize = shoe_size(X) with
p([t,f], ShoesizeTable,[Height]) }.
guilty(X, Guilt) :-
guilty_table(X, GuiltDist),
{ Guilt = guilty(X) with p([y,n], GuiltDist) }.
descn(X, Descn) :-
car_color(X, Car),
hair_color(X, Hair),
height(X, Height),
guilty(X, Guilt),
descn_table(X, DescTable),
{ Descn = descn(X) with
p([t,f], DescTable,[Car,Hair,Height,Guilt]) }.
witness(City, Witness) :-
descn(joe, DescnJ),
descn(p2, Descn2),
wit_table(WitTable),
{ Witness = witness(City) with
p([t,f], WitTable,[DescnJ, Descn2]) }.
:- ensure_loaded(tables).

View File

@ -1,46 +0,0 @@
cons_table(amsterdam, [0.2, 0.8]) :- !.
cons_table(_, [0.8, 0.2]).
gender_table(_, [0.55, 0.44]).
hair_color_table(_,
/* conservative_city */
/* y n */
[ 0.05, 0.1,
0.95, 0.9 ]).
car_color_table(_,
/* t f */
[ 0.9, 0.2,
0.1, 0.8 ]).
height_table(_,
/* m f */
[ 0.6, 0.4,
0.4, 0.6 ]).
shoe_size_table(_,
/* t f */
[ 0.9, 0.1,
0.1, 0.9 ]).
guilty_table(_, [0.23, 0.77]).
descn_table(_,
/* color, hair, height, guilt */
/* ttttt tttf ttft ttff tfttt tftf tfft tfff ttttt fttf ftft ftff ffttt fftf ffft ffff */
[ 0.99, 0.5, 0.23, 0.88, 0.41, 0.3, 0.76, 0.87, 0.44, 0.43, 0.29, 0.72, 0.33, 0.91, 0.95, 0.92,
0.01, 0.5, 0.77, 0.12, 0.59, 0.7, 0.24, 0.13, 0.56, 0.57, 0.61, 0.28, 0.77, 0.09, 0.05, 0.08]).
wit_table([0.2, 0.45, 0.24, 0.34,
0.8, 0.55, 0.76, 0.66]).

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,29 +0,0 @@
:- source.
:- style_check(all).
:- yap_flag(unknown,error).
:- yap_flag(write_strings,on).
:- use_module(library(clpbn)).
:- set_clpbn_flag(solver, bp).
:- [-schema].
lives(_joe, nyc).
run_query(Guilty) :-
guilty(joe, Guilty),
witness(nyc, t),
runall(X, ev(X)).
runall(G, Wrapper) :-
findall(G, Wrapper, L),
execute_all(L).
execute_all([]).
execute_all(G.L) :-
call(G),
execute_all(L).
ev(descn(p2, t)).
ev(descn(p3, t)).

File diff suppressed because it is too large Load Diff

View File

@ -1,59 +0,0 @@
#!/home/tiago/bin/yap -L --
/*
Steps:
1. generate N facts lives(I, nyc), 0 <= I < N.
2. generate evidence on descn for N people, *** except for 1 ***
3. Run query ?- guilty(joe, Guilty), witness(joe, t), descn(2,t), descn(3, f), descn(4, f) ...
*/
:- initialization(main).
main :-
unix(argv([H])),
generate_town(H).
generate_town(N) :-
atomic_concat(['town_', N, '.yap'], FileName),
open(FileName, 'write', S),
write(S, ':- source.\n'),
write(S, ':- style_check(all).\n'),
write(S, ':- yap_flag(unknown,error).\n'),
write(S, ':- yap_flag(write_strings,on).\n'),
write(S, ':- use_module(library(clpbn)).\n'),
write(S, ':- set_clpbn_flag(solver, bp).\n'),
write(S, ':- [-schema].\n\n'),
write(S, 'lives(_joe, nyc).\n'),
atom_number(N, N2),
generate_people(S, N2, 2),
write(S, '\nrun_query(Guilty) :- \n'),
write(S, '\tguilty(joe, Guilty),\n'),
write(S, '\twitness(nyc, t),\n'),
write(S, '\trunall(X, ev(X)).\n\n\n'),
write(S, 'runall(G, Wrapper) :-\n'),
write(S, '\tfindall(G, Wrapper, L),\n'),
write(S, '\texecute_all(L).\n\n\n'),
write(S, 'execute_all([]).\n'),
write(S, 'execute_all(G.L) :-\n'),
write(S, '\tcall(G),\n'),
write(S, '\texecute_all(L).\n\n\n'),
generate_query(S, N2, 2),
close(S).
generate_people(_, N, Counting1) :- !.
generate_people(S, N, Counting) :-
format(S, 'lives(p~w, nyc).~n', [Counting]),
Counting1 is Counting + 1,
generate_people(S, N, Counting1).
generate_query(S, N, Counting) :-
Counting > N, !.
generate_query(S, N, Counting) :- !,
format(S, 'ev(descn(p~w, t)).~n', [Counting]),
Counting1 is Counting + 1,
generate_query(S, N, Counting1).

View File

@ -1,50 +0,0 @@
#!/bin/bash
cp ~/bin/yap ~/bin/town_ve
YAP=~/bin/town_ve
#OUT_FILE_NAME=results`date "+ %H:%M:%S %d-%m-%Y"`.log
OUT_FILE_NAME=ve.log
rm -f $OUT_FILE_NAME
rm -f ignore.$OUT_FILE_NAME
function run_solver
{
if [ $2 = bp ]
then
extra_flag1=clpbn_bp:set_horus_flag\(inf_alg,$4\)
extra_flag2=clpbn_bp:set_horus_flag\(schedule,$5\)
else
extra_flag1=true
extra_flag2=true
fi
/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME
[$1].
clpbn:set_clpbn_flag(solver,$2),
clpbn_bp:set_horus_flag(use_logarithms, true),
$extra_flag1, $extra_flag2,
run_query(_R),
open("$OUT_FILE_NAME", 'append',S),
format(S, '$3: ~15+ ',[]),
close(S).
EOF
}
function run_all_graphs
{
echo "*******************************************************************" >> "$OUT_FILE_NAME"
echo "results for solver $2" >> $OUT_FILE_NAME
echo "*******************************************************************" >> "$OUT_FILE_NAME"
run_solver town_1000 $1 town_1000 $3 $4 $5
#run_solver town_5000 $1 town_5000 $3 $4 $5
#run_solver town_10000 $1 town_10000 $3 $4 $5
#run_solver town_50000 $1 town_50000 $3 $4 $5
#run_solver town_100000 $1 town_100000 $3 $4 $5
#run_solver town_500000 $1 town_500000 $3 $4 $5
#run_solver town_1000000 $1 town_1000000 $3 $4 $5
}
run_all_graphs ve "ve "

View File

@ -9,7 +9,7 @@ OUT_FILE_NAME=results.log
rm -f $OUT_FILE_NAME
rm -f ignore.$OUT_FILE_NAME
# yap -g "['../../../../examples/School/sch32'], [missing5], use_module(library(clpbn/learning/em)), graph(L), clpbn:set_clpbn_flag(em_solver,bp), clpbn_horus:set_horus_flag(inf_alg,fg_bp), statistics(runtime, _), em(L,0.01,10,_,Lik), statistics(runtime, [T,_])."
# yap -g "['../../../../examples/School/sch32'], [missing5], use_module(library(clpbn/learning/em)), graph(L), clpbn:set_clpbn_flag(em_solver,bp), clpbn_horus:set_horus_flag(inf_alg, bp), statistics(runtime, _), em(L,0.01,10,_,Lik), statistics(runtime, [T,_])."
function run_solver
{
@ -59,8 +59,7 @@ function run_all_graphs
#run_all_graphs bp "hve(min_neighbors) " ve min_neighbors
#run_all_graphs bp "bn_bp(seq_fixed) " bn_bp seq_fixed
run_all_graphs bp "fg_bp(seq_fixed) " fg_bp seq_fixed
#run_all_graphs bp "bp(seq_fixed) " bp seq_fixed
#run_all_graphs bp "cbp(seq_fixed) " cbp seq_fixed
exit
@ -69,10 +68,8 @@ run_all_graphs bp "hve(min_neighbors) " ve min_neighbors
run_all_graphs bp "hve(min_weight) " ve min_weight
run_all_graphs bp "hve(min_fill) " ve min_fill
run_all_graphs bp "hve(w_min_fill) " ve weighted_min_fill
run_all_graphs bp "bn_bp(seq_fixed) " bn_bp seq_fixed
run_all_graphs bp "bn_bp(max_residual) " bn_bp max_residual
run_all_graphs bp "fg_bp(seq_fixed) " fg_bp seq_fixed
run_all_graphs bp "fg_bp(max_residual) " fg_bp max_residual
run_all_graphs bp "bp(seq_fixed) " bp seq_fixed
run_all_graphs bp "bp(max_residual) " bp max_residual
run_all_graphs bp "cbp(seq_fixed) " cbp seq_fixed
run_all_graphs bp "cbp(max_residual) " cbp max_residual
run_all_graphs gibbs "gibbs "

View File

@ -0,0 +1,11 @@
#!/bin/bash
source wa.sh
source ../benchs.sh
SOLVER="bp"
YAP=~/bin/$SHORTNAME-$SOLVER
run_all_graphs "bp(shedule=seq_fixed) " seq_fixed

View File

@ -0,0 +1,11 @@
#!/bin/bash
source wa.sh
source ../benchs.sh
SOLVER="cbp"
YAP=~/bin/$SHORTNAME-$SOLVER
run_all_graphs "cbp(shedule=seq_fixed) " seq_fixed

View File

@ -0,0 +1,12 @@
#!/bin/bash
source wa.sh
source ../benchs.sh
SOLVER="fove"
YAP=~/bin/$SHORTNAME-$SOLVER
run_all_graphs "fove "

View File

@ -0,0 +1,11 @@
#!/bin/bash
source wa.sh
source ../benchs.sh
SOLVER="hve"
YAP=~/bin/$SHORTNAME-$SOLVER
run_all_graphs "hve(elim_heuristic=min_neighbors) " min_neighbors

View File

@ -0,0 +1,27 @@
#!/home/tiago/bin/yap -L --
:- initialization(main).
main :-
unix(argv([H])),
generate_town(H).
generate_town(N) :-
atomic_concat(['pop_', N, '.yap'], FileName),
open(FileName, 'write', S),
atom_number(N, N2),
generate_people(S, N2, 4),
write(S, '\n'),
close(S).
generate_people(S, N, Counting) :-
Counting > N, !.
generate_people(S, N, Counting) :-
format(S, 'people(p~w).~n', [Counting]),
Counting1 is Counting + 1,
generate_people(S, N, Counting1).

View File

@ -0,0 +1,33 @@
#!/bin/bash
NETWORK="'../../examples/workshop_attrs'"
SHORTNAME="wa"
QUERY="series(X)"
function run_all_graphs
{
LOG_FILE=$SOLVER.log
#LOG_FILE=results`date "+ %H:%M:%S %d-%m-%Y"`.
rm -f $LOG_FILE
rm -f ignore.$LOG_FILE
cp ~/bin/yap $YAP
echo -n "**********************************" >> $LOG_FILE
echo "**********************************" >> $LOG_FILE
echo "results for solver $1" >> $LOG_FILE
echo -n "**********************************" >> $LOG_FILE
echo "**********************************" >> $LOG_FILE
run_solver pop_10 $2
#run_solver pop_1000 $2
#run_solver pop_5000 $2
#run_solver pop_10000 $2
#run_solver pop_50000 $2
#run_solver pop_100000 $2
#run_solver pop_500000 $2
#run_solver pop_1000000 $2
}

View File

@ -1,18 +0,0 @@
:- use_module(library(pfl)).
:- set_clpbn_flag(solver,fove).
c(x1,y1,z1).
c(x1,y1,z2).
c(x2,y2,z1).
c(x3,y2,z1).
bayes p(X)::[t,f] ; [0.2, 0.4] ; [c(X,_,_)].
bayes q(Y)::[t,f] ; [0.5, 0.6] ; [c(_,Y,_)].
bayes s(Z)::[t,f] , p(X) , q(Y) ; [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8] ; [c(X,Y,Z)].
% bayes series::[t,f] , attends(X) ; [0.5, 0.6, 0.7, 0.8] ; [c(X,_)].

View File

@ -1,81 +0,0 @@
<?xml version="1.0" encoding="US-ASCII"?>
<!--
B E
\ /
\ /
A
/ \
/ \
J M
-->
<BIF VERSION="0.3">
<NETWORK>
<NAME>Simple Loop</NAME>
<VARIABLE TYPE="nature">
<NAME>B</NAME>
<OUTCOME>b1</OUTCOME>
<OUTCOME>b2</OUTCOME>
</VARIABLE>
<VARIABLE TYPE="nature">
<NAME>E</NAME>
<OUTCOME>e1</OUTCOME>
<OUTCOME>e2</OUTCOME>
</VARIABLE>
<VARIABLE TYPE="nature">
<NAME>A</NAME>
<OUTCOME>a1</OUTCOME>
<OUTCOME>a2</OUTCOME>
</VARIABLE>
<VARIABLE TYPE="nature">
<NAME>J</NAME>
<OUTCOME>j1</OUTCOME>
<OUTCOME>j2</OUTCOME>
</VARIABLE>
<VARIABLE TYPE="nature">
<NAME>M</NAME>
<OUTCOME>m1</OUTCOME>
<OUTCOME>m2</OUTCOME>
</VARIABLE>
<DEFINITION>
<FOR>B</FOR>
<TABLE> .001 .999 </TABLE>
</DEFINITION>
<DEFINITION>
<FOR>E</FOR>
<TABLE> .002 .998 </TABLE>
</DEFINITION>
<DEFINITION>
<FOR>A</FOR>
<GIVEN>B</GIVEN>
<GIVEN>E</GIVEN>
<TABLE> .95 .05 .94 .06 .29 .71 .001 .999 </TABLE>
</DEFINITION>
<DEFINITION>
<FOR>J</FOR>
<GIVEN>A</GIVEN>
<TABLE> .9 .1 .05 .95 </TABLE>
</DEFINITION>
<DEFINITION>
<FOR>M</FOR>
<GIVEN>A</GIVEN>
<TABLE> .7 .3 .01 .99 </TABLE>
</DEFINITION>
</NETWORK>
</BIF>

View File

@ -1,29 +1,41 @@
:- use_module(library(clpbn)).
:- use_module(library(pfl)).
:- set_clpbn_flag(solver, bp).
%:- set_pfl_flag(solver,ve).
:- set_pfl_flag(solver,bp), clpbn_horus:set_horus_flag(inf_alg,ve).
%:- set_pfl_flag(solver,bp), clpbn_horus:set_horus_flag(inf_alg,bp).
%:- set_pfl_flag(solver,fove).
r(R) :- r_cpt(RCpt),
{ R = r with p([r1, r2], RCpt) }.
% :- yap_flag(write_strings, off).
t(T) :- t_cpt(TCpt),
{ T = t with p([t1, t2], TCpt) }.
a(A) :- r(R), t(T), a_cpt(ACpt),
{ A = a with p([a1, a2], ACpt, [R, T]) }.
bayes burglary::[b1,b3] ; [0.001, 0.999] ; [].
j(J) :- a(A), j_cpt(JCpt),
{ J = j with p([j1, j2], JCpt, [A]) }.
bayes earthquake::[e1,e2] ; [0.002, 0.998]; [].
m(M) :- a(A), m_cpt(MCpt),
{ M = m with p([m1, m2], MCpt, [A]) }.
bayes alarm::[a1,a2] , burglary, earthquake ; [0.95, 0.94, 0.29, 0.001, 0.05, 0.06, 0.71, 0.999] ; [].
bayes john_calls::[j1,j2] , alarm ; [0.9, 0.05, 0.1, 0.95] ; [].
bayes mary_calls::[m1,m2] , alarm ; [0.7, 0.01, 0.3, 0.99] ; [].
b_cpt([0.001, 0.999]).
e_cpt([0.002, 0.998]).
r_cpt([0.001, 0.999]).
t_cpt([0.002, 0.998]).
a_cpt([0.95, 0.94, 0.29, 0.001,
0.05, 0.06, 0.71, 0.999]).
j_cpt([0.9, 0.05,
jc_cpt([0.9, 0.05,
0.1, 0.95]).
m_cpt([0.7, 0.01,
mc_cpt([0.7, 0.01,
0.3, 0.99]).
% ?- alarm(A).
?- john_calls(J), mary_calls(m1).
%?- john_calls(J), mary_calls(m1), alarm(a1).
%?- john_calls(J), alarm(a1).

View File

@ -1,3 +1,5 @@
# example in counting belief propagation paper
MARKOV
3
2 2 2
@ -5,7 +7,6 @@ MARKOV
2 0 1
2 2 1
4
1.2 1.4 2.0 0.4

View File

@ -0,0 +1,102 @@
:- use_module(library(pfl)).
:- clpbn_horus:set_solver(fove).
%:- clpbn_horus:set_solver(hve).
%:- clpbn_horus:set_solver(bp).
%:- clpbn_horus:set_solver(cbp).
people(joe,nyc).
people(p2, nyc).
people(p3, nyc).
ev(descn(p2, t)).
ev(descn(p3, t)).
% :- [city_7].
bayes city_conservativeness(C)::[y,n] ; cons_table(C) ; [people(_,C)].
bayes gender(P)::[m,f] ; gender_table(P) ; [people(P,_)].
bayes hair_color(P)::[t,f], city_conservativeness(C) ; hair_color_table(P) ; [people(P,C)].
bayes car_color(P)::[t,f], hair_color(P) ; car_color_table(P); [people(P,_)].
bayes height(P)::[t,f], gender(P) ; height_table(P) ; [people(P,_)].
bayes shoe_size(P):[t,f], height(P) ; shoe_size_table(P); [people(P,_)].
bayes guilty(P)::[y,n] ; guilty_table(P) ; [people(P,_)].
bayes descn(P)::[t,f], car_color(P), hair_color(P), height(P), guilty(P) ; descn_table(P) ; [people(P,_)].
bayes witness(C)::[t,f], descn(Joe), descn(P2) ; wit_table ; [people(_,C), Joe=joe, P2=p2].
cons_table(amsterdam, [0.2, 0.8]) :- !.
cons_table(_, [0.8, 0.2]).
gender_table(_, [0.55, 0.44]).
hair_color_table(_,
/* conservative_city */
/* y n */
[ 0.05, 0.1,
0.95, 0.9 ]).
car_color_table(_,
/* t f */
[ 0.9, 0.2,
0.1, 0.8 ]).
height_table(_,
/* m f */
[ 0.6, 0.4,
0.4, 0.6 ]).
shoe_size_table(_,
/* t f */
[ 0.9, 0.1,
0.1, 0.9 ]).
guilty_table(_, [0.23, 0.77]).
descn_table(_,
/* color, hair, height, guilt */
/* ttttt tttf ttft ttff tfttt tftf tfft tfff ttttt fttf ftft ftff ffttt fftf ffft ffff */
[ 0.99, 0.5, 0.23, 0.88, 0.41, 0.3, 0.76, 0.87, 0.44, 0.43, 0.29, 0.72, 0.33, 0.91, 0.95, 0.92,
0.01, 0.5, 0.77, 0.12, 0.59, 0.7, 0.24, 0.13, 0.56, 0.57, 0.61, 0.28, 0.77, 0.09, 0.05, 0.08]).
wit_table([0.2, 0.45, 0.24, 0.34,
0.8, 0.55, 0.76, 0.66]).
runall(G, Wrapper) :-
findall(G, Wrapper, L),
execute_all(L).
execute_all([]).
execute_all(G.L) :-
call(G),
execute_all(L).
is_joe_guilty(Guilty) :-
witness(nyc, t),
runall(X, ev(X)),
guilty(joe, Guilty).
% ?- is_joe_guilty(Guilty)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,9 +1,11 @@
:- use_module(library(pfl)).
%:- set_clpbn_flag(solver,ve).
%:- set_clpbn_flag(solver,bp), clpbn_bp:set_horus_flag(inf_alg,ve).
:- set_clpbn_flag(solver,fove).
:- clpbn_horus:set_solver(fove).
%:- clpbn_horus:set_solver(hve).
%:- clpbn_horus:set_solver(bp).
%:- clpbn_horus:set_solver(cbp).
:- yap_flag(write_strings, off).
c(p1,w1).
c(p1,w2).
@ -25,8 +27,5 @@ markov attends(P)::[t,f] , hot(W)::[t,f] ; [0.1, 0.2, 0.3, 0.4] ; [c(P,W)].
markov attends(P)::[t,f], series::[t,f] ; [0.5, 0.6, 0.7, 0.8] ; [c(P,_)].
:- clpbn_horus:set_horus_flag(use_logarithms,true).
?- series(X).
% ?- series(X).

View File

@ -2,7 +2,9 @@
:- use_module(library(pfl)).
:- set_pfl_flag(solver,fove).
%:- set_pfl_flag(solver,fove).
%:- set_pfl_flag(solver,bp), clpbn_horus:set_horus_flag(inf_alg,ve).
%:- set_pfl_flag(solver,bp), clpbn_horus:set_horus_flag(inf_alg,bp).
%:- set_pfl_flag(solver,bp), clpbn_horus:set_horus_flag(inf_alg,cbp).
t(ann).
@ -10,10 +12,10 @@ t(dave).
% p(ann,t).
bayes p(X)::[t,f] ; [0.1, 0.3] ; [t(X)].
markov p(X)::[t,f] ; [0.1, 0.3] ; [t(X)].
% use standard Prolog queries: provide evidence first.
?- p(dave,t), p(ann,X).
?- p(ann,t), p(ann,X).
% ?- p(ann,X).

View File

@ -1,32 +1,24 @@
:- use_module(library(pfl)).
%:- set_pfl_flag(solver,ve).
:- set_pfl_flag(solver,bp), clpbn_bp:set_horus_flag(inf_alg,ve).
% :- set_pfl_flag(solver,fove).
:- clpbn_horus:set_solver(fove).
%:- clpbn_horus:set_solver(hve).
%:- clpbn_horus:set_solver(bp).
%:- clpbn_horus:set_solver(cbp).
:- yap_flag(write_strings, off).
friendly(P1, P2) :-
person(P1),
person(P2),
P1 @> P2.
person(john).
person(maggie).
person(harry).
person(bill).
person(matt).
person(diana).
person(bob).
person(dick).
person(burr).
person(ann).
friends(P1, P2) :-
people(P1),
people(P2),
P1 \= P2.
person @ 2.
people @ 3.
markov smokes(P)::[t,f] , cancer(P)::[t,f] ; [0.1, 0.2, 0.3, 0.4] ; [person(P)].
markov smokes(P)::[t,f], cancer(P)::[t,f] ; [0.1, 0.2, 0.3, 0.4] ; [people(P)].
markov friend(P1,P2)::[t,f], smokes(P1)::[t,f], smokes(P2)::[t,f] ; [0.5, 0.6, 0.7, 0.8] ; [friendly(P1, P2)].
markov friend(P1,P2)::[t,f], smokes(P1)::[t,f], smokes(P2)::[t,f] ;
[0.5, 0.6, 0.7, 0.8, 0.5, 0.6, 0.7, 0.8] ; [friends(P1, P2)].
% ?- smokes(p1, t), smokes(p2, f), friend(p1, p2, X).
?- smokes(person_0, t), smokes(person_1, t), friend(person_0, person_1, F).

View File

@ -1,120 +0,0 @@
14
1
6
2
2
0 9.974182
1 1.000000
1
7
2
2
0 9.974182
1 1.000000
1
4
2
2
0 4.055200
1 1.000000
1
5
2
2
0 4.055200
1 1.000000
1
0
2
2
0 7.389056
1 1.000000
1
2
2
2
0 7.389056
1 1.000000
1
1
2
2
0 7.389056
1 1.000000
1
3
2
2
0 7.389056
1 1.000000
2
4 6
2 2
4
0 4.481689
1 1.000000
2 4.481689
3 4.481689
2
5 7
2 2
4
0 4.481689
1 1.000000
2 4.481689
3 4.481689
2
0 4
2 2
4
0 3.004166
1 3.004166
2 3.004166
3 3.004166
3
2 5 4
2 2 2
8
0 3.004166
1 3.004166
2 3.004166
3 1.000000
4 3.004166
5 1.000000
6 3.004166
7 3.004166
3
1 4 5
2 2 2
8
0 3.004166
1 3.004166
2 3.004166
3 1.000000
4 3.004166
5 1.000000
6 3.004166
7 3.004166
2
3 5
2 2
4
0 3.004166
1 3.004166
2 3.004166
3 3.004166

View File

@ -1,239 +0,0 @@
27
1
12
2
2
0 9.974182
1 1.000000
1
13
2
2
0 9.974182
1 1.000000
1
14
2
2
0 9.974182
1 1.000000
1
9
2
2
0 4.055200
1 1.000000
1
10
2
2
0 4.055200
1 1.000000
1
11
2
2
0 4.055200
1 1.000000
1
0
2
2
0 7.389056
1 1.000000
1
3
2
2
0 7.389056
1 1.000000
1
6
2
2
0 7.389056
1 1.000000
1
1
2
2
0 7.389056
1 1.000000
1
4
2
2
0 7.389056
1 1.000000
1
7
2
2
0 7.389056
1 1.000000
1
2
2
2
0 7.389056
1 1.000000
1
5
2
2
0 7.389056
1 1.000000
1
8
2
2
0 7.389056
1 1.000000
2
9 12
2 2
4
0 4.481689
1 1.000000
2 4.481689
3 4.481689
2
10 13
2 2
4
0 4.481689
1 1.000000
2 4.481689
3 4.481689
2
11 14
2 2
4
0 4.481689
1 1.000000
2 4.481689
3 4.481689
2
0 9
2 2
4
0 3.004166
1 3.004166
2 3.004166
3 3.004166
3
3 10 9
2 2 2
8
0 3.004166
1 3.004166
2 3.004166
3 1.000000
4 3.004166
5 1.000000
6 3.004166
7 3.004166
3
6 11 9
2 2 2
8
0 3.004166
1 3.004166
2 3.004166
3 1.000000
4 3.004166
5 1.000000
6 3.004166
7 3.004166
3
1 9 10
2 2 2
8
0 3.004166
1 3.004166
2 3.004166
3 1.000000
4 3.004166
5 1.000000
6 3.004166
7 3.004166
2
4 10
2 2
4
0 3.004166
1 3.004166
2 3.004166
3 3.004166
3
7 11 10
2 2 2
8
0 3.004166
1 3.004166
2 3.004166
3 1.000000
4 3.004166
5 1.000000
6 3.004166
7 3.004166
3
2 9 11
2 2 2
8
0 3.004166
1 3.004166
2 3.004166
3 1.000000
4 3.004166
5 1.000000
6 3.004166
7 3.004166
3
5 10 11
2 2 2
8
0 3.004166
1 3.004166
2 3.004166
3 1.000000
4 3.004166
5 1.000000
6 3.004166
7 3.004166
2
8 11
2 2
4
0 3.004166
1 3.004166
2 3.004166
3 3.004166

View File

@ -1,398 +0,0 @@
44
1
20
2
2
0 9.974182
1 1.000000
1
21
2
2
0 9.974182
1 1.000000
1
22
2
2
0 9.974182
1 1.000000
1
23
2
2
0 9.974182
1 1.000000
1
16
2
2
0 4.055200
1 1.000000
1
17
2
2
0 4.055200
1 1.000000
1
18
2
2
0 4.055200
1 1.000000
1
19
2
2
0 4.055200
1 1.000000
1
0
2
2
0 7.389056
1 1.000000
1
4
2
2
0 7.389056
1 1.000000
1
8
2
2
0 7.389056
1 1.000000
1
12
2
2
0 7.389056
1 1.000000
1
1
2
2
0 7.389056
1 1.000000
1
5
2
2
0 7.389056
1 1.000000
1
9
2
2
0 7.389056
1 1.000000
1
13
2
2
0 7.389056
1 1.000000
1
2
2
2
0 7.389056
1 1.000000
1
6
2
2
0 7.389056
1 1.000000
1
10
2
2
0 7.389056
1 1.000000
1
14
2
2
0 7.389056
1 1.000000
1
3
2
2
0 7.389056
1 1.000000
1
7
2
2
0 7.389056
1 1.000000
1
11
2
2
0 7.389056
1 1.000000
1
15
2
2
0 7.389056
1 1.000000
2
16 20
2 2
4
0 4.481689
1 1.000000
2 4.481689
3 4.481689
2
17 21
2 2
4
0 4.481689
1 1.000000
2 4.481689
3 4.481689
2
18 22
2 2
4
0 4.481689
1 1.000000
2 4.481689
3 4.481689
2
19 23
2 2
4
0 4.481689
1 1.000000
2 4.481689
3 4.481689
2
0 16
2 2
4
0 3.004166
1 3.004166
2 3.004166
3 3.004166
3
4 17 16
2 2 2
8
0 3.004166
1 3.004166
2 3.004166
3 1.000000
4 3.004166
5 1.000000
6 3.004166
7 3.004166
3
8 18 16
2 2 2
8
0 3.004166
1 3.004166
2 3.004166
3 1.000000
4 3.004166
5 1.000000
6 3.004166
7 3.004166
3
12 19 16
2 2 2
8
0 3.004166
1 3.004166
2 3.004166
3 1.000000
4 3.004166
5 1.000000
6 3.004166
7 3.004166
3
1 16 17
2 2 2
8
0 3.004166
1 3.004166
2 3.004166
3 1.000000
4 3.004166
5 1.000000
6 3.004166
7 3.004166
2
5 17
2 2
4
0 3.004166
1 3.004166
2 3.004166
3 3.004166
3
9 18 17
2 2 2
8
0 3.004166
1 3.004166
2 3.004166
3 1.000000
4 3.004166
5 1.000000
6 3.004166
7 3.004166
3
13 19 17
2 2 2
8
0 3.004166
1 3.004166
2 3.004166
3 1.000000
4 3.004166
5 1.000000
6 3.004166
7 3.004166
3
2 16 18
2 2 2
8
0 3.004166
1 3.004166
2 3.004166
3 1.000000
4 3.004166
5 1.000000
6 3.004166
7 3.004166
3
6 17 18
2 2 2
8
0 3.004166
1 3.004166
2 3.004166
3 1.000000
4 3.004166
5 1.000000
6 3.004166
7 3.004166
2
10 18
2 2
4
0 3.004166
1 3.004166
2 3.004166
3 3.004166
3
14 19 18
2 2 2
8
0 3.004166
1 3.004166
2 3.004166
3 1.000000
4 3.004166
5 1.000000
6 3.004166
7 3.004166
3
3 16 19
2 2 2
8
0 3.004166
1 3.004166
2 3.004166
3 1.000000
4 3.004166
5 1.000000
6 3.004166
7 3.004166
3
7 17 19
2 2 2
8
0 3.004166
1 3.004166
2 3.004166
3 1.000000
4 3.004166
5 1.000000
6 3.004166
7 3.004166
3
11 18 19
2 2 2
8
0 3.004166
1 3.004166
2 3.004166
3 1.000000
4 3.004166
5 1.000000
6 3.004166
7 3.004166
2
15 19
2 2
4
0 3.004166
1 3.004166
2 3.004166
3 3.004166

View File

@ -1,32 +1,27 @@
:- use_module(library(pfl)).
%:- set_clpbn_flag(solver,ve).
%:- set_clpbn_flag(solver,bp), clpbn_bp:set_horus_flag(inf_alg,ve).
:- set_clpbn_flag(solver,fove).
%:- clpbn_horus:set_solver(fove).
%:- clpbn_horus:set_solver(hve).
:- clpbn_horus:set_solver(bp).
%:- clpbn_horus:set_solver(cbp).
c(p1).
c(p2).
c(p3).
c(p4).
c(p5).
:- yap_flag(write_strings, off).
people @ 3.
markov attends(P)::[t,f] , attr1::[t,f] ; [0.1, 0.2, 0.3, 0.4] ; [c(P)].
markov attends(P)::[t,f], attr1::[t,f] ; [0.11, 0.2, 0.3, 0.4] ; [people(P)].
markov attends(P)::[t,f] , attr2::[t,f] ; [0.1, 0.2, 0.3, 0.4] ; [c(P)].
markov attends(P)::[t,f], attr2::[t,f] ; [0.1, 0.22, 0.3, 0.4] ; [people(P)].
markov attends(P)::[t,f] , attr3::[t,f] ; [0.1, 0.2, 0.3, 0.4] ; [c(P)].
markov attends(P)::[t,f], attr3::[t,f] ; [0.1, 0.2, 0.33, 0.4] ; [people(P)].
markov attends(P)::[t,f] , attr4::[t,f] ; [0.1, 0.2, 0.3, 0.4] ; [c(P)].
markov attends(P)::[t,f], attr4::[t,f] ; [0.1, 0.2, 0.3, 0.44] ; [people(P)].
markov attends(P)::[t,f] , attr5::[t,f] ; [0.1, 0.2, 0.3, 0.4] ; [c(P)].
markov attends(P)::[t,f], attr5::[t,f] ; [0.1, 0.2, 0.3, 0.45] ; [people(P)].
markov attends(P)::[t,f] , attr6::[t,f] ; [0.1, 0.2, 0.3, 0.4] ; [c(P)].
markov attends(P)::[t,f], attr6::[t,f] ; [0.1, 0.2, 0.3, 0.46] ; [people(P)].
markov attends(P)::[t,f], series::[t,f] ; [0.5, 0.6, 0.7, 0.8] ; [c(P)].
markov attends(P)::[t,f], series::[t,f] ; [0.5, 0.6, 0.7, 0.87] ; [people(P)].
%:- clpbn_horus:set_horus_flag(use_logarithms,true).
?- series(X).
% ?- series(X).

View File

@ -1,241 +0,0 @@
Aladdin Free Public License
(Version 8, November 18, 1999)
Copyright (C) 1994, 1995, 1997, 1998, 1999 Aladdin Enterprises,
Menlo Park, California, U.S.A. All rights reserved.
*NOTE:* This License is not the same as any of the GNU Licenses
<http://www.gnu.org/copyleft/gpl.html> published by the Free
Software Foundation <http://www.gnu.org/>. Its terms are
substantially different from those of the GNU Licenses. If you are
familiar with the GNU Licenses, please read this license with extra
care.
Aladdin Enterprises hereby grants to anyone the permission to apply this
License to their own work, as long as the entire License (including the
above notices and this paragraph) is copied with no changes, additions,
or deletions except for changing the first paragraph of Section 0 to
include a suitable description of the work to which the license is being
applied and of the person or entity that holds the copyright in the
work, and, if the License is being applied to a work created in a
country other than the United States, replacing the first paragraph of
Section 6 with an appropriate reference to the laws of the appropriate
country.
0. Subject Matter
This License applies to the computer program known as "XMLParser library".
The "Program", below, refers to such program. The Program
is a copyrighted work whose copyright is held by Frank Vanden Berghen
(the "Licensor").
A "work based on the Program" means either the Program or any derivative
work of the Program, as defined in the United States Copyright Act of
1976, such as a translation or a modification.
* BY MODIFYING OR DISTRIBUTING THE PROGRAM (OR ANY WORK BASED ON THE
PROGRAM), YOU INDICATE YOUR ACCEPTANCE OF THIS LICENSE TO DO SO, AND ALL
ITS TERMS AND CONDITIONS FOR COPYING, DISTRIBUTING OR MODIFYING THE
PROGRAM OR WORKS BASED ON IT. NOTHING OTHER THAN THIS LICENSE GRANTS YOU
PERMISSION TO MODIFY OR DISTRIBUTE THE PROGRAM OR ITS DERIVATIVE WORKS.
THESE ACTIONS ARE PROHIBITED BY LAW. IF YOU DO NOT ACCEPT THESE TERMS
AND CONDITIONS, DO NOT MODIFY OR DISTRIBUTE THE PROGRAM. *
1. Licenses.
Licensor hereby grants you the following rights, provided that you
comply with all of the restrictions set forth in this License and
provided, further, that you distribute an unmodified copy of this
License with the Program:
(a)
You may copy and distribute literal (i.e., verbatim) copies of the
Program's source code as you receive it throughout the world, in any
medium.
(b)
You may modify the Program, create works based on the Program and
distribute copies of such throughout the world, in any medium.
2. Restrictions.
This license is subject to the following restrictions:
(a)
Distribution of the Program or any work based on the Program by a
commercial organization to any third party is prohibited if any
payment is made in connection with such distribution, whether
directly (as in payment for a copy of the Program) or indirectly (as
in payment for some service related to the Program, or payment for
some product or service that includes a copy of the Program "without
charge"; these are only examples, and not an exhaustive enumeration
of prohibited activities). The following methods of distribution
involving payment shall not in and of themselves be a violation of
this restriction:
(i)
Posting the Program on a public access information storage and
retrieval service for which a fee is received for retrieving
information (such as an on-line service), provided that the fee
is not content-dependent (i.e., the fee would be the same for
retrieving the same volume of information consisting of random
data) and that access to the service and to the Program is
available independent of any other product or service. An
example of a service that does not fall under this section is an
on-line service that is operated by a company and that is only
available to customers of that company. (This is not an
exhaustive enumeration.)
(ii)
Distributing the Program on removable computer-readable media,
provided that the files containing the Program are reproduced
entirely and verbatim on such media, that all information on
such media be redistributable for non-commercial purposes
without charge, and that such media are distributed by
themselves (except for accompanying documentation) independent
of any other product or service. Examples of such media include
CD-ROM, magnetic tape, and optical storage media. (This is not
intended to be an exhaustive list.) An example of a distribution
that does not fall under this section is a CD-ROM included in a
book or magazine. (This is not an exhaustive enumeration.)
(b)
Activities other than copying, distribution and modification of the
Program are not subject to this License and they are outside its
scope. Functional use (running) of the Program is not restricted,
and any output produced through the use of the Program is subject to
this license only if its contents constitute a work based on the
Program (independent of having been made by running the Program).
(c)
You must meet all of the following conditions with respect to any
work that you distribute or publish that in whole or in part
contains or is derived from the Program or any part thereof ("the
Work"):
(i)
If you have modified the Program, you must cause the Work to
carry prominent notices stating that you have modified the
Program's files and the date of any change. In each source file
that you have modified, you must include a prominent notice that
you have modified the file, including your name, your e-mail
address (if any), and the date and purpose of the change;
(ii)
You must cause the Work to be licensed as a whole and at no
charge to all third parties under the terms of this License;
(iii)
If the Work normally reads commands interactively when run, you
must cause it, at each time the Work commences operation, to
print or display an announcement including an appropriate
copyright notice and a notice that there is no warranty (or
else, saying that you provide a warranty). Such notice must also
state that users may redistribute the Work only under the
conditions of this License and tell the user how to view the
copy of this License included with the Work. (Exceptions: if the
Program is interactive but normally prints or displays such an
announcement only at the request of a user, such as in an "About
box", the Work is required to print or display the notice only
under the same circumstances; if the Program itself is
interactive but does not normally print such an announcement,
the Work is not required to print an announcement.);
(iv)
You must accompany the Work with the complete corresponding
machine-readable source code, delivered on a medium customarily
used for software interchange. The source code for a work means
the preferred form of the work for making modifications to it.
For an executable work, complete source code means all the
source code for all modules it contains, plus any associated
interface definition files, plus the scripts used to control
compilation and installation of the executable code. If you
distribute with the Work any component that is normally
distributed (in either source or binary form) with the major
components (compiler, kernel, and so on) of the operating system
on which the executable runs, you must also distribute the
source code of that component if you have it and are allowed to
do so;
(v)
If you distribute any written or printed material at all with
the Work, such material must include either a written copy of
this License, or a prominent written indication that the Work is
covered by this License and written instructions for printing
and/or displaying the copy of the License on the distribution
medium;
(vi)
You may not impose any further restrictions on the recipient's
exercise of the rights granted herein.
If distribution of executable or object code is made by offering the
equivalent ability to copy from a designated place, then offering
equivalent ability to copy the source code from the same place counts as
distribution of the source code, even though third parties are not
compelled to copy the source code along with the object code.
3. Reservation of Rights.
No rights are granted to the Program except as expressly set forth
herein. You may not copy, modify, sublicense, or distribute the Program
except as expressly provided under this License. Any attempt otherwise
to copy, modify, sublicense or distribute the Program is void, and will
automatically terminate your rights under this License. However, parties
who have received copies, or rights, from you under this License will
not have their licenses terminated so long as such parties remain in
full compliance.
4. Other Restrictions.
If the distribution and/or use of the Program is restricted in certain
countries for any reason, Licensor may add an explicit geographical
distribution limitation excluding those countries, so that distribution
is permitted only in or among countries not thus excluded. In such case,
this License incorporates the limitation as if written in the body of
this License.
5. Limitations.
* THE PROGRAM IS PROVIDED TO YOU "AS IS," WITHOUT WARRANTY. THERE IS NO
WARRANTY FOR THE PROGRAM, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT
NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF THIRD PARTY RIGHTS. THE
ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH
YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL
NECESSARY SERVICING, REPAIR OR CORRECTION. *
* IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
WILL LICENSOR, OR ANY OTHER PARTY WHO MAY MODIFY AND/OR REDISTRIBUTE THE
PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS
OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR
THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER
PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE
POSSIBILITY OF SUCH DAMAGES. *
6. General.
This License is governed by the laws of Belgium., excluding choice of
law rules.
If any part of this License is found to be in conflict with the law,
that part shall be interpreted in its broadest meaning consistent with
the law, and no other parts of the License shall be affected.
For United States Government users, the Program is provided with
*RESTRICTED RIGHTS*. If you are a unit or agency of the United States
Government or are acquiring the Program for any such unit or agency, the
following apply:
If the unit or agency is the Department of Defense ("DOD"), the
Program and its documentation are classified as "commercial computer
software" and "commercial computer software documentation"
respectively and, pursuant to DFAR Section 227.7202, the Government
is acquiring the Program and its documentation in accordance with
the terms of this License. If the unit or agency is other than DOD,
the Program and its documentation are classified as "commercial
computer software" and "commercial computer software documentation"
respectively and, pursuant to FAR Section 12.212, the Government is
acquiring the Program and its documentation in accordance with the
terms of this License.

File diff suppressed because it is too large Load Diff

View File

@ -1,734 +0,0 @@
/****************************************************************************/
/*! \mainpage XMLParser library
* \section intro_sec Introduction
*
* This is a basic XML parser written in ANSI C++ for portability.
* It works by using recursion and a node tree for breaking
* down the elements of an XML document.
*
* @version V2.42
* @author Frank Vanden Berghen
*
* Copyright (c) 2002, Business-Insight
* <a href="http://www.Business-Insight.com">Business-Insight</a>
* All rights reserved.
* See the file <a href="../../AFPL-license.txt">AFPL-license.txt</a> about the licensing terms
*
* \section tutorial First Tutorial
* You can follow a simple <a href="../../xmlParser.html">Tutorial</a> to know the basics...
*
* \section usage General usage: How to include the XMLParser library inside your project.
*
* The library is composed of two files: <a href="../../xmlParser.cpp">xmlParser.cpp</a> and
* <a href="../../xmlParser.h">xmlParser.h</a>. These are the ONLY 2 files that you need when
* using the library inside your own projects.
*
* All the functions of the library are documented inside the comments of the file
* <a href="../../xmlParser.h">xmlParser.h</a>. These comments can be transformed in
* full-fledged HTML documentation using the DOXYGEN software: simply type: "doxygen doxy.cfg"
*
* By default, the XMLParser library uses (char*) for string representation.To use the (wchar_t*)
* version of the library, you need to define the "_UNICODE" preprocessor definition variable
* (this is usually done inside your project definition file) (This is done automatically for you
* when using Visual Studio).
*
* \section example Advanced Tutorial and Many Examples of usage.
*
* Some very small introductory examples are described inside the Tutorial file
* <a href="../../xmlParser.html">xmlParser.html</a>
*
* Some additional small examples are also inside the file <a href="../../xmlTest.cpp">xmlTest.cpp</a>
* (for the "char*" version of the library) and inside the file
* <a href="../../xmlTestUnicode.cpp">xmlTestUnicode.cpp</a> (for the "wchar_t*"
* version of the library). If you have a question, please review these additionnal examples
* before sending an e-mail to the author.
*
* To build the examples:
* - linux/unix: type "make"
* - solaris: type "make -f makefile.solaris"
* - windows: Visual Studio: double-click on xmlParser.dsw
* (under Visual Studio .NET, the .dsp and .dsw files will be automatically converted to .vcproj and .sln files)
*
* In order to build the examples you need some additional files:
* - linux/unix: makefile
* - solaris: makefile.solaris
* - windows: Visual Studio: *.dsp, xmlParser.dsw and also xmlParser.lib and xmlParser.dll
*
* \section debugging Debugging with the XMLParser library
*
* \subsection debugwin Debugging under WINDOWS
*
* Inside Visual C++, the "debug versions" of the memory allocation functions are
* very slow: Do not forget to compile in "release mode" to get maximum speed.
* When I had to debug a software that was using the XMLParser Library, it was usually
* a nightmare because the library was sooOOOoooo slow in debug mode (because of the
* slow memory allocations in Debug mode). To solve this
* problem, during all the debugging session, I am now using a very fast DLL version of the
* XMLParser Library (the DLL is compiled in release mode). Using the DLL version of
* the XMLParser Library allows me to have lightening XML parsing speed even in debug!
* Other than that, the DLL version is useless: In the release version of my tool,
* I always use the normal, ".cpp"-based, XMLParser Library (I simply include the
* <a href="../../xmlParser.cpp">xmlParser.cpp</a> and
* <a href="../../xmlParser.h">xmlParser.h</a> files into the project).
*
* The file <a href="../../XMLNodeAutoexp.txt">XMLNodeAutoexp.txt</a> contains some
* "tweaks" that improve substancially the display of the content of the XMLNode objects
* inside the Visual Studio Debugger. Believe me, once you have seen inside the debugger
* the "smooth" display of the XMLNode objects, you cannot live without it anymore!
*
* \subsection debuglinux Debugging under LINUX/UNIX
*
* The speed of the debug version of the XMLParser library is tolerable so no extra
* work.has been done.
*
****************************************************************************/
#ifndef __INCLUDE_XML_NODE__
#define __INCLUDE_XML_NODE__
#include <stdlib.h>
#ifdef _UNICODE
// If you comment the next "define" line then the library will never "switch to" _UNICODE (wchar_t*) mode (16/32 bits per characters).
// This is useful when you get error messages like:
// 'XMLNode::openFileHelper' : cannot convert parameter 2 from 'const char [5]' to 'const wchar_t *'
// The _XMLWIDECHAR preprocessor variable force the XMLParser library into either utf16/32-mode (the proprocessor variable
// must be defined) or utf8-mode(the pre-processor variable must be undefined).
#define _XMLWIDECHAR
#endif
#if defined(WIN32) || defined(UNDER_CE) || defined(_WIN32) || defined(WIN64) || defined(__BORLANDC__)
// comment the next line if you are under windows and the compiler is not Microsoft Visual Studio (6.0 or .NET) or Borland
#define _XMLWINDOWS
#endif
#ifdef XMLDLLENTRY
#undef XMLDLLENTRY
#endif
#ifdef _USE_XMLPARSER_DLL
#ifdef _DLL_EXPORTS_
#define XMLDLLENTRY __declspec(dllexport)
#else
#define XMLDLLENTRY __declspec(dllimport)
#endif
#else
#define XMLDLLENTRY
#endif
// uncomment the next line if you want no support for wchar_t* (no need for the <wchar.h> or <tchar.h> libraries anymore to compile)
//#define XML_NO_WIDE_CHAR
#ifdef XML_NO_WIDE_CHAR
#undef _XMLWINDOWS
#undef _XMLWIDECHAR
#endif
#ifdef _XMLWINDOWS
#include <tchar.h>
#else
#define XMLDLLENTRY
#ifndef XML_NO_WIDE_CHAR
#include <wchar.h> // to have 'wcsrtombs' for ANSI version
// to have 'mbsrtowcs' for WIDECHAR version
#endif
#endif
// Some common types for char set portable code
#ifdef _XMLWIDECHAR
#define _CXML(c) L ## c
#define XMLCSTR const wchar_t *
#define XMLSTR wchar_t *
#define XMLCHAR wchar_t
#else
#define _CXML(c) c
#define XMLCSTR const char *
#define XMLSTR char *
#define XMLCHAR char
#endif
#ifndef FALSE
#define FALSE 0
#endif /* FALSE */
#ifndef TRUE
#define TRUE 1
#endif /* TRUE */
/// Enumeration for XML parse errors.
typedef enum XMLError
{
eXMLErrorNone = 0,
eXMLErrorMissingEndTag,
eXMLErrorNoXMLTagFound,
eXMLErrorEmpty,
eXMLErrorMissingTagName,
eXMLErrorMissingEndTagName,
eXMLErrorUnmatchedEndTag,
eXMLErrorUnmatchedEndClearTag,
eXMLErrorUnexpectedToken,
eXMLErrorNoElements,
eXMLErrorFileNotFound,
eXMLErrorFirstTagNotFound,
eXMLErrorUnknownCharacterEntity,
eXMLErrorCharacterCodeAbove255,
eXMLErrorCharConversionError,
eXMLErrorCannotOpenWriteFile,
eXMLErrorCannotWriteFile,
eXMLErrorBase64DataSizeIsNotMultipleOf4,
eXMLErrorBase64DecodeIllegalCharacter,
eXMLErrorBase64DecodeTruncatedData,
eXMLErrorBase64DecodeBufferTooSmall
} XMLError;
/// Enumeration used to manage type of data. Use in conjunction with structure XMLNodeContents
typedef enum XMLElementType
{
eNodeChild=0,
eNodeAttribute=1,
eNodeText=2,
eNodeClear=3,
eNodeNULL=4
} XMLElementType;
/// Structure used to obtain error details if the parse fails.
typedef struct XMLResults
{
enum XMLError error;
int nLine,nColumn;
} XMLResults;
/// Structure for XML clear (unformatted) node (usually comments)
typedef struct XMLClear {
XMLCSTR lpszValue; XMLCSTR lpszOpenTag; XMLCSTR lpszCloseTag;
} XMLClear;
/// Structure for XML attribute.
typedef struct XMLAttribute {
XMLCSTR lpszName; XMLCSTR lpszValue;
} XMLAttribute;
/// XMLElementPosition are not interchangeable with simple indexes
typedef int XMLElementPosition;
struct XMLNodeContents;
/** @defgroup XMLParserGeneral The XML parser */
/// Main Class representing a XML node
/**
* All operations are performed using this class.
* \note The constructors of the XMLNode class are protected, so use instead one of these four methods to get your first instance of XMLNode:
* <ul>
* <li> XMLNode::parseString </li>
* <li> XMLNode::parseFile </li>
* <li> XMLNode::openFileHelper </li>
* <li> XMLNode::createXMLTopNode (or XMLNode::createXMLTopNode_WOSD)</li>
* </ul> */
typedef struct XMLDLLENTRY XMLNode
{
private:
struct XMLNodeDataTag;
/// Constructors are protected, so use instead one of: XMLNode::parseString, XMLNode::parseFile, XMLNode::openFileHelper, XMLNode::createXMLTopNode
XMLNode(struct XMLNodeDataTag *pParent, XMLSTR lpszName, char isDeclaration);
/// Constructors are protected, so use instead one of: XMLNode::parseString, XMLNode::parseFile, XMLNode::openFileHelper, XMLNode::createXMLTopNode
XMLNode(struct XMLNodeDataTag *p);
public:
static XMLCSTR getVersion();///< Return the XMLParser library version number
/** @defgroup conversions Parsing XML files/strings to an XMLNode structure and Rendering XMLNode's to files/string.
* @ingroup XMLParserGeneral
* @{ */
/// Parse an XML string and return the root of a XMLNode tree representing the string.
static XMLNode parseString (XMLCSTR lpXMLString, XMLCSTR tag=NULL, XMLResults *pResults=NULL);
/**< The "parseString" function parse an XML string and return the root of a XMLNode tree. The "opposite" of this function is
* the function "createXMLString" that re-creates an XML string from an XMLNode tree. If the XML document is corrupted, the
* "parseString" method will initialize the "pResults" variable with some information that can be used to trace the error.
* If you still want to parse the file, you can use the APPROXIMATE_PARSING option as explained inside the note at the
* beginning of the "xmlParser.cpp" file.
*
* @param lpXMLString the XML string to parse
* @param tag the name of the first tag inside the XML file. If the tag parameter is omitted, this function returns a node that represents the head of the xml document including the declaration term (<? ... ?>).
* @param pResults a pointer to a XMLResults variable that will contain some information that can be used to trace the XML parsing error. You can have a user-friendly explanation of the parsing error with the "getError" function.
*/
/// Parse an XML file and return the root of a XMLNode tree representing the file.
static XMLNode parseFile (XMLCSTR filename, XMLCSTR tag=NULL, XMLResults *pResults=NULL);
/**< The "parseFile" function parse an XML file and return the root of a XMLNode tree. The "opposite" of this function is
* the function "writeToFile" that re-creates an XML file from an XMLNode tree. If the XML document is corrupted, the
* "parseFile" method will initialize the "pResults" variable with some information that can be used to trace the error.
* If you still want to parse the file, you can use the APPROXIMATE_PARSING option as explained inside the note at the
* beginning of the "xmlParser.cpp" file.
*
* @param filename the path to the XML file to parse
* @param tag the name of the first tag inside the XML file. If the tag parameter is omitted, this function returns a node that represents the head of the xml document including the declaration term (<? ... ?>).
* @param pResults a pointer to a XMLResults variable that will contain some information that can be used to trace the XML parsing error. You can have a user-friendly explanation of the parsing error with the "getError" function.
*/
/// Parse an XML file and return the root of a XMLNode tree representing the file. A very crude error checking is made. An attempt to guess the Char Encoding used in the file is made.
static XMLNode openFileHelper(XMLCSTR filename, XMLCSTR tag=NULL);
/**< The "openFileHelper" function reports to the screen all the warnings and errors that occurred during parsing of the XML file.
* This function also tries to guess char Encoding (UTF-8, ASCII or SHIT-JIS) based on the first 200 bytes of the file. Since each
* application has its own way to report and deal with errors, you should rather use the "parseFile" function to parse XML files
* and program yourself thereafter an "error reporting" tailored for your needs (instead of using the very crude "error reporting"
* mechanism included inside the "openFileHelper" function).
*
* If the XML document is corrupted, the "openFileHelper" method will:
* - display an error message on the console (or inside a messageBox for windows).
* - stop execution (exit).
*
* I strongly suggest that you write your own "openFileHelper" method tailored to your needs. If you still want to parse
* the file, you can use the APPROXIMATE_PARSING option as explained inside the note at the beginning of the "xmlParser.cpp" file.
*
* @param filename the path of the XML file to parse.
* @param tag the name of the first tag inside the XML file. If the tag parameter is omitted, this function returns a node that represents the head of the xml document including the declaration term (<? ... ?>).
*/
static XMLCSTR getError(XMLError error); ///< this gives you a user-friendly explanation of the parsing error
/// Create an XML string starting from the current XMLNode.
XMLSTR createXMLString(int nFormat=1, int *pnSize=NULL) const;
/**< The returned string should be free'd using the "freeXMLString" function.
*
* If nFormat==0, no formatting is required otherwise this returns an user friendly XML string from a given element
* with appropriate white spaces and carriage returns. if pnSize is given it returns the size in character of the string. */
/// Save the content of an xmlNode inside a file
XMLError writeToFile(XMLCSTR filename,
const char *encoding=NULL,
char nFormat=1) const;
/**< If nFormat==0, no formatting is required otherwise this returns an user friendly XML string from a given element with appropriate white spaces and carriage returns.
* If the global parameter "characterEncoding==encoding_UTF8", then the "encoding" parameter is ignored and always set to "utf-8".
* If the global parameter "characterEncoding==encoding_ShiftJIS", then the "encoding" parameter is ignored and always set to "SHIFT-JIS".
* If "_XMLWIDECHAR=1", then the "encoding" parameter is ignored and always set to "utf-16".
* If no "encoding" parameter is given the "ISO-8859-1" encoding is used. */
/** @} */
/** @defgroup navigate Navigate the XMLNode structure
* @ingroup XMLParserGeneral
* @{ */
XMLCSTR getName() const; ///< name of the node
XMLCSTR getText(int i=0) const; ///< return ith text field
int nText() const; ///< nbr of text field
XMLNode getParentNode() const; ///< return the parent node
XMLNode getChildNode(int i=0) const; ///< return ith child node
XMLNode getChildNode(XMLCSTR name, int i) const; ///< return ith child node with specific name (return an empty node if failing). If i==-1, this returns the last XMLNode with the given name.
XMLNode getChildNode(XMLCSTR name, int *i=NULL) const; ///< return next child node with specific name (return an empty node if failing)
XMLNode getChildNodeWithAttribute(XMLCSTR tagName,
XMLCSTR attributeName,
XMLCSTR attributeValue=NULL,
int *i=NULL) const; ///< return child node with specific name/attribute (return an empty node if failing)
XMLNode getChildNodeByPath(XMLCSTR path, char createNodeIfMissing=0, XMLCHAR sep='/');
///< return the first child node with specific path
XMLNode getChildNodeByPathNonConst(XMLSTR path, char createNodeIfMissing=0, XMLCHAR sep='/');
///< return the first child node with specific path.
int nChildNode(XMLCSTR name) const; ///< return the number of child node with specific name
int nChildNode() const; ///< nbr of child node
XMLAttribute getAttribute(int i=0) const; ///< return ith attribute
XMLCSTR getAttributeName(int i=0) const; ///< return ith attribute name
XMLCSTR getAttributeValue(int i=0) const; ///< return ith attribute value
char isAttributeSet(XMLCSTR name) const; ///< test if an attribute with a specific name is given
XMLCSTR getAttribute(XMLCSTR name, int i) const; ///< return ith attribute content with specific name (return a NULL if failing)
XMLCSTR getAttribute(XMLCSTR name, int *i=NULL) const; ///< return next attribute content with specific name (return a NULL if failing)
int nAttribute() const; ///< nbr of attribute
XMLClear getClear(int i=0) const; ///< return ith clear field (comments)
int nClear() const; ///< nbr of clear field
XMLNodeContents enumContents(XMLElementPosition i) const; ///< enumerate all the different contents (attribute,child,text, clear) of the current XMLNode. The order is reflecting the order of the original file/string. NOTE: 0 <= i < nElement();
int nElement() const; ///< nbr of different contents for current node
char isEmpty() const; ///< is this node Empty?
char isDeclaration() const; ///< is this node a declaration <? .... ?>
XMLNode deepCopy() const; ///< deep copy (duplicate/clone) a XMLNode
static XMLNode emptyNode(); ///< return XMLNode::emptyXMLNode;
/** @} */
~XMLNode();
XMLNode(const XMLNode &A); ///< to allow shallow/fast copy:
XMLNode& operator=( const XMLNode& A ); ///< to allow shallow/fast copy:
XMLNode(): d(NULL){};
static XMLNode emptyXMLNode;
static XMLClear emptyXMLClear;
static XMLAttribute emptyXMLAttribute;
/** @defgroup xmlModify Create or Update the XMLNode structure
* @ingroup XMLParserGeneral
* The functions in this group allows you to create from scratch (or update) a XMLNode structure. Start by creating your top
* node with the "createXMLTopNode" function and then add new nodes with the "addChild" function. The parameter 'pos' gives
* the position where the childNode, the text or the XMLClearTag will be inserted. The default value (pos=-1) inserts at the
* end. The value (pos=0) insert at the beginning (Insertion at the beginning is slower than at the end). <br>
*
* REMARK: 0 <= pos < nChild()+nText()+nClear() <br>
*/
/** @defgroup creation Creating from scratch a XMLNode structure
* @ingroup xmlModify
* @{ */
static XMLNode createXMLTopNode(XMLCSTR lpszName, char isDeclaration=FALSE); ///< Create the top node of an XMLNode structure
XMLNode addChild(XMLCSTR lpszName, char isDeclaration=FALSE, XMLElementPosition pos=-1); ///< Add a new child node
XMLNode addChild(XMLNode nodeToAdd, XMLElementPosition pos=-1); ///< If the "nodeToAdd" has some parents, it will be detached from it's parents before being attached to the current XMLNode
XMLAttribute *addAttribute(XMLCSTR lpszName, XMLCSTR lpszValuev); ///< Add a new attribute
XMLCSTR addText(XMLCSTR lpszValue, XMLElementPosition pos=-1); ///< Add a new text content
XMLClear *addClear(XMLCSTR lpszValue, XMLCSTR lpszOpen=NULL, XMLCSTR lpszClose=NULL, XMLElementPosition pos=-1);
/**< Add a new clear tag
* @param lpszOpen default value "<![CDATA["
* @param lpszClose default value "]]>"
*/
/** @} */
/** @defgroup xmlUpdate Updating Nodes
* @ingroup xmlModify
* Some update functions:
* @{
*/
XMLCSTR updateName(XMLCSTR lpszName); ///< change node's name
XMLAttribute *updateAttribute(XMLAttribute *newAttribute, XMLAttribute *oldAttribute); ///< if the attribute to update is missing, a new one will be added
XMLAttribute *updateAttribute(XMLCSTR lpszNewValue, XMLCSTR lpszNewName=NULL,int i=0); ///< if the attribute to update is missing, a new one will be added
XMLAttribute *updateAttribute(XMLCSTR lpszNewValue, XMLCSTR lpszNewName,XMLCSTR lpszOldName);///< set lpszNewName=NULL if you don't want to change the name of the attribute if the attribute to update is missing, a new one will be added
XMLCSTR updateText(XMLCSTR lpszNewValue, int i=0); ///< if the text to update is missing, a new one will be added
XMLCSTR updateText(XMLCSTR lpszNewValue, XMLCSTR lpszOldValue); ///< if the text to update is missing, a new one will be added
XMLClear *updateClear(XMLCSTR lpszNewContent, int i=0); ///< if the clearTag to update is missing, a new one will be added
XMLClear *updateClear(XMLClear *newP,XMLClear *oldP); ///< if the clearTag to update is missing, a new one will be added
XMLClear *updateClear(XMLCSTR lpszNewValue, XMLCSTR lpszOldValue); ///< if the clearTag to update is missing, a new one will be added
/** @} */
/** @defgroup xmlDelete Deleting Nodes or Attributes
* @ingroup xmlModify
* Some deletion functions:
* @{
*/
/// The "deleteNodeContent" function forces the deletion of the content of this XMLNode and the subtree.
void deleteNodeContent();
/**< \note The XMLNode instances that are referring to the part of the subtree that has been deleted CANNOT be used anymore!!. Unexpected results will occur if you continue using them. */
void deleteAttribute(int i=0); ///< Delete the ith attribute of the current XMLNode
void deleteAttribute(XMLCSTR lpszName); ///< Delete the attribute with the given name (the "strcmp" function is used to find the right attribute)
void deleteAttribute(XMLAttribute *anAttribute); ///< Delete the attribute with the name "anAttribute->lpszName" (the "strcmp" function is used to find the right attribute)
void deleteText(int i=0); ///< Delete the Ith text content of the current XMLNode
void deleteText(XMLCSTR lpszValue); ///< Delete the text content "lpszValue" inside the current XMLNode (direct "pointer-to-pointer" comparison is used to find the right text)
void deleteClear(int i=0); ///< Delete the Ith clear tag inside the current XMLNode
void deleteClear(XMLCSTR lpszValue); ///< Delete the clear tag "lpszValue" inside the current XMLNode (direct "pointer-to-pointer" comparison is used to find the clear tag)
void deleteClear(XMLClear *p); ///< Delete the clear tag "p" inside the current XMLNode (direct "pointer-to-pointer" comparison on the lpszName of the clear tag is used to find the clear tag)
/** @} */
/** @defgroup xmlWOSD ???_WOSD functions.
* @ingroup xmlModify
* The strings given as parameters for the "add" and "update" methods that have a name with
* the postfix "_WOSD" (that means "WithOut String Duplication")(for example "addText_WOSD")
* will be free'd by the XMLNode class. For example, it means that this is incorrect:
* \code
* xNode.addText_WOSD("foo");
* xNode.updateAttribute_WOSD("#newcolor" ,NULL,"color");
* \endcode
* In opposition, this is correct:
* \code
* xNode.addText("foo");
* xNode.addText_WOSD(stringDup("foo"));
* xNode.updateAttribute("#newcolor" ,NULL,"color");
* xNode.updateAttribute_WOSD(stringDup("#newcolor"),NULL,"color");
* \endcode
* Typically, you will never do:
* \code
* char *b=(char*)malloc(...);
* xNode.addText(b);
* free(b);
* \endcode
* ... but rather:
* \code
* char *b=(char*)malloc(...);
* xNode.addText_WOSD(b);
* \endcode
* ('free(b)' is performed by the XMLNode class)
* @{ */
static XMLNode createXMLTopNode_WOSD(XMLSTR lpszName, char isDeclaration=FALSE); ///< Create the top node of an XMLNode structure
XMLNode addChild_WOSD(XMLSTR lpszName, char isDeclaration=FALSE, XMLElementPosition pos=-1); ///< Add a new child node
XMLAttribute *addAttribute_WOSD(XMLSTR lpszName, XMLSTR lpszValue); ///< Add a new attribute
XMLCSTR addText_WOSD(XMLSTR lpszValue, XMLElementPosition pos=-1); ///< Add a new text content
XMLClear *addClear_WOSD(XMLSTR lpszValue, XMLCSTR lpszOpen=NULL, XMLCSTR lpszClose=NULL, XMLElementPosition pos=-1); ///< Add a new clear Tag
XMLCSTR updateName_WOSD(XMLSTR lpszName); ///< change node's name
XMLAttribute *updateAttribute_WOSD(XMLAttribute *newAttribute, XMLAttribute *oldAttribute); ///< if the attribute to update is missing, a new one will be added
XMLAttribute *updateAttribute_WOSD(XMLSTR lpszNewValue, XMLSTR lpszNewName=NULL,int i=0); ///< if the attribute to update is missing, a new one will be added
XMLAttribute *updateAttribute_WOSD(XMLSTR lpszNewValue, XMLSTR lpszNewName,XMLCSTR lpszOldName); ///< set lpszNewName=NULL if you don't want to change the name of the attribute if the attribute to update is missing, a new one will be added
XMLCSTR updateText_WOSD(XMLSTR lpszNewValue, int i=0); ///< if the text to update is missing, a new one will be added
XMLCSTR updateText_WOSD(XMLSTR lpszNewValue, XMLCSTR lpszOldValue); ///< if the text to update is missing, a new one will be added
XMLClear *updateClear_WOSD(XMLSTR lpszNewContent, int i=0); ///< if the clearTag to update is missing, a new one will be added
XMLClear *updateClear_WOSD(XMLClear *newP,XMLClear *oldP); ///< if the clearTag to update is missing, a new one will be added
XMLClear *updateClear_WOSD(XMLSTR lpszNewValue, XMLCSTR lpszOldValue); ///< if the clearTag to update is missing, a new one will be added
/** @} */
/** @defgroup xmlPosition Position helper functions (use in conjunction with the update&add functions
* @ingroup xmlModify
* These are some useful functions when you want to insert a childNode, a text or a XMLClearTag in the
* middle (at a specified position) of a XMLNode tree already constructed. The value returned by these
* methods is to be used as last parameter (parameter 'pos') of addChild, addText or addClear.
* @{ */
XMLElementPosition positionOfText(int i=0) const;
XMLElementPosition positionOfText(XMLCSTR lpszValue) const;
XMLElementPosition positionOfClear(int i=0) const;
XMLElementPosition positionOfClear(XMLCSTR lpszValue) const;
XMLElementPosition positionOfClear(XMLClear *a) const;
XMLElementPosition positionOfChildNode(int i=0) const;
XMLElementPosition positionOfChildNode(XMLNode x) const;
XMLElementPosition positionOfChildNode(XMLCSTR name, int i=0) const; ///< return the position of the ith childNode with the specified name if (name==NULL) return the position of the ith childNode
/** @} */
/// Enumeration for XML character encoding.
typedef enum XMLCharEncoding
{
char_encoding_error=0,
char_encoding_UTF8=1,
char_encoding_legacy=2,
char_encoding_ShiftJIS=3,
char_encoding_GB2312=4,
char_encoding_Big5=5,
char_encoding_GBK=6 // this is actually the same as Big5
} XMLCharEncoding;
/** \addtogroup conversions
* @{ */
/// Sets the global options for the conversions
static char setGlobalOptions(XMLCharEncoding characterEncoding=XMLNode::char_encoding_UTF8, char guessWideCharChars=1,
char dropWhiteSpace=1, char removeCommentsInMiddleOfText=1);
/**< The "setGlobalOptions" function allows you to change four global parameters that affect string & file
* parsing. First of all, you most-probably will never have to change these 3 global parameters.
*
* @param guessWideCharChars If "guessWideCharChars"=1 and if this library is compiled in WideChar mode, then the
* XMLNode::parseFile and XMLNode::openFileHelper functions will test if the file contains ASCII
* characters. If this is the case, then the file will be loaded and converted in memory to
* WideChar before being parsed. If 0, no conversion will be performed.
*
* @param guessWideCharChars If "guessWideCharChars"=1 and if this library is compiled in ASCII/UTF8/char* mode, then the
* XMLNode::parseFile and XMLNode::openFileHelper functions will test if the file contains WideChar
* characters. If this is the case, then the file will be loaded and converted in memory to
* ASCII/UTF8/char* before being parsed. If 0, no conversion will be performed.
*
* @param characterEncoding This parameter is only meaningful when compiling in char* mode (multibyte character mode).
* In wchar_t* (wide char mode), this parameter is ignored. This parameter should be one of the
* three currently recognized encodings: XMLNode::encoding_UTF8, XMLNode::encoding_ascii,
* XMLNode::encoding_ShiftJIS.
*
* @param dropWhiteSpace In most situations, text fields containing only white spaces (and carriage returns)
* are useless. Even more, these "empty" text fields are annoying because they increase the
* complexity of the user's code for parsing. So, 99% of the time, it's better to drop
* the "empty" text fields. However The XML specification indicates that no white spaces
* should be lost when parsing the file. So to be perfectly XML-compliant, you should set
* dropWhiteSpace=0. A note of caution: if you set "dropWhiteSpace=0", the parser will be
* slower and your code will be more complex.
*
* @param removeCommentsInMiddleOfText To explain this parameter, let's consider this code:
* \code
* XMLNode x=XMLNode::parseString("<a>foo<!-- hello -->bar<!DOCTYPE world >chu</a>","a");
* \endcode
* If removeCommentsInMiddleOfText=0, then we will have:
* \code
* x.getText(0) -> "foo"
* x.getText(1) -> "bar"
* x.getText(2) -> "chu"
* x.getClear(0) --> "<!-- hello -->"
* x.getClear(1) --> "<!DOCTYPE world >"
* \endcode
* If removeCommentsInMiddleOfText=1, then we will have:
* \code
* x.getText(0) -> "foobar"
* x.getText(1) -> "chu"
* x.getClear(0) --> "<!DOCTYPE world >"
* \endcode
*
* \return "0" when there are no errors. If you try to set an unrecognized encoding then the return value will be "1" to signal an error.
*
* \note Sometime, it's useful to set "guessWideCharChars=0" to disable any conversion
* because the test to detect the file-type (ASCII/UTF8/char* or WideChar) may fail (rarely). */
/// Guess the character encoding of the string (ascii, utf8 or shift-JIS)
static XMLCharEncoding guessCharEncoding(void *buffer, int bufLen, char useXMLEncodingAttribute=1);
/**< The "guessCharEncoding" function try to guess the character encoding. You most-probably will never
* have to use this function. It then returns the appropriate value of the global parameter
* "characterEncoding" described in the XMLNode::setGlobalOptions. The guess is based on the content of a buffer of length
* "bufLen" bytes that contains the first bytes (minimum 25 bytes; 200 bytes is a good value) of the
* file to be parsed. The XMLNode::openFileHelper function is using this function to automatically compute
* the value of the "characterEncoding" global parameter. There are several heuristics used to do the
* guess. One of the heuristic is based on the "encoding" attribute. The original XML specifications
* forbids to use this attribute to do the guess but you can still use it if you set
* "useXMLEncodingAttribute" to 1 (this is the default behavior and the behavior of most parsers).
* If an inconsistency in the encoding is detected, then the return value is "0". */
/** @} */
private:
// these are functions and structures used internally by the XMLNode class (don't bother about them):
typedef struct XMLNodeDataTag // to allow shallow copy and "intelligent/smart" pointers (automatic delete):
{
XMLCSTR lpszName; // Element name (=NULL if root)
int nChild, // Number of child nodes
nText, // Number of text fields
nClear, // Number of Clear fields (comments)
nAttribute; // Number of attributes
char isDeclaration; // Whether node is an XML declaration - '<?xml ?>'
struct XMLNodeDataTag *pParent; // Pointer to parent element (=NULL if root)
XMLNode *pChild; // Array of child nodes
XMLCSTR *pText; // Array of text fields
XMLClear *pClear; // Array of clear fields
XMLAttribute *pAttribute; // Array of attributes
int *pOrder; // order of the child_nodes,text_fields,clear_fields
int ref_count; // for garbage collection (smart pointers)
} XMLNodeData;
XMLNodeData *d;
char parseClearTag(void *px, void *pa);
char maybeAddTxT(void *pa, XMLCSTR tokenPStr);
int ParseXMLElement(void *pXML);
void *addToOrder(int memInc, int *_pos, int nc, void *p, int size, XMLElementType xtype);
int indexText(XMLCSTR lpszValue) const;
int indexClear(XMLCSTR lpszValue) const;
XMLNode addChild_priv(int,XMLSTR,char,int);
XMLAttribute *addAttribute_priv(int,XMLSTR,XMLSTR);
XMLCSTR addText_priv(int,XMLSTR,int);
XMLClear *addClear_priv(int,XMLSTR,XMLCSTR,XMLCSTR,int);
void emptyTheNode(char force);
static inline XMLElementPosition findPosition(XMLNodeData *d, int index, XMLElementType xtype);
static int CreateXMLStringR(XMLNodeData *pEntry, XMLSTR lpszMarker, int nFormat);
static int removeOrderElement(XMLNodeData *d, XMLElementType t, int index);
static void exactMemory(XMLNodeData *d);
static int detachFromParent(XMLNodeData *d);
} XMLNode;
/// This structure is given by the function XMLNode::enumContents.
typedef struct XMLNodeContents
{
/// This dictates what's the content of the XMLNodeContent
enum XMLElementType etype;
/**< should be an union to access the appropriate data. Compiler does not allow union of object with constructor... too bad. */
XMLNode child;
XMLAttribute attrib;
XMLCSTR text;
XMLClear clear;
} XMLNodeContents;
/** @defgroup StringAlloc String Allocation/Free functions
* @ingroup xmlModify
* @{ */
/// Duplicate (copy in a new allocated buffer) the source string.
XMLDLLENTRY XMLSTR stringDup(XMLCSTR source, int cbData=-1);
/**< This is
* a very handy function when used with all the "XMLNode::*_WOSD" functions (\link xmlWOSD \endlink).
* @param cbData If !=0 then cbData is the number of chars to duplicate. New strings allocated with
* this function should be free'd using the "freeXMLString" function. */
/// to free the string allocated inside the "stringDup" function or the "createXMLString" function.
XMLDLLENTRY void freeXMLString(XMLSTR t); // {free(t);}
/** @} */
/** @defgroup atoX ato? like functions
* @ingroup XMLParserGeneral
* The "xmlto?" functions are equivalents to the atoi, atol, atof functions.
* The only difference is: If the variable "xmlString" is NULL, than the return value
* is "defautValue". These 6 functions are only here as "convenience" functions for the
* user (they are not used inside the XMLparser). If you don't need them, you can
* delete them without any trouble.
*
* @{ */
XMLDLLENTRY char xmltob(XMLCSTR xmlString,char defautValue=0);
XMLDLLENTRY int xmltoi(XMLCSTR xmlString,int defautValue=0);
XMLDLLENTRY long xmltol(XMLCSTR xmlString,long defautValue=0);
XMLDLLENTRY double xmltof(XMLCSTR xmlString,double defautValue=.0);
XMLDLLENTRY XMLCSTR xmltoa(XMLCSTR xmlString,XMLCSTR defautValue=_CXML(""));
XMLDLLENTRY XMLCHAR xmltoc(XMLCSTR xmlString,const XMLCHAR defautValue=_CXML('\0'));
/** @} */
/** @defgroup ToXMLStringTool Helper class to create XML files using "printf", "fprintf", "cout",... functions.
* @ingroup XMLParserGeneral
* @{ */
/// Helper class to create XML files using "printf", "fprintf", "cout",... functions.
/** The ToXMLStringTool class helps you creating XML files using "printf", "fprintf", "cout",... functions.
* The "ToXMLStringTool" class is processing strings so that all the characters
* &,",',<,> are replaced by their XML equivalent:
* \verbatim &amp;, &quot;, &apos;, &lt;, &gt; \endverbatim
* Using the "ToXMLStringTool class" and the "fprintf function" is THE most efficient
* way to produce VERY large XML documents VERY fast.
* \note If you are creating from scratch an XML file using the provided XMLNode class
* you must not use the "ToXMLStringTool" class (because the "XMLNode" class does the
* processing job for you during rendering).*/
typedef struct XMLDLLENTRY ToXMLStringTool
{
public:
ToXMLStringTool(): buf(NULL),buflen(0){}
~ToXMLStringTool();
void freeBuffer();///<call this function when you have finished using this object to release memory used by the internal buffer.
XMLSTR toXML(XMLCSTR source);///< returns a pointer to an internal buffer that contains a XML-encoded string based on the "source" parameter.
/** The "toXMLUnSafe" function is deprecated because there is a possibility of
* "destination-buffer-overflow". It converts the string
* "source" to the string "dest". */
static XMLSTR toXMLUnSafe(XMLSTR dest,XMLCSTR source); ///< deprecated: use "toXML" instead
static int lengthXMLString(XMLCSTR source); ///< deprecated: use "toXML" instead
private:
XMLSTR buf;
int buflen;
} ToXMLStringTool;
/** @} */
/** @defgroup XMLParserBase64Tool Helper class to include binary data inside XML strings using "Base64 encoding".
* @ingroup XMLParserGeneral
* @{ */
/// Helper class to include binary data inside XML strings using "Base64 encoding".
/** The "XMLParserBase64Tool" class allows you to include any binary data (images, sounds,...)
* into an XML document using "Base64 encoding". This class is completely
* separated from the rest of the xmlParser library and can be removed without any problem.
* To include some binary data into an XML file, you must convert the binary data into
* standard text (using "encode"). To retrieve the original binary data from the
* b64-encoded text included inside the XML file, use "decode". Alternatively, these
* functions can also be used to "encrypt/decrypt" some critical data contained inside
* the XML (it's not a strong encryption at all, but sometimes it can be useful). */
typedef struct XMLDLLENTRY XMLParserBase64Tool
{
public:
XMLParserBase64Tool(): buf(NULL),buflen(0){}
~XMLParserBase64Tool();
void freeBuffer();///< Call this function when you have finished using this object to release memory used by the internal buffer.
/**
* @param formatted If "formatted"=true, some space will be reserved for a carriage-return every 72 chars. */
static int encodeLength(int inBufLen, char formatted=0); ///< return the length of the base64 string that encodes a data buffer of size inBufLen bytes.
/**
* The "base64Encode" function returns a string containing the base64 encoding of "inByteLen" bytes
* from "inByteBuf". If "formatted" parameter is true, then there will be a carriage-return every 72 chars.
* The string will be free'd when the XMLParserBase64Tool object is deleted.
* All returned strings are sharing the same memory space. */
XMLSTR encode(unsigned char *inByteBuf, unsigned int inByteLen, char formatted=0); ///< returns a pointer to an internal buffer containing the base64 string containing the binary data encoded from "inByteBuf"
/// returns the number of bytes which will be decoded from "inString".
static unsigned int decodeSize(XMLCSTR inString, XMLError *xe=NULL);
/**
* The "decode" function returns a pointer to a buffer containing the binary data decoded from "inString"
* The output buffer will be free'd when the XMLParserBase64Tool object is deleted.
* All output buffer are sharing the same memory space.
* @param inString If "instring" is malformed, NULL will be returned */
unsigned char* decode(XMLCSTR inString, int *outByteLen=NULL, XMLError *xe=NULL); ///< returns a pointer to an internal buffer containing the binary data decoded from "inString"
/**
* decodes data from "inString" to "outByteBuf". You need to provide the size (in byte) of "outByteBuf"
* in "inMaxByteOutBuflen". If "outByteBuf" is not large enough or if data is malformed, then "FALSE"
* will be returned; otherwise "TRUE". */
static unsigned char decode(XMLCSTR inString, unsigned char *outByteBuf, int inMaxByteOutBuflen, XMLError *xe=NULL); ///< deprecated.
private:
void *buf;
int buflen;
void alloc(int newsize);
}XMLParserBase64Tool;
/** @} */
#undef XMLDLLENTRY
#endif

View File

@ -50,7 +50,7 @@ init_fove_solver(_, AllAttVars, _, fove(ParfactorList, DistIds)) :-
get_dist_ids(Parfactors, DistIds0),
sort(DistIds0, DistIds),
get_observed_vars(AllAttVars, ObservedVars),
writeln(factors:Parfactors:'\n'),
writeln(parfactors:Parfactors:'\n'),
writeln(evidence:ObservedVars:'\n'),
create_lifted_network(Parfactors,ObservedVars,ParfactorList).
@ -139,11 +139,11 @@ get_dists_parameters([Id|Ids], [dist(Id, Params)|DistsInfo]) :-
run_fove_solver(QueryVarsAtts, Solutions, fove(ParfactorList, DistIds)) :-
get_dists_parameters(DistIds, DistsParams),
writeln(distParams:DistsParams),
set_parfactors_params(ParfactorList, DistsParams),
get_query_vars(QueryVarsAtts, QueryVars),
writeln(queryVars:QueryVars), writeln(''),
get_dists_parameters(DistIds, DistsParams),
writeln(dists:DistsParams), writeln(''),
set_parfactors_params(ParfactorList, DistsParams),
run_lifted_solver(ParfactorList, QueryVars, Solutions).

View File

@ -6,19 +6,24 @@
********************************************************/
:- module(clpbn_horus,
[create_lifted_network/3,
create_ground_network/2,
[set_solver/1,
create_lifted_network/3,
create_ground_network/4,
set_parfactors_params/2,
set_bayes_net_params/2,
set_factors_params/2,
run_lifted_solver/3,
run_ground_solver/3,
set_extra_vars_info/2,
set_vars_information/2,
set_horus_flag/2,
free_parfactors/1,
free_bayesian_network/1
free_ground_network/1
]).
:- use_module(library(pfl),
[set_pfl_flag/2]).
patch_things_up :-
assert_static(clpbn_horus:set_horus_flag(_,_)).
@ -28,10 +33,18 @@ warning :-
:- catch(load_foreign_files([horus], [], init_predicates), _, patch_things_up) -> true ; warning.
set_solver(ve) :- set_pfl_flag(solver,ve).
set_solver(jt) :- set_pfl_flag(solver,jt).
set_solver(gibbs) :- set_pfl_flag(solver,gibbs).
set_solver(fove) :- set_pfl_flag(solver,fove).
set_solver(hve) :- set_pfl_flag(solver,bp), set_horus_flag(inf_alg, ve).
set_solver(bp) :- set_pfl_flag(solver,bp), set_horus_flag(inf_alg, bp).
set_solver(cbp) :- set_pfl_flag(solver,bp), set_horus_flag(inf_alg, cbp).
set_solver(S) :- throw(error('unknow solver ', S)).
%:- set_horus_flag(inf_alg, ve).
:- set_horus_flag(inf_alg, bn_bp).
%:- set_horus_flag(inf_alg, fg_bp).
%:- set_horus_flag(inf_alg, bp).
%: -set_horus_flag(inf_alg, cbp).
:- set_horus_flag(schedule, seq_fixed).
@ -46,7 +59,6 @@ warning :-
:- set_horus_flag(order_factor_variables, false).
%:- set_horus_flag(order_factor_variables, true).
:- set_horus_flag(use_logarithms, false).
% :- set_horus_flag(use_logarithms, true).

View File

@ -41,18 +41,19 @@
user:term_expansion( bayes((Formula ; Phi ; Constraints)), pfl:factor(bayes,Id,FList,FV,Phi,Constraints)) :-
!,
term_variables(Formula, FreeVars),
FV =.. [fv|FreeVars],
FV =.. [''|FreeVars],
new_id(Id),
process_args(Formula, Id, 0, _, FList, []).
user:term_expansion( markov((Formula ; Phi ; Constraints)), pfl:factor(markov,Id,FList,FV,Phi,Constraints)) :-
!,
term_variables(Formula, FreeVars),
FV =.. [fv|FreeVars],
FV =.. [''|FreeVars],
new_id(Id),
process_args(Formula, Id, 0, _, FList, []).
user:term_expansion( Id@N, L ) :-
atom(Id), number(N), !,
findall(G,generate_entity(0, N, Id, G), L).
N1 is N + 1,
findall(G,generate_entity(1, N1, Id, G), L).
user:term_expansion( Goal, [] ) :-
preprocess(Goal, Sk,Var), !,
(ground(Goal) -> true ; throw(error('non ground evidence',Goal))),
@ -78,7 +79,7 @@ defined_in_factor(Key, Factor) :-
generate_entity(N, N, _, _) :- !.
generate_entity(I0, _N, Id, T) :-
atomic_concat(person_, I0, P),
atomic_concat(p, I0, P),
T =.. [Id, P].
generate_entity(I0, N, Id, T) :-
I is I0+1,
@ -145,7 +146,7 @@ add_evidence(Sk,Var) :-
get_pfl_parameters(Id,Out) :-
factor(_Type,Id,_FList,_FV,Phi,_Constraints),
writeln(factor(_Type,Id,_FList,_FV,_Phi,_Constraints)),
%writeln(factor(_Type,Id,_FList,_FV,_Phi,_Constraints)),
( is_list(Phi) -> Out = Phi ; call(user:Phi, Out) ).