Merge branch 'master' of https://github.com/tacgomes/yap6.3
Conflicts: packages/CLPBN/clpbn/bp.yap packages/CLPBN/clpbn/fove.yap packages/CLPBN/clpbn/horus.yap
This commit is contained in:
commit
6ccd458ea5
@ -1,9 +1,9 @@
|
||||
|
||||
/************************************************
|
||||
/*******************************************************
|
||||
|
||||
Belief Propagation in CLP(BN)
|
||||
Belief Propagation and Variable Elimination Interface
|
||||
|
||||
**************************************************/
|
||||
********************************************************/
|
||||
|
||||
:- module(clpbn_bp,
|
||||
[bp/3,
|
||||
@ -23,18 +23,16 @@
|
||||
|
||||
|
||||
:- use_module(library('clpbn/display'),
|
||||
[clpbn_bind_vals/3]).
|
||||
[clpbn_bind_vals/3]).
|
||||
|
||||
|
||||
:- use_module(library('clpbn/aggregates'),
|
||||
[check_for_agg_vars/2]).
|
||||
[check_for_agg_vars/2]).
|
||||
|
||||
|
||||
:- use_module(library(clpbn/horus)).
|
||||
|
||||
:- use_module(library(atts)).
|
||||
:- use_module(library(lists)).
|
||||
:- use_module(library(charsio)).
|
||||
|
||||
:- attribute id/1.
|
||||
|
||||
@ -51,42 +49,50 @@
|
||||
|
||||
:- set_horus_flag(accuracy, 0.0001).
|
||||
|
||||
:- set_horus_flag(max_iter, 1000).
|
||||
:- use_module(library(charsio),
|
||||
[term_to_atom/2]).
|
||||
|
||||
:- set_horus_flag(use_logarithms, false).
|
||||
%:- set_horus_flag(use_logarithms, true).
|
||||
|
||||
:- set_horus_flag(order_factor_variables, false).
|
||||
%:- set_horus_flag(order_factor_variables, true).
|
||||
:- use_module(horus,
|
||||
[create_ground_network/2,
|
||||
set_bayes_net_params/2,
|
||||
run_ground_solver/3,
|
||||
set_extra_vars_info/2,
|
||||
free_bayesian_network/1
|
||||
]).
|
||||
|
||||
|
||||
:- attribute id/1.
|
||||
|
||||
|
||||
bp([[]],_,_) :- !.
|
||||
bp([QueryVars], AllVars, Output) :-
|
||||
init_bp_solver(_, AllVars, _, Network),
|
||||
run_bp_solver([QueryVars], LPs, Network),
|
||||
finalize_bp_solver(Network),
|
||||
clpbn_bind_vals([QueryVars], LPs, Output).
|
||||
init_bp_solver(_, AllVars, _, Network),
|
||||
run_bp_solver([QueryVars], LPs, Network),
|
||||
finalize_bp_solver(Network),
|
||||
clpbn_bind_vals([QueryVars], LPs, Output).
|
||||
|
||||
|
||||
init_bp_solver(_, AllVars0, _, bp(BayesNet, DistIds)) :-
|
||||
check_for_agg_vars(AllVars0, AllVars),
|
||||
writeln('clpbn_vars:'),
|
||||
print_clpbn_vars(AllVars),
|
||||
assign_ids(AllVars, 0),
|
||||
get_vars_info(AllVars, VarsInfo, DistIds0),
|
||||
sort(DistIds0, DistIds),
|
||||
create_ground_network(VarsInfo, BayesNet).
|
||||
%get_extra_vars_info(AllVars, ExtraVarsInfo),
|
||||
%set_extra_vars_info(BayesNet, ExtraVarsInfo).
|
||||
%writeln('init_bp_solver'),
|
||||
check_for_agg_vars(AllVars0, AllVars),
|
||||
%writeln('clpbn_vars:'), print_clpbn_vars(AllVars),
|
||||
assign_ids(AllVars, 0),
|
||||
get_vars_info(AllVars, VarsInfo, DistIds0),
|
||||
sort(DistIds0, DistIds),
|
||||
create_ground_network(VarsInfo, BayesNet),
|
||||
%get_extra_vars_info(AllVars, ExtraVarsInfo),
|
||||
%set_extra_vars_info(BayesNet, ExtraVarsInfo),
|
||||
%writeln(extravarsinfo:ExtraVarsInfo),
|
||||
true.
|
||||
|
||||
|
||||
run_bp_solver(QueryVars, Solutions, bp(Network, DistIds)) :-
|
||||
get_dists_parameters(DistIds, DistsParams),
|
||||
%writeln('-> run_bp_solver'),
|
||||
get_dists_parameters(DistIds, DistsParams),
|
||||
set_bayes_net_params(Network, DistsParams),
|
||||
flatten_1_element_sublists(QueryVars, QueryVars1),
|
||||
vars_to_ids(QueryVars1, QueryVarsIds),
|
||||
run_other_solvers(Network, QueryVarsIds, Solutions).
|
||||
vars_to_ids(QueryVars, QueryVarsIds),
|
||||
run_ground_solver(Network, QueryVarsIds, Solutions).
|
||||
|
||||
|
||||
finalize_bp_solver(bp(Network, _)) :-
|
||||
@ -95,82 +101,75 @@ finalize_bp_solver(bp(Network, _)) :-
|
||||
|
||||
assign_ids([], _).
|
||||
assign_ids([V|Vs], Count) :-
|
||||
put_atts(V, [id(Count)]),
|
||||
Count1 is Count + 1,
|
||||
assign_ids(Vs, Count1).
|
||||
put_atts(V, [id(Count)]),
|
||||
Count1 is Count + 1,
|
||||
assign_ids(Vs, Count1).
|
||||
|
||||
|
||||
get_vars_info([], [], []).
|
||||
get_vars_info(V.Vs,
|
||||
var(VarId,DS,Ev,PIds,DistId).VarsInfo,
|
||||
DistId.DistIds) :-
|
||||
clpbn:get_atts(V, [dist(DistId, Parents)]), !,
|
||||
get_atts(V, [id(VarId)]),
|
||||
get_dist_domain_size(DistId, DS),
|
||||
get_evidence(V, Ev),
|
||||
vars_to_ids(Parents, PIds),
|
||||
get_vars_info(Vs, VarsInfo, DistIds).
|
||||
var(VarId,DS,Ev,PIds,DistId).VarsInfo,
|
||||
DistId.DistIds) :-
|
||||
clpbn:get_atts(V, [dist(DistId, Parents)]), !,
|
||||
get_atts(V, [id(VarId)]),
|
||||
get_dist_domain_size(DistId, DS),
|
||||
get_evidence(V, Ev),
|
||||
vars_to_ids(Parents, PIds),
|
||||
get_vars_info(Vs, VarsInfo, DistIds).
|
||||
|
||||
|
||||
get_evidence(V, Ev) :-
|
||||
clpbn:get_atts(V, [evidence(Ev)]), !.
|
||||
clpbn:get_atts(V, [evidence(Ev)]), !.
|
||||
get_evidence(_V, -1). % no evidence !!!
|
||||
|
||||
|
||||
vars_to_ids([], []).
|
||||
vars_to_ids([L|Vars], [LIds|Ids]) :-
|
||||
is_list(L), !,
|
||||
vars_to_ids(L, LIds),
|
||||
vars_to_ids(Vars, Ids).
|
||||
is_list(L), !,
|
||||
vars_to_ids(L, LIds),
|
||||
vars_to_ids(Vars, Ids).
|
||||
vars_to_ids([V|Vars], [VarId|Ids]) :-
|
||||
get_atts(V, [id(VarId)]),
|
||||
vars_to_ids(Vars, Ids).
|
||||
get_atts(V, [id(VarId)]),
|
||||
vars_to_ids(Vars, Ids).
|
||||
|
||||
|
||||
get_extra_vars_info([], []).
|
||||
get_extra_vars_info([V|Vs], [v(VarId, Label, Domain)|VarsInfo]) :-
|
||||
get_atts(V, [id(VarId)]), !,
|
||||
clpbn:get_atts(V, [key(Key),dist(DistId, _)]),
|
||||
term_to_atom(Key, Label),
|
||||
get_dist_domain(DistId, Domain0),
|
||||
numbers_to_atoms(Domain0, Domain),
|
||||
get_extra_vars_info(Vs, VarsInfo).
|
||||
get_atts(V, [id(VarId)]), !,
|
||||
clpbn:get_atts(V, [key(Key), dist(DistId, _)]),
|
||||
term_to_atom(Key, Label),
|
||||
get_dist_domain(DistId, Domain0),
|
||||
numbers_to_atoms(Domain0, Domain),
|
||||
get_extra_vars_info(Vs, VarsInfo).
|
||||
get_extra_vars_info([_|Vs], VarsInfo) :-
|
||||
get_extra_vars_info(Vs, VarsInfo).
|
||||
get_extra_vars_info(Vs, VarsInfo).
|
||||
|
||||
|
||||
get_dists_parameters([],[]).
|
||||
get_dists_parameters([Id|Ids], [dist(Id, Params)|DistsInfo]) :-
|
||||
get_dist_params(Id, Params),
|
||||
get_dists_parameters(Ids, DistsInfo).
|
||||
get_dist_params(Id, Params),
|
||||
get_dists_parameters(Ids, DistsInfo).
|
||||
|
||||
|
||||
numbers_to_atoms([], []).
|
||||
numbers_to_atoms([Atom|L0], [Atom|L]) :-
|
||||
atom(Atom), !,
|
||||
numbers_to_atoms(L0, L).
|
||||
atom(Atom), !,
|
||||
numbers_to_atoms(L0, L).
|
||||
numbers_to_atoms([Number|L0], [Atom|L]) :-
|
||||
number_atom(Number, Atom),
|
||||
numbers_to_atoms(L0, L).
|
||||
|
||||
|
||||
flatten_1_element_sublists([],[]).
|
||||
flatten_1_element_sublists([[H|[]]|T],[H|R]) :- !,
|
||||
flatten_1_element_sublists(T,R).
|
||||
flatten_1_element_sublists([H|T],[H|R]) :-
|
||||
flatten_1_element_sublists(T,R).
|
||||
number_atom(Number, Atom),
|
||||
numbers_to_atoms(L0, L).
|
||||
|
||||
|
||||
print_clpbn_vars(Var.AllVars) :-
|
||||
clpbn:get_atts(Var, [key(Key),dist(DistId,Parents)]),
|
||||
parents_to_keys(Parents, ParentKeys),
|
||||
writeln(Var:Key:ParentKeys:DistId),
|
||||
print_clpbn_vars(AllVars).
|
||||
clpbn:get_atts(Var, [key(Key),dist(DistId,Parents)]),
|
||||
parents_to_keys(Parents, ParentKeys),
|
||||
writeln(Var:Key:ParentKeys:DistId),
|
||||
print_clpbn_vars(AllVars).
|
||||
print_clpbn_vars([]).
|
||||
|
||||
|
||||
parents_to_keys([], []).
|
||||
parents_to_keys(Var.Parents, Key.Keys) :-
|
||||
clpbn:get_atts(Var, [key(Key)]),
|
||||
parents_to_keys(Parents, Keys).
|
||||
clpbn:get_atts(Var, [key(Key)]),
|
||||
parents_to_keys(Parents, Keys).
|
||||
|
||||
|
@ -88,15 +88,22 @@ BayesNet::readFromBifFormat (const char* fileName)
|
||||
abort();
|
||||
}
|
||||
params = reorderParameters (params, node->nrStates());
|
||||
Distribution* dist = new Distribution (params);
|
||||
node->setDistribution (dist);
|
||||
addDistribution (dist);
|
||||
if (Globals::logDomain) {
|
||||
Util::toLog (params);
|
||||
}
|
||||
node->setParams (params);
|
||||
}
|
||||
|
||||
setIndexes();
|
||||
if (Globals::logDomain) {
|
||||
distributionsToLogs();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
BayesNode*
|
||||
BayesNet::addNode (BayesNode* n)
|
||||
{
|
||||
varMap_.insert (make_pair (n->varId(), nodes_.size()));
|
||||
nodes_.push_back (n);
|
||||
return nodes_.back();
|
||||
}
|
||||
|
||||
|
||||
@ -114,15 +121,6 @@ BayesNet::addNode (string label, const States& states)
|
||||
|
||||
|
||||
|
||||
BayesNode*
|
||||
BayesNet::addNode (VarId vid, unsigned dsize, int evidence, Distribution* dist)
|
||||
{
|
||||
varMap_.insert (make_pair (vid, nodes_.size()));
|
||||
nodes_.push_back (new BayesNode (vid, dsize, evidence, dist));
|
||||
return nodes_.back();
|
||||
}
|
||||
|
||||
|
||||
|
||||
BayesNode*
|
||||
BayesNet::getBayesNode (VarId vid) const
|
||||
@ -176,29 +174,6 @@ BayesNet::getVariableNodes (void) const
|
||||
|
||||
|
||||
|
||||
void
|
||||
BayesNet::addDistribution (Distribution* dist)
|
||||
{
|
||||
dists_.push_back (dist);
|
||||
}
|
||||
|
||||
|
||||
|
||||
Distribution*
|
||||
BayesNet::getDistribution (unsigned distId) const
|
||||
{
|
||||
Distribution* dist = 0;
|
||||
for (unsigned i = 0; i < dists_.size(); i++) {
|
||||
if (dists_[i]->id == (int) distId) {
|
||||
dist = dists_[i];
|
||||
break;
|
||||
}
|
||||
}
|
||||
return dist;
|
||||
}
|
||||
|
||||
|
||||
|
||||
const BnNodeSet&
|
||||
BayesNet::getBayesNodes (void) const
|
||||
{
|
||||
@ -299,7 +274,7 @@ BayesNet::getMinimalRequesiteNetwork (const VarIds& queryVarIds) const
|
||||
/*
|
||||
cout << "\t\ttop\tbottom" << endl;
|
||||
cout << "variable\t\tmarked\tmarked\tvisited\tobserved" << endl;
|
||||
cout << "----------------------------------------------------------" ;
|
||||
Util::printDashedLine();
|
||||
cout << endl;
|
||||
for (unsigned i = 0; i < states.size(); i++) {
|
||||
cout << nodes_[i]->label() << ":\t\t" ;
|
||||
@ -350,10 +325,8 @@ BayesNet::constructGraph (BayesNet* bn,
|
||||
}
|
||||
}
|
||||
assert (bn->getBayesNode (nodes_[i]->varId()) == 0);
|
||||
BayesNode* mrnNode = bn->addNode (nodes_[i]->varId(),
|
||||
nodes_[i]->nrStates(),
|
||||
nodes_[i]->getEvidence(),
|
||||
nodes_[i]->getDistribution());
|
||||
BayesNode* mrnNode = new BayesNode (nodes_[i]);
|
||||
bn->addNode (mrnNode);
|
||||
mrnNodes.push_back (mrnNode);
|
||||
}
|
||||
}
|
||||
@ -388,26 +361,6 @@ BayesNet::setIndexes (void)
|
||||
|
||||
|
||||
|
||||
void
|
||||
BayesNet::distributionsToLogs (void)
|
||||
{
|
||||
for (unsigned i = 0; i < dists_.size(); i++) {
|
||||
Util::toLog (dists_[i]->params);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BayesNet::freeDistributions (void)
|
||||
{
|
||||
for (unsigned i = 0; i < dists_.size(); i++) {
|
||||
delete dists_[i];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BayesNet::printGraphicalModel (void) const
|
||||
{
|
||||
@ -504,8 +457,8 @@ BayesNet::exportToBifFormat (const char* fileName) const
|
||||
out << "\t<GIVEN>" << parents[j]->label();
|
||||
out << "</GIVEN>" << endl;
|
||||
}
|
||||
Params params = revertParameterReorder (nodes_[i]->getParameters(),
|
||||
nodes_[i]->nrStates());
|
||||
Params params = revertParameterReorder (
|
||||
nodes_[i]->params(), nodes_[i]->nrStates());
|
||||
out << "\t<TABLE>" ;
|
||||
for (unsigned j = 0; j < params.size(); j++) {
|
||||
out << " " << params[j];
|
||||
|
@ -13,16 +13,11 @@
|
||||
|
||||
using namespace std;
|
||||
|
||||
class Distribution;
|
||||
|
||||
struct ScheduleInfo
|
||||
{
|
||||
ScheduleInfo (BayesNode* n, bool vfp, bool vfc)
|
||||
{
|
||||
node = n;
|
||||
visitedFromParent = vfp;
|
||||
visitedFromChild = vfc;
|
||||
}
|
||||
ScheduleInfo (BayesNode* n, bool vfp, bool vfc) :
|
||||
node(n), visitedFromParent(vfp), visitedFromChild(vfc) { }
|
||||
BayesNode* node;
|
||||
bool visitedFromParent;
|
||||
bool visitedFromChild;
|
||||
@ -31,70 +26,84 @@ struct ScheduleInfo
|
||||
|
||||
struct StateInfo
|
||||
{
|
||||
StateInfo (void)
|
||||
{
|
||||
visited = true;
|
||||
markedOnTop = false;
|
||||
markedOnBottom = false;
|
||||
}
|
||||
StateInfo (void) : visited(false), markedOnTop(false),
|
||||
markedOnBottom(false) { }
|
||||
bool visited;
|
||||
bool markedOnTop;
|
||||
bool markedOnBottom;
|
||||
};
|
||||
|
||||
typedef vector<Distribution*> DistSet;
|
||||
|
||||
typedef queue<ScheduleInfo, list<ScheduleInfo> > Scheduling;
|
||||
|
||||
|
||||
class BayesNet : public GraphicalModel
|
||||
{
|
||||
public:
|
||||
BayesNet (void) {};
|
||||
BayesNet (void) { };
|
||||
|
||||
~BayesNet (void);
|
||||
|
||||
void readFromBifFormat (const char*);
|
||||
BayesNode* addNode (string, const States&);
|
||||
// BayesNode* addNode (VarId, unsigned, int, BnNodeSet&, Distribution*);
|
||||
BayesNode* addNode (VarId, unsigned, int, Distribution*);
|
||||
BayesNode* getBayesNode (VarId) const;
|
||||
BayesNode* getBayesNode (string) const;
|
||||
VarNode* getVariableNode (VarId) const;
|
||||
VarNodes getVariableNodes (void) const;
|
||||
void addDistribution (Distribution*);
|
||||
Distribution* getDistribution (unsigned) const;
|
||||
const BnNodeSet& getBayesNodes (void) const;
|
||||
unsigned nrNodes (void) const;
|
||||
BnNodeSet getRootNodes (void) const;
|
||||
BnNodeSet getLeafNodes (void) const;
|
||||
BayesNet* getMinimalRequesiteNetwork (VarId) const;
|
||||
BayesNet* getMinimalRequesiteNetwork (const VarIds&) const;
|
||||
void constructGraph (
|
||||
BayesNet*, const vector<StateInfo*>&) const;
|
||||
bool isPolyTree (void) const;
|
||||
void setIndexes (void);
|
||||
void distributionsToLogs (void);
|
||||
void freeDistributions (void);
|
||||
void printGraphicalModel (void) const;
|
||||
void exportToGraphViz (const char*, bool = true,
|
||||
const VarIds& = VarIds()) const;
|
||||
void exportToBifFormat (const char*) const;
|
||||
void readFromBifFormat (const char*);
|
||||
|
||||
BayesNode* addNode (BayesNode*);
|
||||
|
||||
BayesNode* addNode (string, const States&);
|
||||
|
||||
BayesNode* getBayesNode (VarId) const;
|
||||
|
||||
BayesNode* getBayesNode (string) const;
|
||||
|
||||
VarNode* getVariableNode (VarId) const;
|
||||
|
||||
VarNodes getVariableNodes (void) const;
|
||||
|
||||
const BnNodeSet& getBayesNodes (void) const;
|
||||
|
||||
unsigned nrNodes (void) const;
|
||||
|
||||
BnNodeSet getRootNodes (void) const;
|
||||
|
||||
BnNodeSet getLeafNodes (void) const;
|
||||
|
||||
BayesNet* getMinimalRequesiteNetwork (VarId) const;
|
||||
|
||||
BayesNet* getMinimalRequesiteNetwork (const VarIds&) const;
|
||||
|
||||
void constructGraph (BayesNet*, const vector<StateInfo*>&) const;
|
||||
|
||||
bool isPolyTree (void) const;
|
||||
|
||||
void setIndexes (void);
|
||||
|
||||
void printGraphicalModel (void) const;
|
||||
|
||||
void exportToGraphViz (const char*, bool = true,
|
||||
const VarIds& = VarIds()) const;
|
||||
|
||||
void exportToBifFormat (const char*) const;
|
||||
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN (BayesNet);
|
||||
|
||||
bool containsUndirectedCycle (void) const;
|
||||
bool containsUndirectedCycle (int, int, vector<bool>&)const;
|
||||
vector<int> getAdjacentNodes (int) const;
|
||||
Params reorderParameters (const Params&, unsigned) const;
|
||||
Params revertParameterReorder (const Params&, unsigned) const;
|
||||
void scheduleParents (const BayesNode*, Scheduling&) const;
|
||||
void scheduleChilds (const BayesNode*, Scheduling&) const;
|
||||
bool containsUndirectedCycle (void) const;
|
||||
|
||||
BnNodeSet nodes_;
|
||||
DistSet dists_;
|
||||
bool containsUndirectedCycle (int, int, vector<bool>&)const;
|
||||
|
||||
vector<int> getAdjacentNodes (int) const;
|
||||
|
||||
Params reorderParameters (const Params&, unsigned) const;
|
||||
|
||||
Params revertParameterReorder (const Params&, unsigned) const;
|
||||
|
||||
void scheduleParents (const BayesNode*, Scheduling&) const;
|
||||
|
||||
void scheduleChilds (const BayesNode*, Scheduling&) const;
|
||||
|
||||
BnNodeSet nodes_;
|
||||
|
||||
typedef unordered_map<unsigned, unsigned> IndexMap;
|
||||
IndexMap varMap_;
|
||||
IndexMap varMap_;
|
||||
};
|
||||
|
||||
|
||||
|
@ -8,29 +8,10 @@
|
||||
#include "BayesNode.h"
|
||||
|
||||
|
||||
BayesNode::BayesNode (VarId vid,
|
||||
unsigned dsize,
|
||||
int evidence,
|
||||
Distribution* dist)
|
||||
: VarNode (vid, dsize, evidence)
|
||||
void
|
||||
BayesNode::setParams (const Params& params)
|
||||
{
|
||||
dist_ = dist;
|
||||
}
|
||||
|
||||
|
||||
|
||||
BayesNode::BayesNode (VarId vid,
|
||||
unsigned dsize,
|
||||
int evidence,
|
||||
const BnNodeSet& parents,
|
||||
Distribution* dist)
|
||||
: VarNode (vid, dsize, evidence)
|
||||
{
|
||||
parents_ = parents;
|
||||
dist_ = dist;
|
||||
for (unsigned int i = 0; i < parents.size(); i++) {
|
||||
parents[i]->addChild (this);
|
||||
}
|
||||
params_ = params;
|
||||
}
|
||||
|
||||
|
||||
@ -54,31 +35,6 @@ BayesNode::addChild (BayesNode* node)
|
||||
|
||||
|
||||
|
||||
void
|
||||
BayesNode::setDistribution (Distribution* dist)
|
||||
{
|
||||
assert (dist);
|
||||
dist_ = dist;
|
||||
}
|
||||
|
||||
|
||||
|
||||
Distribution*
|
||||
BayesNode::getDistribution (void)
|
||||
{
|
||||
return dist_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
const Params&
|
||||
BayesNode::getParameters (void)
|
||||
{
|
||||
return dist_->params;
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
BayesNode::getRow (int rowIndex) const
|
||||
{
|
||||
@ -86,7 +42,7 @@ BayesNode::getRow (int rowIndex) const
|
||||
int offset = rowSize * rowIndex;
|
||||
Params row (rowSize);
|
||||
for (int i = 0; i < rowSize; i++) {
|
||||
row[i] = dist_->params[offset + i] ;
|
||||
row[i] = params_[offset + i] ;
|
||||
}
|
||||
return row;
|
||||
}
|
||||
@ -119,13 +75,13 @@ BayesNode::hasNeighbors (void) const
|
||||
int
|
||||
BayesNode::getCptSize (void)
|
||||
{
|
||||
return dist_->params.size();
|
||||
return params_.size();
|
||||
}
|
||||
|
||||
|
||||
|
||||
int
|
||||
BayesNode::getIndexOfParent (const BayesNode* parent) const
|
||||
BayesNode::indexOfParent (const BayesNode* parent) const
|
||||
{
|
||||
for (unsigned int i = 0; i < parents_.size(); i++) {
|
||||
if (parents_[i] == parent) {
|
||||
|
@ -4,7 +4,6 @@
|
||||
#include <vector>
|
||||
|
||||
#include "VarNode.h"
|
||||
#include "Distribution.h"
|
||||
#include "Horus.h"
|
||||
|
||||
using namespace std;
|
||||
@ -13,49 +12,70 @@ using namespace std;
|
||||
class BayesNode : public VarNode
|
||||
{
|
||||
public:
|
||||
BayesNode (const VarNode& v) : VarNode (v) {}
|
||||
BayesNode (VarId, unsigned, int, Distribution*);
|
||||
BayesNode (VarId, unsigned, int, const BnNodeSet&, Distribution*);
|
||||
|
||||
void setParents (const BnNodeSet&);
|
||||
void addChild (BayesNode*);
|
||||
void setDistribution (Distribution*);
|
||||
Distribution* getDistribution (void);
|
||||
const Params& getParameters (void);
|
||||
Params getRow (int) const;
|
||||
bool isRoot (void);
|
||||
bool isLeaf (void);
|
||||
bool hasNeighbors (void) const;
|
||||
int getCptSize (void);
|
||||
int getIndexOfParent (const BayesNode*) const;
|
||||
string cptEntryToString (int, const vector<unsigned>&) const;
|
||||
BayesNode (const VarNode& v) : VarNode (v) { }
|
||||
|
||||
const BnNodeSet& getParents (void) const { return parents_; }
|
||||
const BnNodeSet& getChilds (void) const { return childs_; }
|
||||
BayesNode (const BayesNode* n) :
|
||||
VarNode (n->varId(), n->nrStates(), n->getEvidence()),
|
||||
params_(n->params()), distId_(n->distId()) { }
|
||||
|
||||
BayesNode (VarId vid, unsigned nrStates, int ev,
|
||||
const Params& ps, unsigned id)
|
||||
: VarNode (vid, nrStates, ev) , params_(ps), distId_(id) { }
|
||||
|
||||
const BnNodeSet& getParents (void) const { return parents_; }
|
||||
|
||||
const BnNodeSet& getChilds (void) const { return childs_; }
|
||||
|
||||
const Params& params (void) const { return params_; }
|
||||
|
||||
unsigned distId (void) const { return distId_; }
|
||||
|
||||
unsigned getRowSize (void) const
|
||||
{
|
||||
return dist_->params.size() / nrStates();
|
||||
return params_.size() / nrStates();
|
||||
}
|
||||
|
||||
double getProbability (int row, unsigned col)
|
||||
{
|
||||
int idx = (row * getRowSize()) + col;
|
||||
return dist_->params[idx];
|
||||
return params_[idx];
|
||||
}
|
||||
|
||||
void setParams (const Params& params);
|
||||
|
||||
void setParents (const BnNodeSet&);
|
||||
|
||||
void addChild (BayesNode*);
|
||||
|
||||
const Params& getParameters (void);
|
||||
|
||||
Params getRow (int) const;
|
||||
|
||||
bool isRoot (void);
|
||||
|
||||
bool isLeaf (void);
|
||||
|
||||
bool hasNeighbors (void) const;
|
||||
|
||||
int getCptSize (void);
|
||||
|
||||
int indexOfParent (const BayesNode*) const;
|
||||
|
||||
string cptEntryToString (int, const vector<unsigned>&) const;
|
||||
|
||||
friend ostream& operator << (ostream&, const BayesNode&);
|
||||
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN (BayesNode);
|
||||
|
||||
States getDomainHeaders (void) const;
|
||||
friend ostream& operator << (ostream&, const BayesNode&);
|
||||
States getDomainHeaders (void) const;
|
||||
|
||||
BnNodeSet parents_;
|
||||
BnNodeSet childs_;
|
||||
Distribution* dist_;
|
||||
BnNodeSet parents_;
|
||||
BnNodeSet childs_;
|
||||
Params params_;
|
||||
unsigned distId_;
|
||||
};
|
||||
|
||||
ostream& operator << (ostream&, const BayesNode&);
|
||||
|
||||
#endif // HORUS_BAYESNODE_H
|
||||
|
||||
|
@ -34,12 +34,12 @@ void
|
||||
BnBpSolver::runSolver (void)
|
||||
{
|
||||
clock_t start;
|
||||
if (COLLECT_STATISTICS) {
|
||||
if (Constants::COLLECT_STATS) {
|
||||
start = clock();
|
||||
}
|
||||
initializeSolver();
|
||||
runLoopySolver();
|
||||
if (DL >= 2) {
|
||||
if (Constants::DEBUG >= 2) {
|
||||
cout << endl;
|
||||
if (nIters_ < BpOptions::maxIter) {
|
||||
cout << "Belief propagation converged in " ;
|
||||
@ -51,18 +51,13 @@ BnBpSolver::runSolver (void)
|
||||
}
|
||||
|
||||
unsigned size = bayesNet_->nrNodes();
|
||||
if (COLLECT_STATISTICS) {
|
||||
if (Constants::COLLECT_STATS) {
|
||||
unsigned nIters = 0;
|
||||
bool loopy = bayesNet_->isPolyTree() == false;
|
||||
if (loopy) nIters = nIters_;
|
||||
double time = (double (clock() - start)) / CLOCKS_PER_SEC;
|
||||
Statistics::updateStatistics (size, loopy, nIters, time);
|
||||
}
|
||||
if (EXPORT_TO_GRAPHVIZ && size > EXPORT_MINIMAL_SIZE) {
|
||||
stringstream ss;
|
||||
ss << Statistics::getSolvedNetworksCounting() << "." << size << ".dot" ;
|
||||
bayesNet_->exportToGraphViz (ss.str().c_str());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -80,7 +75,7 @@ BnBpSolver::getPosterioriOf (VarId vid)
|
||||
Params
|
||||
BnBpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
||||
{
|
||||
if (DL >= 2) {
|
||||
if (Constants::DEBUG >= 2) {
|
||||
cout << "calculating joint distribution on: " ;
|
||||
for (unsigned i = 0; i < jointVarIds.size(); i++) {
|
||||
VarNode* var = bayesNet_->getBayesNode (jointVarIds[i]);
|
||||
@ -112,7 +107,7 @@ BnBpSolver::initializeSolver (void)
|
||||
|
||||
BnNodeSet roots = bayesNet_->getRootNodes();
|
||||
for (unsigned i = 0; i < roots.size(); i++) {
|
||||
const Params& params = roots[i]->getParameters();
|
||||
const Params& params = roots[i]->params();
|
||||
Params& piVals = ninf(roots[i])->getPiValues();
|
||||
for (unsigned ri = 0; ri < roots[i]->nrStates(); ri++) {
|
||||
piVals[ri] = params[ri];
|
||||
@ -143,11 +138,11 @@ BnBpSolver::initializeSolver (void)
|
||||
Params& piVals = ninf(nodes[i])->getPiValues();
|
||||
Params& ldVals = ninf(nodes[i])->getLambdaValues();
|
||||
for (unsigned xi = 0; xi < nodes[i]->nrStates(); xi++) {
|
||||
piVals[xi] = Util::noEvidence();
|
||||
ldVals[xi] = Util::noEvidence();
|
||||
piVals[xi] = LogAware::noEvidence();
|
||||
ldVals[xi] = LogAware::noEvidence();
|
||||
}
|
||||
piVals[nodes[i]->getEvidence()] = Util::withEvidence();
|
||||
ldVals[nodes[i]->getEvidence()] = Util::withEvidence();
|
||||
piVals[nodes[i]->getEvidence()] = LogAware::withEvidence();
|
||||
ldVals[nodes[i]->getEvidence()] = LogAware::withEvidence();
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -161,13 +156,8 @@ BnBpSolver::runLoopySolver()
|
||||
while (!converged() && nIters_ < BpOptions::maxIter) {
|
||||
|
||||
nIters_++;
|
||||
if (DL >= 2) {
|
||||
cout << "****************************************" ;
|
||||
cout << "****************************************" ;
|
||||
cout << endl;
|
||||
cout << " Iteration " << nIters_ << endl;
|
||||
cout << "****************************************" ;
|
||||
cout << "****************************************" ;
|
||||
if (Constants::DEBUG >= 2) {
|
||||
Util::printHeader ("Iteration " + nIters_);
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
@ -199,7 +189,7 @@ BnBpSolver::runLoopySolver()
|
||||
break;
|
||||
|
||||
}
|
||||
if (DL >= 2) {
|
||||
if (Constants::DEBUG >= 2) {
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
@ -228,7 +218,7 @@ BnBpSolver::converged (void) const
|
||||
} else {
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
double residual = links_[i]->getResidual();
|
||||
if (DL >= 2) {
|
||||
if (Constants::DEBUG >= 2) {
|
||||
cout << links_[i]->toString() + " residual change = " ;
|
||||
cout << residual << endl;
|
||||
}
|
||||
@ -256,7 +246,7 @@ BnBpSolver::maxResidualSchedule (void)
|
||||
}
|
||||
|
||||
for (unsigned c = 0; c < sortedOrder_.size(); c++) {
|
||||
if (DL >= 2) {
|
||||
if (Constants::DEBUG >= 2) {
|
||||
cout << "current residuals:" << endl;
|
||||
for (SortedOrder::iterator it = sortedOrder_.begin();
|
||||
it != sortedOrder_.end(); it ++) {
|
||||
@ -300,9 +290,8 @@ BnBpSolver::maxResidualSchedule (void)
|
||||
}
|
||||
}
|
||||
|
||||
if (DL >= 2) {
|
||||
cout << "----------------------------------------" ;
|
||||
cout << "----------------------------------------" << endl;
|
||||
if (Constants::DEBUG >= 2) {
|
||||
Util::printDashedLine();
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -313,7 +302,7 @@ void
|
||||
BnBpSolver::updatePiValues (BayesNode* x)
|
||||
{
|
||||
// π(Xi)
|
||||
if (DL >= 3) {
|
||||
if (Constants::DEBUG >= 3) {
|
||||
cout << "updating " << PI_SYMBOL << " values for " << x->label() << endl;
|
||||
}
|
||||
Params& piValues = ninf(x)->getPiValues();
|
||||
@ -329,11 +318,11 @@ BnBpSolver::updatePiValues (BayesNode* x)
|
||||
|
||||
Params messageProducts (indexer.size());
|
||||
for (unsigned k = 0; k < indexer.size(); k++) {
|
||||
if (DL >= 5) {
|
||||
if (Constants::DEBUG >= 5) {
|
||||
calcs1 = new stringstream;
|
||||
calcs2 = new stringstream;
|
||||
}
|
||||
double messageProduct = Util::multIdenty();
|
||||
double messageProduct = LogAware::multIdenty();
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned i = 0; i < parentLinks.size(); i++) {
|
||||
messageProduct += parentLinks[i]->getMessage()[indexer[i]];
|
||||
@ -341,7 +330,7 @@ BnBpSolver::updatePiValues (BayesNode* x)
|
||||
} else {
|
||||
for (unsigned i = 0; i < parentLinks.size(); i++) {
|
||||
messageProduct *= parentLinks[i]->getMessage()[indexer[i]];
|
||||
if (DL >= 5) {
|
||||
if (Constants::DEBUG >= 5) {
|
||||
if (i != 0) *calcs1 << " + " ;
|
||||
if (i != 0) *calcs2 << " + " ;
|
||||
*calcs1 << parentLinks[i]->toString (indexer[i]);
|
||||
@ -350,7 +339,7 @@ BnBpSolver::updatePiValues (BayesNode* x)
|
||||
}
|
||||
}
|
||||
messageProducts[k] = messageProduct;
|
||||
if (DL >= 5) {
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " mp" << k;
|
||||
cout << " = " << (*calcs1).str();
|
||||
if (parentLinks.size() == 1) {
|
||||
@ -366,27 +355,27 @@ BnBpSolver::updatePiValues (BayesNode* x)
|
||||
}
|
||||
|
||||
for (unsigned xi = 0; xi < x->nrStates(); xi++) {
|
||||
double sum = Util::addIdenty();
|
||||
if (DL >= 5) {
|
||||
double sum = LogAware::addIdenty();
|
||||
if (Constants::DEBUG >= 5) {
|
||||
calcs1 = new stringstream;
|
||||
calcs2 = new stringstream;
|
||||
}
|
||||
indexer.reset();
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned k = 0; k < indexer.size(); k++) {
|
||||
Util::logSum (sum,
|
||||
x->getProbability(xi, indexer.linearIndex()) + messageProducts[k]);
|
||||
sum = Util::logSum (sum,
|
||||
x->getProbability(xi, indexer) + messageProducts[k]);
|
||||
++ indexer;
|
||||
}
|
||||
} else {
|
||||
for (unsigned k = 0; k < indexer.size(); k++) {
|
||||
sum += x->getProbability (xi, indexer.linearIndex()) * messageProducts[k];
|
||||
if (DL >= 5) {
|
||||
sum += x->getProbability (xi, indexer) * messageProducts[k];
|
||||
if (Constants::DEBUG >= 5) {
|
||||
if (k != 0) *calcs1 << " + " ;
|
||||
if (k != 0) *calcs2 << " + " ;
|
||||
*calcs1 << x->cptEntryToString (xi, indexer.indices());
|
||||
*calcs1 << ".mp" << k;
|
||||
*calcs2 << Util::fl (x->getProbability (xi, indexer.linearIndex()));
|
||||
*calcs2 << LogAware::fl (x->getProbability (xi, indexer));
|
||||
*calcs2 << "*" << messageProducts[k];
|
||||
}
|
||||
++ indexer;
|
||||
@ -394,7 +383,7 @@ BnBpSolver::updatePiValues (BayesNode* x)
|
||||
}
|
||||
|
||||
piValues[xi] = sum;
|
||||
if (DL >= 5) {
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " " << PI_SYMBOL << "(" << x->label() << ")" ;
|
||||
cout << "[" << x->states()[xi] << "]" ;
|
||||
cout << " = " << (*calcs1).str();
|
||||
@ -412,7 +401,7 @@ void
|
||||
BnBpSolver::updateLambdaValues (BayesNode* x)
|
||||
{
|
||||
// λ(Xi)
|
||||
if (DL >= 3) {
|
||||
if (Constants::DEBUG >= 3) {
|
||||
cout << "updating " << LD_SYMBOL << " values for " << x->label() << endl;
|
||||
}
|
||||
Params& lambdaValues = ninf(x)->getLambdaValues();
|
||||
@ -421,11 +410,11 @@ BnBpSolver::updateLambdaValues (BayesNode* x)
|
||||
stringstream* calcs2 = 0;
|
||||
|
||||
for (unsigned xi = 0; xi < x->nrStates(); xi++) {
|
||||
if (DL >= 5) {
|
||||
if (Constants::DEBUG >= 5) {
|
||||
calcs1 = new stringstream;
|
||||
calcs2 = new stringstream;
|
||||
}
|
||||
double product = Util::multIdenty();
|
||||
double product = LogAware::multIdenty();
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned i = 0; i < childLinks.size(); i++) {
|
||||
product += childLinks[i]->getMessage()[xi];
|
||||
@ -433,7 +422,7 @@ BnBpSolver::updateLambdaValues (BayesNode* x)
|
||||
} else {
|
||||
for (unsigned i = 0; i < childLinks.size(); i++) {
|
||||
product *= childLinks[i]->getMessage()[xi];
|
||||
if (DL >= 5) {
|
||||
if (Constants::DEBUG >= 5) {
|
||||
if (i != 0) *calcs1 << "." ;
|
||||
if (i != 0) *calcs2 << "*" ;
|
||||
*calcs1 << childLinks[i]->toString (xi);
|
||||
@ -442,7 +431,7 @@ BnBpSolver::updateLambdaValues (BayesNode* x)
|
||||
}
|
||||
}
|
||||
lambdaValues[xi] = product;
|
||||
if (DL >= 5) {
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " " << LD_SYMBOL << "(" << x->label() << ")" ;
|
||||
cout << "[" << x->states()[xi] << "]" ;
|
||||
cout << " = " << (*calcs1).str();
|
||||
@ -474,7 +463,7 @@ BnBpSolver::calculatePiMessage (BpLink* link)
|
||||
const Params& zPiValues = ninf(z)->getPiValues();
|
||||
for (unsigned zi = 0; zi < z->nrStates(); zi++) {
|
||||
double product = zPiValues[zi];
|
||||
if (DL >= 5) {
|
||||
if (Constants::DEBUG >= 5) {
|
||||
calcs1 = new stringstream;
|
||||
calcs2 = new stringstream;
|
||||
*calcs1 << PI_SYMBOL << "(" << z->label() << ")";
|
||||
@ -491,7 +480,7 @@ BnBpSolver::calculatePiMessage (BpLink* link)
|
||||
for (unsigned i = 0; i < zChildLinks.size(); i++) {
|
||||
if (zChildLinks[i]->getSource() != x) {
|
||||
product *= zChildLinks[i]->getMessage()[zi];
|
||||
if (DL >= 5) {
|
||||
if (Constants::DEBUG >= 5) {
|
||||
*calcs1 << "." << zChildLinks[i]->toString (zi);
|
||||
*calcs2 << " * " << zChildLinks[i]->getMessage()[zi];
|
||||
}
|
||||
@ -499,7 +488,7 @@ BnBpSolver::calculatePiMessage (BpLink* link)
|
||||
}
|
||||
}
|
||||
zxPiNextMessage[zi] = product;
|
||||
if (DL >= 5) {
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " " << link->toString();
|
||||
cout << "[" << z->states()[zi] << "]" ;
|
||||
cout << " = " << (*calcs1).str();
|
||||
@ -513,7 +502,7 @@ BnBpSolver::calculatePiMessage (BpLink* link)
|
||||
delete calcs2;
|
||||
}
|
||||
}
|
||||
Util::normalize (zxPiNextMessage);
|
||||
LogAware::normalize (zxPiNextMessage);
|
||||
}
|
||||
|
||||
|
||||
@ -527,10 +516,10 @@ BnBpSolver::calculateLambdaMessage (BpLink* link)
|
||||
if (x->hasEvidence()) {
|
||||
return;
|
||||
}
|
||||
Params& yxLambdaNextMessage = link->getNextMessage();
|
||||
const BpLinkSet& yParentLinks = ninf(y)->getIncomingParentLinks();
|
||||
const Params& yLambdaValues = ninf(y)->getLambdaValues();
|
||||
int parentIndex = y->getIndexOfParent (x);
|
||||
Params& yxLambdaNextMessage = link->getNextMessage();
|
||||
const BpLinkSet& yParentLinks = ninf(y)->getIncomingParentLinks();
|
||||
const Params& yLambdaValues = ninf(y)->getLambdaValues();
|
||||
int parentIndex = y->indexOfParent (x);
|
||||
stringstream* calcs1 = 0;
|
||||
stringstream* calcs2 = 0;
|
||||
|
||||
@ -548,11 +537,11 @@ BnBpSolver::calculateLambdaMessage (BpLink* link)
|
||||
while (indexer[parentIndex] != 0) {
|
||||
++ indexer;
|
||||
}
|
||||
if (DL >= 5) {
|
||||
if (Constants::DEBUG >= 5) {
|
||||
calcs1 = new stringstream;
|
||||
calcs2 = new stringstream;
|
||||
}
|
||||
double messageProduct = Util::multIdenty();
|
||||
double messageProduct = LogAware::multIdenty();
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned i = 0; i < yParentLinks.size(); i++) {
|
||||
if (yParentLinks[i]->getSource() != x) {
|
||||
@ -562,9 +551,9 @@ BnBpSolver::calculateLambdaMessage (BpLink* link)
|
||||
} else {
|
||||
for (unsigned i = 0; i < yParentLinks.size(); i++) {
|
||||
if (yParentLinks[i]->getSource() != x) {
|
||||
if (DL >= 5) {
|
||||
if (messageProduct != Util::multIdenty()) *calcs1 << "*" ;
|
||||
if (messageProduct != Util::multIdenty()) *calcs2 << "*" ;
|
||||
if (Constants::DEBUG >= 5) {
|
||||
if (messageProduct != LogAware::multIdenty()) *calcs1 << "*" ;
|
||||
if (messageProduct != LogAware::multIdenty()) *calcs2 << "*" ;
|
||||
*calcs1 << yParentLinks[i]->toString (indexer[i]);
|
||||
*calcs2 << yParentLinks[i]->getMessage()[indexer[i]];
|
||||
}
|
||||
@ -574,7 +563,7 @@ BnBpSolver::calculateLambdaMessage (BpLink* link)
|
||||
}
|
||||
messageProducts[k] = messageProduct;
|
||||
++ indexer;
|
||||
if (DL >= 5) {
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " mp" << k;
|
||||
cout << " = " << (*calcs1).str();
|
||||
if (yParentLinks.size() == 1) {
|
||||
@ -591,55 +580,54 @@ BnBpSolver::calculateLambdaMessage (BpLink* link)
|
||||
}
|
||||
|
||||
for (unsigned xi = 0; xi < x->nrStates(); xi++) {
|
||||
if (DL >= 5) {
|
||||
if (Constants::DEBUG >= 5) {
|
||||
calcs1 = new stringstream;
|
||||
calcs2 = new stringstream;
|
||||
}
|
||||
double outerSum = Util::addIdenty();
|
||||
double outerSum = LogAware::addIdenty();
|
||||
for (unsigned yi = 0; yi < y->nrStates(); yi++) {
|
||||
if (DL >= 5) {
|
||||
if (Constants::DEBUG >= 5) {
|
||||
(yi != 0) ? *calcs1 << " + {" : *calcs1 << "{" ;
|
||||
(yi != 0) ? *calcs2 << " + {" : *calcs2 << "{" ;
|
||||
}
|
||||
double innerSum = Util::addIdenty();
|
||||
double innerSum = LogAware::addIdenty();
|
||||
indexer.reset();
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned k = 0; k < N; k++) {
|
||||
while (indexer[parentIndex] != xi) {
|
||||
++ indexer;
|
||||
}
|
||||
Util::logSum (innerSum, y->getProbability (
|
||||
yi, indexer.linearIndex()) + messageProducts[k]);
|
||||
innerSum = Util::logSum (innerSum,
|
||||
y->getProbability (yi, indexer) + messageProducts[k]);
|
||||
++ indexer;
|
||||
}
|
||||
Util::logSum (outerSum, innerSum + yLambdaValues[yi]);
|
||||
outerSum = Util::logSum (outerSum, innerSum + yLambdaValues[yi]);
|
||||
} else {
|
||||
for (unsigned k = 0; k < N; k++) {
|
||||
while (indexer[parentIndex] != xi) {
|
||||
++ indexer;
|
||||
}
|
||||
if (DL >= 5) {
|
||||
if (Constants::DEBUG >= 5) {
|
||||
if (k != 0) *calcs1 << " + " ;
|
||||
if (k != 0) *calcs2 << " + " ;
|
||||
*calcs1 << y->cptEntryToString (yi, indexer.indices());
|
||||
*calcs1 << ".mp" << k;
|
||||
*calcs2 << y->getProbability (yi, indexer.linearIndex());
|
||||
*calcs2 << y->getProbability (yi, indexer);
|
||||
*calcs2 << "*" << messageProducts[k];
|
||||
}
|
||||
innerSum += y->getProbability (
|
||||
yi, indexer.linearIndex()) * messageProducts[k];
|
||||
innerSum += y->getProbability (yi, indexer) * messageProducts[k];
|
||||
++ indexer;
|
||||
}
|
||||
outerSum += innerSum * yLambdaValues[yi];
|
||||
}
|
||||
if (DL >= 5) {
|
||||
if (Constants::DEBUG >= 5) {
|
||||
*calcs1 << "}." << LD_SYMBOL << "(" << y->label() << ")" ;
|
||||
*calcs1 << "[" << y->states()[yi] << "]";
|
||||
*calcs2 << "}*" << yLambdaValues[yi];
|
||||
}
|
||||
}
|
||||
yxLambdaNextMessage[xi] = outerSum;
|
||||
if (DL >= 5) {
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " " << link->toString();
|
||||
cout << "[" << x->states()[xi] << "]" ;
|
||||
cout << " = " << (*calcs1).str();
|
||||
@ -649,7 +637,7 @@ BnBpSolver::calculateLambdaMessage (BpLink* link)
|
||||
delete calcs2;
|
||||
}
|
||||
}
|
||||
Util::normalize (yxLambdaNextMessage);
|
||||
LogAware::normalize (yxLambdaNextMessage);
|
||||
}
|
||||
|
||||
|
||||
@ -674,7 +662,7 @@ BnBpSolver::getJointByConditioning (const VarIds& jointVarIds) const
|
||||
for (unsigned i = 1; i < jointVarIds.size(); i++) {
|
||||
assert (jointVars[i]->hasEvidence() == false);
|
||||
VarIds reqVars = {jointVarIds[i]};
|
||||
reqVars.insert (reqVars.end(), observedVids.begin(), observedVids.end());
|
||||
Util::addToVector (reqVars, observedVids);
|
||||
mrn = bayesNet_->getMinimalRequesiteNetwork (reqVars);
|
||||
Params newBeliefs;
|
||||
VarNodes observedVars;
|
||||
@ -720,8 +708,7 @@ BnBpSolver::printPiLambdaValues (const BayesNode* var) const
|
||||
cout << setw (20) << LD_SYMBOL << "(" + var->label() + ")" ;
|
||||
cout << setw (16) << "belief" ;
|
||||
cout << endl;
|
||||
cout << "--------------------------------" ;
|
||||
cout << "--------------------------------" ;
|
||||
Util::printDashedLine();
|
||||
cout << endl;
|
||||
const States& states = var->states();
|
||||
const Params& piVals = ninf(var)->getPiValues();
|
||||
@ -731,7 +718,7 @@ BnBpSolver::printPiLambdaValues (const BayesNode* var) const
|
||||
cout << setw (10) << states[xi];
|
||||
cout << setw (19) << piVals[xi];
|
||||
cout << setw (19) << ldVals[xi];
|
||||
cout.precision (PRECISION);
|
||||
cout.precision (Constants::PRECISION);
|
||||
cout << setw (16) << beliefs[xi];
|
||||
cout << endl;
|
||||
}
|
||||
@ -754,8 +741,8 @@ BnBpSolver::printAllMessageStatus (void) const
|
||||
BpNodeInfo::BpNodeInfo (BayesNode* node)
|
||||
{
|
||||
node_ = node;
|
||||
piVals_.resize (node->nrStates(), Util::one());
|
||||
ldVals_.resize (node->nrStates(), Util::one());
|
||||
piVals_.resize (node->nrStates(), LogAware::one());
|
||||
ldVals_.resize (node->nrStates(), LogAware::one());
|
||||
}
|
||||
|
||||
|
||||
|
@ -27,11 +27,11 @@ class BpLink
|
||||
destin_ = d;
|
||||
orientation_ = o;
|
||||
if (orientation_ == LinkOrientation::DOWN) {
|
||||
v1_.resize (s->nrStates(), Util::tl (1.0 / s->nrStates()));
|
||||
v2_.resize (s->nrStates(), Util::tl (1.0 / s->nrStates()));
|
||||
v1_.resize (s->nrStates(), LogAware::tl (1.0 / s->nrStates()));
|
||||
v2_.resize (s->nrStates(), LogAware::tl (1.0 / s->nrStates()));
|
||||
} else {
|
||||
v1_.resize (d->nrStates(), Util::tl (1.0 / d->nrStates()));
|
||||
v2_.resize (d->nrStates(), Util::tl (1.0 / d->nrStates()));
|
||||
v1_.resize (d->nrStates(), LogAware::tl (1.0 / d->nrStates()));
|
||||
v2_.resize (d->nrStates(), LogAware::tl (1.0 / d->nrStates()));
|
||||
}
|
||||
currMsg_ = &v1_;
|
||||
nextMsg_ = &v2_;
|
||||
@ -39,6 +39,22 @@ class BpLink
|
||||
msgSended_ = false;
|
||||
}
|
||||
|
||||
BayesNode* getSource (void) const { return source_; }
|
||||
|
||||
BayesNode* getDestination (void) const { return destin_; }
|
||||
|
||||
LinkOrientation getOrientation (void) const { return orientation_; }
|
||||
|
||||
const Params& getMessage (void) const { return *currMsg_; }
|
||||
|
||||
Params& getNextMessage (void) { return *nextMsg_;}
|
||||
|
||||
bool messageWasSended (void) const { return msgSended_; }
|
||||
|
||||
double getResidual (void) const { return residual_; }
|
||||
|
||||
void clearResidual (void) { residual_ = 0;}
|
||||
|
||||
void updateMessage (void)
|
||||
{
|
||||
swap (currMsg_, nextMsg_);
|
||||
@ -47,7 +63,7 @@ class BpLink
|
||||
|
||||
void updateResidual (void)
|
||||
{
|
||||
residual_ = Util::getMaxNorm (v1_, v2_);
|
||||
residual_ = LogAware::getMaxNorm (v1_, v2_);
|
||||
}
|
||||
|
||||
string toString (void) const
|
||||
@ -74,29 +90,19 @@ class BpLink
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
BayesNode* getSource (void) const { return source_; }
|
||||
BayesNode* getDestination (void) const { return destin_; }
|
||||
LinkOrientation getOrientation (void) const { return orientation_; }
|
||||
const Params& getMessage (void) const { return *currMsg_; }
|
||||
Params& getNextMessage (void) { return *nextMsg_; }
|
||||
bool messageWasSended (void) const { return msgSended_; }
|
||||
double getResidual (void) const { return residual_; }
|
||||
void clearResidual (void) { residual_ = 0;}
|
||||
|
||||
private:
|
||||
BayesNode* source_;
|
||||
BayesNode* destin_;
|
||||
LinkOrientation orientation_;
|
||||
Params v1_;
|
||||
Params v2_;
|
||||
Params* currMsg_;
|
||||
Params* nextMsg_;
|
||||
Params v1_;
|
||||
Params v2_;
|
||||
Params* currMsg_;
|
||||
Params* nextMsg_;
|
||||
bool msgSended_;
|
||||
double residual_;
|
||||
};
|
||||
|
||||
|
||||
typedef vector<BpLink*> BpLinkSet;
|
||||
|
||||
|
||||
@ -105,32 +111,41 @@ class BpNodeInfo
|
||||
public:
|
||||
BpNodeInfo (BayesNode*);
|
||||
|
||||
Params getBeliefs (void) const;
|
||||
bool receivedBottomInfluence (void) const;
|
||||
Params& getPiValues (void) { return piVals_; }
|
||||
|
||||
Params& getPiValues (void) { return piVals_; }
|
||||
Params& getLambdaValues (void) { return ldVals_; }
|
||||
Params& getLambdaValues (void) { return ldVals_; }
|
||||
|
||||
const BpLinkSet& getIncomingParentLinks (void) { return inParentLinks_; }
|
||||
|
||||
const BpLinkSet& getIncomingChildLinks (void) { return inChildLinks_; }
|
||||
|
||||
const BpLinkSet& getIncomingParentLinks (void) { return inParentLinks_; }
|
||||
const BpLinkSet& getIncomingChildLinks (void) { return inChildLinks_; }
|
||||
const BpLinkSet& getOutcomingParentLinks (void) { return outParentLinks_; }
|
||||
const BpLinkSet& getOutcomingChildLinks (void) { return outChildLinks_; }
|
||||
|
||||
const BpLinkSet& getOutcomingChildLinks (void) { return outChildLinks_; }
|
||||
|
||||
void addIncomingParentLink (BpLink* l) { inParentLinks_.push_back (l); }
|
||||
void addIncomingChildLink (BpLink* l) { inChildLinks_.push_back (l); }
|
||||
void addIncomingParentLink (BpLink* l) { inParentLinks_.push_back (l); }
|
||||
|
||||
void addIncomingChildLink (BpLink* l) { inChildLinks_.push_back (l); }
|
||||
|
||||
void addOutcomingParentLink (BpLink* l) { outParentLinks_.push_back (l); }
|
||||
void addOutcomingChildLink (BpLink* l) { outChildLinks_.push_back (l); }
|
||||
|
||||
void addOutcomingChildLink (BpLink* l) { outChildLinks_.push_back (l); }
|
||||
|
||||
Params getBeliefs (void) const;
|
||||
|
||||
bool receivedBottomInfluence (void) const;
|
||||
|
||||
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN (BpNodeInfo);
|
||||
|
||||
const BayesNode* node_;
|
||||
Params piVals_; // pi values
|
||||
Params ldVals_; // lambda values
|
||||
BpLinkSet inParentLinks_;
|
||||
BpLinkSet inChildLinks_;
|
||||
BpLinkSet outParentLinks_;
|
||||
BpLinkSet outChildLinks_;
|
||||
const BayesNode* node_;
|
||||
Params piVals_;
|
||||
Params ldVals_;
|
||||
BpLinkSet inParentLinks_;
|
||||
BpLinkSet inChildLinks_;
|
||||
BpLinkSet outParentLinks_;
|
||||
BpLinkSet outChildLinks_;
|
||||
};
|
||||
|
||||
|
||||
@ -139,32 +154,43 @@ class BnBpSolver : public Solver
|
||||
{
|
||||
public:
|
||||
BnBpSolver (const BayesNet&);
|
||||
|
||||
~BnBpSolver (void);
|
||||
|
||||
void runSolver (void);
|
||||
Params getPosterioriOf (VarId);
|
||||
Params getJointDistributionOf (const VarIds&);
|
||||
void runSolver (void);
|
||||
Params getPosterioriOf (VarId);
|
||||
Params getJointDistributionOf (const VarIds&);
|
||||
|
||||
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN (BnBpSolver);
|
||||
|
||||
void initializeSolver (void);
|
||||
void runLoopySolver (void);
|
||||
void maxResidualSchedule (void);
|
||||
bool converged (void) const;
|
||||
void updatePiValues (BayesNode*);
|
||||
void updateLambdaValues (BayesNode*);
|
||||
void calculateLambdaMessage (BpLink*);
|
||||
void calculatePiMessage (BpLink*);
|
||||
Params getJointByJunctionNode (const VarIds&);
|
||||
Params getJointByConditioning (const VarIds&) const;
|
||||
void printPiLambdaValues (const BayesNode*) const;
|
||||
void printAllMessageStatus (void) const;
|
||||
void initializeSolver (void);
|
||||
|
||||
void runLoopySolver (void);
|
||||
|
||||
void maxResidualSchedule (void);
|
||||
|
||||
bool converged (void) const;
|
||||
|
||||
void updatePiValues (BayesNode*);
|
||||
|
||||
void updateLambdaValues (BayesNode*);
|
||||
|
||||
void calculateLambdaMessage (BpLink*);
|
||||
|
||||
void calculatePiMessage (BpLink*);
|
||||
|
||||
Params getJointByJunctionNode (const VarIds&);
|
||||
|
||||
Params getJointByConditioning (const VarIds&) const;
|
||||
|
||||
void printPiLambdaValues (const BayesNode*) const;
|
||||
|
||||
void printAllMessageStatus (void) const;
|
||||
|
||||
void calculateAndUpdateMessage (BpLink* link, bool calcResidual = true)
|
||||
{
|
||||
if (DL >= 3) {
|
||||
if (Constants::DEBUG >= 3) {
|
||||
cout << "calculating & updating " << link->toString() << endl;
|
||||
}
|
||||
if (link->getOrientation() == LinkOrientation::DOWN) {
|
||||
@ -180,7 +206,7 @@ class BnBpSolver : public Solver
|
||||
|
||||
void calculateMessage (BpLink* link, bool calcResidual = true)
|
||||
{
|
||||
if (DL >= 3) {
|
||||
if (Constants::DEBUG >= 3) {
|
||||
cout << "calculating " << link->toString() << endl;
|
||||
}
|
||||
if (link->getOrientation() == LinkOrientation::DOWN) {
|
||||
@ -195,7 +221,7 @@ class BnBpSolver : public Solver
|
||||
|
||||
void updateMessage (BpLink* link)
|
||||
{
|
||||
if (DL >= 3) {
|
||||
if (Constants::DEBUG >= 3) {
|
||||
cout << "updating " << link->toString() << endl;
|
||||
}
|
||||
link->updateMessage();
|
||||
|
@ -1,7 +1,6 @@
|
||||
|
||||
#include "CFactorGraph.h"
|
||||
#include "Factor.h"
|
||||
#include "Distribution.h"
|
||||
|
||||
|
||||
bool CFactorGraph::checkForIdenticalFactors = true;
|
||||
@ -73,27 +72,34 @@ CFactorGraph::setInitialColors (void)
|
||||
|
||||
const FgFacSet& facNodes = groundFg_->getFactorNodes();
|
||||
if (checkForIdenticalFactors) {
|
||||
for (unsigned i = 0, s = facNodes.size(); i < s; i++) {
|
||||
Distribution* dist1 = facNodes[i]->getDistribution();
|
||||
for (unsigned j = 0; j < i; j++) {
|
||||
Distribution* dist2 = facNodes[j]->getDistribution();
|
||||
if (dist1 != dist2 && dist1->params == dist2->params) {
|
||||
if (facNodes[i]->factor()->getRanges() ==
|
||||
facNodes[j]->factor()->getRanges()) {
|
||||
facNodes[i]->factor()->setDistribution (dist2);
|
||||
}
|
||||
unsigned groupCount = 1;
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
Factor* f1 = facNodes[i]->factor();
|
||||
if (f1->distId() != Util::maxUnsigned()) {
|
||||
continue;
|
||||
}
|
||||
f1->setDistId (groupCount);
|
||||
for (unsigned j = i + 1; j < facNodes.size(); j++) {
|
||||
Factor* f2 = facNodes[j]->factor();
|
||||
if (f2->distId() != Util::maxUnsigned()) {
|
||||
continue;
|
||||
}
|
||||
if (f1->size() == f2->size() &&
|
||||
f1->ranges() == f2->ranges() &&
|
||||
f1->params() == f2->params()) {
|
||||
f2->setDistId (groupCount);
|
||||
}
|
||||
}
|
||||
groupCount ++;
|
||||
}
|
||||
}
|
||||
|
||||
// create the initial factor colors
|
||||
DistColorMap distColors;
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
const Distribution* dist = facNodes[i]->getDistribution();
|
||||
DistColorMap::iterator it = distColors.find (dist);
|
||||
unsigned distId = facNodes[i]->factor()->distId();
|
||||
DistColorMap::iterator it = distColors.find (distId);
|
||||
if (it == distColors.end()) {
|
||||
it = distColors.insert (make_pair (dist, getFreeColor())).first;
|
||||
it = distColors.insert (make_pair (distId, getFreeColor())).first;
|
||||
}
|
||||
setColor (facNodes[i], it->second);
|
||||
}
|
||||
@ -104,11 +110,11 @@ CFactorGraph::setInitialColors (void)
|
||||
void
|
||||
CFactorGraph::createGroups (void)
|
||||
{
|
||||
VarSignMap varGroups;
|
||||
VarSignMap varGroups;
|
||||
FacSignMap factorGroups;
|
||||
unsigned nIters = 0;
|
||||
bool groupsHaveChanged = true;
|
||||
const FgVarSet& varNodes = groundFg_->getVarNodes();
|
||||
const FgVarSet& varNodes = groundFg_->getVarNodes();
|
||||
const FgFacSet& facNodes = groundFg_->getFactorNodes();
|
||||
|
||||
while (groupsHaveChanged || nIters == 1) {
|
||||
@ -164,8 +170,9 @@ CFactorGraph::createGroups (void)
|
||||
|
||||
|
||||
void
|
||||
CFactorGraph::createClusters (const VarSignMap& varGroups,
|
||||
const FacSignMap& factorGroups)
|
||||
CFactorGraph::createClusters (
|
||||
const VarSignMap& varGroups,
|
||||
const FacSignMap& factorGroups)
|
||||
{
|
||||
varClusters_.reserve (varGroups.size());
|
||||
for (VarSignMap::const_iterator it = varGroups.begin();
|
||||
@ -249,7 +256,7 @@ CFactorGraph::getCompressedFactorGraph (void)
|
||||
myGroundVars.push_back (v);
|
||||
}
|
||||
Factor* newFactor = new Factor (myGroundVars,
|
||||
facClusters_[i]->getGroundFactors()[0]->getDistribution());
|
||||
facClusters_[i]->getGroundFactors()[0]->params());
|
||||
FgFacNode* fn = new FgFacNode (newFactor);
|
||||
facClusters_[i]->setRepresentativeFactor (fn);
|
||||
fg->addFactor (fn);
|
||||
@ -293,8 +300,9 @@ CFactorGraph::getGroundEdgeCount (
|
||||
|
||||
|
||||
void
|
||||
CFactorGraph::printGroups (const VarSignMap& varGroups,
|
||||
const FacSignMap& factorGroups) const
|
||||
CFactorGraph::printGroups (
|
||||
const VarSignMap& varGroups,
|
||||
const FacSignMap& factorGroups) const
|
||||
{
|
||||
unsigned count = 1;
|
||||
cout << "variable groups:" << endl;
|
||||
|
@ -15,23 +15,25 @@ class Signature;
|
||||
class SignatureHash;
|
||||
|
||||
|
||||
typedef long Color;
|
||||
typedef unordered_map<unsigned, vector<Color> > VarColorMap;
|
||||
typedef unordered_map<const Distribution*, Color> DistColorMap;
|
||||
typedef unordered_map<VarId, VarCluster*> VarId2VarCluster;
|
||||
typedef vector<VarCluster*> VarClusterSet;
|
||||
typedef vector<FacCluster*> FacClusterSet;
|
||||
typedef unordered_map<Signature, FgVarSet, SignatureHash> VarSignMap;
|
||||
typedef unordered_map<Signature, FgFacSet, SignatureHash> FacSignMap;
|
||||
typedef long Color;
|
||||
|
||||
typedef unordered_map<unsigned, vector<Color>> VarColorMap;
|
||||
|
||||
typedef unordered_map<unsigned, Color> DistColorMap;
|
||||
typedef unordered_map<VarId, VarCluster*> VarId2VarCluster;
|
||||
|
||||
typedef vector<VarCluster*> VarClusterSet;
|
||||
typedef vector<FacCluster*> FacClusterSet;
|
||||
|
||||
typedef unordered_map<Signature, FgVarSet, SignatureHash> VarSignMap;
|
||||
typedef unordered_map<Signature, FgFacSet, SignatureHash> FacSignMap;
|
||||
|
||||
|
||||
|
||||
struct Signature
|
||||
{
|
||||
Signature (unsigned size)
|
||||
{
|
||||
colors.resize (size);
|
||||
}
|
||||
Signature (unsigned size) : colors(size) { }
|
||||
|
||||
bool operator< (const Signature& sig) const
|
||||
{
|
||||
if (colors.size() < sig.colors.size()) {
|
||||
@ -49,6 +51,7 @@ struct Signature
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool operator== (const Signature& sig) const
|
||||
{
|
||||
if (colors.size() != sig.colors.size()) {
|
||||
@ -61,12 +64,14 @@ struct Signature
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
vector<Color> colors;
|
||||
};
|
||||
|
||||
|
||||
|
||||
struct SignatureHash {
|
||||
struct SignatureHash
|
||||
{
|
||||
size_t operator() (const Signature &sig) const
|
||||
{
|
||||
size_t val = hash<size_t>()(sig.colors.size());
|
||||
@ -141,10 +146,12 @@ class FacCluster
|
||||
{
|
||||
return representFactor_;
|
||||
}
|
||||
|
||||
void setRepresentativeFactor (FgFacNode* fn)
|
||||
{
|
||||
representFactor_ = fn;
|
||||
}
|
||||
|
||||
const FgFacSet& getGroundFactors (void) const
|
||||
{
|
||||
return groundFactors_;
|
||||
@ -162,31 +169,28 @@ class CFactorGraph
|
||||
{
|
||||
public:
|
||||
CFactorGraph (const FactorGraph&);
|
||||
|
||||
~CFactorGraph (void);
|
||||
|
||||
FactorGraph* getCompressedFactorGraph (void);
|
||||
unsigned getGroundEdgeCount (const FacCluster*, const VarCluster*) const;
|
||||
const VarClusterSet& getVarClusters (void) { return varClusters_; }
|
||||
|
||||
const FacClusterSet& getFacClusters (void) { return facClusters_; }
|
||||
|
||||
FgVarNode* getEquivalentVariable (VarId vid)
|
||||
{
|
||||
VarCluster* vc = vid2VarCluster_.find (vid)->second;
|
||||
return vc->getRepresentativeVariable();
|
||||
}
|
||||
|
||||
const VarClusterSet& getVarClusters (void) { return varClusters_; }
|
||||
const FacClusterSet& getFacClusters (void) { return facClusters_; }
|
||||
|
||||
FactorGraph* getCompressedFactorGraph (void);
|
||||
|
||||
unsigned getGroundEdgeCount (const FacCluster*, const VarCluster*) const;
|
||||
|
||||
static bool checkForIdenticalFactors;
|
||||
|
||||
private:
|
||||
void setInitialColors (void);
|
||||
void createGroups (void);
|
||||
void createClusters (const VarSignMap&, const FacSignMap&);
|
||||
const Signature& getSignature (const FgVarNode*);
|
||||
const Signature& getSignature (const FgFacNode*);
|
||||
void printGroups (const VarSignMap&, const FacSignMap&) const;
|
||||
|
||||
Color getFreeColor (void) {
|
||||
Color getFreeColor (void)
|
||||
{
|
||||
++ freeColor_;
|
||||
return freeColor_ - 1;
|
||||
}
|
||||
@ -214,14 +218,26 @@ class CFactorGraph
|
||||
return vid2VarCluster_.find (vid)->second;
|
||||
}
|
||||
|
||||
void setInitialColors (void);
|
||||
|
||||
void createGroups (void);
|
||||
|
||||
void createClusters (const VarSignMap&, const FacSignMap&);
|
||||
|
||||
const Signature& getSignature (const FgVarNode*);
|
||||
|
||||
const Signature& getSignature (const FgFacNode*);
|
||||
|
||||
void printGroups (const VarSignMap&, const FacSignMap&) const;
|
||||
|
||||
Color freeColor_;
|
||||
vector<Color> varColors_;
|
||||
vector<Color> factorColors_;
|
||||
vector<Signature> varSignatures_;
|
||||
vector<Signature> factorSignatures_;
|
||||
VarClusterSet varClusters_;
|
||||
FacClusterSet facClusters_;
|
||||
VarId2VarCluster vid2VarCluster_;
|
||||
FacClusterSet facClusters_;
|
||||
VarId2VarCluster vid2VarCluster_;
|
||||
const FactorGraph* groundFg_;
|
||||
};
|
||||
|
||||
|
@ -20,24 +20,24 @@ CbpSolver::getPosterioriOf (VarId vid)
|
||||
FgVarNode* var = lfg_->getEquivalentVariable (vid);
|
||||
Params probs;
|
||||
if (var->hasEvidence()) {
|
||||
probs.resize (var->nrStates(), Util::noEvidence());
|
||||
probs[var->getEvidence()] = Util::withEvidence();
|
||||
probs.resize (var->nrStates(), LogAware::noEvidence());
|
||||
probs[var->getEvidence()] = LogAware::withEvidence();
|
||||
} else {
|
||||
probs.resize (var->nrStates(), Util::multIdenty());
|
||||
probs.resize (var->nrStates(), LogAware::multIdenty());
|
||||
const SpLinkSet& links = ninf(var)->getLinks();
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
||||
Util::add (probs, l->getPoweredMessage());
|
||||
}
|
||||
Util::normalize (probs);
|
||||
LogAware::normalize (probs);
|
||||
Util::fromLog (probs);
|
||||
} else {
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
||||
Util::multiply (probs, l->getPoweredMessage());
|
||||
}
|
||||
Util::normalize (probs);
|
||||
LogAware::normalize (probs);
|
||||
}
|
||||
}
|
||||
return probs;
|
||||
@ -62,7 +62,7 @@ void
|
||||
CbpSolver::initializeSolver (void)
|
||||
{
|
||||
unsigned nGroundVars, nGroundFacs, nWithoutNeighs;
|
||||
if (COLLECT_STATISTICS) {
|
||||
if (Constants::COLLECT_STATS) {
|
||||
nGroundVars = factorGraph_->getVarNodes().size();
|
||||
nGroundFacs = factorGraph_->getFactorNodes().size();
|
||||
const FgVarSet& vars = factorGraph_->getVarNodes();
|
||||
@ -82,7 +82,7 @@ CbpSolver::initializeSolver (void)
|
||||
// factorGraph_->exportToGraphViz ("uncompressed_fg.dot");
|
||||
factorGraph_ = lfg_->getCompressedFactorGraph();
|
||||
|
||||
if (COLLECT_STATISTICS) {
|
||||
if (Constants::COLLECT_STATS) {
|
||||
unsigned nClusterVars = factorGraph_->getVarNodes().size();
|
||||
unsigned nClusterFacs = factorGraph_->getFactorNodes().size();
|
||||
Statistics::updateCompressingStatistics (nGroundVars, nGroundFacs,
|
||||
@ -123,7 +123,7 @@ CbpSolver::maxResidualSchedule (void)
|
||||
calculateMessage (links_[i]);
|
||||
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
|
||||
linkMap_.insert (make_pair (links_[i], it));
|
||||
if (DL >= 2 && DL < 5) {
|
||||
if (Constants::DEBUG >= 2 && Constants::DEBUG < 5) {
|
||||
cout << "calculating " << links_[i]->toString() << endl;
|
||||
}
|
||||
}
|
||||
@ -131,7 +131,7 @@ CbpSolver::maxResidualSchedule (void)
|
||||
}
|
||||
|
||||
for (unsigned c = 0; c < links_.size(); c++) {
|
||||
if (DL >= 2) {
|
||||
if (Constants::DEBUG >= 2) {
|
||||
cout << endl << "current residuals:" << endl;
|
||||
for (SortedOrder::iterator it = sortedOrder_.begin();
|
||||
it != sortedOrder_.end(); it ++) {
|
||||
@ -142,7 +142,7 @@ CbpSolver::maxResidualSchedule (void)
|
||||
|
||||
SortedOrder::iterator it = sortedOrder_.begin();
|
||||
SpLink* link = *it;
|
||||
if (DL >= 2) {
|
||||
if (Constants::DEBUG >= 2) {
|
||||
cout << "updating " << (*sortedOrder_.begin())->toString() << endl;
|
||||
}
|
||||
if (link->getResidual() < BpOptions::accuracy) {
|
||||
@ -159,7 +159,7 @@ CbpSolver::maxResidualSchedule (void)
|
||||
const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks();
|
||||
for (unsigned j = 0; j < links.size(); j++) {
|
||||
if (links[j]->getVariable() != link->getVariable()) {
|
||||
if (DL >= 2 && DL < 5) {
|
||||
if (Constants::DEBUG >= 2 && Constants::DEBUG < 5) {
|
||||
cout << " calculating " << links[j]->toString() << endl;
|
||||
}
|
||||
calculateMessage (links[j]);
|
||||
@ -174,7 +174,7 @@ CbpSolver::maxResidualSchedule (void)
|
||||
const SpLinkSet& links = ninf(link->getFactor())->getLinks();
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
if (links[i]->getVariable() != link->getVariable()) {
|
||||
if (DL >= 2 && DL < 5) {
|
||||
if (Constants::DEBUG >= 2 && Constants::DEBUG < 5) {
|
||||
cout << " calculating " << links[i]->toString() << endl;
|
||||
}
|
||||
calculateMessage (links[i]);
|
||||
@ -196,15 +196,15 @@ CbpSolver::getVar2FactorMsg (const SpLink* link) const
|
||||
const FgFacNode* dst = link->getFactor();
|
||||
const CbpSolverLink* l = static_cast<const CbpSolverLink*> (link);
|
||||
if (src->hasEvidence()) {
|
||||
msg.resize (src->nrStates(), Util::noEvidence());
|
||||
msg.resize (src->nrStates(), LogAware::noEvidence());
|
||||
double value = link->getMessage()[src->getEvidence()];
|
||||
msg[src->getEvidence()] = Util::pow (value, l->getNumberOfEdges() - 1);
|
||||
msg[src->getEvidence()] = LogAware::pow (value, l->getNumberOfEdges() - 1);
|
||||
} else {
|
||||
msg = link->getMessage();
|
||||
Util::pow (msg, l->getNumberOfEdges() - 1);
|
||||
LogAware::pow (msg, l->getNumberOfEdges() - 1);
|
||||
}
|
||||
if (DL >= 5) {
|
||||
cout << " " << "init: " << Util::parametersToString (msg) << endl;
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " " << "init: " << msg << endl;
|
||||
}
|
||||
const SpLinkSet& links = ninf(src)->getLinks();
|
||||
if (Globals::logDomain) {
|
||||
@ -219,16 +219,16 @@ CbpSolver::getVar2FactorMsg (const SpLink* link) const
|
||||
if (links[i]->getFactor() != dst) {
|
||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
||||
Util::multiply (msg, l->getPoweredMessage());
|
||||
if (DL >= 5) {
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " msg from " << l->getFactor()->getLabel() << ": " ;
|
||||
cout << Util::parametersToString (l->getPoweredMessage()) << endl;
|
||||
cout << l->getPoweredMessage() << endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (DL >= 5) {
|
||||
cout << " result = " << Util::parametersToString (msg) << endl;
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " result = " << msg << endl;
|
||||
}
|
||||
return msg;
|
||||
}
|
||||
@ -241,12 +241,9 @@ CbpSolver::printLinkInformation (void) const
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links_[i]);
|
||||
cout << l->toString() << ":" << endl;
|
||||
cout << " curr msg = " ;
|
||||
cout << Util::parametersToString (l->getMessage()) << endl;
|
||||
cout << " next msg = " ;
|
||||
cout << Util::parametersToString (l->getNextMessage()) << endl;
|
||||
cout << " powered = " ;
|
||||
cout << Util::parametersToString (l->getPoweredMessage()) << endl;
|
||||
cout << " curr msg = " << l->getMessage() << endl;
|
||||
cout << " next msg = " << l->getNextMessage() << endl;
|
||||
cout << " powered = " << l->getPoweredMessage() << endl;
|
||||
cout << " residual = " << l->getResidual() << endl;
|
||||
}
|
||||
}
|
||||
|
@ -12,23 +12,24 @@ class CbpSolverLink : public SpLink
|
||||
CbpSolverLink (FgFacNode* fn, FgVarNode* vn, unsigned c) : SpLink (fn, vn)
|
||||
{
|
||||
edgeCount_ = c;
|
||||
poweredMsg_.resize (vn->nrStates(), Util::one());
|
||||
poweredMsg_.resize (vn->nrStates(), LogAware::one());
|
||||
}
|
||||
|
||||
unsigned getNumberOfEdges (void) const { return edgeCount_; }
|
||||
|
||||
const Params& getPoweredMessage (void) const { return poweredMsg_; }
|
||||
|
||||
void updateMessage (void)
|
||||
{
|
||||
poweredMsg_ = *nextMsg_;
|
||||
swap (currMsg_, nextMsg_);
|
||||
msgSended_ = true;
|
||||
Util::pow (poweredMsg_, edgeCount_);
|
||||
LogAware::pow (poweredMsg_, edgeCount_);
|
||||
}
|
||||
|
||||
unsigned getNumberOfEdges (void) const { return edgeCount_; }
|
||||
const Params& getPoweredMessage (void) const { return poweredMsg_; }
|
||||
|
||||
private:
|
||||
Params poweredMsg_;
|
||||
unsigned edgeCount_;
|
||||
Params poweredMsg_;
|
||||
unsigned edgeCount_;
|
||||
};
|
||||
|
||||
|
||||
@ -37,21 +38,22 @@ class CbpSolver : public FgBpSolver
|
||||
{
|
||||
public:
|
||||
CbpSolver (FactorGraph& fg) : FgBpSolver (fg) { }
|
||||
|
||||
~CbpSolver (void);
|
||||
|
||||
Params getPosterioriOf (VarId);
|
||||
Params getJointDistributionOf (const VarIds&);
|
||||
Params getPosterioriOf (VarId);
|
||||
|
||||
Params getJointDistributionOf (const VarIds&);
|
||||
|
||||
private:
|
||||
void initializeSolver (void);
|
||||
void createLinks (void);
|
||||
void initializeSolver (void);
|
||||
void createLinks (void);
|
||||
|
||||
void maxResidualSchedule (void);
|
||||
Params getVar2FactorMsg (const SpLink*) const;
|
||||
void printLinkInformation (void) const;
|
||||
void maxResidualSchedule (void);
|
||||
Params getVar2FactorMsg (const SpLink*) const;
|
||||
void printLinkInformation (void) const;
|
||||
|
||||
|
||||
CFactorGraph* lfg_;
|
||||
CFactorGraph* lfg_;
|
||||
};
|
||||
|
||||
#endif // HORUS_CBP_H
|
||||
|
@ -1,10 +1,11 @@
|
||||
#include <queue>
|
||||
|
||||
#include <fstream>
|
||||
|
||||
#include "ConstraintTree.h"
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
|
||||
void
|
||||
CTNode::addChild (CTNode* child, bool updateLevels)
|
||||
{
|
||||
@ -42,6 +43,26 @@ CTNode::removeChild (CTNode* child)
|
||||
|
||||
|
||||
|
||||
void
|
||||
CTNode::removeAndDeleteChild (CTNode* child)
|
||||
{
|
||||
removeChild (child);
|
||||
CTNode::deleteSubtree (child);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
CTNode::removeAndDeleteAllChilds (void)
|
||||
{
|
||||
for (unsigned i = 0; i < childs_.size(); i++) {
|
||||
deleteSubtree (childs_[i]);
|
||||
}
|
||||
childs_.clear();
|
||||
}
|
||||
|
||||
|
||||
|
||||
SymbolSet
|
||||
CTNode::childSymbols (void) const
|
||||
{
|
||||
@ -66,6 +87,32 @@ CTNode::updateChildLevels (CTNode* n, unsigned level)
|
||||
|
||||
|
||||
|
||||
CTNode*
|
||||
CTNode::copySubtree (const CTNode* n)
|
||||
{
|
||||
CTNode* newNode = new CTNode (*n);
|
||||
const CTNodes& childs = n->childs();
|
||||
for (unsigned i = 0; i < childs.size(); i++) {
|
||||
newNode->addChild (copySubtree (childs[i]));
|
||||
}
|
||||
return newNode;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
CTNode::deleteSubtree (CTNode* n)
|
||||
{
|
||||
assert (n);
|
||||
const CTNodes& childs = n->childs();
|
||||
for (unsigned i = 0; i < childs.size(); i++) {
|
||||
deleteSubtree (childs[i]);
|
||||
}
|
||||
delete n;
|
||||
}
|
||||
|
||||
|
||||
|
||||
ostream& operator<< (ostream &out, const CTNode& n)
|
||||
{
|
||||
// out << "(" << n.level() << ") " ;
|
||||
@ -75,6 +122,17 @@ ostream& operator<< (ostream &out, const CTNode& n)
|
||||
|
||||
|
||||
|
||||
ConstraintTree::ConstraintTree (unsigned nrLvs)
|
||||
{
|
||||
for (unsigned i = 0; i < nrLvs; i++) {
|
||||
logVars_.push_back (LogVar (i));
|
||||
}
|
||||
root_ = new CTNode (0, 0);
|
||||
logVarSet_ = LogVarSet (logVars_);
|
||||
}
|
||||
|
||||
|
||||
|
||||
ConstraintTree::ConstraintTree (const LogVars& logVars)
|
||||
{
|
||||
root_ = new CTNode (0, 0);
|
||||
@ -99,7 +157,7 @@ ConstraintTree::ConstraintTree (const LogVars& logVars,
|
||||
|
||||
ConstraintTree::ConstraintTree (const ConstraintTree& ct)
|
||||
{
|
||||
root_ = copySubtree (ct.root_);
|
||||
root_ = CTNode::copySubtree (ct.root_);
|
||||
logVars_ = ct.logVars_;
|
||||
logVarSet_ = ct.logVarSet_;
|
||||
}
|
||||
@ -108,7 +166,7 @@ ConstraintTree::ConstraintTree (const ConstraintTree& ct)
|
||||
|
||||
ConstraintTree::~ConstraintTree (void)
|
||||
{
|
||||
deleteSubtree (root_);
|
||||
CTNode::deleteSubtree (root_);
|
||||
}
|
||||
|
||||
|
||||
@ -200,21 +258,28 @@ ConstraintTree::moveToBottom (const LogVars& lvs)
|
||||
|
||||
void
|
||||
ConstraintTree::join (ConstraintTree* ct, bool assertWhenNotFound)
|
||||
{
|
||||
{
|
||||
if (logVarSet_.empty()) {
|
||||
delete root_;
|
||||
root_ = CTNode::copySubtree (ct->root());
|
||||
logVars_ = ct->logVars();
|
||||
logVarSet_ = ct->logVarSet();
|
||||
return;
|
||||
}
|
||||
|
||||
LogVarSet intersect = logVarSet_ & ct->logVarSet_;
|
||||
if (intersect.empty()) {
|
||||
const CTNodes& childs = ct->root()->childs();
|
||||
CTNodes leafs = getNodesAtLevel (getLevel (logVars_.back()));
|
||||
for (unsigned i = 0; i < leafs.size(); i++) {
|
||||
for (unsigned j = 0; j < childs.size(); j++) {
|
||||
leafs[i]->addChild (copySubtree (childs[j]));
|
||||
leafs[i]->addChild (CTNode::copySubtree (childs[j]));
|
||||
}
|
||||
}
|
||||
logVars_.insert (logVars_.end(), ct->logVars_.begin(), ct->logVars_.end());
|
||||
Util::addToVector (logVars_, ct->logVars_);
|
||||
logVarSet_ |= ct->logVarSet_;
|
||||
|
||||
} else {
|
||||
|
||||
moveToBottom (intersect.elements());
|
||||
ct->moveToTop (intersect.elements());
|
||||
|
||||
@ -222,25 +287,27 @@ ConstraintTree::join (ConstraintTree* ct, bool assertWhenNotFound)
|
||||
CTNodes nodes = getNodesAtLevel (level);
|
||||
|
||||
Tuples tuples;
|
||||
CTNodes continuationNodes;
|
||||
CTNodes continNodes;
|
||||
getTuples (ct->root(),
|
||||
Tuples(),
|
||||
intersect.size(),
|
||||
tuples,
|
||||
continuationNodes);
|
||||
continNodes);
|
||||
|
||||
for (unsigned i = 0; i < tuples.size(); i++) {
|
||||
bool tupleFounded = false;
|
||||
for (unsigned j = 0; j < nodes.size(); j++) {
|
||||
tupleFounded |= join (nodes[j], tuples[i], 0, continuationNodes[i]);
|
||||
tupleFounded |= join (nodes[j], tuples[i], 0, continNodes[i]);
|
||||
}
|
||||
if (assertWhenNotFound) {
|
||||
assert (tupleFounded);
|
||||
}
|
||||
}
|
||||
LogVarSet newLvs = ct->logVarSet_ - intersect;
|
||||
logVars_.insert (logVars_.end(), newLvs.begin(), newLvs.end());
|
||||
logVarSet_ |= newLvs;
|
||||
|
||||
LogVars newLvs (ct->logVars().begin() + intersect.size(),
|
||||
ct->logVars().end());
|
||||
Util::addToVector (logVars_, newLvs);
|
||||
logVarSet_ |= LogVarSet (newLvs);
|
||||
}
|
||||
}
|
||||
|
||||
@ -280,6 +347,10 @@ ConstraintTree::rename (LogVar X_old, LogVar X_new)
|
||||
void
|
||||
ConstraintTree::applySubstitution (const Substitution& theta)
|
||||
{
|
||||
LogVars discardedLvs = theta.getDiscardedLogVars();
|
||||
for (unsigned i = 0; i < discardedLvs.size(); i++) {
|
||||
remove(discardedLvs[i]);
|
||||
}
|
||||
for (unsigned i = 0; i < logVars_.size(); i++) {
|
||||
logVars_[i] = theta.newNameFor (logVars_[i]);
|
||||
}
|
||||
@ -308,11 +379,7 @@ ConstraintTree::remove (const LogVarSet& X)
|
||||
unsigned level = getLevel (X.front()) - 1;
|
||||
CTNodes nodes = getNodesAtLevel (level);
|
||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||
CTNodes childs = nodes[i]->childs();
|
||||
for (unsigned j = 0; j < childs.size(); j++) {
|
||||
nodes[i]->removeChild (childs[j]);
|
||||
deleteSubtree (childs[j]);
|
||||
}
|
||||
nodes[i]->removeAndDeleteAllChilds();
|
||||
}
|
||||
logVars_.resize (logVars_.size() - X.size());
|
||||
logVarSet_ -= X;
|
||||
@ -545,16 +612,16 @@ ConstraintTree::split (
|
||||
for (unsigned i = 0; i < commNodes.size(); i++) {
|
||||
commCt->root()->addChild (commNodes[i]);
|
||||
}
|
||||
//cout << commCt->tupleSet() << " + " ;
|
||||
//cout << exclCt->tupleSet() << " = " ;
|
||||
//cout << tupleSet() << endl << endl;
|
||||
// cout << commCt->tupleSet() << " + " ;
|
||||
// cout << exclCt->tupleSet() << " = " ;
|
||||
// cout << tupleSet() << endl << endl;
|
||||
// if (((commCt->tupleSet() | exclCt->tupleSet()) == tupleSet()) == false) {
|
||||
// exportToGraphViz ("_fail.dot", true);
|
||||
// commCt->exportToGraphViz ("_fail_comm.dot", true);
|
||||
// exclCt->exportToGraphViz ("_fail_excl.dot", true);
|
||||
// }
|
||||
assert ((commCt->tupleSet() | exclCt->tupleSet()) == tupleSet());
|
||||
assert ((exclCt->tupleSet (stopLevel) & ct->tupleSet (stopLevel)).empty());
|
||||
// assert ((commCt->tupleSet() | exclCt->tupleSet()) == tupleSet());
|
||||
// assert ((exclCt->tupleSet (stopLevel) & ct->tupleSet (stopLevel)).empty());
|
||||
return {commCt, exclCt};
|
||||
}
|
||||
|
||||
@ -601,36 +668,32 @@ ConstraintTree::jointCountNormalize (
|
||||
LogVar X_new1,
|
||||
LogVar X_new2)
|
||||
{
|
||||
exportToGraphViz ("C.dot", true);
|
||||
commCt->exportToGraphViz ("C_comm.dot", true);
|
||||
exclCt->exportToGraphViz ("C_exlc.dot", true);
|
||||
unsigned N = getConditionalCount (X);
|
||||
cout << "My tuples: " << tupleSet() << endl;
|
||||
cout << "CommCt tuples: " << commCt->tupleSet() << endl;
|
||||
cout << "ExclCt tuples: " << exclCt->tupleSet() << endl;
|
||||
cout << "Counted Lv: " << X << endl;
|
||||
cout << "Original N: " << N << endl;
|
||||
cout << endl;
|
||||
// cout << "My tuples: " << tupleSet() << endl;
|
||||
// cout << "CommCt tuples: " << commCt->tupleSet() << endl;
|
||||
// cout << "ExclCt tuples: " << exclCt->tupleSet() << endl;
|
||||
// cout << "Counted Lv: " << X << endl;
|
||||
// cout << "X_new1: " << X_new1 << endl;
|
||||
// cout << "X_new2: " << X_new2 << endl;
|
||||
// cout << "Original N: " << N << endl;
|
||||
// cout << endl;
|
||||
|
||||
ConstraintTrees normCts1 = commCt->countNormalize (X);
|
||||
vector<unsigned> counts1 (normCts1.size());
|
||||
for (unsigned i = 0; i < normCts1.size(); i++) {
|
||||
counts1[i] = normCts1[i]->getConditionalCount (X);
|
||||
cout << "normCts1[" << i << "] #" << counts1[i] ;
|
||||
cout << " " << normCts1[i]->tupleSet() << endl;
|
||||
// cout << "normCts1[" << i << "] #" << counts1[i] ;
|
||||
// cout << " " << normCts1[i]->tupleSet() << endl;
|
||||
}
|
||||
|
||||
ConstraintTrees normCts2 = exclCt->countNormalize (X);
|
||||
vector<unsigned> counts2 (normCts2.size());
|
||||
for (unsigned i = 0; i < normCts2.size(); i++) {
|
||||
counts2[i] = normCts2[i]->getConditionalCount (X);
|
||||
cout << "normCts2[" << i << "] #" << counts2[i] ;
|
||||
cout << " " << normCts2[i]->tupleSet() << endl;
|
||||
// cout << "normCts2[" << i << "] #" << counts2[i] ;
|
||||
// cout << " " << normCts2[i]->tupleSet() << endl;
|
||||
}
|
||||
cout << endl;
|
||||
|
||||
cout << "1###### " << normCts1.size() << endl;
|
||||
cout << "2###### " << normCts2.size() << endl;
|
||||
// cout << endl;
|
||||
|
||||
ConstraintTree* excl1 = 0;
|
||||
for (unsigned i = 0; i < normCts1.size(); i++) {
|
||||
@ -638,7 +701,7 @@ ConstraintTree::jointCountNormalize (
|
||||
excl1 = normCts1[i];
|
||||
normCts1.erase (normCts1.begin() + i);
|
||||
counts1.erase (counts1.begin() + i);
|
||||
cout << ">joint-count(" << N << ",0)" << endl;
|
||||
// cout << "joint-count(" << N << ",0)" << endl;
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -649,22 +712,21 @@ ConstraintTree::jointCountNormalize (
|
||||
excl2 = normCts2[i];
|
||||
normCts2.erase (normCts2.begin() + i);
|
||||
counts2.erase (counts2.begin() + i);
|
||||
cout << ">>joint-count(0," << N << ")" << endl;
|
||||
// cout << "joint-count(0," << N << ")" << endl;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
cout << "3###### " << normCts1.size() << endl;
|
||||
cout << "4###### " << normCts2.size() << endl;
|
||||
|
||||
for (unsigned i = 0; i < normCts1.size(); i++) {
|
||||
unsigned j;
|
||||
for (j = 0; counts1[i] + counts2[j] != N; j++) ;
|
||||
cout << "joint-count(" << counts1[i] << "," << counts2[j] << ")" << endl;
|
||||
// cout << "joint-count(" << counts1[i] ;
|
||||
// cout << "," << counts2[j] << ")" << endl;
|
||||
const CTNodes& childs = normCts2[j]->root_->childs();
|
||||
for (unsigned k = 0; k < childs.size(); k++) {
|
||||
normCts1[i]->root_->addChild (childs[k]);
|
||||
normCts1[i]->root_->addChild (CTNode::copySubtree (childs[k]));
|
||||
}
|
||||
delete normCts2[j];
|
||||
}
|
||||
|
||||
ConstraintTrees cts = normCts1;
|
||||
@ -683,11 +745,6 @@ ConstraintTree::jointCountNormalize (
|
||||
cts.push_back (excl2);
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < cts.size(); i++) {
|
||||
stringstream ss;
|
||||
ss << "aaacts_" << i + 1 << ".dot" ;
|
||||
cts[i]->exportToGraphViz (ss.str().c_str(), true);
|
||||
}
|
||||
return cts;
|
||||
}
|
||||
|
||||
@ -735,11 +792,11 @@ ConstraintTree::expand (LogVar X)
|
||||
unsigned nrSymbols = getConditionalCount (X);
|
||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||
Symbols symbols;
|
||||
CTNodes childs = nodes[i]->childs();
|
||||
const CTNodes& childs = nodes[i]->childs();
|
||||
for (unsigned j = 0; j < childs.size(); j++) {
|
||||
symbols.push_back (childs[j]->symbol());
|
||||
nodes[i]->removeChild (childs[j]);
|
||||
}
|
||||
nodes[i]->removeAndDeleteAllChilds();
|
||||
CTNode* prev = nodes[i];
|
||||
assert (symbols.size() == nrSymbols);
|
||||
for (unsigned j = 0; j < nrSymbols; j++) {
|
||||
@ -768,7 +825,7 @@ ConstraintTree::ground (LogVar X)
|
||||
ConstraintTrees cts;
|
||||
const CTNodes& nodes = root_->childs();
|
||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||
CTNode* copy = copySubtree (nodes[i]);
|
||||
CTNode* copy = CTNode::copySubtree (nodes[i]);
|
||||
copy->setSymbol (nodes[i]->symbol());
|
||||
ConstraintTree* newCt = new ConstraintTree (logVars_);
|
||||
newCt->root()->addChild (copy);
|
||||
@ -884,7 +941,7 @@ ConstraintTree::join (
|
||||
if (currIdx == tuple.size() - 1) {
|
||||
const CTNodes& childs = appendNode->childs();
|
||||
for (unsigned i = 0; i < childs.size(); i++) {
|
||||
n->addChild (copySubtree (childs[i]));
|
||||
n->addChild (CTNode::copySubtree (childs[i]));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
@ -985,7 +1042,7 @@ ConstraintTree::countNormalize (
|
||||
{
|
||||
if (n->level() == stopLevel) {
|
||||
return vector<pair<CTNode*, unsigned>>() = {
|
||||
make_pair (copySubtree (n), countTuples (n))
|
||||
make_pair (CTNode::copySubtree (n), countTuples (n))
|
||||
};
|
||||
}
|
||||
|
||||
@ -1004,65 +1061,6 @@ ConstraintTree::countNormalize (
|
||||
}
|
||||
|
||||
|
||||
/*
|
||||
void
|
||||
ConstraintTree::split (
|
||||
CTNode* n1,
|
||||
CTNode* n2,
|
||||
CTNodes& nodes,
|
||||
unsigned stopLevel)
|
||||
{
|
||||
CTNodes& childs1 = n1->childs();
|
||||
CTNodes& childs2 = n2->childs();
|
||||
// cout << string (n1->level() * 8, '-') << "Level = " << n1->level() + 1;
|
||||
// cout << ", #I = " << childs1.size();
|
||||
// cout << ", #J = " << childs2.size() << endl;
|
||||
for (unsigned i = 0; i < childs1.size(); i++) {
|
||||
for (unsigned j = 0; j < childs2.size(); j++) {
|
||||
if (childs1[i]->symbol() != childs2[j]->symbol()) {
|
||||
continue;
|
||||
}
|
||||
if (childs1[i]->level() == stopLevel) {
|
||||
CTNode* newNode = copySubtree (childs1[i]);
|
||||
newNode->setSymbol (childs1[i]->symbol());
|
||||
nodes.push_back (newNode);
|
||||
childs1[i]->setSymbol (Symbol::invalid());
|
||||
break;
|
||||
} else {
|
||||
CTNodes lowerNodes;
|
||||
split (childs1[i], childs2[j], lowerNodes, stopLevel);
|
||||
if (lowerNodes.empty() == false) {
|
||||
CTNode* me = new CTNode (childs1[i]->symbol(), childs1[i]->level());
|
||||
for (unsigned k = 0; k < lowerNodes.size(); k++) {
|
||||
me->addChild (lowerNodes[k]);
|
||||
}
|
||||
nodes.push_back (me);
|
||||
}
|
||||
if (childs1[i]->isLeaf()) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < (int)childs1.size(); i++) {
|
||||
// cout << string (n1->level() * 8, '-') << childs1[i];
|
||||
if (childs1[i]->symbol() == Symbol::invalid()) {
|
||||
// cout << " empty, removing..." ;
|
||||
n1->removeChild (childs1[i]);
|
||||
i --;
|
||||
} else if (childs1[i]->isLeaf() &&
|
||||
childs1[i]->level() != stopLevel) {
|
||||
// cout << " leaf, removing..." ;
|
||||
n1->removeChild (childs1[i]);
|
||||
i --;
|
||||
}
|
||||
// cout << endl;
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
|
||||
|
||||
void
|
||||
ConstraintTree::split (
|
||||
@ -1085,7 +1083,7 @@ ConstraintTree::split (
|
||||
continue;
|
||||
}
|
||||
if (childs1[i]->level() == stopLevel) {
|
||||
CTNode* newNode = copySubtree (childs1[i]);
|
||||
CTNode* newNode = CTNode::copySubtree (childs1[i]);
|
||||
nodes.push_back (newNode);
|
||||
childs1[i]->setSymbol (Symbol::invalid());
|
||||
} else {
|
||||
@ -1103,11 +1101,11 @@ ConstraintTree::split (
|
||||
|
||||
for (int i = 0; i < (int)childs1.size(); i++) {
|
||||
if (childs1[i]->symbol() == Symbol::invalid()) {
|
||||
n1->removeChild (childs1[i]);
|
||||
n1->removeAndDeleteChild (childs1[i]);
|
||||
i --;
|
||||
} else if (childs1[i]->isLeaf() &&
|
||||
childs1[i]->level() != stopLevel) {
|
||||
n1->removeChild (childs1[i]);
|
||||
n1->removeAndDeleteChild (childs1[i]);
|
||||
i --;
|
||||
}
|
||||
}
|
||||
@ -1141,29 +1139,3 @@ ConstraintTree::overlap (
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
|
||||
CTNode*
|
||||
ConstraintTree::copySubtree (const CTNode* n)
|
||||
{
|
||||
CTNode* newNode = new CTNode (*n);
|
||||
const CTNodes& childs = n->childs();
|
||||
for (unsigned i = 0; i < childs.size(); i++) {
|
||||
newNode->addChild (copySubtree (childs[i]));
|
||||
}
|
||||
return newNode;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
ConstraintTree::deleteSubtree (CTNode* n)
|
||||
{
|
||||
assert (n);
|
||||
const CTNodes& childs = n->childs();
|
||||
for (unsigned i = 0; i < childs.size(); i++) {
|
||||
deleteSubtree (childs[i]);
|
||||
}
|
||||
delete n;
|
||||
}
|
||||
|
||||
|
@ -21,7 +21,6 @@ typedef vector<ConstraintTree*> ConstraintTrees;
|
||||
|
||||
|
||||
|
||||
|
||||
class CTNode
|
||||
{
|
||||
public:
|
||||
@ -47,29 +46,42 @@ class CTNode
|
||||
|
||||
bool isLeaf (void) const { return childs_.empty(); }
|
||||
|
||||
void addChild (CTNode*, bool = true);
|
||||
void removeChild (CTNode*);
|
||||
SymbolSet childSymbols (void) const;
|
||||
void addChild (CTNode*, bool = true);
|
||||
|
||||
void removeChild (CTNode*);
|
||||
|
||||
void removeAndDeleteChild (CTNode*);
|
||||
|
||||
void removeAndDeleteAllChilds (void);
|
||||
|
||||
SymbolSet childSymbols (void) const;
|
||||
|
||||
static CTNode* copySubtree (const CTNode*);
|
||||
|
||||
static void deleteSubtree (CTNode*);
|
||||
|
||||
private:
|
||||
void updateChildLevels (CTNode*, unsigned);
|
||||
void updateChildLevels (CTNode*, unsigned);
|
||||
|
||||
Symbol symbol_;
|
||||
CTNodes childs_;
|
||||
unsigned level_;
|
||||
Symbol symbol_;
|
||||
CTNodes childs_;
|
||||
unsigned level_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
ostream& operator<< (ostream &out, const CTNode&);
|
||||
|
||||
|
||||
class ConstraintTree
|
||||
{
|
||||
public:
|
||||
ConstraintTree (unsigned);
|
||||
|
||||
ConstraintTree (const LogVars&);
|
||||
|
||||
ConstraintTree (const LogVars&, const Tuples&);
|
||||
|
||||
ConstraintTree (const ConstraintTree&);
|
||||
|
||||
~ConstraintTree (void);
|
||||
|
||||
CTNode* root (void) const { return root_; }
|
||||
@ -94,94 +106,95 @@ class ConstraintTree
|
||||
assert (LogVarSet (logVars_) == logVarSet_);
|
||||
}
|
||||
|
||||
void addTuple (const Tuple&);
|
||||
bool containsTuple (const Tuple&);
|
||||
void moveToTop (const LogVars&);
|
||||
void moveToBottom (const LogVars&);
|
||||
void join (ConstraintTree*, bool = false);
|
||||
unsigned getLevel (LogVar) const;
|
||||
void rename (LogVar, LogVar);
|
||||
void applySubstitution (const Substitution&);
|
||||
void project (const LogVarSet&);
|
||||
void remove (const LogVarSet&);
|
||||
bool isSingleton (LogVar);
|
||||
LogVarSet singletons (void);
|
||||
TupleSet tupleSet (unsigned = 0) const;
|
||||
TupleSet tupleSet (const LogVars&);
|
||||
unsigned size (void) const;
|
||||
unsigned nrSymbols (LogVar);
|
||||
void exportToGraphViz (const char*, bool = false) const;
|
||||
bool isCountNormalized (const LogVarSet&);
|
||||
unsigned getConditionalCount (const LogVarSet&);
|
||||
TinySet<unsigned> getConditionalCounts (const LogVarSet&);
|
||||
bool isCarteesianProduct (const LogVarSet&) const;
|
||||
void addTuple (const Tuple&);
|
||||
|
||||
bool containsTuple (const Tuple&);
|
||||
|
||||
void moveToTop (const LogVars&);
|
||||
|
||||
void moveToBottom (const LogVars&);
|
||||
|
||||
void join (ConstraintTree*, bool = false);
|
||||
|
||||
unsigned getLevel (LogVar) const;
|
||||
|
||||
void rename (LogVar, LogVar);
|
||||
|
||||
void applySubstitution (const Substitution&);
|
||||
|
||||
void project (const LogVarSet&);
|
||||
|
||||
void remove (const LogVarSet&);
|
||||
|
||||
bool isSingleton (LogVar);
|
||||
|
||||
LogVarSet singletons (void);
|
||||
|
||||
TupleSet tupleSet (unsigned = 0) const;
|
||||
|
||||
TupleSet tupleSet (const LogVars&);
|
||||
|
||||
unsigned size (void) const;
|
||||
|
||||
unsigned nrSymbols (LogVar);
|
||||
|
||||
void exportToGraphViz (const char*, bool = false) const;
|
||||
|
||||
bool isCountNormalized (const LogVarSet&);
|
||||
|
||||
unsigned getConditionalCount (const LogVarSet&);
|
||||
|
||||
TinySet<unsigned> getConditionalCounts (const LogVarSet&);
|
||||
|
||||
bool isCarteesianProduct (const LogVarSet&) const;
|
||||
|
||||
std::pair<ConstraintTree*, ConstraintTree*> split (
|
||||
const Tuple&,
|
||||
unsigned);
|
||||
const Tuple&, unsigned);
|
||||
|
||||
std::pair<ConstraintTree*, ConstraintTree*> split (
|
||||
const ConstraintTree*,
|
||||
unsigned) const;
|
||||
const ConstraintTree*, unsigned) const;
|
||||
|
||||
ConstraintTrees countNormalize (const LogVarSet&);
|
||||
ConstraintTrees countNormalize (const LogVarSet&);
|
||||
|
||||
ConstraintTrees jointCountNormalize (
|
||||
ConstraintTree*,
|
||||
ConstraintTree*,
|
||||
LogVar,
|
||||
LogVar,
|
||||
LogVar);
|
||||
ConstraintTree*, ConstraintTree*, LogVar, LogVar, LogVar);
|
||||
|
||||
static bool identical (
|
||||
const ConstraintTree*,
|
||||
const ConstraintTree*,
|
||||
unsigned);
|
||||
static bool identical (
|
||||
const ConstraintTree*, const ConstraintTree*, unsigned);
|
||||
|
||||
static bool overlap (
|
||||
const ConstraintTree*,
|
||||
const ConstraintTree*,
|
||||
unsigned);
|
||||
static bool overlap (
|
||||
const ConstraintTree*, const ConstraintTree*, unsigned);
|
||||
|
||||
LogVars expand (LogVar);
|
||||
ConstraintTrees ground (LogVar);
|
||||
LogVars expand (LogVar);
|
||||
ConstraintTrees ground (LogVar);
|
||||
|
||||
private:
|
||||
unsigned countTuples (const CTNode*) const;
|
||||
CTNodes getNodesBelow (CTNode*) const;
|
||||
CTNodes getNodesAtLevel (unsigned) const;
|
||||
void swapLogVar (LogVar);
|
||||
bool join (CTNode*, const Tuple&, unsigned, CTNode*);
|
||||
unsigned countTuples (const CTNode*) const;
|
||||
|
||||
bool indenticalSubtrees (
|
||||
const CTNode*,
|
||||
const CTNode*,
|
||||
bool) const;
|
||||
CTNodes getNodesBelow (CTNode*) const;
|
||||
|
||||
void getTuples (
|
||||
CTNode*,
|
||||
Tuples,
|
||||
unsigned,
|
||||
Tuples&,
|
||||
CTNodes&) const;
|
||||
CTNodes getNodesAtLevel (unsigned) const;
|
||||
|
||||
void swapLogVar (LogVar);
|
||||
|
||||
bool join (CTNode*, const Tuple&, unsigned, CTNode*);
|
||||
|
||||
bool indenticalSubtrees (
|
||||
const CTNode*, const CTNode*, bool) const;
|
||||
|
||||
void getTuples (CTNode*, Tuples, unsigned, Tuples&, CTNodes&) const;
|
||||
|
||||
vector<std::pair<CTNode*, unsigned>> countNormalize (
|
||||
const CTNode*,
|
||||
unsigned);
|
||||
const CTNode*, unsigned);
|
||||
|
||||
static void split (
|
||||
CTNode*,
|
||||
CTNode*,
|
||||
CTNodes&,
|
||||
unsigned);
|
||||
static void split (
|
||||
CTNode*, CTNode*, CTNodes&, unsigned);
|
||||
|
||||
static bool overlap (const CTNode*, const CTNode*, unsigned);
|
||||
static CTNode* copySubtree (const CTNode*);
|
||||
static void deleteSubtree (CTNode*);
|
||||
static bool overlap (const CTNode*, const CTNode*, unsigned);
|
||||
|
||||
CTNode* root_;
|
||||
LogVars logVars_;
|
||||
LogVarSet logVarSet_;
|
||||
CTNode* root_;
|
||||
LogVars logVars_;
|
||||
LogVarSet logVarSet_;
|
||||
};
|
||||
|
||||
|
||||
|
@ -1,45 +0,0 @@
|
||||
#ifndef HORUS_DISTRIBUTION_H
|
||||
#define HORUS_DISTRIBUTION_H
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "Horus.h"
|
||||
|
||||
//TODO die die die die die
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
||||
struct Distribution
|
||||
{
|
||||
public:
|
||||
Distribution (int id)
|
||||
{
|
||||
this->id = id;
|
||||
}
|
||||
|
||||
Distribution (const Params& params, int id = -1)
|
||||
{
|
||||
this->id = id;
|
||||
this->params = params;
|
||||
}
|
||||
|
||||
void updateParameters (const Params& params)
|
||||
{
|
||||
this->params = params;
|
||||
}
|
||||
|
||||
bool shared (void)
|
||||
{
|
||||
return id != -1;
|
||||
}
|
||||
|
||||
int id;
|
||||
Params params;
|
||||
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN (Distribution);
|
||||
};
|
||||
|
||||
#endif // HORUS_DISTRIBUTION_H
|
||||
|
@ -1,5 +1,7 @@
|
||||
#include <limits>
|
||||
|
||||
#include <fstream>
|
||||
|
||||
#include "ElimGraph.h"
|
||||
#include "BayesNet.h"
|
||||
|
||||
|
@ -17,15 +17,15 @@ enum ElimHeuristic
|
||||
};
|
||||
|
||||
|
||||
class EgNode : public VarNode {
|
||||
class EgNode : public VarNode
|
||||
{
|
||||
public:
|
||||
EgNode (VarNode* var) : VarNode (var) { }
|
||||
void addNeighbor (EgNode* n)
|
||||
{
|
||||
neighs_.push_back (n);
|
||||
}
|
||||
|
||||
void addNeighbor (EgNode* n) { neighs_.push_back (n); }
|
||||
|
||||
const vector<EgNode*>& neighbors (void) const { return neighs_; }
|
||||
|
||||
private:
|
||||
vector<EgNode*> neighs_;
|
||||
};
|
||||
@ -35,6 +35,7 @@ class ElimGraph
|
||||
{
|
||||
public:
|
||||
ElimGraph (const BayesNet&);
|
||||
|
||||
~ElimGraph (void);
|
||||
|
||||
void addEdge (EgNode* n1, EgNode* n2)
|
||||
@ -43,13 +44,19 @@ class ElimGraph
|
||||
n1->addNeighbor (n2);
|
||||
n2->addNeighbor (n1);
|
||||
}
|
||||
void addNode (EgNode*);
|
||||
EgNode* getEgNode (VarId) const;
|
||||
VarIds getEliminatingOrder (const VarIds&);
|
||||
void printGraphicalModel (void) const;
|
||||
void exportToGraphViz (const char*, bool = true,
|
||||
const VarIds& = VarIds()) const;
|
||||
void setIndexes();
|
||||
|
||||
void addNode (EgNode*);
|
||||
|
||||
EgNode* getEgNode (VarId) const;
|
||||
|
||||
VarIds getEliminatingOrder (const VarIds&);
|
||||
|
||||
void printGraphicalModel (void) const;
|
||||
|
||||
void exportToGraphViz (const char*, bool = true,
|
||||
const VarIds& = VarIds()) const;
|
||||
|
||||
void setIndexes();
|
||||
|
||||
static void setEliminationHeuristic (ElimHeuristic h)
|
||||
{
|
||||
@ -57,14 +64,19 @@ class ElimGraph
|
||||
}
|
||||
|
||||
private:
|
||||
EgNode* getLowestCostNode (void) const;
|
||||
unsigned getNeighborsCost (const EgNode*) const;
|
||||
unsigned getWeightCost (const EgNode*) const;
|
||||
unsigned getFillCost (const EgNode*) const;
|
||||
unsigned getWeightedFillCost (const EgNode*) const;
|
||||
void connectAllNeighbors (const EgNode*);
|
||||
bool neighbors (const EgNode*, const EgNode*) const;
|
||||
EgNode* getLowestCostNode (void) const;
|
||||
|
||||
unsigned getNeighborsCost (const EgNode*) const;
|
||||
|
||||
unsigned getWeightCost (const EgNode*) const;
|
||||
|
||||
unsigned getFillCost (const EgNode*) const;
|
||||
|
||||
unsigned getWeightedFillCost (const EgNode*) const;
|
||||
|
||||
void connectAllNeighbors (const EgNode*);
|
||||
|
||||
bool neighbors (const EgNode*, const EgNode*) const;
|
||||
|
||||
vector<EgNode*> nodes_;
|
||||
vector<bool> marked_;
|
||||
|
@ -8,7 +8,7 @@
|
||||
|
||||
#include "Factor.h"
|
||||
#include "Indexer.h"
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
|
||||
Factor::Factor (const Factor& g)
|
||||
@ -18,206 +18,73 @@ Factor::Factor (const Factor& g)
|
||||
|
||||
|
||||
|
||||
Factor::Factor (VarId vid, unsigned nStates)
|
||||
Factor::Factor (VarId vid, unsigned nrStates)
|
||||
{
|
||||
varids_.push_back (vid);
|
||||
ranges_.push_back (nStates);
|
||||
dist_ = new Distribution (Params (nStates, 1.0));
|
||||
args_.push_back (vid);
|
||||
ranges_.push_back (nrStates);
|
||||
params_.resize (nrStates, 1.0);
|
||||
distId_ = Util::maxUnsigned();
|
||||
assert (params_.size() == Util::expectedSize (ranges_));
|
||||
}
|
||||
|
||||
|
||||
|
||||
Factor::Factor (const VarNodes& vars)
|
||||
{
|
||||
int nParams = 1;
|
||||
int nrParams = 1;
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
varids_.push_back (vars[i]->varId());
|
||||
args_.push_back (vars[i]->varId());
|
||||
ranges_.push_back (vars[i]->nrStates());
|
||||
nParams *= vars[i]->nrStates();
|
||||
nrParams *= vars[i]->nrStates();
|
||||
}
|
||||
// create a uniform distribution
|
||||
double val = 1.0 / nParams;
|
||||
dist_ = new Distribution (Params (nParams, val));
|
||||
double val = 1.0 / nrParams;
|
||||
params_.resize (nrParams, val);
|
||||
distId_ = Util::maxUnsigned();
|
||||
assert (params_.size() == Util::expectedSize (ranges_));
|
||||
}
|
||||
|
||||
|
||||
|
||||
Factor::Factor (VarId vid, unsigned nStates, const Params& params)
|
||||
Factor::Factor (
|
||||
VarId vid,
|
||||
unsigned nrStates,
|
||||
const Params& params)
|
||||
{
|
||||
varids_.push_back (vid);
|
||||
ranges_.push_back (nStates);
|
||||
dist_ = new Distribution (params);
|
||||
}
|
||||
|
||||
|
||||
|
||||
Factor::Factor (VarNodes& vars, Distribution* dist)
|
||||
{
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
varids_.push_back (vars[i]->varId());
|
||||
ranges_.push_back (vars[i]->nrStates());
|
||||
}
|
||||
dist_ = dist;
|
||||
}
|
||||
|
||||
|
||||
|
||||
Factor::Factor (const VarNodes& vars, const Params& params)
|
||||
{
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
varids_.push_back (vars[i]->varId());
|
||||
ranges_.push_back (vars[i]->nrStates());
|
||||
}
|
||||
dist_ = new Distribution (params);
|
||||
}
|
||||
|
||||
|
||||
|
||||
Factor::Factor (const VarIds& vids,
|
||||
const Ranges& ranges,
|
||||
const Params& params)
|
||||
{
|
||||
varids_ = vids;
|
||||
ranges_ = ranges;
|
||||
dist_ = new Distribution (params);
|
||||
}
|
||||
|
||||
|
||||
|
||||
Factor::~Factor (void)
|
||||
{
|
||||
if (dist_->shared() == false) {
|
||||
delete dist_;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::setParameters (const Params& params)
|
||||
{
|
||||
assert (dist_->params.size() == params.size());
|
||||
dist_->params = params;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::copyFromFactor (const Factor& g)
|
||||
{
|
||||
varids_ = g.getVarIds();
|
||||
ranges_ = g.getRanges();
|
||||
dist_ = new Distribution (g.getParameters());
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::multiply (const Factor& g)
|
||||
{
|
||||
if (varids_.size() == 0) {
|
||||
copyFromFactor (g);
|
||||
return;
|
||||
}
|
||||
|
||||
const VarIds& g_varids = g.getVarIds();
|
||||
const Ranges& g_ranges = g.getRanges();
|
||||
const Params& g_params = g.getParameters();
|
||||
|
||||
if (varids_ == g_varids) {
|
||||
// optimization: if the factors contain the same set of variables,
|
||||
// we can do a 1 to 1 operation on the parameters
|
||||
if (Globals::logDomain) {
|
||||
Util::add (dist_->params, g_params);
|
||||
} else {
|
||||
Util::multiply (dist_->params, g_params);
|
||||
}
|
||||
} else {
|
||||
bool sharedVars = false;
|
||||
vector<unsigned> gvarpos;
|
||||
for (unsigned i = 0; i < g_varids.size(); i++) {
|
||||
int idx = indexOf (g_varids[i]);
|
||||
if (idx == -1) {
|
||||
insertVariable (g_varids[i], g_ranges[i]);
|
||||
gvarpos.push_back (varids_.size() - 1);
|
||||
} else {
|
||||
sharedVars = true;
|
||||
gvarpos.push_back (idx);
|
||||
}
|
||||
}
|
||||
if (sharedVars == false) {
|
||||
// optimization: if the original factors doesn't have common variables,
|
||||
// we don't need to marry the states of the common variables
|
||||
unsigned count = 0;
|
||||
for (unsigned i = 0; i < dist_->params.size(); i++) {
|
||||
if (Globals::logDomain) {
|
||||
dist_->params[i] += g_params[count];
|
||||
} else {
|
||||
dist_->params[i] *= g_params[count];
|
||||
}
|
||||
count ++;
|
||||
if (count >= g_params.size()) {
|
||||
count = 0;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
StatesIndexer indexer (ranges_, false);
|
||||
while (indexer.valid()) {
|
||||
unsigned g_li = 0;
|
||||
unsigned prod = 1;
|
||||
for (int j = gvarpos.size() - 1; j >= 0; j--) {
|
||||
g_li += indexer[gvarpos[j]] * prod;
|
||||
prod *= g_ranges[j];
|
||||
}
|
||||
if (Globals::logDomain) {
|
||||
dist_->params[indexer] += g_params[g_li];
|
||||
} else {
|
||||
dist_->params[indexer] *= g_params[g_li];
|
||||
}
|
||||
++ indexer;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::insertVariable (VarId varId, unsigned nrStates)
|
||||
{
|
||||
assert (indexOf (varId) == -1);
|
||||
Params oldParams = dist_->params;
|
||||
dist_->params.clear();
|
||||
dist_->params.reserve (oldParams.size() * nrStates);
|
||||
for (unsigned i = 0; i < oldParams.size(); i++) {
|
||||
for (unsigned reps = 0; reps < nrStates; reps++) {
|
||||
dist_->params.push_back (oldParams[i]);
|
||||
}
|
||||
}
|
||||
varids_.push_back (varId);
|
||||
args_.push_back (vid);
|
||||
ranges_.push_back (nrStates);
|
||||
params_ = params;
|
||||
distId_ = Util::maxUnsigned();
|
||||
assert (params_.size() == Util::expectedSize (ranges_));
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::insertVariables (const VarIds& varIds, const Ranges& ranges)
|
||||
Factor::Factor (
|
||||
const VarNodes& vars,
|
||||
const Params& params,
|
||||
unsigned distId)
|
||||
{
|
||||
Params oldParams = dist_->params;
|
||||
unsigned nrStates = 1;
|
||||
for (unsigned i = 0; i < varIds.size(); i++) {
|
||||
assert (indexOf (varIds[i]) == -1);
|
||||
varids_.push_back (varIds[i]);
|
||||
ranges_.push_back (ranges[i]);
|
||||
nrStates *= ranges[i];
|
||||
}
|
||||
dist_->params.clear();
|
||||
dist_->params.reserve (oldParams.size() * nrStates);
|
||||
for (unsigned i = 0; i < oldParams.size(); i++) {
|
||||
for (unsigned reps = 0; reps < nrStates; reps++) {
|
||||
dist_->params.push_back (oldParams[i]);
|
||||
}
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
args_.push_back (vars[i]->varId());
|
||||
ranges_.push_back (vars[i]->nrStates());
|
||||
}
|
||||
params_ = params;
|
||||
distId_ = distId;
|
||||
assert (params_.size() == Util::expectedSize (ranges_));
|
||||
}
|
||||
|
||||
|
||||
|
||||
Factor::Factor (
|
||||
const VarIds& vids,
|
||||
const Ranges& ranges,
|
||||
const Params& params)
|
||||
{
|
||||
args_ = vids;
|
||||
ranges_ = ranges;
|
||||
params_ = params;
|
||||
distId_ = Util::maxUnsigned();
|
||||
assert (params_.size() == Util::expectedSize (ranges_));
|
||||
}
|
||||
|
||||
|
||||
@ -226,10 +93,10 @@ void
|
||||
Factor::sumOutAllExcept (VarId vid)
|
||||
{
|
||||
assert (indexOf (vid) != -1);
|
||||
while (varids_.back() != vid) {
|
||||
while (args_.back() != vid) {
|
||||
sumOutLastVariable();
|
||||
}
|
||||
while (varids_.front() != vid) {
|
||||
while (args_.front() != vid) {
|
||||
sumOutFirstVariable();
|
||||
}
|
||||
}
|
||||
@ -239,9 +106,10 @@ Factor::sumOutAllExcept (VarId vid)
|
||||
void
|
||||
Factor::sumOutAllExcept (const VarIds& vids)
|
||||
{
|
||||
for (unsigned i = 0; i < varids_.size(); i++) {
|
||||
if (std::find (vids.begin(), vids.end(), varids_[i]) == vids.end()) {
|
||||
sumOut (varids_[i]);
|
||||
for (int i = 0; i < (int)args_.size(); i++) {
|
||||
if (Util::contains (vids, args_[i]) == false) {
|
||||
sumOut (args_[i]);
|
||||
i --;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -254,11 +122,11 @@ Factor::sumOut (VarId vid)
|
||||
int idx = indexOf (vid);
|
||||
assert (idx != -1);
|
||||
|
||||
if (vid == varids_.back()) {
|
||||
if (vid == args_.back()) {
|
||||
sumOutLastVariable(); // optimization
|
||||
return;
|
||||
}
|
||||
if (vid == varids_.front()) {
|
||||
if (vid == args_.front()) {
|
||||
sumOutFirstVariable(); // optimization
|
||||
return;
|
||||
}
|
||||
@ -271,7 +139,7 @@ Factor::sumOut (VarId vid)
|
||||
// on the left of `var', with the states of the remaining vars fixed
|
||||
unsigned leftVarOffset = 1;
|
||||
|
||||
for (int i = varids_.size() - 1; i > idx; i--) {
|
||||
for (int i = args_.size() - 1; i > idx; i--) {
|
||||
varOffset *= ranges_[i];
|
||||
leftVarOffset *= ranges_[i];
|
||||
}
|
||||
@ -280,25 +148,24 @@ Factor::sumOut (VarId vid)
|
||||
unsigned offset = 0;
|
||||
unsigned count1 = 0;
|
||||
unsigned count2 = 0;
|
||||
unsigned newpsSize = dist_->params.size() / ranges_[idx];
|
||||
unsigned newpsSize = params_.size() / ranges_[idx];
|
||||
|
||||
Params newps;
|
||||
newps.reserve (newpsSize);
|
||||
Params& params = dist_->params;
|
||||
|
||||
while (newps.size() < newpsSize) {
|
||||
double sum = Util::addIdenty();
|
||||
double sum = LogAware::addIdenty();
|
||||
for (unsigned i = 0; i < ranges_[idx]; i++) {
|
||||
if (Globals::logDomain) {
|
||||
Util::logSum (sum, params[offset]);
|
||||
sum = Util::logSum (sum, params_[offset]);
|
||||
} else {
|
||||
sum += params[offset];
|
||||
sum += params_[offset];
|
||||
}
|
||||
offset += varOffset;
|
||||
}
|
||||
newps.push_back (sum);
|
||||
count1 ++;
|
||||
if (idx == (int)varids_.size() - 1) {
|
||||
if (idx == (int)args_.size() - 1) {
|
||||
offset = count1 * ranges_[idx];
|
||||
} else {
|
||||
if (((offset - varOffset + 1) % leftVarOffset) == 0) {
|
||||
@ -308,9 +175,9 @@ Factor::sumOut (VarId vid)
|
||||
offset = (leftVarOffset * count2) + count1;
|
||||
}
|
||||
}
|
||||
varids_.erase (varids_.begin() + idx);
|
||||
args_.erase (args_.begin() + idx);
|
||||
ranges_.erase (ranges_.begin() + idx);
|
||||
dist_->params = newps;
|
||||
params_ = newps;
|
||||
}
|
||||
|
||||
|
||||
@ -318,20 +185,19 @@ Factor::sumOut (VarId vid)
|
||||
void
|
||||
Factor::sumOutFirstVariable (void)
|
||||
{
|
||||
Params& params = dist_->params;
|
||||
unsigned nStates = ranges_.front();
|
||||
unsigned sep = params.size() / nStates;
|
||||
unsigned sep = params_.size() / nStates;
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned i = sep; i < params.size(); i++) {
|
||||
Util::logSum (params[i % sep], params[i]);
|
||||
for (unsigned i = sep; i < params_.size(); i++) {
|
||||
params_[i % sep] = Util::logSum (params_[i % sep], params_[i]);
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = sep; i < params.size(); i++) {
|
||||
params[i % sep] += params[i];
|
||||
for (unsigned i = sep; i < params_.size(); i++) {
|
||||
params_[i % sep] += params_[i];
|
||||
}
|
||||
}
|
||||
params.resize (sep);
|
||||
varids_.erase (varids_.begin());
|
||||
params_.resize (sep);
|
||||
args_.erase (args_.begin());
|
||||
ranges_.erase (ranges_.begin());
|
||||
}
|
||||
|
||||
@ -340,143 +206,55 @@ Factor::sumOutFirstVariable (void)
|
||||
void
|
||||
Factor::sumOutLastVariable (void)
|
||||
{
|
||||
Params& params = dist_->params;
|
||||
unsigned nStates = ranges_.back();
|
||||
unsigned idx1 = 0;
|
||||
unsigned idx2 = 0;
|
||||
if (Globals::logDomain) {
|
||||
while (idx1 < params.size()) {
|
||||
params[idx2] = params[idx1];
|
||||
while (idx1 < params_.size()) {
|
||||
params_[idx2] = params_[idx1];
|
||||
idx1 ++;
|
||||
for (unsigned j = 1; j < nStates; j++) {
|
||||
Util::logSum (params[idx2], params[idx1]);
|
||||
params_[idx2] = Util::logSum (params_[idx2], params_[idx1]);
|
||||
idx1 ++;
|
||||
}
|
||||
idx2 ++;
|
||||
}
|
||||
} else {
|
||||
while (idx1 < params.size()) {
|
||||
params[idx2] = params[idx1];
|
||||
while (idx1 < params_.size()) {
|
||||
params_[idx2] = params_[idx1];
|
||||
idx1 ++;
|
||||
for (unsigned j = 1; j < nStates; j++) {
|
||||
params[idx2] += params[idx1];
|
||||
params_[idx2] += params_[idx1];
|
||||
idx1 ++;
|
||||
}
|
||||
idx2 ++;
|
||||
}
|
||||
}
|
||||
params.resize (idx2);
|
||||
varids_.pop_back();
|
||||
params_.resize (idx2);
|
||||
args_.pop_back();
|
||||
ranges_.pop_back();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::orderVariables (void)
|
||||
Factor::multiply (Factor& g)
|
||||
{
|
||||
VarIds sortedVarIds = varids_;
|
||||
sort (sortedVarIds.begin(), sortedVarIds.end());
|
||||
reorderVariables (sortedVarIds);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::reorderVariables (const VarIds& newVarIds)
|
||||
{
|
||||
assert (newVarIds.size() == varids_.size());
|
||||
if (newVarIds == varids_) {
|
||||
if (args_.size() == 0) {
|
||||
copyFromFactor (g);
|
||||
return;
|
||||
}
|
||||
|
||||
Ranges newRanges;
|
||||
vector<unsigned> positions;
|
||||
for (unsigned i = 0; i < newVarIds.size(); i++) {
|
||||
unsigned idx = indexOf (newVarIds[i]);
|
||||
newRanges.push_back (ranges_[idx]);
|
||||
positions.push_back (idx);
|
||||
}
|
||||
|
||||
unsigned N = ranges_.size();
|
||||
Params newParams (dist_->params.size());
|
||||
for (unsigned i = 0; i < dist_->params.size(); i++) {
|
||||
unsigned li = i;
|
||||
// calculate vector index corresponding to linear index
|
||||
vector<unsigned> vi (N);
|
||||
for (int k = N-1; k >= 0; k--) {
|
||||
vi[k] = li % ranges_[k];
|
||||
li /= ranges_[k];
|
||||
}
|
||||
// convert permuted vector index to corresponding linear index
|
||||
unsigned prod = 1;
|
||||
unsigned new_li = 0;
|
||||
for (int k = N-1; k >= 0; k--) {
|
||||
new_li += vi[positions[k]] * prod;
|
||||
prod *= ranges_[positions[k]];
|
||||
}
|
||||
newParams[new_li] = dist_->params[i];
|
||||
}
|
||||
varids_ = newVarIds;
|
||||
ranges_ = newRanges;
|
||||
dist_->params = newParams;
|
||||
TFactor<VarId>::multiply (g);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::absorveEvidence (VarId vid, unsigned evidence)
|
||||
Factor::reorderAccordingVarIds (void)
|
||||
{
|
||||
int idx = indexOf (vid);
|
||||
assert (idx != -1);
|
||||
|
||||
Params oldParams = dist_->params;
|
||||
dist_->params.clear();
|
||||
dist_->params.reserve (oldParams.size() / ranges_[idx]);
|
||||
StatesIndexer indexer (ranges_);
|
||||
for (unsigned i = 0; i < evidence; i++) {
|
||||
indexer.increment (idx);
|
||||
}
|
||||
while (indexer.valid()) {
|
||||
dist_->params.push_back (oldParams[indexer]);
|
||||
indexer.incrementExcluding (idx);
|
||||
}
|
||||
varids_.erase (varids_.begin() + idx);
|
||||
ranges_.erase (ranges_.begin() + idx);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::normalize (void)
|
||||
{
|
||||
Util::normalize (dist_->params);
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
Factor::contains (const VarIds& vars) const
|
||||
{
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
if (indexOf (vars[i]) == -1) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
|
||||
int
|
||||
Factor::indexOf (VarId vid) const
|
||||
{
|
||||
for (unsigned i = 0; i < varids_.size(); i++) {
|
||||
if (varids_[i] == vid) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
VarIds sortedVarIds = args_;
|
||||
sort (sortedVarIds.begin(), sortedVarIds.end());
|
||||
reorderArguments (sortedVarIds);
|
||||
}
|
||||
|
||||
|
||||
@ -486,9 +264,9 @@ Factor::getLabel (void) const
|
||||
{
|
||||
stringstream ss;
|
||||
ss << "f(" ;
|
||||
for (unsigned i = 0; i < varids_.size(); i++) {
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
if (i != 0) ss << "," ;
|
||||
ss << VarNode (varids_[i], ranges_[i]).label();
|
||||
ss << VarNode (args_[i], ranges_[i]).label();
|
||||
}
|
||||
ss << ")" ;
|
||||
return ss.str();
|
||||
@ -500,13 +278,13 @@ void
|
||||
Factor::print (void) const
|
||||
{
|
||||
VarNodes vars;
|
||||
for (unsigned i = 0; i < varids_.size(); i++) {
|
||||
vars.push_back (new VarNode (varids_[i], ranges_[i]));
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
vars.push_back (new VarNode (args_[i], ranges_[i]));
|
||||
}
|
||||
vector<string> jointStrings = Util::getJointStateStrings (vars);
|
||||
for (unsigned i = 0; i < dist_->params.size(); i++) {
|
||||
for (unsigned i = 0; i < params_.size(); i++) {
|
||||
cout << "f(" << jointStrings[i] << ")" ;
|
||||
cout << " = " << dist_->params[i] << endl;
|
||||
cout << " = " << params_[i] << endl;
|
||||
}
|
||||
cout << endl;
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
@ -515,3 +293,13 @@ Factor::print (void) const
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::copyFromFactor (const Factor& g)
|
||||
{
|
||||
args_ = g.arguments();
|
||||
ranges_ = g.ranges();
|
||||
params_ = g.params();
|
||||
distId_ = g.distId();
|
||||
}
|
||||
|
||||
|
@ -3,64 +3,293 @@
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "Distribution.h"
|
||||
#include "VarNode.h"
|
||||
#include "Indexer.h"
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
|
||||
class Distribution;
|
||||
|
||||
template <typename T>
|
||||
class TFactor
|
||||
{
|
||||
public:
|
||||
const vector<T>& arguments (void) const { return args_; }
|
||||
|
||||
vector<T>& arguments (void) { return args_; }
|
||||
|
||||
const Ranges& ranges (void) const { return ranges_; }
|
||||
|
||||
const Params& params (void) const { return params_; }
|
||||
|
||||
Params& params (void) { return params_; }
|
||||
|
||||
unsigned nrArguments (void) const { return args_.size(); }
|
||||
|
||||
unsigned size (void) const { return params_.size(); }
|
||||
|
||||
unsigned distId (void) const { return distId_; }
|
||||
|
||||
void setDistId (unsigned id) { distId_ = id; }
|
||||
|
||||
void setParams (const Params& newParams)
|
||||
{
|
||||
params_ = newParams;
|
||||
assert (params_.size() == Util::expectedSize (ranges_));
|
||||
}
|
||||
|
||||
void normalize (void)
|
||||
{
|
||||
LogAware::normalize (params_);
|
||||
}
|
||||
|
||||
int indexOf (const T& t) const
|
||||
{
|
||||
int idx = -1;
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
if (args_[i] == t) {
|
||||
idx = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return idx;
|
||||
}
|
||||
|
||||
const T& argument (unsigned idx) const
|
||||
{
|
||||
assert (idx < args_.size());
|
||||
return args_[idx];
|
||||
}
|
||||
|
||||
T& argument (unsigned idx)
|
||||
{
|
||||
assert (idx < args_.size());
|
||||
return args_[idx];
|
||||
}
|
||||
|
||||
unsigned range (unsigned idx) const
|
||||
{
|
||||
assert (idx < ranges_.size());
|
||||
return ranges_[idx];
|
||||
}
|
||||
|
||||
void multiply (TFactor<T>& g)
|
||||
{
|
||||
const vector<T>& g_args = g.arguments();
|
||||
const Ranges& g_ranges = g.ranges();
|
||||
const Params& g_params = g.params();
|
||||
if (args_ == g_args) {
|
||||
// optimization: if the factors contain the same set of args,
|
||||
// we can do a 1 to 1 operation on the parameters
|
||||
if (Globals::logDomain) {
|
||||
Util::add (params_, g_params);
|
||||
} else {
|
||||
Util::multiply (params_, g_params);
|
||||
}
|
||||
} else {
|
||||
bool sharedArgs = false;
|
||||
vector<unsigned> gvarpos;
|
||||
for (unsigned i = 0; i < g_args.size(); i++) {
|
||||
int idx = indexOf (g_args[i]);
|
||||
if (idx == -1) {
|
||||
insertArgument (g_args[i], g_ranges[i]);
|
||||
gvarpos.push_back (args_.size() - 1);
|
||||
} else {
|
||||
sharedArgs = true;
|
||||
gvarpos.push_back (idx);
|
||||
}
|
||||
}
|
||||
if (sharedArgs == false) {
|
||||
// optimization: if the original factors doesn't have common args,
|
||||
// we don't need to marry the states of the common args
|
||||
unsigned count = 0;
|
||||
for (unsigned i = 0; i < params_.size(); i++) {
|
||||
if (Globals::logDomain) {
|
||||
params_[i] += g_params[count];
|
||||
} else {
|
||||
params_[i] *= g_params[count];
|
||||
}
|
||||
count ++;
|
||||
if (count >= g_params.size()) {
|
||||
count = 0;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
StatesIndexer indexer (ranges_, false);
|
||||
while (indexer.valid()) {
|
||||
unsigned g_li = 0;
|
||||
unsigned prod = 1;
|
||||
for (int j = gvarpos.size() - 1; j >= 0; j--) {
|
||||
g_li += indexer[gvarpos[j]] * prod;
|
||||
prod *= g_ranges[j];
|
||||
}
|
||||
if (Globals::logDomain) {
|
||||
params_[indexer] += g_params[g_li];
|
||||
} else {
|
||||
params_[indexer] *= g_params[g_li];
|
||||
}
|
||||
++ indexer;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void absorveEvidence (const T& arg, unsigned evidence)
|
||||
{
|
||||
int idx = indexOf (arg);
|
||||
assert (idx != -1);
|
||||
assert (evidence < ranges_[idx]);
|
||||
Params copy = params_;
|
||||
params_.clear();
|
||||
params_.reserve (copy.size() / ranges_[idx]);
|
||||
StatesIndexer indexer (ranges_);
|
||||
for (unsigned i = 0; i < evidence; i++) {
|
||||
indexer.increment (idx);
|
||||
}
|
||||
while (indexer.valid()) {
|
||||
params_.push_back (copy[indexer]);
|
||||
indexer.incrementExcluding (idx);
|
||||
}
|
||||
args_.erase (args_.begin() + idx);
|
||||
ranges_.erase (ranges_.begin() + idx);
|
||||
}
|
||||
|
||||
void reorderArguments (const vector<T> newArgs)
|
||||
{
|
||||
assert (newArgs.size() == args_.size());
|
||||
if (newArgs == args_) {
|
||||
return; // already in the wanted order
|
||||
}
|
||||
Ranges newRanges;
|
||||
vector<unsigned> positions;
|
||||
for (unsigned i = 0; i < newArgs.size(); i++) {
|
||||
unsigned idx = indexOf (newArgs[i]);
|
||||
newRanges.push_back (ranges_[idx]);
|
||||
positions.push_back (idx);
|
||||
}
|
||||
unsigned N = ranges_.size();
|
||||
Params newParams (params_.size());
|
||||
for (unsigned i = 0; i < params_.size(); i++) {
|
||||
unsigned li = i;
|
||||
// calculate vector index corresponding to linear index
|
||||
vector<unsigned> vi (N);
|
||||
for (int k = N-1; k >= 0; k--) {
|
||||
vi[k] = li % ranges_[k];
|
||||
li /= ranges_[k];
|
||||
}
|
||||
// convert permuted vector index to corresponding linear index
|
||||
unsigned prod = 1;
|
||||
unsigned new_li = 0;
|
||||
for (int k = N - 1; k >= 0; k--) {
|
||||
new_li += vi[positions[k]] * prod;
|
||||
prod *= ranges_[positions[k]];
|
||||
}
|
||||
newParams[new_li] = params_[i];
|
||||
}
|
||||
args_ = newArgs;
|
||||
ranges_ = newRanges;
|
||||
params_ = newParams;
|
||||
}
|
||||
|
||||
bool contains (const T& arg) const
|
||||
{
|
||||
return Util::contains (args_, arg);
|
||||
}
|
||||
|
||||
bool contains (const vector<T>& args) const
|
||||
{
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
if (contains (args[i]) == false) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
vector<T> args_;
|
||||
Ranges ranges_;
|
||||
Params params_;
|
||||
unsigned distId_;
|
||||
|
||||
private:
|
||||
void insertArgument (const T& arg, unsigned range)
|
||||
{
|
||||
assert (indexOf (arg) == -1);
|
||||
Params copy = params_;
|
||||
params_.clear();
|
||||
params_.reserve (copy.size() * range);
|
||||
for (unsigned i = 0; i < copy.size(); i++) {
|
||||
for (unsigned reps = 0; reps < range; reps++) {
|
||||
params_.push_back (copy[i]);
|
||||
}
|
||||
}
|
||||
args_.push_back (arg);
|
||||
ranges_.push_back (range);
|
||||
}
|
||||
|
||||
void insertArguments (const vector<T>& args, const Ranges& ranges)
|
||||
{
|
||||
Params copy = params_;
|
||||
unsigned nrStates = 1;
|
||||
for (unsigned i = 0; i < args.size(); i++) {
|
||||
assert (indexOf (args[i]) == -1);
|
||||
args_.push_back (args[i]);
|
||||
ranges_.push_back (ranges[i]);
|
||||
nrStates *= ranges[i];
|
||||
}
|
||||
params_.clear();
|
||||
params_.reserve (copy.size() * nrStates);
|
||||
for (unsigned i = 0; i < copy.size(); i++) {
|
||||
for (unsigned reps = 0; reps < nrStates; reps++) {
|
||||
params_.push_back (copy[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
class Factor
|
||||
|
||||
class Factor : public TFactor<VarId>
|
||||
{
|
||||
public:
|
||||
Factor (void) { }
|
||||
|
||||
Factor (const Factor&);
|
||||
|
||||
Factor (VarId, unsigned);
|
||||
|
||||
Factor (const VarNodes&);
|
||||
|
||||
Factor (VarId, unsigned, const Params&);
|
||||
Factor (VarNodes&, Distribution*);
|
||||
Factor (const VarNodes&, const Params&);
|
||||
|
||||
Factor (const VarNodes&, const Params&,
|
||||
unsigned = Util::maxUnsigned());
|
||||
|
||||
Factor (const VarIds&, const Ranges&, const Params&);
|
||||
~Factor (void);
|
||||
|
||||
void setParameters (const Params&);
|
||||
void copyFromFactor (const Factor& f);
|
||||
void multiply (const Factor&);
|
||||
void insertVariable (VarId, unsigned);
|
||||
void insertVariables (const VarIds&, const Ranges&);
|
||||
void sumOutAllExcept (VarId);
|
||||
void sumOutAllExcept (const VarIds&);
|
||||
void sumOut (VarId);
|
||||
void sumOutFirstVariable (void);
|
||||
void sumOutLastVariable (void);
|
||||
void orderVariables (void);
|
||||
void reorderVariables (const VarIds&);
|
||||
void absorveEvidence (VarId, unsigned);
|
||||
void normalize (void);
|
||||
bool contains (const VarIds&) const;
|
||||
int indexOf (VarId) const;
|
||||
string getLabel (void) const;
|
||||
void print (void) const;
|
||||
void sumOutAllExcept (VarId);
|
||||
|
||||
const VarIds& getVarIds (void) const { return varids_; }
|
||||
const Ranges& getRanges (void) const { return ranges_; }
|
||||
const Params& getParameters (void) const { return dist_->params; }
|
||||
Distribution* getDistribution (void) const { return dist_; }
|
||||
unsigned nrVariables (void) const { return varids_.size(); }
|
||||
unsigned nrParameters() const { return dist_->params.size(); }
|
||||
void sumOutAllExcept (const VarIds&);
|
||||
|
||||
void setDistribution (Distribution* dist)
|
||||
{
|
||||
dist_ = dist;
|
||||
}
|
||||
void sumOut (VarId);
|
||||
|
||||
void sumOutFirstVariable (void);
|
||||
|
||||
void sumOutLastVariable (void);
|
||||
|
||||
void multiply (Factor&);
|
||||
|
||||
void reorderAccordingVarIds (void);
|
||||
|
||||
string getLabel (void) const;
|
||||
|
||||
void print (void) const;
|
||||
|
||||
private:
|
||||
void copyFromFactor (const Factor& f);
|
||||
|
||||
VarIds varids_;
|
||||
Ranges ranges_;
|
||||
Distribution* dist_;
|
||||
};
|
||||
|
||||
#endif // HORUS_FACTOR_H
|
||||
|
@ -53,10 +53,10 @@ FactorGraph::FactorGraph (const BayesNet& bn)
|
||||
neighs.push_back (varNodes_[parents[j]->getIndex()]);
|
||||
}
|
||||
FgFacNode* fn = new FgFacNode (
|
||||
new Factor (neighs, nodes[i]->getDistribution()));
|
||||
new Factor (neighs, nodes[i]->params(), nodes[i]->distId()));
|
||||
if (orderFactorVariables) {
|
||||
sort (neighs.begin(), neighs.end(), CompVarId());
|
||||
fn->factor()->orderVariables();
|
||||
fn->factor()->reorderAccordingVarIds();
|
||||
}
|
||||
addFactor (fn);
|
||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||
@ -131,10 +131,10 @@ FactorGraph::readFromUaiFormat (const char* fileName)
|
||||
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
||||
unsigned nParams;
|
||||
is >> nParams;
|
||||
if (facNodes_[i]->getParameters().size() != nParams) {
|
||||
if (facNodes_[i]->params().size() != nParams) {
|
||||
cerr << "error: invalid number of parameters for factor " ;
|
||||
cerr << facNodes_[i]->getLabel() ;
|
||||
cerr << ", expected: " << facNodes_[i]->getParameters().size();
|
||||
cerr << ", expected: " << facNodes_[i]->params().size();
|
||||
cerr << ", given: " << nParams << endl;
|
||||
abort();
|
||||
}
|
||||
@ -147,7 +147,7 @@ FactorGraph::readFromUaiFormat (const char* fileName)
|
||||
if (Globals::logDomain) {
|
||||
Util::toLog (params);
|
||||
}
|
||||
facNodes_[i]->factor()->setParameters (params);
|
||||
facNodes_[i]->factor()->setParams (params);
|
||||
}
|
||||
is.close();
|
||||
setIndexes();
|
||||
@ -335,21 +335,6 @@ FactorGraph::setIndexes (void)
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::freeDistributions (void)
|
||||
{
|
||||
set<Distribution*> dists;
|
||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||
dists.insert (facNodes_[i]->factor()->getDistribution());
|
||||
}
|
||||
for (set<Distribution*>::iterator it = dists.begin();
|
||||
it != dists.end(); it++) {
|
||||
delete *it;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::printGraphicalModel (void) const
|
||||
{
|
||||
@ -440,7 +425,7 @@ FactorGraph::exportToUaiFormat (const char* fileName) const
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||
Params params = facNodes_[i]->getParameters();
|
||||
Params params = facNodes_[i]->params();
|
||||
if (Globals::logDomain) {
|
||||
Util::fromLog (params);
|
||||
}
|
||||
@ -477,7 +462,7 @@ FactorGraph::exportToLibDaiFormat (const char* fileName) const
|
||||
out << factorVars[j]->nrStates() << " " ;
|
||||
}
|
||||
out << endl;
|
||||
Params params = facNodes_[i]->factor()->getParameters();
|
||||
Params params = facNodes_[i]->factor()->params();
|
||||
if (Globals::logDomain) {
|
||||
Util::fromLog (params);
|
||||
}
|
||||
|
@ -4,7 +4,6 @@
|
||||
#include <vector>
|
||||
|
||||
#include "GraphicalModel.h"
|
||||
#include "Distribution.h"
|
||||
#include "Factor.h"
|
||||
#include "Horus.h"
|
||||
|
||||
@ -13,18 +12,21 @@ using namespace std;
|
||||
class BayesNet;
|
||||
class FgFacNode;
|
||||
|
||||
|
||||
class FgVarNode : public VarNode
|
||||
{
|
||||
public:
|
||||
FgVarNode (VarId varId, unsigned nrStates) : VarNode (varId, nrStates) { }
|
||||
|
||||
FgVarNode (const VarNode* v) : VarNode (v) { }
|
||||
|
||||
void addNeighbor (FgFacNode* fn) { neighs_.push_back (fn); }
|
||||
const FgFacSet& neighbors (void) const { return neighs_; }
|
||||
void addNeighbor (FgFacNode* fn) { neighs_.push_back (fn); }
|
||||
|
||||
const FgFacSet& neighbors (void) const { return neighs_; }
|
||||
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN (FgVarNode);
|
||||
// members
|
||||
|
||||
FgFacSet neighs_;
|
||||
};
|
||||
|
||||
@ -32,13 +34,18 @@ class FgVarNode : public VarNode
|
||||
class FgFacNode
|
||||
{
|
||||
public:
|
||||
FgFacNode (const FgFacNode* fn) {
|
||||
FgFacNode (const FgFacNode* fn)
|
||||
{
|
||||
factor_ = new Factor (*fn->factor());
|
||||
index_ = -1;
|
||||
}
|
||||
|
||||
FgFacNode (Factor* f) : factor_(new Factor(*f)), index_(-1) { }
|
||||
Factor* factor() const { return factor_; }
|
||||
void addNeighbor (FgVarNode* vn) { neighs_.push_back (vn); }
|
||||
|
||||
Factor* factor() const { return factor_; }
|
||||
|
||||
void addNeighbor (FgVarNode* vn) { neighs_.push_back (vn); }
|
||||
|
||||
const FgVarSet& neighbors (void) const { return neighs_; }
|
||||
|
||||
int getIndex (void) const
|
||||
@ -46,28 +53,28 @@ class FgFacNode
|
||||
assert (index_ != -1);
|
||||
return index_;
|
||||
}
|
||||
|
||||
void setIndex (int index)
|
||||
{
|
||||
index_ = index;
|
||||
}
|
||||
Distribution* getDistribution (void)
|
||||
|
||||
const Params& params (void) const
|
||||
{
|
||||
return factor_->getDistribution();
|
||||
}
|
||||
const Params& getParameters (void) const
|
||||
{
|
||||
return factor_->getParameters();
|
||||
return factor_->params();
|
||||
}
|
||||
|
||||
string getLabel (void)
|
||||
{
|
||||
return factor_->getLabel();
|
||||
}
|
||||
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN (FgFacNode);
|
||||
|
||||
Factor* factor_;
|
||||
int index_;
|
||||
FgVarSet neighs_;
|
||||
Factor* factor_;
|
||||
FgVarSet neighs_;
|
||||
int index_;
|
||||
};
|
||||
|
||||
|
||||
@ -83,29 +90,17 @@ struct CompVarId
|
||||
class FactorGraph : public GraphicalModel
|
||||
{
|
||||
public:
|
||||
FactorGraph (void) {};
|
||||
FactorGraph (void) { };
|
||||
|
||||
FactorGraph (const FactorGraph&);
|
||||
|
||||
FactorGraph (const BayesNet&);
|
||||
|
||||
~FactorGraph (void);
|
||||
|
||||
void readFromUaiFormat (const char*);
|
||||
void readFromLibDaiFormat (const char*);
|
||||
void addVariable (FgVarNode*);
|
||||
void addFactor (FgFacNode*);
|
||||
void addEdge (FgVarNode*, FgFacNode*);
|
||||
void addEdge (FgFacNode*, FgVarNode*);
|
||||
VarNode* getVariableNode (unsigned) const;
|
||||
VarNodes getVariableNodes (void) const;
|
||||
bool isTree (void) const;
|
||||
void setIndexes (void);
|
||||
void freeDistributions (void);
|
||||
void printGraphicalModel (void) const;
|
||||
void exportToGraphViz (const char*) const;
|
||||
void exportToUaiFormat (const char*) const;
|
||||
void exportToLibDaiFormat (const char*) const;
|
||||
|
||||
const FgVarSet& getVarNodes (void) const { return varNodes_; }
|
||||
const FgFacSet& getFactorNodes (void) const { return facNodes_; }
|
||||
const FgVarSet& getVarNodes (void) const { return varNodes_; }
|
||||
|
||||
const FgFacSet& getFactorNodes (void) const { return facNodes_; }
|
||||
|
||||
FgVarNode* getFgVarNode (VarId vid) const
|
||||
{
|
||||
@ -117,21 +112,52 @@ class FactorGraph : public GraphicalModel
|
||||
}
|
||||
}
|
||||
|
||||
void readFromUaiFormat (const char*);
|
||||
|
||||
void readFromLibDaiFormat (const char*);
|
||||
|
||||
void addVariable (FgVarNode*);
|
||||
|
||||
void addFactor (FgFacNode*);
|
||||
|
||||
void addEdge (FgVarNode*, FgFacNode*);
|
||||
|
||||
void addEdge (FgFacNode*, FgVarNode*);
|
||||
|
||||
VarNode* getVariableNode (unsigned) const;
|
||||
|
||||
VarNodes getVariableNodes (void) const;
|
||||
|
||||
bool isTree (void) const;
|
||||
|
||||
void setIndexes (void);
|
||||
|
||||
void printGraphicalModel (void) const;
|
||||
|
||||
void exportToGraphViz (const char*) const;
|
||||
|
||||
void exportToUaiFormat (const char*) const;
|
||||
|
||||
void exportToLibDaiFormat (const char*) const;
|
||||
|
||||
static bool orderFactorVariables;
|
||||
|
||||
private:
|
||||
//DISALLOW_COPY_AND_ASSIGN (FactorGraph);
|
||||
bool containsCycle (void) const;
|
||||
bool containsCycle (const FgVarNode*, const FgFacNode*,
|
||||
vector<bool>&, vector<bool>&) const;
|
||||
bool containsCycle (const FgFacNode*, const FgVarNode*,
|
||||
vector<bool>&, vector<bool>&) const;
|
||||
// DISALLOW_COPY_AND_ASSIGN (FactorGraph);
|
||||
|
||||
FgVarSet varNodes_;
|
||||
FgFacSet facNodes_;
|
||||
bool containsCycle (void) const;
|
||||
|
||||
bool containsCycle (const FgVarNode*, const FgFacNode*,
|
||||
vector<bool>&, vector<bool>&) const;
|
||||
|
||||
bool containsCycle (const FgFacNode*, const FgVarNode*,
|
||||
vector<bool>&, vector<bool>&) const;
|
||||
|
||||
FgVarSet varNodes_;
|
||||
FgFacSet facNodes_;
|
||||
|
||||
typedef unordered_map<unsigned, unsigned> IndexMap;
|
||||
IndexMap varMap_;
|
||||
IndexMap varMap_;
|
||||
};
|
||||
|
||||
#endif // HORUS_FACTORGRAPH_H
|
||||
|
@ -38,11 +38,11 @@ void
|
||||
FgBpSolver::runSolver (void)
|
||||
{
|
||||
clock_t start;
|
||||
if (COLLECT_STATISTICS) {
|
||||
if (Constants::COLLECT_STATS) {
|
||||
start = clock();
|
||||
}
|
||||
runLoopySolver();
|
||||
if (DL >= 2) {
|
||||
if (Constants::DEBUG >= 2) {
|
||||
cout << endl;
|
||||
if (nIters_ < BpOptions::maxIter) {
|
||||
cout << "Sum-Product converged in " ;
|
||||
@ -53,18 +53,13 @@ FgBpSolver::runSolver (void)
|
||||
}
|
||||
}
|
||||
unsigned size = factorGraph_->getVarNodes().size();
|
||||
if (COLLECT_STATISTICS) {
|
||||
if (Constants::COLLECT_STATS) {
|
||||
unsigned nIters = 0;
|
||||
bool loopy = factorGraph_->isTree() == false;
|
||||
if (loopy) nIters = nIters_;
|
||||
double time = (double (clock() - start)) / CLOCKS_PER_SEC;
|
||||
Statistics::updateStatistics (size, loopy, nIters, time);
|
||||
}
|
||||
if (EXPORT_TO_GRAPHVIZ && size > EXPORT_MINIMAL_SIZE) {
|
||||
stringstream ss;
|
||||
ss << Statistics::getSolvedNetworksCounting() << "." << size << ".dot" ;
|
||||
factorGraph_->exportToGraphViz (ss.str().c_str());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -76,22 +71,22 @@ FgBpSolver::getPosterioriOf (VarId vid)
|
||||
FgVarNode* var = factorGraph_->getFgVarNode (vid);
|
||||
Params probs;
|
||||
if (var->hasEvidence()) {
|
||||
probs.resize (var->nrStates(), Util::noEvidence());
|
||||
probs[var->getEvidence()] = Util::withEvidence();
|
||||
probs.resize (var->nrStates(), LogAware::noEvidence());
|
||||
probs[var->getEvidence()] = LogAware::withEvidence();
|
||||
} else {
|
||||
probs.resize (var->nrStates(), Util::multIdenty());
|
||||
probs.resize (var->nrStates(), LogAware::multIdenty());
|
||||
const SpLinkSet& links = ninf(var)->getLinks();
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
Util::add (probs, links[i]->getMessage());
|
||||
}
|
||||
Util::normalize (probs);
|
||||
LogAware::normalize (probs);
|
||||
Util::fromLog (probs);
|
||||
} else {
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
Util::multiply (probs, links[i]->getMessage());
|
||||
}
|
||||
Util::normalize (probs);
|
||||
LogAware::normalize (probs);
|
||||
}
|
||||
}
|
||||
return probs;
|
||||
@ -102,9 +97,9 @@ FgBpSolver::getPosterioriOf (VarId vid)
|
||||
Params
|
||||
FgBpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
||||
{
|
||||
int idx = -1;
|
||||
FgVarNode* vn = factorGraph_->getFgVarNode (jointVarIds[0]);
|
||||
const FgFacSet& factorNodes = vn->neighbors();
|
||||
int idx = -1;
|
||||
for (unsigned i = 0; i < factorNodes.size(); i++) {
|
||||
if (factorNodes[i]->factor()->contains (jointVarIds)) {
|
||||
idx = i;
|
||||
@ -114,18 +109,18 @@ FgBpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
||||
if (idx == -1) {
|
||||
return getJointByConditioning (jointVarIds);
|
||||
} else {
|
||||
Factor r (*factorNodes[idx]->factor());
|
||||
Factor res (*factorNodes[idx]->factor());
|
||||
const SpLinkSet& links = ninf(factorNodes[idx])->getLinks();
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
Factor msg (links[i]->getVariable()->varId(),
|
||||
links[i]->getVariable()->nrStates(),
|
||||
getVar2FactorMsg (links[i]));
|
||||
r.multiply (msg);
|
||||
res.multiply (msg);
|
||||
}
|
||||
r.sumOutAllExcept (jointVarIds);
|
||||
r.reorderVariables (jointVarIds);
|
||||
r.normalize();
|
||||
Params jointDist = r.getParameters();
|
||||
res.sumOutAllExcept (jointVarIds);
|
||||
res.reorderArguments (jointVarIds);
|
||||
res.normalize();
|
||||
Params jointDist = res.params();
|
||||
if (Globals::logDomain) {
|
||||
Util::fromLog (jointDist);
|
||||
}
|
||||
@ -144,13 +139,8 @@ FgBpSolver::runLoopySolver (void)
|
||||
while (!converged() && nIters_ < BpOptions::maxIter) {
|
||||
|
||||
nIters_ ++;
|
||||
if (DL >= 2) {
|
||||
cout << "****************************************" ;
|
||||
cout << "****************************************" ;
|
||||
cout << endl;
|
||||
cout << " Iteration " << nIters_ << endl;
|
||||
cout << "****************************************" ;
|
||||
cout << "****************************************" ;
|
||||
if (Constants::DEBUG >= 2) {
|
||||
Util::printHeader (" Iteration " + nIters_);
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
@ -178,7 +168,7 @@ FgBpSolver::runLoopySolver (void)
|
||||
maxResidualSchedule();
|
||||
break;
|
||||
}
|
||||
if (DL >= 2) {
|
||||
if (Constants::DEBUG >= 2) {
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
@ -256,12 +246,12 @@ FgBpSolver::converged (void)
|
||||
} else {
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
double residual = links_[i]->getResidual();
|
||||
if (DL >= 2) {
|
||||
if (Constants::DEBUG >= 2) {
|
||||
cout << links_[i]->toString() + " residual = " << residual << endl;
|
||||
}
|
||||
if (residual > BpOptions::accuracy) {
|
||||
converged = false;
|
||||
if (DL == 0) break;
|
||||
if (Constants::DEBUG == 0) break;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -283,7 +273,7 @@ FgBpSolver::maxResidualSchedule (void)
|
||||
}
|
||||
|
||||
for (unsigned c = 0; c < links_.size(); c++) {
|
||||
if (DL >= 2) {
|
||||
if (Constants::DEBUG >= 2) {
|
||||
cout << "current residuals:" << endl;
|
||||
for (SortedOrder::iterator it = sortedOrder_.begin();
|
||||
it != sortedOrder_.end(); it ++) {
|
||||
@ -317,9 +307,8 @@ FgBpSolver::maxResidualSchedule (void)
|
||||
}
|
||||
}
|
||||
}
|
||||
if (DL >= 2) {
|
||||
cout << "----------------------------------------" ;
|
||||
cout << "----------------------------------------" << endl;
|
||||
if (Constants::DEBUG >= 2) {
|
||||
Util::printDashedLine();
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -339,7 +328,7 @@ FgBpSolver::calculateFactor2VariableMsg (SpLink* link) const
|
||||
msgSize *= links[i]->getVariable()->nrStates();
|
||||
}
|
||||
unsigned repetitions = 1;
|
||||
Params msgProduct (msgSize, Util::multIdenty());
|
||||
Params msgProduct (msgSize, LogAware::multIdenty());
|
||||
if (Globals::logDomain) {
|
||||
for (int i = links.size() - 1; i >= 0; i--) {
|
||||
if (links[i]->getVariable() != dst) {
|
||||
@ -354,7 +343,7 @@ FgBpSolver::calculateFactor2VariableMsg (SpLink* link) const
|
||||
} else {
|
||||
for (int i = links.size() - 1; i >= 0; i--) {
|
||||
if (links[i]->getVariable() != dst) {
|
||||
if (DL >= 5) {
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " message from " << links[i]->getVariable()->label();
|
||||
cout << ": " << endl;
|
||||
}
|
||||
@ -368,34 +357,29 @@ FgBpSolver::calculateFactor2VariableMsg (SpLink* link) const
|
||||
}
|
||||
}
|
||||
|
||||
Factor result (src->factor()->getVarIds(),
|
||||
src->factor()->getRanges(),
|
||||
Factor result (src->factor()->arguments(),
|
||||
src->factor()->ranges(),
|
||||
msgProduct);
|
||||
result.multiply (*(src->factor()));
|
||||
if (DL >= 5) {
|
||||
cout << " message product: " ;
|
||||
cout << Util::parametersToString (msgProduct) << endl;
|
||||
cout << " original factor: " ;
|
||||
cout << Util::parametersToString (src->getParameters()) << endl;
|
||||
cout << " factor product: " ;
|
||||
cout << Util::parametersToString (result.getParameters()) << endl;
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " message product: " << msgProduct << endl;
|
||||
cout << " original factor: " << src->params() << endl;
|
||||
cout << " factor product: " << result.params() << endl;
|
||||
}
|
||||
result.sumOutAllExcept (dst->varId());
|
||||
if (DL >= 5) {
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " marginalized: " ;
|
||||
cout << Util::parametersToString (result.getParameters()) << endl;
|
||||
cout << result.params() << endl;
|
||||
}
|
||||
const Params& resultParams = result.getParameters();
|
||||
const Params& resultParams = result.params();
|
||||
Params& message = link->getNextMessage();
|
||||
for (unsigned i = 0; i < resultParams.size(); i++) {
|
||||
message[i] = resultParams[i];
|
||||
}
|
||||
Util::normalize (message);
|
||||
if (DL >= 5) {
|
||||
cout << " curr msg: " ;
|
||||
cout << Util::parametersToString (link->getMessage()) << endl;
|
||||
cout << " next msg: " ;
|
||||
cout << Util::parametersToString (message) << endl;
|
||||
LogAware::normalize (message);
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " curr msg: " << link->getMessage() << endl;
|
||||
cout << " next msg: " << message << endl;
|
||||
}
|
||||
}
|
||||
|
||||
@ -408,16 +392,16 @@ FgBpSolver::getVar2FactorMsg (const SpLink* link) const
|
||||
const FgFacNode* dst = link->getFactor();
|
||||
Params msg;
|
||||
if (src->hasEvidence()) {
|
||||
msg.resize (src->nrStates(), Util::noEvidence());
|
||||
msg[src->getEvidence()] = Util::withEvidence();
|
||||
if (DL >= 5) {
|
||||
cout << Util::parametersToString (msg);
|
||||
msg.resize (src->nrStates(), LogAware::noEvidence());
|
||||
msg[src->getEvidence()] = LogAware::withEvidence();
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << msg;
|
||||
}
|
||||
} else {
|
||||
msg.resize (src->nrStates(), Util::one());
|
||||
msg.resize (src->nrStates(), LogAware::one());
|
||||
}
|
||||
if (DL >= 5) {
|
||||
cout << Util::parametersToString (msg);
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << msg;
|
||||
}
|
||||
const SpLinkSet& links = ninf (src)->getLinks();
|
||||
if (Globals::logDomain) {
|
||||
@ -430,14 +414,14 @@ FgBpSolver::getVar2FactorMsg (const SpLink* link) const
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
if (links[i]->getFactor() != dst) {
|
||||
Util::multiply (msg, links[i]->getMessage());
|
||||
if (DL >= 5) {
|
||||
cout << " x " << Util::parametersToString (links[i]->getMessage());
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " x " << links[i]->getMessage();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (DL >= 5) {
|
||||
cout << " = " << Util::parametersToString (msg);
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " = " << msg;
|
||||
}
|
||||
return msg;
|
||||
}
|
||||
@ -503,9 +487,9 @@ FgBpSolver::printLinkInformation (void) const
|
||||
SpLink* l = links_[i];
|
||||
cout << l->toString() << ":" << endl;
|
||||
cout << " curr msg = " ;
|
||||
cout << Util::parametersToString (l->getMessage()) << endl;
|
||||
cout << l->getMessage() << endl;
|
||||
cout << " next msg = " ;
|
||||
cout << Util::parametersToString (l->getNextMessage()) << endl;
|
||||
cout << l->getNextMessage() << endl;
|
||||
cout << " residual = " << l->getResidual() << endl;
|
||||
}
|
||||
}
|
||||
|
@ -13,7 +13,6 @@
|
||||
using namespace std;
|
||||
|
||||
|
||||
|
||||
class SpLink
|
||||
{
|
||||
public:
|
||||
@ -21,15 +20,34 @@ class SpLink
|
||||
{
|
||||
fac_ = fn;
|
||||
var_ = vn;
|
||||
v1_.resize (vn->nrStates(), Util::tl (1.0 / vn->nrStates()));
|
||||
v2_.resize (vn->nrStates(), Util::tl (1.0 / vn->nrStates()));
|
||||
v1_.resize (vn->nrStates(), LogAware::tl (1.0 / vn->nrStates()));
|
||||
v2_.resize (vn->nrStates(), LogAware::tl (1.0 / vn->nrStates()));
|
||||
currMsg_ = &v1_;
|
||||
nextMsg_ = &v2_;
|
||||
msgSended_ = false;
|
||||
residual_ = 0.0;
|
||||
}
|
||||
|
||||
virtual ~SpLink (void) {};
|
||||
virtual ~SpLink (void) { };
|
||||
|
||||
FgFacNode* getFactor (void) const { return fac_; }
|
||||
|
||||
FgVarNode* getVariable (void) const { return var_; }
|
||||
|
||||
const Params& getMessage (void) const { return *currMsg_; }
|
||||
|
||||
Params& getNextMessage (void) { return *nextMsg_; }
|
||||
|
||||
bool messageWasSended (void) const { return msgSended_; }
|
||||
|
||||
double getResidual (void) const { return residual_; }
|
||||
|
||||
void clearResidual (void) { residual_ = 0.0; }
|
||||
|
||||
void updateResidual (void)
|
||||
{
|
||||
residual_ = LogAware::getMaxNorm (v1_,v2_);
|
||||
}
|
||||
|
||||
virtual void updateMessage (void)
|
||||
{
|
||||
@ -37,11 +55,6 @@ class SpLink
|
||||
msgSended_ = true;
|
||||
}
|
||||
|
||||
void updateResidual (void)
|
||||
{
|
||||
residual_ = Util::getMaxNorm (v1_, v2_);
|
||||
}
|
||||
|
||||
string toString (void) const
|
||||
{
|
||||
stringstream ss;
|
||||
@ -50,38 +63,28 @@ class SpLink
|
||||
ss << var_->label();
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
FgFacNode* getFactor (void) const { return fac_; }
|
||||
FgVarNode* getVariable (void) const { return var_; }
|
||||
const Params& getMessage (void) const { return *currMsg_; }
|
||||
Params& getNextMessage (void) { return *nextMsg_; }
|
||||
bool messageWasSended (void) const { return msgSended_; }
|
||||
double getResidual (void) const { return residual_; }
|
||||
void clearResidual (void) { residual_ = 0.0; }
|
||||
|
||||
protected:
|
||||
FgFacNode* fac_;
|
||||
FgVarNode* var_;
|
||||
Params v1_;
|
||||
Params v2_;
|
||||
Params* currMsg_;
|
||||
Params* nextMsg_;
|
||||
bool msgSended_;
|
||||
double residual_;
|
||||
FgFacNode* fac_;
|
||||
FgVarNode* var_;
|
||||
Params v1_;
|
||||
Params v2_;
|
||||
Params* currMsg_;
|
||||
Params* nextMsg_;
|
||||
bool msgSended_;
|
||||
double residual_;
|
||||
};
|
||||
|
||||
|
||||
typedef vector<SpLink*> SpLinkSet;
|
||||
|
||||
|
||||
class SPNodeInfo
|
||||
{
|
||||
public:
|
||||
void addSpLink (SpLink* link) { links_.push_back (link); }
|
||||
const SpLinkSet& getLinks (void) { return links_; }
|
||||
|
||||
void addSpLink (SpLink* link) { links_.push_back (link); }
|
||||
const SpLinkSet& getLinks (void) { return links_; }
|
||||
private:
|
||||
SpLinkSet links_;
|
||||
SpLinkSet links_;
|
||||
};
|
||||
|
||||
|
||||
@ -89,51 +92,29 @@ class FgBpSolver : public Solver
|
||||
{
|
||||
public:
|
||||
FgBpSolver (const FactorGraph&);
|
||||
|
||||
virtual ~FgBpSolver (void);
|
||||
|
||||
void runSolver (void);
|
||||
virtual Params getPosterioriOf (VarId);
|
||||
virtual Params getJointDistributionOf (const VarIds&);
|
||||
void runSolver (void);
|
||||
|
||||
virtual Params getPosterioriOf (VarId);
|
||||
|
||||
virtual Params getJointDistributionOf (const VarIds&);
|
||||
|
||||
protected:
|
||||
virtual void initializeSolver (void);
|
||||
virtual void createLinks (void);
|
||||
virtual void maxResidualSchedule (void);
|
||||
virtual void calculateFactor2VariableMsg (SpLink*) const;
|
||||
virtual Params getVar2FactorMsg (const SpLink*) const;
|
||||
virtual Params getJointByConditioning (const VarIds&) const;
|
||||
virtual void printLinkInformation (void) const;
|
||||
virtual void initializeSolver (void);
|
||||
|
||||
void calculateAndUpdateMessage (SpLink* link, bool calcResidual = true)
|
||||
{
|
||||
if (DL >= 3) {
|
||||
cout << "calculating & updating " << link->toString() << endl;
|
||||
}
|
||||
calculateFactor2VariableMsg (link);
|
||||
if (calcResidual) {
|
||||
link->updateResidual();
|
||||
}
|
||||
link->updateMessage();
|
||||
}
|
||||
virtual void createLinks (void);
|
||||
|
||||
void calculateMessage (SpLink* link, bool calcResidual = true)
|
||||
{
|
||||
if (DL >= 3) {
|
||||
cout << "calculating " << link->toString() << endl;
|
||||
}
|
||||
calculateFactor2VariableMsg (link);
|
||||
if (calcResidual) {
|
||||
link->updateResidual();
|
||||
}
|
||||
}
|
||||
virtual void maxResidualSchedule (void);
|
||||
|
||||
void updateMessage (SpLink* link)
|
||||
{
|
||||
link->updateMessage();
|
||||
if (DL >= 3) {
|
||||
cout << "updating " << link->toString() << endl;
|
||||
}
|
||||
}
|
||||
virtual void calculateFactor2VariableMsg (SpLink*) const;
|
||||
|
||||
virtual Params getVar2FactorMsg (const SpLink*) const;
|
||||
|
||||
virtual Params getJointByConditioning (const VarIds&) const;
|
||||
|
||||
virtual void printLinkInformation (void) const;
|
||||
|
||||
SPNodeInfo* ninf (const FgVarNode* var) const
|
||||
{
|
||||
@ -145,7 +126,39 @@ class FgBpSolver : public Solver
|
||||
return facsI_[fac->getIndex()];
|
||||
}
|
||||
|
||||
struct CompareResidual {
|
||||
void calculateAndUpdateMessage (SpLink* link, bool calcResidual = true)
|
||||
{
|
||||
if (Constants::DEBUG >= 3) {
|
||||
cout << "calculating & updating " << link->toString() << endl;
|
||||
}
|
||||
calculateFactor2VariableMsg (link);
|
||||
if (calcResidual) {
|
||||
link->updateResidual();
|
||||
}
|
||||
link->updateMessage();
|
||||
}
|
||||
|
||||
void calculateMessage (SpLink* link, bool calcResidual = true)
|
||||
{
|
||||
if (Constants::DEBUG >= 3) {
|
||||
cout << "calculating " << link->toString() << endl;
|
||||
}
|
||||
calculateFactor2VariableMsg (link);
|
||||
if (calcResidual) {
|
||||
link->updateResidual();
|
||||
}
|
||||
}
|
||||
|
||||
void updateMessage (SpLink* link)
|
||||
{
|
||||
link->updateMessage();
|
||||
if (Constants::DEBUG >= 3) {
|
||||
cout << "updating " << link->toString() << endl;
|
||||
}
|
||||
}
|
||||
|
||||
struct CompareResidual
|
||||
{
|
||||
inline bool operator() (const SpLink* link1, const SpLink* link2)
|
||||
{
|
||||
return link1->getResidual() > link2->getResidual();
|
||||
@ -165,10 +178,8 @@ class FgBpSolver : public Solver
|
||||
SpLinkMap linkMap_;
|
||||
|
||||
private:
|
||||
void runLoopySolver (void);
|
||||
bool converged (void);
|
||||
|
||||
|
||||
void runLoopySolver (void);
|
||||
bool converged (void);
|
||||
};
|
||||
|
||||
#endif // HORUS_FGBPSOLVER_H
|
||||
|
@ -8,7 +8,9 @@
|
||||
|
||||
|
||||
vector<LiftedOperator*>
|
||||
LiftedOperator::getValidOps (ParfactorList& pfList, const Grounds& query)
|
||||
LiftedOperator::getValidOps (
|
||||
ParfactorList& pfList,
|
||||
const Grounds& query)
|
||||
{
|
||||
vector<LiftedOperator*> validOps;
|
||||
vector<SumOutOperator*> sumOutOps;
|
||||
@ -28,12 +30,15 @@ LiftedOperator::getValidOps (ParfactorList& pfList, const Grounds& query)
|
||||
|
||||
|
||||
void
|
||||
LiftedOperator::printValidOps (ParfactorList& pfList, const Grounds& query)
|
||||
LiftedOperator::printValidOps (
|
||||
ParfactorList& pfList,
|
||||
const Grounds& query)
|
||||
{
|
||||
vector<LiftedOperator*> validOps;
|
||||
validOps = LiftedOperator::getValidOps (pfList, query);
|
||||
for (unsigned i = 0; i < validOps.size(); i++) {
|
||||
cout << "-> " << validOps[i]->toString() << endl;
|
||||
delete validOps[i];
|
||||
}
|
||||
}
|
||||
|
||||
@ -56,14 +61,14 @@ SumOutOperator::getCost (void)
|
||||
pfIter = pfList_.begin();
|
||||
while (pfIter != pfList_.end()) {
|
||||
if ((*pfIter)->containsGroup (groupSet[i])) {
|
||||
int idx = (*pfIter)->indexOfFormulaWithGroup (groupSet[i]);
|
||||
int idx = (*pfIter)->indexOfGroup (groupSet[i]);
|
||||
cost *= (*pfIter)->range (idx);
|
||||
break;
|
||||
}
|
||||
++ pfIter;
|
||||
}
|
||||
}
|
||||
return cost;
|
||||
return cost;
|
||||
}
|
||||
|
||||
|
||||
@ -77,14 +82,13 @@ SumOutOperator::apply (void)
|
||||
pfList_.remove (iters[0]);
|
||||
for (unsigned i = 1; i < iters.size(); i++) {
|
||||
product->multiply (**(iters[i]));
|
||||
delete *(iters[i]);
|
||||
pfList_.remove (iters[i]);
|
||||
pfList_.removeAndDelete (iters[i]);
|
||||
}
|
||||
if (product->nrFormulas() == 1) {
|
||||
if (product->nrArguments() == 1) {
|
||||
delete product;
|
||||
return;
|
||||
}
|
||||
int fIdx = product->indexOfFormulaWithGroup (group_);
|
||||
int fIdx = product->indexOfGroup (group_);
|
||||
LogVarSet excl = product->exclusiveLogVars (fIdx);
|
||||
if (product->constr()->isCountNormalized (excl)) {
|
||||
product->sumOut (fIdx);
|
||||
@ -96,21 +100,21 @@ SumOutOperator::apply (void)
|
||||
pfList_.add (pfs[i]);
|
||||
}
|
||||
delete product;
|
||||
pfList_.shatter();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
vector<SumOutOperator*>
|
||||
SumOutOperator::getValidOps (ParfactorList& pfList, const Grounds& query)
|
||||
SumOutOperator::getValidOps (
|
||||
ParfactorList& pfList,
|
||||
const Grounds& query)
|
||||
{
|
||||
vector<SumOutOperator*> validOps;
|
||||
set<unsigned> allGroups;
|
||||
ParfactorList::const_iterator it = pfList.begin();
|
||||
while (it != pfList.end()) {
|
||||
assert (*it);
|
||||
const ProbFormulas& formulas = (*it)->formulas();
|
||||
const ProbFormulas& formulas = (*it)->arguments();
|
||||
for (unsigned i = 0; i < formulas.size(); i++) {
|
||||
allGroups.insert (formulas[i].group());
|
||||
}
|
||||
@ -134,8 +138,8 @@ SumOutOperator::toString (void)
|
||||
stringstream ss;
|
||||
vector<ParfactorList::iterator> pfIters;
|
||||
pfIters = parfactorsWithGroup (pfList_, group_);
|
||||
int idx = (*pfIters[0])->indexOfFormulaWithGroup (group_);
|
||||
ProbFormula f = (*pfIters[0])->formula (idx);
|
||||
int idx = (*pfIters[0])->indexOfGroup (group_);
|
||||
ProbFormula f = (*pfIters[0])->argument (idx);
|
||||
TupleSet tupleSet = (*pfIters[0])->constr()->tupleSet (f.logVars());
|
||||
ss << "sum out " << f.functor() << "/" << f.arity();
|
||||
ss << "|" << tupleSet << " (group " << group_ << ")";
|
||||
@ -158,9 +162,9 @@ SumOutOperator::validOp (
|
||||
}
|
||||
unordered_map<unsigned, unsigned> groupToRange;
|
||||
for (unsigned i = 0; i < pfIters.size(); i++) {
|
||||
int fIdx = (*pfIters[i])->indexOfFormulaWithGroup (group);
|
||||
if ((*pfIters[i])->formulas()[fIdx].contains (
|
||||
(*pfIters[i])->elimLogVars()) == false) {
|
||||
int fIdx = (*pfIters[i])->indexOfGroup (group);
|
||||
if ((*pfIters[i])->argument (fIdx).contains (
|
||||
(*pfIters[i])->elimLogVars()) == false) {
|
||||
return false;
|
||||
}
|
||||
vector<unsigned> ranges = (*pfIters[i])->ranges();
|
||||
@ -206,8 +210,8 @@ SumOutOperator::isToEliminate (
|
||||
unsigned group,
|
||||
const Grounds& query)
|
||||
{
|
||||
int fIdx = g->indexOfFormulaWithGroup (group);
|
||||
const ProbFormula& formula = g->formula (fIdx);
|
||||
int fIdx = g->indexOfGroup (group);
|
||||
const ProbFormula& formula = g->argument (fIdx);
|
||||
bool toElim = true;
|
||||
for (unsigned i = 0; i < query.size(); i++) {
|
||||
if (formula.functor() == query[i].functor() &&
|
||||
@ -228,7 +232,7 @@ unsigned
|
||||
CountingOperator::getCost (void)
|
||||
{
|
||||
unsigned cost = 0;
|
||||
int fIdx = (*pfIter_)->indexOfFormulaWithLogVar (X_);
|
||||
int fIdx = (*pfIter_)->indexOfLogVar (X_);
|
||||
unsigned range = (*pfIter_)->range (fIdx);
|
||||
unsigned size = (*pfIter_)->size() / range;
|
||||
TinySet<unsigned> counts;
|
||||
@ -247,18 +251,19 @@ CountingOperator::apply (void)
|
||||
if ((*pfIter_)->constr()->isCountNormalized (X_)) {
|
||||
(*pfIter_)->countConvert (X_);
|
||||
} else {
|
||||
Parfactors pfs = FoveSolver::countNormalize (*pfIter_, X_);
|
||||
Parfactor* pf = *pfIter_;
|
||||
pfList_.remove (pfIter_);
|
||||
Parfactors pfs = FoveSolver::countNormalize (pf, X_);
|
||||
for (unsigned i = 0; i < pfs.size(); i++) {
|
||||
unsigned condCount = pfs[i]->constr()->getConditionalCount (X_);
|
||||
bool cartProduct = pfs[i]->constr()->isCarteesianProduct (
|
||||
(*pfIter_)->countedLogVars() | X_);
|
||||
pfs[i]->countedLogVars() | X_);
|
||||
if (condCount > 1 && cartProduct) {
|
||||
pfs[i]->countConvert (X_);
|
||||
}
|
||||
pfList_.add (pfs[i]);
|
||||
}
|
||||
pfList_.deleteAndRemove (pfIter_);
|
||||
pfList_.shatter();
|
||||
delete pf;
|
||||
}
|
||||
}
|
||||
|
||||
@ -289,14 +294,17 @@ CountingOperator::toString (void)
|
||||
{
|
||||
stringstream ss;
|
||||
ss << "count convert " << X_ << " in " ;
|
||||
ss << (*pfIter_)->getHeaderString();
|
||||
ss << (*pfIter_)->getLabel();
|
||||
ss << " [cost=" << getCost() << "]" << endl;
|
||||
Parfactors pfs = FoveSolver::countNormalize (*pfIter_, X_);
|
||||
if ((*pfIter_)->constr()->isCountNormalized (X_) == false) {
|
||||
for (unsigned i = 0; i < pfs.size(); i++) {
|
||||
ss << " º " << pfs[i]->getHeaderString() << endl;
|
||||
ss << " º " << pfs[i]->getLabel() << endl;
|
||||
}
|
||||
}
|
||||
for (unsigned i = 0; i < pfs.size(); i++) {
|
||||
delete pfs[i];
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
@ -308,8 +316,8 @@ CountingOperator::validOp (Parfactor* g, LogVar X)
|
||||
if (g->nrFormulas (X) != 1) {
|
||||
return false;
|
||||
}
|
||||
int fIdx = g->indexOfFormulaWithLogVar (X);
|
||||
if (g->formulas()[fIdx].isCounting()) {
|
||||
int fIdx = g->indexOfLogVar (X);
|
||||
if (g->argument (fIdx).isCounting()) {
|
||||
return false;
|
||||
}
|
||||
bool countNormalized = g->constr()->isCountNormalized (X);
|
||||
@ -332,10 +340,10 @@ GroundOperator::getCost (void)
|
||||
unsigned cost = 0;
|
||||
bool isCountingLv = (*pfIter_)->countedLogVars().contains (X_);
|
||||
if (isCountingLv) {
|
||||
int fIdx = (*pfIter_)->indexOfFormulaWithLogVar (X_);
|
||||
int fIdx = (*pfIter_)->indexOfLogVar (X_);
|
||||
unsigned currSize = (*pfIter_)->size();
|
||||
unsigned nrHists = (*pfIter_)->range (fIdx);
|
||||
unsigned range = (*pfIter_)->formula(fIdx).range();
|
||||
unsigned range = (*pfIter_)->argument (fIdx).range();
|
||||
unsigned nrSymbols = (*pfIter_)->constr()->getConditionalCount (X_);
|
||||
cost = (currSize / nrHists) * (std::pow (range, nrSymbols));
|
||||
} else {
|
||||
@ -350,18 +358,17 @@ void
|
||||
GroundOperator::apply (void)
|
||||
{
|
||||
bool countedLv = (*pfIter_)->countedLogVars().contains (X_);
|
||||
Parfactor* pf = *pfIter_;
|
||||
pfList_.remove (pfIter_);
|
||||
if (countedLv) {
|
||||
(*pfIter_)->fullExpand (X_);
|
||||
(*pfIter_)->setNewGroups();
|
||||
pfList_.shatter();
|
||||
pf->fullExpand (X_);
|
||||
pfList_.add (pf);
|
||||
} else {
|
||||
ConstraintTrees cts = (*pfIter_)->constr()->ground (X_);
|
||||
ConstraintTrees cts = pf->constr()->ground (X_);
|
||||
for (unsigned i = 0; i < cts.size(); i++) {
|
||||
Parfactor* newPf = new Parfactor (*pfIter_, cts[i]);
|
||||
pfList_.add (newPf);
|
||||
pfList_.add (new Parfactor (pf, cts[i]));
|
||||
}
|
||||
pfList_.deleteAndRemove (pfIter_);
|
||||
pfList_.shatter();
|
||||
delete pf;
|
||||
}
|
||||
}
|
||||
|
||||
@ -393,24 +400,13 @@ GroundOperator::toString (void)
|
||||
((*pfIter_)->countedLogVars().contains (X_))
|
||||
? ss << "full expanding "
|
||||
: ss << "grounding " ;
|
||||
ss << X_ << " in " << (*pfIter_)->getHeaderString();
|
||||
ss << X_ << " in " << (*pfIter_)->getLabel();
|
||||
ss << " [cost=" << getCost() << "]" << endl;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
|
||||
|
||||
FoveSolver::FoveSolver (const ParfactorList* pfList)
|
||||
{
|
||||
for (ParfactorList::const_iterator it = pfList->begin();
|
||||
it != pfList->end();
|
||||
it ++) {
|
||||
pfList_.addShattered (new Parfactor (**it));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
FoveSolver::getPosterioriOf (const Ground& query)
|
||||
{
|
||||
@ -422,14 +418,12 @@ FoveSolver::getPosterioriOf (const Ground& query)
|
||||
Params
|
||||
FoveSolver::getJointDistributionOf (const Grounds& query)
|
||||
{
|
||||
shatterAgainstQuery (query);
|
||||
runSolver (query);
|
||||
(*pfList_.begin())->normalize();
|
||||
Params params = (*pfList_.begin())->params();
|
||||
if (Globals::logDomain) {
|
||||
Util::fromLog (params);
|
||||
}
|
||||
delete *pfList_.begin();
|
||||
return params;
|
||||
}
|
||||
|
||||
@ -438,32 +432,38 @@ FoveSolver::getJointDistributionOf (const Grounds& query)
|
||||
void
|
||||
FoveSolver::absorveEvidence (
|
||||
ParfactorList& pfList,
|
||||
const ObservedFormulas& obsFormulas)
|
||||
ObservedFormulas& obsFormulas)
|
||||
{
|
||||
ParfactorList::iterator it = pfList.begin();
|
||||
while (it != pfList.end()) {
|
||||
bool increment = true;
|
||||
for (unsigned i = 0; i < obsFormulas.size(); i++) {
|
||||
if (absorved (pfList, it, obsFormulas[i])) {
|
||||
it = pfList.deleteAndRemove (it);
|
||||
increment = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (increment) {
|
||||
++ it;
|
||||
for (unsigned i = 0; i < obsFormulas.size(); i++) {
|
||||
Parfactors newPfs;
|
||||
ParfactorList::iterator it = pfList.begin();
|
||||
while (it != pfList.end()) {
|
||||
Parfactor* pf = *it;
|
||||
it = pfList.remove (it);
|
||||
Parfactors absorvedPfs = absorve (obsFormulas[i], pf);
|
||||
if (absorvedPfs.empty() == false) {
|
||||
if (absorvedPfs.size() == 1 && absorvedPfs[0] == 0) {
|
||||
// just remove pf;
|
||||
} else {
|
||||
Util::addToVector (newPfs, absorvedPfs);
|
||||
}
|
||||
delete pf;
|
||||
} else {
|
||||
it = pfList.insertShattered (it, pf);
|
||||
++ it;
|
||||
}
|
||||
}
|
||||
pfList.add (newPfs);
|
||||
}
|
||||
pfList.shatter();
|
||||
if (obsFormulas.empty() == false) {
|
||||
cout << "*******************************************************" << endl;
|
||||
if (Constants::DEBUG > 1 && obsFormulas.empty() == false) {
|
||||
Util::printAsteriskLine();
|
||||
cout << "AFTER EVIDENCE ABSORVED" << endl;
|
||||
for (unsigned i = 0; i < obsFormulas.size(); i++) {
|
||||
cout << " -> " << *obsFormulas[i] << endl;
|
||||
cout << " -> " << obsFormulas[i] << endl;
|
||||
}
|
||||
cout << "*******************************************************" << endl;
|
||||
Util::printAsteriskLine();
|
||||
pfList.print();
|
||||
}
|
||||
pfList.print();
|
||||
}
|
||||
|
||||
|
||||
@ -473,14 +473,14 @@ FoveSolver::countNormalize (
|
||||
Parfactor* g,
|
||||
const LogVarSet& set)
|
||||
{
|
||||
if (set.empty()) {
|
||||
assert (false); // TODO
|
||||
return {};
|
||||
}
|
||||
Parfactors normPfs;
|
||||
ConstraintTrees normCts = g->constr()->countNormalize (set);
|
||||
for (unsigned i = 0; i < normCts.size(); i++) {
|
||||
normPfs.push_back (new Parfactor (g, normCts[i]));
|
||||
if (set.empty()) {
|
||||
normPfs.push_back (new Parfactor (*g));
|
||||
} else {
|
||||
ConstraintTrees normCts = g->constr()->countNormalize (set);
|
||||
for (unsigned i = 0; i < normCts.size(); i++) {
|
||||
normPfs.push_back (new Parfactor (g, normCts[i]));
|
||||
}
|
||||
}
|
||||
return normPfs;
|
||||
}
|
||||
@ -490,17 +490,25 @@ FoveSolver::countNormalize (
|
||||
void
|
||||
FoveSolver::runSolver (const Grounds& query)
|
||||
{
|
||||
shatterAgainstQuery (query);
|
||||
runWeakBayesBall (query);
|
||||
while (true) {
|
||||
cout << "---------------------------------------------------" << endl;
|
||||
pfList_.print();
|
||||
LiftedOperator::printValidOps (pfList_, query);
|
||||
if (Constants::DEBUG > 1) {
|
||||
Util::printDashedLine();
|
||||
pfList_.print();
|
||||
LiftedOperator::printValidOps (pfList_, query);
|
||||
}
|
||||
LiftedOperator* op = getBestOperation (query);
|
||||
if (op == 0) {
|
||||
break;
|
||||
}
|
||||
cout << "best operation: " << op->toString() << endl;
|
||||
if (Constants::DEBUG > 1) {
|
||||
cout << "best operation: " << op->toString() << endl;
|
||||
}
|
||||
op->apply();
|
||||
delete op;
|
||||
}
|
||||
assert (pfList_.size() > 0);
|
||||
if (pfList_.size() > 1) {
|
||||
ParfactorList::iterator pfIter = pfList_.begin();
|
||||
pfIter ++;
|
||||
@ -514,26 +522,6 @@ FoveSolver::runSolver (const Grounds& query)
|
||||
|
||||
|
||||
|
||||
bool
|
||||
FoveSolver::allEliminated (const Grounds&)
|
||||
{
|
||||
ParfactorList::iterator pfIter = pfList_.begin();
|
||||
while (pfIter != pfList_.end()) {
|
||||
const ProbFormulas formulas = (*pfIter)->formulas();
|
||||
for (unsigned i = 0; i < formulas.size(); i++) {
|
||||
//bool toElim = false;
|
||||
//for (unsigned j = 0; j < queries.size(); j++) {
|
||||
// if ((*pfIter)->containsGround (queries[i]) == false) {
|
||||
// return
|
||||
// }
|
||||
}
|
||||
++ pfIter;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
|
||||
LiftedOperator*
|
||||
FoveSolver::getBestOperation (const Grounds& query)
|
||||
{
|
||||
@ -548,156 +536,170 @@ FoveSolver::getBestOperation (const Grounds& query)
|
||||
bestCost = cost;
|
||||
}
|
||||
}
|
||||
for (unsigned i = 0; i < validOps.size(); i++) {
|
||||
if (validOps[i] != bestOp) {
|
||||
delete validOps[i];
|
||||
}
|
||||
}
|
||||
return bestOp;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FoveSolver::runWeakBayesBall (const Grounds& query)
|
||||
{
|
||||
queue<unsigned> todo; // groups to process
|
||||
set<unsigned> done; // processed or in queue
|
||||
for (unsigned i = 0; i < query.size(); i++) {
|
||||
ParfactorList::iterator it = pfList_.begin();
|
||||
while (it != pfList_.end()) {
|
||||
int group = (*it)->findGroup (query[i]);
|
||||
if (group != -1) {
|
||||
todo.push (group);
|
||||
done.insert (group);
|
||||
break;
|
||||
}
|
||||
++ it;
|
||||
}
|
||||
}
|
||||
|
||||
set<Parfactor*> requiredPfs;
|
||||
while (todo.empty() == false) {
|
||||
unsigned group = todo.front();
|
||||
ParfactorList::iterator it = pfList_.begin();
|
||||
while (it != pfList_.end()) {
|
||||
if (Util::contains (requiredPfs, *it) == false &&
|
||||
(*it)->containsGroup (group)) {
|
||||
vector<unsigned> groups = (*it)->getAllGroups();
|
||||
for (unsigned i = 0; i < groups.size(); i++) {
|
||||
if (Util::contains (done, groups[i]) == false) {
|
||||
todo.push (groups[i]);
|
||||
done.insert (groups[i]);
|
||||
}
|
||||
}
|
||||
requiredPfs.insert (*it);
|
||||
}
|
||||
++ it;
|
||||
}
|
||||
todo.pop();
|
||||
}
|
||||
|
||||
ParfactorList::iterator it = pfList_.begin();
|
||||
while (it != pfList_.end()) {
|
||||
if (Util::contains (requiredPfs, *it) == false) {
|
||||
it = pfList_.removeAndDelete (it);
|
||||
} else {
|
||||
++ it;
|
||||
}
|
||||
}
|
||||
|
||||
if (Constants::DEBUG > 1) {
|
||||
Util::printHeader ("REQUIRED PARFACTORS");
|
||||
pfList_.print();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FoveSolver::shatterAgainstQuery (const Grounds& query)
|
||||
{
|
||||
// return;
|
||||
return ;
|
||||
for (unsigned i = 0; i < query.size(); i++) {
|
||||
if (query[i].isAtom()) {
|
||||
continue;
|
||||
}
|
||||
ParfactorList pfListCopy = pfList_;
|
||||
pfList_.clear();
|
||||
for (ParfactorList::iterator it = pfListCopy.begin();
|
||||
it != pfListCopy.end(); ++ it) {
|
||||
Parfactor* pf = *it;
|
||||
if (pf->containsGround (query[i])) {
|
||||
Parfactors newPfs;
|
||||
ParfactorList::iterator it = pfList_.begin();
|
||||
while (it != pfList_.end()) {
|
||||
if ((*it)->containsGround (query[i])) {
|
||||
std::pair<ConstraintTree*, ConstraintTree*> split =
|
||||
pf->constr()->split (query[i].args(), query[i].arity());
|
||||
(*it)->constr()->split (query[i].args(), query[i].arity());
|
||||
ConstraintTree* commCt = split.first;
|
||||
ConstraintTree* exclCt = split.second;
|
||||
pfList_.add (new Parfactor (pf, commCt));
|
||||
newPfs.push_back (new Parfactor (*it, commCt));
|
||||
if (exclCt->empty() == false) {
|
||||
pfList_.add (new Parfactor (pf, exclCt));
|
||||
newPfs.push_back (new Parfactor (*it, exclCt));
|
||||
} else {
|
||||
delete exclCt;
|
||||
}
|
||||
delete pf;
|
||||
it = pfList_.removeAndDelete (it);
|
||||
} else {
|
||||
pfList_.add (pf);
|
||||
++ it;
|
||||
}
|
||||
}
|
||||
pfList_.shatter();
|
||||
pfList_.add (newPfs);
|
||||
}
|
||||
cout << endl;
|
||||
cout << "*******************************************************" << endl;
|
||||
cout << "SHATTERED AGAINST THE QUERY" << endl;
|
||||
for (unsigned i = 0; i < query.size(); i++) {
|
||||
cout << " -> " << query[i] << endl;
|
||||
if (Constants::DEBUG > 1) {
|
||||
cout << endl;
|
||||
Util::printAsteriskLine();
|
||||
cout << "SHATTERED AGAINST THE QUERY" << endl;
|
||||
for (unsigned i = 0; i < query.size(); i++) {
|
||||
cout << " -> " << query[i] << endl;
|
||||
}
|
||||
Util::printAsteriskLine();
|
||||
pfList_.print();
|
||||
}
|
||||
cout << "*******************************************************" << endl;
|
||||
pfList_.print();
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
FoveSolver::absorved (
|
||||
ParfactorList& pfList,
|
||||
ParfactorList::iterator pfIter,
|
||||
const ObservedFormula* obsFormula)
|
||||
Parfactors
|
||||
FoveSolver::absorve (
|
||||
ObservedFormula& obsFormula,
|
||||
Parfactor* g)
|
||||
{
|
||||
Parfactors absorvedPfs;
|
||||
Parfactor* g = *pfIter;
|
||||
const ProbFormulas& formulas = g->formulas();
|
||||
const ProbFormulas& formulas = g->arguments();
|
||||
for (unsigned i = 0; i < formulas.size(); i++) {
|
||||
if (obsFormula->functor() == formulas[i].functor() &&
|
||||
obsFormula->arity() == formulas[i].arity()) {
|
||||
if (obsFormula.functor() == formulas[i].functor() &&
|
||||
obsFormula.arity() == formulas[i].arity()) {
|
||||
|
||||
if (obsFormula->isAtom()) {
|
||||
if (obsFormula.isAtom()) {
|
||||
if (formulas.size() > 1) {
|
||||
g->absorveEvidence (i, obsFormula->evidence());
|
||||
g->absorveEvidence (formulas[i], obsFormula.evidence());
|
||||
} else {
|
||||
return true;
|
||||
// hack to erase parfactor g
|
||||
absorvedPfs.push_back (0);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
g->constr()->moveToTop (formulas[i].logVars());
|
||||
std::pair<ConstraintTree*, ConstraintTree*> res
|
||||
= g->constr()->split (obsFormula->constr(), formulas[i].arity());
|
||||
= g->constr()->split (&(obsFormula.constr()), formulas[i].arity());
|
||||
ConstraintTree* commCt = res.first;
|
||||
ConstraintTree* exclCt = res.second;
|
||||
|
||||
if (commCt->empty()) {
|
||||
delete commCt;
|
||||
delete exclCt;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (exclCt->empty() == false) {
|
||||
pfList.add (new Parfactor (g, exclCt));
|
||||
} else {
|
||||
delete exclCt;
|
||||
}
|
||||
|
||||
if (formulas.size() > 1) {
|
||||
LogVarSet excl = g->exclusiveLogVars (i);
|
||||
Parfactors countNormPfs = countNormalize (g, excl);
|
||||
for (unsigned j = 0; j < countNormPfs.size(); j++) {
|
||||
countNormPfs[j]->absorveEvidence (i, obsFormula->evidence());
|
||||
absorvedPfs.push_back (countNormPfs[j]);
|
||||
if (commCt->empty() == false) {
|
||||
if (formulas.size() > 1) {
|
||||
LogVarSet excl = g->exclusiveLogVars (i);
|
||||
Parfactors countNormPfs = countNormalize (g, excl);
|
||||
for (unsigned j = 0; j < countNormPfs.size(); j++) {
|
||||
countNormPfs[j]->absorveEvidence (
|
||||
formulas[i], obsFormula.evidence());
|
||||
absorvedPfs.push_back (countNormPfs[j]);
|
||||
}
|
||||
} else {
|
||||
delete commCt;
|
||||
}
|
||||
if (exclCt->empty() == false) {
|
||||
absorvedPfs.push_back (new Parfactor (g, exclCt));
|
||||
} else {
|
||||
delete exclCt;
|
||||
}
|
||||
if (absorvedPfs.empty()) {
|
||||
// hack to erase parfactor g
|
||||
absorvedPfs.push_back (0);
|
||||
}
|
||||
break;
|
||||
} else {
|
||||
delete commCt;
|
||||
delete exclCt;
|
||||
}
|
||||
return true;
|
||||
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
FoveSolver::proper (
|
||||
const ProbFormula& f1,
|
||||
ConstraintTree* c1,
|
||||
const ProbFormula& f2,
|
||||
ConstraintTree* c2)
|
||||
{
|
||||
return disjoint (f1, c1, f2, c2)
|
||||
|| identical (f1, c1, f2, c2);
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
FoveSolver::identical (
|
||||
const ProbFormula& f1,
|
||||
ConstraintTree* c1,
|
||||
const ProbFormula& f2,
|
||||
ConstraintTree* c2)
|
||||
{
|
||||
if (f1.sameSkeletonAs (f2) == false) {
|
||||
return false;
|
||||
}
|
||||
c1->moveToTop (f1.logVars());
|
||||
c2->moveToTop (f2.logVars());
|
||||
return ConstraintTree::identical (
|
||||
c1, c2, f1.logVars().size());
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
FoveSolver::disjoint (
|
||||
const ProbFormula& f1,
|
||||
ConstraintTree* c1,
|
||||
const ProbFormula& f2,
|
||||
ConstraintTree* c2)
|
||||
{
|
||||
if (f1.sameSkeletonAs (f2) == false) {
|
||||
return true;
|
||||
}
|
||||
c1->moveToTop (f1.logVars());
|
||||
c2->moveToTop (f2.logVars());
|
||||
return ConstraintTree::overlap (
|
||||
c1, c2, f1.arity()) == false;
|
||||
return absorvedPfs;
|
||||
}
|
||||
|
||||
|
@ -9,10 +9,14 @@ class LiftedOperator
|
||||
{
|
||||
public:
|
||||
virtual unsigned getCost (void) = 0;
|
||||
|
||||
virtual void apply (void) = 0;
|
||||
|
||||
virtual string toString (void) = 0;
|
||||
|
||||
static vector<LiftedOperator*> getValidOps (
|
||||
ParfactorList&, const Grounds&);
|
||||
|
||||
static void printValidOps (ParfactorList&, const Grounds&);
|
||||
};
|
||||
|
||||
@ -23,18 +27,26 @@ class SumOutOperator : public LiftedOperator
|
||||
public:
|
||||
SumOutOperator (unsigned group, ParfactorList& pfList)
|
||||
: group_(group), pfList_(pfList) { }
|
||||
|
||||
unsigned getCost (void);
|
||||
|
||||
void apply (void);
|
||||
|
||||
static vector<SumOutOperator*> getValidOps (
|
||||
ParfactorList&, const Grounds&);
|
||||
|
||||
string toString (void);
|
||||
|
||||
private:
|
||||
static bool validOp (unsigned, ParfactorList&, const Grounds&);
|
||||
|
||||
static vector<ParfactorList::iterator> parfactorsWithGroup (
|
||||
ParfactorList& pfList, unsigned group);
|
||||
|
||||
static bool isToEliminate (Parfactor*, unsigned, const Grounds&);
|
||||
unsigned group_;
|
||||
ParfactorList& pfList_;
|
||||
|
||||
unsigned group_;
|
||||
ParfactorList& pfList_;
|
||||
};
|
||||
|
||||
|
||||
@ -47,15 +59,21 @@ class CountingOperator : public LiftedOperator
|
||||
LogVar X,
|
||||
ParfactorList& pfList)
|
||||
: pfIter_(pfIter), X_(X), pfList_(pfList) { }
|
||||
|
||||
unsigned getCost (void);
|
||||
|
||||
void apply (void);
|
||||
|
||||
static vector<CountingOperator*> getValidOps (ParfactorList&);
|
||||
|
||||
string toString (void);
|
||||
|
||||
private:
|
||||
static bool validOp (Parfactor*, LogVar);
|
||||
ParfactorList::iterator pfIter_;
|
||||
LogVar X_;
|
||||
ParfactorList& pfList_;
|
||||
|
||||
ParfactorList::iterator pfIter_;
|
||||
LogVar X_;
|
||||
ParfactorList& pfList_;
|
||||
};
|
||||
|
||||
|
||||
@ -68,14 +86,19 @@ class GroundOperator : public LiftedOperator
|
||||
LogVar X,
|
||||
ParfactorList& pfList)
|
||||
: pfIter_(pfIter), X_(X), pfList_(pfList) { }
|
||||
|
||||
unsigned getCost (void);
|
||||
|
||||
void apply (void);
|
||||
|
||||
static vector<GroundOperator*> getValidOps (ParfactorList&);
|
||||
|
||||
string toString (void);
|
||||
|
||||
private:
|
||||
ParfactorList::iterator pfIter_;
|
||||
LogVar X_;
|
||||
ParfactorList& pfList_;
|
||||
ParfactorList::iterator pfIter_;
|
||||
LogVar X_;
|
||||
ParfactorList& pfList_;
|
||||
};
|
||||
|
||||
|
||||
@ -83,49 +106,29 @@ class GroundOperator : public LiftedOperator
|
||||
class FoveSolver
|
||||
{
|
||||
public:
|
||||
FoveSolver (const ParfactorList*);
|
||||
FoveSolver (const ParfactorList& pfList) : pfList_(pfList) { }
|
||||
|
||||
Params getPosterioriOf (const Ground&);
|
||||
Params getJointDistributionOf (const Grounds&);
|
||||
Params getPosterioriOf (const Ground&);
|
||||
|
||||
static void absorveEvidence (
|
||||
ParfactorList& pfList,
|
||||
const ObservedFormulas& obsFormulas);
|
||||
Params getJointDistributionOf (const Grounds&);
|
||||
|
||||
static Parfactors countNormalize (Parfactor*, const LogVarSet&);
|
||||
static void absorveEvidence (
|
||||
ParfactorList& pfList, ObservedFormulas& obsFormulas);
|
||||
|
||||
static Parfactors countNormalize (Parfactor*, const LogVarSet&);
|
||||
|
||||
private:
|
||||
void runSolver (const Grounds&);
|
||||
bool allEliminated (const Grounds&);
|
||||
LiftedOperator* getBestOperation (const Grounds&);
|
||||
void shatterAgainstQuery (const Grounds&);
|
||||
void runSolver (const Grounds&);
|
||||
|
||||
static bool absorved (
|
||||
ParfactorList& pfList,
|
||||
ParfactorList::iterator pfIter,
|
||||
const ObservedFormula*);
|
||||
LiftedOperator* getBestOperation (const Grounds&);
|
||||
|
||||
public:
|
||||
void runWeakBayesBall (const Grounds&);
|
||||
|
||||
static bool proper (
|
||||
const ProbFormula&,
|
||||
ConstraintTree*,
|
||||
const ProbFormula&,
|
||||
ConstraintTree*);
|
||||
void shatterAgainstQuery (const Grounds&);
|
||||
|
||||
static bool identical (
|
||||
const ProbFormula&,
|
||||
ConstraintTree*,
|
||||
const ProbFormula&,
|
||||
ConstraintTree*);
|
||||
static Parfactors absorve (ObservedFormula&, Parfactor*);
|
||||
|
||||
static bool disjoint (
|
||||
const ProbFormula&,
|
||||
ConstraintTree*,
|
||||
const ProbFormula&,
|
||||
ConstraintTree*);
|
||||
|
||||
ParfactorList pfList_;
|
||||
ParfactorList pfList_;
|
||||
};
|
||||
|
||||
#endif // HORUS_FOVESOLVER_H
|
||||
|
@ -1,66 +1,63 @@
|
||||
#ifndef HORUS_GRAPHICALMODEL_H
|
||||
#define HORUS_GRAPHICALMODEL_H
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include "VarNode.h"
|
||||
#include "Distribution.h"
|
||||
#include "Util.h"
|
||||
#include "Horus.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
||||
struct VariableInfo
|
||||
struct VarInfo
|
||||
{
|
||||
VariableInfo (string l, const States& sts)
|
||||
{
|
||||
label = l;
|
||||
states = sts;
|
||||
}
|
||||
string label;
|
||||
States states;
|
||||
VarInfo (string l, const States& sts) : label(l), states(sts) { }
|
||||
string label;
|
||||
States states;
|
||||
};
|
||||
|
||||
|
||||
class GraphicalModel
|
||||
{
|
||||
public:
|
||||
virtual ~GraphicalModel (void) {};
|
||||
virtual VarNode* getVariableNode (VarId) const = 0;
|
||||
virtual VarNodes getVariableNodes (void) const = 0;
|
||||
virtual void printGraphicalModel (void) const = 0;
|
||||
virtual ~GraphicalModel (void) { };
|
||||
|
||||
virtual VarNode* getVariableNode (VarId) const = 0;
|
||||
|
||||
virtual VarNodes getVariableNodes (void) const = 0;
|
||||
|
||||
virtual void printGraphicalModel (void) const = 0;
|
||||
|
||||
static void addVariableInformation (VarId vid, string label,
|
||||
const States& states)
|
||||
static void addVariableInformation (
|
||||
VarId vid, string label, const States& states)
|
||||
{
|
||||
assert (varsInfo_.find (vid) == varsInfo_.end());
|
||||
varsInfo_.insert (make_pair (vid, VariableInfo (label, states)));
|
||||
assert (Util::contains (varsInfo_, vid) == false);
|
||||
varsInfo_.insert (make_pair (vid, VarInfo (label, states)));
|
||||
}
|
||||
static VariableInfo getVariableInformation (VarId vid)
|
||||
|
||||
static VarInfo getVarInformation (VarId vid)
|
||||
{
|
||||
assert (varsInfo_.find (vid) != varsInfo_.end());
|
||||
assert (Util::contains (varsInfo_, vid));
|
||||
return varsInfo_.find (vid)->second;
|
||||
}
|
||||
|
||||
static bool variablesHaveInformation (void)
|
||||
{
|
||||
return varsInfo_.size() != 0;
|
||||
}
|
||||
|
||||
static void clearVariablesInformation (void)
|
||||
{
|
||||
varsInfo_.clear();
|
||||
}
|
||||
static void addDistribution (unsigned id, Distribution* dist)
|
||||
{
|
||||
distsInfo_[id] = dist;
|
||||
}
|
||||
static void updateDistribution (unsigned id, const Params& params)
|
||||
{
|
||||
distsInfo_[id]->updateParameters (params);
|
||||
}
|
||||
|
||||
private:
|
||||
static unordered_map<VarId,VariableInfo> varsInfo_;
|
||||
static unordered_map<unsigned,Distribution*> distsInfo_;
|
||||
static unordered_map<VarId,VarInfo> varsInfo_;
|
||||
};
|
||||
|
||||
#endif // HORUS_GRAPHICALMODEL_H
|
||||
|
@ -84,16 +84,34 @@ HistogramSet::nrHistograms (unsigned N, unsigned R)
|
||||
|
||||
unsigned
|
||||
HistogramSet::findIndex (
|
||||
const Histogram& hist,
|
||||
const vector<Histogram>& histograms)
|
||||
const Histogram& h,
|
||||
const vector<Histogram>& hists)
|
||||
{
|
||||
vector<Histogram>::const_iterator it = std::lower_bound (
|
||||
histograms.begin(),
|
||||
histograms.end(),
|
||||
hist,
|
||||
std::greater<Histogram>());
|
||||
assert (it != histograms.end() && *it == hist);
|
||||
return std::distance (histograms.begin(), it);
|
||||
hists.begin(), hists.end(), h, std::greater<Histogram>());
|
||||
assert (it != hists.end() && *it == h);
|
||||
return std::distance (hists.begin(), it);
|
||||
}
|
||||
|
||||
|
||||
|
||||
vector<double>
|
||||
HistogramSet::getNumAssigns (unsigned N, unsigned R)
|
||||
{
|
||||
HistogramSet hs (N, R);
|
||||
unsigned N_factorial = Util::factorial (N);
|
||||
unsigned H = hs.nrHistograms();
|
||||
vector<double> numAssigns;
|
||||
numAssigns.reserve (H);
|
||||
for (unsigned h = 0; h < H; h++) {
|
||||
unsigned prod = 1;
|
||||
for (unsigned r = 0; r < R; r++) {
|
||||
prod *= Util::factorial (hs[r]);
|
||||
}
|
||||
numAssigns.push_back (LogAware::tl (N_factorial / prod));
|
||||
hs.nextHistogram();
|
||||
}
|
||||
return numAssigns;
|
||||
}
|
||||
|
||||
|
||||
|
@ -26,8 +26,9 @@ class HistogramSet
|
||||
static unsigned nrHistograms (unsigned, unsigned);
|
||||
|
||||
static unsigned findIndex (
|
||||
const Histogram&,
|
||||
const vector<Histogram>&);
|
||||
const Histogram&, const vector<Histogram>&);
|
||||
|
||||
static vector<double> getNumAssigns (unsigned, unsigned);
|
||||
|
||||
friend std::ostream& operator<< (ostream &os, const HistogramSet& hs);
|
||||
|
||||
|
@ -1,17 +1,9 @@
|
||||
#ifndef HORUS_HORUS_H
|
||||
#define HORUS_HORUS_H
|
||||
|
||||
#include <cmath>
|
||||
#include <cassert>
|
||||
#include <limits>
|
||||
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
|
||||
#define DISALLOW_COPY_AND_ASSIGN(TypeName) \
|
||||
TypeName(const TypeName&); \
|
||||
@ -25,49 +17,48 @@ class FgVarNode;
|
||||
class FgFacNode;
|
||||
class Factor;
|
||||
|
||||
typedef vector<double> Params;
|
||||
typedef unsigned VarId;
|
||||
typedef vector<VarId> VarIds;
|
||||
typedef vector<VarNode*> VarNodes;
|
||||
typedef vector<BayesNode*> BnNodeSet;
|
||||
typedef vector<FgVarNode*> FgVarSet;
|
||||
typedef vector<FgFacNode*> FgFacSet;
|
||||
typedef vector<Factor*> FactorSet;
|
||||
typedef vector<string> States;
|
||||
typedef vector<unsigned> Ranges;
|
||||
typedef vector<double> Params;
|
||||
typedef unsigned VarId;
|
||||
typedef vector<VarId> VarIds;
|
||||
typedef vector<VarNode*> VarNodes;
|
||||
typedef vector<BayesNode*> BnNodeSet;
|
||||
typedef vector<FgVarNode*> FgVarSet;
|
||||
typedef vector<FgFacNode*> FgFacSet;
|
||||
typedef vector<Factor*> FactorSet;
|
||||
typedef vector<string> States;
|
||||
typedef vector<unsigned> Ranges;
|
||||
|
||||
|
||||
namespace Globals {
|
||||
extern bool logDomain;
|
||||
enum InfAlgorithms
|
||||
{
|
||||
VE, // variable elimination
|
||||
BN_BP, // bayesian network belief propagation
|
||||
FG_BP, // factor graph belief propagation
|
||||
CBP // counting bp solver
|
||||
};
|
||||
|
||||
|
||||
// level of debug information
|
||||
static const unsigned DL = 1;
|
||||
namespace Globals {
|
||||
|
||||
static const int NO_EVIDENCE = -1;
|
||||
extern bool logDomain;
|
||||
|
||||
extern InfAlgorithms infAlgorithm;
|
||||
|
||||
};
|
||||
|
||||
|
||||
namespace Constants {
|
||||
|
||||
// level of debug information
|
||||
const unsigned DEBUG = 2;
|
||||
|
||||
const int NO_EVIDENCE = -1;
|
||||
|
||||
// number of digits to show when printing a parameter
|
||||
static const unsigned PRECISION = 5;
|
||||
const unsigned PRECISION = 5;
|
||||
|
||||
static const bool COLLECT_STATISTICS = false;
|
||||
const bool COLLECT_STATS = false;
|
||||
|
||||
static const bool EXPORT_TO_GRAPHVIZ = false;
|
||||
static const unsigned EXPORT_MINIMAL_SIZE = 100;
|
||||
|
||||
static const double INF = -numeric_limits<double>::infinity();
|
||||
|
||||
|
||||
|
||||
namespace InfAlgorithms {
|
||||
enum InfAlgs
|
||||
{
|
||||
VE, // variable elimination
|
||||
BN_BP, // bayesian network belief propagation
|
||||
FG_BP, // factor graph belief propagation
|
||||
CBP // counting bp solver
|
||||
};
|
||||
extern InfAlgs infAlgorithm;
|
||||
};
|
||||
|
||||
|
||||
|
@ -10,10 +10,6 @@
|
||||
#include "FgBpSolver.h"
|
||||
#include "CbpSolver.h"
|
||||
|
||||
//#include "TinySet.h"
|
||||
#include "LiftedUtils.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
|
||||
void processArguments (BayesNet&, int, const char* []);
|
||||
@ -24,38 +20,9 @@ const string USAGE = "usage: \
|
||||
./hcli FILE [VARIABLE | OBSERVED_VARIABLE=EVIDENCE]..." ;
|
||||
|
||||
|
||||
class Cenas
|
||||
{
|
||||
public:
|
||||
Cenas (int cc)
|
||||
{
|
||||
c = cc;
|
||||
}
|
||||
//operator int (void) const
|
||||
//{
|
||||
// cout << "return int" << endl;
|
||||
// return c;
|
||||
//}
|
||||
operator double (void) const
|
||||
{
|
||||
cout << "return double" << endl;
|
||||
return 0.0;
|
||||
}
|
||||
private:
|
||||
int c;
|
||||
};
|
||||
|
||||
|
||||
int
|
||||
main (int argc, const char* argv[])
|
||||
{
|
||||
LogVar X = 3;
|
||||
LogVarSet Xs = X;
|
||||
cout << "set: " << X << endl;
|
||||
Cenas c1 (1);
|
||||
Cenas c2 (3);
|
||||
cout << (c1 < c2) << endl;
|
||||
return 0;
|
||||
if (!argv[1]) {
|
||||
cerr << "error: no graphical model specified" << endl;
|
||||
cerr << USAGE << endl;
|
||||
@ -99,7 +66,6 @@ processArguments (BayesNet& bn, int argc, const char* argv[])
|
||||
cerr << "error: there isn't a variable labeled of " ;
|
||||
cerr << "`" << arg << "'" ;
|
||||
cerr << endl;
|
||||
bn.freeDistributions();
|
||||
exit (0);
|
||||
}
|
||||
} else {
|
||||
@ -109,13 +75,11 @@ processArguments (BayesNet& bn, int argc, const char* argv[])
|
||||
if (label.empty()) {
|
||||
cerr << "error: missing left argument" << endl;
|
||||
cerr << USAGE << endl;
|
||||
bn.freeDistributions();
|
||||
exit (0);
|
||||
}
|
||||
if (state.empty()) {
|
||||
cerr << "error: missing right argument" << endl;
|
||||
cerr << USAGE << endl;
|
||||
bn.freeDistributions();
|
||||
exit (0);
|
||||
}
|
||||
BayesNode* node = bn.getBayesNode (label);
|
||||
@ -127,14 +91,12 @@ processArguments (BayesNet& bn, int argc, const char* argv[])
|
||||
cerr << "is not a valid state for " ;
|
||||
cerr << "`" << node->label() << "'" ;
|
||||
cerr << endl;
|
||||
bn.freeDistributions();
|
||||
exit (0);
|
||||
}
|
||||
} else {
|
||||
cerr << "error: there isn't a variable labeled of " ;
|
||||
cerr << "`" << label << "'" ;
|
||||
cerr << endl;
|
||||
bn.freeDistributions();
|
||||
exit (0);
|
||||
}
|
||||
}
|
||||
@ -142,7 +104,7 @@ processArguments (BayesNet& bn, int argc, const char* argv[])
|
||||
|
||||
Solver* solver = 0;
|
||||
FactorGraph* fg = 0;
|
||||
switch (InfAlgorithms::infAlgorithm) {
|
||||
switch (Globals::infAlgorithm) {
|
||||
case InfAlgorithms::VE:
|
||||
fg = new FactorGraph (bn);
|
||||
solver = new VarElimSolver (*fg);
|
||||
@ -163,7 +125,6 @@ processArguments (BayesNet& bn, int argc, const char* argv[])
|
||||
}
|
||||
runSolver (solver, queryVars);
|
||||
delete fg;
|
||||
bn.freeDistributions();
|
||||
}
|
||||
|
||||
|
||||
@ -179,7 +140,6 @@ processArguments (FactorGraph& fg, int argc, const char* argv[])
|
||||
cerr << "error: `" << arg << "' " ;
|
||||
cerr << "is not a valid variable id" ;
|
||||
cerr << endl;
|
||||
fg.freeDistributions();
|
||||
exit (0);
|
||||
}
|
||||
VarId vid;
|
||||
@ -193,7 +153,6 @@ processArguments (FactorGraph& fg, int argc, const char* argv[])
|
||||
cerr << "error: there isn't a variable with " ;
|
||||
cerr << "`" << vid << "' as id" ;
|
||||
cerr << endl;
|
||||
fg.freeDistributions();
|
||||
exit (0);
|
||||
}
|
||||
} else {
|
||||
@ -201,20 +160,17 @@ processArguments (FactorGraph& fg, int argc, const char* argv[])
|
||||
if (arg.substr (0, pos).empty()) {
|
||||
cerr << "error: missing left argument" << endl;
|
||||
cerr << USAGE << endl;
|
||||
fg.freeDistributions();
|
||||
exit (0);
|
||||
}
|
||||
if (arg.substr (pos + 1).empty()) {
|
||||
cerr << "error: missing right argument" << endl;
|
||||
cerr << USAGE << endl;
|
||||
fg.freeDistributions();
|
||||
exit (0);
|
||||
}
|
||||
if (!Util::isInteger (arg.substr (0, pos))) {
|
||||
cerr << "error: `" << arg.substr (0, pos) << "' " ;
|
||||
cerr << "is not a variable id" ;
|
||||
cerr << endl;
|
||||
fg.freeDistributions();
|
||||
exit (0);
|
||||
}
|
||||
VarId vid;
|
||||
@ -227,7 +183,6 @@ processArguments (FactorGraph& fg, int argc, const char* argv[])
|
||||
cerr << "error: `" << arg.substr (pos + 1) << "' " ;
|
||||
cerr << "is not a state index" ;
|
||||
cerr << endl;
|
||||
fg.freeDistributions();
|
||||
exit (0);
|
||||
}
|
||||
int stateIndex;
|
||||
@ -241,28 +196,23 @@ processArguments (FactorGraph& fg, int argc, const char* argv[])
|
||||
cerr << "is not a valid state index for variable " ;
|
||||
cerr << "`" << var->varId() << "'" ;
|
||||
cerr << endl;
|
||||
fg.freeDistributions();
|
||||
exit (0);
|
||||
}
|
||||
} else {
|
||||
cerr << "error: there isn't a variable with " ;
|
||||
cerr << "`" << vid << "' as id" ;
|
||||
cerr << endl;
|
||||
fg.freeDistributions();
|
||||
exit (0);
|
||||
}
|
||||
}
|
||||
}
|
||||
Solver* solver = 0;
|
||||
switch (InfAlgorithms::infAlgorithm) {
|
||||
switch (Globals::infAlgorithm) {
|
||||
case InfAlgorithms::VE:
|
||||
solver = new VarElimSolver (fg);
|
||||
break;
|
||||
case InfAlgorithms::BN_BP:
|
||||
case InfAlgorithms::FG_BP:
|
||||
//cout << "here!" << endl;
|
||||
//fg.printGraphicalModel();
|
||||
//fg.exportToLibDaiFormat ("net.fg");
|
||||
solver = new FgBpSolver (fg);
|
||||
break;
|
||||
case InfAlgorithms::CBP:
|
||||
@ -272,7 +222,6 @@ processArguments (FactorGraph& fg, int argc, const char* argv[])
|
||||
assert (false);
|
||||
}
|
||||
runSolver (solver, queryVars);
|
||||
fg.freeDistributions();
|
||||
}
|
||||
|
||||
|
||||
|
@ -7,22 +7,26 @@
|
||||
|
||||
#include <YapInterface.h>
|
||||
|
||||
#include "ParfactorList.h"
|
||||
#include "BayesNet.h"
|
||||
#include "FactorGraph.h"
|
||||
#include "FoveSolver.h"
|
||||
#include "VarElimSolver.h"
|
||||
#include "BnBpSolver.h"
|
||||
#include "FgBpSolver.h"
|
||||
#include "CbpSolver.h"
|
||||
#include "ElimGraph.h"
|
||||
#include "FoveSolver.h"
|
||||
#include "ParfactorList.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
||||
typedef std::pair<ParfactorList*, ObservedFormulas*> LiftedNetwork;
|
||||
|
||||
|
||||
Params readParams (YAP_Term);
|
||||
void readLiftedEvidence (YAP_Term, ObservedFormulas&);
|
||||
Parfactor* readParfactor (YAP_Term);
|
||||
|
||||
|
||||
int createLiftedNetwork (void)
|
||||
@ -30,107 +34,124 @@ int createLiftedNetwork (void)
|
||||
Parfactors parfactors;
|
||||
YAP_Term parfactorList = YAP_ARG1;
|
||||
while (parfactorList != YAP_TermNil()) {
|
||||
YAP_Term parfactor = YAP_HeadOfTerm (parfactorList);
|
||||
|
||||
// read dist id
|
||||
unsigned distId = YAP_IntOfTerm (YAP_ArgOfTerm (1, parfactor));
|
||||
|
||||
// read the ranges
|
||||
Ranges ranges;
|
||||
YAP_Term rangeList = YAP_ArgOfTerm (3, parfactor);
|
||||
while (rangeList != YAP_TermNil()) {
|
||||
unsigned range = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (rangeList));
|
||||
ranges.push_back (range);
|
||||
rangeList = YAP_TailOfTerm (rangeList);
|
||||
}
|
||||
|
||||
// read parametric random vars
|
||||
ProbFormulas formulas;
|
||||
unsigned count = 0;
|
||||
unordered_map<YAP_Term, LogVar> lvMap;
|
||||
YAP_Term pvList = YAP_ArgOfTerm (2, parfactor);
|
||||
while (pvList != YAP_TermNil()) {
|
||||
YAP_Term formulaTerm = YAP_HeadOfTerm (pvList);
|
||||
if (YAP_IsAtomTerm (formulaTerm)) {
|
||||
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (formulaTerm)));
|
||||
Symbol functor = LiftedUtils::getSymbol (name);
|
||||
formulas.push_back (ProbFormula (functor, ranges[count]));
|
||||
} else {
|
||||
LogVars logVars;
|
||||
YAP_Functor yapFunctor = YAP_FunctorOfTerm (formulaTerm);
|
||||
string name ((char*) YAP_AtomName (YAP_NameOfFunctor (yapFunctor)));
|
||||
Symbol functor = LiftedUtils::getSymbol (name);
|
||||
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
|
||||
for (unsigned i = 1; i <= arity; i++) {
|
||||
YAP_Term ti = YAP_ArgOfTerm (i, formulaTerm);
|
||||
unordered_map<YAP_Term, LogVar>::iterator it = lvMap.find (ti);
|
||||
if (it != lvMap.end()) {
|
||||
logVars.push_back (it->second);
|
||||
} else {
|
||||
unsigned newLv = lvMap.size();
|
||||
lvMap[ti] = newLv;
|
||||
logVars.push_back (newLv);
|
||||
}
|
||||
}
|
||||
formulas.push_back (ProbFormula (functor, logVars, ranges[count]));
|
||||
}
|
||||
count ++;
|
||||
pvList = YAP_TailOfTerm (pvList);
|
||||
}
|
||||
|
||||
// read the parameters
|
||||
const Params& params = readParams (YAP_ArgOfTerm (4, parfactor));
|
||||
|
||||
// read the constraint
|
||||
Tuples tuples;
|
||||
if (lvMap.size() >= 1) {
|
||||
YAP_Term tupleList = YAP_ArgOfTerm (5, parfactor);
|
||||
while (tupleList != YAP_TermNil()) {
|
||||
YAP_Term term = YAP_HeadOfTerm (tupleList);
|
||||
assert (YAP_IsApplTerm (term));
|
||||
YAP_Functor yapFunctor = YAP_FunctorOfTerm (term);
|
||||
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
|
||||
assert (lvMap.size() == arity);
|
||||
Tuple tuple (arity);
|
||||
for (unsigned i = 1; i <= arity; i++) {
|
||||
YAP_Term ti = YAP_ArgOfTerm (i, term);
|
||||
if (YAP_IsAtomTerm (ti) == false) {
|
||||
cerr << "error: bad formed constraint" << endl;
|
||||
abort();
|
||||
}
|
||||
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ti)));
|
||||
tuple[i - 1] = LiftedUtils::getSymbol (name);
|
||||
}
|
||||
tuples.push_back (tuple);
|
||||
tupleList = YAP_TailOfTerm (tupleList);
|
||||
}
|
||||
}
|
||||
parfactors.push_back (new Parfactor (formulas, params, tuples, distId));
|
||||
YAP_Term pfTerm = YAP_HeadOfTerm (parfactorList);
|
||||
parfactors.push_back (readParfactor (pfTerm));
|
||||
parfactorList = YAP_TailOfTerm (parfactorList);
|
||||
}
|
||||
|
||||
// LiftedUtils::printSymbolDictionary();
|
||||
cout << "*******************************************************" << endl;
|
||||
cout << "INITIAL PARFACTORS" << endl;
|
||||
cout << "*******************************************************" << endl;
|
||||
for (unsigned i = 0; i < parfactors.size(); i++) {
|
||||
parfactors[i]->print();
|
||||
cout << endl;
|
||||
if (Constants::DEBUG > 1) {
|
||||
// Util::printHeader ("INITIAL PARFACTORS");
|
||||
// for (unsigned i = 0; i < parfactors.size(); i++) {
|
||||
// parfactors[i]->print();
|
||||
// cout << endl;
|
||||
// }
|
||||
// parfactors[0]->countConvert (LogVar (0));
|
||||
//parfactors[1]->fullExpand (LogVar (1));
|
||||
Util::printHeader ("SHATTERED PARFACTORS");
|
||||
}
|
||||
ParfactorList* pfList = new ParfactorList();
|
||||
for (unsigned i = 0; i < parfactors.size(); i++) {
|
||||
pfList->add (parfactors[i]);
|
||||
}
|
||||
cout << endl;
|
||||
cout << "*******************************************************" << endl;
|
||||
cout << "SHATTERED PARFACTORS" << endl;
|
||||
cout << "*******************************************************" << endl;
|
||||
pfList->shatter();
|
||||
pfList->print();
|
||||
|
||||
// insert the evidence
|
||||
ObservedFormulas obsFormulas;
|
||||
YAP_Term observedList = YAP_ARG2;
|
||||
ParfactorList* pfList = new ParfactorList (parfactors);
|
||||
|
||||
if (Constants::DEBUG > 1) {
|
||||
pfList->print();
|
||||
}
|
||||
|
||||
// read evidence
|
||||
ObservedFormulas* obsFormulas = new ObservedFormulas();
|
||||
readLiftedEvidence (YAP_ARG2, *(obsFormulas));
|
||||
|
||||
LiftedNetwork* net = new LiftedNetwork (pfList, obsFormulas);
|
||||
YAP_Int p = (YAP_Int) (net);
|
||||
return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG3);
|
||||
}
|
||||
|
||||
|
||||
|
||||
Parfactor* readParfactor (YAP_Term pfTerm)
|
||||
{
|
||||
// read dist id
|
||||
unsigned distId = YAP_IntOfTerm (YAP_ArgOfTerm (1, pfTerm));
|
||||
|
||||
// read the ranges
|
||||
Ranges ranges;
|
||||
YAP_Term rangeList = YAP_ArgOfTerm (3, pfTerm);
|
||||
while (rangeList != YAP_TermNil()) {
|
||||
unsigned range = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (rangeList));
|
||||
ranges.push_back (range);
|
||||
rangeList = YAP_TailOfTerm (rangeList);
|
||||
}
|
||||
|
||||
// read parametric random vars
|
||||
ProbFormulas formulas;
|
||||
unsigned count = 0;
|
||||
unordered_map<YAP_Term, LogVar> lvMap;
|
||||
YAP_Term pvList = YAP_ArgOfTerm (2, pfTerm);
|
||||
while (pvList != YAP_TermNil()) {
|
||||
YAP_Term formulaTerm = YAP_HeadOfTerm (pvList);
|
||||
if (YAP_IsAtomTerm (formulaTerm)) {
|
||||
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (formulaTerm)));
|
||||
Symbol functor = LiftedUtils::getSymbol (name);
|
||||
formulas.push_back (ProbFormula (functor, ranges[count]));
|
||||
} else {
|
||||
LogVars logVars;
|
||||
YAP_Functor yapFunctor = YAP_FunctorOfTerm (formulaTerm);
|
||||
string name ((char*) YAP_AtomName (YAP_NameOfFunctor (yapFunctor)));
|
||||
Symbol functor = LiftedUtils::getSymbol (name);
|
||||
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
|
||||
for (unsigned i = 1; i <= arity; i++) {
|
||||
YAP_Term ti = YAP_ArgOfTerm (i, formulaTerm);
|
||||
unordered_map<YAP_Term, LogVar>::iterator it = lvMap.find (ti);
|
||||
if (it != lvMap.end()) {
|
||||
logVars.push_back (it->second);
|
||||
} else {
|
||||
unsigned newLv = lvMap.size();
|
||||
lvMap[ti] = newLv;
|
||||
logVars.push_back (newLv);
|
||||
}
|
||||
}
|
||||
formulas.push_back (ProbFormula (functor, logVars, ranges[count]));
|
||||
}
|
||||
count ++;
|
||||
pvList = YAP_TailOfTerm (pvList);
|
||||
}
|
||||
|
||||
// read the parameters
|
||||
const Params& params = readParams (YAP_ArgOfTerm (4, pfTerm));
|
||||
|
||||
// read the constraint
|
||||
Tuples tuples;
|
||||
if (lvMap.size() >= 1) {
|
||||
YAP_Term tupleList = YAP_ArgOfTerm (5, pfTerm);
|
||||
while (tupleList != YAP_TermNil()) {
|
||||
YAP_Term term = YAP_HeadOfTerm (tupleList);
|
||||
assert (YAP_IsApplTerm (term));
|
||||
YAP_Functor yapFunctor = YAP_FunctorOfTerm (term);
|
||||
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
|
||||
assert (lvMap.size() == arity);
|
||||
Tuple tuple (arity);
|
||||
for (unsigned i = 1; i <= arity; i++) {
|
||||
YAP_Term ti = YAP_ArgOfTerm (i, term);
|
||||
if (YAP_IsAtomTerm (ti) == false) {
|
||||
cerr << "error: constraint has free variables" << endl;
|
||||
abort();
|
||||
}
|
||||
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ti)));
|
||||
tuple[i - 1] = LiftedUtils::getSymbol (name);
|
||||
}
|
||||
tuples.push_back (tuple);
|
||||
tupleList = YAP_TailOfTerm (tupleList);
|
||||
}
|
||||
}
|
||||
return new Parfactor (formulas, params, tuples, distId);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void readLiftedEvidence (
|
||||
YAP_Term observedList,
|
||||
ObservedFormulas& obsFormulas)
|
||||
{
|
||||
while (observedList != YAP_TermNil()) {
|
||||
YAP_Term pair = YAP_HeadOfTerm (observedList);
|
||||
YAP_Term ground = YAP_ArgOfTerm (1, pair);
|
||||
@ -155,22 +176,18 @@ int createLiftedNetwork (void)
|
||||
unsigned evidence = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (2, pair));
|
||||
bool found = false;
|
||||
for (unsigned i = 0; i < obsFormulas.size(); i++) {
|
||||
if (obsFormulas[i]->functor() == functor &&
|
||||
obsFormulas[i]->arity() == args.size() &&
|
||||
obsFormulas[i]->evidence() == evidence) {
|
||||
obsFormulas[i]->addTuple (args);
|
||||
if (obsFormulas[i].functor() == functor &&
|
||||
obsFormulas[i].arity() == args.size() &&
|
||||
obsFormulas[i].evidence() == evidence) {
|
||||
obsFormulas[i].addTuple (args);
|
||||
found = true;
|
||||
}
|
||||
}
|
||||
if (found == false) {
|
||||
obsFormulas.push_back (new ObservedFormula (functor, evidence, args));
|
||||
obsFormulas.push_back (ObservedFormula (functor, evidence, args));
|
||||
}
|
||||
observedList = YAP_TailOfTerm (observedList);
|
||||
}
|
||||
FoveSolver::absorveEvidence (*pfList, obsFormulas);
|
||||
|
||||
YAP_Int p = (YAP_Int) (pfList);
|
||||
return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG3);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -186,7 +203,6 @@ createGroundNetwork (void)
|
||||
// }
|
||||
BayesNet* bn = new BayesNet();
|
||||
YAP_Term varList = YAP_ARG1;
|
||||
BnNodeSet nodes;
|
||||
vector<VarIds> parents;
|
||||
while (varList != YAP_TermNil()) {
|
||||
YAP_Term var = YAP_HeadOfTerm (varList);
|
||||
@ -201,15 +217,13 @@ createGroundNetwork (void)
|
||||
parents.back().push_back (parentId);
|
||||
parentL = YAP_TailOfTerm (parentL);
|
||||
}
|
||||
Distribution* dist = bn->getDistribution (distId);
|
||||
if (!dist) {
|
||||
dist = new Distribution (distId);
|
||||
bn->addDistribution (dist);
|
||||
}
|
||||
assert (bn->getBayesNode (vid) == 0);
|
||||
nodes.push_back (bn->addNode (vid, dsize, evidence, dist));
|
||||
BayesNode* newNode = new BayesNode (
|
||||
vid, dsize, evidence, Params(), distId);
|
||||
bn->addNode (newNode);
|
||||
varList = YAP_TailOfTerm (varList);
|
||||
}
|
||||
const BnNodeSet& nodes = bn->getBayesNodes();
|
||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||
BnNodeSet ps;
|
||||
for (unsigned j = 0; j < parents[i].size(); j++) {
|
||||
@ -225,41 +239,6 @@ createGroundNetwork (void)
|
||||
|
||||
|
||||
|
||||
int
|
||||
setBayesNetParams (void)
|
||||
{
|
||||
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
|
||||
YAP_Term distList = YAP_ARG2;
|
||||
while (distList != YAP_TermNil()) {
|
||||
YAP_Term dist = YAP_HeadOfTerm (distList);
|
||||
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, dist));
|
||||
const Params params = readParams (YAP_ArgOfTerm (2, dist));
|
||||
bn->getDistribution(distId)->updateParameters (params);
|
||||
distList = YAP_TailOfTerm (distList);
|
||||
}
|
||||
return TRUE;
|
||||
}
|
||||
|
||||
|
||||
|
||||
int
|
||||
setParfactorGraphParams (void)
|
||||
{
|
||||
// FIXME
|
||||
// ParfactorGraph* pfg = (ParfactorGraph*) YAP_IntOfTerm (YAP_ARG1);
|
||||
YAP_Term distList = YAP_ARG2;
|
||||
while (distList != YAP_TermNil()) {
|
||||
// YAP_Term dist = YAP_HeadOfTerm (distList);
|
||||
// unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, dist));
|
||||
// const Params params = readParams (YAP_ArgOfTerm (2, dist));
|
||||
// pfg->getDistribution(distId)->setData (params);
|
||||
distList = YAP_TailOfTerm (distList);
|
||||
}
|
||||
return TRUE;
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
readParams (YAP_Term paramL)
|
||||
{
|
||||
@ -279,15 +258,14 @@ readParams (YAP_Term paramL)
|
||||
int
|
||||
runLiftedSolver (void)
|
||||
{
|
||||
ParfactorList* pfList = (ParfactorList*) YAP_IntOfTerm (YAP_ARG1);
|
||||
LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1);
|
||||
YAP_Term taskList = YAP_ARG2;
|
||||
vector<Params> results;
|
||||
|
||||
ParfactorList pfListCopy (*network->first);
|
||||
FoveSolver::absorveEvidence (pfListCopy, *network->second);
|
||||
while (taskList != YAP_TermNil()) {
|
||||
YAP_Term jointList = YAP_HeadOfTerm (taskList);
|
||||
Grounds queryVars;
|
||||
assert (YAP_IsPairTerm (taskList));
|
||||
assert (YAP_IsPairTerm (jointList));
|
||||
YAP_Term jointList = YAP_HeadOfTerm (taskList);
|
||||
while (jointList != YAP_TermNil()) {
|
||||
YAP_Term ground = YAP_HeadOfTerm (jointList);
|
||||
if (YAP_IsAtomTerm (ground)) {
|
||||
@ -310,11 +288,11 @@ runLiftedSolver (void)
|
||||
}
|
||||
jointList = YAP_TailOfTerm (jointList);
|
||||
}
|
||||
FoveSolver solver (pfList);
|
||||
FoveSolver solver (pfListCopy);
|
||||
if (queryVars.size() == 1) {
|
||||
results.push_back (solver.getPosterioriOf (queryVars[0]));
|
||||
} else {
|
||||
assert (false); // TODO joint dist
|
||||
results.push_back (solver.getJointDistributionOf (queryVars));
|
||||
}
|
||||
taskList = YAP_TailOfTerm (taskList);
|
||||
}
|
||||
@ -339,46 +317,40 @@ runLiftedSolver (void)
|
||||
|
||||
|
||||
int
|
||||
runOtherSolvers (void)
|
||||
runGroundSolver (void)
|
||||
{
|
||||
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
|
||||
YAP_Term taskList = YAP_ARG2;
|
||||
vector<VarIds> tasks;
|
||||
std::set<VarId> vids;
|
||||
while (taskList != YAP_TermNil()) {
|
||||
if (YAP_IsPairTerm (YAP_HeadOfTerm (taskList))) {
|
||||
tasks.push_back (VarIds());
|
||||
YAP_Term jointList = YAP_HeadOfTerm (taskList);
|
||||
while (jointList != YAP_TermNil()) {
|
||||
VarId vid = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (jointList));
|
||||
assert (bn->getBayesNode (vid));
|
||||
tasks.back().push_back (vid);
|
||||
vids.insert (vid);
|
||||
jointList = YAP_TailOfTerm (jointList);
|
||||
}
|
||||
} else {
|
||||
VarId vid = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (taskList));
|
||||
VarIds queryVars;
|
||||
YAP_Term jointList = YAP_HeadOfTerm (taskList);
|
||||
while (jointList != YAP_TermNil()) {
|
||||
VarId vid = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (jointList));
|
||||
assert (bn->getBayesNode (vid));
|
||||
tasks.push_back (VarIds() = {vid});
|
||||
queryVars.push_back (vid);
|
||||
vids.insert (vid);
|
||||
jointList = YAP_TailOfTerm (jointList);
|
||||
}
|
||||
tasks.push_back (queryVars);
|
||||
taskList = YAP_TailOfTerm (taskList);
|
||||
}
|
||||
|
||||
Solver* bpSolver = 0;
|
||||
GraphicalModel* graphicalModel = 0;
|
||||
CFactorGraph::checkForIdenticalFactors = false;
|
||||
if (InfAlgorithms::infAlgorithm != InfAlgorithms::VE) {
|
||||
if (Globals::infAlgorithm != InfAlgorithms::VE) {
|
||||
BayesNet* mrn = bn->getMinimalRequesiteNetwork (
|
||||
VarIds (vids.begin(), vids.end()));
|
||||
if (InfAlgorithms::infAlgorithm == InfAlgorithms::BN_BP) {
|
||||
if (Globals::infAlgorithm == InfAlgorithms::BN_BP) {
|
||||
graphicalModel = mrn;
|
||||
bpSolver = new BnBpSolver (*static_cast<BayesNet*> (graphicalModel));
|
||||
} else if (InfAlgorithms::infAlgorithm == InfAlgorithms::FG_BP) {
|
||||
} else if (Globals::infAlgorithm == InfAlgorithms::FG_BP) {
|
||||
graphicalModel = new FactorGraph (*mrn);
|
||||
bpSolver = new FgBpSolver (*static_cast<FactorGraph*> (graphicalModel));
|
||||
delete mrn;
|
||||
} else if (InfAlgorithms::infAlgorithm == InfAlgorithms::CBP) {
|
||||
} else if (Globals::infAlgorithm == InfAlgorithms::CBP) {
|
||||
graphicalModel = new FactorGraph (*mrn);
|
||||
bpSolver = new CbpSolver (*static_cast<FactorGraph*> (graphicalModel));
|
||||
delete mrn;
|
||||
@ -389,8 +361,7 @@ runOtherSolvers (void)
|
||||
vector<Params> results;
|
||||
results.reserve (tasks.size());
|
||||
for (unsigned i = 0; i < tasks.size(); i++) {
|
||||
//if (i == 1) exit (0);
|
||||
if (InfAlgorithms::infAlgorithm == InfAlgorithms::VE) {
|
||||
if (Globals::infAlgorithm == InfAlgorithms::VE) {
|
||||
BayesNet* mrn = bn->getMinimalRequesiteNetwork (tasks[i]);
|
||||
VarElimSolver* veSolver = new VarElimSolver (*mrn);
|
||||
if (tasks[i].size() == 1) {
|
||||
@ -430,10 +401,57 @@ runOtherSolvers (void)
|
||||
|
||||
|
||||
|
||||
int
|
||||
setParfactorsParams (void)
|
||||
{
|
||||
LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1);
|
||||
ParfactorList* pfList = network->first;
|
||||
YAP_Term distList = YAP_ARG2;
|
||||
unordered_map<unsigned, Params> paramsMap;
|
||||
while (distList != YAP_TermNil()) {
|
||||
YAP_Term dist = YAP_HeadOfTerm (distList);
|
||||
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, dist));
|
||||
assert (Util::contains (paramsMap, distId) == false);
|
||||
paramsMap[distId] = readParams (YAP_ArgOfTerm (2, dist));
|
||||
distList = YAP_TailOfTerm (distList);
|
||||
}
|
||||
ParfactorList::iterator it = pfList->begin();
|
||||
while (it != pfList->end()) {
|
||||
assert (Util::contains (paramsMap, (*it)->distId()));
|
||||
// (*it)->setParams (paramsMap[(*it)->distId()]);
|
||||
++ it;
|
||||
}
|
||||
return TRUE;
|
||||
}
|
||||
|
||||
|
||||
|
||||
int
|
||||
setBayesNetParams (void)
|
||||
{
|
||||
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
|
||||
YAP_Term distList = YAP_ARG2;
|
||||
unordered_map<unsigned, Params> paramsMap;
|
||||
while (distList != YAP_TermNil()) {
|
||||
YAP_Term dist = YAP_HeadOfTerm (distList);
|
||||
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, dist));
|
||||
assert (Util::contains (paramsMap, distId) == false);
|
||||
paramsMap[distId] = readParams (YAP_ArgOfTerm (2, dist));
|
||||
distList = YAP_TailOfTerm (distList);
|
||||
}
|
||||
const BnNodeSet& nodes = bn->getBayesNodes();
|
||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||
assert (Util::contains (paramsMap, nodes[i]->distId()));
|
||||
nodes[i]->setParams (paramsMap[nodes[i]->distId()]);
|
||||
}
|
||||
return TRUE;
|
||||
}
|
||||
|
||||
|
||||
|
||||
int
|
||||
setExtraVarsInfo (void)
|
||||
{
|
||||
// BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
|
||||
GraphicalModel::clearVariablesInformation();
|
||||
YAP_Term varsInfoL = YAP_ARG2;
|
||||
while (varsInfoL != YAP_TermNil()) {
|
||||
@ -463,13 +481,13 @@ setHorusFlag (void)
|
||||
if (key == "inf_alg") {
|
||||
string value ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG2)));
|
||||
if ( value == "ve") {
|
||||
InfAlgorithms::infAlgorithm = InfAlgorithms::VE;
|
||||
Globals::infAlgorithm = InfAlgorithms::VE;
|
||||
} else if (value == "bn_bp") {
|
||||
InfAlgorithms::infAlgorithm = InfAlgorithms::BN_BP;
|
||||
Globals::infAlgorithm = InfAlgorithms::BN_BP;
|
||||
} else if (value == "fg_bp") {
|
||||
InfAlgorithms::infAlgorithm = InfAlgorithms::FG_BP;
|
||||
Globals::infAlgorithm = InfAlgorithms::FG_BP;
|
||||
} else if (value == "cbp") {
|
||||
InfAlgorithms::infAlgorithm = InfAlgorithms::CBP;
|
||||
Globals::infAlgorithm = InfAlgorithms::CBP;
|
||||
} else {
|
||||
cerr << "warning: invalid value `" << value << "' " ;
|
||||
cerr << "for `" << key << "'" << endl;
|
||||
@ -543,19 +561,19 @@ setHorusFlag (void)
|
||||
int
|
||||
freeBayesNetwork (void)
|
||||
{
|
||||
//Statistics::writeStatisticsToFile ("stats.txt");
|
||||
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
|
||||
bn->freeDistributions();
|
||||
delete bn;
|
||||
delete (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
|
||||
return TRUE;
|
||||
}
|
||||
|
||||
|
||||
|
||||
int
|
||||
freeParfactorGraph (void)
|
||||
freeParfactors (void)
|
||||
{
|
||||
delete (ParfactorList*) YAP_IntOfTerm (YAP_ARG1);
|
||||
LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1);
|
||||
delete network->first;
|
||||
delete network->second;
|
||||
delete network;
|
||||
return TRUE;
|
||||
}
|
||||
|
||||
@ -564,15 +582,15 @@ freeParfactorGraph (void)
|
||||
extern "C" void
|
||||
init_predicates (void)
|
||||
{
|
||||
YAP_UserCPredicate ("create_lifted_network", createLiftedNetwork, 3);
|
||||
YAP_UserCPredicate ("create_ground_network", createGroundNetwork, 2);
|
||||
YAP_UserCPredicate ("set_parfactor_graph_params", setParfactorGraphParams, 2);
|
||||
YAP_UserCPredicate ("set_bayes_net_params", setBayesNetParams, 2);
|
||||
YAP_UserCPredicate ("run_lifted_solver", runLiftedSolver, 3);
|
||||
YAP_UserCPredicate ("run_other_solvers", runOtherSolvers, 3);
|
||||
YAP_UserCPredicate ("set_extra_vars_info", setExtraVarsInfo, 2);
|
||||
YAP_UserCPredicate ("set_horus_flag", setHorusFlag, 2);
|
||||
YAP_UserCPredicate ("free_bayesian_network", freeBayesNetwork, 1);
|
||||
YAP_UserCPredicate ("free_parfactor_graph", freeParfactorGraph, 1);
|
||||
YAP_UserCPredicate ("create_lifted_network", createLiftedNetwork, 3);
|
||||
YAP_UserCPredicate ("create_ground_network", createGroundNetwork, 2);
|
||||
YAP_UserCPredicate ("run_lifted_solver", runLiftedSolver, 3);
|
||||
YAP_UserCPredicate ("run_ground_solver", runGroundSolver, 3);
|
||||
YAP_UserCPredicate ("set_parfactors_params", setParfactorsParams, 2);
|
||||
YAP_UserCPredicate ("set_bayes_net_params", setBayesNetParams, 2);
|
||||
YAP_UserCPredicate ("set_extra_vars_info", setExtraVarsInfo, 2);
|
||||
YAP_UserCPredicate ("set_horus_flag", setHorusFlag, 2);
|
||||
YAP_UserCPredicate ("free_parfactors", freeParfactors, 1);
|
||||
YAP_UserCPredicate ("free_bayesian_network", freeBayesNetwork, 1);
|
||||
}
|
||||
|
||||
|
@ -12,7 +12,9 @@
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
class StatesIndexer {
|
||||
|
||||
class StatesIndexer
|
||||
{
|
||||
public:
|
||||
|
||||
StatesIndexer (const Ranges& ranges, bool calcOffsets = true)
|
||||
@ -134,11 +136,11 @@ class StatesIndexer {
|
||||
return size_ ;
|
||||
}
|
||||
|
||||
friend ostream& operator<< (ostream &out, const StatesIndexer& idx)
|
||||
friend ostream& operator<< (ostream &os, const StatesIndexer& idx)
|
||||
{
|
||||
out << "(" << std::setw (2) << std::setfill('0') << idx.li_ << ") " ;
|
||||
out << idx.indices_;
|
||||
return out;
|
||||
os << "(" << std::setw (2) << std::setfill('0') << idx.li_ << ") " ;
|
||||
os << idx.indices_;
|
||||
return os;
|
||||
}
|
||||
|
||||
private:
|
||||
@ -274,21 +276,14 @@ class MapIndexer
|
||||
index_ = 0;
|
||||
}
|
||||
|
||||
friend ostream& operator<< (ostream &out, const MapIndexer& idx)
|
||||
friend ostream& operator<< (ostream &os, const MapIndexer& idx)
|
||||
{
|
||||
out << "(" << std::setw (2) << std::setfill('0') << idx.index_ << ") " ;
|
||||
out << idx.indices_;
|
||||
return out;
|
||||
os << "(" << std::setw (2) << std::setfill('0') << idx.index_ << ") " ;
|
||||
os << idx.indices_;
|
||||
return os;
|
||||
}
|
||||
|
||||
private:
|
||||
MapIndexer (const Ranges& ranges) :
|
||||
ranges_(ranges),
|
||||
indices_(ranges.size(), 0),
|
||||
offsets_(ranges.size())
|
||||
{
|
||||
index_ = 0;
|
||||
}
|
||||
unsigned index_;
|
||||
bool valid_;
|
||||
vector<unsigned> ranges_;
|
||||
|
@ -95,26 +95,37 @@ ostream& operator<< (ostream &os, const Ground& gr)
|
||||
|
||||
|
||||
|
||||
void
|
||||
ObservedFormula::addTuple (const Tuple& t)
|
||||
LogVars
|
||||
Substitution::getDiscardedLogVars (void) const
|
||||
{
|
||||
if (constr_ == 0) {
|
||||
LogVars lvs (arity_);
|
||||
for (unsigned i = 0; i < arity_; i++) {
|
||||
lvs[i] = i;
|
||||
LogVars discardedLvs;
|
||||
set<LogVar> doneLvs;
|
||||
unordered_map<LogVar, LogVar>::const_iterator it;
|
||||
it = subs_.begin();
|
||||
while (it != subs_.end()) {
|
||||
if (Util::contains (doneLvs, it->second)) {
|
||||
discardedLvs.push_back (it->first);
|
||||
} else {
|
||||
doneLvs.insert (it->second);
|
||||
}
|
||||
constr_ = new ConstraintTree (lvs);
|
||||
it ++;
|
||||
}
|
||||
constr_->addTuple (t);
|
||||
return discardedLvs;
|
||||
}
|
||||
|
||||
|
||||
|
||||
ostream& operator<< (ostream &os, const ObservedFormula of)
|
||||
ostream& operator<< (ostream &os, const Substitution& theta)
|
||||
{
|
||||
os << of.functor_ << "/" << of.arity_;
|
||||
os << "|" << of.constr_->tupleSet();
|
||||
os << " [evidence=" << of.evidence_ << "]";
|
||||
unordered_map<LogVar, LogVar>::const_iterator it;
|
||||
os << "[" ;
|
||||
it = theta.subs_.begin();
|
||||
while (it != theta.subs_.end()) {
|
||||
if (it != theta.subs_.begin()) os << ", " ;
|
||||
os << it->first << "->" << it->second ;
|
||||
++ it;
|
||||
}
|
||||
os << "]" ;
|
||||
return os;
|
||||
}
|
||||
|
||||
|
@ -18,11 +18,17 @@ class Symbol
|
||||
{
|
||||
public:
|
||||
Symbol (void) : id_(numeric_limits<unsigned>::max()) { }
|
||||
|
||||
Symbol (unsigned id) : id_(id) { }
|
||||
|
||||
operator unsigned (void) const { return id_; }
|
||||
|
||||
bool valid (void) const { return id_ != numeric_limits<unsigned>::max(); }
|
||||
|
||||
static Symbol invalid (void) { return Symbol(); }
|
||||
|
||||
friend ostream& operator<< (ostream &os, const Symbol& s);
|
||||
|
||||
private:
|
||||
unsigned id_;
|
||||
};
|
||||
@ -32,7 +38,9 @@ class LogVar
|
||||
{
|
||||
public:
|
||||
LogVar (void) : id_(numeric_limits<unsigned>::max()) { }
|
||||
|
||||
LogVar (unsigned id) : id_(id) { }
|
||||
|
||||
operator unsigned (void) const { return id_; }
|
||||
|
||||
LogVar& operator++ (void)
|
||||
@ -48,6 +56,7 @@ class LogVar
|
||||
}
|
||||
|
||||
friend ostream& operator<< (ostream &os, const LogVar& X);
|
||||
|
||||
private:
|
||||
unsigned id_;
|
||||
};
|
||||
@ -79,8 +88,8 @@ ostream& operator<< (ostream &os, const Tuple& t);
|
||||
|
||||
|
||||
namespace LiftedUtils {
|
||||
Symbol getSymbol (const string&);
|
||||
void printSymbolDictionary (void);
|
||||
Symbol getSymbol (const string&);
|
||||
void printSymbolDictionary (void);
|
||||
}
|
||||
|
||||
|
||||
@ -89,71 +98,56 @@ class Ground
|
||||
{
|
||||
public:
|
||||
Ground (Symbol f) : functor_(f) { }
|
||||
|
||||
Ground (Symbol f, const Symbols& args) : functor_(f), args_(args) { }
|
||||
|
||||
Symbol functor (void) const { return functor_; }
|
||||
Symbols args (void) const { return args_; }
|
||||
unsigned arity (void) const { return args_.size(); }
|
||||
bool isAtom (void) const { return args_.size() == 0; }
|
||||
Symbol functor (void) const { return functor_; }
|
||||
|
||||
Symbols args (void) const { return args_; }
|
||||
|
||||
unsigned arity (void) const { return args_.size(); }
|
||||
|
||||
bool isAtom (void) const { return args_.size() == 0; }
|
||||
|
||||
friend ostream& operator<< (ostream &os, const Ground& gr);
|
||||
|
||||
private:
|
||||
Symbol functor_;
|
||||
Symbols args_;
|
||||
Symbol functor_;
|
||||
Symbols args_;
|
||||
};
|
||||
|
||||
typedef vector<Ground> Grounds;
|
||||
|
||||
|
||||
|
||||
class ConstraintTree;
|
||||
class ObservedFormula
|
||||
{
|
||||
public:
|
||||
ObservedFormula (Symbol f, unsigned a, unsigned ev)
|
||||
: functor_(f), arity_(a), evidence_(ev), constr_(0) { }
|
||||
|
||||
ObservedFormula (Symbol f, unsigned ev, const Tuple& tuple)
|
||||
: functor_(f), arity_(tuple.size()), evidence_(ev), constr_(0)
|
||||
{
|
||||
addTuple (tuple);
|
||||
}
|
||||
|
||||
Symbol functor (void) const { return functor_; }
|
||||
unsigned arity (void) const { return arity_; }
|
||||
unsigned evidence (void) const { return evidence_; }
|
||||
ConstraintTree* constr (void) const { return constr_; }
|
||||
bool isAtom (void) const { return arity_ == 0; }
|
||||
|
||||
void addTuple (const Tuple& t);
|
||||
friend ostream& operator<< (ostream &os, const ObservedFormula opv);
|
||||
private:
|
||||
Symbol functor_;
|
||||
unsigned arity_;
|
||||
unsigned evidence_;
|
||||
ConstraintTree* constr_;
|
||||
};
|
||||
typedef vector<ObservedFormula*> ObservedFormulas;
|
||||
|
||||
|
||||
|
||||
class Substitution
|
||||
{
|
||||
public:
|
||||
void add (LogVar X_old, LogVar X_new)
|
||||
{
|
||||
assert (Util::contains (subs_, X_old) == false);
|
||||
subs_.insert (make_pair (X_old, X_new));
|
||||
}
|
||||
|
||||
void rename (LogVar X_old, LogVar X_new)
|
||||
{
|
||||
assert (subs_.find (X_old) != subs_.end());
|
||||
assert (Util::contains (subs_, X_old));
|
||||
subs_.find (X_old)->second = X_new;
|
||||
}
|
||||
|
||||
LogVar newNameFor (LogVar X) const
|
||||
{
|
||||
assert (subs_.find (X) != subs_.end());
|
||||
assert (Util::contains (subs_, X));
|
||||
return subs_.find (X)->second;
|
||||
}
|
||||
|
||||
LogVars getDiscardedLogVars (void) const;
|
||||
|
||||
friend ostream& operator<< (ostream &os, const Substitution& theta);
|
||||
|
||||
private:
|
||||
unordered_map<LogVar, LogVar> subs_;
|
||||
|
||||
};
|
||||
|
||||
|
||||
|
@ -60,7 +60,6 @@ HEADERS = \
|
||||
$(srcdir)/CbpSolver.h \
|
||||
$(srcdir)/FoveSolver.h \
|
||||
$(srcdir)/VarNode.h \
|
||||
$(srcdir)/Distribution.h \
|
||||
$(srcdir)/Indexer.h \
|
||||
$(srcdir)/Parfactor.h \
|
||||
$(srcdir)/ProbFormula.h \
|
||||
|
@ -2,6 +2,7 @@
|
||||
#include "Parfactor.h"
|
||||
#include "Histogram.h"
|
||||
#include "Indexer.h"
|
||||
#include "Util.h"
|
||||
#include "Horus.h"
|
||||
|
||||
|
||||
@ -11,55 +12,58 @@ Parfactor::Parfactor (
|
||||
const Tuples& tuples,
|
||||
unsigned distId)
|
||||
{
|
||||
formulas_ = formulas;
|
||||
params_ = params;
|
||||
distId_ = distId;
|
||||
args_ = formulas;
|
||||
params_ = params;
|
||||
distId_ = distId;
|
||||
|
||||
LogVars logVars;
|
||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
||||
ranges_.push_back (formulas_[i].range());
|
||||
const LogVars& lvs = formulas_[i].logVars();
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
ranges_.push_back (args_[i].range());
|
||||
const LogVars& lvs = args_[i].logVars();
|
||||
for (unsigned j = 0; j < lvs.size(); j++) {
|
||||
if (std::find (logVars.begin(), logVars.end(), lvs[j]) ==
|
||||
logVars.end()) {
|
||||
if (Util::contains (logVars, lvs[j]) == false) {
|
||||
logVars.push_back (lvs[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
constr_ = new ConstraintTree (logVars, tuples);
|
||||
assert (params_.size() == Util::expectedSize (ranges_));
|
||||
}
|
||||
|
||||
|
||||
|
||||
Parfactor::Parfactor (const Parfactor* g, const Tuple& tuple)
|
||||
{
|
||||
formulas_ = g->formulas();
|
||||
params_ = g->params();
|
||||
ranges_ = g->ranges();
|
||||
distId_ = g->distId();
|
||||
constr_ = new ConstraintTree (g->logVars(), {tuple});
|
||||
args_ = g->arguments();
|
||||
params_ = g->params();
|
||||
ranges_ = g->ranges();
|
||||
distId_ = g->distId();
|
||||
constr_ = new ConstraintTree (g->logVars(), {tuple});
|
||||
assert (params_.size() == Util::expectedSize (ranges_));
|
||||
}
|
||||
|
||||
|
||||
|
||||
Parfactor::Parfactor (const Parfactor* g, ConstraintTree* constr)
|
||||
{
|
||||
formulas_ = g->formulas();
|
||||
params_ = g->params();
|
||||
ranges_ = g->ranges();
|
||||
distId_ = g->distId();
|
||||
constr_ = constr;
|
||||
args_ = g->arguments();
|
||||
params_ = g->params();
|
||||
ranges_ = g->ranges();
|
||||
distId_ = g->distId();
|
||||
constr_ = constr;
|
||||
assert (params_.size() == Util::expectedSize (ranges_));
|
||||
}
|
||||
|
||||
|
||||
|
||||
Parfactor::Parfactor (const Parfactor& g)
|
||||
{
|
||||
formulas_ = g.formulas();
|
||||
params_ = g.params();
|
||||
ranges_ = g.ranges();
|
||||
distId_ = g.distId();
|
||||
constr_ = new ConstraintTree (*g.constr());
|
||||
args_ = g.arguments();
|
||||
params_ = g.params();
|
||||
ranges_ = g.ranges();
|
||||
distId_ = g.distId();
|
||||
constr_ = new ConstraintTree (*g.constr());
|
||||
assert (params_.size() == Util::expectedSize (ranges_));
|
||||
}
|
||||
|
||||
|
||||
@ -75,9 +79,9 @@ LogVarSet
|
||||
Parfactor::countedLogVars (void) const
|
||||
{
|
||||
LogVarSet set;
|
||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
||||
if (formulas_[i].isCounting()) {
|
||||
set.insert (formulas_[i].countedLogVar());
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
if (args_[i].isCounting()) {
|
||||
set.insert (args_[i].countedLogVar());
|
||||
}
|
||||
}
|
||||
return set;
|
||||
@ -107,14 +111,14 @@ Parfactor::elimLogVars (void) const
|
||||
LogVarSet
|
||||
Parfactor::exclusiveLogVars (unsigned fIdx) const
|
||||
{
|
||||
assert (fIdx < formulas_.size());
|
||||
assert (fIdx < args_.size());
|
||||
LogVarSet remaining;
|
||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
if (i != fIdx) {
|
||||
remaining |= formulas_[i].logVarSet();
|
||||
remaining |= args_[i].logVarSet();
|
||||
}
|
||||
}
|
||||
return formulas_[fIdx].logVarSet() - remaining;
|
||||
return args_[fIdx].logVarSet() - remaining;
|
||||
}
|
||||
|
||||
|
||||
@ -131,44 +135,51 @@ Parfactor::setConstraintTree (ConstraintTree* newTree)
|
||||
void
|
||||
Parfactor::sumOut (unsigned fIdx)
|
||||
{
|
||||
assert (fIdx < formulas_.size());
|
||||
assert (formulas_[fIdx].contains (elimLogVars()));
|
||||
assert (fIdx < args_.size());
|
||||
assert (args_[fIdx].contains (elimLogVars()));
|
||||
|
||||
LogVarSet excl = exclusiveLogVars (fIdx);
|
||||
unsigned condCount = constr_->getConditionalCount (excl);
|
||||
Util::pow (params_, condCount);
|
||||
if (args_[fIdx].isCounting()) {
|
||||
LogAware::pow (params_, constr_->getConditionalCount (
|
||||
excl - args_[fIdx].countedLogVar()));
|
||||
} else {
|
||||
LogAware::pow (params_, constr_->getConditionalCount (excl));
|
||||
}
|
||||
|
||||
vector<unsigned> numAssigns (ranges_[fIdx], 1);
|
||||
if (formulas_[fIdx].isCounting()) {
|
||||
if (args_[fIdx].isCounting()) {
|
||||
unsigned N = constr_->getConditionalCount (
|
||||
formulas_[fIdx].countedLogVar());
|
||||
unsigned R = formulas_[fIdx].range();
|
||||
unsigned H = ranges_[fIdx];
|
||||
HistogramSet hs (N, R);
|
||||
unsigned N_factorial = Util::factorial (N);
|
||||
for (unsigned h = 0; h < H; h++) {
|
||||
unsigned prod = 1;
|
||||
for (unsigned r = 0; r < R; r++) {
|
||||
prod *= Util::factorial (hs[r]);
|
||||
args_[fIdx].countedLogVar());
|
||||
unsigned R = args_[fIdx].range();
|
||||
vector<double> numAssigns = HistogramSet::getNumAssigns (N, R);
|
||||
StatesIndexer sindexer (ranges_, fIdx);
|
||||
while (sindexer.valid()) {
|
||||
unsigned h = sindexer[fIdx];
|
||||
if (Globals::logDomain) {
|
||||
params_[sindexer] += numAssigns[h];
|
||||
} else {
|
||||
params_[sindexer] *= numAssigns[h];
|
||||
}
|
||||
numAssigns[h] = N_factorial / prod;
|
||||
hs.nextHistogram();
|
||||
++ sindexer;
|
||||
}
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
Params copy = params_;
|
||||
params_.clear();
|
||||
params_.resize (copy.size() / ranges_[fIdx], 0.0);
|
||||
|
||||
params_.resize (copy.size() / ranges_[fIdx], LogAware::addIdenty());
|
||||
MapIndexer indexer (ranges_, fIdx);
|
||||
for (unsigned i = 0; i < copy.size(); i++) {
|
||||
unsigned h = indexer[fIdx];
|
||||
// TODO NOT LOG DOMAIN AWARE :(
|
||||
params_[indexer] += numAssigns[h] * copy[i];
|
||||
++ indexer;
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned i = 0; i < copy.size(); i++) {
|
||||
params_[indexer] = Util::logSum (params_[indexer], copy[i]);
|
||||
++ indexer;
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = 0; i < copy.size(); i++) {
|
||||
params_[indexer] += copy[i];
|
||||
++ indexer;
|
||||
}
|
||||
}
|
||||
formulas_.erase (formulas_.begin() + fIdx);
|
||||
|
||||
args_.erase (args_.begin() + fIdx);
|
||||
ranges_.erase (ranges_.begin() + fIdx);
|
||||
constr_->remove (excl);
|
||||
}
|
||||
@ -179,55 +190,7 @@ void
|
||||
Parfactor::multiply (Parfactor& g)
|
||||
{
|
||||
alignAndExponentiate (this, &g);
|
||||
bool sharedVars = false;
|
||||
vector<unsigned> g_varpos;
|
||||
const ProbFormulas& g_formulas = g.formulas();
|
||||
const Params& g_params = g.params();
|
||||
const Ranges& g_ranges = g.ranges();
|
||||
|
||||
for (unsigned i = 0; i < g_formulas.size(); i++) {
|
||||
int group = g_formulas[i].group();
|
||||
if (indexOfFormulaWithGroup (group) == -1) {
|
||||
insertDimension (g.ranges()[i]);
|
||||
formulas_.push_back (g_formulas[i]);
|
||||
g_varpos.push_back (formulas_.size() - 1);
|
||||
} else {
|
||||
sharedVars = true;
|
||||
g_varpos.push_back (indexOfFormulaWithGroup (group));
|
||||
}
|
||||
}
|
||||
|
||||
if (sharedVars == false) {
|
||||
unsigned count = 0;
|
||||
for (unsigned i = 0; i < params_.size(); i++) {
|
||||
if (Globals::logDomain) {
|
||||
params_[i] += g_params[count];
|
||||
} else {
|
||||
params_[i] *= g_params[count];
|
||||
}
|
||||
count ++;
|
||||
if (count >= g_params.size()) {
|
||||
count = 0;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
StatesIndexer indexer (ranges_, false);
|
||||
while (indexer.valid()) {
|
||||
unsigned g_li = 0;
|
||||
unsigned prod = 1;
|
||||
for (int j = g_varpos.size() - 1; j >= 0; j--) {
|
||||
g_li += indexer[g_varpos[j]] * prod;
|
||||
prod *= g_ranges[j];
|
||||
}
|
||||
if (Globals::logDomain) {
|
||||
params_[indexer] += g_params[g_li];
|
||||
} else {
|
||||
params_[indexer] *= g_params[g_li];
|
||||
}
|
||||
++ indexer;
|
||||
}
|
||||
}
|
||||
|
||||
TFactor<ProbFormula>::multiply (g);
|
||||
constr_->join (g.constr(), true);
|
||||
}
|
||||
|
||||
@ -236,7 +199,7 @@ Parfactor::multiply (Parfactor& g)
|
||||
void
|
||||
Parfactor::countConvert (LogVar X)
|
||||
{
|
||||
int fIdx = indexOfFormulaWithLogVar (X);
|
||||
int fIdx = indexOfLogVar (X);
|
||||
assert (fIdx != -1);
|
||||
assert (constr_->isCountNormalized (X));
|
||||
assert (constr_->getConditionalCount (X) > 1);
|
||||
@ -248,12 +211,12 @@ Parfactor::countConvert (LogVar X)
|
||||
vector<Histogram> histograms = HistogramSet::getHistograms (N, R);
|
||||
|
||||
StatesIndexer indexer (ranges_);
|
||||
vector<Params> summout (params_.size() / R);
|
||||
vector<Params> sumout (params_.size() / R);
|
||||
unsigned count = 0;
|
||||
while (indexer.valid()) {
|
||||
summout[count].reserve (R);
|
||||
sumout[count].reserve (R);
|
||||
for (unsigned r = 0; r < R; r++) {
|
||||
summout[count].push_back (params_[indexer]);
|
||||
sumout[count].push_back (params_[indexer]);
|
||||
indexer.increment (fIdx);
|
||||
}
|
||||
count ++;
|
||||
@ -262,45 +225,42 @@ Parfactor::countConvert (LogVar X)
|
||||
}
|
||||
|
||||
params_.clear();
|
||||
params_.reserve (summout.size() * H);
|
||||
params_.reserve (sumout.size() * H);
|
||||
|
||||
vector<bool> mapDims (ranges_.size(), true);
|
||||
ranges_[fIdx] = H;
|
||||
mapDims[fIdx] = false;
|
||||
MapIndexer mapIndexer (ranges_, mapDims);
|
||||
MapIndexer mapIndexer (ranges_, fIdx);
|
||||
while (mapIndexer.valid()) {
|
||||
double prod = 1.0;
|
||||
double prod = LogAware::multIdenty();
|
||||
unsigned i = mapIndexer.mappedIndex();
|
||||
unsigned h = mapIndexer[fIdx];
|
||||
for (unsigned r = 0; r < R; r++) {
|
||||
// TODO not log domain aware
|
||||
prod *= Util::pow (summout[i][r], histograms[h][r]);
|
||||
if (Globals::logDomain) {
|
||||
prod += LogAware::pow (sumout[i][r], histograms[h][r]);
|
||||
} else {
|
||||
prod *= LogAware::pow (sumout[i][r], histograms[h][r]);
|
||||
}
|
||||
}
|
||||
params_.push_back (prod);
|
||||
++ mapIndexer;
|
||||
}
|
||||
formulas_[fIdx].setCountedLogVar (X);
|
||||
args_[fIdx].setCountedLogVar (X);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Parfactor::expandPotential (
|
||||
LogVar X,
|
||||
LogVar X_new1,
|
||||
LogVar X_new2)
|
||||
Parfactor::expand (LogVar X, LogVar X_new1, LogVar X_new2)
|
||||
{
|
||||
int fIdx = indexOfFormulaWithLogVar (X);
|
||||
int fIdx = indexOfLogVar (X);
|
||||
assert (fIdx != -1);
|
||||
assert (formulas_[fIdx].isCounting());
|
||||
assert (args_[fIdx].isCounting());
|
||||
|
||||
unsigned N1 = constr_->getConditionalCount (X_new1);
|
||||
unsigned N2 = constr_->getConditionalCount (X_new2);
|
||||
unsigned N = N1 + N2;
|
||||
unsigned R = formulas_[fIdx].range();
|
||||
unsigned R = args_[fIdx].range();
|
||||
unsigned H1 = HistogramSet::nrHistograms (N1, R);
|
||||
unsigned H2 = HistogramSet::nrHistograms (N2, R);
|
||||
unsigned H = ranges_[fIdx];
|
||||
|
||||
vector<Histogram> histograms = HistogramSet::getHistograms (N, R);
|
||||
vector<Histogram> histograms1 = HistogramSet::getHistograms (N1, R);
|
||||
@ -320,48 +280,11 @@ Parfactor::expandPotential (
|
||||
}
|
||||
}
|
||||
|
||||
unsigned size = (params_.size() / H) * H1 * H2;
|
||||
Params copy = params_;
|
||||
params_.clear();
|
||||
params_.reserve (size);
|
||||
expandPotential (fIdx, H1 * H2, sumIndexes);
|
||||
|
||||
unsigned prod = 1;
|
||||
vector<unsigned> offsets_ (ranges_.size());
|
||||
for (int i = ranges_.size() - 1; i >= 0; i--) {
|
||||
offsets_[i] = prod;
|
||||
prod *= ranges_[i];
|
||||
}
|
||||
|
||||
unsigned index = 0;
|
||||
ranges_[fIdx] = H1 * H2;
|
||||
vector<unsigned> indices (ranges_.size(), 0);
|
||||
for (unsigned k = 0; k < size; k++) {
|
||||
params_.push_back (copy[index]);
|
||||
for (int i = ranges_.size() - 1; i >= 0; i--) {
|
||||
indices[i] ++;
|
||||
if (i == fIdx) {
|
||||
int diff = sumIndexes[indices[i]] - sumIndexes[indices[i] - 1];
|
||||
index += diff * offsets_[i];
|
||||
} else {
|
||||
index += offsets_[i];
|
||||
}
|
||||
if (indices[i] != ranges_[i]) {
|
||||
break;
|
||||
} else {
|
||||
if (i == fIdx) {
|
||||
int diff = sumIndexes[0] - sumIndexes[indices[i]];
|
||||
index += diff * offsets_[i];
|
||||
} else {
|
||||
index -= offsets_[i] * ranges_[i];
|
||||
}
|
||||
indices[i] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
formulas_.insert (formulas_.begin() + fIdx + 1, formulas_[fIdx]);
|
||||
formulas_[fIdx].rename (X, X_new1);
|
||||
formulas_[fIdx + 1].rename (X, X_new2);
|
||||
args_.insert (args_.begin() + fIdx + 1, args_[fIdx]);
|
||||
args_[fIdx].rename (X, X_new1);
|
||||
args_[fIdx + 1].rename (X, X_new2);
|
||||
ranges_.insert (ranges_.begin() + fIdx + 1, H2);
|
||||
ranges_[fIdx] = H1;
|
||||
}
|
||||
@ -371,13 +294,12 @@ Parfactor::expandPotential (
|
||||
void
|
||||
Parfactor::fullExpand (LogVar X)
|
||||
{
|
||||
int fIdx = indexOfFormulaWithLogVar (X);
|
||||
int fIdx = indexOfLogVar (X);
|
||||
assert (fIdx != -1);
|
||||
assert (formulas_[fIdx].isCounting());
|
||||
assert (args_[fIdx].isCounting());
|
||||
|
||||
unsigned N = constr_->getConditionalCount (X);
|
||||
unsigned R = formulas_[fIdx].range();
|
||||
unsigned H = ranges_[fIdx];
|
||||
unsigned R = args_[fIdx].range();
|
||||
|
||||
vector<Histogram> originHists = HistogramSet::getHistograms (N, R);
|
||||
vector<Histogram> expandHists = HistogramSet::getHistograms (1, R);
|
||||
@ -400,54 +322,17 @@ Parfactor::fullExpand (LogVar X)
|
||||
++ indexer;
|
||||
}
|
||||
|
||||
unsigned size = (params_.size() / H) * std::pow (R, N);
|
||||
Params copy = params_;
|
||||
params_.clear();
|
||||
params_.reserve (size);
|
||||
expandPotential (fIdx, std::pow (R, N), sumIndexes);
|
||||
|
||||
unsigned prod = 1;
|
||||
vector<unsigned> offsets_ (ranges_.size());
|
||||
for (int i = ranges_.size() - 1; i >= 0; i--) {
|
||||
offsets_[i] = prod;
|
||||
prod *= ranges_[i];
|
||||
}
|
||||
|
||||
unsigned index = 0;
|
||||
ranges_[fIdx] = std::pow (R, N);
|
||||
vector<unsigned> indices (ranges_.size(), 0);
|
||||
for (unsigned k = 0; k < size; k++) {
|
||||
params_.push_back (copy[index]);
|
||||
for (int i = ranges_.size() - 1; i >= 0; i--) {
|
||||
indices[i] ++;
|
||||
if (i == fIdx) {
|
||||
int diff = sumIndexes[indices[i]] - sumIndexes[indices[i] - 1];
|
||||
index += diff * offsets_[i];
|
||||
} else {
|
||||
index += offsets_[i];
|
||||
}
|
||||
if (indices[i] != ranges_[i]) {
|
||||
break;
|
||||
} else {
|
||||
if (i == fIdx) {
|
||||
int diff = sumIndexes[0] - sumIndexes[indices[i]];
|
||||
index += diff * offsets_[i];
|
||||
} else {
|
||||
index -= offsets_[i] * ranges_[i];
|
||||
}
|
||||
indices[i] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ProbFormula f = formulas_[fIdx];
|
||||
formulas_.erase (formulas_.begin() + fIdx);
|
||||
ProbFormula f = args_[fIdx];
|
||||
args_.erase (args_.begin() + fIdx);
|
||||
ranges_.erase (ranges_.begin() + fIdx);
|
||||
LogVars newLvs = constr_->expand (X);
|
||||
assert (newLvs.size() == N);
|
||||
for (unsigned i = 0 ; i < N; i++) {
|
||||
ProbFormula newFormula (f.functor(), f.logVars(), f.range());
|
||||
newFormula.rename (X, newLvs[i]);
|
||||
formulas_.insert (formulas_.begin() + fIdx + i, newFormula);
|
||||
args_.insert (args_.begin() + fIdx + i, newFormula);
|
||||
ranges_.insert (ranges_.begin() + fIdx + i, R);
|
||||
}
|
||||
}
|
||||
@ -459,117 +344,43 @@ Parfactor::reorderAccordingGrounds (const Grounds& grounds)
|
||||
{
|
||||
ProbFormulas newFormulas;
|
||||
for (unsigned i = 0; i < grounds.size(); i++) {
|
||||
for (unsigned j = 0; j < formulas_.size(); j++) {
|
||||
if (grounds[i].functor() == formulas_[j].functor() &&
|
||||
grounds[i].arity() == formulas_[j].arity()) {
|
||||
constr_->moveToTop (formulas_[j].logVars());
|
||||
for (unsigned j = 0; j < args_.size(); j++) {
|
||||
if (grounds[i].functor() == args_[j].functor() &&
|
||||
grounds[i].arity() == args_[j].arity()) {
|
||||
constr_->moveToTop (args_[j].logVars());
|
||||
if (constr_->containsTuple (grounds[i].args())) {
|
||||
newFormulas.push_back (formulas_[j]);
|
||||
newFormulas.push_back (args_[j]);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
assert (newFormulas.size() == i + 1);
|
||||
}
|
||||
reorderFormulas (newFormulas);
|
||||
reorderArguments (newFormulas);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Parfactor::reorderFormulas (const ProbFormulas& newFormulas)
|
||||
{
|
||||
assert (newFormulas.size() == formulas_.size());
|
||||
if (newFormulas == formulas_) {
|
||||
return;
|
||||
}
|
||||
|
||||
Ranges newRanges;
|
||||
vector<unsigned> positions;
|
||||
for (unsigned i = 0; i < newFormulas.size(); i++) {
|
||||
unsigned idx = indexOf (newFormulas[i]);
|
||||
newRanges.push_back (ranges_[idx]);
|
||||
positions.push_back (idx);
|
||||
}
|
||||
|
||||
unsigned N = ranges_.size();
|
||||
Params newParams (params_.size());
|
||||
for (unsigned i = 0; i < params_.size(); i++) {
|
||||
unsigned li = i;
|
||||
// calculate vector index corresponding to linear index
|
||||
vector<unsigned> vi (N);
|
||||
for (int k = N-1; k >= 0; k--) {
|
||||
vi[k] = li % ranges_[k];
|
||||
li /= ranges_[k];
|
||||
}
|
||||
// convert permuted vector index to corresponding linear index
|
||||
unsigned prod = 1;
|
||||
unsigned new_li = 0;
|
||||
for (int k = N - 1; k >= 0; k--) {
|
||||
new_li += vi[positions[k]] * prod;
|
||||
prod *= ranges_[positions[k]];
|
||||
}
|
||||
newParams[new_li] = params_[i];
|
||||
}
|
||||
formulas_ = newFormulas;
|
||||
ranges_ = newRanges;
|
||||
params_ = newParams;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Parfactor::absorveEvidence (unsigned fIdx, unsigned evidence)
|
||||
Parfactor::absorveEvidence (const ProbFormula& formula, unsigned evidence)
|
||||
{
|
||||
int fIdx = indexOf (formula);
|
||||
assert (fIdx != -1);
|
||||
LogVarSet excl = exclusiveLogVars (fIdx);
|
||||
assert (fIdx < formulas_.size());
|
||||
assert (evidence < formulas_[fIdx].range());
|
||||
assert (formulas_[fIdx].isCounting() == false);
|
||||
assert (args_[fIdx].isCounting() == false);
|
||||
assert (constr_->isCountNormalized (excl));
|
||||
|
||||
Util::pow (params_, constr_->getConditionalCount (excl));
|
||||
|
||||
Params copy = params_;
|
||||
params_.clear();
|
||||
params_.reserve (copy.size() / formulas_[fIdx].range());
|
||||
|
||||
StatesIndexer indexer (ranges_);
|
||||
for (unsigned i = 0; i < evidence; i++) {
|
||||
indexer.increment (fIdx);
|
||||
}
|
||||
while (indexer.valid()) {
|
||||
params_.push_back (copy[indexer]);
|
||||
indexer.incrementExcluding (fIdx);
|
||||
}
|
||||
formulas_.erase (formulas_.begin() + fIdx);
|
||||
ranges_.erase (ranges_.begin() + fIdx);
|
||||
LogAware::pow (params_, constr_->getConditionalCount (excl));
|
||||
TFactor<ProbFormula>::absorveEvidence (formula, evidence);
|
||||
constr_->remove (excl);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Parfactor::normalize (void)
|
||||
{
|
||||
Util::normalize (params_);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Parfactor::setFormulaGroup (const ProbFormula& f, int group)
|
||||
{
|
||||
assert (indexOf (f) != -1);
|
||||
formulas_[indexOf (f)].setGroup (group);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Parfactor::setNewGroups (void)
|
||||
{
|
||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
||||
formulas_[i].setGroup (ProbFormula::getNewGroup());
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
args_[i].setGroup (ProbFormula::getNewGroup());
|
||||
}
|
||||
}
|
||||
|
||||
@ -578,14 +389,14 @@ Parfactor::setNewGroups (void)
|
||||
void
|
||||
Parfactor::applySubstitution (const Substitution& theta)
|
||||
{
|
||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
||||
LogVars& lvs = formulas_[i].logVars();
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
LogVars& lvs = args_[i].logVars();
|
||||
for (unsigned j = 0; j < lvs.size(); j++) {
|
||||
lvs[j] = theta.newNameFor (lvs[j]);
|
||||
}
|
||||
if (formulas_[i].isCounting()) {
|
||||
LogVar clv = formulas_[i].countedLogVar();
|
||||
formulas_[i].setCountedLogVar (theta.newNameFor (clv));
|
||||
if (args_[i].isCounting()) {
|
||||
LogVar clv = args_[i].countedLogVar();
|
||||
args_[i].setCountedLogVar (theta.newNameFor (clv));
|
||||
}
|
||||
}
|
||||
constr_->applySubstitution (theta);
|
||||
@ -593,19 +404,29 @@ Parfactor::applySubstitution (const Substitution& theta)
|
||||
|
||||
|
||||
|
||||
bool
|
||||
Parfactor::containsGround (const Ground& ground) const
|
||||
int
|
||||
Parfactor::findGroup (const Ground& ground) const
|
||||
{
|
||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
||||
if (formulas_[i].functor() == ground.functor() &&
|
||||
formulas_[i].arity() == ground.arity()) {
|
||||
constr_->moveToTop (formulas_[i].logVars());
|
||||
int group = -1;
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
if (args_[i].functor() == ground.functor() &&
|
||||
args_[i].arity() == ground.arity()) {
|
||||
constr_->moveToTop (args_[i].logVars());
|
||||
if (constr_->containsTuple (ground.args())) {
|
||||
return true;
|
||||
group = args_[i].group();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
return group;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
Parfactor::containsGround (const Ground& ground) const
|
||||
{
|
||||
return findGroup (ground) != -1;
|
||||
}
|
||||
|
||||
|
||||
@ -613,8 +434,8 @@ Parfactor::containsGround (const Ground& ground) const
|
||||
bool
|
||||
Parfactor::containsGroup (unsigned group) const
|
||||
{
|
||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
||||
if (formulas_[i].group() == group) {
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
if (args_[i].group() == group) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
@ -623,30 +444,12 @@ Parfactor::containsGroup (unsigned group) const
|
||||
|
||||
|
||||
|
||||
const ProbFormula&
|
||||
Parfactor::formula (unsigned fIdx) const
|
||||
{
|
||||
assert (fIdx < formulas_.size());
|
||||
return formulas_[fIdx];
|
||||
}
|
||||
|
||||
|
||||
|
||||
unsigned
|
||||
Parfactor::range (unsigned fIdx) const
|
||||
{
|
||||
assert (fIdx < ranges_.size());
|
||||
return ranges_[fIdx];
|
||||
}
|
||||
|
||||
|
||||
|
||||
unsigned
|
||||
Parfactor::nrFormulas (LogVar X) const
|
||||
{
|
||||
unsigned count = 0;
|
||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
||||
if (formulas_[i].contains (X)) {
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
if (args_[i].contains (X)) {
|
||||
count ++;
|
||||
}
|
||||
}
|
||||
@ -656,27 +459,12 @@ Parfactor::nrFormulas (LogVar X) const
|
||||
|
||||
|
||||
int
|
||||
Parfactor::indexOf (const ProbFormula& f) const
|
||||
{
|
||||
int idx = -1;
|
||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
||||
if (f == formulas_[i]) {
|
||||
idx = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return idx;
|
||||
}
|
||||
|
||||
|
||||
|
||||
int
|
||||
Parfactor::indexOfFormulaWithLogVar (LogVar X) const
|
||||
Parfactor::indexOfLogVar (LogVar X) const
|
||||
{
|
||||
int idx = -1;
|
||||
assert (nrFormulas (X) == 1);
|
||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
||||
if (formulas_[i].contains (X)) {
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
if (args_[i].contains (X)) {
|
||||
idx = i;
|
||||
break;
|
||||
}
|
||||
@ -687,11 +475,11 @@ Parfactor::indexOfFormulaWithLogVar (LogVar X) const
|
||||
|
||||
|
||||
int
|
||||
Parfactor::indexOfFormulaWithGroup (unsigned group) const
|
||||
Parfactor::indexOfGroup (unsigned group) const
|
||||
{
|
||||
int pos = -1;
|
||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
||||
if (formulas_[i].group() == group) {
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
if (args_[i].group() == group) {
|
||||
pos = i;
|
||||
break;
|
||||
}
|
||||
@ -704,9 +492,9 @@ Parfactor::indexOfFormulaWithGroup (unsigned group) const
|
||||
vector<unsigned>
|
||||
Parfactor::getAllGroups (void) const
|
||||
{
|
||||
vector<unsigned> groups (formulas_.size());
|
||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
||||
groups[i] = formulas_[i].group();
|
||||
vector<unsigned> groups (args_.size());
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
groups[i] = args_[i].group();
|
||||
}
|
||||
return groups;
|
||||
}
|
||||
@ -714,13 +502,13 @@ Parfactor::getAllGroups (void) const
|
||||
|
||||
|
||||
string
|
||||
Parfactor::getHeaderString (void) const
|
||||
Parfactor::getLabel (void) const
|
||||
{
|
||||
stringstream ss;
|
||||
ss << "phi(" ;
|
||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
if (i != 0) ss << "," ;
|
||||
ss << formulas_[i];
|
||||
ss << args_[i];
|
||||
}
|
||||
ss << ")" ;
|
||||
ConstraintTree copy (*constr_);
|
||||
@ -735,32 +523,35 @@ void
|
||||
Parfactor::print (bool printParams) const
|
||||
{
|
||||
cout << "Formulas: " ;
|
||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
if (i != 0) cout << ", " ;
|
||||
cout << formulas_[i];
|
||||
cout << args_[i];
|
||||
}
|
||||
cout << endl;
|
||||
vector<string> groups;
|
||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
||||
groups.push_back (string ("g") + Util::toString (formulas_[i].group()));
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
groups.push_back (string ("g") + Util::toString (args_[i].group()));
|
||||
}
|
||||
cout << "Groups: " << groups << endl;
|
||||
cout << "LogVars: " << constr_->logVars() << endl;
|
||||
cout << "LogVars: " << constr_->logVarSet() << endl;
|
||||
cout << "Ranges: " << ranges_ << endl;
|
||||
if (printParams == false) {
|
||||
cout << "Params: " << params_ << endl;
|
||||
}
|
||||
cout << "Tuples: " << constr_->tupleSet() << endl;
|
||||
ConstraintTree copy (*constr_);
|
||||
copy.moveToTop (copy.logVarSet().elements());
|
||||
cout << "Tuples: " << copy.tupleSet() << endl;
|
||||
if (printParams) {
|
||||
vector<string> jointStrings;
|
||||
StatesIndexer indexer (ranges_);
|
||||
while (indexer.valid()) {
|
||||
stringstream ss;
|
||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
||||
for (unsigned i = 0; i < args_.size(); i++) {
|
||||
if (i != 0) ss << ", " ;
|
||||
if (formulas_[i].isCounting()) {
|
||||
unsigned N = constr_->getConditionalCount (formulas_[i].countedLogVar());
|
||||
HistogramSet hs (N, formulas_[i].range());
|
||||
if (args_[i].isCounting()) {
|
||||
unsigned N = constr_->getConditionalCount (
|
||||
args_[i].countedLogVar());
|
||||
HistogramSet hs (N, args_[i].range());
|
||||
unsigned c = 0;
|
||||
while (c < indexer[i]) {
|
||||
hs.nextHistogram();
|
||||
@ -784,17 +575,50 @@ Parfactor::print (bool printParams) const
|
||||
|
||||
|
||||
void
|
||||
Parfactor::insertDimension (unsigned range)
|
||||
Parfactor::expandPotential (
|
||||
int fIdx,
|
||||
unsigned newRange,
|
||||
const vector<unsigned>& sumIndexes)
|
||||
{
|
||||
unsigned size = (params_.size() / ranges_[fIdx]) * newRange;
|
||||
Params copy = params_;
|
||||
params_.clear();
|
||||
params_.reserve (copy.size() * range);
|
||||
for (unsigned i = 0; i < copy.size(); i++) {
|
||||
for (unsigned reps = 0; reps < range; reps++) {
|
||||
params_.push_back (copy[i]);
|
||||
params_.reserve (size);
|
||||
|
||||
unsigned prod = 1;
|
||||
vector<unsigned> offsets_ (ranges_.size());
|
||||
for (int i = ranges_.size() - 1; i >= 0; i--) {
|
||||
offsets_[i] = prod;
|
||||
prod *= ranges_[i];
|
||||
}
|
||||
|
||||
unsigned index = 0;
|
||||
ranges_[fIdx] = newRange;
|
||||
vector<unsigned> indices (ranges_.size(), 0);
|
||||
for (unsigned k = 0; k < size; k++) {
|
||||
params_.push_back (copy[index]);
|
||||
for (int i = ranges_.size() - 1; i >= 0; i--) {
|
||||
indices[i] ++;
|
||||
if (i == fIdx) {
|
||||
assert (indices[i] - 1 < sumIndexes.size());
|
||||
int diff = sumIndexes[indices[i]] - sumIndexes[indices[i] - 1];
|
||||
index += diff * offsets_[i];
|
||||
} else {
|
||||
index += offsets_[i];
|
||||
}
|
||||
if (indices[i] != ranges_[i]) {
|
||||
break;
|
||||
} else {
|
||||
if (i == fIdx) {
|
||||
int diff = sumIndexes[0] - sumIndexes[indices[i]];
|
||||
index += diff * offsets_[i];
|
||||
} else {
|
||||
index -= offsets_[i] * ranges_[i];
|
||||
}
|
||||
indices[i] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
ranges_.push_back (range);
|
||||
}
|
||||
|
||||
|
||||
@ -803,29 +627,27 @@ void
|
||||
Parfactor::alignAndExponentiate (Parfactor* g1, Parfactor* g2)
|
||||
{
|
||||
LogVars X_1, X_2;
|
||||
const ProbFormulas& formulas1 = g1->formulas();
|
||||
const ProbFormulas& formulas2 = g2->formulas();
|
||||
const ProbFormulas& formulas1 = g1->arguments();
|
||||
const ProbFormulas& formulas2 = g2->arguments();
|
||||
for (unsigned i = 0; i < formulas1.size(); i++) {
|
||||
for (unsigned j = 0; j < formulas2.size(); j++) {
|
||||
if (formulas1[i].group() == formulas2[j].group()) {
|
||||
X_1.insert (X_1.end(),
|
||||
formulas1[i].logVars().begin(),
|
||||
formulas1[i].logVars().end());
|
||||
X_2.insert (X_2.end(),
|
||||
formulas2[j].logVars().begin(),
|
||||
formulas2[j].logVars().end());
|
||||
Util::addToVector (X_1, formulas1[i].logVars());
|
||||
Util::addToVector (X_2, formulas2[j].logVars());
|
||||
}
|
||||
}
|
||||
}
|
||||
align (g1, X_1, g2, X_2);
|
||||
LogVarSet Y_1 = g1->logVarSet() - LogVarSet (X_1);
|
||||
LogVarSet Y_2 = g2->logVarSet() - LogVarSet (X_2);
|
||||
assert (g1->constr()->isCountNormalized (Y_1));
|
||||
assert (g2->constr()->isCountNormalized (Y_2));
|
||||
unsigned condCount1 = g1->constr()->getConditionalCount (Y_1);
|
||||
unsigned condCount2 = g2->constr()->getConditionalCount (Y_2);
|
||||
Util::pow (g1->params(), 1.0 / condCount2);
|
||||
Util::pow (g2->params(), 1.0 / condCount1);
|
||||
LogAware::pow (g1->params(), 1.0 / condCount2);
|
||||
LogAware::pow (g2->params(), 1.0 / condCount1);
|
||||
// this must be done in the end or else X_1 and X_2
|
||||
// will refer the old log var names in the code above
|
||||
align (g1, X_1, g2, X_2);
|
||||
}
|
||||
|
||||
|
||||
@ -838,7 +660,6 @@ Parfactor::align (
|
||||
LogVar freeLogVar = 0;
|
||||
Substitution theta1;
|
||||
Substitution theta2;
|
||||
|
||||
const LogVarSet& allLvs1 = g1->logVarSet();
|
||||
for (unsigned i = 0; i < allLvs1.size(); i++) {
|
||||
theta1.add (allLvs1[i], freeLogVar);
|
||||
@ -850,7 +671,7 @@ Parfactor::align (
|
||||
theta2.add (allLvs2[i], freeLogVar);
|
||||
++ freeLogVar;
|
||||
}
|
||||
|
||||
|
||||
assert (alignLvs1.size() == alignLvs2.size());
|
||||
for (unsigned i = 0; i < alignLvs1.size(); i++) {
|
||||
theta1.rename (alignLvs1[i], theta2.newNameFor (alignLvs2[i]));
|
||||
|
@ -9,8 +9,9 @@
|
||||
#include "LiftedUtils.h"
|
||||
#include "Horus.h"
|
||||
|
||||
#include "Factor.h"
|
||||
|
||||
class Parfactor
|
||||
class Parfactor : public TFactor<ProbFormula>
|
||||
{
|
||||
public:
|
||||
Parfactor (
|
||||
@ -18,27 +19,15 @@ class Parfactor
|
||||
const Params&,
|
||||
const Tuples&,
|
||||
unsigned);
|
||||
|
||||
Parfactor (const Parfactor*, const Tuple&);
|
||||
|
||||
Parfactor (const Parfactor*, ConstraintTree*);
|
||||
|
||||
Parfactor (const Parfactor&);
|
||||
|
||||
~Parfactor (void);
|
||||
|
||||
ProbFormulas& formulas (void) { return formulas_; }
|
||||
|
||||
const ProbFormulas& formulas (void) const { return formulas_; }
|
||||
|
||||
unsigned nrFormulas (void) const { return formulas_.size(); }
|
||||
|
||||
Params& params (void) { return params_; }
|
||||
|
||||
const Params& params (void) const { return params_; }
|
||||
|
||||
unsigned size (void) const { return params_.size(); }
|
||||
|
||||
const Ranges& ranges (void) const { return ranges_; }
|
||||
|
||||
unsigned distId (void) const { return distId_; }
|
||||
|
||||
ConstraintTree* constr (void) { return constr_; }
|
||||
|
||||
const ConstraintTree* constr (void) const { return constr_; }
|
||||
@ -57,64 +46,52 @@ class Parfactor
|
||||
|
||||
void setConstraintTree (ConstraintTree*);
|
||||
|
||||
void sumOut (unsigned);
|
||||
void sumOut (unsigned fIdx);
|
||||
|
||||
void multiply (Parfactor&);
|
||||
|
||||
void countConvert (LogVar);
|
||||
|
||||
void expandPotential (LogVar, LogVar, LogVar);
|
||||
void expand (LogVar, LogVar, LogVar);
|
||||
|
||||
void fullExpand (LogVar);
|
||||
|
||||
void reorderAccordingGrounds (const Grounds&);
|
||||
|
||||
void reorderFormulas (const ProbFormulas&);
|
||||
|
||||
void absorveEvidence (unsigned, unsigned);
|
||||
|
||||
void normalize (void);
|
||||
|
||||
void setFormulaGroup (const ProbFormula&, int);
|
||||
void absorveEvidence (const ProbFormula&, unsigned);
|
||||
|
||||
void setNewGroups (void);
|
||||
|
||||
void applySubstitution (const Substitution&);
|
||||
|
||||
int findGroup (const Ground&) const;
|
||||
|
||||
bool containsGround (const Ground&) const;
|
||||
|
||||
bool containsGroup (unsigned) const;
|
||||
|
||||
const ProbFormula& formula (unsigned) const;
|
||||
|
||||
unsigned range (unsigned) const;
|
||||
|
||||
|
||||
unsigned nrFormulas (LogVar) const;
|
||||
|
||||
int indexOf (const ProbFormula&) const;
|
||||
int indexOfLogVar (LogVar) const;
|
||||
|
||||
int indexOfFormulaWithLogVar (LogVar) const;
|
||||
|
||||
int indexOfFormulaWithGroup (unsigned) const;
|
||||
int indexOfGroup (unsigned) const;
|
||||
|
||||
vector<unsigned> getAllGroups (void) const;
|
||||
|
||||
void print (bool = false) const;
|
||||
|
||||
string getHeaderString (void) const;
|
||||
string getLabel (void) const;
|
||||
|
||||
private:
|
||||
void expandPotential (int fIdx, unsigned newRange,
|
||||
const vector<unsigned>& sumIndexes);
|
||||
|
||||
static void alignAndExponentiate (Parfactor*, Parfactor*);
|
||||
|
||||
static void align (
|
||||
Parfactor*, const LogVars&, Parfactor*, const LogVars&);
|
||||
|
||||
void insertDimension (unsigned);
|
||||
|
||||
ProbFormulas formulas_;
|
||||
Ranges ranges_;
|
||||
Params params_;
|
||||
unsigned distId_;
|
||||
ConstraintTree* constr_;
|
||||
ConstraintTree* constr_;
|
||||
};
|
||||
|
||||
|
||||
|
@ -3,9 +3,32 @@
|
||||
#include "ParfactorList.h"
|
||||
|
||||
|
||||
ParfactorList::ParfactorList (Parfactors& pfs)
|
||||
ParfactorList::ParfactorList (const ParfactorList& pfList)
|
||||
{
|
||||
pfList_.insert (pfList_.end(), pfs.begin(), pfs.end());
|
||||
ParfactorList::const_iterator it = pfList.begin();
|
||||
while (it != pfList.end()) {
|
||||
addShattered (new Parfactor (**it));
|
||||
++ it;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
ParfactorList::ParfactorList (const Parfactors& pfs)
|
||||
{
|
||||
add (pfs);
|
||||
}
|
||||
|
||||
|
||||
|
||||
ParfactorList::~ParfactorList (void)
|
||||
{
|
||||
ParfactorList::const_iterator it = pfList_.begin();
|
||||
while (it != pfList_.end()) {
|
||||
delete *it;
|
||||
++ it;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -14,17 +37,17 @@ void
|
||||
ParfactorList::add (Parfactor* pf)
|
||||
{
|
||||
pf->setNewGroups();
|
||||
pfList_.push_back (pf);
|
||||
addToShatteredList (pf);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
ParfactorList::add (Parfactors& pfs)
|
||||
ParfactorList::add (const Parfactors& pfs)
|
||||
{
|
||||
for (unsigned i = 0; i < pfs.size(); i++) {
|
||||
pfs[i]->setNewGroups();
|
||||
pfList_.push_back (pfs[i]);
|
||||
addToShatteredList (pfs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@ -33,7 +56,20 @@ ParfactorList::add (Parfactors& pfs)
|
||||
void
|
||||
ParfactorList::addShattered (Parfactor* pf)
|
||||
{
|
||||
assert (isAllShattered());
|
||||
pfList_.push_back (pf);
|
||||
assert (isAllShattered());
|
||||
}
|
||||
|
||||
|
||||
|
||||
list<Parfactor*>::iterator
|
||||
ParfactorList::insertShattered (
|
||||
list<Parfactor*>::iterator it,
|
||||
Parfactor* pf)
|
||||
{
|
||||
return pfList_.insert (it, pf);
|
||||
assert (isAllShattered());
|
||||
}
|
||||
|
||||
|
||||
@ -47,7 +83,7 @@ ParfactorList::remove (list<Parfactor*>::iterator it)
|
||||
|
||||
|
||||
list<Parfactor*>::iterator
|
||||
ParfactorList::deleteAndRemove (list<Parfactor*>::iterator it)
|
||||
ParfactorList::removeAndDelete (list<Parfactor*>::iterator it)
|
||||
{
|
||||
delete *it;
|
||||
return pfList_.erase (it);
|
||||
@ -55,58 +91,21 @@ ParfactorList::deleteAndRemove (list<Parfactor*>::iterator it)
|
||||
|
||||
|
||||
|
||||
void
|
||||
ParfactorList::shatter (void)
|
||||
bool
|
||||
ParfactorList::isAllShattered (void) const
|
||||
{
|
||||
list<Parfactor*> tempList;
|
||||
Parfactors newPfs;
|
||||
newPfs.insert (newPfs.end(), pfList_.begin(), pfList_.end());
|
||||
while (newPfs.empty() == false) {
|
||||
tempList.insert (tempList.end(), newPfs.begin(), newPfs.end());
|
||||
newPfs.clear();
|
||||
list<Parfactor*>::iterator iter1 = tempList.begin();
|
||||
while (tempList.size() > 1 && iter1 != -- tempList.end()) {
|
||||
list<Parfactor*>::iterator iter2 = iter1;
|
||||
++ iter2;
|
||||
bool incIter1 = true;
|
||||
while (iter2 != tempList.end()) {
|
||||
assert (iter1 != iter2);
|
||||
std::pair<Parfactors, Parfactors> res = shatter (
|
||||
(*iter1)->formulas(), *iter1, (*iter2)->formulas(), *iter2);
|
||||
bool incIter2 = true;
|
||||
if (res.second.empty() == false) {
|
||||
// cout << "second unshattered" << endl;
|
||||
delete *iter2;
|
||||
iter2 = tempList.erase (iter2);
|
||||
incIter2 = false;
|
||||
newPfs.insert (
|
||||
newPfs.begin(), res.second.begin(), res.second.end());
|
||||
}
|
||||
if (res.first.empty() == false) {
|
||||
// cout << "first unshattered" << endl;
|
||||
delete *iter1;
|
||||
iter1 = tempList.erase (iter1);
|
||||
newPfs.insert (
|
||||
newPfs.begin(), res.first.begin(), res.first.end());
|
||||
incIter1 = false;
|
||||
break;
|
||||
}
|
||||
if (incIter2) {
|
||||
++ iter2;
|
||||
}
|
||||
}
|
||||
if (incIter1) {
|
||||
++ iter1;
|
||||
if (pfList_.size() <= 1) {
|
||||
return true;
|
||||
}
|
||||
vector<Parfactor*> pfs (pfList_.begin(), pfList_.end());
|
||||
for (unsigned i = 0; i < pfs.size() - 1; i++) {
|
||||
for (unsigned j = i + 1; j < pfs.size(); j++) {
|
||||
if (isShattered (pfs[i], pfs[j]) == false) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// cout << "|||||||||||||||||||||||||||||||||||||||||||||||||" << endl;
|
||||
// cout << "||||||||||||| SHATTERING ITERATION ||||||||||||||" << endl;
|
||||
// cout << "|||||||||||||||||||||||||||||||||||||||||||||||||" << endl;
|
||||
// printParfactors (newPfs);
|
||||
// cout << "|||||||||||||||||||||||||||||||||||||||||||||||||" << endl;
|
||||
}
|
||||
pfList_.clear();
|
||||
pfList_.insert (pfList_.end(), tempList.begin(), tempList.end());
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
@ -123,19 +122,83 @@ ParfactorList::print (void) const
|
||||
|
||||
|
||||
|
||||
std::pair<Parfactors, Parfactors>
|
||||
ParfactorList::shatter (
|
||||
ProbFormulas& formulas1,
|
||||
Parfactor* g1,
|
||||
ProbFormulas& formulas2,
|
||||
Parfactor* g2)
|
||||
bool
|
||||
ParfactorList::isShattered (
|
||||
const Parfactor* g1,
|
||||
const Parfactor* g2) const
|
||||
{
|
||||
assert (g1 != g2);
|
||||
const ProbFormulas& fms1 = g1->arguments();
|
||||
const ProbFormulas& fms2 = g2->arguments();
|
||||
for (unsigned i = 0; i < fms1.size(); i++) {
|
||||
for (unsigned j = 0; j < fms2.size(); j++) {
|
||||
if (fms1[i].group() == fms2[j].group()) {
|
||||
if (identical (
|
||||
fms1[i], *(g1->constr()),
|
||||
fms2[j], *(g2->constr())) == false) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
if (disjoint (
|
||||
fms1[i], *(g1->constr()),
|
||||
fms2[j], *(g2->constr())) == false) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
ParfactorList::addToShatteredList (Parfactor* g)
|
||||
{
|
||||
queue<Parfactor*> residuals;
|
||||
residuals.push (g);
|
||||
while (residuals.empty() == false) {
|
||||
Parfactor* pf = residuals.front();
|
||||
bool pfSplitted = false;
|
||||
list<Parfactor*>::iterator pfIter;
|
||||
pfIter = pfList_.begin();
|
||||
while (pfIter != pfList_.end()) {
|
||||
std::pair<Parfactors, Parfactors> shattRes;
|
||||
shattRes = shatter (*pfIter, pf);
|
||||
if (shattRes.first.empty() == false) {
|
||||
pfIter = removeAndDelete (pfIter);
|
||||
Util::addToQueue (residuals, shattRes.first);
|
||||
} else {
|
||||
++ pfIter;
|
||||
}
|
||||
if (shattRes.second.empty() == false) {
|
||||
delete pf;
|
||||
Util::addToQueue (residuals, shattRes.second);
|
||||
pfSplitted = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
residuals.pop();
|
||||
if (pfSplitted == false) {
|
||||
addShattered (pf);
|
||||
}
|
||||
}
|
||||
assert (isAllShattered());
|
||||
}
|
||||
|
||||
|
||||
|
||||
std::pair<Parfactors, Parfactors>
|
||||
ParfactorList::shatter (Parfactor* g1, Parfactor* g2)
|
||||
{
|
||||
ProbFormulas& formulas1 = g1->arguments();
|
||||
ProbFormulas& formulas2 = g2->arguments();
|
||||
assert (g1 != 0 && g2 != 0 && g1 != g2);
|
||||
for (unsigned i = 0; i < formulas1.size(); i++) {
|
||||
for (unsigned j = 0; j < formulas2.size(); j++) {
|
||||
if (formulas1[i].sameSkeletonAs (formulas2[j])) {
|
||||
std::pair<Parfactors, Parfactors> res
|
||||
= shatter (formulas1[i], g1, formulas2[j], g2);
|
||||
std::pair<Parfactors, Parfactors> res;
|
||||
res = shatter (i, g1, j, g2);
|
||||
if (res.first.empty() == false ||
|
||||
res.second.empty() == false) {
|
||||
return res;
|
||||
@ -150,21 +213,22 @@ ParfactorList::shatter (
|
||||
|
||||
std::pair<Parfactors, Parfactors>
|
||||
ParfactorList::shatter (
|
||||
ProbFormula& f1,
|
||||
Parfactor* g1,
|
||||
ProbFormula& f2,
|
||||
Parfactor* g2)
|
||||
unsigned fIdx1, Parfactor* g1,
|
||||
unsigned fIdx2, Parfactor* g2)
|
||||
{
|
||||
ProbFormula& f1 = g1->argument (fIdx1);
|
||||
ProbFormula& f2 = g2->argument (fIdx2);
|
||||
// cout << endl;
|
||||
// cout << "-------------------------------------------------" << endl;
|
||||
// Util::printDashLine();
|
||||
// cout << "-> SHATTERING (#" << g1 << ", #" << g2 << ")" << endl;
|
||||
// g1->print();
|
||||
// cout << "-> WITH" << endl;
|
||||
// g2->print();
|
||||
// cout << "-> ON: " << f1.toString (g1->constr()) << endl;
|
||||
// cout << "-> ON: " << f2.toString (g2->constr()) << endl;
|
||||
// cout << "-------------------------------------------------" << endl;
|
||||
|
||||
// cout << "-> ON: " << f1 << "|" ;
|
||||
// cout << g1->constr()->tupleSet (f1.logVars()) << endl;
|
||||
// cout << "-> ON: " << f2 << "|" ;
|
||||
// cout << g2->constr()->tupleSet (f2.logVars()) << endl;
|
||||
// Util::printDashLine();
|
||||
if (f1.isAtom()) {
|
||||
unsigned group = (f1.group() < f2.group()) ? f1.group() : f2.group();
|
||||
f1.setGroup (group);
|
||||
@ -174,7 +238,7 @@ ParfactorList::shatter (
|
||||
assert (g1->constr()->empty() == false);
|
||||
assert (g2->constr()->empty() == false);
|
||||
if (f1.group() == f2.group()) {
|
||||
// assert (identical (f1, g1->constr(), f2, g2->constr()));
|
||||
assert (identical (f1, *(g1->constr()), f2, *(g2->constr())));
|
||||
return { };
|
||||
}
|
||||
|
||||
@ -215,7 +279,9 @@ ParfactorList::shatter (
|
||||
// exclCt2->exportToGraphViz (ss6.str().c_str(), true);
|
||||
|
||||
if (exclCt1->empty() && exclCt2->empty()) {
|
||||
unsigned group = (f1.group() < f2.group()) ? f1.group() : f2.group();
|
||||
unsigned group = (f1.group() < f2.group())
|
||||
? f1.group()
|
||||
: f2.group();
|
||||
// identical
|
||||
f1.setGroup (group);
|
||||
f2.setGroup (group);
|
||||
@ -235,8 +301,8 @@ ParfactorList::shatter (
|
||||
} else {
|
||||
group = ProbFormula::getNewGroup();
|
||||
}
|
||||
Parfactors res1 = shatter (g1, f1, commCt1, exclCt1, group);
|
||||
Parfactors res2 = shatter (g2, f2, commCt2, exclCt2, group);
|
||||
Parfactors res1 = shatter (g1, fIdx1, commCt1, exclCt1, group);
|
||||
Parfactors res2 = shatter (g2, fIdx2, commCt2, exclCt2, group);
|
||||
return make_pair (res1, res2);
|
||||
}
|
||||
|
||||
@ -245,11 +311,19 @@ ParfactorList::shatter (
|
||||
Parfactors
|
||||
ParfactorList::shatter (
|
||||
Parfactor* g,
|
||||
const ProbFormula& f,
|
||||
unsigned fIdx,
|
||||
ConstraintTree* commCt,
|
||||
ConstraintTree* exclCt,
|
||||
unsigned commGroup)
|
||||
{
|
||||
ProbFormula& f = g->argument (fIdx);
|
||||
if (exclCt->empty()) {
|
||||
delete commCt;
|
||||
delete exclCt;
|
||||
f.setGroup (commGroup);
|
||||
return { };
|
||||
}
|
||||
|
||||
Parfactors result;
|
||||
if (f.isCounting()) {
|
||||
LogVar X_new1 = g->constr()->logVarSet().back() + 1;
|
||||
@ -259,7 +333,7 @@ ParfactorList::shatter (
|
||||
for (unsigned i = 0; i < cts.size(); i++) {
|
||||
Parfactor* newPf = new Parfactor (g, cts[i]);
|
||||
if (cts[i]->nrLogVars() == g->constr()->nrLogVars() + 1) {
|
||||
newPf->expandPotential (f.countedLogVar(), X_new1, X_new2);
|
||||
newPf->expand (f.countedLogVar(), X_new1, X_new2);
|
||||
assert (g->constr()->getConditionalCount (f.countedLogVar()) ==
|
||||
cts[i]->getConditionalCount (X_new1) +
|
||||
cts[i]->getConditionalCount (X_new2));
|
||||
@ -270,20 +344,16 @@ ParfactorList::shatter (
|
||||
newPf->setNewGroups();
|
||||
result.push_back (newPf);
|
||||
}
|
||||
delete commCt;
|
||||
delete exclCt;
|
||||
} else {
|
||||
if (exclCt->empty()) {
|
||||
delete commCt;
|
||||
delete exclCt;
|
||||
g->setFormulaGroup (f, commGroup);
|
||||
} else {
|
||||
Parfactor* newPf = new Parfactor (g, commCt);
|
||||
newPf->setNewGroups();
|
||||
newPf->setFormulaGroup (f, commGroup);
|
||||
result.push_back (newPf);
|
||||
newPf = new Parfactor (g, exclCt);
|
||||
newPf->setNewGroups();
|
||||
result.push_back (newPf);
|
||||
}
|
||||
Parfactor* newPf = new Parfactor (g, commCt);
|
||||
newPf->setNewGroups();
|
||||
newPf->argument (fIdx).setGroup (commGroup);
|
||||
result.push_back (newPf);
|
||||
newPf = new Parfactor (g, exclCt);
|
||||
newPf->setNewGroups();
|
||||
result.push_back (newPf);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
@ -296,7 +366,7 @@ ParfactorList::unifyGroups (unsigned group1, unsigned group2)
|
||||
unsigned newGroup = ProbFormula::getNewGroup();
|
||||
for (ParfactorList::iterator it = pfList_.begin();
|
||||
it != pfList_.end(); it++) {
|
||||
ProbFormulas& formulas = (*it)->formulas();
|
||||
ProbFormulas& formulas = (*it)->arguments();
|
||||
for (unsigned i = 0; i < formulas.size(); i++) {
|
||||
if (formulas[i].group() == group1 ||
|
||||
formulas[i].group() == group2) {
|
||||
@ -306,3 +376,52 @@ ParfactorList::unifyGroups (unsigned group1, unsigned group2)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
ParfactorList::proper (
|
||||
const ProbFormula& f1, ConstraintTree c1,
|
||||
const ProbFormula& f2, ConstraintTree c2) const
|
||||
{
|
||||
return disjoint (f1, c1, f2, c2)
|
||||
|| identical (f1, c1, f2, c2);
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
ParfactorList::identical (
|
||||
const ProbFormula& f1, ConstraintTree c1,
|
||||
const ProbFormula& f2, ConstraintTree c2) const
|
||||
{
|
||||
if (f1.sameSkeletonAs (f2) == false) {
|
||||
return false;
|
||||
}
|
||||
if (f1.isAtom()) {
|
||||
return true;
|
||||
}
|
||||
c1.moveToTop (f1.logVars());
|
||||
c2.moveToTop (f2.logVars());
|
||||
return ConstraintTree::identical (
|
||||
&c1, &c2, f1.logVars().size());
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
ParfactorList::disjoint (
|
||||
const ProbFormula& f1, ConstraintTree c1,
|
||||
const ProbFormula& f2, ConstraintTree c2) const
|
||||
{
|
||||
if (f1.sameSkeletonAs (f2) == false) {
|
||||
return true;
|
||||
}
|
||||
if (f1.isAtom()) {
|
||||
return true;
|
||||
}
|
||||
c1.moveToTop (f1.logVars());
|
||||
c2.moveToTop (f2.logVars());
|
||||
return ConstraintTree::overlap (
|
||||
&c1, &c2, f1.arity()) == false;
|
||||
}
|
||||
|
||||
|
@ -2,6 +2,7 @@
|
||||
#define HORUS_PARFACTORLIST_H
|
||||
|
||||
#include <list>
|
||||
#include <queue>
|
||||
|
||||
#include "Parfactor.h"
|
||||
#include "ProbFormula.h"
|
||||
@ -14,56 +15,82 @@ class ParfactorList
|
||||
{
|
||||
public:
|
||||
ParfactorList (void) { }
|
||||
ParfactorList (Parfactors&);
|
||||
list<Parfactor*>& getParfactors (void) { return pfList_; }
|
||||
const list<Parfactor*>& getParfactors (void) const { return pfList_; }
|
||||
|
||||
void add (Parfactor* pf);
|
||||
void add (Parfactors& pfs);
|
||||
void addShattered (Parfactor* pf);
|
||||
list<Parfactor*>::iterator remove (list<Parfactor*>::iterator);
|
||||
list<Parfactor*>::iterator deleteAndRemove (list<Parfactor*>::iterator);
|
||||
ParfactorList (const ParfactorList&);
|
||||
|
||||
void clear (void) { pfList_.clear(); }
|
||||
unsigned size (void) const { return pfList_.size(); }
|
||||
|
||||
ParfactorList (const Parfactors&);
|
||||
|
||||
void shatter (void);
|
||||
~ParfactorList (void);
|
||||
|
||||
const list<Parfactor*>& parfactors (void) const { return pfList_; }
|
||||
|
||||
void clear (void) { pfList_.clear(); }
|
||||
|
||||
unsigned size (void) const { return pfList_.size(); }
|
||||
|
||||
typedef std::list<Parfactor*>::iterator iterator;
|
||||
|
||||
iterator begin (void) { return pfList_.begin(); }
|
||||
iterator end (void) { return pfList_.end(); }
|
||||
|
||||
iterator end (void) { return pfList_.end(); }
|
||||
|
||||
typedef std::list<Parfactor*>::const_iterator const_iterator;
|
||||
|
||||
const_iterator begin (void) const { return pfList_.begin(); }
|
||||
const_iterator end (void) const { return pfList_.end(); }
|
||||
|
||||
const_iterator end (void) const { return pfList_.end(); }
|
||||
|
||||
void add (Parfactor* pf);
|
||||
|
||||
void add (const Parfactors& pfs);
|
||||
|
||||
void addShattered (Parfactor* pf);
|
||||
|
||||
list<Parfactor*>::iterator insertShattered (
|
||||
list<Parfactor*>::iterator, Parfactor*);
|
||||
|
||||
list<Parfactor*>::iterator remove (list<Parfactor*>::iterator);
|
||||
|
||||
list<Parfactor*>::iterator removeAndDelete (list<Parfactor*>::iterator);
|
||||
|
||||
bool isAllShattered (void) const;
|
||||
|
||||
void print (void) const;
|
||||
|
||||
private:
|
||||
|
||||
bool isShattered (const Parfactor*, const Parfactor*) const;
|
||||
|
||||
static std::pair<Parfactors, Parfactors> shatter (
|
||||
ProbFormulas&,
|
||||
Parfactor*,
|
||||
ProbFormulas&,
|
||||
Parfactor*);
|
||||
void addToShatteredList (Parfactor*);
|
||||
|
||||
std::pair<Parfactors, Parfactors> shatter (
|
||||
Parfactor*, Parfactor*);
|
||||
|
||||
static std::pair<Parfactors, Parfactors> shatter (
|
||||
ProbFormula&,
|
||||
Parfactor*,
|
||||
ProbFormula&,
|
||||
Parfactor*);
|
||||
std::pair<Parfactors, Parfactors> shatter (
|
||||
unsigned, Parfactor*, unsigned, Parfactor*);
|
||||
|
||||
static Parfactors shatter (
|
||||
Parfactor*,
|
||||
const ProbFormula&,
|
||||
ConstraintTree*,
|
||||
ConstraintTree*,
|
||||
unsigned);
|
||||
Parfactors shatter (
|
||||
Parfactor*,
|
||||
unsigned,
|
||||
ConstraintTree*,
|
||||
ConstraintTree*,
|
||||
unsigned);
|
||||
|
||||
void unifyGroups (unsigned group1, unsigned group2);
|
||||
void unifyGroups (unsigned group1, unsigned group2);
|
||||
|
||||
list<Parfactor*> pfList_;
|
||||
bool proper (
|
||||
const ProbFormula&, ConstraintTree,
|
||||
const ProbFormula&, ConstraintTree) const;
|
||||
|
||||
bool identical (
|
||||
const ProbFormula&, ConstraintTree,
|
||||
const ProbFormula&, ConstraintTree) const;
|
||||
|
||||
bool disjoint (
|
||||
const ProbFormula&, ConstraintTree,
|
||||
const ProbFormula&, ConstraintTree) const;
|
||||
|
||||
list<Parfactor*> pfList_;
|
||||
};
|
||||
|
||||
#endif // HORUS_PARFACTORLIST_H
|
||||
|
@ -16,8 +16,7 @@ ProbFormula::sameSkeletonAs (const ProbFormula& f) const
|
||||
bool
|
||||
ProbFormula::contains (LogVar lv) const
|
||||
{
|
||||
return std::find (logVars_.begin(), logVars_.end(), lv) !=
|
||||
logVars_.end();
|
||||
return Util::contains (logVars_, lv);
|
||||
}
|
||||
|
||||
|
||||
@ -77,16 +76,15 @@ ProbFormula::rename (LogVar oldName, LogVar newName)
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
ProbFormula::operator== (const ProbFormula& f) const
|
||||
bool operator== (const ProbFormula& f1, const ProbFormula& f2)
|
||||
{
|
||||
return functor_ == f.functor_ && logVars_ == f.logVars_ ;
|
||||
return f1.group_ == f2.group_;
|
||||
//return functor_ == f.functor_ && logVars_ == f.logVars_ ;
|
||||
}
|
||||
|
||||
|
||||
|
||||
ostream& operator<< (ostream &os, const ProbFormula& f)
|
||||
std::ostream& operator<< (ostream &os, const ProbFormula& f)
|
||||
{
|
||||
os << f.functor_;
|
||||
if (f.isAtom() == false) {
|
||||
@ -113,3 +111,13 @@ ProbFormula::getNewGroup (void)
|
||||
return freeGroup_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
ostream& operator<< (ostream &os, const ObservedFormula& of)
|
||||
{
|
||||
os << of.functor_ << "/" << of.arity_;
|
||||
os << "|" << of.constr_.tupleSet();
|
||||
os << " [evidence=" << of.evidence_ << "]";
|
||||
return os;
|
||||
}
|
||||
|
||||
|
@ -8,14 +8,16 @@
|
||||
#include "Horus.h"
|
||||
|
||||
|
||||
|
||||
class ProbFormula
|
||||
{
|
||||
public:
|
||||
ProbFormula (Symbol f, const LogVars& lvs, unsigned range)
|
||||
: functor_(f), logVars_(lvs), range_(range),
|
||||
countedLogVar_() { }
|
||||
countedLogVar_(), group_(Util::maxUnsigned()) { }
|
||||
|
||||
ProbFormula (Symbol f, unsigned r) : functor_(f), range_(r) { }
|
||||
ProbFormula (Symbol f, unsigned r)
|
||||
: functor_(f), range_(r), group_(Util::maxUnsigned()) { }
|
||||
|
||||
Symbol functor (void) const { return functor_; }
|
||||
|
||||
@ -29,9 +31,9 @@ class ProbFormula
|
||||
|
||||
LogVarSet logVarSet (void) const { return LogVarSet (logVars_); }
|
||||
|
||||
unsigned group (void) const { return groupId_; }
|
||||
unsigned group (void) const { return group_; }
|
||||
|
||||
void setGroup (unsigned g) { groupId_ = g; }
|
||||
void setGroup (unsigned g) { group_ = g; }
|
||||
|
||||
bool sameSkeletonAs (const ProbFormula&) const;
|
||||
|
||||
@ -49,23 +51,58 @@ class ProbFormula
|
||||
|
||||
void rename (LogVar, LogVar);
|
||||
|
||||
bool operator== (const ProbFormula& f) const;
|
||||
|
||||
friend ostream& operator<< (ostream &out, const ProbFormula& f);
|
||||
|
||||
static unsigned getNewGroup (void);
|
||||
|
||||
friend std::ostream& operator<< (ostream &os, const ProbFormula& f);
|
||||
|
||||
friend bool operator== (const ProbFormula& f1, const ProbFormula& f2);
|
||||
|
||||
private:
|
||||
Symbol functor_;
|
||||
LogVars logVars_;
|
||||
unsigned range_;
|
||||
LogVar countedLogVar_;
|
||||
unsigned groupId_;
|
||||
static int freeGroup_;
|
||||
Symbol functor_;
|
||||
LogVars logVars_;
|
||||
unsigned range_;
|
||||
LogVar countedLogVar_;
|
||||
unsigned group_;
|
||||
static int freeGroup_;
|
||||
};
|
||||
|
||||
typedef vector<ProbFormula> ProbFormulas;
|
||||
|
||||
|
||||
class ObservedFormula
|
||||
{
|
||||
public:
|
||||
ObservedFormula (Symbol f, unsigned a, unsigned ev)
|
||||
: functor_(f), arity_(a), evidence_(ev), constr_(a) { }
|
||||
|
||||
ObservedFormula (Symbol f, unsigned ev, const Tuple& tuple)
|
||||
: functor_(f), arity_(tuple.size()), evidence_(ev), constr_(arity_)
|
||||
{
|
||||
constr_.addTuple (tuple);
|
||||
}
|
||||
|
||||
Symbol functor (void) const { return functor_; }
|
||||
|
||||
unsigned arity (void) const { return arity_; }
|
||||
|
||||
unsigned evidence (void) const { return evidence_; }
|
||||
|
||||
ConstraintTree& constr (void) { return constr_; }
|
||||
|
||||
bool isAtom (void) const { return arity_ == 0; }
|
||||
|
||||
void addTuple (const Tuple& tuple) { constr_.addTuple (tuple); }
|
||||
|
||||
friend ostream& operator<< (ostream &os, const ObservedFormula& of);
|
||||
|
||||
private:
|
||||
Symbol functor_;
|
||||
unsigned arity_;
|
||||
unsigned evidence_;
|
||||
ConstraintTree constr_;
|
||||
};
|
||||
|
||||
typedef vector<ObservedFormula> ObservedFormulas;
|
||||
|
||||
#endif // HORUS_PROBFORMULA_H
|
||||
|
||||
|
@ -21,7 +21,7 @@ Solver::printPosterioriOf (VarId vid)
|
||||
const States& states = var->states();
|
||||
for (unsigned i = 0; i < states.size(); i++) {
|
||||
cout << "P(" << var->label() << "=" << states[i] << ") = " ;
|
||||
cout << setprecision (PRECISION) << posterioriDist[i];
|
||||
cout << setprecision (Constants::PRECISION) << posterioriDist[i];
|
||||
cout << endl;
|
||||
}
|
||||
cout << endl;
|
||||
@ -45,7 +45,7 @@ Solver::printJointDistributionOf (const VarIds& vids)
|
||||
vector<string> jointStrings = Util::getJointStateStrings (vars);
|
||||
for (unsigned i = 0; i < jointDist.size(); i++) {
|
||||
cout << "P(" << jointStrings[i] << ") = " ;
|
||||
cout << setprecision (PRECISION) << jointDist[i];
|
||||
cout << setprecision (Constants::PRECISION) << jointDist[i];
|
||||
cout << endl;
|
||||
}
|
||||
cout << endl;
|
||||
|
@ -11,17 +11,20 @@ using namespace std;
|
||||
class Solver
|
||||
{
|
||||
public:
|
||||
Solver (const GraphicalModel* gm)
|
||||
{
|
||||
gm_ = gm;
|
||||
}
|
||||
virtual ~Solver() {} // to ensure that subclass destructor is called
|
||||
virtual void runSolver (void) = 0;
|
||||
virtual Params getPosterioriOf (VarId) = 0;
|
||||
virtual Params getJointDistributionOf (const VarIds&) = 0;
|
||||
Solver (const GraphicalModel* gm) : gm_(gm) { }
|
||||
|
||||
virtual ~Solver() { } // ensure that subclass destructor is called
|
||||
|
||||
virtual void runSolver (void) = 0;
|
||||
|
||||
virtual Params getPosterioriOf (VarId) = 0;
|
||||
|
||||
virtual Params getJointDistributionOf (const VarIds&) = 0;
|
||||
|
||||
void printAllPosterioris (void);
|
||||
|
||||
void printPosterioriOf (VarId vid);
|
||||
|
||||
void printJointDistributionOf (const VarIds& vids);
|
||||
|
||||
private:
|
||||
|
5
packages/CLPBN/clpbn/bp/TODO
Normal file
5
packages/CLPBN/clpbn/bp/TODO
Normal file
@ -0,0 +1,5 @@
|
||||
TODO
|
||||
|
||||
- add way to calculate combinations and factorials with large numbers
|
||||
- refactor sumOut in parfactor -> is really ugly code
|
||||
- Indexer: start receiving ranges as constant reference
|
@ -1,4 +1,7 @@
|
||||
#include <limits>
|
||||
|
||||
#include <sstream>
|
||||
#include <fstream>
|
||||
|
||||
#include "Util.h"
|
||||
#include "Indexer.h"
|
||||
@ -6,16 +9,15 @@
|
||||
|
||||
|
||||
namespace Globals {
|
||||
bool logDomain = false;
|
||||
bool logDomain = false;
|
||||
|
||||
//InfAlgs infAlgorithm = InfAlgorithms::VE;
|
||||
//InfAlgs infAlgorithm = InfAlgorithms::BN_BP;
|
||||
//InfAlgs infAlgorithm = InfAlgorithms::FG_BP;
|
||||
InfAlgorithms infAlgorithm = InfAlgorithms::CBP;
|
||||
};
|
||||
|
||||
|
||||
namespace InfAlgorithms {
|
||||
//InfAlgs infAlgorithm = InfAlgorithms::VE;
|
||||
//InfAlgs infAlgorithm = InfAlgorithms::BN_BP;
|
||||
InfAlgs infAlgorithm = InfAlgorithms::FG_BP;
|
||||
//InfAlgs infAlgorithm = InfAlgorithms::CBP;
|
||||
}
|
||||
|
||||
|
||||
namespace BpOptions {
|
||||
@ -28,8 +30,7 @@ unsigned maxIter = 1000;
|
||||
}
|
||||
|
||||
|
||||
unordered_map<VarId,VariableInfo> GraphicalModel::varsInfo_;
|
||||
unordered_map<unsigned,Distribution*> GraphicalModel::distsInfo_;
|
||||
unordered_map<VarId, VarInfo> GraphicalModel::varsInfo_;
|
||||
|
||||
vector<NetInfo> Statistics::netInfo_;
|
||||
vector<CompressInfo> Statistics::compressInfo_;
|
||||
@ -58,76 +59,6 @@ fromLog (Params& v)
|
||||
|
||||
|
||||
|
||||
void
|
||||
normalize (Params& v)
|
||||
{
|
||||
double sum;
|
||||
if (Globals::logDomain) {
|
||||
sum = addIdenty();
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
logSum (sum, v[i]);
|
||||
}
|
||||
assert (sum != -numeric_limits<double>::infinity());
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
v[i] -= sum;
|
||||
}
|
||||
} else {
|
||||
sum = 0.0;
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
sum += v[i];
|
||||
}
|
||||
assert (sum != 0.0);
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
v[i] /= sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
pow (Params& v, double expoent)
|
||||
{
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
v[i] *= expoent;
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
v[i] = std::pow (v[i], expoent);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
pow (Params& v, unsigned expoent)
|
||||
{
|
||||
if (expoent == 1) {
|
||||
return;
|
||||
}
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
v[i] *= expoent;
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
v[i] = std::pow (v[i], expoent);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
double
|
||||
pow (double p, unsigned expoent)
|
||||
{
|
||||
return Globals::logDomain ? p * expoent : std::pow (p, expoent);
|
||||
}
|
||||
|
||||
|
||||
|
||||
double
|
||||
factorial (double num)
|
||||
{
|
||||
@ -153,52 +84,21 @@ nrCombinations (unsigned n, unsigned r)
|
||||
|
||||
|
||||
|
||||
double
|
||||
getL1Distance (const Params& v1, const Params& v2)
|
||||
unsigned
|
||||
expectedSize (const Ranges& ranges)
|
||||
{
|
||||
assert (v1.size() == v2.size());
|
||||
double dist = 0.0;
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned i = 0; i < v1.size(); i++) {
|
||||
dist += abs (exp(v1[i]) - exp(v2[i]));
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = 0; i < v1.size(); i++) {
|
||||
dist += abs (v1[i] - v2[i]);
|
||||
}
|
||||
unsigned prod = 1;
|
||||
for (unsigned i = 0; i < ranges.size(); i++) {
|
||||
prod *= ranges[i];
|
||||
}
|
||||
return dist;
|
||||
}
|
||||
|
||||
|
||||
|
||||
double
|
||||
getMaxNorm (const Params& v1, const Params& v2)
|
||||
{
|
||||
assert (v1.size() == v2.size());
|
||||
double max = 0.0;
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned i = 0; i < v1.size(); i++) {
|
||||
double diff = abs (exp(v1[i]) - exp(v2[i]));
|
||||
if (diff > max) {
|
||||
max = diff;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = 0; i < v1.size(); i++) {
|
||||
double diff = abs (v1[i] - v2[i]);
|
||||
if (diff > max) {
|
||||
max = diff;
|
||||
}
|
||||
}
|
||||
}
|
||||
return max;
|
||||
return prod;
|
||||
}
|
||||
|
||||
|
||||
|
||||
unsigned
|
||||
getNumberOfDigits (int number) {
|
||||
getNumberOfDigits (int number)
|
||||
{
|
||||
unsigned count = 1;
|
||||
while (number >= 10) {
|
||||
number /= 10;
|
||||
@ -257,6 +157,168 @@ getJointStateStrings (const VarNodes& vars)
|
||||
|
||||
|
||||
|
||||
void printHeader (string header, std::ostream& os)
|
||||
{
|
||||
printAsteriskLine (os);
|
||||
os << header << endl;
|
||||
printAsteriskLine (os);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void printSubHeader (string header, std::ostream& os)
|
||||
{
|
||||
printDashedLine (os);
|
||||
os << header << endl;
|
||||
printDashedLine (os);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void printAsteriskLine (std::ostream& os)
|
||||
{
|
||||
os << "********************************" ;
|
||||
os << "********************************" ;
|
||||
os << endl;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void printDashedLine (std::ostream& os)
|
||||
{
|
||||
os << "--------------------------------" ;
|
||||
os << "--------------------------------" ;
|
||||
os << endl;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
namespace LogAware {
|
||||
|
||||
void
|
||||
normalize (Params& v)
|
||||
{
|
||||
double sum;
|
||||
if (Globals::logDomain) {
|
||||
sum = LogAware::addIdenty();
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
sum = Util::logSum (sum, v[i]);
|
||||
}
|
||||
assert (sum != -numeric_limits<double>::infinity());
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
v[i] -= sum;
|
||||
}
|
||||
} else {
|
||||
sum = 0.0;
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
sum += v[i];
|
||||
}
|
||||
assert (sum != 0.0);
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
v[i] /= sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
double
|
||||
getL1Distance (const Params& v1, const Params& v2)
|
||||
{
|
||||
assert (v1.size() == v2.size());
|
||||
double dist = 0.0;
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned i = 0; i < v1.size(); i++) {
|
||||
dist += abs (exp(v1[i]) - exp(v2[i]));
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = 0; i < v1.size(); i++) {
|
||||
dist += abs (v1[i] - v2[i]);
|
||||
}
|
||||
}
|
||||
return dist;
|
||||
}
|
||||
|
||||
|
||||
|
||||
double
|
||||
getMaxNorm (const Params& v1, const Params& v2)
|
||||
{
|
||||
assert (v1.size() == v2.size());
|
||||
double max = 0.0;
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned i = 0; i < v1.size(); i++) {
|
||||
double diff = abs (exp(v1[i]) - exp(v2[i]));
|
||||
if (diff > max) {
|
||||
max = diff;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = 0; i < v1.size(); i++) {
|
||||
double diff = abs (v1[i] - v2[i]);
|
||||
if (diff > max) {
|
||||
max = diff;
|
||||
}
|
||||
}
|
||||
}
|
||||
return max;
|
||||
}
|
||||
|
||||
|
||||
double
|
||||
pow (double p, unsigned expoent)
|
||||
{
|
||||
return Globals::logDomain ? p * expoent : std::pow (p, expoent);
|
||||
}
|
||||
|
||||
|
||||
|
||||
double
|
||||
pow (double p, double expoent)
|
||||
{
|
||||
// assumes that `expoent' is never in log domain
|
||||
return Globals::logDomain ? p * expoent : std::pow (p, expoent);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
pow (Params& v, unsigned expoent)
|
||||
{
|
||||
if (expoent == 1) {
|
||||
return;
|
||||
}
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
v[i] *= expoent;
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
v[i] = std::pow (v[i], expoent);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
pow (Params& v, double expoent)
|
||||
{
|
||||
// assumes that `expoent' is never in log domain
|
||||
if (Globals::logDomain) {
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
v[i] *= expoent;
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
v[i] = std::pow (v[i], expoent);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
@ -286,8 +348,11 @@ Statistics::getPrimaryNetworksCounting (void)
|
||||
|
||||
|
||||
void
|
||||
Statistics::updateStatistics (unsigned size, bool loopy,
|
||||
unsigned nIters, double time)
|
||||
Statistics::updateStatistics (
|
||||
unsigned size,
|
||||
bool loopy,
|
||||
unsigned nIters,
|
||||
double time)
|
||||
{
|
||||
netInfo_.push_back (NetInfo (size, loopy, nIters, time));
|
||||
}
|
||||
@ -318,11 +383,12 @@ Statistics::writeStatisticsToFile (const char* fileName)
|
||||
|
||||
|
||||
void
|
||||
Statistics::updateCompressingStatistics (unsigned nGroundVars,
|
||||
unsigned nGroundFactors,
|
||||
unsigned nClusterVars,
|
||||
unsigned nClusterFactors,
|
||||
unsigned nWithoutNeighs) {
|
||||
Statistics::updateCompressingStatistics (
|
||||
unsigned nGroundVars,
|
||||
unsigned nGroundFactors,
|
||||
unsigned nClusterVars,
|
||||
unsigned nClusterFactors,
|
||||
unsigned nWithoutNeighs) {
|
||||
compressInfo_.push_back (CompressInfo (nGroundVars, nGroundFactors,
|
||||
nClusterVars, nClusterFactors, nWithoutNeighs));
|
||||
}
|
||||
@ -334,7 +400,7 @@ Statistics::getStatisticString (void)
|
||||
{
|
||||
stringstream ss2, ss3, ss4, ss1;
|
||||
ss1 << "running mode: " ;
|
||||
switch (InfAlgorithms::infAlgorithm) {
|
||||
switch (Globals::infAlgorithm) {
|
||||
case InfAlgorithms::VE: ss1 << "ve" << endl; break;
|
||||
case InfAlgorithms::BN_BP: ss1 << "bn_bp" << endl; break;
|
||||
case InfAlgorithms::FG_BP: ss1 << "fg_bp" << endl; break;
|
||||
@ -342,18 +408,23 @@ Statistics::getStatisticString (void)
|
||||
}
|
||||
ss1 << "message schedule: " ;
|
||||
switch (BpOptions::schedule) {
|
||||
case BpOptions::Schedule::SEQ_FIXED: ss1 << "sequential fixed" << endl; break;
|
||||
case BpOptions::Schedule::SEQ_RANDOM: ss1 << "sequential random" << endl; break;
|
||||
case BpOptions::Schedule::PARALLEL: ss1 << "parallel" << endl; break;
|
||||
case BpOptions::Schedule::MAX_RESIDUAL: ss1 << "max residual" << endl; break;
|
||||
case BpOptions::Schedule::SEQ_FIXED:
|
||||
ss1 << "sequential fixed" << endl;
|
||||
break;
|
||||
case BpOptions::Schedule::SEQ_RANDOM:
|
||||
ss1 << "sequential random" << endl;
|
||||
break;
|
||||
case BpOptions::Schedule::PARALLEL:
|
||||
ss1 << "parallel" << endl;
|
||||
break;
|
||||
case BpOptions::Schedule::MAX_RESIDUAL:
|
||||
ss1 << "max residual" << endl;
|
||||
break;
|
||||
}
|
||||
ss1 << "max iterations: " << BpOptions::maxIter << endl;
|
||||
ss1 << "accuracy " << BpOptions::accuracy << endl;
|
||||
ss1 << endl << endl;
|
||||
|
||||
ss2 << "---------------------------------------------------" << endl;
|
||||
ss2 << " Network information" << endl;
|
||||
ss2 << "---------------------------------------------------" << endl;
|
||||
Util::printSubHeader ("Network information", ss2);
|
||||
ss2 << left;
|
||||
ss2 << setw (15) << "Network Size" ;
|
||||
ss2 << setw (9) << "Loopy" ;
|
||||
@ -387,9 +458,7 @@ Statistics::getStatisticString (void)
|
||||
|
||||
unsigned c1 = 0, c2 = 0, c3 = 0, c4 = 0;
|
||||
if (compressInfo_.size() > 0) {
|
||||
ss3 << "---------------------------------------------------" << endl;
|
||||
ss3 << " Compression information" << endl;
|
||||
ss3 << "---------------------------------------------------" << endl;
|
||||
Util::printSubHeader ("Compress information", ss3);
|
||||
ss3 << left;
|
||||
ss3 << "Ground Cluster Ground Cluster Neighborless" << endl;
|
||||
ss3 << "Vars Vars Factors Factors Vars" << endl;
|
||||
|
@ -1,53 +1,131 @@
|
||||
#ifndef HORUS_UTIL_H
|
||||
#define HORUS_UTIL_H
|
||||
|
||||
#include <cmath>
|
||||
#include <cassert>
|
||||
#include <limits>
|
||||
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <queue>
|
||||
#include <unordered_map>
|
||||
|
||||
#include <sstream>
|
||||
#include <iostream>
|
||||
|
||||
#include "Horus.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
||||
namespace Util {
|
||||
|
||||
void toLog (Params&);
|
||||
void fromLog (Params&);
|
||||
void normalize (Params&);
|
||||
void logSum (double&, double);
|
||||
void multiply (Params&, const Params&);
|
||||
void multiply (Params&, const Params&, unsigned);
|
||||
void add (Params&, const Params&);
|
||||
void add (Params&, const Params&, unsigned);
|
||||
void pow (Params&, double);
|
||||
void pow (Params&, unsigned);
|
||||
double pow (double, unsigned);
|
||||
double factorial (double);
|
||||
unsigned nrCombinations (unsigned, unsigned);
|
||||
double getL1Distance (const Params&, const Params&);
|
||||
double getMaxNorm (const Params&, const Params&);
|
||||
unsigned getNumberOfDigits (int);
|
||||
bool isInteger (const string&);
|
||||
string parametersToString (const Params&, unsigned = PRECISION);
|
||||
vector<string> getJointStateStrings (const VarNodes&);
|
||||
double tl (double);
|
||||
double fl (double);
|
||||
double multIdenty();
|
||||
double addIdenty();
|
||||
double withEvidence();
|
||||
double noEvidence();
|
||||
double one();
|
||||
double zero();
|
||||
template <typename T> void addToVector (vector<T>&, const vector<T>&);
|
||||
|
||||
template <typename T> void addToQueue (queue<T>&, const vector<T>&);
|
||||
|
||||
template <typename T> bool contains (const vector<T>&, const T&);
|
||||
|
||||
template <typename T> bool contains (const set<T>&, const T&);
|
||||
|
||||
template <typename K, typename V> bool contains (
|
||||
const unordered_map<K, V>&, const K&);
|
||||
|
||||
template <typename T> std::string toString (const T&);
|
||||
|
||||
void toLog (Params&);
|
||||
|
||||
void fromLog (Params&);
|
||||
|
||||
double logSum (double, double);
|
||||
|
||||
void multiply (Params&, const Params&);
|
||||
|
||||
void multiply (Params&, const Params&, unsigned);
|
||||
|
||||
void add (Params&, const Params&);
|
||||
|
||||
void add (Params&, const Params&, unsigned);
|
||||
|
||||
double factorial (double);
|
||||
|
||||
unsigned nrCombinations (unsigned, unsigned);
|
||||
|
||||
unsigned expectedSize (const Ranges&);
|
||||
|
||||
unsigned getNumberOfDigits (int);
|
||||
|
||||
bool isInteger (const string&);
|
||||
|
||||
string parametersToString (const Params&, unsigned = Constants::PRECISION);
|
||||
|
||||
vector<string> getJointStateStrings (const VarNodes&);
|
||||
|
||||
void printHeader (string, std::ostream& os = std::cout);
|
||||
|
||||
void printSubHeader (string, std::ostream& os = std::cout);
|
||||
|
||||
void printAsteriskLine (std::ostream& os = std::cout);
|
||||
|
||||
void printDashedLine (std::ostream& os = std::cout);
|
||||
|
||||
unsigned maxUnsigned (void);
|
||||
|
||||
};
|
||||
|
||||
|
||||
template <class T>
|
||||
std::string toString (const T& t)
|
||||
|
||||
template <typename T> void
|
||||
Util::addToVector (vector<T>& v, const vector<T>& elements)
|
||||
{
|
||||
v.insert (v.end(), elements.begin(), elements.end());
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T> void
|
||||
Util::addToQueue (queue<T>& q, const vector<T>& elements)
|
||||
{
|
||||
for (unsigned i = 0; i < elements.size(); i++) {
|
||||
q.push (elements[i]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T> bool
|
||||
Util::contains (const vector<T>& v, const T& e)
|
||||
{
|
||||
return std::find (v.begin(), v.end(), e) != v.end();
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T> bool
|
||||
Util::contains (const set<T>& s, const T& e)
|
||||
{
|
||||
return s.find (e) != s.end();
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename K, typename V> bool
|
||||
Util::contains (
|
||||
const unordered_map<K, V>& m, const K& k)
|
||||
{
|
||||
return m.find (k) != m.end();
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T> std::string
|
||||
Util::toString (const T& t)
|
||||
{
|
||||
std::stringstream ss;
|
||||
ss << t;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
@ -62,28 +140,31 @@ std::ostream& operator << (std::ostream& os, const vector<T>& v)
|
||||
}
|
||||
|
||||
|
||||
namespace {
|
||||
const double INF = -numeric_limits<double>::infinity();
|
||||
};
|
||||
|
||||
|
||||
inline void
|
||||
Util::logSum (double& x, double y)
|
||||
inline double
|
||||
Util::logSum (double x, double y)
|
||||
{
|
||||
x = log (exp (x) + exp (y)); return;
|
||||
return log (exp (x) + exp (y));
|
||||
assert (isfinite (x) && isfinite (y));
|
||||
// If one value is much smaller than the other, keep the larger value.
|
||||
if (x < (y - log (1e200))) {
|
||||
x = y;
|
||||
return;
|
||||
return y;
|
||||
}
|
||||
if (y < (x - log (1e200))) {
|
||||
return;
|
||||
return x;
|
||||
}
|
||||
double diff = x - y;
|
||||
assert (isfinite (diff) && isfinite (x) && isfinite (y));
|
||||
if (!isfinite (exp (diff))) { // difference is too large
|
||||
x = x > y ? x : y;
|
||||
} else { // otherwise return the sum.
|
||||
x = y + log (static_cast<double>(1.0) + exp (diff));
|
||||
if (!isfinite (exp (diff))) {
|
||||
// difference is too large
|
||||
return x > y ? x : y;
|
||||
}
|
||||
// otherwise return the sum.
|
||||
return y + log (static_cast<double>(1.0) + exp (diff));
|
||||
}
|
||||
|
||||
|
||||
@ -140,52 +221,87 @@ Util::add (Params& v1, const Params& v2, unsigned repetitions)
|
||||
|
||||
|
||||
|
||||
inline double
|
||||
Util::tl (double v)
|
||||
inline unsigned
|
||||
Util::maxUnsigned (void)
|
||||
{
|
||||
return Globals::logDomain ? log(v) : v;
|
||||
return numeric_limits<unsigned>::max();
|
||||
}
|
||||
|
||||
inline double
|
||||
Util::fl (double v)
|
||||
{
|
||||
return Globals::logDomain ? exp(v) : v;
|
||||
}
|
||||
|
||||
|
||||
namespace LogAware {
|
||||
|
||||
inline double
|
||||
Util::multIdenty() {
|
||||
return Globals::logDomain ? 0.0 : 1.0;
|
||||
}
|
||||
|
||||
inline double
|
||||
Util::addIdenty()
|
||||
{
|
||||
return Globals::logDomain ? INF : 0.0;
|
||||
}
|
||||
|
||||
inline double
|
||||
Util::withEvidence()
|
||||
one()
|
||||
{
|
||||
return Globals::logDomain ? 0.0 : 1.0;
|
||||
}
|
||||
|
||||
inline double
|
||||
Util::noEvidence() {
|
||||
return Globals::logDomain ? INF : 0.0;
|
||||
}
|
||||
|
||||
inline double
|
||||
Util::one()
|
||||
{
|
||||
return Globals::logDomain ? 0.0 : 1.0;
|
||||
}
|
||||
|
||||
inline double
|
||||
Util::zero() {
|
||||
zero() {
|
||||
return Globals::logDomain ? INF : 0.0 ;
|
||||
}
|
||||
|
||||
|
||||
inline double
|
||||
addIdenty()
|
||||
{
|
||||
return Globals::logDomain ? INF : 0.0;
|
||||
}
|
||||
|
||||
|
||||
inline double
|
||||
multIdenty()
|
||||
{
|
||||
return Globals::logDomain ? 0.0 : 1.0;
|
||||
}
|
||||
|
||||
|
||||
inline double
|
||||
withEvidence()
|
||||
{
|
||||
return Globals::logDomain ? 0.0 : 1.0;
|
||||
}
|
||||
|
||||
|
||||
inline double
|
||||
noEvidence() {
|
||||
return Globals::logDomain ? INF : 0.0;
|
||||
}
|
||||
|
||||
|
||||
inline double
|
||||
tl (double v)
|
||||
{
|
||||
return Globals::logDomain ? log (v) : v;
|
||||
}
|
||||
|
||||
|
||||
inline double
|
||||
fl (double v)
|
||||
{
|
||||
return Globals::logDomain ? exp (v) : v;
|
||||
}
|
||||
|
||||
|
||||
void normalize (Params&);
|
||||
|
||||
double getL1Distance (const Params&, const Params&);
|
||||
|
||||
double getMaxNorm (const Params&, const Params&);
|
||||
|
||||
double pow (double, unsigned);
|
||||
|
||||
double pow (double, double);
|
||||
|
||||
void pow (Params&, unsigned);
|
||||
|
||||
void pow (Params&, double);
|
||||
|
||||
};
|
||||
|
||||
|
||||
struct NetInfo
|
||||
{
|
||||
NetInfo (unsigned size, bool loopy, unsigned nIters, double time)
|
||||
@ -224,11 +340,17 @@ class Statistics
|
||||
{
|
||||
public:
|
||||
static unsigned getSolvedNetworksCounting (void);
|
||||
|
||||
static void incrementPrimaryNetworksCounting (void);
|
||||
|
||||
static unsigned getPrimaryNetworksCounting (void);
|
||||
|
||||
static void updateStatistics (unsigned, bool, unsigned, double);
|
||||
|
||||
static void printStatistics (void);
|
||||
|
||||
static void writeStatisticsToFile (const char*);
|
||||
|
||||
static void updateCompressingStatistics (
|
||||
unsigned, unsigned, unsigned, unsigned, unsigned);
|
||||
|
||||
|
@ -56,7 +56,7 @@ VarElimSolver::getJointDistributionOf (const VarIds& vids)
|
||||
introduceEvidence();
|
||||
chooseEliminationOrder (vids);
|
||||
processFactorList (vids);
|
||||
Params params = factorList_.back()->getParameters();
|
||||
Params params = factorList_.back()->params();
|
||||
if (Globals::logDomain) {
|
||||
Util::fromLog (params);
|
||||
}
|
||||
@ -98,7 +98,7 @@ VarElimSolver::introduceEvidence (void)
|
||||
varFactors_.find (varNodes[i]->varId())->second;
|
||||
for (unsigned j = 0; j < idxs.size(); j++) {
|
||||
Factor* factor = factorList_[idxs[j]];
|
||||
if (factor->nrVariables() == 1) {
|
||||
if (factor->nrArguments() == 1) {
|
||||
factorList_[idxs[j]] = 0;
|
||||
} else {
|
||||
factorList_[idxs[j]]->absorveEvidence (
|
||||
@ -121,8 +121,8 @@ VarElimSolver::chooseEliminationOrder (const VarIds& vids)
|
||||
const FgVarSet& varNodes = factorGraph_->getVarNodes();
|
||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||
VarId vid = varNodes[i]->varId();
|
||||
if (std::find (vids.begin(), vids.end(), vid) == vids.end()
|
||||
&& !varNodes[i]->hasEvidence()) {
|
||||
if (Util::contains (vids, vid) == false &&
|
||||
varNodes[i]->hasEvidence() == false) {
|
||||
elimOrder_.push_back (vid);
|
||||
}
|
||||
}
|
||||
@ -154,7 +154,7 @@ VarElimSolver::processFactorList (const VarIds& vids)
|
||||
}
|
||||
}
|
||||
|
||||
finalFactor->reorderVariables (unobservedVids);
|
||||
finalFactor->reorderArguments (unobservedVids);
|
||||
finalFactor->normalize();
|
||||
factorList_.push_back (finalFactor);
|
||||
}
|
||||
@ -179,10 +179,10 @@ VarElimSolver::eliminate (VarId elimVar)
|
||||
factorList_[idx] = 0;
|
||||
}
|
||||
}
|
||||
if (result != 0 && result->nrVariables() != 1) {
|
||||
if (result != 0 && result->nrArguments() != 1) {
|
||||
result->sumOut (vn->varId());
|
||||
factorList_.push_back (result);
|
||||
const VarIds& resultVarIds = result->getVarIds();
|
||||
const VarIds& resultVarIds = result->arguments();
|
||||
for (unsigned i = 0; i < resultVarIds.size(); i++) {
|
||||
vector<unsigned>& idxs =
|
||||
varFactors_.find (resultVarIds[i])->second;
|
||||
|
@ -16,18 +16,28 @@ class VarElimSolver : public Solver
|
||||
{
|
||||
public:
|
||||
VarElimSolver (const BayesNet&);
|
||||
|
||||
VarElimSolver (const FactorGraph&);
|
||||
|
||||
~VarElimSolver (void);
|
||||
void runSolver (void) { }
|
||||
Params getPosterioriOf (VarId);
|
||||
Params getJointDistributionOf (const VarIds&);
|
||||
|
||||
void runSolver (void) { }
|
||||
|
||||
Params getPosterioriOf (VarId);
|
||||
|
||||
Params getJointDistributionOf (const VarIds&);
|
||||
|
||||
private:
|
||||
void createFactorList (void);
|
||||
|
||||
void introduceEvidence (void);
|
||||
|
||||
void chooseEliminationOrder (const VarIds&);
|
||||
|
||||
void processFactorList (const VarIds&);
|
||||
|
||||
void eliminate (VarId);
|
||||
|
||||
void printActiveFactors (void);
|
||||
|
||||
const BayesNet* bayesNet_;
|
||||
|
@ -40,8 +40,8 @@ VarNode::isValidState (int stateIndex)
|
||||
bool
|
||||
VarNode::isValidState (const string& stateName)
|
||||
{
|
||||
States states = GraphicalModel::getVariableInformation (varId_).states;
|
||||
return find (states.begin(), states.end(), stateName) != states.end();
|
||||
States states = GraphicalModel::getVarInformation (varId_).states;
|
||||
return Util::contains (states, stateName);
|
||||
}
|
||||
|
||||
|
||||
@ -58,7 +58,7 @@ VarNode::setEvidence (int ev)
|
||||
void
|
||||
VarNode::setEvidence (const string& ev)
|
||||
{
|
||||
States states = GraphicalModel::getVariableInformation (varId_).states;
|
||||
States states = GraphicalModel::getVarInformation (varId_).states;
|
||||
for (unsigned i = 0; i < states.size(); i++) {
|
||||
if (states[i] == ev) {
|
||||
evidence_ = i;
|
||||
@ -74,7 +74,7 @@ string
|
||||
VarNode::label (void) const
|
||||
{
|
||||
if (GraphicalModel::variablesHaveInformation()) {
|
||||
return GraphicalModel::getVariableInformation (varId_).label;
|
||||
return GraphicalModel::getVarInformation (varId_).label;
|
||||
}
|
||||
stringstream ss;
|
||||
ss << "x" << varId_;
|
||||
@ -87,7 +87,7 @@ States
|
||||
VarNode::states (void) const
|
||||
{
|
||||
if (GraphicalModel::variablesHaveInformation()) {
|
||||
return GraphicalModel::getVariableInformation (varId_).states;
|
||||
return GraphicalModel::getVarInformation (varId_).states;
|
||||
}
|
||||
States states;
|
||||
for (unsigned i = 0; i < nrStates_; i++) {
|
||||
|
@ -1,6 +1,10 @@
|
||||
#ifndef HORUS_VARNODE_H
|
||||
#define HORUS_VARNODE_H
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "Horus.h"
|
||||
|
||||
using namespace std;
|
||||
@ -9,25 +13,28 @@ class VarNode
|
||||
{
|
||||
public:
|
||||
VarNode (const VarNode*);
|
||||
VarNode (VarId, unsigned, int = NO_EVIDENCE);
|
||||
virtual ~VarNode (void) {};
|
||||
|
||||
bool isValidState (int);
|
||||
bool isValidState (const string&);
|
||||
void setEvidence (int);
|
||||
void setEvidence (const string&);
|
||||
string label (void) const;
|
||||
States states (void) const;
|
||||
VarNode (VarId, unsigned, int = Constants::NO_EVIDENCE);
|
||||
|
||||
unsigned varId (void) const { return varId_; }
|
||||
unsigned nrStates (void) const { return nrStates_; }
|
||||
bool hasEvidence (void) const { return evidence_ != NO_EVIDENCE; }
|
||||
int getEvidence (void) const { return evidence_; }
|
||||
unsigned getIndex (void) const { return index_; }
|
||||
void setIndex (unsigned idx) { index_ = idx; }
|
||||
virtual ~VarNode (void) { };
|
||||
|
||||
unsigned varId (void) const { return varId_; }
|
||||
|
||||
unsigned nrStates (void) const { return nrStates_; }
|
||||
|
||||
int getEvidence (void) const { return evidence_; }
|
||||
|
||||
unsigned getIndex (void) const { return index_; }
|
||||
|
||||
void setIndex (unsigned idx) { index_ = idx; }
|
||||
|
||||
operator unsigned () const { return index_; }
|
||||
|
||||
bool hasEvidence (void) const
|
||||
{
|
||||
return evidence_ != Constants::NO_EVIDENCE;
|
||||
}
|
||||
|
||||
bool operator== (const VarNode& var) const
|
||||
{
|
||||
cout << "equal operator called" << endl;
|
||||
@ -42,11 +49,23 @@ class VarNode
|
||||
return varId_ != var.varId();
|
||||
}
|
||||
|
||||
bool isValidState (int);
|
||||
|
||||
bool isValidState (const string&);
|
||||
|
||||
void setEvidence (int);
|
||||
|
||||
void setEvidence (const string&);
|
||||
|
||||
string label (void) const;
|
||||
|
||||
States states (void) const;
|
||||
|
||||
private:
|
||||
VarId varId_;
|
||||
unsigned nrStates_;
|
||||
int evidence_;
|
||||
unsigned index_;
|
||||
VarId varId_;
|
||||
unsigned nrStates_;
|
||||
int evidence_;
|
||||
unsigned index_;
|
||||
|
||||
};
|
||||
|
||||
|
@ -13,8 +13,8 @@ function run_solver
|
||||
{
|
||||
if [ $2 = bp ]
|
||||
then
|
||||
extra_flag1=clpbn_bp:set_horus_flag\(inf_alg,$4\)
|
||||
extra_flag2=clpbn_bp:set_horus_flag\(schedule,$5\)
|
||||
extra_flag1=clpbn_horus:set_horus_flag\(inf_alg,$4\)
|
||||
extra_flag2=clpbn_horus:set_horus_flag\(schedule,$5\)
|
||||
else
|
||||
extra_flag1=true
|
||||
extra_flag2=true
|
||||
@ -22,7 +22,7 @@ fi
|
||||
/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME
|
||||
[$1].
|
||||
clpbn:set_clpbn_flag(solver,$2),
|
||||
clpbn_bp:set_horus_flag(use_logarithms, true),
|
||||
clpbn_horus:set_horus_flag(use_logarithms, true),
|
||||
$extra_flag1, $extra_flag2,
|
||||
run_query(_R),
|
||||
open("$OUT_FILE_NAME", 'append',S),
|
||||
|
@ -13,8 +13,8 @@ function run_solver
|
||||
{
|
||||
if [ $2 = bp ]
|
||||
then
|
||||
extra_flag1=clpbn_bp:set_horus_flag\(inf_alg,$4\)
|
||||
extra_flag2=clpbn_bp:set_horus_flag\(schedule,$5\)
|
||||
extra_flag1=clpbn_horus:set_horus_flag\(inf_alg,$4\)
|
||||
extra_flag2=clpbn_horus:set_horus_flag\(schedule,$5\)
|
||||
else
|
||||
extra_flag1=true
|
||||
extra_flag2=true
|
||||
@ -22,7 +22,7 @@ fi
|
||||
/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME
|
||||
[$1].
|
||||
clpbn:set_clpbn_flag(solver,$2),
|
||||
clpbn_bp:set_horus_flag(use_logarithms, true),
|
||||
clpbn_horus:set_horus_flag(use_logarithms, true),
|
||||
$extra_flag1, $extra_flag2,
|
||||
run_query(_R),
|
||||
open("$OUT_FILE_NAME", 'append',S),
|
||||
@ -37,6 +37,8 @@ function run_all_graphs
|
||||
echo "*******************************************************************" >> "$OUT_FILE_NAME"
|
||||
echo "results for solver $2" >> $OUT_FILE_NAME
|
||||
echo "*******************************************************************" >> "$OUT_FILE_NAME"
|
||||
run_solver town_3 $1 town_3 $3 $4 $5
|
||||
return
|
||||
run_solver town_1000 $1 town_1000 $3 $4 $5
|
||||
run_solver town_5000 $1 town_5000 $3 $4 $5
|
||||
run_solver town_10000 $1 town_10000 $3 $4 $5
|
||||
|
@ -13,8 +13,8 @@ function run_solver
|
||||
{
|
||||
if [ $2 = bp ]
|
||||
then
|
||||
extra_flag1=clpbn_bp:set_horus_flag\(inf_alg,$4\)
|
||||
extra_flag2=clpbn_bp:set_horus_flag\(schedule,$5\)
|
||||
extra_flag1=clpbn_horus:set_horus_flag\(inf_alg,$4\)
|
||||
extra_flag2=clpbn_horus:set_horus_flag\(schedule,$5\)
|
||||
else
|
||||
extra_flag1=true
|
||||
extra_flag2=true
|
||||
@ -22,7 +22,7 @@ fi
|
||||
/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME
|
||||
[$1].
|
||||
clpbn:set_clpbn_flag(solver,$2),
|
||||
clpbn_bp:set_horus_flag(use_logarithms, true),
|
||||
clpbn_horus:set_horus_flag(use_logarithms, true),
|
||||
$extra_flag1, $extra_flag2,
|
||||
run_query(_R),
|
||||
open("$OUT_FILE_NAME", 'append',S),
|
||||
|
29
packages/CLPBN/clpbn/bp/benchmarks/city/town_3.yap
Normal file
29
packages/CLPBN/clpbn/bp/benchmarks/city/town_3.yap
Normal file
@ -0,0 +1,29 @@
|
||||
:- source.
|
||||
:- style_check(all).
|
||||
:- yap_flag(unknown,error).
|
||||
:- yap_flag(write_strings,on).
|
||||
:- use_module(library(clpbn)).
|
||||
:- set_clpbn_flag(solver, bp).
|
||||
:- [-schema].
|
||||
|
||||
lives(_joe, nyc).
|
||||
|
||||
run_query(Guilty) :-
|
||||
guilty(joe, Guilty),
|
||||
witness(nyc, t),
|
||||
runall(X, ev(X)).
|
||||
|
||||
|
||||
runall(G, Wrapper) :-
|
||||
findall(G, Wrapper, L),
|
||||
execute_all(L).
|
||||
|
||||
|
||||
execute_all([]).
|
||||
execute_all(G.L) :-
|
||||
call(G),
|
||||
execute_all(L).
|
||||
|
||||
|
||||
ev(descn(p2, t)).
|
||||
ev(descn(p3, t)).
|
@ -9,7 +9,7 @@ OUT_FILE_NAME=results.log
|
||||
rm -f $OUT_FILE_NAME
|
||||
rm -f ignore.$OUT_FILE_NAME
|
||||
|
||||
# yap -g "['../../../../examples/School/school_32'], [missing5], use_module(library(clpbn/learning/em)), graph(L), clpbn:set_clpbn_flag(em_solver,bp), clpbn_bp:set_horus_flag(inf_alg,ve), statistics(runtime, _), em(L,0.01,10,_,Lik), statistics(runtime, [T,_])."
|
||||
# yap -g "['../../../../examples/School/sch32'], [missing5], use_module(library(clpbn/learning/em)), graph(L), clpbn:set_clpbn_flag(em_solver,bp), clpbn_horus:set_horus_flag(inf_alg,fg_bp), statistics(runtime, _), em(L,0.01,10,_,Lik), statistics(runtime, [T,_])."
|
||||
|
||||
function run_solver
|
||||
{
|
||||
@ -17,11 +17,11 @@ if [ $2 = bp ]
|
||||
then
|
||||
if [ $4 = ve ]
|
||||
then
|
||||
extra_flag1=clpbn_bp:set_horus_flag\(inf_alg,$4\)
|
||||
extra_flag2=clpbn_bp:set_horus_flag\(elim_heuristic,$5\)
|
||||
extra_flag1=clpbn_horus:set_horus_flag\(inf_alg,$4\)
|
||||
extra_flag2=clpbn_horus:set_horus_flag\(elim_heuristic,$5\)
|
||||
else
|
||||
extra_flag1=clpbn_bp:set_horus_flag\(inf_alg,$4\)
|
||||
extra_flag2=clpbn_bp:set_horus_flag\(schedule,$5\)
|
||||
extra_flag1=clpbn_horus:set_horus_flag\(inf_alg,$4\)
|
||||
extra_flag2=clpbn_horus:set_horus_flag\(schedule,$5\)
|
||||
fi
|
||||
else
|
||||
extra_flag1=true
|
||||
@ -29,7 +29,7 @@ else
|
||||
fi
|
||||
/usr/bin/time -o "$OUT_FILE_NAME" -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF &>> "ignore.$OUT_FILE_NAME"
|
||||
:- [pos:train].
|
||||
:- ['../../../../examples/School/school_32'].
|
||||
:- ['../../../../examples/School/sch32'].
|
||||
:- use_module(library(clpbn/learning/em)).
|
||||
:- use_module(library(clpbn/bp)).
|
||||
[$1].
|
||||
@ -57,12 +57,11 @@ function run_all_graphs
|
||||
#run_solver missing50 $1 missing50 $3 $4 $5
|
||||
}
|
||||
|
||||
run_solver missing5 ve missing5 $3 $4 $5
|
||||
exit
|
||||
run_all_graphs bp "hve(min_neighbors) " ve min_neighbors
|
||||
|
||||
#run_all_graphs bp "hve(min_neighbors) " ve min_neighbors
|
||||
#run_all_graphs bp "bn_bp(seq_fixed) " bn_bp seq_fixed
|
||||
#run_all_graphs bp "fg_bp(seq_fixed) " fg_bp seq_fixed
|
||||
#run_all_graphs bp "cbp(seq_fixed) " cbp seq_fixed
|
||||
run_all_graphs bp "fg_bp(seq_fixed) " fg_bp seq_fixed
|
||||
#run_all_graphs bp "cbp(seq_fixed) " cbp seq_fixed
|
||||
exit
|
||||
|
||||
|
||||
|
18
packages/CLPBN/clpbn/bp/examples/allopstest.yap
Normal file
18
packages/CLPBN/clpbn/bp/examples/allopstest.yap
Normal file
@ -0,0 +1,18 @@
|
||||
|
||||
:- use_module(library(pfl)).
|
||||
|
||||
:- set_clpbn_flag(solver,fove).
|
||||
|
||||
|
||||
c(x1,y1,z1).
|
||||
c(x1,y1,z2).
|
||||
c(x2,y2,z1).
|
||||
c(x3,y2,z1).
|
||||
|
||||
bayes p(X)::[t,f] ; [0.2, 0.4] ; [c(X,_,_)].
|
||||
|
||||
bayes q(Y)::[t,f] ; [0.5, 0.6] ; [c(_,Y,_)].
|
||||
|
||||
bayes s(Z)::[t,f] , p(X) , q(Y) ; [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8] ; [c(X,Y,Z)].
|
||||
|
||||
% bayes series::[t,f] , attends(X) ; [0.5, 0.6, 0.7, 0.8] ; [c(X,_)].
|
78022
packages/CLPBN/clpbn/bp/examples/city_test.fg
Normal file
78022
packages/CLPBN/clpbn/bp/examples/city_test.fg
Normal file
File diff suppressed because it is too large
Load Diff
@ -25,6 +25,8 @@ markov attends(P)::[t,f] , hot(W)::[t,f] ; [0.1, 0.2, 0.3, 0.4] ; [c(P,W)].
|
||||
|
||||
markov attends(P)::[t,f], series::[t,f] ; [0.5, 0.6, 0.7, 0.8] ; [c(P,_)].
|
||||
|
||||
:- clpbn_horus:set_horus_flag(use_logarithms,true).
|
||||
|
||||
?- series(X).
|
||||
|
||||
|
||||
|
20
packages/CLPBN/clpbn/bp/examples/fail.yap
Normal file
20
packages/CLPBN/clpbn/bp/examples/fail.yap
Normal file
@ -0,0 +1,20 @@
|
||||
|
||||
:- use_module(library(pfl)).
|
||||
|
||||
:- set_clpbn_flag(solver,ve).
|
||||
%:- set_clpbn_flag(solver,fove).
|
||||
|
||||
|
||||
t(ann).
|
||||
t(dave).
|
||||
|
||||
% p(ann,t).
|
||||
|
||||
bayes p(X)::[t,f] ; [0.1, 0.3] ; [t(X)].
|
||||
|
||||
% use standard Prolog queries: provide evidence first.
|
||||
|
||||
?- p(ann,t), p(ann,X).
|
||||
|
||||
% ?- p(ann,X).
|
||||
|
32
packages/CLPBN/clpbn/bp/examples/workshop_attrs.yap
Normal file
32
packages/CLPBN/clpbn/bp/examples/workshop_attrs.yap
Normal file
@ -0,0 +1,32 @@
|
||||
|
||||
:- use_module(library(pfl)).
|
||||
|
||||
%:- set_clpbn_flag(solver,ve).
|
||||
%:- set_clpbn_flag(solver,bp), clpbn_bp:set_horus_flag(inf_alg,ve).
|
||||
:- set_clpbn_flag(solver,fove).
|
||||
|
||||
c(p1).
|
||||
c(p2).
|
||||
c(p3).
|
||||
c(p4).
|
||||
c(p5).
|
||||
|
||||
|
||||
markov attends(P)::[t,f] , attr1::[t,f] ; [0.1, 0.2, 0.3, 0.4] ; [c(P)].
|
||||
|
||||
markov attends(P)::[t,f] , attr2::[t,f] ; [0.1, 0.2, 0.3, 0.4] ; [c(P)].
|
||||
|
||||
markov attends(P)::[t,f] , attr3::[t,f] ; [0.1, 0.2, 0.3, 0.4] ; [c(P)].
|
||||
|
||||
markov attends(P)::[t,f] , attr4::[t,f] ; [0.1, 0.2, 0.3, 0.4] ; [c(P)].
|
||||
|
||||
markov attends(P)::[t,f] , attr5::[t,f] ; [0.1, 0.2, 0.3, 0.4] ; [c(P)].
|
||||
|
||||
markov attends(P)::[t,f] , attr6::[t,f] ; [0.1, 0.2, 0.3, 0.4] ; [c(P)].
|
||||
|
||||
markov attends(P)::[t,f], series::[t,f] ; [0.5, 0.6, 0.7, 0.8] ; [c(P)].
|
||||
|
||||
%:- clpbn_horus:set_horus_flag(use_logarithms,true).
|
||||
|
||||
?- series(X).
|
||||
|
@ -1,9 +1,9 @@
|
||||
|
||||
/************************************************
|
||||
/*******************************************************
|
||||
|
||||
(GC) First Order Variable Elimination Interface
|
||||
First Order Variable Elimination Interface
|
||||
|
||||
**************************************************/
|
||||
********************************************************/
|
||||
|
||||
:- module(clpbn_fove,
|
||||
[fove/3,
|
||||
@ -23,75 +23,88 @@
|
||||
|
||||
|
||||
:- use_module(library(pfl),
|
||||
[factor/6,
|
||||
skolem/2
|
||||
[factor/5,
|
||||
skolem/2,
|
||||
get_pfl_parameters/2
|
||||
]).
|
||||
|
||||
|
||||
:- use_module(horus).
|
||||
:- use_module(horus,
|
||||
[create_lifted_network/3,
|
||||
set_parfactors_params/2,
|
||||
run_lifted_solver/3,
|
||||
free_parfactors/1
|
||||
]).
|
||||
|
||||
:- set_horus_flag(use_logarithms, false).
|
||||
%:- set_horus_flag(use_logarithms, true).
|
||||
|
||||
fove([[]], _, _) :- !.
|
||||
fove([QueryVars], AllVars, Output) :-
|
||||
writeln(queryVars:QueryVars),
|
||||
writeln(allVars:AllVars),
|
||||
init_fove_solver(_, AllVars, _, ParfactorGraph),
|
||||
run_fove_solver([QueryVars], LPs, ParfactorGraph),
|
||||
finalize_fove_solver(ParfactorGraph),
|
||||
clpbn_bind_vals([QueryVars], LPs, Output).
|
||||
init_fove_solver(_, AllVars, _, ParfactorList),
|
||||
run_fove_solver([QueryVars], LPs, ParfactorList),
|
||||
finalize_fove_solver(ParfactorList),
|
||||
clpbn_bind_vals([QueryVars], LPs, Output).
|
||||
|
||||
init_fove_solver(_, AllAttVars, _, fove(ParfactorGraph, DistIds)) :-
|
||||
writeln(allattvars:AllAttVars), writeln(''),
|
||||
get_parfactors(Parfactors),
|
||||
get_dist_ids(Parfactors, DistIds0),
|
||||
sort(DistIds0, DistIds),
|
||||
get_observed_vars(AllAttVars, ObservedVars),
|
||||
|
||||
init_fove_solver(_, AllAttVars, _, fove(ParfactorList, DistIds)) :-
|
||||
get_parfactors(Parfactors),
|
||||
get_dist_ids(Parfactors, DistIds0),
|
||||
sort(DistIds0, DistIds),
|
||||
get_observed_vars(AllAttVars, ObservedVars),
|
||||
writeln(factors:Parfactors:'\n'),
|
||||
writeln(evidence:ObservedVars:'\n'),
|
||||
create_lifted_network(Parfactors,ObservedVars,ParfactorGraph).
|
||||
writeln(evidence:ObservedVars:'\n'),
|
||||
create_lifted_network(Parfactors,ObservedVars,ParfactorList).
|
||||
|
||||
|
||||
:- table get_parfactors/1.
|
||||
|
||||
%
|
||||
% enumerate all parfactors and enumerate their domain as tuples.
|
||||
%
|
||||
% output is list of pf(
|
||||
% Id: an unique number
|
||||
% Ks: a list of keys, also known as the pf formula [a(X),b(Y),c(X,Y)]
|
||||
% Vs: the list of free variables [X,Y]
|
||||
% Phi: the table following usual CLP(BN) convention
|
||||
% Tuples: tuples with all ground bindings for variables in Vs, of the form [fv(x,y)]
|
||||
% Id: an unique number
|
||||
% Ks: a list of keys, also known as the pf formula [a(X),b(Y),c(X,Y)]
|
||||
% Vs: the list of free variables [X,Y]
|
||||
% Phi: the table following usual CLP(BN) convention
|
||||
% Tuples: ground bindings for variables in Vs, of the form [fv(x,y)]
|
||||
%
|
||||
get_parfactors(Factors) :-
|
||||
findall(F, is_factor(F), Factors).
|
||||
findall(F, is_factor(F), Factors).
|
||||
|
||||
|
||||
is_factor(pf(Id, Ks, Rs, Phi, Tuples)) :-
|
||||
<<<<<<< HEAD
|
||||
factor(_Type, Id, Ks, Vs, Table, Constraints),
|
||||
get_ranges(Ks,Rs),
|
||||
Table \= avg,
|
||||
gen_table(Table, Phi),
|
||||
all_tuples(Constraints, Vs, Tuples).
|
||||
=======
|
||||
factor(Id, Ks, Vs, Table, Constraints),
|
||||
get_ranges(Ks,Rs),
|
||||
Table \= avg,
|
||||
gen_table(Table, Phi),
|
||||
all_tuples(Constraints, Vs, Tuples).
|
||||
>>>>>>> 911b241ad663a911af52babcf5d702c5239b4350
|
||||
|
||||
|
||||
get_ranges([],[]).
|
||||
get_ranges(K.Ks, Range.Rs) :- !,
|
||||
skolem(K,Domain),
|
||||
length(Domain,Range),
|
||||
get_ranges(Ks, Rs).
|
||||
skolem(K,Domain),
|
||||
length(Domain,Range),
|
||||
get_ranges(Ks, Rs).
|
||||
|
||||
|
||||
gen_table(Table, Phi) :-
|
||||
( is_list(Table)
|
||||
->
|
||||
Phi = Table
|
||||
;
|
||||
call(user:Table, Phi)
|
||||
).
|
||||
( is_list(Table)
|
||||
->
|
||||
Phi = Table
|
||||
;
|
||||
call(user:Table, Phi)
|
||||
).
|
||||
|
||||
|
||||
all_tuples(Constraints, Tuple, Tuples) :-
|
||||
setof(Tuple, Constraints^run(Constraints), Tuples).
|
||||
setof(Tuple, Constraints^run(Constraints), Tuples).
|
||||
|
||||
|
||||
run([]).
|
||||
@ -107,45 +120,41 @@ get_dist_ids(pf(Id, _, _, _, _).Parfactors, Id.DistIds) :-
|
||||
|
||||
get_observed_vars([], []).
|
||||
get_observed_vars(V.AllAttVars, [K:E|ObservedVars]) :-
|
||||
writeln('checking ev for':V),
|
||||
clpbn:get_atts(V,[key(K)]),
|
||||
( clpbn:get_atts(V,[evidence(E)]) ; pfl:evidence(K,E) ), !,
|
||||
writeln('evidence!!!':K:E),
|
||||
get_observed_vars(AllAttVars, ObservedVars).
|
||||
clpbn:get_atts(V,[key(K)]),
|
||||
( clpbn:get_atts(V,[evidence(E)]) ; pfl:evidence(K,E) ), !,
|
||||
get_observed_vars(AllAttVars, ObservedVars).
|
||||
get_observed_vars(V.AllAttVars, ObservedVars) :-
|
||||
clpbn:get_atts(V,[key(K)]), !,
|
||||
writeln('no evidence for':V:K),
|
||||
get_observed_vars(AllAttVars, ObservedVars).
|
||||
clpbn:get_atts(V,[key(K)]), !,
|
||||
get_observed_vars(AllAttVars, ObservedVars).
|
||||
|
||||
|
||||
get_query_vars([], []).
|
||||
get_query_vars(E1.L1, E2.L2) :-
|
||||
get_query_vars_2(E1,E2),
|
||||
get_query_vars(L1, L2).
|
||||
get_query_vars(L1, L2).
|
||||
|
||||
|
||||
get_query_vars_2([], []).
|
||||
get_query_vars_2(V.AttVars, [RV|RVs]) :-
|
||||
clpbn:get_atts(V,[key(RV)]), !,
|
||||
get_query_vars_2(AttVars, RVs).
|
||||
clpbn:get_atts(V,[key(RV)]), !,
|
||||
get_query_vars_2(AttVars, RVs).
|
||||
|
||||
|
||||
get_dists_parameters([], []).
|
||||
get_dists_parameters([Id|Ids], [dist(Id, Params)|DistsInfo]) :-
|
||||
get_pfl_parameters(Id, Params),
|
||||
get_dists_parameters(Ids, DistsInfo).
|
||||
get_pfl_parameters(Id, Params),
|
||||
get_dists_parameters(Ids, DistsInfo).
|
||||
|
||||
|
||||
run_fove_solver(QueryVarsAtts, Solutions, fove(ParfactorGraph, DistIds)) :-
|
||||
% TODO set_parfactor_graph_params
|
||||
writeln(distIds:DistIds),
|
||||
%get_dists_parameters(DistIds, DistParams),
|
||||
%writeln(distParams:DistParams),
|
||||
run_fove_solver(QueryVarsAtts, Solutions, fove(ParfactorList, DistIds)) :-
|
||||
get_dists_parameters(DistIds, DistsParams),
|
||||
writeln(distParams:DistsParams),
|
||||
set_parfactors_params(ParfactorList, DistsParams),
|
||||
get_query_vars(QueryVarsAtts, QueryVars),
|
||||
writeln(queryVars:QueryVars),
|
||||
run_lifted_solver(ParfactorGraph, QueryVars, Solutions).
|
||||
writeln(queryVars:QueryVars), writeln(''),
|
||||
run_lifted_solver(ParfactorList, QueryVars, Solutions).
|
||||
|
||||
|
||||
finalize_fove_solver(fove(ParfactorGraph, _)) :-
|
||||
free_parfactor_graph(ParfactorGraph).
|
||||
finalize_fove_solver(fove(ParfactorList, _)) :-
|
||||
free_parfactors(ParfactorList).
|
||||
|
||||
|
@ -1,17 +1,23 @@
|
||||
|
||||
/*******************************************************
|
||||
|
||||
Interface with C++
|
||||
|
||||
********************************************************/
|
||||
|
||||
:- module(clpbn_horus,
|
||||
[
|
||||
create_lifted_network/3,
|
||||
create_ground_network/2,
|
||||
set_parfactor_graph_params/2,
|
||||
set_bayes_net_params/2,
|
||||
run_lifted_solver/3,
|
||||
run_other_solvers/3,
|
||||
set_extra_vars_info/2,
|
||||
set_horus_flag/2,
|
||||
free_bayesian_network/1,
|
||||
free_parfactor_graph/1
|
||||
]).
|
||||
[create_lifted_network/3,
|
||||
create_ground_network/2,
|
||||
set_parfactors_params/2,
|
||||
set_bayes_net_params/2,
|
||||
run_lifted_solver/3,
|
||||
run_ground_solver/3,
|
||||
set_extra_vars_info/2,
|
||||
set_horus_flag/2,
|
||||
free_parfactors/1,
|
||||
free_bayesian_network/1
|
||||
]).
|
||||
|
||||
|
||||
patch_things_up :-
|
||||
assert_static(clpbn_horus:set_horus_flag(_,_)).
|
||||
@ -23,3 +29,24 @@ warning :-
|
||||
|
||||
|
||||
|
||||
%:- set_horus_flag(inf_alg, ve).
|
||||
:- set_horus_flag(inf_alg, bn_bp).
|
||||
%:- set_horus_flag(inf_alg, fg_bp).
|
||||
%: -set_horus_flag(inf_alg, cbp).
|
||||
|
||||
:- set_horus_flag(schedule, seq_fixed).
|
||||
%:- set_horus_flag(schedule, seq_random).
|
||||
%:- set_horus_flag(schedule, parallel).
|
||||
%:- set_horus_flag(schedule, max_residual).
|
||||
|
||||
:- set_horus_flag(accuracy, 0.0001).
|
||||
|
||||
:- set_horus_flag(max_iter, 1000).
|
||||
|
||||
:- set_horus_flag(order_factor_variables, false).
|
||||
%:- set_horus_flag(order_factor_variables, true).
|
||||
|
||||
|
||||
:- set_horus_flag(use_logarithms, false).
|
||||
% :- set_horus_flag(use_logarithms, true).
|
||||
|
||||
|
Reference in New Issue
Block a user