Merge branch 'master' of https://github.com/tacgomes/yap6.3
Conflicts: packages/CLPBN/clpbn/bp.yap packages/CLPBN/clpbn/fove.yap packages/CLPBN/clpbn/horus.yap
This commit is contained in:
commit
6ccd458ea5
@ -1,9 +1,9 @@
|
|||||||
|
|
||||||
/************************************************
|
/*******************************************************
|
||||||
|
|
||||||
Belief Propagation in CLP(BN)
|
Belief Propagation and Variable Elimination Interface
|
||||||
|
|
||||||
**************************************************/
|
********************************************************/
|
||||||
|
|
||||||
:- module(clpbn_bp,
|
:- module(clpbn_bp,
|
||||||
[bp/3,
|
[bp/3,
|
||||||
@ -33,8 +33,6 @@
|
|||||||
:- use_module(library(clpbn/horus)).
|
:- use_module(library(clpbn/horus)).
|
||||||
|
|
||||||
:- use_module(library(atts)).
|
:- use_module(library(atts)).
|
||||||
:- use_module(library(lists)).
|
|
||||||
:- use_module(library(charsio)).
|
|
||||||
|
|
||||||
:- attribute id/1.
|
:- attribute id/1.
|
||||||
|
|
||||||
@ -51,15 +49,21 @@
|
|||||||
|
|
||||||
:- set_horus_flag(accuracy, 0.0001).
|
:- set_horus_flag(accuracy, 0.0001).
|
||||||
|
|
||||||
:- set_horus_flag(max_iter, 1000).
|
:- use_module(library(charsio),
|
||||||
|
[term_to_atom/2]).
|
||||||
|
|
||||||
:- set_horus_flag(use_logarithms, false).
|
|
||||||
%:- set_horus_flag(use_logarithms, true).
|
|
||||||
|
|
||||||
:- set_horus_flag(order_factor_variables, false).
|
:- use_module(horus,
|
||||||
%:- set_horus_flag(order_factor_variables, true).
|
[create_ground_network/2,
|
||||||
|
set_bayes_net_params/2,
|
||||||
|
run_ground_solver/3,
|
||||||
|
set_extra_vars_info/2,
|
||||||
|
free_bayesian_network/1
|
||||||
|
]).
|
||||||
|
|
||||||
|
|
||||||
|
:- attribute id/1.
|
||||||
|
|
||||||
|
|
||||||
bp([[]],_,_) :- !.
|
bp([[]],_,_) :- !.
|
||||||
bp([QueryVars], AllVars, Output) :-
|
bp([QueryVars], AllVars, Output) :-
|
||||||
@ -70,23 +74,25 @@ bp([QueryVars], AllVars, Output) :-
|
|||||||
|
|
||||||
|
|
||||||
init_bp_solver(_, AllVars0, _, bp(BayesNet, DistIds)) :-
|
init_bp_solver(_, AllVars0, _, bp(BayesNet, DistIds)) :-
|
||||||
|
%writeln('init_bp_solver'),
|
||||||
check_for_agg_vars(AllVars0, AllVars),
|
check_for_agg_vars(AllVars0, AllVars),
|
||||||
writeln('clpbn_vars:'),
|
%writeln('clpbn_vars:'), print_clpbn_vars(AllVars),
|
||||||
print_clpbn_vars(AllVars),
|
|
||||||
assign_ids(AllVars, 0),
|
assign_ids(AllVars, 0),
|
||||||
get_vars_info(AllVars, VarsInfo, DistIds0),
|
get_vars_info(AllVars, VarsInfo, DistIds0),
|
||||||
sort(DistIds0, DistIds),
|
sort(DistIds0, DistIds),
|
||||||
create_ground_network(VarsInfo, BayesNet).
|
create_ground_network(VarsInfo, BayesNet),
|
||||||
%get_extra_vars_info(AllVars, ExtraVarsInfo),
|
%get_extra_vars_info(AllVars, ExtraVarsInfo),
|
||||||
%set_extra_vars_info(BayesNet, ExtraVarsInfo).
|
%set_extra_vars_info(BayesNet, ExtraVarsInfo),
|
||||||
|
%writeln(extravarsinfo:ExtraVarsInfo),
|
||||||
|
true.
|
||||||
|
|
||||||
|
|
||||||
run_bp_solver(QueryVars, Solutions, bp(Network, DistIds)) :-
|
run_bp_solver(QueryVars, Solutions, bp(Network, DistIds)) :-
|
||||||
|
%writeln('-> run_bp_solver'),
|
||||||
get_dists_parameters(DistIds, DistsParams),
|
get_dists_parameters(DistIds, DistsParams),
|
||||||
set_bayes_net_params(Network, DistsParams),
|
set_bayes_net_params(Network, DistsParams),
|
||||||
flatten_1_element_sublists(QueryVars, QueryVars1),
|
vars_to_ids(QueryVars, QueryVarsIds),
|
||||||
vars_to_ids(QueryVars1, QueryVarsIds),
|
run_ground_solver(Network, QueryVarsIds, Solutions).
|
||||||
run_other_solvers(Network, QueryVarsIds, Solutions).
|
|
||||||
|
|
||||||
|
|
||||||
finalize_bp_solver(bp(Network, _)) :-
|
finalize_bp_solver(bp(Network, _)) :-
|
||||||
@ -130,7 +136,7 @@ vars_to_ids([V|Vars], [VarId|Ids]) :-
|
|||||||
get_extra_vars_info([], []).
|
get_extra_vars_info([], []).
|
||||||
get_extra_vars_info([V|Vs], [v(VarId, Label, Domain)|VarsInfo]) :-
|
get_extra_vars_info([V|Vs], [v(VarId, Label, Domain)|VarsInfo]) :-
|
||||||
get_atts(V, [id(VarId)]), !,
|
get_atts(V, [id(VarId)]), !,
|
||||||
clpbn:get_atts(V, [key(Key),dist(DistId, _)]),
|
clpbn:get_atts(V, [key(Key), dist(DistId, _)]),
|
||||||
term_to_atom(Key, Label),
|
term_to_atom(Key, Label),
|
||||||
get_dist_domain(DistId, Domain0),
|
get_dist_domain(DistId, Domain0),
|
||||||
numbers_to_atoms(Domain0, Domain),
|
numbers_to_atoms(Domain0, Domain),
|
||||||
@ -154,13 +160,6 @@ numbers_to_atoms([Number|L0], [Atom|L]) :-
|
|||||||
numbers_to_atoms(L0, L).
|
numbers_to_atoms(L0, L).
|
||||||
|
|
||||||
|
|
||||||
flatten_1_element_sublists([],[]).
|
|
||||||
flatten_1_element_sublists([[H|[]]|T],[H|R]) :- !,
|
|
||||||
flatten_1_element_sublists(T,R).
|
|
||||||
flatten_1_element_sublists([H|T],[H|R]) :-
|
|
||||||
flatten_1_element_sublists(T,R).
|
|
||||||
|
|
||||||
|
|
||||||
print_clpbn_vars(Var.AllVars) :-
|
print_clpbn_vars(Var.AllVars) :-
|
||||||
clpbn:get_atts(Var, [key(Key),dist(DistId,Parents)]),
|
clpbn:get_atts(Var, [key(Key),dist(DistId,Parents)]),
|
||||||
parents_to_keys(Parents, ParentKeys),
|
parents_to_keys(Parents, ParentKeys),
|
||||||
|
@ -88,15 +88,22 @@ BayesNet::readFromBifFormat (const char* fileName)
|
|||||||
abort();
|
abort();
|
||||||
}
|
}
|
||||||
params = reorderParameters (params, node->nrStates());
|
params = reorderParameters (params, node->nrStates());
|
||||||
Distribution* dist = new Distribution (params);
|
|
||||||
node->setDistribution (dist);
|
|
||||||
addDistribution (dist);
|
|
||||||
}
|
|
||||||
|
|
||||||
setIndexes();
|
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
distributionsToLogs();
|
Util::toLog (params);
|
||||||
}
|
}
|
||||||
|
node->setParams (params);
|
||||||
|
}
|
||||||
|
setIndexes();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
BayesNode*
|
||||||
|
BayesNet::addNode (BayesNode* n)
|
||||||
|
{
|
||||||
|
varMap_.insert (make_pair (n->varId(), nodes_.size()));
|
||||||
|
nodes_.push_back (n);
|
||||||
|
return nodes_.back();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -114,15 +121,6 @@ BayesNet::addNode (string label, const States& states)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
BayesNode*
|
|
||||||
BayesNet::addNode (VarId vid, unsigned dsize, int evidence, Distribution* dist)
|
|
||||||
{
|
|
||||||
varMap_.insert (make_pair (vid, nodes_.size()));
|
|
||||||
nodes_.push_back (new BayesNode (vid, dsize, evidence, dist));
|
|
||||||
return nodes_.back();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
BayesNode*
|
BayesNode*
|
||||||
BayesNet::getBayesNode (VarId vid) const
|
BayesNet::getBayesNode (VarId vid) const
|
||||||
@ -176,29 +174,6 @@ BayesNet::getVariableNodes (void) const
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BayesNet::addDistribution (Distribution* dist)
|
|
||||||
{
|
|
||||||
dists_.push_back (dist);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Distribution*
|
|
||||||
BayesNet::getDistribution (unsigned distId) const
|
|
||||||
{
|
|
||||||
Distribution* dist = 0;
|
|
||||||
for (unsigned i = 0; i < dists_.size(); i++) {
|
|
||||||
if (dists_[i]->id == (int) distId) {
|
|
||||||
dist = dists_[i];
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return dist;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
const BnNodeSet&
|
const BnNodeSet&
|
||||||
BayesNet::getBayesNodes (void) const
|
BayesNet::getBayesNodes (void) const
|
||||||
{
|
{
|
||||||
@ -299,7 +274,7 @@ BayesNet::getMinimalRequesiteNetwork (const VarIds& queryVarIds) const
|
|||||||
/*
|
/*
|
||||||
cout << "\t\ttop\tbottom" << endl;
|
cout << "\t\ttop\tbottom" << endl;
|
||||||
cout << "variable\t\tmarked\tmarked\tvisited\tobserved" << endl;
|
cout << "variable\t\tmarked\tmarked\tvisited\tobserved" << endl;
|
||||||
cout << "----------------------------------------------------------" ;
|
Util::printDashedLine();
|
||||||
cout << endl;
|
cout << endl;
|
||||||
for (unsigned i = 0; i < states.size(); i++) {
|
for (unsigned i = 0; i < states.size(); i++) {
|
||||||
cout << nodes_[i]->label() << ":\t\t" ;
|
cout << nodes_[i]->label() << ":\t\t" ;
|
||||||
@ -350,10 +325,8 @@ BayesNet::constructGraph (BayesNet* bn,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
assert (bn->getBayesNode (nodes_[i]->varId()) == 0);
|
assert (bn->getBayesNode (nodes_[i]->varId()) == 0);
|
||||||
BayesNode* mrnNode = bn->addNode (nodes_[i]->varId(),
|
BayesNode* mrnNode = new BayesNode (nodes_[i]);
|
||||||
nodes_[i]->nrStates(),
|
bn->addNode (mrnNode);
|
||||||
nodes_[i]->getEvidence(),
|
|
||||||
nodes_[i]->getDistribution());
|
|
||||||
mrnNodes.push_back (mrnNode);
|
mrnNodes.push_back (mrnNode);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -388,26 +361,6 @@ BayesNet::setIndexes (void)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BayesNet::distributionsToLogs (void)
|
|
||||||
{
|
|
||||||
for (unsigned i = 0; i < dists_.size(); i++) {
|
|
||||||
Util::toLog (dists_[i]->params);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BayesNet::freeDistributions (void)
|
|
||||||
{
|
|
||||||
for (unsigned i = 0; i < dists_.size(); i++) {
|
|
||||||
delete dists_[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
BayesNet::printGraphicalModel (void) const
|
BayesNet::printGraphicalModel (void) const
|
||||||
{
|
{
|
||||||
@ -504,8 +457,8 @@ BayesNet::exportToBifFormat (const char* fileName) const
|
|||||||
out << "\t<GIVEN>" << parents[j]->label();
|
out << "\t<GIVEN>" << parents[j]->label();
|
||||||
out << "</GIVEN>" << endl;
|
out << "</GIVEN>" << endl;
|
||||||
}
|
}
|
||||||
Params params = revertParameterReorder (nodes_[i]->getParameters(),
|
Params params = revertParameterReorder (
|
||||||
nodes_[i]->nrStates());
|
nodes_[i]->params(), nodes_[i]->nrStates());
|
||||||
out << "\t<TABLE>" ;
|
out << "\t<TABLE>" ;
|
||||||
for (unsigned j = 0; j < params.size(); j++) {
|
for (unsigned j = 0; j < params.size(); j++) {
|
||||||
out << " " << params[j];
|
out << " " << params[j];
|
||||||
|
@ -13,16 +13,11 @@
|
|||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
class Distribution;
|
|
||||||
|
|
||||||
struct ScheduleInfo
|
struct ScheduleInfo
|
||||||
{
|
{
|
||||||
ScheduleInfo (BayesNode* n, bool vfp, bool vfc)
|
ScheduleInfo (BayesNode* n, bool vfp, bool vfc) :
|
||||||
{
|
node(n), visitedFromParent(vfp), visitedFromChild(vfc) { }
|
||||||
node = n;
|
|
||||||
visitedFromParent = vfp;
|
|
||||||
visitedFromChild = vfc;
|
|
||||||
}
|
|
||||||
BayesNode* node;
|
BayesNode* node;
|
||||||
bool visitedFromParent;
|
bool visitedFromParent;
|
||||||
bool visitedFromChild;
|
bool visitedFromChild;
|
||||||
@ -31,67 +26,81 @@ struct ScheduleInfo
|
|||||||
|
|
||||||
struct StateInfo
|
struct StateInfo
|
||||||
{
|
{
|
||||||
StateInfo (void)
|
StateInfo (void) : visited(false), markedOnTop(false),
|
||||||
{
|
markedOnBottom(false) { }
|
||||||
visited = true;
|
|
||||||
markedOnTop = false;
|
|
||||||
markedOnBottom = false;
|
|
||||||
}
|
|
||||||
bool visited;
|
bool visited;
|
||||||
bool markedOnTop;
|
bool markedOnTop;
|
||||||
bool markedOnBottom;
|
bool markedOnBottom;
|
||||||
};
|
};
|
||||||
|
|
||||||
typedef vector<Distribution*> DistSet;
|
|
||||||
typedef queue<ScheduleInfo, list<ScheduleInfo> > Scheduling;
|
typedef queue<ScheduleInfo, list<ScheduleInfo> > Scheduling;
|
||||||
|
|
||||||
|
|
||||||
class BayesNet : public GraphicalModel
|
class BayesNet : public GraphicalModel
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
BayesNet (void) {};
|
BayesNet (void) { };
|
||||||
|
|
||||||
~BayesNet (void);
|
~BayesNet (void);
|
||||||
|
|
||||||
void readFromBifFormat (const char*);
|
void readFromBifFormat (const char*);
|
||||||
|
|
||||||
|
BayesNode* addNode (BayesNode*);
|
||||||
|
|
||||||
BayesNode* addNode (string, const States&);
|
BayesNode* addNode (string, const States&);
|
||||||
// BayesNode* addNode (VarId, unsigned, int, BnNodeSet&, Distribution*);
|
|
||||||
BayesNode* addNode (VarId, unsigned, int, Distribution*);
|
|
||||||
BayesNode* getBayesNode (VarId) const;
|
BayesNode* getBayesNode (VarId) const;
|
||||||
|
|
||||||
BayesNode* getBayesNode (string) const;
|
BayesNode* getBayesNode (string) const;
|
||||||
|
|
||||||
VarNode* getVariableNode (VarId) const;
|
VarNode* getVariableNode (VarId) const;
|
||||||
|
|
||||||
VarNodes getVariableNodes (void) const;
|
VarNodes getVariableNodes (void) const;
|
||||||
void addDistribution (Distribution*);
|
|
||||||
Distribution* getDistribution (unsigned) const;
|
|
||||||
const BnNodeSet& getBayesNodes (void) const;
|
const BnNodeSet& getBayesNodes (void) const;
|
||||||
|
|
||||||
unsigned nrNodes (void) const;
|
unsigned nrNodes (void) const;
|
||||||
|
|
||||||
BnNodeSet getRootNodes (void) const;
|
BnNodeSet getRootNodes (void) const;
|
||||||
|
|
||||||
BnNodeSet getLeafNodes (void) const;
|
BnNodeSet getLeafNodes (void) const;
|
||||||
|
|
||||||
BayesNet* getMinimalRequesiteNetwork (VarId) const;
|
BayesNet* getMinimalRequesiteNetwork (VarId) const;
|
||||||
|
|
||||||
BayesNet* getMinimalRequesiteNetwork (const VarIds&) const;
|
BayesNet* getMinimalRequesiteNetwork (const VarIds&) const;
|
||||||
void constructGraph (
|
|
||||||
BayesNet*, const vector<StateInfo*>&) const;
|
void constructGraph (BayesNet*, const vector<StateInfo*>&) const;
|
||||||
|
|
||||||
bool isPolyTree (void) const;
|
bool isPolyTree (void) const;
|
||||||
|
|
||||||
void setIndexes (void);
|
void setIndexes (void);
|
||||||
void distributionsToLogs (void);
|
|
||||||
void freeDistributions (void);
|
|
||||||
void printGraphicalModel (void) const;
|
void printGraphicalModel (void) const;
|
||||||
|
|
||||||
void exportToGraphViz (const char*, bool = true,
|
void exportToGraphViz (const char*, bool = true,
|
||||||
const VarIds& = VarIds()) const;
|
const VarIds& = VarIds()) const;
|
||||||
|
|
||||||
void exportToBifFormat (const char*) const;
|
void exportToBifFormat (const char*) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
DISALLOW_COPY_AND_ASSIGN (BayesNet);
|
DISALLOW_COPY_AND_ASSIGN (BayesNet);
|
||||||
|
|
||||||
bool containsUndirectedCycle (void) const;
|
bool containsUndirectedCycle (void) const;
|
||||||
|
|
||||||
bool containsUndirectedCycle (int, int, vector<bool>&)const;
|
bool containsUndirectedCycle (int, int, vector<bool>&)const;
|
||||||
|
|
||||||
vector<int> getAdjacentNodes (int) const;
|
vector<int> getAdjacentNodes (int) const;
|
||||||
|
|
||||||
Params reorderParameters (const Params&, unsigned) const;
|
Params reorderParameters (const Params&, unsigned) const;
|
||||||
|
|
||||||
Params revertParameterReorder (const Params&, unsigned) const;
|
Params revertParameterReorder (const Params&, unsigned) const;
|
||||||
|
|
||||||
void scheduleParents (const BayesNode*, Scheduling&) const;
|
void scheduleParents (const BayesNode*, Scheduling&) const;
|
||||||
|
|
||||||
void scheduleChilds (const BayesNode*, Scheduling&) const;
|
void scheduleChilds (const BayesNode*, Scheduling&) const;
|
||||||
|
|
||||||
BnNodeSet nodes_;
|
BnNodeSet nodes_;
|
||||||
DistSet dists_;
|
|
||||||
|
|
||||||
typedef unordered_map<unsigned, unsigned> IndexMap;
|
typedef unordered_map<unsigned, unsigned> IndexMap;
|
||||||
IndexMap varMap_;
|
IndexMap varMap_;
|
||||||
|
@ -8,29 +8,10 @@
|
|||||||
#include "BayesNode.h"
|
#include "BayesNode.h"
|
||||||
|
|
||||||
|
|
||||||
BayesNode::BayesNode (VarId vid,
|
void
|
||||||
unsigned dsize,
|
BayesNode::setParams (const Params& params)
|
||||||
int evidence,
|
|
||||||
Distribution* dist)
|
|
||||||
: VarNode (vid, dsize, evidence)
|
|
||||||
{
|
{
|
||||||
dist_ = dist;
|
params_ = params;
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
BayesNode::BayesNode (VarId vid,
|
|
||||||
unsigned dsize,
|
|
||||||
int evidence,
|
|
||||||
const BnNodeSet& parents,
|
|
||||||
Distribution* dist)
|
|
||||||
: VarNode (vid, dsize, evidence)
|
|
||||||
{
|
|
||||||
parents_ = parents;
|
|
||||||
dist_ = dist;
|
|
||||||
for (unsigned int i = 0; i < parents.size(); i++) {
|
|
||||||
parents[i]->addChild (this);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -54,31 +35,6 @@ BayesNode::addChild (BayesNode* node)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BayesNode::setDistribution (Distribution* dist)
|
|
||||||
{
|
|
||||||
assert (dist);
|
|
||||||
dist_ = dist;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Distribution*
|
|
||||||
BayesNode::getDistribution (void)
|
|
||||||
{
|
|
||||||
return dist_;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
const Params&
|
|
||||||
BayesNode::getParameters (void)
|
|
||||||
{
|
|
||||||
return dist_->params;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
Params
|
||||||
BayesNode::getRow (int rowIndex) const
|
BayesNode::getRow (int rowIndex) const
|
||||||
{
|
{
|
||||||
@ -86,7 +42,7 @@ BayesNode::getRow (int rowIndex) const
|
|||||||
int offset = rowSize * rowIndex;
|
int offset = rowSize * rowIndex;
|
||||||
Params row (rowSize);
|
Params row (rowSize);
|
||||||
for (int i = 0; i < rowSize; i++) {
|
for (int i = 0; i < rowSize; i++) {
|
||||||
row[i] = dist_->params[offset + i] ;
|
row[i] = params_[offset + i] ;
|
||||||
}
|
}
|
||||||
return row;
|
return row;
|
||||||
}
|
}
|
||||||
@ -119,13 +75,13 @@ BayesNode::hasNeighbors (void) const
|
|||||||
int
|
int
|
||||||
BayesNode::getCptSize (void)
|
BayesNode::getCptSize (void)
|
||||||
{
|
{
|
||||||
return dist_->params.size();
|
return params_.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
int
|
int
|
||||||
BayesNode::getIndexOfParent (const BayesNode* parent) const
|
BayesNode::indexOfParent (const BayesNode* parent) const
|
||||||
{
|
{
|
||||||
for (unsigned int i = 0; i < parents_.size(); i++) {
|
for (unsigned int i = 0; i < parents_.size(); i++) {
|
||||||
if (parents_[i] == parent) {
|
if (parents_[i] == parent) {
|
||||||
|
@ -4,7 +4,6 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "VarNode.h"
|
#include "VarNode.h"
|
||||||
#include "Distribution.h"
|
|
||||||
#include "Horus.h"
|
#include "Horus.h"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
@ -13,49 +12,70 @@ using namespace std;
|
|||||||
class BayesNode : public VarNode
|
class BayesNode : public VarNode
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
BayesNode (const VarNode& v) : VarNode (v) {}
|
|
||||||
BayesNode (VarId, unsigned, int, Distribution*);
|
|
||||||
BayesNode (VarId, unsigned, int, const BnNodeSet&, Distribution*);
|
|
||||||
|
|
||||||
void setParents (const BnNodeSet&);
|
BayesNode (const VarNode& v) : VarNode (v) { }
|
||||||
void addChild (BayesNode*);
|
|
||||||
void setDistribution (Distribution*);
|
BayesNode (const BayesNode* n) :
|
||||||
Distribution* getDistribution (void);
|
VarNode (n->varId(), n->nrStates(), n->getEvidence()),
|
||||||
const Params& getParameters (void);
|
params_(n->params()), distId_(n->distId()) { }
|
||||||
Params getRow (int) const;
|
|
||||||
bool isRoot (void);
|
BayesNode (VarId vid, unsigned nrStates, int ev,
|
||||||
bool isLeaf (void);
|
const Params& ps, unsigned id)
|
||||||
bool hasNeighbors (void) const;
|
: VarNode (vid, nrStates, ev) , params_(ps), distId_(id) { }
|
||||||
int getCptSize (void);
|
|
||||||
int getIndexOfParent (const BayesNode*) const;
|
|
||||||
string cptEntryToString (int, const vector<unsigned>&) const;
|
|
||||||
|
|
||||||
const BnNodeSet& getParents (void) const { return parents_; }
|
const BnNodeSet& getParents (void) const { return parents_; }
|
||||||
|
|
||||||
const BnNodeSet& getChilds (void) const { return childs_; }
|
const BnNodeSet& getChilds (void) const { return childs_; }
|
||||||
|
|
||||||
|
const Params& params (void) const { return params_; }
|
||||||
|
|
||||||
|
unsigned distId (void) const { return distId_; }
|
||||||
|
|
||||||
unsigned getRowSize (void) const
|
unsigned getRowSize (void) const
|
||||||
{
|
{
|
||||||
return dist_->params.size() / nrStates();
|
return params_.size() / nrStates();
|
||||||
}
|
}
|
||||||
|
|
||||||
double getProbability (int row, unsigned col)
|
double getProbability (int row, unsigned col)
|
||||||
{
|
{
|
||||||
int idx = (row * getRowSize()) + col;
|
int idx = (row * getRowSize()) + col;
|
||||||
return dist_->params[idx];
|
return params_[idx];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void setParams (const Params& params);
|
||||||
|
|
||||||
|
void setParents (const BnNodeSet&);
|
||||||
|
|
||||||
|
void addChild (BayesNode*);
|
||||||
|
|
||||||
|
const Params& getParameters (void);
|
||||||
|
|
||||||
|
Params getRow (int) const;
|
||||||
|
|
||||||
|
bool isRoot (void);
|
||||||
|
|
||||||
|
bool isLeaf (void);
|
||||||
|
|
||||||
|
bool hasNeighbors (void) const;
|
||||||
|
|
||||||
|
int getCptSize (void);
|
||||||
|
|
||||||
|
int indexOfParent (const BayesNode*) const;
|
||||||
|
|
||||||
|
string cptEntryToString (int, const vector<unsigned>&) const;
|
||||||
|
|
||||||
|
friend ostream& operator << (ostream&, const BayesNode&);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
DISALLOW_COPY_AND_ASSIGN (BayesNode);
|
DISALLOW_COPY_AND_ASSIGN (BayesNode);
|
||||||
|
|
||||||
States getDomainHeaders (void) const;
|
States getDomainHeaders (void) const;
|
||||||
friend ostream& operator << (ostream&, const BayesNode&);
|
|
||||||
|
|
||||||
BnNodeSet parents_;
|
BnNodeSet parents_;
|
||||||
BnNodeSet childs_;
|
BnNodeSet childs_;
|
||||||
Distribution* dist_;
|
Params params_;
|
||||||
|
unsigned distId_;
|
||||||
};
|
};
|
||||||
|
|
||||||
ostream& operator << (ostream&, const BayesNode&);
|
|
||||||
|
|
||||||
#endif // HORUS_BAYESNODE_H
|
#endif // HORUS_BAYESNODE_H
|
||||||
|
|
||||||
|
@ -34,12 +34,12 @@ void
|
|||||||
BnBpSolver::runSolver (void)
|
BnBpSolver::runSolver (void)
|
||||||
{
|
{
|
||||||
clock_t start;
|
clock_t start;
|
||||||
if (COLLECT_STATISTICS) {
|
if (Constants::COLLECT_STATS) {
|
||||||
start = clock();
|
start = clock();
|
||||||
}
|
}
|
||||||
initializeSolver();
|
initializeSolver();
|
||||||
runLoopySolver();
|
runLoopySolver();
|
||||||
if (DL >= 2) {
|
if (Constants::DEBUG >= 2) {
|
||||||
cout << endl;
|
cout << endl;
|
||||||
if (nIters_ < BpOptions::maxIter) {
|
if (nIters_ < BpOptions::maxIter) {
|
||||||
cout << "Belief propagation converged in " ;
|
cout << "Belief propagation converged in " ;
|
||||||
@ -51,18 +51,13 @@ BnBpSolver::runSolver (void)
|
|||||||
}
|
}
|
||||||
|
|
||||||
unsigned size = bayesNet_->nrNodes();
|
unsigned size = bayesNet_->nrNodes();
|
||||||
if (COLLECT_STATISTICS) {
|
if (Constants::COLLECT_STATS) {
|
||||||
unsigned nIters = 0;
|
unsigned nIters = 0;
|
||||||
bool loopy = bayesNet_->isPolyTree() == false;
|
bool loopy = bayesNet_->isPolyTree() == false;
|
||||||
if (loopy) nIters = nIters_;
|
if (loopy) nIters = nIters_;
|
||||||
double time = (double (clock() - start)) / CLOCKS_PER_SEC;
|
double time = (double (clock() - start)) / CLOCKS_PER_SEC;
|
||||||
Statistics::updateStatistics (size, loopy, nIters, time);
|
Statistics::updateStatistics (size, loopy, nIters, time);
|
||||||
}
|
}
|
||||||
if (EXPORT_TO_GRAPHVIZ && size > EXPORT_MINIMAL_SIZE) {
|
|
||||||
stringstream ss;
|
|
||||||
ss << Statistics::getSolvedNetworksCounting() << "." << size << ".dot" ;
|
|
||||||
bayesNet_->exportToGraphViz (ss.str().c_str());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -80,7 +75,7 @@ BnBpSolver::getPosterioriOf (VarId vid)
|
|||||||
Params
|
Params
|
||||||
BnBpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
BnBpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
||||||
{
|
{
|
||||||
if (DL >= 2) {
|
if (Constants::DEBUG >= 2) {
|
||||||
cout << "calculating joint distribution on: " ;
|
cout << "calculating joint distribution on: " ;
|
||||||
for (unsigned i = 0; i < jointVarIds.size(); i++) {
|
for (unsigned i = 0; i < jointVarIds.size(); i++) {
|
||||||
VarNode* var = bayesNet_->getBayesNode (jointVarIds[i]);
|
VarNode* var = bayesNet_->getBayesNode (jointVarIds[i]);
|
||||||
@ -112,7 +107,7 @@ BnBpSolver::initializeSolver (void)
|
|||||||
|
|
||||||
BnNodeSet roots = bayesNet_->getRootNodes();
|
BnNodeSet roots = bayesNet_->getRootNodes();
|
||||||
for (unsigned i = 0; i < roots.size(); i++) {
|
for (unsigned i = 0; i < roots.size(); i++) {
|
||||||
const Params& params = roots[i]->getParameters();
|
const Params& params = roots[i]->params();
|
||||||
Params& piVals = ninf(roots[i])->getPiValues();
|
Params& piVals = ninf(roots[i])->getPiValues();
|
||||||
for (unsigned ri = 0; ri < roots[i]->nrStates(); ri++) {
|
for (unsigned ri = 0; ri < roots[i]->nrStates(); ri++) {
|
||||||
piVals[ri] = params[ri];
|
piVals[ri] = params[ri];
|
||||||
@ -143,11 +138,11 @@ BnBpSolver::initializeSolver (void)
|
|||||||
Params& piVals = ninf(nodes[i])->getPiValues();
|
Params& piVals = ninf(nodes[i])->getPiValues();
|
||||||
Params& ldVals = ninf(nodes[i])->getLambdaValues();
|
Params& ldVals = ninf(nodes[i])->getLambdaValues();
|
||||||
for (unsigned xi = 0; xi < nodes[i]->nrStates(); xi++) {
|
for (unsigned xi = 0; xi < nodes[i]->nrStates(); xi++) {
|
||||||
piVals[xi] = Util::noEvidence();
|
piVals[xi] = LogAware::noEvidence();
|
||||||
ldVals[xi] = Util::noEvidence();
|
ldVals[xi] = LogAware::noEvidence();
|
||||||
}
|
}
|
||||||
piVals[nodes[i]->getEvidence()] = Util::withEvidence();
|
piVals[nodes[i]->getEvidence()] = LogAware::withEvidence();
|
||||||
ldVals[nodes[i]->getEvidence()] = Util::withEvidence();
|
ldVals[nodes[i]->getEvidence()] = LogAware::withEvidence();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -161,13 +156,8 @@ BnBpSolver::runLoopySolver()
|
|||||||
while (!converged() && nIters_ < BpOptions::maxIter) {
|
while (!converged() && nIters_ < BpOptions::maxIter) {
|
||||||
|
|
||||||
nIters_++;
|
nIters_++;
|
||||||
if (DL >= 2) {
|
if (Constants::DEBUG >= 2) {
|
||||||
cout << "****************************************" ;
|
Util::printHeader ("Iteration " + nIters_);
|
||||||
cout << "****************************************" ;
|
|
||||||
cout << endl;
|
|
||||||
cout << " Iteration " << nIters_ << endl;
|
|
||||||
cout << "****************************************" ;
|
|
||||||
cout << "****************************************" ;
|
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -199,7 +189,7 @@ BnBpSolver::runLoopySolver()
|
|||||||
break;
|
break;
|
||||||
|
|
||||||
}
|
}
|
||||||
if (DL >= 2) {
|
if (Constants::DEBUG >= 2) {
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -228,7 +218,7 @@ BnBpSolver::converged (void) const
|
|||||||
} else {
|
} else {
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
for (unsigned i = 0; i < links_.size(); i++) {
|
||||||
double residual = links_[i]->getResidual();
|
double residual = links_[i]->getResidual();
|
||||||
if (DL >= 2) {
|
if (Constants::DEBUG >= 2) {
|
||||||
cout << links_[i]->toString() + " residual change = " ;
|
cout << links_[i]->toString() + " residual change = " ;
|
||||||
cout << residual << endl;
|
cout << residual << endl;
|
||||||
}
|
}
|
||||||
@ -256,7 +246,7 @@ BnBpSolver::maxResidualSchedule (void)
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (unsigned c = 0; c < sortedOrder_.size(); c++) {
|
for (unsigned c = 0; c < sortedOrder_.size(); c++) {
|
||||||
if (DL >= 2) {
|
if (Constants::DEBUG >= 2) {
|
||||||
cout << "current residuals:" << endl;
|
cout << "current residuals:" << endl;
|
||||||
for (SortedOrder::iterator it = sortedOrder_.begin();
|
for (SortedOrder::iterator it = sortedOrder_.begin();
|
||||||
it != sortedOrder_.end(); it ++) {
|
it != sortedOrder_.end(); it ++) {
|
||||||
@ -300,9 +290,8 @@ BnBpSolver::maxResidualSchedule (void)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (DL >= 2) {
|
if (Constants::DEBUG >= 2) {
|
||||||
cout << "----------------------------------------" ;
|
Util::printDashedLine();
|
||||||
cout << "----------------------------------------" << endl;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -313,7 +302,7 @@ void
|
|||||||
BnBpSolver::updatePiValues (BayesNode* x)
|
BnBpSolver::updatePiValues (BayesNode* x)
|
||||||
{
|
{
|
||||||
// π(Xi)
|
// π(Xi)
|
||||||
if (DL >= 3) {
|
if (Constants::DEBUG >= 3) {
|
||||||
cout << "updating " << PI_SYMBOL << " values for " << x->label() << endl;
|
cout << "updating " << PI_SYMBOL << " values for " << x->label() << endl;
|
||||||
}
|
}
|
||||||
Params& piValues = ninf(x)->getPiValues();
|
Params& piValues = ninf(x)->getPiValues();
|
||||||
@ -329,11 +318,11 @@ BnBpSolver::updatePiValues (BayesNode* x)
|
|||||||
|
|
||||||
Params messageProducts (indexer.size());
|
Params messageProducts (indexer.size());
|
||||||
for (unsigned k = 0; k < indexer.size(); k++) {
|
for (unsigned k = 0; k < indexer.size(); k++) {
|
||||||
if (DL >= 5) {
|
if (Constants::DEBUG >= 5) {
|
||||||
calcs1 = new stringstream;
|
calcs1 = new stringstream;
|
||||||
calcs2 = new stringstream;
|
calcs2 = new stringstream;
|
||||||
}
|
}
|
||||||
double messageProduct = Util::multIdenty();
|
double messageProduct = LogAware::multIdenty();
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
for (unsigned i = 0; i < parentLinks.size(); i++) {
|
for (unsigned i = 0; i < parentLinks.size(); i++) {
|
||||||
messageProduct += parentLinks[i]->getMessage()[indexer[i]];
|
messageProduct += parentLinks[i]->getMessage()[indexer[i]];
|
||||||
@ -341,7 +330,7 @@ BnBpSolver::updatePiValues (BayesNode* x)
|
|||||||
} else {
|
} else {
|
||||||
for (unsigned i = 0; i < parentLinks.size(); i++) {
|
for (unsigned i = 0; i < parentLinks.size(); i++) {
|
||||||
messageProduct *= parentLinks[i]->getMessage()[indexer[i]];
|
messageProduct *= parentLinks[i]->getMessage()[indexer[i]];
|
||||||
if (DL >= 5) {
|
if (Constants::DEBUG >= 5) {
|
||||||
if (i != 0) *calcs1 << " + " ;
|
if (i != 0) *calcs1 << " + " ;
|
||||||
if (i != 0) *calcs2 << " + " ;
|
if (i != 0) *calcs2 << " + " ;
|
||||||
*calcs1 << parentLinks[i]->toString (indexer[i]);
|
*calcs1 << parentLinks[i]->toString (indexer[i]);
|
||||||
@ -350,7 +339,7 @@ BnBpSolver::updatePiValues (BayesNode* x)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
messageProducts[k] = messageProduct;
|
messageProducts[k] = messageProduct;
|
||||||
if (DL >= 5) {
|
if (Constants::DEBUG >= 5) {
|
||||||
cout << " mp" << k;
|
cout << " mp" << k;
|
||||||
cout << " = " << (*calcs1).str();
|
cout << " = " << (*calcs1).str();
|
||||||
if (parentLinks.size() == 1) {
|
if (parentLinks.size() == 1) {
|
||||||
@ -366,27 +355,27 @@ BnBpSolver::updatePiValues (BayesNode* x)
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (unsigned xi = 0; xi < x->nrStates(); xi++) {
|
for (unsigned xi = 0; xi < x->nrStates(); xi++) {
|
||||||
double sum = Util::addIdenty();
|
double sum = LogAware::addIdenty();
|
||||||
if (DL >= 5) {
|
if (Constants::DEBUG >= 5) {
|
||||||
calcs1 = new stringstream;
|
calcs1 = new stringstream;
|
||||||
calcs2 = new stringstream;
|
calcs2 = new stringstream;
|
||||||
}
|
}
|
||||||
indexer.reset();
|
indexer.reset();
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
for (unsigned k = 0; k < indexer.size(); k++) {
|
for (unsigned k = 0; k < indexer.size(); k++) {
|
||||||
Util::logSum (sum,
|
sum = Util::logSum (sum,
|
||||||
x->getProbability(xi, indexer.linearIndex()) + messageProducts[k]);
|
x->getProbability(xi, indexer) + messageProducts[k]);
|
||||||
++ indexer;
|
++ indexer;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (unsigned k = 0; k < indexer.size(); k++) {
|
for (unsigned k = 0; k < indexer.size(); k++) {
|
||||||
sum += x->getProbability (xi, indexer.linearIndex()) * messageProducts[k];
|
sum += x->getProbability (xi, indexer) * messageProducts[k];
|
||||||
if (DL >= 5) {
|
if (Constants::DEBUG >= 5) {
|
||||||
if (k != 0) *calcs1 << " + " ;
|
if (k != 0) *calcs1 << " + " ;
|
||||||
if (k != 0) *calcs2 << " + " ;
|
if (k != 0) *calcs2 << " + " ;
|
||||||
*calcs1 << x->cptEntryToString (xi, indexer.indices());
|
*calcs1 << x->cptEntryToString (xi, indexer.indices());
|
||||||
*calcs1 << ".mp" << k;
|
*calcs1 << ".mp" << k;
|
||||||
*calcs2 << Util::fl (x->getProbability (xi, indexer.linearIndex()));
|
*calcs2 << LogAware::fl (x->getProbability (xi, indexer));
|
||||||
*calcs2 << "*" << messageProducts[k];
|
*calcs2 << "*" << messageProducts[k];
|
||||||
}
|
}
|
||||||
++ indexer;
|
++ indexer;
|
||||||
@ -394,7 +383,7 @@ BnBpSolver::updatePiValues (BayesNode* x)
|
|||||||
}
|
}
|
||||||
|
|
||||||
piValues[xi] = sum;
|
piValues[xi] = sum;
|
||||||
if (DL >= 5) {
|
if (Constants::DEBUG >= 5) {
|
||||||
cout << " " << PI_SYMBOL << "(" << x->label() << ")" ;
|
cout << " " << PI_SYMBOL << "(" << x->label() << ")" ;
|
||||||
cout << "[" << x->states()[xi] << "]" ;
|
cout << "[" << x->states()[xi] << "]" ;
|
||||||
cout << " = " << (*calcs1).str();
|
cout << " = " << (*calcs1).str();
|
||||||
@ -412,7 +401,7 @@ void
|
|||||||
BnBpSolver::updateLambdaValues (BayesNode* x)
|
BnBpSolver::updateLambdaValues (BayesNode* x)
|
||||||
{
|
{
|
||||||
// λ(Xi)
|
// λ(Xi)
|
||||||
if (DL >= 3) {
|
if (Constants::DEBUG >= 3) {
|
||||||
cout << "updating " << LD_SYMBOL << " values for " << x->label() << endl;
|
cout << "updating " << LD_SYMBOL << " values for " << x->label() << endl;
|
||||||
}
|
}
|
||||||
Params& lambdaValues = ninf(x)->getLambdaValues();
|
Params& lambdaValues = ninf(x)->getLambdaValues();
|
||||||
@ -421,11 +410,11 @@ BnBpSolver::updateLambdaValues (BayesNode* x)
|
|||||||
stringstream* calcs2 = 0;
|
stringstream* calcs2 = 0;
|
||||||
|
|
||||||
for (unsigned xi = 0; xi < x->nrStates(); xi++) {
|
for (unsigned xi = 0; xi < x->nrStates(); xi++) {
|
||||||
if (DL >= 5) {
|
if (Constants::DEBUG >= 5) {
|
||||||
calcs1 = new stringstream;
|
calcs1 = new stringstream;
|
||||||
calcs2 = new stringstream;
|
calcs2 = new stringstream;
|
||||||
}
|
}
|
||||||
double product = Util::multIdenty();
|
double product = LogAware::multIdenty();
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
for (unsigned i = 0; i < childLinks.size(); i++) {
|
for (unsigned i = 0; i < childLinks.size(); i++) {
|
||||||
product += childLinks[i]->getMessage()[xi];
|
product += childLinks[i]->getMessage()[xi];
|
||||||
@ -433,7 +422,7 @@ BnBpSolver::updateLambdaValues (BayesNode* x)
|
|||||||
} else {
|
} else {
|
||||||
for (unsigned i = 0; i < childLinks.size(); i++) {
|
for (unsigned i = 0; i < childLinks.size(); i++) {
|
||||||
product *= childLinks[i]->getMessage()[xi];
|
product *= childLinks[i]->getMessage()[xi];
|
||||||
if (DL >= 5) {
|
if (Constants::DEBUG >= 5) {
|
||||||
if (i != 0) *calcs1 << "." ;
|
if (i != 0) *calcs1 << "." ;
|
||||||
if (i != 0) *calcs2 << "*" ;
|
if (i != 0) *calcs2 << "*" ;
|
||||||
*calcs1 << childLinks[i]->toString (xi);
|
*calcs1 << childLinks[i]->toString (xi);
|
||||||
@ -442,7 +431,7 @@ BnBpSolver::updateLambdaValues (BayesNode* x)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
lambdaValues[xi] = product;
|
lambdaValues[xi] = product;
|
||||||
if (DL >= 5) {
|
if (Constants::DEBUG >= 5) {
|
||||||
cout << " " << LD_SYMBOL << "(" << x->label() << ")" ;
|
cout << " " << LD_SYMBOL << "(" << x->label() << ")" ;
|
||||||
cout << "[" << x->states()[xi] << "]" ;
|
cout << "[" << x->states()[xi] << "]" ;
|
||||||
cout << " = " << (*calcs1).str();
|
cout << " = " << (*calcs1).str();
|
||||||
@ -474,7 +463,7 @@ BnBpSolver::calculatePiMessage (BpLink* link)
|
|||||||
const Params& zPiValues = ninf(z)->getPiValues();
|
const Params& zPiValues = ninf(z)->getPiValues();
|
||||||
for (unsigned zi = 0; zi < z->nrStates(); zi++) {
|
for (unsigned zi = 0; zi < z->nrStates(); zi++) {
|
||||||
double product = zPiValues[zi];
|
double product = zPiValues[zi];
|
||||||
if (DL >= 5) {
|
if (Constants::DEBUG >= 5) {
|
||||||
calcs1 = new stringstream;
|
calcs1 = new stringstream;
|
||||||
calcs2 = new stringstream;
|
calcs2 = new stringstream;
|
||||||
*calcs1 << PI_SYMBOL << "(" << z->label() << ")";
|
*calcs1 << PI_SYMBOL << "(" << z->label() << ")";
|
||||||
@ -491,7 +480,7 @@ BnBpSolver::calculatePiMessage (BpLink* link)
|
|||||||
for (unsigned i = 0; i < zChildLinks.size(); i++) {
|
for (unsigned i = 0; i < zChildLinks.size(); i++) {
|
||||||
if (zChildLinks[i]->getSource() != x) {
|
if (zChildLinks[i]->getSource() != x) {
|
||||||
product *= zChildLinks[i]->getMessage()[zi];
|
product *= zChildLinks[i]->getMessage()[zi];
|
||||||
if (DL >= 5) {
|
if (Constants::DEBUG >= 5) {
|
||||||
*calcs1 << "." << zChildLinks[i]->toString (zi);
|
*calcs1 << "." << zChildLinks[i]->toString (zi);
|
||||||
*calcs2 << " * " << zChildLinks[i]->getMessage()[zi];
|
*calcs2 << " * " << zChildLinks[i]->getMessage()[zi];
|
||||||
}
|
}
|
||||||
@ -499,7 +488,7 @@ BnBpSolver::calculatePiMessage (BpLink* link)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
zxPiNextMessage[zi] = product;
|
zxPiNextMessage[zi] = product;
|
||||||
if (DL >= 5) {
|
if (Constants::DEBUG >= 5) {
|
||||||
cout << " " << link->toString();
|
cout << " " << link->toString();
|
||||||
cout << "[" << z->states()[zi] << "]" ;
|
cout << "[" << z->states()[zi] << "]" ;
|
||||||
cout << " = " << (*calcs1).str();
|
cout << " = " << (*calcs1).str();
|
||||||
@ -513,7 +502,7 @@ BnBpSolver::calculatePiMessage (BpLink* link)
|
|||||||
delete calcs2;
|
delete calcs2;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Util::normalize (zxPiNextMessage);
|
LogAware::normalize (zxPiNextMessage);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -530,7 +519,7 @@ BnBpSolver::calculateLambdaMessage (BpLink* link)
|
|||||||
Params& yxLambdaNextMessage = link->getNextMessage();
|
Params& yxLambdaNextMessage = link->getNextMessage();
|
||||||
const BpLinkSet& yParentLinks = ninf(y)->getIncomingParentLinks();
|
const BpLinkSet& yParentLinks = ninf(y)->getIncomingParentLinks();
|
||||||
const Params& yLambdaValues = ninf(y)->getLambdaValues();
|
const Params& yLambdaValues = ninf(y)->getLambdaValues();
|
||||||
int parentIndex = y->getIndexOfParent (x);
|
int parentIndex = y->indexOfParent (x);
|
||||||
stringstream* calcs1 = 0;
|
stringstream* calcs1 = 0;
|
||||||
stringstream* calcs2 = 0;
|
stringstream* calcs2 = 0;
|
||||||
|
|
||||||
@ -548,11 +537,11 @@ BnBpSolver::calculateLambdaMessage (BpLink* link)
|
|||||||
while (indexer[parentIndex] != 0) {
|
while (indexer[parentIndex] != 0) {
|
||||||
++ indexer;
|
++ indexer;
|
||||||
}
|
}
|
||||||
if (DL >= 5) {
|
if (Constants::DEBUG >= 5) {
|
||||||
calcs1 = new stringstream;
|
calcs1 = new stringstream;
|
||||||
calcs2 = new stringstream;
|
calcs2 = new stringstream;
|
||||||
}
|
}
|
||||||
double messageProduct = Util::multIdenty();
|
double messageProduct = LogAware::multIdenty();
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
for (unsigned i = 0; i < yParentLinks.size(); i++) {
|
for (unsigned i = 0; i < yParentLinks.size(); i++) {
|
||||||
if (yParentLinks[i]->getSource() != x) {
|
if (yParentLinks[i]->getSource() != x) {
|
||||||
@ -562,9 +551,9 @@ BnBpSolver::calculateLambdaMessage (BpLink* link)
|
|||||||
} else {
|
} else {
|
||||||
for (unsigned i = 0; i < yParentLinks.size(); i++) {
|
for (unsigned i = 0; i < yParentLinks.size(); i++) {
|
||||||
if (yParentLinks[i]->getSource() != x) {
|
if (yParentLinks[i]->getSource() != x) {
|
||||||
if (DL >= 5) {
|
if (Constants::DEBUG >= 5) {
|
||||||
if (messageProduct != Util::multIdenty()) *calcs1 << "*" ;
|
if (messageProduct != LogAware::multIdenty()) *calcs1 << "*" ;
|
||||||
if (messageProduct != Util::multIdenty()) *calcs2 << "*" ;
|
if (messageProduct != LogAware::multIdenty()) *calcs2 << "*" ;
|
||||||
*calcs1 << yParentLinks[i]->toString (indexer[i]);
|
*calcs1 << yParentLinks[i]->toString (indexer[i]);
|
||||||
*calcs2 << yParentLinks[i]->getMessage()[indexer[i]];
|
*calcs2 << yParentLinks[i]->getMessage()[indexer[i]];
|
||||||
}
|
}
|
||||||
@ -574,7 +563,7 @@ BnBpSolver::calculateLambdaMessage (BpLink* link)
|
|||||||
}
|
}
|
||||||
messageProducts[k] = messageProduct;
|
messageProducts[k] = messageProduct;
|
||||||
++ indexer;
|
++ indexer;
|
||||||
if (DL >= 5) {
|
if (Constants::DEBUG >= 5) {
|
||||||
cout << " mp" << k;
|
cout << " mp" << k;
|
||||||
cout << " = " << (*calcs1).str();
|
cout << " = " << (*calcs1).str();
|
||||||
if (yParentLinks.size() == 1) {
|
if (yParentLinks.size() == 1) {
|
||||||
@ -591,55 +580,54 @@ BnBpSolver::calculateLambdaMessage (BpLink* link)
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (unsigned xi = 0; xi < x->nrStates(); xi++) {
|
for (unsigned xi = 0; xi < x->nrStates(); xi++) {
|
||||||
if (DL >= 5) {
|
if (Constants::DEBUG >= 5) {
|
||||||
calcs1 = new stringstream;
|
calcs1 = new stringstream;
|
||||||
calcs2 = new stringstream;
|
calcs2 = new stringstream;
|
||||||
}
|
}
|
||||||
double outerSum = Util::addIdenty();
|
double outerSum = LogAware::addIdenty();
|
||||||
for (unsigned yi = 0; yi < y->nrStates(); yi++) {
|
for (unsigned yi = 0; yi < y->nrStates(); yi++) {
|
||||||
if (DL >= 5) {
|
if (Constants::DEBUG >= 5) {
|
||||||
(yi != 0) ? *calcs1 << " + {" : *calcs1 << "{" ;
|
(yi != 0) ? *calcs1 << " + {" : *calcs1 << "{" ;
|
||||||
(yi != 0) ? *calcs2 << " + {" : *calcs2 << "{" ;
|
(yi != 0) ? *calcs2 << " + {" : *calcs2 << "{" ;
|
||||||
}
|
}
|
||||||
double innerSum = Util::addIdenty();
|
double innerSum = LogAware::addIdenty();
|
||||||
indexer.reset();
|
indexer.reset();
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
for (unsigned k = 0; k < N; k++) {
|
for (unsigned k = 0; k < N; k++) {
|
||||||
while (indexer[parentIndex] != xi) {
|
while (indexer[parentIndex] != xi) {
|
||||||
++ indexer;
|
++ indexer;
|
||||||
}
|
}
|
||||||
Util::logSum (innerSum, y->getProbability (
|
innerSum = Util::logSum (innerSum,
|
||||||
yi, indexer.linearIndex()) + messageProducts[k]);
|
y->getProbability (yi, indexer) + messageProducts[k]);
|
||||||
++ indexer;
|
++ indexer;
|
||||||
}
|
}
|
||||||
Util::logSum (outerSum, innerSum + yLambdaValues[yi]);
|
outerSum = Util::logSum (outerSum, innerSum + yLambdaValues[yi]);
|
||||||
} else {
|
} else {
|
||||||
for (unsigned k = 0; k < N; k++) {
|
for (unsigned k = 0; k < N; k++) {
|
||||||
while (indexer[parentIndex] != xi) {
|
while (indexer[parentIndex] != xi) {
|
||||||
++ indexer;
|
++ indexer;
|
||||||
}
|
}
|
||||||
if (DL >= 5) {
|
if (Constants::DEBUG >= 5) {
|
||||||
if (k != 0) *calcs1 << " + " ;
|
if (k != 0) *calcs1 << " + " ;
|
||||||
if (k != 0) *calcs2 << " + " ;
|
if (k != 0) *calcs2 << " + " ;
|
||||||
*calcs1 << y->cptEntryToString (yi, indexer.indices());
|
*calcs1 << y->cptEntryToString (yi, indexer.indices());
|
||||||
*calcs1 << ".mp" << k;
|
*calcs1 << ".mp" << k;
|
||||||
*calcs2 << y->getProbability (yi, indexer.linearIndex());
|
*calcs2 << y->getProbability (yi, indexer);
|
||||||
*calcs2 << "*" << messageProducts[k];
|
*calcs2 << "*" << messageProducts[k];
|
||||||
}
|
}
|
||||||
innerSum += y->getProbability (
|
innerSum += y->getProbability (yi, indexer) * messageProducts[k];
|
||||||
yi, indexer.linearIndex()) * messageProducts[k];
|
|
||||||
++ indexer;
|
++ indexer;
|
||||||
}
|
}
|
||||||
outerSum += innerSum * yLambdaValues[yi];
|
outerSum += innerSum * yLambdaValues[yi];
|
||||||
}
|
}
|
||||||
if (DL >= 5) {
|
if (Constants::DEBUG >= 5) {
|
||||||
*calcs1 << "}." << LD_SYMBOL << "(" << y->label() << ")" ;
|
*calcs1 << "}." << LD_SYMBOL << "(" << y->label() << ")" ;
|
||||||
*calcs1 << "[" << y->states()[yi] << "]";
|
*calcs1 << "[" << y->states()[yi] << "]";
|
||||||
*calcs2 << "}*" << yLambdaValues[yi];
|
*calcs2 << "}*" << yLambdaValues[yi];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
yxLambdaNextMessage[xi] = outerSum;
|
yxLambdaNextMessage[xi] = outerSum;
|
||||||
if (DL >= 5) {
|
if (Constants::DEBUG >= 5) {
|
||||||
cout << " " << link->toString();
|
cout << " " << link->toString();
|
||||||
cout << "[" << x->states()[xi] << "]" ;
|
cout << "[" << x->states()[xi] << "]" ;
|
||||||
cout << " = " << (*calcs1).str();
|
cout << " = " << (*calcs1).str();
|
||||||
@ -649,7 +637,7 @@ BnBpSolver::calculateLambdaMessage (BpLink* link)
|
|||||||
delete calcs2;
|
delete calcs2;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Util::normalize (yxLambdaNextMessage);
|
LogAware::normalize (yxLambdaNextMessage);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -674,7 +662,7 @@ BnBpSolver::getJointByConditioning (const VarIds& jointVarIds) const
|
|||||||
for (unsigned i = 1; i < jointVarIds.size(); i++) {
|
for (unsigned i = 1; i < jointVarIds.size(); i++) {
|
||||||
assert (jointVars[i]->hasEvidence() == false);
|
assert (jointVars[i]->hasEvidence() == false);
|
||||||
VarIds reqVars = {jointVarIds[i]};
|
VarIds reqVars = {jointVarIds[i]};
|
||||||
reqVars.insert (reqVars.end(), observedVids.begin(), observedVids.end());
|
Util::addToVector (reqVars, observedVids);
|
||||||
mrn = bayesNet_->getMinimalRequesiteNetwork (reqVars);
|
mrn = bayesNet_->getMinimalRequesiteNetwork (reqVars);
|
||||||
Params newBeliefs;
|
Params newBeliefs;
|
||||||
VarNodes observedVars;
|
VarNodes observedVars;
|
||||||
@ -720,8 +708,7 @@ BnBpSolver::printPiLambdaValues (const BayesNode* var) const
|
|||||||
cout << setw (20) << LD_SYMBOL << "(" + var->label() + ")" ;
|
cout << setw (20) << LD_SYMBOL << "(" + var->label() + ")" ;
|
||||||
cout << setw (16) << "belief" ;
|
cout << setw (16) << "belief" ;
|
||||||
cout << endl;
|
cout << endl;
|
||||||
cout << "--------------------------------" ;
|
Util::printDashedLine();
|
||||||
cout << "--------------------------------" ;
|
|
||||||
cout << endl;
|
cout << endl;
|
||||||
const States& states = var->states();
|
const States& states = var->states();
|
||||||
const Params& piVals = ninf(var)->getPiValues();
|
const Params& piVals = ninf(var)->getPiValues();
|
||||||
@ -731,7 +718,7 @@ BnBpSolver::printPiLambdaValues (const BayesNode* var) const
|
|||||||
cout << setw (10) << states[xi];
|
cout << setw (10) << states[xi];
|
||||||
cout << setw (19) << piVals[xi];
|
cout << setw (19) << piVals[xi];
|
||||||
cout << setw (19) << ldVals[xi];
|
cout << setw (19) << ldVals[xi];
|
||||||
cout.precision (PRECISION);
|
cout.precision (Constants::PRECISION);
|
||||||
cout << setw (16) << beliefs[xi];
|
cout << setw (16) << beliefs[xi];
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
@ -754,8 +741,8 @@ BnBpSolver::printAllMessageStatus (void) const
|
|||||||
BpNodeInfo::BpNodeInfo (BayesNode* node)
|
BpNodeInfo::BpNodeInfo (BayesNode* node)
|
||||||
{
|
{
|
||||||
node_ = node;
|
node_ = node;
|
||||||
piVals_.resize (node->nrStates(), Util::one());
|
piVals_.resize (node->nrStates(), LogAware::one());
|
||||||
ldVals_.resize (node->nrStates(), Util::one());
|
ldVals_.resize (node->nrStates(), LogAware::one());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -27,11 +27,11 @@ class BpLink
|
|||||||
destin_ = d;
|
destin_ = d;
|
||||||
orientation_ = o;
|
orientation_ = o;
|
||||||
if (orientation_ == LinkOrientation::DOWN) {
|
if (orientation_ == LinkOrientation::DOWN) {
|
||||||
v1_.resize (s->nrStates(), Util::tl (1.0 / s->nrStates()));
|
v1_.resize (s->nrStates(), LogAware::tl (1.0 / s->nrStates()));
|
||||||
v2_.resize (s->nrStates(), Util::tl (1.0 / s->nrStates()));
|
v2_.resize (s->nrStates(), LogAware::tl (1.0 / s->nrStates()));
|
||||||
} else {
|
} else {
|
||||||
v1_.resize (d->nrStates(), Util::tl (1.0 / d->nrStates()));
|
v1_.resize (d->nrStates(), LogAware::tl (1.0 / d->nrStates()));
|
||||||
v2_.resize (d->nrStates(), Util::tl (1.0 / d->nrStates()));
|
v2_.resize (d->nrStates(), LogAware::tl (1.0 / d->nrStates()));
|
||||||
}
|
}
|
||||||
currMsg_ = &v1_;
|
currMsg_ = &v1_;
|
||||||
nextMsg_ = &v2_;
|
nextMsg_ = &v2_;
|
||||||
@ -39,6 +39,22 @@ class BpLink
|
|||||||
msgSended_ = false;
|
msgSended_ = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
BayesNode* getSource (void) const { return source_; }
|
||||||
|
|
||||||
|
BayesNode* getDestination (void) const { return destin_; }
|
||||||
|
|
||||||
|
LinkOrientation getOrientation (void) const { return orientation_; }
|
||||||
|
|
||||||
|
const Params& getMessage (void) const { return *currMsg_; }
|
||||||
|
|
||||||
|
Params& getNextMessage (void) { return *nextMsg_;}
|
||||||
|
|
||||||
|
bool messageWasSended (void) const { return msgSended_; }
|
||||||
|
|
||||||
|
double getResidual (void) const { return residual_; }
|
||||||
|
|
||||||
|
void clearResidual (void) { residual_ = 0;}
|
||||||
|
|
||||||
void updateMessage (void)
|
void updateMessage (void)
|
||||||
{
|
{
|
||||||
swap (currMsg_, nextMsg_);
|
swap (currMsg_, nextMsg_);
|
||||||
@ -47,7 +63,7 @@ class BpLink
|
|||||||
|
|
||||||
void updateResidual (void)
|
void updateResidual (void)
|
||||||
{
|
{
|
||||||
residual_ = Util::getMaxNorm (v1_, v2_);
|
residual_ = LogAware::getMaxNorm (v1_, v2_);
|
||||||
}
|
}
|
||||||
|
|
||||||
string toString (void) const
|
string toString (void) const
|
||||||
@ -75,15 +91,6 @@ class BpLink
|
|||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
BayesNode* getSource (void) const { return source_; }
|
|
||||||
BayesNode* getDestination (void) const { return destin_; }
|
|
||||||
LinkOrientation getOrientation (void) const { return orientation_; }
|
|
||||||
const Params& getMessage (void) const { return *currMsg_; }
|
|
||||||
Params& getNextMessage (void) { return *nextMsg_; }
|
|
||||||
bool messageWasSended (void) const { return msgSended_; }
|
|
||||||
double getResidual (void) const { return residual_; }
|
|
||||||
void clearResidual (void) { residual_ = 0;}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
BayesNode* source_;
|
BayesNode* source_;
|
||||||
BayesNode* destin_;
|
BayesNode* destin_;
|
||||||
@ -96,7 +103,6 @@ class BpLink
|
|||||||
double residual_;
|
double residual_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
typedef vector<BpLink*> BpLinkSet;
|
typedef vector<BpLink*> BpLinkSet;
|
||||||
|
|
||||||
|
|
||||||
@ -105,28 +111,37 @@ class BpNodeInfo
|
|||||||
public:
|
public:
|
||||||
BpNodeInfo (BayesNode*);
|
BpNodeInfo (BayesNode*);
|
||||||
|
|
||||||
Params getBeliefs (void) const;
|
|
||||||
bool receivedBottomInfluence (void) const;
|
|
||||||
|
|
||||||
Params& getPiValues (void) { return piVals_; }
|
Params& getPiValues (void) { return piVals_; }
|
||||||
|
|
||||||
Params& getLambdaValues (void) { return ldVals_; }
|
Params& getLambdaValues (void) { return ldVals_; }
|
||||||
|
|
||||||
const BpLinkSet& getIncomingParentLinks (void) { return inParentLinks_; }
|
const BpLinkSet& getIncomingParentLinks (void) { return inParentLinks_; }
|
||||||
|
|
||||||
const BpLinkSet& getIncomingChildLinks (void) { return inChildLinks_; }
|
const BpLinkSet& getIncomingChildLinks (void) { return inChildLinks_; }
|
||||||
|
|
||||||
const BpLinkSet& getOutcomingParentLinks (void) { return outParentLinks_; }
|
const BpLinkSet& getOutcomingParentLinks (void) { return outParentLinks_; }
|
||||||
|
|
||||||
const BpLinkSet& getOutcomingChildLinks (void) { return outChildLinks_; }
|
const BpLinkSet& getOutcomingChildLinks (void) { return outChildLinks_; }
|
||||||
|
|
||||||
void addIncomingParentLink (BpLink* l) { inParentLinks_.push_back (l); }
|
void addIncomingParentLink (BpLink* l) { inParentLinks_.push_back (l); }
|
||||||
|
|
||||||
void addIncomingChildLink (BpLink* l) { inChildLinks_.push_back (l); }
|
void addIncomingChildLink (BpLink* l) { inChildLinks_.push_back (l); }
|
||||||
|
|
||||||
void addOutcomingParentLink (BpLink* l) { outParentLinks_.push_back (l); }
|
void addOutcomingParentLink (BpLink* l) { outParentLinks_.push_back (l); }
|
||||||
|
|
||||||
void addOutcomingChildLink (BpLink* l) { outChildLinks_.push_back (l); }
|
void addOutcomingChildLink (BpLink* l) { outChildLinks_.push_back (l); }
|
||||||
|
|
||||||
|
Params getBeliefs (void) const;
|
||||||
|
|
||||||
|
bool receivedBottomInfluence (void) const;
|
||||||
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
DISALLOW_COPY_AND_ASSIGN (BpNodeInfo);
|
DISALLOW_COPY_AND_ASSIGN (BpNodeInfo);
|
||||||
|
|
||||||
const BayesNode* node_;
|
const BayesNode* node_;
|
||||||
Params piVals_; // pi values
|
Params piVals_;
|
||||||
Params ldVals_; // lambda values
|
Params ldVals_;
|
||||||
BpLinkSet inParentLinks_;
|
BpLinkSet inParentLinks_;
|
||||||
BpLinkSet inChildLinks_;
|
BpLinkSet inChildLinks_;
|
||||||
BpLinkSet outParentLinks_;
|
BpLinkSet outParentLinks_;
|
||||||
@ -139,32 +154,43 @@ class BnBpSolver : public Solver
|
|||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
BnBpSolver (const BayesNet&);
|
BnBpSolver (const BayesNet&);
|
||||||
|
|
||||||
~BnBpSolver (void);
|
~BnBpSolver (void);
|
||||||
|
|
||||||
void runSolver (void);
|
void runSolver (void);
|
||||||
Params getPosterioriOf (VarId);
|
Params getPosterioriOf (VarId);
|
||||||
Params getJointDistributionOf (const VarIds&);
|
Params getJointDistributionOf (const VarIds&);
|
||||||
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
DISALLOW_COPY_AND_ASSIGN (BnBpSolver);
|
DISALLOW_COPY_AND_ASSIGN (BnBpSolver);
|
||||||
|
|
||||||
void initializeSolver (void);
|
void initializeSolver (void);
|
||||||
|
|
||||||
void runLoopySolver (void);
|
void runLoopySolver (void);
|
||||||
|
|
||||||
void maxResidualSchedule (void);
|
void maxResidualSchedule (void);
|
||||||
|
|
||||||
bool converged (void) const;
|
bool converged (void) const;
|
||||||
|
|
||||||
void updatePiValues (BayesNode*);
|
void updatePiValues (BayesNode*);
|
||||||
|
|
||||||
void updateLambdaValues (BayesNode*);
|
void updateLambdaValues (BayesNode*);
|
||||||
|
|
||||||
void calculateLambdaMessage (BpLink*);
|
void calculateLambdaMessage (BpLink*);
|
||||||
|
|
||||||
void calculatePiMessage (BpLink*);
|
void calculatePiMessage (BpLink*);
|
||||||
|
|
||||||
Params getJointByJunctionNode (const VarIds&);
|
Params getJointByJunctionNode (const VarIds&);
|
||||||
|
|
||||||
Params getJointByConditioning (const VarIds&) const;
|
Params getJointByConditioning (const VarIds&) const;
|
||||||
|
|
||||||
void printPiLambdaValues (const BayesNode*) const;
|
void printPiLambdaValues (const BayesNode*) const;
|
||||||
|
|
||||||
void printAllMessageStatus (void) const;
|
void printAllMessageStatus (void) const;
|
||||||
|
|
||||||
void calculateAndUpdateMessage (BpLink* link, bool calcResidual = true)
|
void calculateAndUpdateMessage (BpLink* link, bool calcResidual = true)
|
||||||
{
|
{
|
||||||
if (DL >= 3) {
|
if (Constants::DEBUG >= 3) {
|
||||||
cout << "calculating & updating " << link->toString() << endl;
|
cout << "calculating & updating " << link->toString() << endl;
|
||||||
}
|
}
|
||||||
if (link->getOrientation() == LinkOrientation::DOWN) {
|
if (link->getOrientation() == LinkOrientation::DOWN) {
|
||||||
@ -180,7 +206,7 @@ class BnBpSolver : public Solver
|
|||||||
|
|
||||||
void calculateMessage (BpLink* link, bool calcResidual = true)
|
void calculateMessage (BpLink* link, bool calcResidual = true)
|
||||||
{
|
{
|
||||||
if (DL >= 3) {
|
if (Constants::DEBUG >= 3) {
|
||||||
cout << "calculating " << link->toString() << endl;
|
cout << "calculating " << link->toString() << endl;
|
||||||
}
|
}
|
||||||
if (link->getOrientation() == LinkOrientation::DOWN) {
|
if (link->getOrientation() == LinkOrientation::DOWN) {
|
||||||
@ -195,7 +221,7 @@ class BnBpSolver : public Solver
|
|||||||
|
|
||||||
void updateMessage (BpLink* link)
|
void updateMessage (BpLink* link)
|
||||||
{
|
{
|
||||||
if (DL >= 3) {
|
if (Constants::DEBUG >= 3) {
|
||||||
cout << "updating " << link->toString() << endl;
|
cout << "updating " << link->toString() << endl;
|
||||||
}
|
}
|
||||||
link->updateMessage();
|
link->updateMessage();
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
|
|
||||||
#include "CFactorGraph.h"
|
#include "CFactorGraph.h"
|
||||||
#include "Factor.h"
|
#include "Factor.h"
|
||||||
#include "Distribution.h"
|
|
||||||
|
|
||||||
|
|
||||||
bool CFactorGraph::checkForIdenticalFactors = true;
|
bool CFactorGraph::checkForIdenticalFactors = true;
|
||||||
@ -73,27 +72,34 @@ CFactorGraph::setInitialColors (void)
|
|||||||
|
|
||||||
const FgFacSet& facNodes = groundFg_->getFactorNodes();
|
const FgFacSet& facNodes = groundFg_->getFactorNodes();
|
||||||
if (checkForIdenticalFactors) {
|
if (checkForIdenticalFactors) {
|
||||||
for (unsigned i = 0, s = facNodes.size(); i < s; i++) {
|
unsigned groupCount = 1;
|
||||||
Distribution* dist1 = facNodes[i]->getDistribution();
|
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||||
for (unsigned j = 0; j < i; j++) {
|
Factor* f1 = facNodes[i]->factor();
|
||||||
Distribution* dist2 = facNodes[j]->getDistribution();
|
if (f1->distId() != Util::maxUnsigned()) {
|
||||||
if (dist1 != dist2 && dist1->params == dist2->params) {
|
continue;
|
||||||
if (facNodes[i]->factor()->getRanges() ==
|
}
|
||||||
facNodes[j]->factor()->getRanges()) {
|
f1->setDistId (groupCount);
|
||||||
facNodes[i]->factor()->setDistribution (dist2);
|
for (unsigned j = i + 1; j < facNodes.size(); j++) {
|
||||||
|
Factor* f2 = facNodes[j]->factor();
|
||||||
|
if (f2->distId() != Util::maxUnsigned()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (f1->size() == f2->size() &&
|
||||||
|
f1->ranges() == f2->ranges() &&
|
||||||
|
f1->params() == f2->params()) {
|
||||||
|
f2->setDistId (groupCount);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
groupCount ++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// create the initial factor colors
|
// create the initial factor colors
|
||||||
DistColorMap distColors;
|
DistColorMap distColors;
|
||||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||||
const Distribution* dist = facNodes[i]->getDistribution();
|
unsigned distId = facNodes[i]->factor()->distId();
|
||||||
DistColorMap::iterator it = distColors.find (dist);
|
DistColorMap::iterator it = distColors.find (distId);
|
||||||
if (it == distColors.end()) {
|
if (it == distColors.end()) {
|
||||||
it = distColors.insert (make_pair (dist, getFreeColor())).first;
|
it = distColors.insert (make_pair (distId, getFreeColor())).first;
|
||||||
}
|
}
|
||||||
setColor (facNodes[i], it->second);
|
setColor (facNodes[i], it->second);
|
||||||
}
|
}
|
||||||
@ -164,7 +170,8 @@ CFactorGraph::createGroups (void)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
CFactorGraph::createClusters (const VarSignMap& varGroups,
|
CFactorGraph::createClusters (
|
||||||
|
const VarSignMap& varGroups,
|
||||||
const FacSignMap& factorGroups)
|
const FacSignMap& factorGroups)
|
||||||
{
|
{
|
||||||
varClusters_.reserve (varGroups.size());
|
varClusters_.reserve (varGroups.size());
|
||||||
@ -249,7 +256,7 @@ CFactorGraph::getCompressedFactorGraph (void)
|
|||||||
myGroundVars.push_back (v);
|
myGroundVars.push_back (v);
|
||||||
}
|
}
|
||||||
Factor* newFactor = new Factor (myGroundVars,
|
Factor* newFactor = new Factor (myGroundVars,
|
||||||
facClusters_[i]->getGroundFactors()[0]->getDistribution());
|
facClusters_[i]->getGroundFactors()[0]->params());
|
||||||
FgFacNode* fn = new FgFacNode (newFactor);
|
FgFacNode* fn = new FgFacNode (newFactor);
|
||||||
facClusters_[i]->setRepresentativeFactor (fn);
|
facClusters_[i]->setRepresentativeFactor (fn);
|
||||||
fg->addFactor (fn);
|
fg->addFactor (fn);
|
||||||
@ -293,7 +300,8 @@ CFactorGraph::getGroundEdgeCount (
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
CFactorGraph::printGroups (const VarSignMap& varGroups,
|
CFactorGraph::printGroups (
|
||||||
|
const VarSignMap& varGroups,
|
||||||
const FacSignMap& factorGroups) const
|
const FacSignMap& factorGroups) const
|
||||||
{
|
{
|
||||||
unsigned count = 1;
|
unsigned count = 1;
|
||||||
|
@ -16,11 +16,15 @@ class SignatureHash;
|
|||||||
|
|
||||||
|
|
||||||
typedef long Color;
|
typedef long Color;
|
||||||
typedef unordered_map<unsigned, vector<Color> > VarColorMap;
|
|
||||||
typedef unordered_map<const Distribution*, Color> DistColorMap;
|
typedef unordered_map<unsigned, vector<Color>> VarColorMap;
|
||||||
|
|
||||||
|
typedef unordered_map<unsigned, Color> DistColorMap;
|
||||||
typedef unordered_map<VarId, VarCluster*> VarId2VarCluster;
|
typedef unordered_map<VarId, VarCluster*> VarId2VarCluster;
|
||||||
|
|
||||||
typedef vector<VarCluster*> VarClusterSet;
|
typedef vector<VarCluster*> VarClusterSet;
|
||||||
typedef vector<FacCluster*> FacClusterSet;
|
typedef vector<FacCluster*> FacClusterSet;
|
||||||
|
|
||||||
typedef unordered_map<Signature, FgVarSet, SignatureHash> VarSignMap;
|
typedef unordered_map<Signature, FgVarSet, SignatureHash> VarSignMap;
|
||||||
typedef unordered_map<Signature, FgFacSet, SignatureHash> FacSignMap;
|
typedef unordered_map<Signature, FgFacSet, SignatureHash> FacSignMap;
|
||||||
|
|
||||||
@ -28,10 +32,8 @@ typedef unordered_map<Signature, FgFacSet, SignatureHash> FacSignMap;
|
|||||||
|
|
||||||
struct Signature
|
struct Signature
|
||||||
{
|
{
|
||||||
Signature (unsigned size)
|
Signature (unsigned size) : colors(size) { }
|
||||||
{
|
|
||||||
colors.resize (size);
|
|
||||||
}
|
|
||||||
bool operator< (const Signature& sig) const
|
bool operator< (const Signature& sig) const
|
||||||
{
|
{
|
||||||
if (colors.size() < sig.colors.size()) {
|
if (colors.size() < sig.colors.size()) {
|
||||||
@ -49,6 +51,7 @@ struct Signature
|
|||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool operator== (const Signature& sig) const
|
bool operator== (const Signature& sig) const
|
||||||
{
|
{
|
||||||
if (colors.size() != sig.colors.size()) {
|
if (colors.size() != sig.colors.size()) {
|
||||||
@ -61,12 +64,14 @@ struct Signature
|
|||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
vector<Color> colors;
|
vector<Color> colors;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
struct SignatureHash {
|
struct SignatureHash
|
||||||
|
{
|
||||||
size_t operator() (const Signature &sig) const
|
size_t operator() (const Signature &sig) const
|
||||||
{
|
{
|
||||||
size_t val = hash<size_t>()(sig.colors.size());
|
size_t val = hash<size_t>()(sig.colors.size());
|
||||||
@ -141,10 +146,12 @@ class FacCluster
|
|||||||
{
|
{
|
||||||
return representFactor_;
|
return representFactor_;
|
||||||
}
|
}
|
||||||
|
|
||||||
void setRepresentativeFactor (FgFacNode* fn)
|
void setRepresentativeFactor (FgFacNode* fn)
|
||||||
{
|
{
|
||||||
representFactor_ = fn;
|
representFactor_ = fn;
|
||||||
}
|
}
|
||||||
|
|
||||||
const FgFacSet& getGroundFactors (void) const
|
const FgFacSet& getGroundFactors (void) const
|
||||||
{
|
{
|
||||||
return groundFactors_;
|
return groundFactors_;
|
||||||
@ -162,10 +169,12 @@ class CFactorGraph
|
|||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
CFactorGraph (const FactorGraph&);
|
CFactorGraph (const FactorGraph&);
|
||||||
|
|
||||||
~CFactorGraph (void);
|
~CFactorGraph (void);
|
||||||
|
|
||||||
FactorGraph* getCompressedFactorGraph (void);
|
const VarClusterSet& getVarClusters (void) { return varClusters_; }
|
||||||
unsigned getGroundEdgeCount (const FacCluster*, const VarCluster*) const;
|
|
||||||
|
const FacClusterSet& getFacClusters (void) { return facClusters_; }
|
||||||
|
|
||||||
FgVarNode* getEquivalentVariable (VarId vid)
|
FgVarNode* getEquivalentVariable (VarId vid)
|
||||||
{
|
{
|
||||||
@ -173,20 +182,15 @@ class CFactorGraph
|
|||||||
return vc->getRepresentativeVariable();
|
return vc->getRepresentativeVariable();
|
||||||
}
|
}
|
||||||
|
|
||||||
const VarClusterSet& getVarClusters (void) { return varClusters_; }
|
FactorGraph* getCompressedFactorGraph (void);
|
||||||
const FacClusterSet& getFacClusters (void) { return facClusters_; }
|
|
||||||
|
unsigned getGroundEdgeCount (const FacCluster*, const VarCluster*) const;
|
||||||
|
|
||||||
static bool checkForIdenticalFactors;
|
static bool checkForIdenticalFactors;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void setInitialColors (void);
|
Color getFreeColor (void)
|
||||||
void createGroups (void);
|
{
|
||||||
void createClusters (const VarSignMap&, const FacSignMap&);
|
|
||||||
const Signature& getSignature (const FgVarNode*);
|
|
||||||
const Signature& getSignature (const FgFacNode*);
|
|
||||||
void printGroups (const VarSignMap&, const FacSignMap&) const;
|
|
||||||
|
|
||||||
Color getFreeColor (void) {
|
|
||||||
++ freeColor_;
|
++ freeColor_;
|
||||||
return freeColor_ - 1;
|
return freeColor_ - 1;
|
||||||
}
|
}
|
||||||
@ -214,6 +218,18 @@ class CFactorGraph
|
|||||||
return vid2VarCluster_.find (vid)->second;
|
return vid2VarCluster_.find (vid)->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void setInitialColors (void);
|
||||||
|
|
||||||
|
void createGroups (void);
|
||||||
|
|
||||||
|
void createClusters (const VarSignMap&, const FacSignMap&);
|
||||||
|
|
||||||
|
const Signature& getSignature (const FgVarNode*);
|
||||||
|
|
||||||
|
const Signature& getSignature (const FgFacNode*);
|
||||||
|
|
||||||
|
void printGroups (const VarSignMap&, const FacSignMap&) const;
|
||||||
|
|
||||||
Color freeColor_;
|
Color freeColor_;
|
||||||
vector<Color> varColors_;
|
vector<Color> varColors_;
|
||||||
vector<Color> factorColors_;
|
vector<Color> factorColors_;
|
||||||
|
@ -20,24 +20,24 @@ CbpSolver::getPosterioriOf (VarId vid)
|
|||||||
FgVarNode* var = lfg_->getEquivalentVariable (vid);
|
FgVarNode* var = lfg_->getEquivalentVariable (vid);
|
||||||
Params probs;
|
Params probs;
|
||||||
if (var->hasEvidence()) {
|
if (var->hasEvidence()) {
|
||||||
probs.resize (var->nrStates(), Util::noEvidence());
|
probs.resize (var->nrStates(), LogAware::noEvidence());
|
||||||
probs[var->getEvidence()] = Util::withEvidence();
|
probs[var->getEvidence()] = LogAware::withEvidence();
|
||||||
} else {
|
} else {
|
||||||
probs.resize (var->nrStates(), Util::multIdenty());
|
probs.resize (var->nrStates(), LogAware::multIdenty());
|
||||||
const SpLinkSet& links = ninf(var)->getLinks();
|
const SpLinkSet& links = ninf(var)->getLinks();
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
for (unsigned i = 0; i < links.size(); i++) {
|
for (unsigned i = 0; i < links.size(); i++) {
|
||||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
||||||
Util::add (probs, l->getPoweredMessage());
|
Util::add (probs, l->getPoweredMessage());
|
||||||
}
|
}
|
||||||
Util::normalize (probs);
|
LogAware::normalize (probs);
|
||||||
Util::fromLog (probs);
|
Util::fromLog (probs);
|
||||||
} else {
|
} else {
|
||||||
for (unsigned i = 0; i < links.size(); i++) {
|
for (unsigned i = 0; i < links.size(); i++) {
|
||||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
||||||
Util::multiply (probs, l->getPoweredMessage());
|
Util::multiply (probs, l->getPoweredMessage());
|
||||||
}
|
}
|
||||||
Util::normalize (probs);
|
LogAware::normalize (probs);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return probs;
|
return probs;
|
||||||
@ -62,7 +62,7 @@ void
|
|||||||
CbpSolver::initializeSolver (void)
|
CbpSolver::initializeSolver (void)
|
||||||
{
|
{
|
||||||
unsigned nGroundVars, nGroundFacs, nWithoutNeighs;
|
unsigned nGroundVars, nGroundFacs, nWithoutNeighs;
|
||||||
if (COLLECT_STATISTICS) {
|
if (Constants::COLLECT_STATS) {
|
||||||
nGroundVars = factorGraph_->getVarNodes().size();
|
nGroundVars = factorGraph_->getVarNodes().size();
|
||||||
nGroundFacs = factorGraph_->getFactorNodes().size();
|
nGroundFacs = factorGraph_->getFactorNodes().size();
|
||||||
const FgVarSet& vars = factorGraph_->getVarNodes();
|
const FgVarSet& vars = factorGraph_->getVarNodes();
|
||||||
@ -82,7 +82,7 @@ CbpSolver::initializeSolver (void)
|
|||||||
// factorGraph_->exportToGraphViz ("uncompressed_fg.dot");
|
// factorGraph_->exportToGraphViz ("uncompressed_fg.dot");
|
||||||
factorGraph_ = lfg_->getCompressedFactorGraph();
|
factorGraph_ = lfg_->getCompressedFactorGraph();
|
||||||
|
|
||||||
if (COLLECT_STATISTICS) {
|
if (Constants::COLLECT_STATS) {
|
||||||
unsigned nClusterVars = factorGraph_->getVarNodes().size();
|
unsigned nClusterVars = factorGraph_->getVarNodes().size();
|
||||||
unsigned nClusterFacs = factorGraph_->getFactorNodes().size();
|
unsigned nClusterFacs = factorGraph_->getFactorNodes().size();
|
||||||
Statistics::updateCompressingStatistics (nGroundVars, nGroundFacs,
|
Statistics::updateCompressingStatistics (nGroundVars, nGroundFacs,
|
||||||
@ -123,7 +123,7 @@ CbpSolver::maxResidualSchedule (void)
|
|||||||
calculateMessage (links_[i]);
|
calculateMessage (links_[i]);
|
||||||
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
|
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
|
||||||
linkMap_.insert (make_pair (links_[i], it));
|
linkMap_.insert (make_pair (links_[i], it));
|
||||||
if (DL >= 2 && DL < 5) {
|
if (Constants::DEBUG >= 2 && Constants::DEBUG < 5) {
|
||||||
cout << "calculating " << links_[i]->toString() << endl;
|
cout << "calculating " << links_[i]->toString() << endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -131,7 +131,7 @@ CbpSolver::maxResidualSchedule (void)
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (unsigned c = 0; c < links_.size(); c++) {
|
for (unsigned c = 0; c < links_.size(); c++) {
|
||||||
if (DL >= 2) {
|
if (Constants::DEBUG >= 2) {
|
||||||
cout << endl << "current residuals:" << endl;
|
cout << endl << "current residuals:" << endl;
|
||||||
for (SortedOrder::iterator it = sortedOrder_.begin();
|
for (SortedOrder::iterator it = sortedOrder_.begin();
|
||||||
it != sortedOrder_.end(); it ++) {
|
it != sortedOrder_.end(); it ++) {
|
||||||
@ -142,7 +142,7 @@ CbpSolver::maxResidualSchedule (void)
|
|||||||
|
|
||||||
SortedOrder::iterator it = sortedOrder_.begin();
|
SortedOrder::iterator it = sortedOrder_.begin();
|
||||||
SpLink* link = *it;
|
SpLink* link = *it;
|
||||||
if (DL >= 2) {
|
if (Constants::DEBUG >= 2) {
|
||||||
cout << "updating " << (*sortedOrder_.begin())->toString() << endl;
|
cout << "updating " << (*sortedOrder_.begin())->toString() << endl;
|
||||||
}
|
}
|
||||||
if (link->getResidual() < BpOptions::accuracy) {
|
if (link->getResidual() < BpOptions::accuracy) {
|
||||||
@ -159,7 +159,7 @@ CbpSolver::maxResidualSchedule (void)
|
|||||||
const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks();
|
const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks();
|
||||||
for (unsigned j = 0; j < links.size(); j++) {
|
for (unsigned j = 0; j < links.size(); j++) {
|
||||||
if (links[j]->getVariable() != link->getVariable()) {
|
if (links[j]->getVariable() != link->getVariable()) {
|
||||||
if (DL >= 2 && DL < 5) {
|
if (Constants::DEBUG >= 2 && Constants::DEBUG < 5) {
|
||||||
cout << " calculating " << links[j]->toString() << endl;
|
cout << " calculating " << links[j]->toString() << endl;
|
||||||
}
|
}
|
||||||
calculateMessage (links[j]);
|
calculateMessage (links[j]);
|
||||||
@ -174,7 +174,7 @@ CbpSolver::maxResidualSchedule (void)
|
|||||||
const SpLinkSet& links = ninf(link->getFactor())->getLinks();
|
const SpLinkSet& links = ninf(link->getFactor())->getLinks();
|
||||||
for (unsigned i = 0; i < links.size(); i++) {
|
for (unsigned i = 0; i < links.size(); i++) {
|
||||||
if (links[i]->getVariable() != link->getVariable()) {
|
if (links[i]->getVariable() != link->getVariable()) {
|
||||||
if (DL >= 2 && DL < 5) {
|
if (Constants::DEBUG >= 2 && Constants::DEBUG < 5) {
|
||||||
cout << " calculating " << links[i]->toString() << endl;
|
cout << " calculating " << links[i]->toString() << endl;
|
||||||
}
|
}
|
||||||
calculateMessage (links[i]);
|
calculateMessage (links[i]);
|
||||||
@ -196,15 +196,15 @@ CbpSolver::getVar2FactorMsg (const SpLink* link) const
|
|||||||
const FgFacNode* dst = link->getFactor();
|
const FgFacNode* dst = link->getFactor();
|
||||||
const CbpSolverLink* l = static_cast<const CbpSolverLink*> (link);
|
const CbpSolverLink* l = static_cast<const CbpSolverLink*> (link);
|
||||||
if (src->hasEvidence()) {
|
if (src->hasEvidence()) {
|
||||||
msg.resize (src->nrStates(), Util::noEvidence());
|
msg.resize (src->nrStates(), LogAware::noEvidence());
|
||||||
double value = link->getMessage()[src->getEvidence()];
|
double value = link->getMessage()[src->getEvidence()];
|
||||||
msg[src->getEvidence()] = Util::pow (value, l->getNumberOfEdges() - 1);
|
msg[src->getEvidence()] = LogAware::pow (value, l->getNumberOfEdges() - 1);
|
||||||
} else {
|
} else {
|
||||||
msg = link->getMessage();
|
msg = link->getMessage();
|
||||||
Util::pow (msg, l->getNumberOfEdges() - 1);
|
LogAware::pow (msg, l->getNumberOfEdges() - 1);
|
||||||
}
|
}
|
||||||
if (DL >= 5) {
|
if (Constants::DEBUG >= 5) {
|
||||||
cout << " " << "init: " << Util::parametersToString (msg) << endl;
|
cout << " " << "init: " << msg << endl;
|
||||||
}
|
}
|
||||||
const SpLinkSet& links = ninf(src)->getLinks();
|
const SpLinkSet& links = ninf(src)->getLinks();
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
@ -219,16 +219,16 @@ CbpSolver::getVar2FactorMsg (const SpLink* link) const
|
|||||||
if (links[i]->getFactor() != dst) {
|
if (links[i]->getFactor() != dst) {
|
||||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
||||||
Util::multiply (msg, l->getPoweredMessage());
|
Util::multiply (msg, l->getPoweredMessage());
|
||||||
if (DL >= 5) {
|
if (Constants::DEBUG >= 5) {
|
||||||
cout << " msg from " << l->getFactor()->getLabel() << ": " ;
|
cout << " msg from " << l->getFactor()->getLabel() << ": " ;
|
||||||
cout << Util::parametersToString (l->getPoweredMessage()) << endl;
|
cout << l->getPoweredMessage() << endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (DL >= 5) {
|
if (Constants::DEBUG >= 5) {
|
||||||
cout << " result = " << Util::parametersToString (msg) << endl;
|
cout << " result = " << msg << endl;
|
||||||
}
|
}
|
||||||
return msg;
|
return msg;
|
||||||
}
|
}
|
||||||
@ -241,12 +241,9 @@ CbpSolver::printLinkInformation (void) const
|
|||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
for (unsigned i = 0; i < links_.size(); i++) {
|
||||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links_[i]);
|
CbpSolverLink* l = static_cast<CbpSolverLink*> (links_[i]);
|
||||||
cout << l->toString() << ":" << endl;
|
cout << l->toString() << ":" << endl;
|
||||||
cout << " curr msg = " ;
|
cout << " curr msg = " << l->getMessage() << endl;
|
||||||
cout << Util::parametersToString (l->getMessage()) << endl;
|
cout << " next msg = " << l->getNextMessage() << endl;
|
||||||
cout << " next msg = " ;
|
cout << " powered = " << l->getPoweredMessage() << endl;
|
||||||
cout << Util::parametersToString (l->getNextMessage()) << endl;
|
|
||||||
cout << " powered = " ;
|
|
||||||
cout << Util::parametersToString (l->getPoweredMessage()) << endl;
|
|
||||||
cout << " residual = " << l->getResidual() << endl;
|
cout << " residual = " << l->getResidual() << endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -12,20 +12,21 @@ class CbpSolverLink : public SpLink
|
|||||||
CbpSolverLink (FgFacNode* fn, FgVarNode* vn, unsigned c) : SpLink (fn, vn)
|
CbpSolverLink (FgFacNode* fn, FgVarNode* vn, unsigned c) : SpLink (fn, vn)
|
||||||
{
|
{
|
||||||
edgeCount_ = c;
|
edgeCount_ = c;
|
||||||
poweredMsg_.resize (vn->nrStates(), Util::one());
|
poweredMsg_.resize (vn->nrStates(), LogAware::one());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
unsigned getNumberOfEdges (void) const { return edgeCount_; }
|
||||||
|
|
||||||
|
const Params& getPoweredMessage (void) const { return poweredMsg_; }
|
||||||
|
|
||||||
void updateMessage (void)
|
void updateMessage (void)
|
||||||
{
|
{
|
||||||
poweredMsg_ = *nextMsg_;
|
poweredMsg_ = *nextMsg_;
|
||||||
swap (currMsg_, nextMsg_);
|
swap (currMsg_, nextMsg_);
|
||||||
msgSended_ = true;
|
msgSended_ = true;
|
||||||
Util::pow (poweredMsg_, edgeCount_);
|
LogAware::pow (poweredMsg_, edgeCount_);
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned getNumberOfEdges (void) const { return edgeCount_; }
|
|
||||||
const Params& getPoweredMessage (void) const { return poweredMsg_; }
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Params poweredMsg_;
|
Params poweredMsg_;
|
||||||
unsigned edgeCount_;
|
unsigned edgeCount_;
|
||||||
@ -37,9 +38,11 @@ class CbpSolver : public FgBpSolver
|
|||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
CbpSolver (FactorGraph& fg) : FgBpSolver (fg) { }
|
CbpSolver (FactorGraph& fg) : FgBpSolver (fg) { }
|
||||||
|
|
||||||
~CbpSolver (void);
|
~CbpSolver (void);
|
||||||
|
|
||||||
Params getPosterioriOf (VarId);
|
Params getPosterioriOf (VarId);
|
||||||
|
|
||||||
Params getJointDistributionOf (const VarIds&);
|
Params getJointDistributionOf (const VarIds&);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -50,7 +53,6 @@ class CbpSolver : public FgBpSolver
|
|||||||
Params getVar2FactorMsg (const SpLink*) const;
|
Params getVar2FactorMsg (const SpLink*) const;
|
||||||
void printLinkInformation (void) const;
|
void printLinkInformation (void) const;
|
||||||
|
|
||||||
|
|
||||||
CFactorGraph* lfg_;
|
CFactorGraph* lfg_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
#include <queue>
|
#include <queue>
|
||||||
|
|
||||||
|
#include <fstream>
|
||||||
|
|
||||||
#include "ConstraintTree.h"
|
#include "ConstraintTree.h"
|
||||||
#include "Util.h"
|
#include "Util.h"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
CTNode::addChild (CTNode* child, bool updateLevels)
|
CTNode::addChild (CTNode* child, bool updateLevels)
|
||||||
{
|
{
|
||||||
@ -42,6 +43,26 @@ CTNode::removeChild (CTNode* child)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
CTNode::removeAndDeleteChild (CTNode* child)
|
||||||
|
{
|
||||||
|
removeChild (child);
|
||||||
|
CTNode::deleteSubtree (child);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
CTNode::removeAndDeleteAllChilds (void)
|
||||||
|
{
|
||||||
|
for (unsigned i = 0; i < childs_.size(); i++) {
|
||||||
|
deleteSubtree (childs_[i]);
|
||||||
|
}
|
||||||
|
childs_.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
SymbolSet
|
SymbolSet
|
||||||
CTNode::childSymbols (void) const
|
CTNode::childSymbols (void) const
|
||||||
{
|
{
|
||||||
@ -66,6 +87,32 @@ CTNode::updateChildLevels (CTNode* n, unsigned level)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
CTNode*
|
||||||
|
CTNode::copySubtree (const CTNode* n)
|
||||||
|
{
|
||||||
|
CTNode* newNode = new CTNode (*n);
|
||||||
|
const CTNodes& childs = n->childs();
|
||||||
|
for (unsigned i = 0; i < childs.size(); i++) {
|
||||||
|
newNode->addChild (copySubtree (childs[i]));
|
||||||
|
}
|
||||||
|
return newNode;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
CTNode::deleteSubtree (CTNode* n)
|
||||||
|
{
|
||||||
|
assert (n);
|
||||||
|
const CTNodes& childs = n->childs();
|
||||||
|
for (unsigned i = 0; i < childs.size(); i++) {
|
||||||
|
deleteSubtree (childs[i]);
|
||||||
|
}
|
||||||
|
delete n;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
ostream& operator<< (ostream &out, const CTNode& n)
|
ostream& operator<< (ostream &out, const CTNode& n)
|
||||||
{
|
{
|
||||||
// out << "(" << n.level() << ") " ;
|
// out << "(" << n.level() << ") " ;
|
||||||
@ -75,6 +122,17 @@ ostream& operator<< (ostream &out, const CTNode& n)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
ConstraintTree::ConstraintTree (unsigned nrLvs)
|
||||||
|
{
|
||||||
|
for (unsigned i = 0; i < nrLvs; i++) {
|
||||||
|
logVars_.push_back (LogVar (i));
|
||||||
|
}
|
||||||
|
root_ = new CTNode (0, 0);
|
||||||
|
logVarSet_ = LogVarSet (logVars_);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
ConstraintTree::ConstraintTree (const LogVars& logVars)
|
ConstraintTree::ConstraintTree (const LogVars& logVars)
|
||||||
{
|
{
|
||||||
root_ = new CTNode (0, 0);
|
root_ = new CTNode (0, 0);
|
||||||
@ -99,7 +157,7 @@ ConstraintTree::ConstraintTree (const LogVars& logVars,
|
|||||||
|
|
||||||
ConstraintTree::ConstraintTree (const ConstraintTree& ct)
|
ConstraintTree::ConstraintTree (const ConstraintTree& ct)
|
||||||
{
|
{
|
||||||
root_ = copySubtree (ct.root_);
|
root_ = CTNode::copySubtree (ct.root_);
|
||||||
logVars_ = ct.logVars_;
|
logVars_ = ct.logVars_;
|
||||||
logVarSet_ = ct.logVarSet_;
|
logVarSet_ = ct.logVarSet_;
|
||||||
}
|
}
|
||||||
@ -108,7 +166,7 @@ ConstraintTree::ConstraintTree (const ConstraintTree& ct)
|
|||||||
|
|
||||||
ConstraintTree::~ConstraintTree (void)
|
ConstraintTree::~ConstraintTree (void)
|
||||||
{
|
{
|
||||||
deleteSubtree (root_);
|
CTNode::deleteSubtree (root_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -201,20 +259,27 @@ ConstraintTree::moveToBottom (const LogVars& lvs)
|
|||||||
void
|
void
|
||||||
ConstraintTree::join (ConstraintTree* ct, bool assertWhenNotFound)
|
ConstraintTree::join (ConstraintTree* ct, bool assertWhenNotFound)
|
||||||
{
|
{
|
||||||
|
if (logVarSet_.empty()) {
|
||||||
|
delete root_;
|
||||||
|
root_ = CTNode::copySubtree (ct->root());
|
||||||
|
logVars_ = ct->logVars();
|
||||||
|
logVarSet_ = ct->logVarSet();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
LogVarSet intersect = logVarSet_ & ct->logVarSet_;
|
LogVarSet intersect = logVarSet_ & ct->logVarSet_;
|
||||||
if (intersect.empty()) {
|
if (intersect.empty()) {
|
||||||
const CTNodes& childs = ct->root()->childs();
|
const CTNodes& childs = ct->root()->childs();
|
||||||
CTNodes leafs = getNodesAtLevel (getLevel (logVars_.back()));
|
CTNodes leafs = getNodesAtLevel (getLevel (logVars_.back()));
|
||||||
for (unsigned i = 0; i < leafs.size(); i++) {
|
for (unsigned i = 0; i < leafs.size(); i++) {
|
||||||
for (unsigned j = 0; j < childs.size(); j++) {
|
for (unsigned j = 0; j < childs.size(); j++) {
|
||||||
leafs[i]->addChild (copySubtree (childs[j]));
|
leafs[i]->addChild (CTNode::copySubtree (childs[j]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
logVars_.insert (logVars_.end(), ct->logVars_.begin(), ct->logVars_.end());
|
Util::addToVector (logVars_, ct->logVars_);
|
||||||
logVarSet_ |= ct->logVarSet_;
|
logVarSet_ |= ct->logVarSet_;
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
|
|
||||||
moveToBottom (intersect.elements());
|
moveToBottom (intersect.elements());
|
||||||
ct->moveToTop (intersect.elements());
|
ct->moveToTop (intersect.elements());
|
||||||
|
|
||||||
@ -222,25 +287,27 @@ ConstraintTree::join (ConstraintTree* ct, bool assertWhenNotFound)
|
|||||||
CTNodes nodes = getNodesAtLevel (level);
|
CTNodes nodes = getNodesAtLevel (level);
|
||||||
|
|
||||||
Tuples tuples;
|
Tuples tuples;
|
||||||
CTNodes continuationNodes;
|
CTNodes continNodes;
|
||||||
getTuples (ct->root(),
|
getTuples (ct->root(),
|
||||||
Tuples(),
|
Tuples(),
|
||||||
intersect.size(),
|
intersect.size(),
|
||||||
tuples,
|
tuples,
|
||||||
continuationNodes);
|
continNodes);
|
||||||
|
|
||||||
for (unsigned i = 0; i < tuples.size(); i++) {
|
for (unsigned i = 0; i < tuples.size(); i++) {
|
||||||
bool tupleFounded = false;
|
bool tupleFounded = false;
|
||||||
for (unsigned j = 0; j < nodes.size(); j++) {
|
for (unsigned j = 0; j < nodes.size(); j++) {
|
||||||
tupleFounded |= join (nodes[j], tuples[i], 0, continuationNodes[i]);
|
tupleFounded |= join (nodes[j], tuples[i], 0, continNodes[i]);
|
||||||
}
|
}
|
||||||
if (assertWhenNotFound) {
|
if (assertWhenNotFound) {
|
||||||
assert (tupleFounded);
|
assert (tupleFounded);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
LogVarSet newLvs = ct->logVarSet_ - intersect;
|
|
||||||
logVars_.insert (logVars_.end(), newLvs.begin(), newLvs.end());
|
LogVars newLvs (ct->logVars().begin() + intersect.size(),
|
||||||
logVarSet_ |= newLvs;
|
ct->logVars().end());
|
||||||
|
Util::addToVector (logVars_, newLvs);
|
||||||
|
logVarSet_ |= LogVarSet (newLvs);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -280,6 +347,10 @@ ConstraintTree::rename (LogVar X_old, LogVar X_new)
|
|||||||
void
|
void
|
||||||
ConstraintTree::applySubstitution (const Substitution& theta)
|
ConstraintTree::applySubstitution (const Substitution& theta)
|
||||||
{
|
{
|
||||||
|
LogVars discardedLvs = theta.getDiscardedLogVars();
|
||||||
|
for (unsigned i = 0; i < discardedLvs.size(); i++) {
|
||||||
|
remove(discardedLvs[i]);
|
||||||
|
}
|
||||||
for (unsigned i = 0; i < logVars_.size(); i++) {
|
for (unsigned i = 0; i < logVars_.size(); i++) {
|
||||||
logVars_[i] = theta.newNameFor (logVars_[i]);
|
logVars_[i] = theta.newNameFor (logVars_[i]);
|
||||||
}
|
}
|
||||||
@ -308,11 +379,7 @@ ConstraintTree::remove (const LogVarSet& X)
|
|||||||
unsigned level = getLevel (X.front()) - 1;
|
unsigned level = getLevel (X.front()) - 1;
|
||||||
CTNodes nodes = getNodesAtLevel (level);
|
CTNodes nodes = getNodesAtLevel (level);
|
||||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||||
CTNodes childs = nodes[i]->childs();
|
nodes[i]->removeAndDeleteAllChilds();
|
||||||
for (unsigned j = 0; j < childs.size(); j++) {
|
|
||||||
nodes[i]->removeChild (childs[j]);
|
|
||||||
deleteSubtree (childs[j]);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
logVars_.resize (logVars_.size() - X.size());
|
logVars_.resize (logVars_.size() - X.size());
|
||||||
logVarSet_ -= X;
|
logVarSet_ -= X;
|
||||||
@ -545,16 +612,16 @@ ConstraintTree::split (
|
|||||||
for (unsigned i = 0; i < commNodes.size(); i++) {
|
for (unsigned i = 0; i < commNodes.size(); i++) {
|
||||||
commCt->root()->addChild (commNodes[i]);
|
commCt->root()->addChild (commNodes[i]);
|
||||||
}
|
}
|
||||||
//cout << commCt->tupleSet() << " + " ;
|
// cout << commCt->tupleSet() << " + " ;
|
||||||
//cout << exclCt->tupleSet() << " = " ;
|
// cout << exclCt->tupleSet() << " = " ;
|
||||||
//cout << tupleSet() << endl << endl;
|
// cout << tupleSet() << endl << endl;
|
||||||
// if (((commCt->tupleSet() | exclCt->tupleSet()) == tupleSet()) == false) {
|
// if (((commCt->tupleSet() | exclCt->tupleSet()) == tupleSet()) == false) {
|
||||||
// exportToGraphViz ("_fail.dot", true);
|
// exportToGraphViz ("_fail.dot", true);
|
||||||
// commCt->exportToGraphViz ("_fail_comm.dot", true);
|
// commCt->exportToGraphViz ("_fail_comm.dot", true);
|
||||||
// exclCt->exportToGraphViz ("_fail_excl.dot", true);
|
// exclCt->exportToGraphViz ("_fail_excl.dot", true);
|
||||||
// }
|
// }
|
||||||
assert ((commCt->tupleSet() | exclCt->tupleSet()) == tupleSet());
|
// assert ((commCt->tupleSet() | exclCt->tupleSet()) == tupleSet());
|
||||||
assert ((exclCt->tupleSet (stopLevel) & ct->tupleSet (stopLevel)).empty());
|
// assert ((exclCt->tupleSet (stopLevel) & ct->tupleSet (stopLevel)).empty());
|
||||||
return {commCt, exclCt};
|
return {commCt, exclCt};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -601,36 +668,32 @@ ConstraintTree::jointCountNormalize (
|
|||||||
LogVar X_new1,
|
LogVar X_new1,
|
||||||
LogVar X_new2)
|
LogVar X_new2)
|
||||||
{
|
{
|
||||||
exportToGraphViz ("C.dot", true);
|
|
||||||
commCt->exportToGraphViz ("C_comm.dot", true);
|
|
||||||
exclCt->exportToGraphViz ("C_exlc.dot", true);
|
|
||||||
unsigned N = getConditionalCount (X);
|
unsigned N = getConditionalCount (X);
|
||||||
cout << "My tuples: " << tupleSet() << endl;
|
// cout << "My tuples: " << tupleSet() << endl;
|
||||||
cout << "CommCt tuples: " << commCt->tupleSet() << endl;
|
// cout << "CommCt tuples: " << commCt->tupleSet() << endl;
|
||||||
cout << "ExclCt tuples: " << exclCt->tupleSet() << endl;
|
// cout << "ExclCt tuples: " << exclCt->tupleSet() << endl;
|
||||||
cout << "Counted Lv: " << X << endl;
|
// cout << "Counted Lv: " << X << endl;
|
||||||
cout << "Original N: " << N << endl;
|
// cout << "X_new1: " << X_new1 << endl;
|
||||||
cout << endl;
|
// cout << "X_new2: " << X_new2 << endl;
|
||||||
|
// cout << "Original N: " << N << endl;
|
||||||
|
// cout << endl;
|
||||||
|
|
||||||
ConstraintTrees normCts1 = commCt->countNormalize (X);
|
ConstraintTrees normCts1 = commCt->countNormalize (X);
|
||||||
vector<unsigned> counts1 (normCts1.size());
|
vector<unsigned> counts1 (normCts1.size());
|
||||||
for (unsigned i = 0; i < normCts1.size(); i++) {
|
for (unsigned i = 0; i < normCts1.size(); i++) {
|
||||||
counts1[i] = normCts1[i]->getConditionalCount (X);
|
counts1[i] = normCts1[i]->getConditionalCount (X);
|
||||||
cout << "normCts1[" << i << "] #" << counts1[i] ;
|
// cout << "normCts1[" << i << "] #" << counts1[i] ;
|
||||||
cout << " " << normCts1[i]->tupleSet() << endl;
|
// cout << " " << normCts1[i]->tupleSet() << endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
ConstraintTrees normCts2 = exclCt->countNormalize (X);
|
ConstraintTrees normCts2 = exclCt->countNormalize (X);
|
||||||
vector<unsigned> counts2 (normCts2.size());
|
vector<unsigned> counts2 (normCts2.size());
|
||||||
for (unsigned i = 0; i < normCts2.size(); i++) {
|
for (unsigned i = 0; i < normCts2.size(); i++) {
|
||||||
counts2[i] = normCts2[i]->getConditionalCount (X);
|
counts2[i] = normCts2[i]->getConditionalCount (X);
|
||||||
cout << "normCts2[" << i << "] #" << counts2[i] ;
|
// cout << "normCts2[" << i << "] #" << counts2[i] ;
|
||||||
cout << " " << normCts2[i]->tupleSet() << endl;
|
// cout << " " << normCts2[i]->tupleSet() << endl;
|
||||||
}
|
}
|
||||||
cout << endl;
|
// cout << endl;
|
||||||
|
|
||||||
cout << "1###### " << normCts1.size() << endl;
|
|
||||||
cout << "2###### " << normCts2.size() << endl;
|
|
||||||
|
|
||||||
ConstraintTree* excl1 = 0;
|
ConstraintTree* excl1 = 0;
|
||||||
for (unsigned i = 0; i < normCts1.size(); i++) {
|
for (unsigned i = 0; i < normCts1.size(); i++) {
|
||||||
@ -638,7 +701,7 @@ ConstraintTree::jointCountNormalize (
|
|||||||
excl1 = normCts1[i];
|
excl1 = normCts1[i];
|
||||||
normCts1.erase (normCts1.begin() + i);
|
normCts1.erase (normCts1.begin() + i);
|
||||||
counts1.erase (counts1.begin() + i);
|
counts1.erase (counts1.begin() + i);
|
||||||
cout << ">joint-count(" << N << ",0)" << endl;
|
// cout << "joint-count(" << N << ",0)" << endl;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -649,22 +712,21 @@ ConstraintTree::jointCountNormalize (
|
|||||||
excl2 = normCts2[i];
|
excl2 = normCts2[i];
|
||||||
normCts2.erase (normCts2.begin() + i);
|
normCts2.erase (normCts2.begin() + i);
|
||||||
counts2.erase (counts2.begin() + i);
|
counts2.erase (counts2.begin() + i);
|
||||||
cout << ">>joint-count(0," << N << ")" << endl;
|
// cout << "joint-count(0," << N << ")" << endl;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cout << "3###### " << normCts1.size() << endl;
|
|
||||||
cout << "4###### " << normCts2.size() << endl;
|
|
||||||
|
|
||||||
for (unsigned i = 0; i < normCts1.size(); i++) {
|
for (unsigned i = 0; i < normCts1.size(); i++) {
|
||||||
unsigned j;
|
unsigned j;
|
||||||
for (j = 0; counts1[i] + counts2[j] != N; j++) ;
|
for (j = 0; counts1[i] + counts2[j] != N; j++) ;
|
||||||
cout << "joint-count(" << counts1[i] << "," << counts2[j] << ")" << endl;
|
// cout << "joint-count(" << counts1[i] ;
|
||||||
|
// cout << "," << counts2[j] << ")" << endl;
|
||||||
const CTNodes& childs = normCts2[j]->root_->childs();
|
const CTNodes& childs = normCts2[j]->root_->childs();
|
||||||
for (unsigned k = 0; k < childs.size(); k++) {
|
for (unsigned k = 0; k < childs.size(); k++) {
|
||||||
normCts1[i]->root_->addChild (childs[k]);
|
normCts1[i]->root_->addChild (CTNode::copySubtree (childs[k]));
|
||||||
}
|
}
|
||||||
|
delete normCts2[j];
|
||||||
}
|
}
|
||||||
|
|
||||||
ConstraintTrees cts = normCts1;
|
ConstraintTrees cts = normCts1;
|
||||||
@ -683,11 +745,6 @@ ConstraintTree::jointCountNormalize (
|
|||||||
cts.push_back (excl2);
|
cts.push_back (excl2);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (unsigned i = 0; i < cts.size(); i++) {
|
|
||||||
stringstream ss;
|
|
||||||
ss << "aaacts_" << i + 1 << ".dot" ;
|
|
||||||
cts[i]->exportToGraphViz (ss.str().c_str(), true);
|
|
||||||
}
|
|
||||||
return cts;
|
return cts;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -735,11 +792,11 @@ ConstraintTree::expand (LogVar X)
|
|||||||
unsigned nrSymbols = getConditionalCount (X);
|
unsigned nrSymbols = getConditionalCount (X);
|
||||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||||
Symbols symbols;
|
Symbols symbols;
|
||||||
CTNodes childs = nodes[i]->childs();
|
const CTNodes& childs = nodes[i]->childs();
|
||||||
for (unsigned j = 0; j < childs.size(); j++) {
|
for (unsigned j = 0; j < childs.size(); j++) {
|
||||||
symbols.push_back (childs[j]->symbol());
|
symbols.push_back (childs[j]->symbol());
|
||||||
nodes[i]->removeChild (childs[j]);
|
|
||||||
}
|
}
|
||||||
|
nodes[i]->removeAndDeleteAllChilds();
|
||||||
CTNode* prev = nodes[i];
|
CTNode* prev = nodes[i];
|
||||||
assert (symbols.size() == nrSymbols);
|
assert (symbols.size() == nrSymbols);
|
||||||
for (unsigned j = 0; j < nrSymbols; j++) {
|
for (unsigned j = 0; j < nrSymbols; j++) {
|
||||||
@ -768,7 +825,7 @@ ConstraintTree::ground (LogVar X)
|
|||||||
ConstraintTrees cts;
|
ConstraintTrees cts;
|
||||||
const CTNodes& nodes = root_->childs();
|
const CTNodes& nodes = root_->childs();
|
||||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||||
CTNode* copy = copySubtree (nodes[i]);
|
CTNode* copy = CTNode::copySubtree (nodes[i]);
|
||||||
copy->setSymbol (nodes[i]->symbol());
|
copy->setSymbol (nodes[i]->symbol());
|
||||||
ConstraintTree* newCt = new ConstraintTree (logVars_);
|
ConstraintTree* newCt = new ConstraintTree (logVars_);
|
||||||
newCt->root()->addChild (copy);
|
newCt->root()->addChild (copy);
|
||||||
@ -884,7 +941,7 @@ ConstraintTree::join (
|
|||||||
if (currIdx == tuple.size() - 1) {
|
if (currIdx == tuple.size() - 1) {
|
||||||
const CTNodes& childs = appendNode->childs();
|
const CTNodes& childs = appendNode->childs();
|
||||||
for (unsigned i = 0; i < childs.size(); i++) {
|
for (unsigned i = 0; i < childs.size(); i++) {
|
||||||
n->addChild (copySubtree (childs[i]));
|
n->addChild (CTNode::copySubtree (childs[i]));
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@ -985,7 +1042,7 @@ ConstraintTree::countNormalize (
|
|||||||
{
|
{
|
||||||
if (n->level() == stopLevel) {
|
if (n->level() == stopLevel) {
|
||||||
return vector<pair<CTNode*, unsigned>>() = {
|
return vector<pair<CTNode*, unsigned>>() = {
|
||||||
make_pair (copySubtree (n), countTuples (n))
|
make_pair (CTNode::copySubtree (n), countTuples (n))
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1004,65 +1061,6 @@ ConstraintTree::countNormalize (
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/*
|
|
||||||
void
|
|
||||||
ConstraintTree::split (
|
|
||||||
CTNode* n1,
|
|
||||||
CTNode* n2,
|
|
||||||
CTNodes& nodes,
|
|
||||||
unsigned stopLevel)
|
|
||||||
{
|
|
||||||
CTNodes& childs1 = n1->childs();
|
|
||||||
CTNodes& childs2 = n2->childs();
|
|
||||||
// cout << string (n1->level() * 8, '-') << "Level = " << n1->level() + 1;
|
|
||||||
// cout << ", #I = " << childs1.size();
|
|
||||||
// cout << ", #J = " << childs2.size() << endl;
|
|
||||||
for (unsigned i = 0; i < childs1.size(); i++) {
|
|
||||||
for (unsigned j = 0; j < childs2.size(); j++) {
|
|
||||||
if (childs1[i]->symbol() != childs2[j]->symbol()) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (childs1[i]->level() == stopLevel) {
|
|
||||||
CTNode* newNode = copySubtree (childs1[i]);
|
|
||||||
newNode->setSymbol (childs1[i]->symbol());
|
|
||||||
nodes.push_back (newNode);
|
|
||||||
childs1[i]->setSymbol (Symbol::invalid());
|
|
||||||
break;
|
|
||||||
} else {
|
|
||||||
CTNodes lowerNodes;
|
|
||||||
split (childs1[i], childs2[j], lowerNodes, stopLevel);
|
|
||||||
if (lowerNodes.empty() == false) {
|
|
||||||
CTNode* me = new CTNode (childs1[i]->symbol(), childs1[i]->level());
|
|
||||||
for (unsigned k = 0; k < lowerNodes.size(); k++) {
|
|
||||||
me->addChild (lowerNodes[k]);
|
|
||||||
}
|
|
||||||
nodes.push_back (me);
|
|
||||||
}
|
|
||||||
if (childs1[i]->isLeaf()) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = 0; i < (int)childs1.size(); i++) {
|
|
||||||
// cout << string (n1->level() * 8, '-') << childs1[i];
|
|
||||||
if (childs1[i]->symbol() == Symbol::invalid()) {
|
|
||||||
// cout << " empty, removing..." ;
|
|
||||||
n1->removeChild (childs1[i]);
|
|
||||||
i --;
|
|
||||||
} else if (childs1[i]->isLeaf() &&
|
|
||||||
childs1[i]->level() != stopLevel) {
|
|
||||||
// cout << " leaf, removing..." ;
|
|
||||||
n1->removeChild (childs1[i]);
|
|
||||||
i --;
|
|
||||||
}
|
|
||||||
// cout << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
ConstraintTree::split (
|
ConstraintTree::split (
|
||||||
@ -1085,7 +1083,7 @@ ConstraintTree::split (
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (childs1[i]->level() == stopLevel) {
|
if (childs1[i]->level() == stopLevel) {
|
||||||
CTNode* newNode = copySubtree (childs1[i]);
|
CTNode* newNode = CTNode::copySubtree (childs1[i]);
|
||||||
nodes.push_back (newNode);
|
nodes.push_back (newNode);
|
||||||
childs1[i]->setSymbol (Symbol::invalid());
|
childs1[i]->setSymbol (Symbol::invalid());
|
||||||
} else {
|
} else {
|
||||||
@ -1103,11 +1101,11 @@ ConstraintTree::split (
|
|||||||
|
|
||||||
for (int i = 0; i < (int)childs1.size(); i++) {
|
for (int i = 0; i < (int)childs1.size(); i++) {
|
||||||
if (childs1[i]->symbol() == Symbol::invalid()) {
|
if (childs1[i]->symbol() == Symbol::invalid()) {
|
||||||
n1->removeChild (childs1[i]);
|
n1->removeAndDeleteChild (childs1[i]);
|
||||||
i --;
|
i --;
|
||||||
} else if (childs1[i]->isLeaf() &&
|
} else if (childs1[i]->isLeaf() &&
|
||||||
childs1[i]->level() != stopLevel) {
|
childs1[i]->level() != stopLevel) {
|
||||||
n1->removeChild (childs1[i]);
|
n1->removeAndDeleteChild (childs1[i]);
|
||||||
i --;
|
i --;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1141,29 +1139,3 @@ ConstraintTree::overlap (
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
CTNode*
|
|
||||||
ConstraintTree::copySubtree (const CTNode* n)
|
|
||||||
{
|
|
||||||
CTNode* newNode = new CTNode (*n);
|
|
||||||
const CTNodes& childs = n->childs();
|
|
||||||
for (unsigned i = 0; i < childs.size(); i++) {
|
|
||||||
newNode->addChild (copySubtree (childs[i]));
|
|
||||||
}
|
|
||||||
return newNode;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
ConstraintTree::deleteSubtree (CTNode* n)
|
|
||||||
{
|
|
||||||
assert (n);
|
|
||||||
const CTNodes& childs = n->childs();
|
|
||||||
for (unsigned i = 0; i < childs.size(); i++) {
|
|
||||||
deleteSubtree (childs[i]);
|
|
||||||
}
|
|
||||||
delete n;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
@ -21,7 +21,6 @@ typedef vector<ConstraintTree*> ConstraintTrees;
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CTNode
|
class CTNode
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
@ -48,9 +47,19 @@ class CTNode
|
|||||||
bool isLeaf (void) const { return childs_.empty(); }
|
bool isLeaf (void) const { return childs_.empty(); }
|
||||||
|
|
||||||
void addChild (CTNode*, bool = true);
|
void addChild (CTNode*, bool = true);
|
||||||
|
|
||||||
void removeChild (CTNode*);
|
void removeChild (CTNode*);
|
||||||
|
|
||||||
|
void removeAndDeleteChild (CTNode*);
|
||||||
|
|
||||||
|
void removeAndDeleteAllChilds (void);
|
||||||
|
|
||||||
SymbolSet childSymbols (void) const;
|
SymbolSet childSymbols (void) const;
|
||||||
|
|
||||||
|
static CTNode* copySubtree (const CTNode*);
|
||||||
|
|
||||||
|
static void deleteSubtree (CTNode*);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void updateChildLevels (CTNode*, unsigned);
|
void updateChildLevels (CTNode*, unsigned);
|
||||||
|
|
||||||
@ -59,17 +68,20 @@ class CTNode
|
|||||||
unsigned level_;
|
unsigned level_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
ostream& operator<< (ostream &out, const CTNode&);
|
ostream& operator<< (ostream &out, const CTNode&);
|
||||||
|
|
||||||
|
|
||||||
class ConstraintTree
|
class ConstraintTree
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
|
ConstraintTree (unsigned);
|
||||||
|
|
||||||
ConstraintTree (const LogVars&);
|
ConstraintTree (const LogVars&);
|
||||||
|
|
||||||
ConstraintTree (const LogVars&, const Tuples&);
|
ConstraintTree (const LogVars&, const Tuples&);
|
||||||
|
|
||||||
ConstraintTree (const ConstraintTree&);
|
ConstraintTree (const ConstraintTree&);
|
||||||
|
|
||||||
~ConstraintTree (void);
|
~ConstraintTree (void);
|
||||||
|
|
||||||
CTNode* root (void) const { return root_; }
|
CTNode* root (void) const { return root_; }
|
||||||
@ -95,89 +107,90 @@ class ConstraintTree
|
|||||||
}
|
}
|
||||||
|
|
||||||
void addTuple (const Tuple&);
|
void addTuple (const Tuple&);
|
||||||
|
|
||||||
bool containsTuple (const Tuple&);
|
bool containsTuple (const Tuple&);
|
||||||
|
|
||||||
void moveToTop (const LogVars&);
|
void moveToTop (const LogVars&);
|
||||||
|
|
||||||
void moveToBottom (const LogVars&);
|
void moveToBottom (const LogVars&);
|
||||||
|
|
||||||
void join (ConstraintTree*, bool = false);
|
void join (ConstraintTree*, bool = false);
|
||||||
|
|
||||||
unsigned getLevel (LogVar) const;
|
unsigned getLevel (LogVar) const;
|
||||||
|
|
||||||
void rename (LogVar, LogVar);
|
void rename (LogVar, LogVar);
|
||||||
|
|
||||||
void applySubstitution (const Substitution&);
|
void applySubstitution (const Substitution&);
|
||||||
|
|
||||||
void project (const LogVarSet&);
|
void project (const LogVarSet&);
|
||||||
|
|
||||||
void remove (const LogVarSet&);
|
void remove (const LogVarSet&);
|
||||||
|
|
||||||
bool isSingleton (LogVar);
|
bool isSingleton (LogVar);
|
||||||
|
|
||||||
LogVarSet singletons (void);
|
LogVarSet singletons (void);
|
||||||
|
|
||||||
TupleSet tupleSet (unsigned = 0) const;
|
TupleSet tupleSet (unsigned = 0) const;
|
||||||
|
|
||||||
TupleSet tupleSet (const LogVars&);
|
TupleSet tupleSet (const LogVars&);
|
||||||
|
|
||||||
unsigned size (void) const;
|
unsigned size (void) const;
|
||||||
|
|
||||||
unsigned nrSymbols (LogVar);
|
unsigned nrSymbols (LogVar);
|
||||||
|
|
||||||
void exportToGraphViz (const char*, bool = false) const;
|
void exportToGraphViz (const char*, bool = false) const;
|
||||||
|
|
||||||
bool isCountNormalized (const LogVarSet&);
|
bool isCountNormalized (const LogVarSet&);
|
||||||
|
|
||||||
unsigned getConditionalCount (const LogVarSet&);
|
unsigned getConditionalCount (const LogVarSet&);
|
||||||
|
|
||||||
TinySet<unsigned> getConditionalCounts (const LogVarSet&);
|
TinySet<unsigned> getConditionalCounts (const LogVarSet&);
|
||||||
|
|
||||||
bool isCarteesianProduct (const LogVarSet&) const;
|
bool isCarteesianProduct (const LogVarSet&) const;
|
||||||
|
|
||||||
std::pair<ConstraintTree*, ConstraintTree*> split (
|
std::pair<ConstraintTree*, ConstraintTree*> split (
|
||||||
const Tuple&,
|
const Tuple&, unsigned);
|
||||||
unsigned);
|
|
||||||
|
|
||||||
std::pair<ConstraintTree*, ConstraintTree*> split (
|
std::pair<ConstraintTree*, ConstraintTree*> split (
|
||||||
const ConstraintTree*,
|
const ConstraintTree*, unsigned) const;
|
||||||
unsigned) const;
|
|
||||||
|
|
||||||
ConstraintTrees countNormalize (const LogVarSet&);
|
ConstraintTrees countNormalize (const LogVarSet&);
|
||||||
|
|
||||||
ConstraintTrees jointCountNormalize (
|
ConstraintTrees jointCountNormalize (
|
||||||
ConstraintTree*,
|
ConstraintTree*, ConstraintTree*, LogVar, LogVar, LogVar);
|
||||||
ConstraintTree*,
|
|
||||||
LogVar,
|
|
||||||
LogVar,
|
|
||||||
LogVar);
|
|
||||||
|
|
||||||
static bool identical (
|
static bool identical (
|
||||||
const ConstraintTree*,
|
const ConstraintTree*, const ConstraintTree*, unsigned);
|
||||||
const ConstraintTree*,
|
|
||||||
unsigned);
|
|
||||||
|
|
||||||
static bool overlap (
|
static bool overlap (
|
||||||
const ConstraintTree*,
|
const ConstraintTree*, const ConstraintTree*, unsigned);
|
||||||
const ConstraintTree*,
|
|
||||||
unsigned);
|
|
||||||
|
|
||||||
LogVars expand (LogVar);
|
LogVars expand (LogVar);
|
||||||
ConstraintTrees ground (LogVar);
|
ConstraintTrees ground (LogVar);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
unsigned countTuples (const CTNode*) const;
|
unsigned countTuples (const CTNode*) const;
|
||||||
|
|
||||||
CTNodes getNodesBelow (CTNode*) const;
|
CTNodes getNodesBelow (CTNode*) const;
|
||||||
|
|
||||||
CTNodes getNodesAtLevel (unsigned) const;
|
CTNodes getNodesAtLevel (unsigned) const;
|
||||||
|
|
||||||
void swapLogVar (LogVar);
|
void swapLogVar (LogVar);
|
||||||
|
|
||||||
bool join (CTNode*, const Tuple&, unsigned, CTNode*);
|
bool join (CTNode*, const Tuple&, unsigned, CTNode*);
|
||||||
|
|
||||||
bool indenticalSubtrees (
|
bool indenticalSubtrees (
|
||||||
const CTNode*,
|
const CTNode*, const CTNode*, bool) const;
|
||||||
const CTNode*,
|
|
||||||
bool) const;
|
|
||||||
|
|
||||||
void getTuples (
|
void getTuples (CTNode*, Tuples, unsigned, Tuples&, CTNodes&) const;
|
||||||
CTNode*,
|
|
||||||
Tuples,
|
|
||||||
unsigned,
|
|
||||||
Tuples&,
|
|
||||||
CTNodes&) const;
|
|
||||||
|
|
||||||
vector<std::pair<CTNode*, unsigned>> countNormalize (
|
vector<std::pair<CTNode*, unsigned>> countNormalize (
|
||||||
const CTNode*,
|
const CTNode*, unsigned);
|
||||||
unsigned);
|
|
||||||
|
|
||||||
static void split (
|
static void split (
|
||||||
CTNode*,
|
CTNode*, CTNode*, CTNodes&, unsigned);
|
||||||
CTNode*,
|
|
||||||
CTNodes&,
|
|
||||||
unsigned);
|
|
||||||
|
|
||||||
static bool overlap (const CTNode*, const CTNode*, unsigned);
|
static bool overlap (const CTNode*, const CTNode*, unsigned);
|
||||||
static CTNode* copySubtree (const CTNode*);
|
|
||||||
static void deleteSubtree (CTNode*);
|
|
||||||
|
|
||||||
CTNode* root_;
|
CTNode* root_;
|
||||||
LogVars logVars_;
|
LogVars logVars_;
|
||||||
|
@ -1,45 +0,0 @@
|
|||||||
#ifndef HORUS_DISTRIBUTION_H
|
|
||||||
#define HORUS_DISTRIBUTION_H
|
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "Horus.h"
|
|
||||||
|
|
||||||
//TODO die die die die die
|
|
||||||
|
|
||||||
using namespace std;
|
|
||||||
|
|
||||||
|
|
||||||
struct Distribution
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
Distribution (int id)
|
|
||||||
{
|
|
||||||
this->id = id;
|
|
||||||
}
|
|
||||||
|
|
||||||
Distribution (const Params& params, int id = -1)
|
|
||||||
{
|
|
||||||
this->id = id;
|
|
||||||
this->params = params;
|
|
||||||
}
|
|
||||||
|
|
||||||
void updateParameters (const Params& params)
|
|
||||||
{
|
|
||||||
this->params = params;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool shared (void)
|
|
||||||
{
|
|
||||||
return id != -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
int id;
|
|
||||||
Params params;
|
|
||||||
|
|
||||||
private:
|
|
||||||
DISALLOW_COPY_AND_ASSIGN (Distribution);
|
|
||||||
};
|
|
||||||
|
|
||||||
#endif // HORUS_DISTRIBUTION_H
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
|||||||
#include <limits>
|
#include <limits>
|
||||||
|
|
||||||
|
#include <fstream>
|
||||||
|
|
||||||
#include "ElimGraph.h"
|
#include "ElimGraph.h"
|
||||||
#include "BayesNet.h"
|
#include "BayesNet.h"
|
||||||
|
|
||||||
|
@ -17,15 +17,15 @@ enum ElimHeuristic
|
|||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
class EgNode : public VarNode {
|
class EgNode : public VarNode
|
||||||
|
{
|
||||||
public:
|
public:
|
||||||
EgNode (VarNode* var) : VarNode (var) { }
|
EgNode (VarNode* var) : VarNode (var) { }
|
||||||
void addNeighbor (EgNode* n)
|
|
||||||
{
|
void addNeighbor (EgNode* n) { neighs_.push_back (n); }
|
||||||
neighs_.push_back (n);
|
|
||||||
}
|
|
||||||
|
|
||||||
const vector<EgNode*>& neighbors (void) const { return neighs_; }
|
const vector<EgNode*>& neighbors (void) const { return neighs_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
vector<EgNode*> neighs_;
|
vector<EgNode*> neighs_;
|
||||||
};
|
};
|
||||||
@ -35,6 +35,7 @@ class ElimGraph
|
|||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
ElimGraph (const BayesNet&);
|
ElimGraph (const BayesNet&);
|
||||||
|
|
||||||
~ElimGraph (void);
|
~ElimGraph (void);
|
||||||
|
|
||||||
void addEdge (EgNode* n1, EgNode* n2)
|
void addEdge (EgNode* n1, EgNode* n2)
|
||||||
@ -43,12 +44,18 @@ class ElimGraph
|
|||||||
n1->addNeighbor (n2);
|
n1->addNeighbor (n2);
|
||||||
n2->addNeighbor (n1);
|
n2->addNeighbor (n1);
|
||||||
}
|
}
|
||||||
|
|
||||||
void addNode (EgNode*);
|
void addNode (EgNode*);
|
||||||
|
|
||||||
EgNode* getEgNode (VarId) const;
|
EgNode* getEgNode (VarId) const;
|
||||||
|
|
||||||
VarIds getEliminatingOrder (const VarIds&);
|
VarIds getEliminatingOrder (const VarIds&);
|
||||||
|
|
||||||
void printGraphicalModel (void) const;
|
void printGraphicalModel (void) const;
|
||||||
|
|
||||||
void exportToGraphViz (const char*, bool = true,
|
void exportToGraphViz (const char*, bool = true,
|
||||||
const VarIds& = VarIds()) const;
|
const VarIds& = VarIds()) const;
|
||||||
|
|
||||||
void setIndexes();
|
void setIndexes();
|
||||||
|
|
||||||
static void setEliminationHeuristic (ElimHeuristic h)
|
static void setEliminationHeuristic (ElimHeuristic h)
|
||||||
@ -58,13 +65,18 @@ class ElimGraph
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
EgNode* getLowestCostNode (void) const;
|
EgNode* getLowestCostNode (void) const;
|
||||||
unsigned getNeighborsCost (const EgNode*) const;
|
|
||||||
unsigned getWeightCost (const EgNode*) const;
|
|
||||||
unsigned getFillCost (const EgNode*) const;
|
|
||||||
unsigned getWeightedFillCost (const EgNode*) const;
|
|
||||||
void connectAllNeighbors (const EgNode*);
|
|
||||||
bool neighbors (const EgNode*, const EgNode*) const;
|
|
||||||
|
|
||||||
|
unsigned getNeighborsCost (const EgNode*) const;
|
||||||
|
|
||||||
|
unsigned getWeightCost (const EgNode*) const;
|
||||||
|
|
||||||
|
unsigned getFillCost (const EgNode*) const;
|
||||||
|
|
||||||
|
unsigned getWeightedFillCost (const EgNode*) const;
|
||||||
|
|
||||||
|
void connectAllNeighbors (const EgNode*);
|
||||||
|
|
||||||
|
bool neighbors (const EgNode*, const EgNode*) const;
|
||||||
|
|
||||||
vector<EgNode*> nodes_;
|
vector<EgNode*> nodes_;
|
||||||
vector<bool> marked_;
|
vector<bool> marked_;
|
||||||
|
@ -8,7 +8,7 @@
|
|||||||
|
|
||||||
#include "Factor.h"
|
#include "Factor.h"
|
||||||
#include "Indexer.h"
|
#include "Indexer.h"
|
||||||
#include "Util.h"
|
|
||||||
|
|
||||||
|
|
||||||
Factor::Factor (const Factor& g)
|
Factor::Factor (const Factor& g)
|
||||||
@ -18,206 +18,73 @@ Factor::Factor (const Factor& g)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
Factor::Factor (VarId vid, unsigned nStates)
|
Factor::Factor (VarId vid, unsigned nrStates)
|
||||||
{
|
{
|
||||||
varids_.push_back (vid);
|
args_.push_back (vid);
|
||||||
ranges_.push_back (nStates);
|
ranges_.push_back (nrStates);
|
||||||
dist_ = new Distribution (Params (nStates, 1.0));
|
params_.resize (nrStates, 1.0);
|
||||||
|
distId_ = Util::maxUnsigned();
|
||||||
|
assert (params_.size() == Util::expectedSize (ranges_));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Factor::Factor (const VarNodes& vars)
|
Factor::Factor (const VarNodes& vars)
|
||||||
{
|
{
|
||||||
int nParams = 1;
|
int nrParams = 1;
|
||||||
for (unsigned i = 0; i < vars.size(); i++) {
|
for (unsigned i = 0; i < vars.size(); i++) {
|
||||||
varids_.push_back (vars[i]->varId());
|
args_.push_back (vars[i]->varId());
|
||||||
ranges_.push_back (vars[i]->nrStates());
|
ranges_.push_back (vars[i]->nrStates());
|
||||||
nParams *= vars[i]->nrStates();
|
nrParams *= vars[i]->nrStates();
|
||||||
}
|
}
|
||||||
// create a uniform distribution
|
double val = 1.0 / nrParams;
|
||||||
double val = 1.0 / nParams;
|
params_.resize (nrParams, val);
|
||||||
dist_ = new Distribution (Params (nParams, val));
|
distId_ = Util::maxUnsigned();
|
||||||
|
assert (params_.size() == Util::expectedSize (ranges_));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Factor::Factor (VarId vid, unsigned nStates, const Params& params)
|
Factor::Factor (
|
||||||
|
VarId vid,
|
||||||
|
unsigned nrStates,
|
||||||
|
const Params& params)
|
||||||
{
|
{
|
||||||
varids_.push_back (vid);
|
args_.push_back (vid);
|
||||||
ranges_.push_back (nStates);
|
ranges_.push_back (nrStates);
|
||||||
dist_ = new Distribution (params);
|
params_ = params;
|
||||||
|
distId_ = Util::maxUnsigned();
|
||||||
|
assert (params_.size() == Util::expectedSize (ranges_));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Factor::Factor (VarNodes& vars, Distribution* dist)
|
Factor::Factor (
|
||||||
|
const VarNodes& vars,
|
||||||
|
const Params& params,
|
||||||
|
unsigned distId)
|
||||||
{
|
{
|
||||||
for (unsigned i = 0; i < vars.size(); i++) {
|
for (unsigned i = 0; i < vars.size(); i++) {
|
||||||
varids_.push_back (vars[i]->varId());
|
args_.push_back (vars[i]->varId());
|
||||||
ranges_.push_back (vars[i]->nrStates());
|
ranges_.push_back (vars[i]->nrStates());
|
||||||
}
|
}
|
||||||
dist_ = dist;
|
params_ = params;
|
||||||
|
distId_ = distId;
|
||||||
|
assert (params_.size() == Util::expectedSize (ranges_));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Factor::Factor (const VarNodes& vars, const Params& params)
|
Factor::Factor (
|
||||||
{
|
const VarIds& vids,
|
||||||
for (unsigned i = 0; i < vars.size(); i++) {
|
|
||||||
varids_.push_back (vars[i]->varId());
|
|
||||||
ranges_.push_back (vars[i]->nrStates());
|
|
||||||
}
|
|
||||||
dist_ = new Distribution (params);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Factor::Factor (const VarIds& vids,
|
|
||||||
const Ranges& ranges,
|
const Ranges& ranges,
|
||||||
const Params& params)
|
const Params& params)
|
||||||
{
|
{
|
||||||
varids_ = vids;
|
args_ = vids;
|
||||||
ranges_ = ranges;
|
ranges_ = ranges;
|
||||||
dist_ = new Distribution (params);
|
params_ = params;
|
||||||
}
|
distId_ = Util::maxUnsigned();
|
||||||
|
assert (params_.size() == Util::expectedSize (ranges_));
|
||||||
|
|
||||||
|
|
||||||
Factor::~Factor (void)
|
|
||||||
{
|
|
||||||
if (dist_->shared() == false) {
|
|
||||||
delete dist_;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Factor::setParameters (const Params& params)
|
|
||||||
{
|
|
||||||
assert (dist_->params.size() == params.size());
|
|
||||||
dist_->params = params;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Factor::copyFromFactor (const Factor& g)
|
|
||||||
{
|
|
||||||
varids_ = g.getVarIds();
|
|
||||||
ranges_ = g.getRanges();
|
|
||||||
dist_ = new Distribution (g.getParameters());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Factor::multiply (const Factor& g)
|
|
||||||
{
|
|
||||||
if (varids_.size() == 0) {
|
|
||||||
copyFromFactor (g);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const VarIds& g_varids = g.getVarIds();
|
|
||||||
const Ranges& g_ranges = g.getRanges();
|
|
||||||
const Params& g_params = g.getParameters();
|
|
||||||
|
|
||||||
if (varids_ == g_varids) {
|
|
||||||
// optimization: if the factors contain the same set of variables,
|
|
||||||
// we can do a 1 to 1 operation on the parameters
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
Util::add (dist_->params, g_params);
|
|
||||||
} else {
|
|
||||||
Util::multiply (dist_->params, g_params);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
bool sharedVars = false;
|
|
||||||
vector<unsigned> gvarpos;
|
|
||||||
for (unsigned i = 0; i < g_varids.size(); i++) {
|
|
||||||
int idx = indexOf (g_varids[i]);
|
|
||||||
if (idx == -1) {
|
|
||||||
insertVariable (g_varids[i], g_ranges[i]);
|
|
||||||
gvarpos.push_back (varids_.size() - 1);
|
|
||||||
} else {
|
|
||||||
sharedVars = true;
|
|
||||||
gvarpos.push_back (idx);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (sharedVars == false) {
|
|
||||||
// optimization: if the original factors doesn't have common variables,
|
|
||||||
// we don't need to marry the states of the common variables
|
|
||||||
unsigned count = 0;
|
|
||||||
for (unsigned i = 0; i < dist_->params.size(); i++) {
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
dist_->params[i] += g_params[count];
|
|
||||||
} else {
|
|
||||||
dist_->params[i] *= g_params[count];
|
|
||||||
}
|
|
||||||
count ++;
|
|
||||||
if (count >= g_params.size()) {
|
|
||||||
count = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
StatesIndexer indexer (ranges_, false);
|
|
||||||
while (indexer.valid()) {
|
|
||||||
unsigned g_li = 0;
|
|
||||||
unsigned prod = 1;
|
|
||||||
for (int j = gvarpos.size() - 1; j >= 0; j--) {
|
|
||||||
g_li += indexer[gvarpos[j]] * prod;
|
|
||||||
prod *= g_ranges[j];
|
|
||||||
}
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
dist_->params[indexer] += g_params[g_li];
|
|
||||||
} else {
|
|
||||||
dist_->params[indexer] *= g_params[g_li];
|
|
||||||
}
|
|
||||||
++ indexer;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Factor::insertVariable (VarId varId, unsigned nrStates)
|
|
||||||
{
|
|
||||||
assert (indexOf (varId) == -1);
|
|
||||||
Params oldParams = dist_->params;
|
|
||||||
dist_->params.clear();
|
|
||||||
dist_->params.reserve (oldParams.size() * nrStates);
|
|
||||||
for (unsigned i = 0; i < oldParams.size(); i++) {
|
|
||||||
for (unsigned reps = 0; reps < nrStates; reps++) {
|
|
||||||
dist_->params.push_back (oldParams[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
varids_.push_back (varId);
|
|
||||||
ranges_.push_back (nrStates);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Factor::insertVariables (const VarIds& varIds, const Ranges& ranges)
|
|
||||||
{
|
|
||||||
Params oldParams = dist_->params;
|
|
||||||
unsigned nrStates = 1;
|
|
||||||
for (unsigned i = 0; i < varIds.size(); i++) {
|
|
||||||
assert (indexOf (varIds[i]) == -1);
|
|
||||||
varids_.push_back (varIds[i]);
|
|
||||||
ranges_.push_back (ranges[i]);
|
|
||||||
nrStates *= ranges[i];
|
|
||||||
}
|
|
||||||
dist_->params.clear();
|
|
||||||
dist_->params.reserve (oldParams.size() * nrStates);
|
|
||||||
for (unsigned i = 0; i < oldParams.size(); i++) {
|
|
||||||
for (unsigned reps = 0; reps < nrStates; reps++) {
|
|
||||||
dist_->params.push_back (oldParams[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -226,10 +93,10 @@ void
|
|||||||
Factor::sumOutAllExcept (VarId vid)
|
Factor::sumOutAllExcept (VarId vid)
|
||||||
{
|
{
|
||||||
assert (indexOf (vid) != -1);
|
assert (indexOf (vid) != -1);
|
||||||
while (varids_.back() != vid) {
|
while (args_.back() != vid) {
|
||||||
sumOutLastVariable();
|
sumOutLastVariable();
|
||||||
}
|
}
|
||||||
while (varids_.front() != vid) {
|
while (args_.front() != vid) {
|
||||||
sumOutFirstVariable();
|
sumOutFirstVariable();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -239,9 +106,10 @@ Factor::sumOutAllExcept (VarId vid)
|
|||||||
void
|
void
|
||||||
Factor::sumOutAllExcept (const VarIds& vids)
|
Factor::sumOutAllExcept (const VarIds& vids)
|
||||||
{
|
{
|
||||||
for (unsigned i = 0; i < varids_.size(); i++) {
|
for (int i = 0; i < (int)args_.size(); i++) {
|
||||||
if (std::find (vids.begin(), vids.end(), varids_[i]) == vids.end()) {
|
if (Util::contains (vids, args_[i]) == false) {
|
||||||
sumOut (varids_[i]);
|
sumOut (args_[i]);
|
||||||
|
i --;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -254,11 +122,11 @@ Factor::sumOut (VarId vid)
|
|||||||
int idx = indexOf (vid);
|
int idx = indexOf (vid);
|
||||||
assert (idx != -1);
|
assert (idx != -1);
|
||||||
|
|
||||||
if (vid == varids_.back()) {
|
if (vid == args_.back()) {
|
||||||
sumOutLastVariable(); // optimization
|
sumOutLastVariable(); // optimization
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (vid == varids_.front()) {
|
if (vid == args_.front()) {
|
||||||
sumOutFirstVariable(); // optimization
|
sumOutFirstVariable(); // optimization
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -271,7 +139,7 @@ Factor::sumOut (VarId vid)
|
|||||||
// on the left of `var', with the states of the remaining vars fixed
|
// on the left of `var', with the states of the remaining vars fixed
|
||||||
unsigned leftVarOffset = 1;
|
unsigned leftVarOffset = 1;
|
||||||
|
|
||||||
for (int i = varids_.size() - 1; i > idx; i--) {
|
for (int i = args_.size() - 1; i > idx; i--) {
|
||||||
varOffset *= ranges_[i];
|
varOffset *= ranges_[i];
|
||||||
leftVarOffset *= ranges_[i];
|
leftVarOffset *= ranges_[i];
|
||||||
}
|
}
|
||||||
@ -280,25 +148,24 @@ Factor::sumOut (VarId vid)
|
|||||||
unsigned offset = 0;
|
unsigned offset = 0;
|
||||||
unsigned count1 = 0;
|
unsigned count1 = 0;
|
||||||
unsigned count2 = 0;
|
unsigned count2 = 0;
|
||||||
unsigned newpsSize = dist_->params.size() / ranges_[idx];
|
unsigned newpsSize = params_.size() / ranges_[idx];
|
||||||
|
|
||||||
Params newps;
|
Params newps;
|
||||||
newps.reserve (newpsSize);
|
newps.reserve (newpsSize);
|
||||||
Params& params = dist_->params;
|
|
||||||
|
|
||||||
while (newps.size() < newpsSize) {
|
while (newps.size() < newpsSize) {
|
||||||
double sum = Util::addIdenty();
|
double sum = LogAware::addIdenty();
|
||||||
for (unsigned i = 0; i < ranges_[idx]; i++) {
|
for (unsigned i = 0; i < ranges_[idx]; i++) {
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
Util::logSum (sum, params[offset]);
|
sum = Util::logSum (sum, params_[offset]);
|
||||||
} else {
|
} else {
|
||||||
sum += params[offset];
|
sum += params_[offset];
|
||||||
}
|
}
|
||||||
offset += varOffset;
|
offset += varOffset;
|
||||||
}
|
}
|
||||||
newps.push_back (sum);
|
newps.push_back (sum);
|
||||||
count1 ++;
|
count1 ++;
|
||||||
if (idx == (int)varids_.size() - 1) {
|
if (idx == (int)args_.size() - 1) {
|
||||||
offset = count1 * ranges_[idx];
|
offset = count1 * ranges_[idx];
|
||||||
} else {
|
} else {
|
||||||
if (((offset - varOffset + 1) % leftVarOffset) == 0) {
|
if (((offset - varOffset + 1) % leftVarOffset) == 0) {
|
||||||
@ -308,9 +175,9 @@ Factor::sumOut (VarId vid)
|
|||||||
offset = (leftVarOffset * count2) + count1;
|
offset = (leftVarOffset * count2) + count1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
varids_.erase (varids_.begin() + idx);
|
args_.erase (args_.begin() + idx);
|
||||||
ranges_.erase (ranges_.begin() + idx);
|
ranges_.erase (ranges_.begin() + idx);
|
||||||
dist_->params = newps;
|
params_ = newps;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -318,20 +185,19 @@ Factor::sumOut (VarId vid)
|
|||||||
void
|
void
|
||||||
Factor::sumOutFirstVariable (void)
|
Factor::sumOutFirstVariable (void)
|
||||||
{
|
{
|
||||||
Params& params = dist_->params;
|
|
||||||
unsigned nStates = ranges_.front();
|
unsigned nStates = ranges_.front();
|
||||||
unsigned sep = params.size() / nStates;
|
unsigned sep = params_.size() / nStates;
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
for (unsigned i = sep; i < params.size(); i++) {
|
for (unsigned i = sep; i < params_.size(); i++) {
|
||||||
Util::logSum (params[i % sep], params[i]);
|
params_[i % sep] = Util::logSum (params_[i % sep], params_[i]);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (unsigned i = sep; i < params.size(); i++) {
|
for (unsigned i = sep; i < params_.size(); i++) {
|
||||||
params[i % sep] += params[i];
|
params_[i % sep] += params_[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
params.resize (sep);
|
params_.resize (sep);
|
||||||
varids_.erase (varids_.begin());
|
args_.erase (args_.begin());
|
||||||
ranges_.erase (ranges_.begin());
|
ranges_.erase (ranges_.begin());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -340,143 +206,55 @@ Factor::sumOutFirstVariable (void)
|
|||||||
void
|
void
|
||||||
Factor::sumOutLastVariable (void)
|
Factor::sumOutLastVariable (void)
|
||||||
{
|
{
|
||||||
Params& params = dist_->params;
|
|
||||||
unsigned nStates = ranges_.back();
|
unsigned nStates = ranges_.back();
|
||||||
unsigned idx1 = 0;
|
unsigned idx1 = 0;
|
||||||
unsigned idx2 = 0;
|
unsigned idx2 = 0;
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
while (idx1 < params.size()) {
|
while (idx1 < params_.size()) {
|
||||||
params[idx2] = params[idx1];
|
params_[idx2] = params_[idx1];
|
||||||
idx1 ++;
|
idx1 ++;
|
||||||
for (unsigned j = 1; j < nStates; j++) {
|
for (unsigned j = 1; j < nStates; j++) {
|
||||||
Util::logSum (params[idx2], params[idx1]);
|
params_[idx2] = Util::logSum (params_[idx2], params_[idx1]);
|
||||||
idx1 ++;
|
idx1 ++;
|
||||||
}
|
}
|
||||||
idx2 ++;
|
idx2 ++;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
while (idx1 < params.size()) {
|
while (idx1 < params_.size()) {
|
||||||
params[idx2] = params[idx1];
|
params_[idx2] = params_[idx1];
|
||||||
idx1 ++;
|
idx1 ++;
|
||||||
for (unsigned j = 1; j < nStates; j++) {
|
for (unsigned j = 1; j < nStates; j++) {
|
||||||
params[idx2] += params[idx1];
|
params_[idx2] += params_[idx1];
|
||||||
idx1 ++;
|
idx1 ++;
|
||||||
}
|
}
|
||||||
idx2 ++;
|
idx2 ++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
params.resize (idx2);
|
params_.resize (idx2);
|
||||||
varids_.pop_back();
|
args_.pop_back();
|
||||||
ranges_.pop_back();
|
ranges_.pop_back();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
Factor::orderVariables (void)
|
Factor::multiply (Factor& g)
|
||||||
{
|
{
|
||||||
VarIds sortedVarIds = varids_;
|
if (args_.size() == 0) {
|
||||||
sort (sortedVarIds.begin(), sortedVarIds.end());
|
copyFromFactor (g);
|
||||||
reorderVariables (sortedVarIds);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Factor::reorderVariables (const VarIds& newVarIds)
|
|
||||||
{
|
|
||||||
assert (newVarIds.size() == varids_.size());
|
|
||||||
if (newVarIds == varids_) {
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
TFactor<VarId>::multiply (g);
|
||||||
Ranges newRanges;
|
|
||||||
vector<unsigned> positions;
|
|
||||||
for (unsigned i = 0; i < newVarIds.size(); i++) {
|
|
||||||
unsigned idx = indexOf (newVarIds[i]);
|
|
||||||
newRanges.push_back (ranges_[idx]);
|
|
||||||
positions.push_back (idx);
|
|
||||||
}
|
|
||||||
|
|
||||||
unsigned N = ranges_.size();
|
|
||||||
Params newParams (dist_->params.size());
|
|
||||||
for (unsigned i = 0; i < dist_->params.size(); i++) {
|
|
||||||
unsigned li = i;
|
|
||||||
// calculate vector index corresponding to linear index
|
|
||||||
vector<unsigned> vi (N);
|
|
||||||
for (int k = N-1; k >= 0; k--) {
|
|
||||||
vi[k] = li % ranges_[k];
|
|
||||||
li /= ranges_[k];
|
|
||||||
}
|
|
||||||
// convert permuted vector index to corresponding linear index
|
|
||||||
unsigned prod = 1;
|
|
||||||
unsigned new_li = 0;
|
|
||||||
for (int k = N-1; k >= 0; k--) {
|
|
||||||
new_li += vi[positions[k]] * prod;
|
|
||||||
prod *= ranges_[positions[k]];
|
|
||||||
}
|
|
||||||
newParams[new_li] = dist_->params[i];
|
|
||||||
}
|
|
||||||
varids_ = newVarIds;
|
|
||||||
ranges_ = newRanges;
|
|
||||||
dist_->params = newParams;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
Factor::absorveEvidence (VarId vid, unsigned evidence)
|
Factor::reorderAccordingVarIds (void)
|
||||||
{
|
{
|
||||||
int idx = indexOf (vid);
|
VarIds sortedVarIds = args_;
|
||||||
assert (idx != -1);
|
sort (sortedVarIds.begin(), sortedVarIds.end());
|
||||||
|
reorderArguments (sortedVarIds);
|
||||||
Params oldParams = dist_->params;
|
|
||||||
dist_->params.clear();
|
|
||||||
dist_->params.reserve (oldParams.size() / ranges_[idx]);
|
|
||||||
StatesIndexer indexer (ranges_);
|
|
||||||
for (unsigned i = 0; i < evidence; i++) {
|
|
||||||
indexer.increment (idx);
|
|
||||||
}
|
|
||||||
while (indexer.valid()) {
|
|
||||||
dist_->params.push_back (oldParams[indexer]);
|
|
||||||
indexer.incrementExcluding (idx);
|
|
||||||
}
|
|
||||||
varids_.erase (varids_.begin() + idx);
|
|
||||||
ranges_.erase (ranges_.begin() + idx);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Factor::normalize (void)
|
|
||||||
{
|
|
||||||
Util::normalize (dist_->params);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
|
||||||
Factor::contains (const VarIds& vars) const
|
|
||||||
{
|
|
||||||
for (unsigned i = 0; i < vars.size(); i++) {
|
|
||||||
if (indexOf (vars[i]) == -1) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
int
|
|
||||||
Factor::indexOf (VarId vid) const
|
|
||||||
{
|
|
||||||
for (unsigned i = 0; i < varids_.size(); i++) {
|
|
||||||
if (varids_[i] == vid) {
|
|
||||||
return i;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return -1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -486,9 +264,9 @@ Factor::getLabel (void) const
|
|||||||
{
|
{
|
||||||
stringstream ss;
|
stringstream ss;
|
||||||
ss << "f(" ;
|
ss << "f(" ;
|
||||||
for (unsigned i = 0; i < varids_.size(); i++) {
|
for (unsigned i = 0; i < args_.size(); i++) {
|
||||||
if (i != 0) ss << "," ;
|
if (i != 0) ss << "," ;
|
||||||
ss << VarNode (varids_[i], ranges_[i]).label();
|
ss << VarNode (args_[i], ranges_[i]).label();
|
||||||
}
|
}
|
||||||
ss << ")" ;
|
ss << ")" ;
|
||||||
return ss.str();
|
return ss.str();
|
||||||
@ -500,13 +278,13 @@ void
|
|||||||
Factor::print (void) const
|
Factor::print (void) const
|
||||||
{
|
{
|
||||||
VarNodes vars;
|
VarNodes vars;
|
||||||
for (unsigned i = 0; i < varids_.size(); i++) {
|
for (unsigned i = 0; i < args_.size(); i++) {
|
||||||
vars.push_back (new VarNode (varids_[i], ranges_[i]));
|
vars.push_back (new VarNode (args_[i], ranges_[i]));
|
||||||
}
|
}
|
||||||
vector<string> jointStrings = Util::getJointStateStrings (vars);
|
vector<string> jointStrings = Util::getJointStateStrings (vars);
|
||||||
for (unsigned i = 0; i < dist_->params.size(); i++) {
|
for (unsigned i = 0; i < params_.size(); i++) {
|
||||||
cout << "f(" << jointStrings[i] << ")" ;
|
cout << "f(" << jointStrings[i] << ")" ;
|
||||||
cout << " = " << dist_->params[i] << endl;
|
cout << " = " << params_[i] << endl;
|
||||||
}
|
}
|
||||||
cout << endl;
|
cout << endl;
|
||||||
for (unsigned i = 0; i < vars.size(); i++) {
|
for (unsigned i = 0; i < vars.size(); i++) {
|
||||||
@ -515,3 +293,13 @@ Factor::print (void) const
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
Factor::copyFromFactor (const Factor& g)
|
||||||
|
{
|
||||||
|
args_ = g.arguments();
|
||||||
|
ranges_ = g.ranges();
|
||||||
|
params_ = g.params();
|
||||||
|
distId_ = g.distId();
|
||||||
|
}
|
||||||
|
|
||||||
|
@ -3,64 +3,293 @@
|
|||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "Distribution.h"
|
|
||||||
#include "VarNode.h"
|
#include "VarNode.h"
|
||||||
|
#include "Indexer.h"
|
||||||
|
#include "Util.h"
|
||||||
|
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
class Distribution;
|
|
||||||
|
template <typename T>
|
||||||
|
class TFactor
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
const vector<T>& arguments (void) const { return args_; }
|
||||||
|
|
||||||
|
vector<T>& arguments (void) { return args_; }
|
||||||
|
|
||||||
|
const Ranges& ranges (void) const { return ranges_; }
|
||||||
|
|
||||||
|
const Params& params (void) const { return params_; }
|
||||||
|
|
||||||
|
Params& params (void) { return params_; }
|
||||||
|
|
||||||
|
unsigned nrArguments (void) const { return args_.size(); }
|
||||||
|
|
||||||
|
unsigned size (void) const { return params_.size(); }
|
||||||
|
|
||||||
|
unsigned distId (void) const { return distId_; }
|
||||||
|
|
||||||
|
void setDistId (unsigned id) { distId_ = id; }
|
||||||
|
|
||||||
|
void setParams (const Params& newParams)
|
||||||
|
{
|
||||||
|
params_ = newParams;
|
||||||
|
assert (params_.size() == Util::expectedSize (ranges_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void normalize (void)
|
||||||
|
{
|
||||||
|
LogAware::normalize (params_);
|
||||||
|
}
|
||||||
|
|
||||||
|
int indexOf (const T& t) const
|
||||||
|
{
|
||||||
|
int idx = -1;
|
||||||
|
for (unsigned i = 0; i < args_.size(); i++) {
|
||||||
|
if (args_[i] == t) {
|
||||||
|
idx = i;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return idx;
|
||||||
|
}
|
||||||
|
|
||||||
|
const T& argument (unsigned idx) const
|
||||||
|
{
|
||||||
|
assert (idx < args_.size());
|
||||||
|
return args_[idx];
|
||||||
|
}
|
||||||
|
|
||||||
|
T& argument (unsigned idx)
|
||||||
|
{
|
||||||
|
assert (idx < args_.size());
|
||||||
|
return args_[idx];
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned range (unsigned idx) const
|
||||||
|
{
|
||||||
|
assert (idx < ranges_.size());
|
||||||
|
return ranges_[idx];
|
||||||
|
}
|
||||||
|
|
||||||
|
void multiply (TFactor<T>& g)
|
||||||
|
{
|
||||||
|
const vector<T>& g_args = g.arguments();
|
||||||
|
const Ranges& g_ranges = g.ranges();
|
||||||
|
const Params& g_params = g.params();
|
||||||
|
if (args_ == g_args) {
|
||||||
|
// optimization: if the factors contain the same set of args,
|
||||||
|
// we can do a 1 to 1 operation on the parameters
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
Util::add (params_, g_params);
|
||||||
|
} else {
|
||||||
|
Util::multiply (params_, g_params);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
bool sharedArgs = false;
|
||||||
|
vector<unsigned> gvarpos;
|
||||||
|
for (unsigned i = 0; i < g_args.size(); i++) {
|
||||||
|
int idx = indexOf (g_args[i]);
|
||||||
|
if (idx == -1) {
|
||||||
|
insertArgument (g_args[i], g_ranges[i]);
|
||||||
|
gvarpos.push_back (args_.size() - 1);
|
||||||
|
} else {
|
||||||
|
sharedArgs = true;
|
||||||
|
gvarpos.push_back (idx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (sharedArgs == false) {
|
||||||
|
// optimization: if the original factors doesn't have common args,
|
||||||
|
// we don't need to marry the states of the common args
|
||||||
|
unsigned count = 0;
|
||||||
|
for (unsigned i = 0; i < params_.size(); i++) {
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
params_[i] += g_params[count];
|
||||||
|
} else {
|
||||||
|
params_[i] *= g_params[count];
|
||||||
|
}
|
||||||
|
count ++;
|
||||||
|
if (count >= g_params.size()) {
|
||||||
|
count = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
StatesIndexer indexer (ranges_, false);
|
||||||
|
while (indexer.valid()) {
|
||||||
|
unsigned g_li = 0;
|
||||||
|
unsigned prod = 1;
|
||||||
|
for (int j = gvarpos.size() - 1; j >= 0; j--) {
|
||||||
|
g_li += indexer[gvarpos[j]] * prod;
|
||||||
|
prod *= g_ranges[j];
|
||||||
|
}
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
params_[indexer] += g_params[g_li];
|
||||||
|
} else {
|
||||||
|
params_[indexer] *= g_params[g_li];
|
||||||
|
}
|
||||||
|
++ indexer;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void absorveEvidence (const T& arg, unsigned evidence)
|
||||||
|
{
|
||||||
|
int idx = indexOf (arg);
|
||||||
|
assert (idx != -1);
|
||||||
|
assert (evidence < ranges_[idx]);
|
||||||
|
Params copy = params_;
|
||||||
|
params_.clear();
|
||||||
|
params_.reserve (copy.size() / ranges_[idx]);
|
||||||
|
StatesIndexer indexer (ranges_);
|
||||||
|
for (unsigned i = 0; i < evidence; i++) {
|
||||||
|
indexer.increment (idx);
|
||||||
|
}
|
||||||
|
while (indexer.valid()) {
|
||||||
|
params_.push_back (copy[indexer]);
|
||||||
|
indexer.incrementExcluding (idx);
|
||||||
|
}
|
||||||
|
args_.erase (args_.begin() + idx);
|
||||||
|
ranges_.erase (ranges_.begin() + idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
void reorderArguments (const vector<T> newArgs)
|
||||||
|
{
|
||||||
|
assert (newArgs.size() == args_.size());
|
||||||
|
if (newArgs == args_) {
|
||||||
|
return; // already in the wanted order
|
||||||
|
}
|
||||||
|
Ranges newRanges;
|
||||||
|
vector<unsigned> positions;
|
||||||
|
for (unsigned i = 0; i < newArgs.size(); i++) {
|
||||||
|
unsigned idx = indexOf (newArgs[i]);
|
||||||
|
newRanges.push_back (ranges_[idx]);
|
||||||
|
positions.push_back (idx);
|
||||||
|
}
|
||||||
|
unsigned N = ranges_.size();
|
||||||
|
Params newParams (params_.size());
|
||||||
|
for (unsigned i = 0; i < params_.size(); i++) {
|
||||||
|
unsigned li = i;
|
||||||
|
// calculate vector index corresponding to linear index
|
||||||
|
vector<unsigned> vi (N);
|
||||||
|
for (int k = N-1; k >= 0; k--) {
|
||||||
|
vi[k] = li % ranges_[k];
|
||||||
|
li /= ranges_[k];
|
||||||
|
}
|
||||||
|
// convert permuted vector index to corresponding linear index
|
||||||
|
unsigned prod = 1;
|
||||||
|
unsigned new_li = 0;
|
||||||
|
for (int k = N - 1; k >= 0; k--) {
|
||||||
|
new_li += vi[positions[k]] * prod;
|
||||||
|
prod *= ranges_[positions[k]];
|
||||||
|
}
|
||||||
|
newParams[new_li] = params_[i];
|
||||||
|
}
|
||||||
|
args_ = newArgs;
|
||||||
|
ranges_ = newRanges;
|
||||||
|
params_ = newParams;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool contains (const T& arg) const
|
||||||
|
{
|
||||||
|
return Util::contains (args_, arg);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool contains (const vector<T>& args) const
|
||||||
|
{
|
||||||
|
for (unsigned i = 0; i < args_.size(); i++) {
|
||||||
|
if (contains (args[i]) == false) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
vector<T> args_;
|
||||||
|
Ranges ranges_;
|
||||||
|
Params params_;
|
||||||
|
unsigned distId_;
|
||||||
|
|
||||||
|
private:
|
||||||
|
void insertArgument (const T& arg, unsigned range)
|
||||||
|
{
|
||||||
|
assert (indexOf (arg) == -1);
|
||||||
|
Params copy = params_;
|
||||||
|
params_.clear();
|
||||||
|
params_.reserve (copy.size() * range);
|
||||||
|
for (unsigned i = 0; i < copy.size(); i++) {
|
||||||
|
for (unsigned reps = 0; reps < range; reps++) {
|
||||||
|
params_.push_back (copy[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
args_.push_back (arg);
|
||||||
|
ranges_.push_back (range);
|
||||||
|
}
|
||||||
|
|
||||||
|
void insertArguments (const vector<T>& args, const Ranges& ranges)
|
||||||
|
{
|
||||||
|
Params copy = params_;
|
||||||
|
unsigned nrStates = 1;
|
||||||
|
for (unsigned i = 0; i < args.size(); i++) {
|
||||||
|
assert (indexOf (args[i]) == -1);
|
||||||
|
args_.push_back (args[i]);
|
||||||
|
ranges_.push_back (ranges[i]);
|
||||||
|
nrStates *= ranges[i];
|
||||||
|
}
|
||||||
|
params_.clear();
|
||||||
|
params_.reserve (copy.size() * nrStates);
|
||||||
|
for (unsigned i = 0; i < copy.size(); i++) {
|
||||||
|
for (unsigned reps = 0; reps < nrStates; reps++) {
|
||||||
|
params_.push_back (copy[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
class Factor
|
|
||||||
|
class Factor : public TFactor<VarId>
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
Factor (void) { }
|
Factor (void) { }
|
||||||
Factor (const Factor&);
|
|
||||||
Factor (VarId, unsigned);
|
|
||||||
Factor (const VarNodes&);
|
|
||||||
Factor (VarId, unsigned, const Params&);
|
|
||||||
Factor (VarNodes&, Distribution*);
|
|
||||||
Factor (const VarNodes&, const Params&);
|
|
||||||
Factor (const VarIds&, const Ranges&, const Params&);
|
|
||||||
~Factor (void);
|
|
||||||
|
|
||||||
void setParameters (const Params&);
|
Factor (const Factor&);
|
||||||
void copyFromFactor (const Factor& f);
|
|
||||||
void multiply (const Factor&);
|
Factor (VarId, unsigned);
|
||||||
void insertVariable (VarId, unsigned);
|
|
||||||
void insertVariables (const VarIds&, const Ranges&);
|
Factor (const VarNodes&);
|
||||||
|
|
||||||
|
Factor (VarId, unsigned, const Params&);
|
||||||
|
|
||||||
|
Factor (const VarNodes&, const Params&,
|
||||||
|
unsigned = Util::maxUnsigned());
|
||||||
|
|
||||||
|
Factor (const VarIds&, const Ranges&, const Params&);
|
||||||
|
|
||||||
void sumOutAllExcept (VarId);
|
void sumOutAllExcept (VarId);
|
||||||
|
|
||||||
void sumOutAllExcept (const VarIds&);
|
void sumOutAllExcept (const VarIds&);
|
||||||
|
|
||||||
void sumOut (VarId);
|
void sumOut (VarId);
|
||||||
|
|
||||||
void sumOutFirstVariable (void);
|
void sumOutFirstVariable (void);
|
||||||
|
|
||||||
void sumOutLastVariable (void);
|
void sumOutLastVariable (void);
|
||||||
void orderVariables (void);
|
|
||||||
void reorderVariables (const VarIds&);
|
void multiply (Factor&);
|
||||||
void absorveEvidence (VarId, unsigned);
|
|
||||||
void normalize (void);
|
void reorderAccordingVarIds (void);
|
||||||
bool contains (const VarIds&) const;
|
|
||||||
int indexOf (VarId) const;
|
|
||||||
string getLabel (void) const;
|
string getLabel (void) const;
|
||||||
|
|
||||||
void print (void) const;
|
void print (void) const;
|
||||||
|
|
||||||
const VarIds& getVarIds (void) const { return varids_; }
|
|
||||||
const Ranges& getRanges (void) const { return ranges_; }
|
|
||||||
const Params& getParameters (void) const { return dist_->params; }
|
|
||||||
Distribution* getDistribution (void) const { return dist_; }
|
|
||||||
unsigned nrVariables (void) const { return varids_.size(); }
|
|
||||||
unsigned nrParameters() const { return dist_->params.size(); }
|
|
||||||
|
|
||||||
void setDistribution (Distribution* dist)
|
|
||||||
{
|
|
||||||
dist_ = dist;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
void copyFromFactor (const Factor& f);
|
||||||
|
|
||||||
VarIds varids_;
|
|
||||||
Ranges ranges_;
|
|
||||||
Distribution* dist_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // HORUS_FACTOR_H
|
#endif // HORUS_FACTOR_H
|
||||||
|
@ -53,10 +53,10 @@ FactorGraph::FactorGraph (const BayesNet& bn)
|
|||||||
neighs.push_back (varNodes_[parents[j]->getIndex()]);
|
neighs.push_back (varNodes_[parents[j]->getIndex()]);
|
||||||
}
|
}
|
||||||
FgFacNode* fn = new FgFacNode (
|
FgFacNode* fn = new FgFacNode (
|
||||||
new Factor (neighs, nodes[i]->getDistribution()));
|
new Factor (neighs, nodes[i]->params(), nodes[i]->distId()));
|
||||||
if (orderFactorVariables) {
|
if (orderFactorVariables) {
|
||||||
sort (neighs.begin(), neighs.end(), CompVarId());
|
sort (neighs.begin(), neighs.end(), CompVarId());
|
||||||
fn->factor()->orderVariables();
|
fn->factor()->reorderAccordingVarIds();
|
||||||
}
|
}
|
||||||
addFactor (fn);
|
addFactor (fn);
|
||||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||||
@ -131,10 +131,10 @@ FactorGraph::readFromUaiFormat (const char* fileName)
|
|||||||
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
||||||
unsigned nParams;
|
unsigned nParams;
|
||||||
is >> nParams;
|
is >> nParams;
|
||||||
if (facNodes_[i]->getParameters().size() != nParams) {
|
if (facNodes_[i]->params().size() != nParams) {
|
||||||
cerr << "error: invalid number of parameters for factor " ;
|
cerr << "error: invalid number of parameters for factor " ;
|
||||||
cerr << facNodes_[i]->getLabel() ;
|
cerr << facNodes_[i]->getLabel() ;
|
||||||
cerr << ", expected: " << facNodes_[i]->getParameters().size();
|
cerr << ", expected: " << facNodes_[i]->params().size();
|
||||||
cerr << ", given: " << nParams << endl;
|
cerr << ", given: " << nParams << endl;
|
||||||
abort();
|
abort();
|
||||||
}
|
}
|
||||||
@ -147,7 +147,7 @@ FactorGraph::readFromUaiFormat (const char* fileName)
|
|||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
Util::toLog (params);
|
Util::toLog (params);
|
||||||
}
|
}
|
||||||
facNodes_[i]->factor()->setParameters (params);
|
facNodes_[i]->factor()->setParams (params);
|
||||||
}
|
}
|
||||||
is.close();
|
is.close();
|
||||||
setIndexes();
|
setIndexes();
|
||||||
@ -335,21 +335,6 @@ FactorGraph::setIndexes (void)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
FactorGraph::freeDistributions (void)
|
|
||||||
{
|
|
||||||
set<Distribution*> dists;
|
|
||||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
|
||||||
dists.insert (facNodes_[i]->factor()->getDistribution());
|
|
||||||
}
|
|
||||||
for (set<Distribution*>::iterator it = dists.begin();
|
|
||||||
it != dists.end(); it++) {
|
|
||||||
delete *it;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
FactorGraph::printGraphicalModel (void) const
|
FactorGraph::printGraphicalModel (void) const
|
||||||
{
|
{
|
||||||
@ -440,7 +425,7 @@ FactorGraph::exportToUaiFormat (const char* fileName) const
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||||
Params params = facNodes_[i]->getParameters();
|
Params params = facNodes_[i]->params();
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
Util::fromLog (params);
|
Util::fromLog (params);
|
||||||
}
|
}
|
||||||
@ -477,7 +462,7 @@ FactorGraph::exportToLibDaiFormat (const char* fileName) const
|
|||||||
out << factorVars[j]->nrStates() << " " ;
|
out << factorVars[j]->nrStates() << " " ;
|
||||||
}
|
}
|
||||||
out << endl;
|
out << endl;
|
||||||
Params params = facNodes_[i]->factor()->getParameters();
|
Params params = facNodes_[i]->factor()->params();
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
Util::fromLog (params);
|
Util::fromLog (params);
|
||||||
}
|
}
|
||||||
|
@ -4,7 +4,6 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "GraphicalModel.h"
|
#include "GraphicalModel.h"
|
||||||
#include "Distribution.h"
|
|
||||||
#include "Factor.h"
|
#include "Factor.h"
|
||||||
#include "Horus.h"
|
#include "Horus.h"
|
||||||
|
|
||||||
@ -13,18 +12,21 @@ using namespace std;
|
|||||||
class BayesNet;
|
class BayesNet;
|
||||||
class FgFacNode;
|
class FgFacNode;
|
||||||
|
|
||||||
|
|
||||||
class FgVarNode : public VarNode
|
class FgVarNode : public VarNode
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
FgVarNode (VarId varId, unsigned nrStates) : VarNode (varId, nrStates) { }
|
FgVarNode (VarId varId, unsigned nrStates) : VarNode (varId, nrStates) { }
|
||||||
|
|
||||||
FgVarNode (const VarNode* v) : VarNode (v) { }
|
FgVarNode (const VarNode* v) : VarNode (v) { }
|
||||||
|
|
||||||
void addNeighbor (FgFacNode* fn) { neighs_.push_back (fn); }
|
void addNeighbor (FgFacNode* fn) { neighs_.push_back (fn); }
|
||||||
|
|
||||||
const FgFacSet& neighbors (void) const { return neighs_; }
|
const FgFacSet& neighbors (void) const { return neighs_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
DISALLOW_COPY_AND_ASSIGN (FgVarNode);
|
DISALLOW_COPY_AND_ASSIGN (FgVarNode);
|
||||||
// members
|
|
||||||
FgFacSet neighs_;
|
FgFacSet neighs_;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -32,13 +34,18 @@ class FgVarNode : public VarNode
|
|||||||
class FgFacNode
|
class FgFacNode
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
FgFacNode (const FgFacNode* fn) {
|
FgFacNode (const FgFacNode* fn)
|
||||||
|
{
|
||||||
factor_ = new Factor (*fn->factor());
|
factor_ = new Factor (*fn->factor());
|
||||||
index_ = -1;
|
index_ = -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
FgFacNode (Factor* f) : factor_(new Factor(*f)), index_(-1) { }
|
FgFacNode (Factor* f) : factor_(new Factor(*f)), index_(-1) { }
|
||||||
|
|
||||||
Factor* factor() const { return factor_; }
|
Factor* factor() const { return factor_; }
|
||||||
|
|
||||||
void addNeighbor (FgVarNode* vn) { neighs_.push_back (vn); }
|
void addNeighbor (FgVarNode* vn) { neighs_.push_back (vn); }
|
||||||
|
|
||||||
const FgVarSet& neighbors (void) const { return neighs_; }
|
const FgVarSet& neighbors (void) const { return neighs_; }
|
||||||
|
|
||||||
int getIndex (void) const
|
int getIndex (void) const
|
||||||
@ -46,28 +53,28 @@ class FgFacNode
|
|||||||
assert (index_ != -1);
|
assert (index_ != -1);
|
||||||
return index_;
|
return index_;
|
||||||
}
|
}
|
||||||
|
|
||||||
void setIndex (int index)
|
void setIndex (int index)
|
||||||
{
|
{
|
||||||
index_ = index;
|
index_ = index;
|
||||||
}
|
}
|
||||||
Distribution* getDistribution (void)
|
|
||||||
|
const Params& params (void) const
|
||||||
{
|
{
|
||||||
return factor_->getDistribution();
|
return factor_->params();
|
||||||
}
|
|
||||||
const Params& getParameters (void) const
|
|
||||||
{
|
|
||||||
return factor_->getParameters();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
string getLabel (void)
|
string getLabel (void)
|
||||||
{
|
{
|
||||||
return factor_->getLabel();
|
return factor_->getLabel();
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
DISALLOW_COPY_AND_ASSIGN (FgFacNode);
|
DISALLOW_COPY_AND_ASSIGN (FgFacNode);
|
||||||
|
|
||||||
Factor* factor_;
|
Factor* factor_;
|
||||||
int index_;
|
|
||||||
FgVarSet neighs_;
|
FgVarSet neighs_;
|
||||||
|
int index_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
@ -83,28 +90,16 @@ struct CompVarId
|
|||||||
class FactorGraph : public GraphicalModel
|
class FactorGraph : public GraphicalModel
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
FactorGraph (void) {};
|
FactorGraph (void) { };
|
||||||
|
|
||||||
FactorGraph (const FactorGraph&);
|
FactorGraph (const FactorGraph&);
|
||||||
|
|
||||||
FactorGraph (const BayesNet&);
|
FactorGraph (const BayesNet&);
|
||||||
|
|
||||||
~FactorGraph (void);
|
~FactorGraph (void);
|
||||||
|
|
||||||
void readFromUaiFormat (const char*);
|
|
||||||
void readFromLibDaiFormat (const char*);
|
|
||||||
void addVariable (FgVarNode*);
|
|
||||||
void addFactor (FgFacNode*);
|
|
||||||
void addEdge (FgVarNode*, FgFacNode*);
|
|
||||||
void addEdge (FgFacNode*, FgVarNode*);
|
|
||||||
VarNode* getVariableNode (unsigned) const;
|
|
||||||
VarNodes getVariableNodes (void) const;
|
|
||||||
bool isTree (void) const;
|
|
||||||
void setIndexes (void);
|
|
||||||
void freeDistributions (void);
|
|
||||||
void printGraphicalModel (void) const;
|
|
||||||
void exportToGraphViz (const char*) const;
|
|
||||||
void exportToUaiFormat (const char*) const;
|
|
||||||
void exportToLibDaiFormat (const char*) const;
|
|
||||||
|
|
||||||
const FgVarSet& getVarNodes (void) const { return varNodes_; }
|
const FgVarSet& getVarNodes (void) const { return varNodes_; }
|
||||||
|
|
||||||
const FgFacSet& getFactorNodes (void) const { return facNodes_; }
|
const FgFacSet& getFactorNodes (void) const { return facNodes_; }
|
||||||
|
|
||||||
FgVarNode* getFgVarNode (VarId vid) const
|
FgVarNode* getFgVarNode (VarId vid) const
|
||||||
@ -117,13 +112,44 @@ class FactorGraph : public GraphicalModel
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void readFromUaiFormat (const char*);
|
||||||
|
|
||||||
|
void readFromLibDaiFormat (const char*);
|
||||||
|
|
||||||
|
void addVariable (FgVarNode*);
|
||||||
|
|
||||||
|
void addFactor (FgFacNode*);
|
||||||
|
|
||||||
|
void addEdge (FgVarNode*, FgFacNode*);
|
||||||
|
|
||||||
|
void addEdge (FgFacNode*, FgVarNode*);
|
||||||
|
|
||||||
|
VarNode* getVariableNode (unsigned) const;
|
||||||
|
|
||||||
|
VarNodes getVariableNodes (void) const;
|
||||||
|
|
||||||
|
bool isTree (void) const;
|
||||||
|
|
||||||
|
void setIndexes (void);
|
||||||
|
|
||||||
|
void printGraphicalModel (void) const;
|
||||||
|
|
||||||
|
void exportToGraphViz (const char*) const;
|
||||||
|
|
||||||
|
void exportToUaiFormat (const char*) const;
|
||||||
|
|
||||||
|
void exportToLibDaiFormat (const char*) const;
|
||||||
|
|
||||||
static bool orderFactorVariables;
|
static bool orderFactorVariables;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
//DISALLOW_COPY_AND_ASSIGN (FactorGraph);
|
// DISALLOW_COPY_AND_ASSIGN (FactorGraph);
|
||||||
|
|
||||||
bool containsCycle (void) const;
|
bool containsCycle (void) const;
|
||||||
|
|
||||||
bool containsCycle (const FgVarNode*, const FgFacNode*,
|
bool containsCycle (const FgVarNode*, const FgFacNode*,
|
||||||
vector<bool>&, vector<bool>&) const;
|
vector<bool>&, vector<bool>&) const;
|
||||||
|
|
||||||
bool containsCycle (const FgFacNode*, const FgVarNode*,
|
bool containsCycle (const FgFacNode*, const FgVarNode*,
|
||||||
vector<bool>&, vector<bool>&) const;
|
vector<bool>&, vector<bool>&) const;
|
||||||
|
|
||||||
|
@ -38,11 +38,11 @@ void
|
|||||||
FgBpSolver::runSolver (void)
|
FgBpSolver::runSolver (void)
|
||||||
{
|
{
|
||||||
clock_t start;
|
clock_t start;
|
||||||
if (COLLECT_STATISTICS) {
|
if (Constants::COLLECT_STATS) {
|
||||||
start = clock();
|
start = clock();
|
||||||
}
|
}
|
||||||
runLoopySolver();
|
runLoopySolver();
|
||||||
if (DL >= 2) {
|
if (Constants::DEBUG >= 2) {
|
||||||
cout << endl;
|
cout << endl;
|
||||||
if (nIters_ < BpOptions::maxIter) {
|
if (nIters_ < BpOptions::maxIter) {
|
||||||
cout << "Sum-Product converged in " ;
|
cout << "Sum-Product converged in " ;
|
||||||
@ -53,18 +53,13 @@ FgBpSolver::runSolver (void)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
unsigned size = factorGraph_->getVarNodes().size();
|
unsigned size = factorGraph_->getVarNodes().size();
|
||||||
if (COLLECT_STATISTICS) {
|
if (Constants::COLLECT_STATS) {
|
||||||
unsigned nIters = 0;
|
unsigned nIters = 0;
|
||||||
bool loopy = factorGraph_->isTree() == false;
|
bool loopy = factorGraph_->isTree() == false;
|
||||||
if (loopy) nIters = nIters_;
|
if (loopy) nIters = nIters_;
|
||||||
double time = (double (clock() - start)) / CLOCKS_PER_SEC;
|
double time = (double (clock() - start)) / CLOCKS_PER_SEC;
|
||||||
Statistics::updateStatistics (size, loopy, nIters, time);
|
Statistics::updateStatistics (size, loopy, nIters, time);
|
||||||
}
|
}
|
||||||
if (EXPORT_TO_GRAPHVIZ && size > EXPORT_MINIMAL_SIZE) {
|
|
||||||
stringstream ss;
|
|
||||||
ss << Statistics::getSolvedNetworksCounting() << "." << size << ".dot" ;
|
|
||||||
factorGraph_->exportToGraphViz (ss.str().c_str());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -76,22 +71,22 @@ FgBpSolver::getPosterioriOf (VarId vid)
|
|||||||
FgVarNode* var = factorGraph_->getFgVarNode (vid);
|
FgVarNode* var = factorGraph_->getFgVarNode (vid);
|
||||||
Params probs;
|
Params probs;
|
||||||
if (var->hasEvidence()) {
|
if (var->hasEvidence()) {
|
||||||
probs.resize (var->nrStates(), Util::noEvidence());
|
probs.resize (var->nrStates(), LogAware::noEvidence());
|
||||||
probs[var->getEvidence()] = Util::withEvidence();
|
probs[var->getEvidence()] = LogAware::withEvidence();
|
||||||
} else {
|
} else {
|
||||||
probs.resize (var->nrStates(), Util::multIdenty());
|
probs.resize (var->nrStates(), LogAware::multIdenty());
|
||||||
const SpLinkSet& links = ninf(var)->getLinks();
|
const SpLinkSet& links = ninf(var)->getLinks();
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
for (unsigned i = 0; i < links.size(); i++) {
|
for (unsigned i = 0; i < links.size(); i++) {
|
||||||
Util::add (probs, links[i]->getMessage());
|
Util::add (probs, links[i]->getMessage());
|
||||||
}
|
}
|
||||||
Util::normalize (probs);
|
LogAware::normalize (probs);
|
||||||
Util::fromLog (probs);
|
Util::fromLog (probs);
|
||||||
} else {
|
} else {
|
||||||
for (unsigned i = 0; i < links.size(); i++) {
|
for (unsigned i = 0; i < links.size(); i++) {
|
||||||
Util::multiply (probs, links[i]->getMessage());
|
Util::multiply (probs, links[i]->getMessage());
|
||||||
}
|
}
|
||||||
Util::normalize (probs);
|
LogAware::normalize (probs);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return probs;
|
return probs;
|
||||||
@ -102,9 +97,9 @@ FgBpSolver::getPosterioriOf (VarId vid)
|
|||||||
Params
|
Params
|
||||||
FgBpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
FgBpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
||||||
{
|
{
|
||||||
|
int idx = -1;
|
||||||
FgVarNode* vn = factorGraph_->getFgVarNode (jointVarIds[0]);
|
FgVarNode* vn = factorGraph_->getFgVarNode (jointVarIds[0]);
|
||||||
const FgFacSet& factorNodes = vn->neighbors();
|
const FgFacSet& factorNodes = vn->neighbors();
|
||||||
int idx = -1;
|
|
||||||
for (unsigned i = 0; i < factorNodes.size(); i++) {
|
for (unsigned i = 0; i < factorNodes.size(); i++) {
|
||||||
if (factorNodes[i]->factor()->contains (jointVarIds)) {
|
if (factorNodes[i]->factor()->contains (jointVarIds)) {
|
||||||
idx = i;
|
idx = i;
|
||||||
@ -114,18 +109,18 @@ FgBpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
|||||||
if (idx == -1) {
|
if (idx == -1) {
|
||||||
return getJointByConditioning (jointVarIds);
|
return getJointByConditioning (jointVarIds);
|
||||||
} else {
|
} else {
|
||||||
Factor r (*factorNodes[idx]->factor());
|
Factor res (*factorNodes[idx]->factor());
|
||||||
const SpLinkSet& links = ninf(factorNodes[idx])->getLinks();
|
const SpLinkSet& links = ninf(factorNodes[idx])->getLinks();
|
||||||
for (unsigned i = 0; i < links.size(); i++) {
|
for (unsigned i = 0; i < links.size(); i++) {
|
||||||
Factor msg (links[i]->getVariable()->varId(),
|
Factor msg (links[i]->getVariable()->varId(),
|
||||||
links[i]->getVariable()->nrStates(),
|
links[i]->getVariable()->nrStates(),
|
||||||
getVar2FactorMsg (links[i]));
|
getVar2FactorMsg (links[i]));
|
||||||
r.multiply (msg);
|
res.multiply (msg);
|
||||||
}
|
}
|
||||||
r.sumOutAllExcept (jointVarIds);
|
res.sumOutAllExcept (jointVarIds);
|
||||||
r.reorderVariables (jointVarIds);
|
res.reorderArguments (jointVarIds);
|
||||||
r.normalize();
|
res.normalize();
|
||||||
Params jointDist = r.getParameters();
|
Params jointDist = res.params();
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
Util::fromLog (jointDist);
|
Util::fromLog (jointDist);
|
||||||
}
|
}
|
||||||
@ -144,13 +139,8 @@ FgBpSolver::runLoopySolver (void)
|
|||||||
while (!converged() && nIters_ < BpOptions::maxIter) {
|
while (!converged() && nIters_ < BpOptions::maxIter) {
|
||||||
|
|
||||||
nIters_ ++;
|
nIters_ ++;
|
||||||
if (DL >= 2) {
|
if (Constants::DEBUG >= 2) {
|
||||||
cout << "****************************************" ;
|
Util::printHeader (" Iteration " + nIters_);
|
||||||
cout << "****************************************" ;
|
|
||||||
cout << endl;
|
|
||||||
cout << " Iteration " << nIters_ << endl;
|
|
||||||
cout << "****************************************" ;
|
|
||||||
cout << "****************************************" ;
|
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -178,7 +168,7 @@ FgBpSolver::runLoopySolver (void)
|
|||||||
maxResidualSchedule();
|
maxResidualSchedule();
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
if (DL >= 2) {
|
if (Constants::DEBUG >= 2) {
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -256,12 +246,12 @@ FgBpSolver::converged (void)
|
|||||||
} else {
|
} else {
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
for (unsigned i = 0; i < links_.size(); i++) {
|
||||||
double residual = links_[i]->getResidual();
|
double residual = links_[i]->getResidual();
|
||||||
if (DL >= 2) {
|
if (Constants::DEBUG >= 2) {
|
||||||
cout << links_[i]->toString() + " residual = " << residual << endl;
|
cout << links_[i]->toString() + " residual = " << residual << endl;
|
||||||
}
|
}
|
||||||
if (residual > BpOptions::accuracy) {
|
if (residual > BpOptions::accuracy) {
|
||||||
converged = false;
|
converged = false;
|
||||||
if (DL == 0) break;
|
if (Constants::DEBUG == 0) break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -283,7 +273,7 @@ FgBpSolver::maxResidualSchedule (void)
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (unsigned c = 0; c < links_.size(); c++) {
|
for (unsigned c = 0; c < links_.size(); c++) {
|
||||||
if (DL >= 2) {
|
if (Constants::DEBUG >= 2) {
|
||||||
cout << "current residuals:" << endl;
|
cout << "current residuals:" << endl;
|
||||||
for (SortedOrder::iterator it = sortedOrder_.begin();
|
for (SortedOrder::iterator it = sortedOrder_.begin();
|
||||||
it != sortedOrder_.end(); it ++) {
|
it != sortedOrder_.end(); it ++) {
|
||||||
@ -317,9 +307,8 @@ FgBpSolver::maxResidualSchedule (void)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (DL >= 2) {
|
if (Constants::DEBUG >= 2) {
|
||||||
cout << "----------------------------------------" ;
|
Util::printDashedLine();
|
||||||
cout << "----------------------------------------" << endl;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -339,7 +328,7 @@ FgBpSolver::calculateFactor2VariableMsg (SpLink* link) const
|
|||||||
msgSize *= links[i]->getVariable()->nrStates();
|
msgSize *= links[i]->getVariable()->nrStates();
|
||||||
}
|
}
|
||||||
unsigned repetitions = 1;
|
unsigned repetitions = 1;
|
||||||
Params msgProduct (msgSize, Util::multIdenty());
|
Params msgProduct (msgSize, LogAware::multIdenty());
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
for (int i = links.size() - 1; i >= 0; i--) {
|
for (int i = links.size() - 1; i >= 0; i--) {
|
||||||
if (links[i]->getVariable() != dst) {
|
if (links[i]->getVariable() != dst) {
|
||||||
@ -354,7 +343,7 @@ FgBpSolver::calculateFactor2VariableMsg (SpLink* link) const
|
|||||||
} else {
|
} else {
|
||||||
for (int i = links.size() - 1; i >= 0; i--) {
|
for (int i = links.size() - 1; i >= 0; i--) {
|
||||||
if (links[i]->getVariable() != dst) {
|
if (links[i]->getVariable() != dst) {
|
||||||
if (DL >= 5) {
|
if (Constants::DEBUG >= 5) {
|
||||||
cout << " message from " << links[i]->getVariable()->label();
|
cout << " message from " << links[i]->getVariable()->label();
|
||||||
cout << ": " << endl;
|
cout << ": " << endl;
|
||||||
}
|
}
|
||||||
@ -368,34 +357,29 @@ FgBpSolver::calculateFactor2VariableMsg (SpLink* link) const
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Factor result (src->factor()->getVarIds(),
|
Factor result (src->factor()->arguments(),
|
||||||
src->factor()->getRanges(),
|
src->factor()->ranges(),
|
||||||
msgProduct);
|
msgProduct);
|
||||||
result.multiply (*(src->factor()));
|
result.multiply (*(src->factor()));
|
||||||
if (DL >= 5) {
|
if (Constants::DEBUG >= 5) {
|
||||||
cout << " message product: " ;
|
cout << " message product: " << msgProduct << endl;
|
||||||
cout << Util::parametersToString (msgProduct) << endl;
|
cout << " original factor: " << src->params() << endl;
|
||||||
cout << " original factor: " ;
|
cout << " factor product: " << result.params() << endl;
|
||||||
cout << Util::parametersToString (src->getParameters()) << endl;
|
|
||||||
cout << " factor product: " ;
|
|
||||||
cout << Util::parametersToString (result.getParameters()) << endl;
|
|
||||||
}
|
}
|
||||||
result.sumOutAllExcept (dst->varId());
|
result.sumOutAllExcept (dst->varId());
|
||||||
if (DL >= 5) {
|
if (Constants::DEBUG >= 5) {
|
||||||
cout << " marginalized: " ;
|
cout << " marginalized: " ;
|
||||||
cout << Util::parametersToString (result.getParameters()) << endl;
|
cout << result.params() << endl;
|
||||||
}
|
}
|
||||||
const Params& resultParams = result.getParameters();
|
const Params& resultParams = result.params();
|
||||||
Params& message = link->getNextMessage();
|
Params& message = link->getNextMessage();
|
||||||
for (unsigned i = 0; i < resultParams.size(); i++) {
|
for (unsigned i = 0; i < resultParams.size(); i++) {
|
||||||
message[i] = resultParams[i];
|
message[i] = resultParams[i];
|
||||||
}
|
}
|
||||||
Util::normalize (message);
|
LogAware::normalize (message);
|
||||||
if (DL >= 5) {
|
if (Constants::DEBUG >= 5) {
|
||||||
cout << " curr msg: " ;
|
cout << " curr msg: " << link->getMessage() << endl;
|
||||||
cout << Util::parametersToString (link->getMessage()) << endl;
|
cout << " next msg: " << message << endl;
|
||||||
cout << " next msg: " ;
|
|
||||||
cout << Util::parametersToString (message) << endl;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -408,16 +392,16 @@ FgBpSolver::getVar2FactorMsg (const SpLink* link) const
|
|||||||
const FgFacNode* dst = link->getFactor();
|
const FgFacNode* dst = link->getFactor();
|
||||||
Params msg;
|
Params msg;
|
||||||
if (src->hasEvidence()) {
|
if (src->hasEvidence()) {
|
||||||
msg.resize (src->nrStates(), Util::noEvidence());
|
msg.resize (src->nrStates(), LogAware::noEvidence());
|
||||||
msg[src->getEvidence()] = Util::withEvidence();
|
msg[src->getEvidence()] = LogAware::withEvidence();
|
||||||
if (DL >= 5) {
|
if (Constants::DEBUG >= 5) {
|
||||||
cout << Util::parametersToString (msg);
|
cout << msg;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
msg.resize (src->nrStates(), Util::one());
|
msg.resize (src->nrStates(), LogAware::one());
|
||||||
}
|
}
|
||||||
if (DL >= 5) {
|
if (Constants::DEBUG >= 5) {
|
||||||
cout << Util::parametersToString (msg);
|
cout << msg;
|
||||||
}
|
}
|
||||||
const SpLinkSet& links = ninf (src)->getLinks();
|
const SpLinkSet& links = ninf (src)->getLinks();
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
@ -430,14 +414,14 @@ FgBpSolver::getVar2FactorMsg (const SpLink* link) const
|
|||||||
for (unsigned i = 0; i < links.size(); i++) {
|
for (unsigned i = 0; i < links.size(); i++) {
|
||||||
if (links[i]->getFactor() != dst) {
|
if (links[i]->getFactor() != dst) {
|
||||||
Util::multiply (msg, links[i]->getMessage());
|
Util::multiply (msg, links[i]->getMessage());
|
||||||
if (DL >= 5) {
|
if (Constants::DEBUG >= 5) {
|
||||||
cout << " x " << Util::parametersToString (links[i]->getMessage());
|
cout << " x " << links[i]->getMessage();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (DL >= 5) {
|
if (Constants::DEBUG >= 5) {
|
||||||
cout << " = " << Util::parametersToString (msg);
|
cout << " = " << msg;
|
||||||
}
|
}
|
||||||
return msg;
|
return msg;
|
||||||
}
|
}
|
||||||
@ -503,9 +487,9 @@ FgBpSolver::printLinkInformation (void) const
|
|||||||
SpLink* l = links_[i];
|
SpLink* l = links_[i];
|
||||||
cout << l->toString() << ":" << endl;
|
cout << l->toString() << ":" << endl;
|
||||||
cout << " curr msg = " ;
|
cout << " curr msg = " ;
|
||||||
cout << Util::parametersToString (l->getMessage()) << endl;
|
cout << l->getMessage() << endl;
|
||||||
cout << " next msg = " ;
|
cout << " next msg = " ;
|
||||||
cout << Util::parametersToString (l->getNextMessage()) << endl;
|
cout << l->getNextMessage() << endl;
|
||||||
cout << " residual = " << l->getResidual() << endl;
|
cout << " residual = " << l->getResidual() << endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -13,7 +13,6 @@
|
|||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SpLink
|
class SpLink
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
@ -21,15 +20,34 @@ class SpLink
|
|||||||
{
|
{
|
||||||
fac_ = fn;
|
fac_ = fn;
|
||||||
var_ = vn;
|
var_ = vn;
|
||||||
v1_.resize (vn->nrStates(), Util::tl (1.0 / vn->nrStates()));
|
v1_.resize (vn->nrStates(), LogAware::tl (1.0 / vn->nrStates()));
|
||||||
v2_.resize (vn->nrStates(), Util::tl (1.0 / vn->nrStates()));
|
v2_.resize (vn->nrStates(), LogAware::tl (1.0 / vn->nrStates()));
|
||||||
currMsg_ = &v1_;
|
currMsg_ = &v1_;
|
||||||
nextMsg_ = &v2_;
|
nextMsg_ = &v2_;
|
||||||
msgSended_ = false;
|
msgSended_ = false;
|
||||||
residual_ = 0.0;
|
residual_ = 0.0;
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual ~SpLink (void) {};
|
virtual ~SpLink (void) { };
|
||||||
|
|
||||||
|
FgFacNode* getFactor (void) const { return fac_; }
|
||||||
|
|
||||||
|
FgVarNode* getVariable (void) const { return var_; }
|
||||||
|
|
||||||
|
const Params& getMessage (void) const { return *currMsg_; }
|
||||||
|
|
||||||
|
Params& getNextMessage (void) { return *nextMsg_; }
|
||||||
|
|
||||||
|
bool messageWasSended (void) const { return msgSended_; }
|
||||||
|
|
||||||
|
double getResidual (void) const { return residual_; }
|
||||||
|
|
||||||
|
void clearResidual (void) { residual_ = 0.0; }
|
||||||
|
|
||||||
|
void updateResidual (void)
|
||||||
|
{
|
||||||
|
residual_ = LogAware::getMaxNorm (v1_,v2_);
|
||||||
|
}
|
||||||
|
|
||||||
virtual void updateMessage (void)
|
virtual void updateMessage (void)
|
||||||
{
|
{
|
||||||
@ -37,11 +55,6 @@ class SpLink
|
|||||||
msgSended_ = true;
|
msgSended_ = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void updateResidual (void)
|
|
||||||
{
|
|
||||||
residual_ = Util::getMaxNorm (v1_, v2_);
|
|
||||||
}
|
|
||||||
|
|
||||||
string toString (void) const
|
string toString (void) const
|
||||||
{
|
{
|
||||||
stringstream ss;
|
stringstream ss;
|
||||||
@ -51,14 +64,6 @@ class SpLink
|
|||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
FgFacNode* getFactor (void) const { return fac_; }
|
|
||||||
FgVarNode* getVariable (void) const { return var_; }
|
|
||||||
const Params& getMessage (void) const { return *currMsg_; }
|
|
||||||
Params& getNextMessage (void) { return *nextMsg_; }
|
|
||||||
bool messageWasSended (void) const { return msgSended_; }
|
|
||||||
double getResidual (void) const { return residual_; }
|
|
||||||
void clearResidual (void) { residual_ = 0.0; }
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
FgFacNode* fac_;
|
FgFacNode* fac_;
|
||||||
FgVarNode* var_;
|
FgVarNode* var_;
|
||||||
@ -70,7 +75,6 @@ class SpLink
|
|||||||
double residual_;
|
double residual_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
typedef vector<SpLink*> SpLinkSet;
|
typedef vector<SpLink*> SpLinkSet;
|
||||||
|
|
||||||
|
|
||||||
@ -79,7 +83,6 @@ class SPNodeInfo
|
|||||||
public:
|
public:
|
||||||
void addSpLink (SpLink* link) { links_.push_back (link); }
|
void addSpLink (SpLink* link) { links_.push_back (link); }
|
||||||
const SpLinkSet& getLinks (void) { return links_; }
|
const SpLinkSet& getLinks (void) { return links_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
SpLinkSet links_;
|
SpLinkSet links_;
|
||||||
};
|
};
|
||||||
@ -89,52 +92,30 @@ class FgBpSolver : public Solver
|
|||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
FgBpSolver (const FactorGraph&);
|
FgBpSolver (const FactorGraph&);
|
||||||
|
|
||||||
virtual ~FgBpSolver (void);
|
virtual ~FgBpSolver (void);
|
||||||
|
|
||||||
void runSolver (void);
|
void runSolver (void);
|
||||||
|
|
||||||
virtual Params getPosterioriOf (VarId);
|
virtual Params getPosterioriOf (VarId);
|
||||||
|
|
||||||
virtual Params getJointDistributionOf (const VarIds&);
|
virtual Params getJointDistributionOf (const VarIds&);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
virtual void initializeSolver (void);
|
virtual void initializeSolver (void);
|
||||||
|
|
||||||
virtual void createLinks (void);
|
virtual void createLinks (void);
|
||||||
|
|
||||||
virtual void maxResidualSchedule (void);
|
virtual void maxResidualSchedule (void);
|
||||||
|
|
||||||
virtual void calculateFactor2VariableMsg (SpLink*) const;
|
virtual void calculateFactor2VariableMsg (SpLink*) const;
|
||||||
|
|
||||||
virtual Params getVar2FactorMsg (const SpLink*) const;
|
virtual Params getVar2FactorMsg (const SpLink*) const;
|
||||||
|
|
||||||
virtual Params getJointByConditioning (const VarIds&) const;
|
virtual Params getJointByConditioning (const VarIds&) const;
|
||||||
|
|
||||||
virtual void printLinkInformation (void) const;
|
virtual void printLinkInformation (void) const;
|
||||||
|
|
||||||
void calculateAndUpdateMessage (SpLink* link, bool calcResidual = true)
|
|
||||||
{
|
|
||||||
if (DL >= 3) {
|
|
||||||
cout << "calculating & updating " << link->toString() << endl;
|
|
||||||
}
|
|
||||||
calculateFactor2VariableMsg (link);
|
|
||||||
if (calcResidual) {
|
|
||||||
link->updateResidual();
|
|
||||||
}
|
|
||||||
link->updateMessage();
|
|
||||||
}
|
|
||||||
|
|
||||||
void calculateMessage (SpLink* link, bool calcResidual = true)
|
|
||||||
{
|
|
||||||
if (DL >= 3) {
|
|
||||||
cout << "calculating " << link->toString() << endl;
|
|
||||||
}
|
|
||||||
calculateFactor2VariableMsg (link);
|
|
||||||
if (calcResidual) {
|
|
||||||
link->updateResidual();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void updateMessage (SpLink* link)
|
|
||||||
{
|
|
||||||
link->updateMessage();
|
|
||||||
if (DL >= 3) {
|
|
||||||
cout << "updating " << link->toString() << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
SPNodeInfo* ninf (const FgVarNode* var) const
|
SPNodeInfo* ninf (const FgVarNode* var) const
|
||||||
{
|
{
|
||||||
return varsI_[var->getIndex()];
|
return varsI_[var->getIndex()];
|
||||||
@ -145,7 +126,39 @@ class FgBpSolver : public Solver
|
|||||||
return facsI_[fac->getIndex()];
|
return facsI_[fac->getIndex()];
|
||||||
}
|
}
|
||||||
|
|
||||||
struct CompareResidual {
|
void calculateAndUpdateMessage (SpLink* link, bool calcResidual = true)
|
||||||
|
{
|
||||||
|
if (Constants::DEBUG >= 3) {
|
||||||
|
cout << "calculating & updating " << link->toString() << endl;
|
||||||
|
}
|
||||||
|
calculateFactor2VariableMsg (link);
|
||||||
|
if (calcResidual) {
|
||||||
|
link->updateResidual();
|
||||||
|
}
|
||||||
|
link->updateMessage();
|
||||||
|
}
|
||||||
|
|
||||||
|
void calculateMessage (SpLink* link, bool calcResidual = true)
|
||||||
|
{
|
||||||
|
if (Constants::DEBUG >= 3) {
|
||||||
|
cout << "calculating " << link->toString() << endl;
|
||||||
|
}
|
||||||
|
calculateFactor2VariableMsg (link);
|
||||||
|
if (calcResidual) {
|
||||||
|
link->updateResidual();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void updateMessage (SpLink* link)
|
||||||
|
{
|
||||||
|
link->updateMessage();
|
||||||
|
if (Constants::DEBUG >= 3) {
|
||||||
|
cout << "updating " << link->toString() << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct CompareResidual
|
||||||
|
{
|
||||||
inline bool operator() (const SpLink* link1, const SpLink* link2)
|
inline bool operator() (const SpLink* link1, const SpLink* link2)
|
||||||
{
|
{
|
||||||
return link1->getResidual() > link2->getResidual();
|
return link1->getResidual() > link2->getResidual();
|
||||||
@ -167,8 +180,6 @@ class FgBpSolver : public Solver
|
|||||||
private:
|
private:
|
||||||
void runLoopySolver (void);
|
void runLoopySolver (void);
|
||||||
bool converged (void);
|
bool converged (void);
|
||||||
|
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // HORUS_FGBPSOLVER_H
|
#endif // HORUS_FGBPSOLVER_H
|
||||||
|
@ -8,7 +8,9 @@
|
|||||||
|
|
||||||
|
|
||||||
vector<LiftedOperator*>
|
vector<LiftedOperator*>
|
||||||
LiftedOperator::getValidOps (ParfactorList& pfList, const Grounds& query)
|
LiftedOperator::getValidOps (
|
||||||
|
ParfactorList& pfList,
|
||||||
|
const Grounds& query)
|
||||||
{
|
{
|
||||||
vector<LiftedOperator*> validOps;
|
vector<LiftedOperator*> validOps;
|
||||||
vector<SumOutOperator*> sumOutOps;
|
vector<SumOutOperator*> sumOutOps;
|
||||||
@ -28,12 +30,15 @@ LiftedOperator::getValidOps (ParfactorList& pfList, const Grounds& query)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
LiftedOperator::printValidOps (ParfactorList& pfList, const Grounds& query)
|
LiftedOperator::printValidOps (
|
||||||
|
ParfactorList& pfList,
|
||||||
|
const Grounds& query)
|
||||||
{
|
{
|
||||||
vector<LiftedOperator*> validOps;
|
vector<LiftedOperator*> validOps;
|
||||||
validOps = LiftedOperator::getValidOps (pfList, query);
|
validOps = LiftedOperator::getValidOps (pfList, query);
|
||||||
for (unsigned i = 0; i < validOps.size(); i++) {
|
for (unsigned i = 0; i < validOps.size(); i++) {
|
||||||
cout << "-> " << validOps[i]->toString() << endl;
|
cout << "-> " << validOps[i]->toString() << endl;
|
||||||
|
delete validOps[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -56,7 +61,7 @@ SumOutOperator::getCost (void)
|
|||||||
pfIter = pfList_.begin();
|
pfIter = pfList_.begin();
|
||||||
while (pfIter != pfList_.end()) {
|
while (pfIter != pfList_.end()) {
|
||||||
if ((*pfIter)->containsGroup (groupSet[i])) {
|
if ((*pfIter)->containsGroup (groupSet[i])) {
|
||||||
int idx = (*pfIter)->indexOfFormulaWithGroup (groupSet[i]);
|
int idx = (*pfIter)->indexOfGroup (groupSet[i]);
|
||||||
cost *= (*pfIter)->range (idx);
|
cost *= (*pfIter)->range (idx);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -77,14 +82,13 @@ SumOutOperator::apply (void)
|
|||||||
pfList_.remove (iters[0]);
|
pfList_.remove (iters[0]);
|
||||||
for (unsigned i = 1; i < iters.size(); i++) {
|
for (unsigned i = 1; i < iters.size(); i++) {
|
||||||
product->multiply (**(iters[i]));
|
product->multiply (**(iters[i]));
|
||||||
delete *(iters[i]);
|
pfList_.removeAndDelete (iters[i]);
|
||||||
pfList_.remove (iters[i]);
|
|
||||||
}
|
}
|
||||||
if (product->nrFormulas() == 1) {
|
if (product->nrArguments() == 1) {
|
||||||
delete product;
|
delete product;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
int fIdx = product->indexOfFormulaWithGroup (group_);
|
int fIdx = product->indexOfGroup (group_);
|
||||||
LogVarSet excl = product->exclusiveLogVars (fIdx);
|
LogVarSet excl = product->exclusiveLogVars (fIdx);
|
||||||
if (product->constr()->isCountNormalized (excl)) {
|
if (product->constr()->isCountNormalized (excl)) {
|
||||||
product->sumOut (fIdx);
|
product->sumOut (fIdx);
|
||||||
@ -96,21 +100,21 @@ SumOutOperator::apply (void)
|
|||||||
pfList_.add (pfs[i]);
|
pfList_.add (pfs[i]);
|
||||||
}
|
}
|
||||||
delete product;
|
delete product;
|
||||||
pfList_.shatter();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
vector<SumOutOperator*>
|
vector<SumOutOperator*>
|
||||||
SumOutOperator::getValidOps (ParfactorList& pfList, const Grounds& query)
|
SumOutOperator::getValidOps (
|
||||||
|
ParfactorList& pfList,
|
||||||
|
const Grounds& query)
|
||||||
{
|
{
|
||||||
vector<SumOutOperator*> validOps;
|
vector<SumOutOperator*> validOps;
|
||||||
set<unsigned> allGroups;
|
set<unsigned> allGroups;
|
||||||
ParfactorList::const_iterator it = pfList.begin();
|
ParfactorList::const_iterator it = pfList.begin();
|
||||||
while (it != pfList.end()) {
|
while (it != pfList.end()) {
|
||||||
assert (*it);
|
const ProbFormulas& formulas = (*it)->arguments();
|
||||||
const ProbFormulas& formulas = (*it)->formulas();
|
|
||||||
for (unsigned i = 0; i < formulas.size(); i++) {
|
for (unsigned i = 0; i < formulas.size(); i++) {
|
||||||
allGroups.insert (formulas[i].group());
|
allGroups.insert (formulas[i].group());
|
||||||
}
|
}
|
||||||
@ -134,8 +138,8 @@ SumOutOperator::toString (void)
|
|||||||
stringstream ss;
|
stringstream ss;
|
||||||
vector<ParfactorList::iterator> pfIters;
|
vector<ParfactorList::iterator> pfIters;
|
||||||
pfIters = parfactorsWithGroup (pfList_, group_);
|
pfIters = parfactorsWithGroup (pfList_, group_);
|
||||||
int idx = (*pfIters[0])->indexOfFormulaWithGroup (group_);
|
int idx = (*pfIters[0])->indexOfGroup (group_);
|
||||||
ProbFormula f = (*pfIters[0])->formula (idx);
|
ProbFormula f = (*pfIters[0])->argument (idx);
|
||||||
TupleSet tupleSet = (*pfIters[0])->constr()->tupleSet (f.logVars());
|
TupleSet tupleSet = (*pfIters[0])->constr()->tupleSet (f.logVars());
|
||||||
ss << "sum out " << f.functor() << "/" << f.arity();
|
ss << "sum out " << f.functor() << "/" << f.arity();
|
||||||
ss << "|" << tupleSet << " (group " << group_ << ")";
|
ss << "|" << tupleSet << " (group " << group_ << ")";
|
||||||
@ -158,8 +162,8 @@ SumOutOperator::validOp (
|
|||||||
}
|
}
|
||||||
unordered_map<unsigned, unsigned> groupToRange;
|
unordered_map<unsigned, unsigned> groupToRange;
|
||||||
for (unsigned i = 0; i < pfIters.size(); i++) {
|
for (unsigned i = 0; i < pfIters.size(); i++) {
|
||||||
int fIdx = (*pfIters[i])->indexOfFormulaWithGroup (group);
|
int fIdx = (*pfIters[i])->indexOfGroup (group);
|
||||||
if ((*pfIters[i])->formulas()[fIdx].contains (
|
if ((*pfIters[i])->argument (fIdx).contains (
|
||||||
(*pfIters[i])->elimLogVars()) == false) {
|
(*pfIters[i])->elimLogVars()) == false) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -206,8 +210,8 @@ SumOutOperator::isToEliminate (
|
|||||||
unsigned group,
|
unsigned group,
|
||||||
const Grounds& query)
|
const Grounds& query)
|
||||||
{
|
{
|
||||||
int fIdx = g->indexOfFormulaWithGroup (group);
|
int fIdx = g->indexOfGroup (group);
|
||||||
const ProbFormula& formula = g->formula (fIdx);
|
const ProbFormula& formula = g->argument (fIdx);
|
||||||
bool toElim = true;
|
bool toElim = true;
|
||||||
for (unsigned i = 0; i < query.size(); i++) {
|
for (unsigned i = 0; i < query.size(); i++) {
|
||||||
if (formula.functor() == query[i].functor() &&
|
if (formula.functor() == query[i].functor() &&
|
||||||
@ -228,7 +232,7 @@ unsigned
|
|||||||
CountingOperator::getCost (void)
|
CountingOperator::getCost (void)
|
||||||
{
|
{
|
||||||
unsigned cost = 0;
|
unsigned cost = 0;
|
||||||
int fIdx = (*pfIter_)->indexOfFormulaWithLogVar (X_);
|
int fIdx = (*pfIter_)->indexOfLogVar (X_);
|
||||||
unsigned range = (*pfIter_)->range (fIdx);
|
unsigned range = (*pfIter_)->range (fIdx);
|
||||||
unsigned size = (*pfIter_)->size() / range;
|
unsigned size = (*pfIter_)->size() / range;
|
||||||
TinySet<unsigned> counts;
|
TinySet<unsigned> counts;
|
||||||
@ -247,18 +251,19 @@ CountingOperator::apply (void)
|
|||||||
if ((*pfIter_)->constr()->isCountNormalized (X_)) {
|
if ((*pfIter_)->constr()->isCountNormalized (X_)) {
|
||||||
(*pfIter_)->countConvert (X_);
|
(*pfIter_)->countConvert (X_);
|
||||||
} else {
|
} else {
|
||||||
Parfactors pfs = FoveSolver::countNormalize (*pfIter_, X_);
|
Parfactor* pf = *pfIter_;
|
||||||
|
pfList_.remove (pfIter_);
|
||||||
|
Parfactors pfs = FoveSolver::countNormalize (pf, X_);
|
||||||
for (unsigned i = 0; i < pfs.size(); i++) {
|
for (unsigned i = 0; i < pfs.size(); i++) {
|
||||||
unsigned condCount = pfs[i]->constr()->getConditionalCount (X_);
|
unsigned condCount = pfs[i]->constr()->getConditionalCount (X_);
|
||||||
bool cartProduct = pfs[i]->constr()->isCarteesianProduct (
|
bool cartProduct = pfs[i]->constr()->isCarteesianProduct (
|
||||||
(*pfIter_)->countedLogVars() | X_);
|
pfs[i]->countedLogVars() | X_);
|
||||||
if (condCount > 1 && cartProduct) {
|
if (condCount > 1 && cartProduct) {
|
||||||
pfs[i]->countConvert (X_);
|
pfs[i]->countConvert (X_);
|
||||||
}
|
}
|
||||||
pfList_.add (pfs[i]);
|
pfList_.add (pfs[i]);
|
||||||
}
|
}
|
||||||
pfList_.deleteAndRemove (pfIter_);
|
delete pf;
|
||||||
pfList_.shatter();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -289,14 +294,17 @@ CountingOperator::toString (void)
|
|||||||
{
|
{
|
||||||
stringstream ss;
|
stringstream ss;
|
||||||
ss << "count convert " << X_ << " in " ;
|
ss << "count convert " << X_ << " in " ;
|
||||||
ss << (*pfIter_)->getHeaderString();
|
ss << (*pfIter_)->getLabel();
|
||||||
ss << " [cost=" << getCost() << "]" << endl;
|
ss << " [cost=" << getCost() << "]" << endl;
|
||||||
Parfactors pfs = FoveSolver::countNormalize (*pfIter_, X_);
|
Parfactors pfs = FoveSolver::countNormalize (*pfIter_, X_);
|
||||||
if ((*pfIter_)->constr()->isCountNormalized (X_) == false) {
|
if ((*pfIter_)->constr()->isCountNormalized (X_) == false) {
|
||||||
for (unsigned i = 0; i < pfs.size(); i++) {
|
for (unsigned i = 0; i < pfs.size(); i++) {
|
||||||
ss << " º " << pfs[i]->getHeaderString() << endl;
|
ss << " º " << pfs[i]->getLabel() << endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
for (unsigned i = 0; i < pfs.size(); i++) {
|
||||||
|
delete pfs[i];
|
||||||
|
}
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -308,8 +316,8 @@ CountingOperator::validOp (Parfactor* g, LogVar X)
|
|||||||
if (g->nrFormulas (X) != 1) {
|
if (g->nrFormulas (X) != 1) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
int fIdx = g->indexOfFormulaWithLogVar (X);
|
int fIdx = g->indexOfLogVar (X);
|
||||||
if (g->formulas()[fIdx].isCounting()) {
|
if (g->argument (fIdx).isCounting()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
bool countNormalized = g->constr()->isCountNormalized (X);
|
bool countNormalized = g->constr()->isCountNormalized (X);
|
||||||
@ -332,10 +340,10 @@ GroundOperator::getCost (void)
|
|||||||
unsigned cost = 0;
|
unsigned cost = 0;
|
||||||
bool isCountingLv = (*pfIter_)->countedLogVars().contains (X_);
|
bool isCountingLv = (*pfIter_)->countedLogVars().contains (X_);
|
||||||
if (isCountingLv) {
|
if (isCountingLv) {
|
||||||
int fIdx = (*pfIter_)->indexOfFormulaWithLogVar (X_);
|
int fIdx = (*pfIter_)->indexOfLogVar (X_);
|
||||||
unsigned currSize = (*pfIter_)->size();
|
unsigned currSize = (*pfIter_)->size();
|
||||||
unsigned nrHists = (*pfIter_)->range (fIdx);
|
unsigned nrHists = (*pfIter_)->range (fIdx);
|
||||||
unsigned range = (*pfIter_)->formula(fIdx).range();
|
unsigned range = (*pfIter_)->argument (fIdx).range();
|
||||||
unsigned nrSymbols = (*pfIter_)->constr()->getConditionalCount (X_);
|
unsigned nrSymbols = (*pfIter_)->constr()->getConditionalCount (X_);
|
||||||
cost = (currSize / nrHists) * (std::pow (range, nrSymbols));
|
cost = (currSize / nrHists) * (std::pow (range, nrSymbols));
|
||||||
} else {
|
} else {
|
||||||
@ -350,18 +358,17 @@ void
|
|||||||
GroundOperator::apply (void)
|
GroundOperator::apply (void)
|
||||||
{
|
{
|
||||||
bool countedLv = (*pfIter_)->countedLogVars().contains (X_);
|
bool countedLv = (*pfIter_)->countedLogVars().contains (X_);
|
||||||
|
Parfactor* pf = *pfIter_;
|
||||||
|
pfList_.remove (pfIter_);
|
||||||
if (countedLv) {
|
if (countedLv) {
|
||||||
(*pfIter_)->fullExpand (X_);
|
pf->fullExpand (X_);
|
||||||
(*pfIter_)->setNewGroups();
|
pfList_.add (pf);
|
||||||
pfList_.shatter();
|
|
||||||
} else {
|
} else {
|
||||||
ConstraintTrees cts = (*pfIter_)->constr()->ground (X_);
|
ConstraintTrees cts = pf->constr()->ground (X_);
|
||||||
for (unsigned i = 0; i < cts.size(); i++) {
|
for (unsigned i = 0; i < cts.size(); i++) {
|
||||||
Parfactor* newPf = new Parfactor (*pfIter_, cts[i]);
|
pfList_.add (new Parfactor (pf, cts[i]));
|
||||||
pfList_.add (newPf);
|
|
||||||
}
|
}
|
||||||
pfList_.deleteAndRemove (pfIter_);
|
delete pf;
|
||||||
pfList_.shatter();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -393,24 +400,13 @@ GroundOperator::toString (void)
|
|||||||
((*pfIter_)->countedLogVars().contains (X_))
|
((*pfIter_)->countedLogVars().contains (X_))
|
||||||
? ss << "full expanding "
|
? ss << "full expanding "
|
||||||
: ss << "grounding " ;
|
: ss << "grounding " ;
|
||||||
ss << X_ << " in " << (*pfIter_)->getHeaderString();
|
ss << X_ << " in " << (*pfIter_)->getLabel();
|
||||||
ss << " [cost=" << getCost() << "]" << endl;
|
ss << " [cost=" << getCost() << "]" << endl;
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
FoveSolver::FoveSolver (const ParfactorList* pfList)
|
|
||||||
{
|
|
||||||
for (ParfactorList::const_iterator it = pfList->begin();
|
|
||||||
it != pfList->end();
|
|
||||||
it ++) {
|
|
||||||
pfList_.addShattered (new Parfactor (**it));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
Params
|
||||||
FoveSolver::getPosterioriOf (const Ground& query)
|
FoveSolver::getPosterioriOf (const Ground& query)
|
||||||
{
|
{
|
||||||
@ -422,14 +418,12 @@ FoveSolver::getPosterioriOf (const Ground& query)
|
|||||||
Params
|
Params
|
||||||
FoveSolver::getJointDistributionOf (const Grounds& query)
|
FoveSolver::getJointDistributionOf (const Grounds& query)
|
||||||
{
|
{
|
||||||
shatterAgainstQuery (query);
|
|
||||||
runSolver (query);
|
runSolver (query);
|
||||||
(*pfList_.begin())->normalize();
|
(*pfList_.begin())->normalize();
|
||||||
Params params = (*pfList_.begin())->params();
|
Params params = (*pfList_.begin())->params();
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
Util::fromLog (params);
|
Util::fromLog (params);
|
||||||
}
|
}
|
||||||
delete *pfList_.begin();
|
|
||||||
return params;
|
return params;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -438,32 +432,38 @@ FoveSolver::getJointDistributionOf (const Grounds& query)
|
|||||||
void
|
void
|
||||||
FoveSolver::absorveEvidence (
|
FoveSolver::absorveEvidence (
|
||||||
ParfactorList& pfList,
|
ParfactorList& pfList,
|
||||||
const ObservedFormulas& obsFormulas)
|
ObservedFormulas& obsFormulas)
|
||||||
{
|
{
|
||||||
|
for (unsigned i = 0; i < obsFormulas.size(); i++) {
|
||||||
|
Parfactors newPfs;
|
||||||
ParfactorList::iterator it = pfList.begin();
|
ParfactorList::iterator it = pfList.begin();
|
||||||
while (it != pfList.end()) {
|
while (it != pfList.end()) {
|
||||||
bool increment = true;
|
Parfactor* pf = *it;
|
||||||
for (unsigned i = 0; i < obsFormulas.size(); i++) {
|
it = pfList.remove (it);
|
||||||
if (absorved (pfList, it, obsFormulas[i])) {
|
Parfactors absorvedPfs = absorve (obsFormulas[i], pf);
|
||||||
it = pfList.deleteAndRemove (it);
|
if (absorvedPfs.empty() == false) {
|
||||||
increment = false;
|
if (absorvedPfs.size() == 1 && absorvedPfs[0] == 0) {
|
||||||
break;
|
// just remove pf;
|
||||||
|
} else {
|
||||||
|
Util::addToVector (newPfs, absorvedPfs);
|
||||||
}
|
}
|
||||||
}
|
delete pf;
|
||||||
if (increment) {
|
} else {
|
||||||
|
it = pfList.insertShattered (it, pf);
|
||||||
++ it;
|
++ it;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
pfList.shatter();
|
pfList.add (newPfs);
|
||||||
if (obsFormulas.empty() == false) {
|
}
|
||||||
cout << "*******************************************************" << endl;
|
if (Constants::DEBUG > 1 && obsFormulas.empty() == false) {
|
||||||
|
Util::printAsteriskLine();
|
||||||
cout << "AFTER EVIDENCE ABSORVED" << endl;
|
cout << "AFTER EVIDENCE ABSORVED" << endl;
|
||||||
for (unsigned i = 0; i < obsFormulas.size(); i++) {
|
for (unsigned i = 0; i < obsFormulas.size(); i++) {
|
||||||
cout << " -> " << *obsFormulas[i] << endl;
|
cout << " -> " << obsFormulas[i] << endl;
|
||||||
}
|
|
||||||
cout << "*******************************************************" << endl;
|
|
||||||
}
|
}
|
||||||
|
Util::printAsteriskLine();
|
||||||
pfList.print();
|
pfList.print();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -473,15 +473,15 @@ FoveSolver::countNormalize (
|
|||||||
Parfactor* g,
|
Parfactor* g,
|
||||||
const LogVarSet& set)
|
const LogVarSet& set)
|
||||||
{
|
{
|
||||||
if (set.empty()) {
|
|
||||||
assert (false); // TODO
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
Parfactors normPfs;
|
Parfactors normPfs;
|
||||||
|
if (set.empty()) {
|
||||||
|
normPfs.push_back (new Parfactor (*g));
|
||||||
|
} else {
|
||||||
ConstraintTrees normCts = g->constr()->countNormalize (set);
|
ConstraintTrees normCts = g->constr()->countNormalize (set);
|
||||||
for (unsigned i = 0; i < normCts.size(); i++) {
|
for (unsigned i = 0; i < normCts.size(); i++) {
|
||||||
normPfs.push_back (new Parfactor (g, normCts[i]));
|
normPfs.push_back (new Parfactor (g, normCts[i]));
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return normPfs;
|
return normPfs;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -490,17 +490,25 @@ FoveSolver::countNormalize (
|
|||||||
void
|
void
|
||||||
FoveSolver::runSolver (const Grounds& query)
|
FoveSolver::runSolver (const Grounds& query)
|
||||||
{
|
{
|
||||||
|
shatterAgainstQuery (query);
|
||||||
|
runWeakBayesBall (query);
|
||||||
while (true) {
|
while (true) {
|
||||||
cout << "---------------------------------------------------" << endl;
|
if (Constants::DEBUG > 1) {
|
||||||
|
Util::printDashedLine();
|
||||||
pfList_.print();
|
pfList_.print();
|
||||||
LiftedOperator::printValidOps (pfList_, query);
|
LiftedOperator::printValidOps (pfList_, query);
|
||||||
|
}
|
||||||
LiftedOperator* op = getBestOperation (query);
|
LiftedOperator* op = getBestOperation (query);
|
||||||
if (op == 0) {
|
if (op == 0) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
if (Constants::DEBUG > 1) {
|
||||||
cout << "best operation: " << op->toString() << endl;
|
cout << "best operation: " << op->toString() << endl;
|
||||||
op->apply();
|
|
||||||
}
|
}
|
||||||
|
op->apply();
|
||||||
|
delete op;
|
||||||
|
}
|
||||||
|
assert (pfList_.size() > 0);
|
||||||
if (pfList_.size() > 1) {
|
if (pfList_.size() > 1) {
|
||||||
ParfactorList::iterator pfIter = pfList_.begin();
|
ParfactorList::iterator pfIter = pfList_.begin();
|
||||||
pfIter ++;
|
pfIter ++;
|
||||||
@ -514,26 +522,6 @@ FoveSolver::runSolver (const Grounds& query)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
|
||||||
FoveSolver::allEliminated (const Grounds&)
|
|
||||||
{
|
|
||||||
ParfactorList::iterator pfIter = pfList_.begin();
|
|
||||||
while (pfIter != pfList_.end()) {
|
|
||||||
const ProbFormulas formulas = (*pfIter)->formulas();
|
|
||||||
for (unsigned i = 0; i < formulas.size(); i++) {
|
|
||||||
//bool toElim = false;
|
|
||||||
//for (unsigned j = 0; j < queries.size(); j++) {
|
|
||||||
// if ((*pfIter)->containsGround (queries[i]) == false) {
|
|
||||||
// return
|
|
||||||
// }
|
|
||||||
}
|
|
||||||
++ pfIter;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
LiftedOperator*
|
LiftedOperator*
|
||||||
FoveSolver::getBestOperation (const Grounds& query)
|
FoveSolver::getBestOperation (const Grounds& query)
|
||||||
{
|
{
|
||||||
@ -548,156 +536,170 @@ FoveSolver::getBestOperation (const Grounds& query)
|
|||||||
bestCost = cost;
|
bestCost = cost;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
for (unsigned i = 0; i < validOps.size(); i++) {
|
||||||
|
if (validOps[i] != bestOp) {
|
||||||
|
delete validOps[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
return bestOp;
|
return bestOp;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
FoveSolver::runWeakBayesBall (const Grounds& query)
|
||||||
|
{
|
||||||
|
queue<unsigned> todo; // groups to process
|
||||||
|
set<unsigned> done; // processed or in queue
|
||||||
|
for (unsigned i = 0; i < query.size(); i++) {
|
||||||
|
ParfactorList::iterator it = pfList_.begin();
|
||||||
|
while (it != pfList_.end()) {
|
||||||
|
int group = (*it)->findGroup (query[i]);
|
||||||
|
if (group != -1) {
|
||||||
|
todo.push (group);
|
||||||
|
done.insert (group);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
++ it;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
set<Parfactor*> requiredPfs;
|
||||||
|
while (todo.empty() == false) {
|
||||||
|
unsigned group = todo.front();
|
||||||
|
ParfactorList::iterator it = pfList_.begin();
|
||||||
|
while (it != pfList_.end()) {
|
||||||
|
if (Util::contains (requiredPfs, *it) == false &&
|
||||||
|
(*it)->containsGroup (group)) {
|
||||||
|
vector<unsigned> groups = (*it)->getAllGroups();
|
||||||
|
for (unsigned i = 0; i < groups.size(); i++) {
|
||||||
|
if (Util::contains (done, groups[i]) == false) {
|
||||||
|
todo.push (groups[i]);
|
||||||
|
done.insert (groups[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
requiredPfs.insert (*it);
|
||||||
|
}
|
||||||
|
++ it;
|
||||||
|
}
|
||||||
|
todo.pop();
|
||||||
|
}
|
||||||
|
|
||||||
|
ParfactorList::iterator it = pfList_.begin();
|
||||||
|
while (it != pfList_.end()) {
|
||||||
|
if (Util::contains (requiredPfs, *it) == false) {
|
||||||
|
it = pfList_.removeAndDelete (it);
|
||||||
|
} else {
|
||||||
|
++ it;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (Constants::DEBUG > 1) {
|
||||||
|
Util::printHeader ("REQUIRED PARFACTORS");
|
||||||
|
pfList_.print();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
FoveSolver::shatterAgainstQuery (const Grounds& query)
|
FoveSolver::shatterAgainstQuery (const Grounds& query)
|
||||||
{
|
{
|
||||||
// return;
|
return ;
|
||||||
for (unsigned i = 0; i < query.size(); i++) {
|
for (unsigned i = 0; i < query.size(); i++) {
|
||||||
if (query[i].isAtom()) {
|
if (query[i].isAtom()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
ParfactorList pfListCopy = pfList_;
|
Parfactors newPfs;
|
||||||
pfList_.clear();
|
ParfactorList::iterator it = pfList_.begin();
|
||||||
for (ParfactorList::iterator it = pfListCopy.begin();
|
while (it != pfList_.end()) {
|
||||||
it != pfListCopy.end(); ++ it) {
|
if ((*it)->containsGround (query[i])) {
|
||||||
Parfactor* pf = *it;
|
|
||||||
if (pf->containsGround (query[i])) {
|
|
||||||
std::pair<ConstraintTree*, ConstraintTree*> split =
|
std::pair<ConstraintTree*, ConstraintTree*> split =
|
||||||
pf->constr()->split (query[i].args(), query[i].arity());
|
(*it)->constr()->split (query[i].args(), query[i].arity());
|
||||||
ConstraintTree* commCt = split.first;
|
ConstraintTree* commCt = split.first;
|
||||||
ConstraintTree* exclCt = split.second;
|
ConstraintTree* exclCt = split.second;
|
||||||
pfList_.add (new Parfactor (pf, commCt));
|
newPfs.push_back (new Parfactor (*it, commCt));
|
||||||
if (exclCt->empty() == false) {
|
if (exclCt->empty() == false) {
|
||||||
pfList_.add (new Parfactor (pf, exclCt));
|
newPfs.push_back (new Parfactor (*it, exclCt));
|
||||||
} else {
|
} else {
|
||||||
delete exclCt;
|
delete exclCt;
|
||||||
}
|
}
|
||||||
delete pf;
|
it = pfList_.removeAndDelete (it);
|
||||||
} else {
|
} else {
|
||||||
pfList_.add (pf);
|
++ it;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
pfList_.shatter();
|
pfList_.add (newPfs);
|
||||||
}
|
}
|
||||||
|
if (Constants::DEBUG > 1) {
|
||||||
cout << endl;
|
cout << endl;
|
||||||
cout << "*******************************************************" << endl;
|
Util::printAsteriskLine();
|
||||||
cout << "SHATTERED AGAINST THE QUERY" << endl;
|
cout << "SHATTERED AGAINST THE QUERY" << endl;
|
||||||
for (unsigned i = 0; i < query.size(); i++) {
|
for (unsigned i = 0; i < query.size(); i++) {
|
||||||
cout << " -> " << query[i] << endl;
|
cout << " -> " << query[i] << endl;
|
||||||
}
|
}
|
||||||
cout << "*******************************************************" << endl;
|
Util::printAsteriskLine();
|
||||||
pfList_.print();
|
pfList_.print();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
Parfactors
|
||||||
FoveSolver::absorved (
|
FoveSolver::absorve (
|
||||||
ParfactorList& pfList,
|
ObservedFormula& obsFormula,
|
||||||
ParfactorList::iterator pfIter,
|
Parfactor* g)
|
||||||
const ObservedFormula* obsFormula)
|
|
||||||
{
|
{
|
||||||
Parfactors absorvedPfs;
|
Parfactors absorvedPfs;
|
||||||
Parfactor* g = *pfIter;
|
const ProbFormulas& formulas = g->arguments();
|
||||||
const ProbFormulas& formulas = g->formulas();
|
|
||||||
for (unsigned i = 0; i < formulas.size(); i++) {
|
for (unsigned i = 0; i < formulas.size(); i++) {
|
||||||
if (obsFormula->functor() == formulas[i].functor() &&
|
if (obsFormula.functor() == formulas[i].functor() &&
|
||||||
obsFormula->arity() == formulas[i].arity()) {
|
obsFormula.arity() == formulas[i].arity()) {
|
||||||
|
|
||||||
if (obsFormula->isAtom()) {
|
if (obsFormula.isAtom()) {
|
||||||
if (formulas.size() > 1) {
|
if (formulas.size() > 1) {
|
||||||
g->absorveEvidence (i, obsFormula->evidence());
|
g->absorveEvidence (formulas[i], obsFormula.evidence());
|
||||||
} else {
|
} else {
|
||||||
return true;
|
// hack to erase parfactor g
|
||||||
|
absorvedPfs.push_back (0);
|
||||||
}
|
}
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
g->constr()->moveToTop (formulas[i].logVars());
|
g->constr()->moveToTop (formulas[i].logVars());
|
||||||
std::pair<ConstraintTree*, ConstraintTree*> res
|
std::pair<ConstraintTree*, ConstraintTree*> res
|
||||||
= g->constr()->split (obsFormula->constr(), formulas[i].arity());
|
= g->constr()->split (&(obsFormula.constr()), formulas[i].arity());
|
||||||
ConstraintTree* commCt = res.first;
|
ConstraintTree* commCt = res.first;
|
||||||
ConstraintTree* exclCt = res.second;
|
ConstraintTree* exclCt = res.second;
|
||||||
|
|
||||||
if (commCt->empty()) {
|
if (commCt->empty() == false) {
|
||||||
delete commCt;
|
|
||||||
delete exclCt;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (exclCt->empty() == false) {
|
|
||||||
pfList.add (new Parfactor (g, exclCt));
|
|
||||||
} else {
|
|
||||||
delete exclCt;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (formulas.size() > 1) {
|
if (formulas.size() > 1) {
|
||||||
LogVarSet excl = g->exclusiveLogVars (i);
|
LogVarSet excl = g->exclusiveLogVars (i);
|
||||||
Parfactors countNormPfs = countNormalize (g, excl);
|
Parfactors countNormPfs = countNormalize (g, excl);
|
||||||
for (unsigned j = 0; j < countNormPfs.size(); j++) {
|
for (unsigned j = 0; j < countNormPfs.size(); j++) {
|
||||||
countNormPfs[j]->absorveEvidence (i, obsFormula->evidence());
|
countNormPfs[j]->absorveEvidence (
|
||||||
|
formulas[i], obsFormula.evidence());
|
||||||
absorvedPfs.push_back (countNormPfs[j]);
|
absorvedPfs.push_back (countNormPfs[j]);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
delete commCt;
|
delete commCt;
|
||||||
}
|
}
|
||||||
return true;
|
if (exclCt->empty() == false) {
|
||||||
|
absorvedPfs.push_back (new Parfactor (g, exclCt));
|
||||||
|
} else {
|
||||||
|
delete exclCt;
|
||||||
|
}
|
||||||
|
if (absorvedPfs.empty()) {
|
||||||
|
// hack to erase parfactor g
|
||||||
|
absorvedPfs.push_back (0);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
delete commCt;
|
||||||
|
delete exclCt;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
|
||||||
FoveSolver::proper (
|
|
||||||
const ProbFormula& f1,
|
|
||||||
ConstraintTree* c1,
|
|
||||||
const ProbFormula& f2,
|
|
||||||
ConstraintTree* c2)
|
|
||||||
{
|
|
||||||
return disjoint (f1, c1, f2, c2)
|
|
||||||
|| identical (f1, c1, f2, c2);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
|
||||||
FoveSolver::identical (
|
|
||||||
const ProbFormula& f1,
|
|
||||||
ConstraintTree* c1,
|
|
||||||
const ProbFormula& f2,
|
|
||||||
ConstraintTree* c2)
|
|
||||||
{
|
|
||||||
if (f1.sameSkeletonAs (f2) == false) {
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
c1->moveToTop (f1.logVars());
|
return absorvedPfs;
|
||||||
c2->moveToTop (f2.logVars());
|
|
||||||
return ConstraintTree::identical (
|
|
||||||
c1, c2, f1.logVars().size());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
|
||||||
FoveSolver::disjoint (
|
|
||||||
const ProbFormula& f1,
|
|
||||||
ConstraintTree* c1,
|
|
||||||
const ProbFormula& f2,
|
|
||||||
ConstraintTree* c2)
|
|
||||||
{
|
|
||||||
if (f1.sameSkeletonAs (f2) == false) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
c1->moveToTop (f1.logVars());
|
|
||||||
c2->moveToTop (f2.logVars());
|
|
||||||
return ConstraintTree::overlap (
|
|
||||||
c1, c2, f1.arity()) == false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -9,10 +9,14 @@ class LiftedOperator
|
|||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
virtual unsigned getCost (void) = 0;
|
virtual unsigned getCost (void) = 0;
|
||||||
|
|
||||||
virtual void apply (void) = 0;
|
virtual void apply (void) = 0;
|
||||||
|
|
||||||
virtual string toString (void) = 0;
|
virtual string toString (void) = 0;
|
||||||
|
|
||||||
static vector<LiftedOperator*> getValidOps (
|
static vector<LiftedOperator*> getValidOps (
|
||||||
ParfactorList&, const Grounds&);
|
ParfactorList&, const Grounds&);
|
||||||
|
|
||||||
static void printValidOps (ParfactorList&, const Grounds&);
|
static void printValidOps (ParfactorList&, const Grounds&);
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -23,16 +27,24 @@ class SumOutOperator : public LiftedOperator
|
|||||||
public:
|
public:
|
||||||
SumOutOperator (unsigned group, ParfactorList& pfList)
|
SumOutOperator (unsigned group, ParfactorList& pfList)
|
||||||
: group_(group), pfList_(pfList) { }
|
: group_(group), pfList_(pfList) { }
|
||||||
|
|
||||||
unsigned getCost (void);
|
unsigned getCost (void);
|
||||||
|
|
||||||
void apply (void);
|
void apply (void);
|
||||||
|
|
||||||
static vector<SumOutOperator*> getValidOps (
|
static vector<SumOutOperator*> getValidOps (
|
||||||
ParfactorList&, const Grounds&);
|
ParfactorList&, const Grounds&);
|
||||||
|
|
||||||
string toString (void);
|
string toString (void);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static bool validOp (unsigned, ParfactorList&, const Grounds&);
|
static bool validOp (unsigned, ParfactorList&, const Grounds&);
|
||||||
|
|
||||||
static vector<ParfactorList::iterator> parfactorsWithGroup (
|
static vector<ParfactorList::iterator> parfactorsWithGroup (
|
||||||
ParfactorList& pfList, unsigned group);
|
ParfactorList& pfList, unsigned group);
|
||||||
|
|
||||||
static bool isToEliminate (Parfactor*, unsigned, const Grounds&);
|
static bool isToEliminate (Parfactor*, unsigned, const Grounds&);
|
||||||
|
|
||||||
unsigned group_;
|
unsigned group_;
|
||||||
ParfactorList& pfList_;
|
ParfactorList& pfList_;
|
||||||
};
|
};
|
||||||
@ -47,12 +59,18 @@ class CountingOperator : public LiftedOperator
|
|||||||
LogVar X,
|
LogVar X,
|
||||||
ParfactorList& pfList)
|
ParfactorList& pfList)
|
||||||
: pfIter_(pfIter), X_(X), pfList_(pfList) { }
|
: pfIter_(pfIter), X_(X), pfList_(pfList) { }
|
||||||
|
|
||||||
unsigned getCost (void);
|
unsigned getCost (void);
|
||||||
|
|
||||||
void apply (void);
|
void apply (void);
|
||||||
|
|
||||||
static vector<CountingOperator*> getValidOps (ParfactorList&);
|
static vector<CountingOperator*> getValidOps (ParfactorList&);
|
||||||
|
|
||||||
string toString (void);
|
string toString (void);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static bool validOp (Parfactor*, LogVar);
|
static bool validOp (Parfactor*, LogVar);
|
||||||
|
|
||||||
ParfactorList::iterator pfIter_;
|
ParfactorList::iterator pfIter_;
|
||||||
LogVar X_;
|
LogVar X_;
|
||||||
ParfactorList& pfList_;
|
ParfactorList& pfList_;
|
||||||
@ -68,10 +86,15 @@ class GroundOperator : public LiftedOperator
|
|||||||
LogVar X,
|
LogVar X,
|
||||||
ParfactorList& pfList)
|
ParfactorList& pfList)
|
||||||
: pfIter_(pfIter), X_(X), pfList_(pfList) { }
|
: pfIter_(pfIter), X_(X), pfList_(pfList) { }
|
||||||
|
|
||||||
unsigned getCost (void);
|
unsigned getCost (void);
|
||||||
|
|
||||||
void apply (void);
|
void apply (void);
|
||||||
|
|
||||||
static vector<GroundOperator*> getValidOps (ParfactorList&);
|
static vector<GroundOperator*> getValidOps (ParfactorList&);
|
||||||
|
|
||||||
string toString (void);
|
string toString (void);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
ParfactorList::iterator pfIter_;
|
ParfactorList::iterator pfIter_;
|
||||||
LogVar X_;
|
LogVar X_;
|
||||||
@ -83,47 +106,27 @@ class GroundOperator : public LiftedOperator
|
|||||||
class FoveSolver
|
class FoveSolver
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
FoveSolver (const ParfactorList*);
|
FoveSolver (const ParfactorList& pfList) : pfList_(pfList) { }
|
||||||
|
|
||||||
Params getPosterioriOf (const Ground&);
|
Params getPosterioriOf (const Ground&);
|
||||||
|
|
||||||
Params getJointDistributionOf (const Grounds&);
|
Params getJointDistributionOf (const Grounds&);
|
||||||
|
|
||||||
static void absorveEvidence (
|
static void absorveEvidence (
|
||||||
ParfactorList& pfList,
|
ParfactorList& pfList, ObservedFormulas& obsFormulas);
|
||||||
const ObservedFormulas& obsFormulas);
|
|
||||||
|
|
||||||
static Parfactors countNormalize (Parfactor*, const LogVarSet&);
|
static Parfactors countNormalize (Parfactor*, const LogVarSet&);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void runSolver (const Grounds&);
|
void runSolver (const Grounds&);
|
||||||
bool allEliminated (const Grounds&);
|
|
||||||
LiftedOperator* getBestOperation (const Grounds&);
|
LiftedOperator* getBestOperation (const Grounds&);
|
||||||
|
|
||||||
|
void runWeakBayesBall (const Grounds&);
|
||||||
|
|
||||||
void shatterAgainstQuery (const Grounds&);
|
void shatterAgainstQuery (const Grounds&);
|
||||||
|
|
||||||
static bool absorved (
|
static Parfactors absorve (ObservedFormula&, Parfactor*);
|
||||||
ParfactorList& pfList,
|
|
||||||
ParfactorList::iterator pfIter,
|
|
||||||
const ObservedFormula*);
|
|
||||||
|
|
||||||
public:
|
|
||||||
|
|
||||||
static bool proper (
|
|
||||||
const ProbFormula&,
|
|
||||||
ConstraintTree*,
|
|
||||||
const ProbFormula&,
|
|
||||||
ConstraintTree*);
|
|
||||||
|
|
||||||
static bool identical (
|
|
||||||
const ProbFormula&,
|
|
||||||
ConstraintTree*,
|
|
||||||
const ProbFormula&,
|
|
||||||
ConstraintTree*);
|
|
||||||
|
|
||||||
static bool disjoint (
|
|
||||||
const ProbFormula&,
|
|
||||||
ConstraintTree*,
|
|
||||||
const ProbFormula&,
|
|
||||||
ConstraintTree*);
|
|
||||||
|
|
||||||
ParfactorList pfList_;
|
ParfactorList pfList_;
|
||||||
};
|
};
|
||||||
|
@ -1,22 +1,22 @@
|
|||||||
#ifndef HORUS_GRAPHICALMODEL_H
|
#ifndef HORUS_GRAPHICALMODEL_H
|
||||||
#define HORUS_GRAPHICALMODEL_H
|
#define HORUS_GRAPHICALMODEL_H
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
#include "VarNode.h"
|
#include "VarNode.h"
|
||||||
#include "Distribution.h"
|
#include "Util.h"
|
||||||
#include "Horus.h"
|
#include "Horus.h"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
|
|
||||||
struct VariableInfo
|
struct VarInfo
|
||||||
{
|
{
|
||||||
VariableInfo (string l, const States& sts)
|
VarInfo (string l, const States& sts) : label(l), states(sts) { }
|
||||||
{
|
|
||||||
label = l;
|
|
||||||
states = sts;
|
|
||||||
}
|
|
||||||
string label;
|
string label;
|
||||||
States states;
|
States states;
|
||||||
};
|
};
|
||||||
@ -25,42 +25,39 @@ struct VariableInfo
|
|||||||
class GraphicalModel
|
class GraphicalModel
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
virtual ~GraphicalModel (void) {};
|
virtual ~GraphicalModel (void) { };
|
||||||
|
|
||||||
virtual VarNode* getVariableNode (VarId) const = 0;
|
virtual VarNode* getVariableNode (VarId) const = 0;
|
||||||
|
|
||||||
virtual VarNodes getVariableNodes (void) const = 0;
|
virtual VarNodes getVariableNodes (void) const = 0;
|
||||||
|
|
||||||
virtual void printGraphicalModel (void) const = 0;
|
virtual void printGraphicalModel (void) const = 0;
|
||||||
|
|
||||||
static void addVariableInformation (VarId vid, string label,
|
static void addVariableInformation (
|
||||||
const States& states)
|
VarId vid, string label, const States& states)
|
||||||
{
|
{
|
||||||
assert (varsInfo_.find (vid) == varsInfo_.end());
|
assert (Util::contains (varsInfo_, vid) == false);
|
||||||
varsInfo_.insert (make_pair (vid, VariableInfo (label, states)));
|
varsInfo_.insert (make_pair (vid, VarInfo (label, states)));
|
||||||
}
|
}
|
||||||
static VariableInfo getVariableInformation (VarId vid)
|
|
||||||
|
static VarInfo getVarInformation (VarId vid)
|
||||||
{
|
{
|
||||||
assert (varsInfo_.find (vid) != varsInfo_.end());
|
assert (Util::contains (varsInfo_, vid));
|
||||||
return varsInfo_.find (vid)->second;
|
return varsInfo_.find (vid)->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool variablesHaveInformation (void)
|
static bool variablesHaveInformation (void)
|
||||||
{
|
{
|
||||||
return varsInfo_.size() != 0;
|
return varsInfo_.size() != 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void clearVariablesInformation (void)
|
static void clearVariablesInformation (void)
|
||||||
{
|
{
|
||||||
varsInfo_.clear();
|
varsInfo_.clear();
|
||||||
}
|
}
|
||||||
static void addDistribution (unsigned id, Distribution* dist)
|
|
||||||
{
|
|
||||||
distsInfo_[id] = dist;
|
|
||||||
}
|
|
||||||
static void updateDistribution (unsigned id, const Params& params)
|
|
||||||
{
|
|
||||||
distsInfo_[id]->updateParameters (params);
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static unordered_map<VarId,VariableInfo> varsInfo_;
|
static unordered_map<VarId,VarInfo> varsInfo_;
|
||||||
static unordered_map<unsigned,Distribution*> distsInfo_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // HORUS_GRAPHICALMODEL_H
|
#endif // HORUS_GRAPHICALMODEL_H
|
||||||
|
@ -84,16 +84,34 @@ HistogramSet::nrHistograms (unsigned N, unsigned R)
|
|||||||
|
|
||||||
unsigned
|
unsigned
|
||||||
HistogramSet::findIndex (
|
HistogramSet::findIndex (
|
||||||
const Histogram& hist,
|
const Histogram& h,
|
||||||
const vector<Histogram>& histograms)
|
const vector<Histogram>& hists)
|
||||||
{
|
{
|
||||||
vector<Histogram>::const_iterator it = std::lower_bound (
|
vector<Histogram>::const_iterator it = std::lower_bound (
|
||||||
histograms.begin(),
|
hists.begin(), hists.end(), h, std::greater<Histogram>());
|
||||||
histograms.end(),
|
assert (it != hists.end() && *it == h);
|
||||||
hist,
|
return std::distance (hists.begin(), it);
|
||||||
std::greater<Histogram>());
|
}
|
||||||
assert (it != histograms.end() && *it == hist);
|
|
||||||
return std::distance (histograms.begin(), it);
|
|
||||||
|
|
||||||
|
vector<double>
|
||||||
|
HistogramSet::getNumAssigns (unsigned N, unsigned R)
|
||||||
|
{
|
||||||
|
HistogramSet hs (N, R);
|
||||||
|
unsigned N_factorial = Util::factorial (N);
|
||||||
|
unsigned H = hs.nrHistograms();
|
||||||
|
vector<double> numAssigns;
|
||||||
|
numAssigns.reserve (H);
|
||||||
|
for (unsigned h = 0; h < H; h++) {
|
||||||
|
unsigned prod = 1;
|
||||||
|
for (unsigned r = 0; r < R; r++) {
|
||||||
|
prod *= Util::factorial (hs[r]);
|
||||||
|
}
|
||||||
|
numAssigns.push_back (LogAware::tl (N_factorial / prod));
|
||||||
|
hs.nextHistogram();
|
||||||
|
}
|
||||||
|
return numAssigns;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -26,8 +26,9 @@ class HistogramSet
|
|||||||
static unsigned nrHistograms (unsigned, unsigned);
|
static unsigned nrHistograms (unsigned, unsigned);
|
||||||
|
|
||||||
static unsigned findIndex (
|
static unsigned findIndex (
|
||||||
const Histogram&,
|
const Histogram&, const vector<Histogram>&);
|
||||||
const vector<Histogram>&);
|
|
||||||
|
static vector<double> getNumAssigns (unsigned, unsigned);
|
||||||
|
|
||||||
friend std::ostream& operator<< (ostream &os, const HistogramSet& hs);
|
friend std::ostream& operator<< (ostream &os, const HistogramSet& hs);
|
||||||
|
|
||||||
|
@ -1,17 +1,9 @@
|
|||||||
#ifndef HORUS_HORUS_H
|
#ifndef HORUS_HORUS_H
|
||||||
#define HORUS_HORUS_H
|
#define HORUS_HORUS_H
|
||||||
|
|
||||||
#include <cmath>
|
|
||||||
#include <cassert>
|
|
||||||
#include <limits>
|
#include <limits>
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <unordered_map>
|
|
||||||
|
|
||||||
#include <iostream>
|
|
||||||
#include <fstream>
|
|
||||||
#include <sstream>
|
|
||||||
|
|
||||||
#define DISALLOW_COPY_AND_ASSIGN(TypeName) \
|
#define DISALLOW_COPY_AND_ASSIGN(TypeName) \
|
||||||
TypeName(const TypeName&); \
|
TypeName(const TypeName&); \
|
||||||
@ -37,37 +29,36 @@ typedef vector<string> States;
|
|||||||
typedef vector<unsigned> Ranges;
|
typedef vector<unsigned> Ranges;
|
||||||
|
|
||||||
|
|
||||||
namespace Globals {
|
enum InfAlgorithms
|
||||||
extern bool logDomain;
|
{
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
// level of debug information
|
|
||||||
static const unsigned DL = 1;
|
|
||||||
|
|
||||||
static const int NO_EVIDENCE = -1;
|
|
||||||
|
|
||||||
// number of digits to show when printing a parameter
|
|
||||||
static const unsigned PRECISION = 5;
|
|
||||||
|
|
||||||
static const bool COLLECT_STATISTICS = false;
|
|
||||||
|
|
||||||
static const bool EXPORT_TO_GRAPHVIZ = false;
|
|
||||||
static const unsigned EXPORT_MINIMAL_SIZE = 100;
|
|
||||||
|
|
||||||
static const double INF = -numeric_limits<double>::infinity();
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
namespace InfAlgorithms {
|
|
||||||
enum InfAlgs
|
|
||||||
{
|
|
||||||
VE, // variable elimination
|
VE, // variable elimination
|
||||||
BN_BP, // bayesian network belief propagation
|
BN_BP, // bayesian network belief propagation
|
||||||
FG_BP, // factor graph belief propagation
|
FG_BP, // factor graph belief propagation
|
||||||
CBP // counting bp solver
|
CBP // counting bp solver
|
||||||
};
|
};
|
||||||
extern InfAlgs infAlgorithm;
|
|
||||||
|
|
||||||
|
namespace Globals {
|
||||||
|
|
||||||
|
extern bool logDomain;
|
||||||
|
|
||||||
|
extern InfAlgorithms infAlgorithm;
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
namespace Constants {
|
||||||
|
|
||||||
|
// level of debug information
|
||||||
|
const unsigned DEBUG = 2;
|
||||||
|
|
||||||
|
const int NO_EVIDENCE = -1;
|
||||||
|
|
||||||
|
// number of digits to show when printing a parameter
|
||||||
|
const unsigned PRECISION = 5;
|
||||||
|
|
||||||
|
const bool COLLECT_STATS = false;
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
@ -10,10 +10,6 @@
|
|||||||
#include "FgBpSolver.h"
|
#include "FgBpSolver.h"
|
||||||
#include "CbpSolver.h"
|
#include "CbpSolver.h"
|
||||||
|
|
||||||
//#include "TinySet.h"
|
|
||||||
#include "LiftedUtils.h"
|
|
||||||
|
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
void processArguments (BayesNet&, int, const char* []);
|
void processArguments (BayesNet&, int, const char* []);
|
||||||
@ -24,38 +20,9 @@ const string USAGE = "usage: \
|
|||||||
./hcli FILE [VARIABLE | OBSERVED_VARIABLE=EVIDENCE]..." ;
|
./hcli FILE [VARIABLE | OBSERVED_VARIABLE=EVIDENCE]..." ;
|
||||||
|
|
||||||
|
|
||||||
class Cenas
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
Cenas (int cc)
|
|
||||||
{
|
|
||||||
c = cc;
|
|
||||||
}
|
|
||||||
//operator int (void) const
|
|
||||||
//{
|
|
||||||
// cout << "return int" << endl;
|
|
||||||
// return c;
|
|
||||||
//}
|
|
||||||
operator double (void) const
|
|
||||||
{
|
|
||||||
cout << "return double" << endl;
|
|
||||||
return 0.0;
|
|
||||||
}
|
|
||||||
private:
|
|
||||||
int c;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
int
|
int
|
||||||
main (int argc, const char* argv[])
|
main (int argc, const char* argv[])
|
||||||
{
|
{
|
||||||
LogVar X = 3;
|
|
||||||
LogVarSet Xs = X;
|
|
||||||
cout << "set: " << X << endl;
|
|
||||||
Cenas c1 (1);
|
|
||||||
Cenas c2 (3);
|
|
||||||
cout << (c1 < c2) << endl;
|
|
||||||
return 0;
|
|
||||||
if (!argv[1]) {
|
if (!argv[1]) {
|
||||||
cerr << "error: no graphical model specified" << endl;
|
cerr << "error: no graphical model specified" << endl;
|
||||||
cerr << USAGE << endl;
|
cerr << USAGE << endl;
|
||||||
@ -99,7 +66,6 @@ processArguments (BayesNet& bn, int argc, const char* argv[])
|
|||||||
cerr << "error: there isn't a variable labeled of " ;
|
cerr << "error: there isn't a variable labeled of " ;
|
||||||
cerr << "`" << arg << "'" ;
|
cerr << "`" << arg << "'" ;
|
||||||
cerr << endl;
|
cerr << endl;
|
||||||
bn.freeDistributions();
|
|
||||||
exit (0);
|
exit (0);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -109,13 +75,11 @@ processArguments (BayesNet& bn, int argc, const char* argv[])
|
|||||||
if (label.empty()) {
|
if (label.empty()) {
|
||||||
cerr << "error: missing left argument" << endl;
|
cerr << "error: missing left argument" << endl;
|
||||||
cerr << USAGE << endl;
|
cerr << USAGE << endl;
|
||||||
bn.freeDistributions();
|
|
||||||
exit (0);
|
exit (0);
|
||||||
}
|
}
|
||||||
if (state.empty()) {
|
if (state.empty()) {
|
||||||
cerr << "error: missing right argument" << endl;
|
cerr << "error: missing right argument" << endl;
|
||||||
cerr << USAGE << endl;
|
cerr << USAGE << endl;
|
||||||
bn.freeDistributions();
|
|
||||||
exit (0);
|
exit (0);
|
||||||
}
|
}
|
||||||
BayesNode* node = bn.getBayesNode (label);
|
BayesNode* node = bn.getBayesNode (label);
|
||||||
@ -127,14 +91,12 @@ processArguments (BayesNet& bn, int argc, const char* argv[])
|
|||||||
cerr << "is not a valid state for " ;
|
cerr << "is not a valid state for " ;
|
||||||
cerr << "`" << node->label() << "'" ;
|
cerr << "`" << node->label() << "'" ;
|
||||||
cerr << endl;
|
cerr << endl;
|
||||||
bn.freeDistributions();
|
|
||||||
exit (0);
|
exit (0);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
cerr << "error: there isn't a variable labeled of " ;
|
cerr << "error: there isn't a variable labeled of " ;
|
||||||
cerr << "`" << label << "'" ;
|
cerr << "`" << label << "'" ;
|
||||||
cerr << endl;
|
cerr << endl;
|
||||||
bn.freeDistributions();
|
|
||||||
exit (0);
|
exit (0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -142,7 +104,7 @@ processArguments (BayesNet& bn, int argc, const char* argv[])
|
|||||||
|
|
||||||
Solver* solver = 0;
|
Solver* solver = 0;
|
||||||
FactorGraph* fg = 0;
|
FactorGraph* fg = 0;
|
||||||
switch (InfAlgorithms::infAlgorithm) {
|
switch (Globals::infAlgorithm) {
|
||||||
case InfAlgorithms::VE:
|
case InfAlgorithms::VE:
|
||||||
fg = new FactorGraph (bn);
|
fg = new FactorGraph (bn);
|
||||||
solver = new VarElimSolver (*fg);
|
solver = new VarElimSolver (*fg);
|
||||||
@ -163,7 +125,6 @@ processArguments (BayesNet& bn, int argc, const char* argv[])
|
|||||||
}
|
}
|
||||||
runSolver (solver, queryVars);
|
runSolver (solver, queryVars);
|
||||||
delete fg;
|
delete fg;
|
||||||
bn.freeDistributions();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -179,7 +140,6 @@ processArguments (FactorGraph& fg, int argc, const char* argv[])
|
|||||||
cerr << "error: `" << arg << "' " ;
|
cerr << "error: `" << arg << "' " ;
|
||||||
cerr << "is not a valid variable id" ;
|
cerr << "is not a valid variable id" ;
|
||||||
cerr << endl;
|
cerr << endl;
|
||||||
fg.freeDistributions();
|
|
||||||
exit (0);
|
exit (0);
|
||||||
}
|
}
|
||||||
VarId vid;
|
VarId vid;
|
||||||
@ -193,7 +153,6 @@ processArguments (FactorGraph& fg, int argc, const char* argv[])
|
|||||||
cerr << "error: there isn't a variable with " ;
|
cerr << "error: there isn't a variable with " ;
|
||||||
cerr << "`" << vid << "' as id" ;
|
cerr << "`" << vid << "' as id" ;
|
||||||
cerr << endl;
|
cerr << endl;
|
||||||
fg.freeDistributions();
|
|
||||||
exit (0);
|
exit (0);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -201,20 +160,17 @@ processArguments (FactorGraph& fg, int argc, const char* argv[])
|
|||||||
if (arg.substr (0, pos).empty()) {
|
if (arg.substr (0, pos).empty()) {
|
||||||
cerr << "error: missing left argument" << endl;
|
cerr << "error: missing left argument" << endl;
|
||||||
cerr << USAGE << endl;
|
cerr << USAGE << endl;
|
||||||
fg.freeDistributions();
|
|
||||||
exit (0);
|
exit (0);
|
||||||
}
|
}
|
||||||
if (arg.substr (pos + 1).empty()) {
|
if (arg.substr (pos + 1).empty()) {
|
||||||
cerr << "error: missing right argument" << endl;
|
cerr << "error: missing right argument" << endl;
|
||||||
cerr << USAGE << endl;
|
cerr << USAGE << endl;
|
||||||
fg.freeDistributions();
|
|
||||||
exit (0);
|
exit (0);
|
||||||
}
|
}
|
||||||
if (!Util::isInteger (arg.substr (0, pos))) {
|
if (!Util::isInteger (arg.substr (0, pos))) {
|
||||||
cerr << "error: `" << arg.substr (0, pos) << "' " ;
|
cerr << "error: `" << arg.substr (0, pos) << "' " ;
|
||||||
cerr << "is not a variable id" ;
|
cerr << "is not a variable id" ;
|
||||||
cerr << endl;
|
cerr << endl;
|
||||||
fg.freeDistributions();
|
|
||||||
exit (0);
|
exit (0);
|
||||||
}
|
}
|
||||||
VarId vid;
|
VarId vid;
|
||||||
@ -227,7 +183,6 @@ processArguments (FactorGraph& fg, int argc, const char* argv[])
|
|||||||
cerr << "error: `" << arg.substr (pos + 1) << "' " ;
|
cerr << "error: `" << arg.substr (pos + 1) << "' " ;
|
||||||
cerr << "is not a state index" ;
|
cerr << "is not a state index" ;
|
||||||
cerr << endl;
|
cerr << endl;
|
||||||
fg.freeDistributions();
|
|
||||||
exit (0);
|
exit (0);
|
||||||
}
|
}
|
||||||
int stateIndex;
|
int stateIndex;
|
||||||
@ -241,28 +196,23 @@ processArguments (FactorGraph& fg, int argc, const char* argv[])
|
|||||||
cerr << "is not a valid state index for variable " ;
|
cerr << "is not a valid state index for variable " ;
|
||||||
cerr << "`" << var->varId() << "'" ;
|
cerr << "`" << var->varId() << "'" ;
|
||||||
cerr << endl;
|
cerr << endl;
|
||||||
fg.freeDistributions();
|
|
||||||
exit (0);
|
exit (0);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
cerr << "error: there isn't a variable with " ;
|
cerr << "error: there isn't a variable with " ;
|
||||||
cerr << "`" << vid << "' as id" ;
|
cerr << "`" << vid << "' as id" ;
|
||||||
cerr << endl;
|
cerr << endl;
|
||||||
fg.freeDistributions();
|
|
||||||
exit (0);
|
exit (0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Solver* solver = 0;
|
Solver* solver = 0;
|
||||||
switch (InfAlgorithms::infAlgorithm) {
|
switch (Globals::infAlgorithm) {
|
||||||
case InfAlgorithms::VE:
|
case InfAlgorithms::VE:
|
||||||
solver = new VarElimSolver (fg);
|
solver = new VarElimSolver (fg);
|
||||||
break;
|
break;
|
||||||
case InfAlgorithms::BN_BP:
|
case InfAlgorithms::BN_BP:
|
||||||
case InfAlgorithms::FG_BP:
|
case InfAlgorithms::FG_BP:
|
||||||
//cout << "here!" << endl;
|
|
||||||
//fg.printGraphicalModel();
|
|
||||||
//fg.exportToLibDaiFormat ("net.fg");
|
|
||||||
solver = new FgBpSolver (fg);
|
solver = new FgBpSolver (fg);
|
||||||
break;
|
break;
|
||||||
case InfAlgorithms::CBP:
|
case InfAlgorithms::CBP:
|
||||||
@ -272,7 +222,6 @@ processArguments (FactorGraph& fg, int argc, const char* argv[])
|
|||||||
assert (false);
|
assert (false);
|
||||||
}
|
}
|
||||||
runSolver (solver, queryVars);
|
runSolver (solver, queryVars);
|
||||||
fg.freeDistributions();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,22 +7,26 @@
|
|||||||
|
|
||||||
#include <YapInterface.h>
|
#include <YapInterface.h>
|
||||||
|
|
||||||
|
#include "ParfactorList.h"
|
||||||
#include "BayesNet.h"
|
#include "BayesNet.h"
|
||||||
#include "FactorGraph.h"
|
#include "FactorGraph.h"
|
||||||
|
#include "FoveSolver.h"
|
||||||
#include "VarElimSolver.h"
|
#include "VarElimSolver.h"
|
||||||
#include "BnBpSolver.h"
|
#include "BnBpSolver.h"
|
||||||
#include "FgBpSolver.h"
|
#include "FgBpSolver.h"
|
||||||
#include "CbpSolver.h"
|
#include "CbpSolver.h"
|
||||||
#include "ElimGraph.h"
|
#include "ElimGraph.h"
|
||||||
#include "FoveSolver.h"
|
|
||||||
#include "ParfactorList.h"
|
|
||||||
|
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
|
|
||||||
|
typedef std::pair<ParfactorList*, ObservedFormulas*> LiftedNetwork;
|
||||||
|
|
||||||
|
|
||||||
Params readParams (YAP_Term);
|
Params readParams (YAP_Term);
|
||||||
|
void readLiftedEvidence (YAP_Term, ObservedFormulas&);
|
||||||
|
Parfactor* readParfactor (YAP_Term);
|
||||||
|
|
||||||
|
|
||||||
int createLiftedNetwork (void)
|
int createLiftedNetwork (void)
|
||||||
@ -30,14 +34,48 @@ int createLiftedNetwork (void)
|
|||||||
Parfactors parfactors;
|
Parfactors parfactors;
|
||||||
YAP_Term parfactorList = YAP_ARG1;
|
YAP_Term parfactorList = YAP_ARG1;
|
||||||
while (parfactorList != YAP_TermNil()) {
|
while (parfactorList != YAP_TermNil()) {
|
||||||
YAP_Term parfactor = YAP_HeadOfTerm (parfactorList);
|
YAP_Term pfTerm = YAP_HeadOfTerm (parfactorList);
|
||||||
|
parfactors.push_back (readParfactor (pfTerm));
|
||||||
|
parfactorList = YAP_TailOfTerm (parfactorList);
|
||||||
|
}
|
||||||
|
|
||||||
|
// LiftedUtils::printSymbolDictionary();
|
||||||
|
if (Constants::DEBUG > 1) {
|
||||||
|
// Util::printHeader ("INITIAL PARFACTORS");
|
||||||
|
// for (unsigned i = 0; i < parfactors.size(); i++) {
|
||||||
|
// parfactors[i]->print();
|
||||||
|
// cout << endl;
|
||||||
|
// }
|
||||||
|
// parfactors[0]->countConvert (LogVar (0));
|
||||||
|
//parfactors[1]->fullExpand (LogVar (1));
|
||||||
|
Util::printHeader ("SHATTERED PARFACTORS");
|
||||||
|
}
|
||||||
|
|
||||||
|
ParfactorList* pfList = new ParfactorList (parfactors);
|
||||||
|
|
||||||
|
if (Constants::DEBUG > 1) {
|
||||||
|
pfList->print();
|
||||||
|
}
|
||||||
|
|
||||||
|
// read evidence
|
||||||
|
ObservedFormulas* obsFormulas = new ObservedFormulas();
|
||||||
|
readLiftedEvidence (YAP_ARG2, *(obsFormulas));
|
||||||
|
|
||||||
|
LiftedNetwork* net = new LiftedNetwork (pfList, obsFormulas);
|
||||||
|
YAP_Int p = (YAP_Int) (net);
|
||||||
|
return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG3);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Parfactor* readParfactor (YAP_Term pfTerm)
|
||||||
|
{
|
||||||
// read dist id
|
// read dist id
|
||||||
unsigned distId = YAP_IntOfTerm (YAP_ArgOfTerm (1, parfactor));
|
unsigned distId = YAP_IntOfTerm (YAP_ArgOfTerm (1, pfTerm));
|
||||||
|
|
||||||
// read the ranges
|
// read the ranges
|
||||||
Ranges ranges;
|
Ranges ranges;
|
||||||
YAP_Term rangeList = YAP_ArgOfTerm (3, parfactor);
|
YAP_Term rangeList = YAP_ArgOfTerm (3, pfTerm);
|
||||||
while (rangeList != YAP_TermNil()) {
|
while (rangeList != YAP_TermNil()) {
|
||||||
unsigned range = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (rangeList));
|
unsigned range = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (rangeList));
|
||||||
ranges.push_back (range);
|
ranges.push_back (range);
|
||||||
@ -48,7 +86,7 @@ int createLiftedNetwork (void)
|
|||||||
ProbFormulas formulas;
|
ProbFormulas formulas;
|
||||||
unsigned count = 0;
|
unsigned count = 0;
|
||||||
unordered_map<YAP_Term, LogVar> lvMap;
|
unordered_map<YAP_Term, LogVar> lvMap;
|
||||||
YAP_Term pvList = YAP_ArgOfTerm (2, parfactor);
|
YAP_Term pvList = YAP_ArgOfTerm (2, pfTerm);
|
||||||
while (pvList != YAP_TermNil()) {
|
while (pvList != YAP_TermNil()) {
|
||||||
YAP_Term formulaTerm = YAP_HeadOfTerm (pvList);
|
YAP_Term formulaTerm = YAP_HeadOfTerm (pvList);
|
||||||
if (YAP_IsAtomTerm (formulaTerm)) {
|
if (YAP_IsAtomTerm (formulaTerm)) {
|
||||||
@ -79,12 +117,12 @@ int createLiftedNetwork (void)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// read the parameters
|
// read the parameters
|
||||||
const Params& params = readParams (YAP_ArgOfTerm (4, parfactor));
|
const Params& params = readParams (YAP_ArgOfTerm (4, pfTerm));
|
||||||
|
|
||||||
// read the constraint
|
// read the constraint
|
||||||
Tuples tuples;
|
Tuples tuples;
|
||||||
if (lvMap.size() >= 1) {
|
if (lvMap.size() >= 1) {
|
||||||
YAP_Term tupleList = YAP_ArgOfTerm (5, parfactor);
|
YAP_Term tupleList = YAP_ArgOfTerm (5, pfTerm);
|
||||||
while (tupleList != YAP_TermNil()) {
|
while (tupleList != YAP_TermNil()) {
|
||||||
YAP_Term term = YAP_HeadOfTerm (tupleList);
|
YAP_Term term = YAP_HeadOfTerm (tupleList);
|
||||||
assert (YAP_IsApplTerm (term));
|
assert (YAP_IsApplTerm (term));
|
||||||
@ -95,7 +133,7 @@ int createLiftedNetwork (void)
|
|||||||
for (unsigned i = 1; i <= arity; i++) {
|
for (unsigned i = 1; i <= arity; i++) {
|
||||||
YAP_Term ti = YAP_ArgOfTerm (i, term);
|
YAP_Term ti = YAP_ArgOfTerm (i, term);
|
||||||
if (YAP_IsAtomTerm (ti) == false) {
|
if (YAP_IsAtomTerm (ti) == false) {
|
||||||
cerr << "error: bad formed constraint" << endl;
|
cerr << "error: constraint has free variables" << endl;
|
||||||
abort();
|
abort();
|
||||||
}
|
}
|
||||||
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ti)));
|
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ti)));
|
||||||
@ -105,32 +143,15 @@ int createLiftedNetwork (void)
|
|||||||
tupleList = YAP_TailOfTerm (tupleList);
|
tupleList = YAP_TailOfTerm (tupleList);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
parfactors.push_back (new Parfactor (formulas, params, tuples, distId));
|
return new Parfactor (formulas, params, tuples, distId);
|
||||||
parfactorList = YAP_TailOfTerm (parfactorList);
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// LiftedUtils::printSymbolDictionary();
|
|
||||||
cout << "*******************************************************" << endl;
|
|
||||||
cout << "INITIAL PARFACTORS" << endl;
|
|
||||||
cout << "*******************************************************" << endl;
|
|
||||||
for (unsigned i = 0; i < parfactors.size(); i++) {
|
|
||||||
parfactors[i]->print();
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
ParfactorList* pfList = new ParfactorList();
|
|
||||||
for (unsigned i = 0; i < parfactors.size(); i++) {
|
|
||||||
pfList->add (parfactors[i]);
|
|
||||||
}
|
|
||||||
cout << endl;
|
|
||||||
cout << "*******************************************************" << endl;
|
|
||||||
cout << "SHATTERED PARFACTORS" << endl;
|
|
||||||
cout << "*******************************************************" << endl;
|
|
||||||
pfList->shatter();
|
|
||||||
pfList->print();
|
|
||||||
|
|
||||||
// insert the evidence
|
|
||||||
ObservedFormulas obsFormulas;
|
void readLiftedEvidence (
|
||||||
YAP_Term observedList = YAP_ARG2;
|
YAP_Term observedList,
|
||||||
|
ObservedFormulas& obsFormulas)
|
||||||
|
{
|
||||||
while (observedList != YAP_TermNil()) {
|
while (observedList != YAP_TermNil()) {
|
||||||
YAP_Term pair = YAP_HeadOfTerm (observedList);
|
YAP_Term pair = YAP_HeadOfTerm (observedList);
|
||||||
YAP_Term ground = YAP_ArgOfTerm (1, pair);
|
YAP_Term ground = YAP_ArgOfTerm (1, pair);
|
||||||
@ -155,22 +176,18 @@ int createLiftedNetwork (void)
|
|||||||
unsigned evidence = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (2, pair));
|
unsigned evidence = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (2, pair));
|
||||||
bool found = false;
|
bool found = false;
|
||||||
for (unsigned i = 0; i < obsFormulas.size(); i++) {
|
for (unsigned i = 0; i < obsFormulas.size(); i++) {
|
||||||
if (obsFormulas[i]->functor() == functor &&
|
if (obsFormulas[i].functor() == functor &&
|
||||||
obsFormulas[i]->arity() == args.size() &&
|
obsFormulas[i].arity() == args.size() &&
|
||||||
obsFormulas[i]->evidence() == evidence) {
|
obsFormulas[i].evidence() == evidence) {
|
||||||
obsFormulas[i]->addTuple (args);
|
obsFormulas[i].addTuple (args);
|
||||||
found = true;
|
found = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (found == false) {
|
if (found == false) {
|
||||||
obsFormulas.push_back (new ObservedFormula (functor, evidence, args));
|
obsFormulas.push_back (ObservedFormula (functor, evidence, args));
|
||||||
}
|
}
|
||||||
observedList = YAP_TailOfTerm (observedList);
|
observedList = YAP_TailOfTerm (observedList);
|
||||||
}
|
}
|
||||||
FoveSolver::absorveEvidence (*pfList, obsFormulas);
|
|
||||||
|
|
||||||
YAP_Int p = (YAP_Int) (pfList);
|
|
||||||
return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG3);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -186,7 +203,6 @@ createGroundNetwork (void)
|
|||||||
// }
|
// }
|
||||||
BayesNet* bn = new BayesNet();
|
BayesNet* bn = new BayesNet();
|
||||||
YAP_Term varList = YAP_ARG1;
|
YAP_Term varList = YAP_ARG1;
|
||||||
BnNodeSet nodes;
|
|
||||||
vector<VarIds> parents;
|
vector<VarIds> parents;
|
||||||
while (varList != YAP_TermNil()) {
|
while (varList != YAP_TermNil()) {
|
||||||
YAP_Term var = YAP_HeadOfTerm (varList);
|
YAP_Term var = YAP_HeadOfTerm (varList);
|
||||||
@ -201,15 +217,13 @@ createGroundNetwork (void)
|
|||||||
parents.back().push_back (parentId);
|
parents.back().push_back (parentId);
|
||||||
parentL = YAP_TailOfTerm (parentL);
|
parentL = YAP_TailOfTerm (parentL);
|
||||||
}
|
}
|
||||||
Distribution* dist = bn->getDistribution (distId);
|
|
||||||
if (!dist) {
|
|
||||||
dist = new Distribution (distId);
|
|
||||||
bn->addDistribution (dist);
|
|
||||||
}
|
|
||||||
assert (bn->getBayesNode (vid) == 0);
|
assert (bn->getBayesNode (vid) == 0);
|
||||||
nodes.push_back (bn->addNode (vid, dsize, evidence, dist));
|
BayesNode* newNode = new BayesNode (
|
||||||
|
vid, dsize, evidence, Params(), distId);
|
||||||
|
bn->addNode (newNode);
|
||||||
varList = YAP_TailOfTerm (varList);
|
varList = YAP_TailOfTerm (varList);
|
||||||
}
|
}
|
||||||
|
const BnNodeSet& nodes = bn->getBayesNodes();
|
||||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||||
BnNodeSet ps;
|
BnNodeSet ps;
|
||||||
for (unsigned j = 0; j < parents[i].size(); j++) {
|
for (unsigned j = 0; j < parents[i].size(); j++) {
|
||||||
@ -225,41 +239,6 @@ createGroundNetwork (void)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
int
|
|
||||||
setBayesNetParams (void)
|
|
||||||
{
|
|
||||||
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
|
|
||||||
YAP_Term distList = YAP_ARG2;
|
|
||||||
while (distList != YAP_TermNil()) {
|
|
||||||
YAP_Term dist = YAP_HeadOfTerm (distList);
|
|
||||||
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, dist));
|
|
||||||
const Params params = readParams (YAP_ArgOfTerm (2, dist));
|
|
||||||
bn->getDistribution(distId)->updateParameters (params);
|
|
||||||
distList = YAP_TailOfTerm (distList);
|
|
||||||
}
|
|
||||||
return TRUE;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
int
|
|
||||||
setParfactorGraphParams (void)
|
|
||||||
{
|
|
||||||
// FIXME
|
|
||||||
// ParfactorGraph* pfg = (ParfactorGraph*) YAP_IntOfTerm (YAP_ARG1);
|
|
||||||
YAP_Term distList = YAP_ARG2;
|
|
||||||
while (distList != YAP_TermNil()) {
|
|
||||||
// YAP_Term dist = YAP_HeadOfTerm (distList);
|
|
||||||
// unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, dist));
|
|
||||||
// const Params params = readParams (YAP_ArgOfTerm (2, dist));
|
|
||||||
// pfg->getDistribution(distId)->setData (params);
|
|
||||||
distList = YAP_TailOfTerm (distList);
|
|
||||||
}
|
|
||||||
return TRUE;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
Params
|
||||||
readParams (YAP_Term paramL)
|
readParams (YAP_Term paramL)
|
||||||
{
|
{
|
||||||
@ -279,15 +258,14 @@ readParams (YAP_Term paramL)
|
|||||||
int
|
int
|
||||||
runLiftedSolver (void)
|
runLiftedSolver (void)
|
||||||
{
|
{
|
||||||
ParfactorList* pfList = (ParfactorList*) YAP_IntOfTerm (YAP_ARG1);
|
LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1);
|
||||||
YAP_Term taskList = YAP_ARG2;
|
YAP_Term taskList = YAP_ARG2;
|
||||||
vector<Params> results;
|
vector<Params> results;
|
||||||
|
ParfactorList pfListCopy (*network->first);
|
||||||
|
FoveSolver::absorveEvidence (pfListCopy, *network->second);
|
||||||
while (taskList != YAP_TermNil()) {
|
while (taskList != YAP_TermNil()) {
|
||||||
YAP_Term jointList = YAP_HeadOfTerm (taskList);
|
|
||||||
Grounds queryVars;
|
Grounds queryVars;
|
||||||
assert (YAP_IsPairTerm (taskList));
|
YAP_Term jointList = YAP_HeadOfTerm (taskList);
|
||||||
assert (YAP_IsPairTerm (jointList));
|
|
||||||
while (jointList != YAP_TermNil()) {
|
while (jointList != YAP_TermNil()) {
|
||||||
YAP_Term ground = YAP_HeadOfTerm (jointList);
|
YAP_Term ground = YAP_HeadOfTerm (jointList);
|
||||||
if (YAP_IsAtomTerm (ground)) {
|
if (YAP_IsAtomTerm (ground)) {
|
||||||
@ -310,11 +288,11 @@ runLiftedSolver (void)
|
|||||||
}
|
}
|
||||||
jointList = YAP_TailOfTerm (jointList);
|
jointList = YAP_TailOfTerm (jointList);
|
||||||
}
|
}
|
||||||
FoveSolver solver (pfList);
|
FoveSolver solver (pfListCopy);
|
||||||
if (queryVars.size() == 1) {
|
if (queryVars.size() == 1) {
|
||||||
results.push_back (solver.getPosterioriOf (queryVars[0]));
|
results.push_back (solver.getPosterioriOf (queryVars[0]));
|
||||||
} else {
|
} else {
|
||||||
assert (false); // TODO joint dist
|
results.push_back (solver.getJointDistributionOf (queryVars));
|
||||||
}
|
}
|
||||||
taskList = YAP_TailOfTerm (taskList);
|
taskList = YAP_TailOfTerm (taskList);
|
||||||
}
|
}
|
||||||
@ -339,46 +317,40 @@ runLiftedSolver (void)
|
|||||||
|
|
||||||
|
|
||||||
int
|
int
|
||||||
runOtherSolvers (void)
|
runGroundSolver (void)
|
||||||
{
|
{
|
||||||
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
|
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
|
||||||
YAP_Term taskList = YAP_ARG2;
|
YAP_Term taskList = YAP_ARG2;
|
||||||
vector<VarIds> tasks;
|
vector<VarIds> tasks;
|
||||||
std::set<VarId> vids;
|
std::set<VarId> vids;
|
||||||
while (taskList != YAP_TermNil()) {
|
while (taskList != YAP_TermNil()) {
|
||||||
if (YAP_IsPairTerm (YAP_HeadOfTerm (taskList))) {
|
VarIds queryVars;
|
||||||
tasks.push_back (VarIds());
|
|
||||||
YAP_Term jointList = YAP_HeadOfTerm (taskList);
|
YAP_Term jointList = YAP_HeadOfTerm (taskList);
|
||||||
while (jointList != YAP_TermNil()) {
|
while (jointList != YAP_TermNil()) {
|
||||||
VarId vid = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (jointList));
|
VarId vid = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (jointList));
|
||||||
assert (bn->getBayesNode (vid));
|
assert (bn->getBayesNode (vid));
|
||||||
tasks.back().push_back (vid);
|
queryVars.push_back (vid);
|
||||||
vids.insert (vid);
|
vids.insert (vid);
|
||||||
jointList = YAP_TailOfTerm (jointList);
|
jointList = YAP_TailOfTerm (jointList);
|
||||||
}
|
}
|
||||||
} else {
|
tasks.push_back (queryVars);
|
||||||
VarId vid = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (taskList));
|
|
||||||
assert (bn->getBayesNode (vid));
|
|
||||||
tasks.push_back (VarIds() = {vid});
|
|
||||||
vids.insert (vid);
|
|
||||||
}
|
|
||||||
taskList = YAP_TailOfTerm (taskList);
|
taskList = YAP_TailOfTerm (taskList);
|
||||||
}
|
}
|
||||||
|
|
||||||
Solver* bpSolver = 0;
|
Solver* bpSolver = 0;
|
||||||
GraphicalModel* graphicalModel = 0;
|
GraphicalModel* graphicalModel = 0;
|
||||||
CFactorGraph::checkForIdenticalFactors = false;
|
CFactorGraph::checkForIdenticalFactors = false;
|
||||||
if (InfAlgorithms::infAlgorithm != InfAlgorithms::VE) {
|
if (Globals::infAlgorithm != InfAlgorithms::VE) {
|
||||||
BayesNet* mrn = bn->getMinimalRequesiteNetwork (
|
BayesNet* mrn = bn->getMinimalRequesiteNetwork (
|
||||||
VarIds (vids.begin(), vids.end()));
|
VarIds (vids.begin(), vids.end()));
|
||||||
if (InfAlgorithms::infAlgorithm == InfAlgorithms::BN_BP) {
|
if (Globals::infAlgorithm == InfAlgorithms::BN_BP) {
|
||||||
graphicalModel = mrn;
|
graphicalModel = mrn;
|
||||||
bpSolver = new BnBpSolver (*static_cast<BayesNet*> (graphicalModel));
|
bpSolver = new BnBpSolver (*static_cast<BayesNet*> (graphicalModel));
|
||||||
} else if (InfAlgorithms::infAlgorithm == InfAlgorithms::FG_BP) {
|
} else if (Globals::infAlgorithm == InfAlgorithms::FG_BP) {
|
||||||
graphicalModel = new FactorGraph (*mrn);
|
graphicalModel = new FactorGraph (*mrn);
|
||||||
bpSolver = new FgBpSolver (*static_cast<FactorGraph*> (graphicalModel));
|
bpSolver = new FgBpSolver (*static_cast<FactorGraph*> (graphicalModel));
|
||||||
delete mrn;
|
delete mrn;
|
||||||
} else if (InfAlgorithms::infAlgorithm == InfAlgorithms::CBP) {
|
} else if (Globals::infAlgorithm == InfAlgorithms::CBP) {
|
||||||
graphicalModel = new FactorGraph (*mrn);
|
graphicalModel = new FactorGraph (*mrn);
|
||||||
bpSolver = new CbpSolver (*static_cast<FactorGraph*> (graphicalModel));
|
bpSolver = new CbpSolver (*static_cast<FactorGraph*> (graphicalModel));
|
||||||
delete mrn;
|
delete mrn;
|
||||||
@ -389,8 +361,7 @@ runOtherSolvers (void)
|
|||||||
vector<Params> results;
|
vector<Params> results;
|
||||||
results.reserve (tasks.size());
|
results.reserve (tasks.size());
|
||||||
for (unsigned i = 0; i < tasks.size(); i++) {
|
for (unsigned i = 0; i < tasks.size(); i++) {
|
||||||
//if (i == 1) exit (0);
|
if (Globals::infAlgorithm == InfAlgorithms::VE) {
|
||||||
if (InfAlgorithms::infAlgorithm == InfAlgorithms::VE) {
|
|
||||||
BayesNet* mrn = bn->getMinimalRequesiteNetwork (tasks[i]);
|
BayesNet* mrn = bn->getMinimalRequesiteNetwork (tasks[i]);
|
||||||
VarElimSolver* veSolver = new VarElimSolver (*mrn);
|
VarElimSolver* veSolver = new VarElimSolver (*mrn);
|
||||||
if (tasks[i].size() == 1) {
|
if (tasks[i].size() == 1) {
|
||||||
@ -430,10 +401,57 @@ runOtherSolvers (void)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
int
|
||||||
|
setParfactorsParams (void)
|
||||||
|
{
|
||||||
|
LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1);
|
||||||
|
ParfactorList* pfList = network->first;
|
||||||
|
YAP_Term distList = YAP_ARG2;
|
||||||
|
unordered_map<unsigned, Params> paramsMap;
|
||||||
|
while (distList != YAP_TermNil()) {
|
||||||
|
YAP_Term dist = YAP_HeadOfTerm (distList);
|
||||||
|
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, dist));
|
||||||
|
assert (Util::contains (paramsMap, distId) == false);
|
||||||
|
paramsMap[distId] = readParams (YAP_ArgOfTerm (2, dist));
|
||||||
|
distList = YAP_TailOfTerm (distList);
|
||||||
|
}
|
||||||
|
ParfactorList::iterator it = pfList->begin();
|
||||||
|
while (it != pfList->end()) {
|
||||||
|
assert (Util::contains (paramsMap, (*it)->distId()));
|
||||||
|
// (*it)->setParams (paramsMap[(*it)->distId()]);
|
||||||
|
++ it;
|
||||||
|
}
|
||||||
|
return TRUE;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
int
|
||||||
|
setBayesNetParams (void)
|
||||||
|
{
|
||||||
|
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
|
||||||
|
YAP_Term distList = YAP_ARG2;
|
||||||
|
unordered_map<unsigned, Params> paramsMap;
|
||||||
|
while (distList != YAP_TermNil()) {
|
||||||
|
YAP_Term dist = YAP_HeadOfTerm (distList);
|
||||||
|
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, dist));
|
||||||
|
assert (Util::contains (paramsMap, distId) == false);
|
||||||
|
paramsMap[distId] = readParams (YAP_ArgOfTerm (2, dist));
|
||||||
|
distList = YAP_TailOfTerm (distList);
|
||||||
|
}
|
||||||
|
const BnNodeSet& nodes = bn->getBayesNodes();
|
||||||
|
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||||
|
assert (Util::contains (paramsMap, nodes[i]->distId()));
|
||||||
|
nodes[i]->setParams (paramsMap[nodes[i]->distId()]);
|
||||||
|
}
|
||||||
|
return TRUE;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
int
|
int
|
||||||
setExtraVarsInfo (void)
|
setExtraVarsInfo (void)
|
||||||
{
|
{
|
||||||
// BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
|
|
||||||
GraphicalModel::clearVariablesInformation();
|
GraphicalModel::clearVariablesInformation();
|
||||||
YAP_Term varsInfoL = YAP_ARG2;
|
YAP_Term varsInfoL = YAP_ARG2;
|
||||||
while (varsInfoL != YAP_TermNil()) {
|
while (varsInfoL != YAP_TermNil()) {
|
||||||
@ -463,13 +481,13 @@ setHorusFlag (void)
|
|||||||
if (key == "inf_alg") {
|
if (key == "inf_alg") {
|
||||||
string value ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG2)));
|
string value ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG2)));
|
||||||
if ( value == "ve") {
|
if ( value == "ve") {
|
||||||
InfAlgorithms::infAlgorithm = InfAlgorithms::VE;
|
Globals::infAlgorithm = InfAlgorithms::VE;
|
||||||
} else if (value == "bn_bp") {
|
} else if (value == "bn_bp") {
|
||||||
InfAlgorithms::infAlgorithm = InfAlgorithms::BN_BP;
|
Globals::infAlgorithm = InfAlgorithms::BN_BP;
|
||||||
} else if (value == "fg_bp") {
|
} else if (value == "fg_bp") {
|
||||||
InfAlgorithms::infAlgorithm = InfAlgorithms::FG_BP;
|
Globals::infAlgorithm = InfAlgorithms::FG_BP;
|
||||||
} else if (value == "cbp") {
|
} else if (value == "cbp") {
|
||||||
InfAlgorithms::infAlgorithm = InfAlgorithms::CBP;
|
Globals::infAlgorithm = InfAlgorithms::CBP;
|
||||||
} else {
|
} else {
|
||||||
cerr << "warning: invalid value `" << value << "' " ;
|
cerr << "warning: invalid value `" << value << "' " ;
|
||||||
cerr << "for `" << key << "'" << endl;
|
cerr << "for `" << key << "'" << endl;
|
||||||
@ -543,19 +561,19 @@ setHorusFlag (void)
|
|||||||
int
|
int
|
||||||
freeBayesNetwork (void)
|
freeBayesNetwork (void)
|
||||||
{
|
{
|
||||||
//Statistics::writeStatisticsToFile ("stats.txt");
|
delete (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
|
||||||
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
|
|
||||||
bn->freeDistributions();
|
|
||||||
delete bn;
|
|
||||||
return TRUE;
|
return TRUE;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
int
|
int
|
||||||
freeParfactorGraph (void)
|
freeParfactors (void)
|
||||||
{
|
{
|
||||||
delete (ParfactorList*) YAP_IntOfTerm (YAP_ARG1);
|
LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1);
|
||||||
|
delete network->first;
|
||||||
|
delete network->second;
|
||||||
|
delete network;
|
||||||
return TRUE;
|
return TRUE;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -566,13 +584,13 @@ init_predicates (void)
|
|||||||
{
|
{
|
||||||
YAP_UserCPredicate ("create_lifted_network", createLiftedNetwork, 3);
|
YAP_UserCPredicate ("create_lifted_network", createLiftedNetwork, 3);
|
||||||
YAP_UserCPredicate ("create_ground_network", createGroundNetwork, 2);
|
YAP_UserCPredicate ("create_ground_network", createGroundNetwork, 2);
|
||||||
YAP_UserCPredicate ("set_parfactor_graph_params", setParfactorGraphParams, 2);
|
|
||||||
YAP_UserCPredicate ("set_bayes_net_params", setBayesNetParams, 2);
|
|
||||||
YAP_UserCPredicate ("run_lifted_solver", runLiftedSolver, 3);
|
YAP_UserCPredicate ("run_lifted_solver", runLiftedSolver, 3);
|
||||||
YAP_UserCPredicate ("run_other_solvers", runOtherSolvers, 3);
|
YAP_UserCPredicate ("run_ground_solver", runGroundSolver, 3);
|
||||||
|
YAP_UserCPredicate ("set_parfactors_params", setParfactorsParams, 2);
|
||||||
|
YAP_UserCPredicate ("set_bayes_net_params", setBayesNetParams, 2);
|
||||||
YAP_UserCPredicate ("set_extra_vars_info", setExtraVarsInfo, 2);
|
YAP_UserCPredicate ("set_extra_vars_info", setExtraVarsInfo, 2);
|
||||||
YAP_UserCPredicate ("set_horus_flag", setHorusFlag, 2);
|
YAP_UserCPredicate ("set_horus_flag", setHorusFlag, 2);
|
||||||
|
YAP_UserCPredicate ("free_parfactors", freeParfactors, 1);
|
||||||
YAP_UserCPredicate ("free_bayesian_network", freeBayesNetwork, 1);
|
YAP_UserCPredicate ("free_bayesian_network", freeBayesNetwork, 1);
|
||||||
YAP_UserCPredicate ("free_parfactor_graph", freeParfactorGraph, 1);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -12,7 +12,9 @@
|
|||||||
#include "Util.h"
|
#include "Util.h"
|
||||||
|
|
||||||
|
|
||||||
class StatesIndexer {
|
|
||||||
|
class StatesIndexer
|
||||||
|
{
|
||||||
public:
|
public:
|
||||||
|
|
||||||
StatesIndexer (const Ranges& ranges, bool calcOffsets = true)
|
StatesIndexer (const Ranges& ranges, bool calcOffsets = true)
|
||||||
@ -134,11 +136,11 @@ class StatesIndexer {
|
|||||||
return size_ ;
|
return size_ ;
|
||||||
}
|
}
|
||||||
|
|
||||||
friend ostream& operator<< (ostream &out, const StatesIndexer& idx)
|
friend ostream& operator<< (ostream &os, const StatesIndexer& idx)
|
||||||
{
|
{
|
||||||
out << "(" << std::setw (2) << std::setfill('0') << idx.li_ << ") " ;
|
os << "(" << std::setw (2) << std::setfill('0') << idx.li_ << ") " ;
|
||||||
out << idx.indices_;
|
os << idx.indices_;
|
||||||
return out;
|
return os;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -274,21 +276,14 @@ class MapIndexer
|
|||||||
index_ = 0;
|
index_ = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
friend ostream& operator<< (ostream &out, const MapIndexer& idx)
|
friend ostream& operator<< (ostream &os, const MapIndexer& idx)
|
||||||
{
|
{
|
||||||
out << "(" << std::setw (2) << std::setfill('0') << idx.index_ << ") " ;
|
os << "(" << std::setw (2) << std::setfill('0') << idx.index_ << ") " ;
|
||||||
out << idx.indices_;
|
os << idx.indices_;
|
||||||
return out;
|
return os;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
MapIndexer (const Ranges& ranges) :
|
|
||||||
ranges_(ranges),
|
|
||||||
indices_(ranges.size(), 0),
|
|
||||||
offsets_(ranges.size())
|
|
||||||
{
|
|
||||||
index_ = 0;
|
|
||||||
}
|
|
||||||
unsigned index_;
|
unsigned index_;
|
||||||
bool valid_;
|
bool valid_;
|
||||||
vector<unsigned> ranges_;
|
vector<unsigned> ranges_;
|
||||||
|
@ -95,26 +95,37 @@ ostream& operator<< (ostream &os, const Ground& gr)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
LogVars
|
||||||
ObservedFormula::addTuple (const Tuple& t)
|
Substitution::getDiscardedLogVars (void) const
|
||||||
{
|
{
|
||||||
if (constr_ == 0) {
|
LogVars discardedLvs;
|
||||||
LogVars lvs (arity_);
|
set<LogVar> doneLvs;
|
||||||
for (unsigned i = 0; i < arity_; i++) {
|
unordered_map<LogVar, LogVar>::const_iterator it;
|
||||||
lvs[i] = i;
|
it = subs_.begin();
|
||||||
|
while (it != subs_.end()) {
|
||||||
|
if (Util::contains (doneLvs, it->second)) {
|
||||||
|
discardedLvs.push_back (it->first);
|
||||||
|
} else {
|
||||||
|
doneLvs.insert (it->second);
|
||||||
}
|
}
|
||||||
constr_ = new ConstraintTree (lvs);
|
it ++;
|
||||||
}
|
}
|
||||||
constr_->addTuple (t);
|
return discardedLvs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
ostream& operator<< (ostream &os, const ObservedFormula of)
|
ostream& operator<< (ostream &os, const Substitution& theta)
|
||||||
{
|
{
|
||||||
os << of.functor_ << "/" << of.arity_;
|
unordered_map<LogVar, LogVar>::const_iterator it;
|
||||||
os << "|" << of.constr_->tupleSet();
|
os << "[" ;
|
||||||
os << " [evidence=" << of.evidence_ << "]";
|
it = theta.subs_.begin();
|
||||||
|
while (it != theta.subs_.end()) {
|
||||||
|
if (it != theta.subs_.begin()) os << ", " ;
|
||||||
|
os << it->first << "->" << it->second ;
|
||||||
|
++ it;
|
||||||
|
}
|
||||||
|
os << "]" ;
|
||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -18,11 +18,17 @@ class Symbol
|
|||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
Symbol (void) : id_(numeric_limits<unsigned>::max()) { }
|
Symbol (void) : id_(numeric_limits<unsigned>::max()) { }
|
||||||
|
|
||||||
Symbol (unsigned id) : id_(id) { }
|
Symbol (unsigned id) : id_(id) { }
|
||||||
|
|
||||||
operator unsigned (void) const { return id_; }
|
operator unsigned (void) const { return id_; }
|
||||||
|
|
||||||
bool valid (void) const { return id_ != numeric_limits<unsigned>::max(); }
|
bool valid (void) const { return id_ != numeric_limits<unsigned>::max(); }
|
||||||
|
|
||||||
static Symbol invalid (void) { return Symbol(); }
|
static Symbol invalid (void) { return Symbol(); }
|
||||||
|
|
||||||
friend ostream& operator<< (ostream &os, const Symbol& s);
|
friend ostream& operator<< (ostream &os, const Symbol& s);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
unsigned id_;
|
unsigned id_;
|
||||||
};
|
};
|
||||||
@ -32,7 +38,9 @@ class LogVar
|
|||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
LogVar (void) : id_(numeric_limits<unsigned>::max()) { }
|
LogVar (void) : id_(numeric_limits<unsigned>::max()) { }
|
||||||
|
|
||||||
LogVar (unsigned id) : id_(id) { }
|
LogVar (unsigned id) : id_(id) { }
|
||||||
|
|
||||||
operator unsigned (void) const { return id_; }
|
operator unsigned (void) const { return id_; }
|
||||||
|
|
||||||
LogVar& operator++ (void)
|
LogVar& operator++ (void)
|
||||||
@ -48,6 +56,7 @@ class LogVar
|
|||||||
}
|
}
|
||||||
|
|
||||||
friend ostream& operator<< (ostream &os, const LogVar& X);
|
friend ostream& operator<< (ostream &os, const LogVar& X);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
unsigned id_;
|
unsigned id_;
|
||||||
};
|
};
|
||||||
@ -89,71 +98,56 @@ class Ground
|
|||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
Ground (Symbol f) : functor_(f) { }
|
Ground (Symbol f) : functor_(f) { }
|
||||||
|
|
||||||
Ground (Symbol f, const Symbols& args) : functor_(f), args_(args) { }
|
Ground (Symbol f, const Symbols& args) : functor_(f), args_(args) { }
|
||||||
|
|
||||||
Symbol functor (void) const { return functor_; }
|
Symbol functor (void) const { return functor_; }
|
||||||
|
|
||||||
Symbols args (void) const { return args_; }
|
Symbols args (void) const { return args_; }
|
||||||
|
|
||||||
unsigned arity (void) const { return args_.size(); }
|
unsigned arity (void) const { return args_.size(); }
|
||||||
|
|
||||||
bool isAtom (void) const { return args_.size() == 0; }
|
bool isAtom (void) const { return args_.size() == 0; }
|
||||||
|
|
||||||
friend ostream& operator<< (ostream &os, const Ground& gr);
|
friend ostream& operator<< (ostream &os, const Ground& gr);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Symbol functor_;
|
Symbol functor_;
|
||||||
Symbols args_;
|
Symbols args_;
|
||||||
};
|
};
|
||||||
|
|
||||||
typedef vector<Ground> Grounds;
|
typedef vector<Ground> Grounds;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ConstraintTree;
|
|
||||||
class ObservedFormula
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
ObservedFormula (Symbol f, unsigned a, unsigned ev)
|
|
||||||
: functor_(f), arity_(a), evidence_(ev), constr_(0) { }
|
|
||||||
|
|
||||||
ObservedFormula (Symbol f, unsigned ev, const Tuple& tuple)
|
|
||||||
: functor_(f), arity_(tuple.size()), evidence_(ev), constr_(0)
|
|
||||||
{
|
|
||||||
addTuple (tuple);
|
|
||||||
}
|
|
||||||
|
|
||||||
Symbol functor (void) const { return functor_; }
|
|
||||||
unsigned arity (void) const { return arity_; }
|
|
||||||
unsigned evidence (void) const { return evidence_; }
|
|
||||||
ConstraintTree* constr (void) const { return constr_; }
|
|
||||||
bool isAtom (void) const { return arity_ == 0; }
|
|
||||||
|
|
||||||
void addTuple (const Tuple& t);
|
|
||||||
friend ostream& operator<< (ostream &os, const ObservedFormula opv);
|
|
||||||
private:
|
|
||||||
Symbol functor_;
|
|
||||||
unsigned arity_;
|
|
||||||
unsigned evidence_;
|
|
||||||
ConstraintTree* constr_;
|
|
||||||
};
|
|
||||||
typedef vector<ObservedFormula*> ObservedFormulas;
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Substitution
|
class Substitution
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
void add (LogVar X_old, LogVar X_new)
|
void add (LogVar X_old, LogVar X_new)
|
||||||
{
|
{
|
||||||
|
assert (Util::contains (subs_, X_old) == false);
|
||||||
subs_.insert (make_pair (X_old, X_new));
|
subs_.insert (make_pair (X_old, X_new));
|
||||||
}
|
}
|
||||||
|
|
||||||
void rename (LogVar X_old, LogVar X_new)
|
void rename (LogVar X_old, LogVar X_new)
|
||||||
{
|
{
|
||||||
assert (subs_.find (X_old) != subs_.end());
|
assert (Util::contains (subs_, X_old));
|
||||||
subs_.find (X_old)->second = X_new;
|
subs_.find (X_old)->second = X_new;
|
||||||
}
|
}
|
||||||
|
|
||||||
LogVar newNameFor (LogVar X) const
|
LogVar newNameFor (LogVar X) const
|
||||||
{
|
{
|
||||||
assert (subs_.find (X) != subs_.end());
|
assert (Util::contains (subs_, X));
|
||||||
return subs_.find (X)->second;
|
return subs_.find (X)->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogVars getDiscardedLogVars (void) const;
|
||||||
|
|
||||||
|
friend ostream& operator<< (ostream &os, const Substitution& theta);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
unordered_map<LogVar, LogVar> subs_;
|
unordered_map<LogVar, LogVar> subs_;
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
@ -60,7 +60,6 @@ HEADERS = \
|
|||||||
$(srcdir)/CbpSolver.h \
|
$(srcdir)/CbpSolver.h \
|
||||||
$(srcdir)/FoveSolver.h \
|
$(srcdir)/FoveSolver.h \
|
||||||
$(srcdir)/VarNode.h \
|
$(srcdir)/VarNode.h \
|
||||||
$(srcdir)/Distribution.h \
|
|
||||||
$(srcdir)/Indexer.h \
|
$(srcdir)/Indexer.h \
|
||||||
$(srcdir)/Parfactor.h \
|
$(srcdir)/Parfactor.h \
|
||||||
$(srcdir)/ProbFormula.h \
|
$(srcdir)/ProbFormula.h \
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
#include "Parfactor.h"
|
#include "Parfactor.h"
|
||||||
#include "Histogram.h"
|
#include "Histogram.h"
|
||||||
#include "Indexer.h"
|
#include "Indexer.h"
|
||||||
|
#include "Util.h"
|
||||||
#include "Horus.h"
|
#include "Horus.h"
|
||||||
|
|
||||||
|
|
||||||
@ -11,55 +12,58 @@ Parfactor::Parfactor (
|
|||||||
const Tuples& tuples,
|
const Tuples& tuples,
|
||||||
unsigned distId)
|
unsigned distId)
|
||||||
{
|
{
|
||||||
formulas_ = formulas;
|
args_ = formulas;
|
||||||
params_ = params;
|
params_ = params;
|
||||||
distId_ = distId;
|
distId_ = distId;
|
||||||
|
|
||||||
LogVars logVars;
|
LogVars logVars;
|
||||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
for (unsigned i = 0; i < args_.size(); i++) {
|
||||||
ranges_.push_back (formulas_[i].range());
|
ranges_.push_back (args_[i].range());
|
||||||
const LogVars& lvs = formulas_[i].logVars();
|
const LogVars& lvs = args_[i].logVars();
|
||||||
for (unsigned j = 0; j < lvs.size(); j++) {
|
for (unsigned j = 0; j < lvs.size(); j++) {
|
||||||
if (std::find (logVars.begin(), logVars.end(), lvs[j]) ==
|
if (Util::contains (logVars, lvs[j]) == false) {
|
||||||
logVars.end()) {
|
|
||||||
logVars.push_back (lvs[j]);
|
logVars.push_back (lvs[j]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
constr_ = new ConstraintTree (logVars, tuples);
|
constr_ = new ConstraintTree (logVars, tuples);
|
||||||
|
assert (params_.size() == Util::expectedSize (ranges_));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Parfactor::Parfactor (const Parfactor* g, const Tuple& tuple)
|
Parfactor::Parfactor (const Parfactor* g, const Tuple& tuple)
|
||||||
{
|
{
|
||||||
formulas_ = g->formulas();
|
args_ = g->arguments();
|
||||||
params_ = g->params();
|
params_ = g->params();
|
||||||
ranges_ = g->ranges();
|
ranges_ = g->ranges();
|
||||||
distId_ = g->distId();
|
distId_ = g->distId();
|
||||||
constr_ = new ConstraintTree (g->logVars(), {tuple});
|
constr_ = new ConstraintTree (g->logVars(), {tuple});
|
||||||
|
assert (params_.size() == Util::expectedSize (ranges_));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Parfactor::Parfactor (const Parfactor* g, ConstraintTree* constr)
|
Parfactor::Parfactor (const Parfactor* g, ConstraintTree* constr)
|
||||||
{
|
{
|
||||||
formulas_ = g->formulas();
|
args_ = g->arguments();
|
||||||
params_ = g->params();
|
params_ = g->params();
|
||||||
ranges_ = g->ranges();
|
ranges_ = g->ranges();
|
||||||
distId_ = g->distId();
|
distId_ = g->distId();
|
||||||
constr_ = constr;
|
constr_ = constr;
|
||||||
|
assert (params_.size() == Util::expectedSize (ranges_));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Parfactor::Parfactor (const Parfactor& g)
|
Parfactor::Parfactor (const Parfactor& g)
|
||||||
{
|
{
|
||||||
formulas_ = g.formulas();
|
args_ = g.arguments();
|
||||||
params_ = g.params();
|
params_ = g.params();
|
||||||
ranges_ = g.ranges();
|
ranges_ = g.ranges();
|
||||||
distId_ = g.distId();
|
distId_ = g.distId();
|
||||||
constr_ = new ConstraintTree (*g.constr());
|
constr_ = new ConstraintTree (*g.constr());
|
||||||
|
assert (params_.size() == Util::expectedSize (ranges_));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -75,9 +79,9 @@ LogVarSet
|
|||||||
Parfactor::countedLogVars (void) const
|
Parfactor::countedLogVars (void) const
|
||||||
{
|
{
|
||||||
LogVarSet set;
|
LogVarSet set;
|
||||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
for (unsigned i = 0; i < args_.size(); i++) {
|
||||||
if (formulas_[i].isCounting()) {
|
if (args_[i].isCounting()) {
|
||||||
set.insert (formulas_[i].countedLogVar());
|
set.insert (args_[i].countedLogVar());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return set;
|
return set;
|
||||||
@ -107,14 +111,14 @@ Parfactor::elimLogVars (void) const
|
|||||||
LogVarSet
|
LogVarSet
|
||||||
Parfactor::exclusiveLogVars (unsigned fIdx) const
|
Parfactor::exclusiveLogVars (unsigned fIdx) const
|
||||||
{
|
{
|
||||||
assert (fIdx < formulas_.size());
|
assert (fIdx < args_.size());
|
||||||
LogVarSet remaining;
|
LogVarSet remaining;
|
||||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
for (unsigned i = 0; i < args_.size(); i++) {
|
||||||
if (i != fIdx) {
|
if (i != fIdx) {
|
||||||
remaining |= formulas_[i].logVarSet();
|
remaining |= args_[i].logVarSet();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return formulas_[fIdx].logVarSet() - remaining;
|
return args_[fIdx].logVarSet() - remaining;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -131,44 +135,51 @@ Parfactor::setConstraintTree (ConstraintTree* newTree)
|
|||||||
void
|
void
|
||||||
Parfactor::sumOut (unsigned fIdx)
|
Parfactor::sumOut (unsigned fIdx)
|
||||||
{
|
{
|
||||||
assert (fIdx < formulas_.size());
|
assert (fIdx < args_.size());
|
||||||
assert (formulas_[fIdx].contains (elimLogVars()));
|
assert (args_[fIdx].contains (elimLogVars()));
|
||||||
|
|
||||||
LogVarSet excl = exclusiveLogVars (fIdx);
|
LogVarSet excl = exclusiveLogVars (fIdx);
|
||||||
unsigned condCount = constr_->getConditionalCount (excl);
|
if (args_[fIdx].isCounting()) {
|
||||||
Util::pow (params_, condCount);
|
LogAware::pow (params_, constr_->getConditionalCount (
|
||||||
|
excl - args_[fIdx].countedLogVar()));
|
||||||
|
} else {
|
||||||
|
LogAware::pow (params_, constr_->getConditionalCount (excl));
|
||||||
|
}
|
||||||
|
|
||||||
vector<unsigned> numAssigns (ranges_[fIdx], 1);
|
if (args_[fIdx].isCounting()) {
|
||||||
if (formulas_[fIdx].isCounting()) {
|
|
||||||
unsigned N = constr_->getConditionalCount (
|
unsigned N = constr_->getConditionalCount (
|
||||||
formulas_[fIdx].countedLogVar());
|
args_[fIdx].countedLogVar());
|
||||||
unsigned R = formulas_[fIdx].range();
|
unsigned R = args_[fIdx].range();
|
||||||
unsigned H = ranges_[fIdx];
|
vector<double> numAssigns = HistogramSet::getNumAssigns (N, R);
|
||||||
HistogramSet hs (N, R);
|
StatesIndexer sindexer (ranges_, fIdx);
|
||||||
unsigned N_factorial = Util::factorial (N);
|
while (sindexer.valid()) {
|
||||||
for (unsigned h = 0; h < H; h++) {
|
unsigned h = sindexer[fIdx];
|
||||||
unsigned prod = 1;
|
if (Globals::logDomain) {
|
||||||
for (unsigned r = 0; r < R; r++) {
|
params_[sindexer] += numAssigns[h];
|
||||||
prod *= Util::factorial (hs[r]);
|
} else {
|
||||||
|
params_[sindexer] *= numAssigns[h];
|
||||||
}
|
}
|
||||||
numAssigns[h] = N_factorial / prod;
|
++ sindexer;
|
||||||
hs.nextHistogram();
|
|
||||||
}
|
}
|
||||||
cout << endl;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Params copy = params_;
|
Params copy = params_;
|
||||||
params_.clear();
|
params_.clear();
|
||||||
params_.resize (copy.size() / ranges_[fIdx], 0.0);
|
params_.resize (copy.size() / ranges_[fIdx], LogAware::addIdenty());
|
||||||
|
|
||||||
MapIndexer indexer (ranges_, fIdx);
|
MapIndexer indexer (ranges_, fIdx);
|
||||||
|
if (Globals::logDomain) {
|
||||||
for (unsigned i = 0; i < copy.size(); i++) {
|
for (unsigned i = 0; i < copy.size(); i++) {
|
||||||
unsigned h = indexer[fIdx];
|
params_[indexer] = Util::logSum (params_[indexer], copy[i]);
|
||||||
// TODO NOT LOG DOMAIN AWARE :(
|
|
||||||
params_[indexer] += numAssigns[h] * copy[i];
|
|
||||||
++ indexer;
|
++ indexer;
|
||||||
}
|
}
|
||||||
formulas_.erase (formulas_.begin() + fIdx);
|
} else {
|
||||||
|
for (unsigned i = 0; i < copy.size(); i++) {
|
||||||
|
params_[indexer] += copy[i];
|
||||||
|
++ indexer;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
args_.erase (args_.begin() + fIdx);
|
||||||
ranges_.erase (ranges_.begin() + fIdx);
|
ranges_.erase (ranges_.begin() + fIdx);
|
||||||
constr_->remove (excl);
|
constr_->remove (excl);
|
||||||
}
|
}
|
||||||
@ -179,55 +190,7 @@ void
|
|||||||
Parfactor::multiply (Parfactor& g)
|
Parfactor::multiply (Parfactor& g)
|
||||||
{
|
{
|
||||||
alignAndExponentiate (this, &g);
|
alignAndExponentiate (this, &g);
|
||||||
bool sharedVars = false;
|
TFactor<ProbFormula>::multiply (g);
|
||||||
vector<unsigned> g_varpos;
|
|
||||||
const ProbFormulas& g_formulas = g.formulas();
|
|
||||||
const Params& g_params = g.params();
|
|
||||||
const Ranges& g_ranges = g.ranges();
|
|
||||||
|
|
||||||
for (unsigned i = 0; i < g_formulas.size(); i++) {
|
|
||||||
int group = g_formulas[i].group();
|
|
||||||
if (indexOfFormulaWithGroup (group) == -1) {
|
|
||||||
insertDimension (g.ranges()[i]);
|
|
||||||
formulas_.push_back (g_formulas[i]);
|
|
||||||
g_varpos.push_back (formulas_.size() - 1);
|
|
||||||
} else {
|
|
||||||
sharedVars = true;
|
|
||||||
g_varpos.push_back (indexOfFormulaWithGroup (group));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (sharedVars == false) {
|
|
||||||
unsigned count = 0;
|
|
||||||
for (unsigned i = 0; i < params_.size(); i++) {
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
params_[i] += g_params[count];
|
|
||||||
} else {
|
|
||||||
params_[i] *= g_params[count];
|
|
||||||
}
|
|
||||||
count ++;
|
|
||||||
if (count >= g_params.size()) {
|
|
||||||
count = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
StatesIndexer indexer (ranges_, false);
|
|
||||||
while (indexer.valid()) {
|
|
||||||
unsigned g_li = 0;
|
|
||||||
unsigned prod = 1;
|
|
||||||
for (int j = g_varpos.size() - 1; j >= 0; j--) {
|
|
||||||
g_li += indexer[g_varpos[j]] * prod;
|
|
||||||
prod *= g_ranges[j];
|
|
||||||
}
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
params_[indexer] += g_params[g_li];
|
|
||||||
} else {
|
|
||||||
params_[indexer] *= g_params[g_li];
|
|
||||||
}
|
|
||||||
++ indexer;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
constr_->join (g.constr(), true);
|
constr_->join (g.constr(), true);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -236,7 +199,7 @@ Parfactor::multiply (Parfactor& g)
|
|||||||
void
|
void
|
||||||
Parfactor::countConvert (LogVar X)
|
Parfactor::countConvert (LogVar X)
|
||||||
{
|
{
|
||||||
int fIdx = indexOfFormulaWithLogVar (X);
|
int fIdx = indexOfLogVar (X);
|
||||||
assert (fIdx != -1);
|
assert (fIdx != -1);
|
||||||
assert (constr_->isCountNormalized (X));
|
assert (constr_->isCountNormalized (X));
|
||||||
assert (constr_->getConditionalCount (X) > 1);
|
assert (constr_->getConditionalCount (X) > 1);
|
||||||
@ -248,12 +211,12 @@ Parfactor::countConvert (LogVar X)
|
|||||||
vector<Histogram> histograms = HistogramSet::getHistograms (N, R);
|
vector<Histogram> histograms = HistogramSet::getHistograms (N, R);
|
||||||
|
|
||||||
StatesIndexer indexer (ranges_);
|
StatesIndexer indexer (ranges_);
|
||||||
vector<Params> summout (params_.size() / R);
|
vector<Params> sumout (params_.size() / R);
|
||||||
unsigned count = 0;
|
unsigned count = 0;
|
||||||
while (indexer.valid()) {
|
while (indexer.valid()) {
|
||||||
summout[count].reserve (R);
|
sumout[count].reserve (R);
|
||||||
for (unsigned r = 0; r < R; r++) {
|
for (unsigned r = 0; r < R; r++) {
|
||||||
summout[count].push_back (params_[indexer]);
|
sumout[count].push_back (params_[indexer]);
|
||||||
indexer.increment (fIdx);
|
indexer.increment (fIdx);
|
||||||
}
|
}
|
||||||
count ++;
|
count ++;
|
||||||
@ -262,45 +225,42 @@ Parfactor::countConvert (LogVar X)
|
|||||||
}
|
}
|
||||||
|
|
||||||
params_.clear();
|
params_.clear();
|
||||||
params_.reserve (summout.size() * H);
|
params_.reserve (sumout.size() * H);
|
||||||
|
|
||||||
vector<bool> mapDims (ranges_.size(), true);
|
|
||||||
ranges_[fIdx] = H;
|
ranges_[fIdx] = H;
|
||||||
mapDims[fIdx] = false;
|
MapIndexer mapIndexer (ranges_, fIdx);
|
||||||
MapIndexer mapIndexer (ranges_, mapDims);
|
|
||||||
while (mapIndexer.valid()) {
|
while (mapIndexer.valid()) {
|
||||||
double prod = 1.0;
|
double prod = LogAware::multIdenty();
|
||||||
unsigned i = mapIndexer.mappedIndex();
|
unsigned i = mapIndexer.mappedIndex();
|
||||||
unsigned h = mapIndexer[fIdx];
|
unsigned h = mapIndexer[fIdx];
|
||||||
for (unsigned r = 0; r < R; r++) {
|
for (unsigned r = 0; r < R; r++) {
|
||||||
// TODO not log domain aware
|
if (Globals::logDomain) {
|
||||||
prod *= Util::pow (summout[i][r], histograms[h][r]);
|
prod += LogAware::pow (sumout[i][r], histograms[h][r]);
|
||||||
|
} else {
|
||||||
|
prod *= LogAware::pow (sumout[i][r], histograms[h][r]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
params_.push_back (prod);
|
params_.push_back (prod);
|
||||||
++ mapIndexer;
|
++ mapIndexer;
|
||||||
}
|
}
|
||||||
formulas_[fIdx].setCountedLogVar (X);
|
args_[fIdx].setCountedLogVar (X);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
Parfactor::expandPotential (
|
Parfactor::expand (LogVar X, LogVar X_new1, LogVar X_new2)
|
||||||
LogVar X,
|
|
||||||
LogVar X_new1,
|
|
||||||
LogVar X_new2)
|
|
||||||
{
|
{
|
||||||
int fIdx = indexOfFormulaWithLogVar (X);
|
int fIdx = indexOfLogVar (X);
|
||||||
assert (fIdx != -1);
|
assert (fIdx != -1);
|
||||||
assert (formulas_[fIdx].isCounting());
|
assert (args_[fIdx].isCounting());
|
||||||
|
|
||||||
unsigned N1 = constr_->getConditionalCount (X_new1);
|
unsigned N1 = constr_->getConditionalCount (X_new1);
|
||||||
unsigned N2 = constr_->getConditionalCount (X_new2);
|
unsigned N2 = constr_->getConditionalCount (X_new2);
|
||||||
unsigned N = N1 + N2;
|
unsigned N = N1 + N2;
|
||||||
unsigned R = formulas_[fIdx].range();
|
unsigned R = args_[fIdx].range();
|
||||||
unsigned H1 = HistogramSet::nrHistograms (N1, R);
|
unsigned H1 = HistogramSet::nrHistograms (N1, R);
|
||||||
unsigned H2 = HistogramSet::nrHistograms (N2, R);
|
unsigned H2 = HistogramSet::nrHistograms (N2, R);
|
||||||
unsigned H = ranges_[fIdx];
|
|
||||||
|
|
||||||
vector<Histogram> histograms = HistogramSet::getHistograms (N, R);
|
vector<Histogram> histograms = HistogramSet::getHistograms (N, R);
|
||||||
vector<Histogram> histograms1 = HistogramSet::getHistograms (N1, R);
|
vector<Histogram> histograms1 = HistogramSet::getHistograms (N1, R);
|
||||||
@ -320,48 +280,11 @@ Parfactor::expandPotential (
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned size = (params_.size() / H) * H1 * H2;
|
expandPotential (fIdx, H1 * H2, sumIndexes);
|
||||||
Params copy = params_;
|
|
||||||
params_.clear();
|
|
||||||
params_.reserve (size);
|
|
||||||
|
|
||||||
unsigned prod = 1;
|
args_.insert (args_.begin() + fIdx + 1, args_[fIdx]);
|
||||||
vector<unsigned> offsets_ (ranges_.size());
|
args_[fIdx].rename (X, X_new1);
|
||||||
for (int i = ranges_.size() - 1; i >= 0; i--) {
|
args_[fIdx + 1].rename (X, X_new2);
|
||||||
offsets_[i] = prod;
|
|
||||||
prod *= ranges_[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
unsigned index = 0;
|
|
||||||
ranges_[fIdx] = H1 * H2;
|
|
||||||
vector<unsigned> indices (ranges_.size(), 0);
|
|
||||||
for (unsigned k = 0; k < size; k++) {
|
|
||||||
params_.push_back (copy[index]);
|
|
||||||
for (int i = ranges_.size() - 1; i >= 0; i--) {
|
|
||||||
indices[i] ++;
|
|
||||||
if (i == fIdx) {
|
|
||||||
int diff = sumIndexes[indices[i]] - sumIndexes[indices[i] - 1];
|
|
||||||
index += diff * offsets_[i];
|
|
||||||
} else {
|
|
||||||
index += offsets_[i];
|
|
||||||
}
|
|
||||||
if (indices[i] != ranges_[i]) {
|
|
||||||
break;
|
|
||||||
} else {
|
|
||||||
if (i == fIdx) {
|
|
||||||
int diff = sumIndexes[0] - sumIndexes[indices[i]];
|
|
||||||
index += diff * offsets_[i];
|
|
||||||
} else {
|
|
||||||
index -= offsets_[i] * ranges_[i];
|
|
||||||
}
|
|
||||||
indices[i] = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
formulas_.insert (formulas_.begin() + fIdx + 1, formulas_[fIdx]);
|
|
||||||
formulas_[fIdx].rename (X, X_new1);
|
|
||||||
formulas_[fIdx + 1].rename (X, X_new2);
|
|
||||||
ranges_.insert (ranges_.begin() + fIdx + 1, H2);
|
ranges_.insert (ranges_.begin() + fIdx + 1, H2);
|
||||||
ranges_[fIdx] = H1;
|
ranges_[fIdx] = H1;
|
||||||
}
|
}
|
||||||
@ -371,13 +294,12 @@ Parfactor::expandPotential (
|
|||||||
void
|
void
|
||||||
Parfactor::fullExpand (LogVar X)
|
Parfactor::fullExpand (LogVar X)
|
||||||
{
|
{
|
||||||
int fIdx = indexOfFormulaWithLogVar (X);
|
int fIdx = indexOfLogVar (X);
|
||||||
assert (fIdx != -1);
|
assert (fIdx != -1);
|
||||||
assert (formulas_[fIdx].isCounting());
|
assert (args_[fIdx].isCounting());
|
||||||
|
|
||||||
unsigned N = constr_->getConditionalCount (X);
|
unsigned N = constr_->getConditionalCount (X);
|
||||||
unsigned R = formulas_[fIdx].range();
|
unsigned R = args_[fIdx].range();
|
||||||
unsigned H = ranges_[fIdx];
|
|
||||||
|
|
||||||
vector<Histogram> originHists = HistogramSet::getHistograms (N, R);
|
vector<Histogram> originHists = HistogramSet::getHistograms (N, R);
|
||||||
vector<Histogram> expandHists = HistogramSet::getHistograms (1, R);
|
vector<Histogram> expandHists = HistogramSet::getHistograms (1, R);
|
||||||
@ -400,54 +322,17 @@ Parfactor::fullExpand (LogVar X)
|
|||||||
++ indexer;
|
++ indexer;
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned size = (params_.size() / H) * std::pow (R, N);
|
expandPotential (fIdx, std::pow (R, N), sumIndexes);
|
||||||
Params copy = params_;
|
|
||||||
params_.clear();
|
|
||||||
params_.reserve (size);
|
|
||||||
|
|
||||||
unsigned prod = 1;
|
ProbFormula f = args_[fIdx];
|
||||||
vector<unsigned> offsets_ (ranges_.size());
|
args_.erase (args_.begin() + fIdx);
|
||||||
for (int i = ranges_.size() - 1; i >= 0; i--) {
|
|
||||||
offsets_[i] = prod;
|
|
||||||
prod *= ranges_[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
unsigned index = 0;
|
|
||||||
ranges_[fIdx] = std::pow (R, N);
|
|
||||||
vector<unsigned> indices (ranges_.size(), 0);
|
|
||||||
for (unsigned k = 0; k < size; k++) {
|
|
||||||
params_.push_back (copy[index]);
|
|
||||||
for (int i = ranges_.size() - 1; i >= 0; i--) {
|
|
||||||
indices[i] ++;
|
|
||||||
if (i == fIdx) {
|
|
||||||
int diff = sumIndexes[indices[i]] - sumIndexes[indices[i] - 1];
|
|
||||||
index += diff * offsets_[i];
|
|
||||||
} else {
|
|
||||||
index += offsets_[i];
|
|
||||||
}
|
|
||||||
if (indices[i] != ranges_[i]) {
|
|
||||||
break;
|
|
||||||
} else {
|
|
||||||
if (i == fIdx) {
|
|
||||||
int diff = sumIndexes[0] - sumIndexes[indices[i]];
|
|
||||||
index += diff * offsets_[i];
|
|
||||||
} else {
|
|
||||||
index -= offsets_[i] * ranges_[i];
|
|
||||||
}
|
|
||||||
indices[i] = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ProbFormula f = formulas_[fIdx];
|
|
||||||
formulas_.erase (formulas_.begin() + fIdx);
|
|
||||||
ranges_.erase (ranges_.begin() + fIdx);
|
ranges_.erase (ranges_.begin() + fIdx);
|
||||||
LogVars newLvs = constr_->expand (X);
|
LogVars newLvs = constr_->expand (X);
|
||||||
assert (newLvs.size() == N);
|
assert (newLvs.size() == N);
|
||||||
for (unsigned i = 0 ; i < N; i++) {
|
for (unsigned i = 0 ; i < N; i++) {
|
||||||
ProbFormula newFormula (f.functor(), f.logVars(), f.range());
|
ProbFormula newFormula (f.functor(), f.logVars(), f.range());
|
||||||
newFormula.rename (X, newLvs[i]);
|
newFormula.rename (X, newLvs[i]);
|
||||||
formulas_.insert (formulas_.begin() + fIdx + i, newFormula);
|
args_.insert (args_.begin() + fIdx + i, newFormula);
|
||||||
ranges_.insert (ranges_.begin() + fIdx + i, R);
|
ranges_.insert (ranges_.begin() + fIdx + i, R);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -459,117 +344,43 @@ Parfactor::reorderAccordingGrounds (const Grounds& grounds)
|
|||||||
{
|
{
|
||||||
ProbFormulas newFormulas;
|
ProbFormulas newFormulas;
|
||||||
for (unsigned i = 0; i < grounds.size(); i++) {
|
for (unsigned i = 0; i < grounds.size(); i++) {
|
||||||
for (unsigned j = 0; j < formulas_.size(); j++) {
|
for (unsigned j = 0; j < args_.size(); j++) {
|
||||||
if (grounds[i].functor() == formulas_[j].functor() &&
|
if (grounds[i].functor() == args_[j].functor() &&
|
||||||
grounds[i].arity() == formulas_[j].arity()) {
|
grounds[i].arity() == args_[j].arity()) {
|
||||||
constr_->moveToTop (formulas_[j].logVars());
|
constr_->moveToTop (args_[j].logVars());
|
||||||
if (constr_->containsTuple (grounds[i].args())) {
|
if (constr_->containsTuple (grounds[i].args())) {
|
||||||
newFormulas.push_back (formulas_[j]);
|
newFormulas.push_back (args_[j]);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
assert (newFormulas.size() == i + 1);
|
assert (newFormulas.size() == i + 1);
|
||||||
}
|
}
|
||||||
reorderFormulas (newFormulas);
|
reorderArguments (newFormulas);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
Parfactor::reorderFormulas (const ProbFormulas& newFormulas)
|
Parfactor::absorveEvidence (const ProbFormula& formula, unsigned evidence)
|
||||||
{
|
|
||||||
assert (newFormulas.size() == formulas_.size());
|
|
||||||
if (newFormulas == formulas_) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
Ranges newRanges;
|
|
||||||
vector<unsigned> positions;
|
|
||||||
for (unsigned i = 0; i < newFormulas.size(); i++) {
|
|
||||||
unsigned idx = indexOf (newFormulas[i]);
|
|
||||||
newRanges.push_back (ranges_[idx]);
|
|
||||||
positions.push_back (idx);
|
|
||||||
}
|
|
||||||
|
|
||||||
unsigned N = ranges_.size();
|
|
||||||
Params newParams (params_.size());
|
|
||||||
for (unsigned i = 0; i < params_.size(); i++) {
|
|
||||||
unsigned li = i;
|
|
||||||
// calculate vector index corresponding to linear index
|
|
||||||
vector<unsigned> vi (N);
|
|
||||||
for (int k = N-1; k >= 0; k--) {
|
|
||||||
vi[k] = li % ranges_[k];
|
|
||||||
li /= ranges_[k];
|
|
||||||
}
|
|
||||||
// convert permuted vector index to corresponding linear index
|
|
||||||
unsigned prod = 1;
|
|
||||||
unsigned new_li = 0;
|
|
||||||
for (int k = N - 1; k >= 0; k--) {
|
|
||||||
new_li += vi[positions[k]] * prod;
|
|
||||||
prod *= ranges_[positions[k]];
|
|
||||||
}
|
|
||||||
newParams[new_li] = params_[i];
|
|
||||||
}
|
|
||||||
formulas_ = newFormulas;
|
|
||||||
ranges_ = newRanges;
|
|
||||||
params_ = newParams;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Parfactor::absorveEvidence (unsigned fIdx, unsigned evidence)
|
|
||||||
{
|
{
|
||||||
|
int fIdx = indexOf (formula);
|
||||||
|
assert (fIdx != -1);
|
||||||
LogVarSet excl = exclusiveLogVars (fIdx);
|
LogVarSet excl = exclusiveLogVars (fIdx);
|
||||||
assert (fIdx < formulas_.size());
|
assert (args_[fIdx].isCounting() == false);
|
||||||
assert (evidence < formulas_[fIdx].range());
|
|
||||||
assert (formulas_[fIdx].isCounting() == false);
|
|
||||||
assert (constr_->isCountNormalized (excl));
|
assert (constr_->isCountNormalized (excl));
|
||||||
|
LogAware::pow (params_, constr_->getConditionalCount (excl));
|
||||||
Util::pow (params_, constr_->getConditionalCount (excl));
|
TFactor<ProbFormula>::absorveEvidence (formula, evidence);
|
||||||
|
|
||||||
Params copy = params_;
|
|
||||||
params_.clear();
|
|
||||||
params_.reserve (copy.size() / formulas_[fIdx].range());
|
|
||||||
|
|
||||||
StatesIndexer indexer (ranges_);
|
|
||||||
for (unsigned i = 0; i < evidence; i++) {
|
|
||||||
indexer.increment (fIdx);
|
|
||||||
}
|
|
||||||
while (indexer.valid()) {
|
|
||||||
params_.push_back (copy[indexer]);
|
|
||||||
indexer.incrementExcluding (fIdx);
|
|
||||||
}
|
|
||||||
formulas_.erase (formulas_.begin() + fIdx);
|
|
||||||
ranges_.erase (ranges_.begin() + fIdx);
|
|
||||||
constr_->remove (excl);
|
constr_->remove (excl);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Parfactor::normalize (void)
|
|
||||||
{
|
|
||||||
Util::normalize (params_);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Parfactor::setFormulaGroup (const ProbFormula& f, int group)
|
|
||||||
{
|
|
||||||
assert (indexOf (f) != -1);
|
|
||||||
formulas_[indexOf (f)].setGroup (group);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
Parfactor::setNewGroups (void)
|
Parfactor::setNewGroups (void)
|
||||||
{
|
{
|
||||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
for (unsigned i = 0; i < args_.size(); i++) {
|
||||||
formulas_[i].setGroup (ProbFormula::getNewGroup());
|
args_[i].setGroup (ProbFormula::getNewGroup());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -578,14 +389,14 @@ Parfactor::setNewGroups (void)
|
|||||||
void
|
void
|
||||||
Parfactor::applySubstitution (const Substitution& theta)
|
Parfactor::applySubstitution (const Substitution& theta)
|
||||||
{
|
{
|
||||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
for (unsigned i = 0; i < args_.size(); i++) {
|
||||||
LogVars& lvs = formulas_[i].logVars();
|
LogVars& lvs = args_[i].logVars();
|
||||||
for (unsigned j = 0; j < lvs.size(); j++) {
|
for (unsigned j = 0; j < lvs.size(); j++) {
|
||||||
lvs[j] = theta.newNameFor (lvs[j]);
|
lvs[j] = theta.newNameFor (lvs[j]);
|
||||||
}
|
}
|
||||||
if (formulas_[i].isCounting()) {
|
if (args_[i].isCounting()) {
|
||||||
LogVar clv = formulas_[i].countedLogVar();
|
LogVar clv = args_[i].countedLogVar();
|
||||||
formulas_[i].setCountedLogVar (theta.newNameFor (clv));
|
args_[i].setCountedLogVar (theta.newNameFor (clv));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
constr_->applySubstitution (theta);
|
constr_->applySubstitution (theta);
|
||||||
@ -593,19 +404,29 @@ Parfactor::applySubstitution (const Substitution& theta)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
int
|
||||||
|
Parfactor::findGroup (const Ground& ground) const
|
||||||
|
{
|
||||||
|
int group = -1;
|
||||||
|
for (unsigned i = 0; i < args_.size(); i++) {
|
||||||
|
if (args_[i].functor() == ground.functor() &&
|
||||||
|
args_[i].arity() == ground.arity()) {
|
||||||
|
constr_->moveToTop (args_[i].logVars());
|
||||||
|
if (constr_->containsTuple (ground.args())) {
|
||||||
|
group = args_[i].group();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return group;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
bool
|
||||||
Parfactor::containsGround (const Ground& ground) const
|
Parfactor::containsGround (const Ground& ground) const
|
||||||
{
|
{
|
||||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
return findGroup (ground) != -1;
|
||||||
if (formulas_[i].functor() == ground.functor() &&
|
|
||||||
formulas_[i].arity() == ground.arity()) {
|
|
||||||
constr_->moveToTop (formulas_[i].logVars());
|
|
||||||
if (constr_->containsTuple (ground.args())) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -613,8 +434,8 @@ Parfactor::containsGround (const Ground& ground) const
|
|||||||
bool
|
bool
|
||||||
Parfactor::containsGroup (unsigned group) const
|
Parfactor::containsGroup (unsigned group) const
|
||||||
{
|
{
|
||||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
for (unsigned i = 0; i < args_.size(); i++) {
|
||||||
if (formulas_[i].group() == group) {
|
if (args_[i].group() == group) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -623,30 +444,12 @@ Parfactor::containsGroup (unsigned group) const
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
const ProbFormula&
|
|
||||||
Parfactor::formula (unsigned fIdx) const
|
|
||||||
{
|
|
||||||
assert (fIdx < formulas_.size());
|
|
||||||
return formulas_[fIdx];
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
unsigned
|
|
||||||
Parfactor::range (unsigned fIdx) const
|
|
||||||
{
|
|
||||||
assert (fIdx < ranges_.size());
|
|
||||||
return ranges_[fIdx];
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
unsigned
|
unsigned
|
||||||
Parfactor::nrFormulas (LogVar X) const
|
Parfactor::nrFormulas (LogVar X) const
|
||||||
{
|
{
|
||||||
unsigned count = 0;
|
unsigned count = 0;
|
||||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
for (unsigned i = 0; i < args_.size(); i++) {
|
||||||
if (formulas_[i].contains (X)) {
|
if (args_[i].contains (X)) {
|
||||||
count ++;
|
count ++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -656,27 +459,12 @@ Parfactor::nrFormulas (LogVar X) const
|
|||||||
|
|
||||||
|
|
||||||
int
|
int
|
||||||
Parfactor::indexOf (const ProbFormula& f) const
|
Parfactor::indexOfLogVar (LogVar X) const
|
||||||
{
|
|
||||||
int idx = -1;
|
|
||||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
|
||||||
if (f == formulas_[i]) {
|
|
||||||
idx = i;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return idx;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
int
|
|
||||||
Parfactor::indexOfFormulaWithLogVar (LogVar X) const
|
|
||||||
{
|
{
|
||||||
int idx = -1;
|
int idx = -1;
|
||||||
assert (nrFormulas (X) == 1);
|
assert (nrFormulas (X) == 1);
|
||||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
for (unsigned i = 0; i < args_.size(); i++) {
|
||||||
if (formulas_[i].contains (X)) {
|
if (args_[i].contains (X)) {
|
||||||
idx = i;
|
idx = i;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -687,11 +475,11 @@ Parfactor::indexOfFormulaWithLogVar (LogVar X) const
|
|||||||
|
|
||||||
|
|
||||||
int
|
int
|
||||||
Parfactor::indexOfFormulaWithGroup (unsigned group) const
|
Parfactor::indexOfGroup (unsigned group) const
|
||||||
{
|
{
|
||||||
int pos = -1;
|
int pos = -1;
|
||||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
for (unsigned i = 0; i < args_.size(); i++) {
|
||||||
if (formulas_[i].group() == group) {
|
if (args_[i].group() == group) {
|
||||||
pos = i;
|
pos = i;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -704,9 +492,9 @@ Parfactor::indexOfFormulaWithGroup (unsigned group) const
|
|||||||
vector<unsigned>
|
vector<unsigned>
|
||||||
Parfactor::getAllGroups (void) const
|
Parfactor::getAllGroups (void) const
|
||||||
{
|
{
|
||||||
vector<unsigned> groups (formulas_.size());
|
vector<unsigned> groups (args_.size());
|
||||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
for (unsigned i = 0; i < args_.size(); i++) {
|
||||||
groups[i] = formulas_[i].group();
|
groups[i] = args_[i].group();
|
||||||
}
|
}
|
||||||
return groups;
|
return groups;
|
||||||
}
|
}
|
||||||
@ -714,13 +502,13 @@ Parfactor::getAllGroups (void) const
|
|||||||
|
|
||||||
|
|
||||||
string
|
string
|
||||||
Parfactor::getHeaderString (void) const
|
Parfactor::getLabel (void) const
|
||||||
{
|
{
|
||||||
stringstream ss;
|
stringstream ss;
|
||||||
ss << "phi(" ;
|
ss << "phi(" ;
|
||||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
for (unsigned i = 0; i < args_.size(); i++) {
|
||||||
if (i != 0) ss << "," ;
|
if (i != 0) ss << "," ;
|
||||||
ss << formulas_[i];
|
ss << args_[i];
|
||||||
}
|
}
|
||||||
ss << ")" ;
|
ss << ")" ;
|
||||||
ConstraintTree copy (*constr_);
|
ConstraintTree copy (*constr_);
|
||||||
@ -735,32 +523,35 @@ void
|
|||||||
Parfactor::print (bool printParams) const
|
Parfactor::print (bool printParams) const
|
||||||
{
|
{
|
||||||
cout << "Formulas: " ;
|
cout << "Formulas: " ;
|
||||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
for (unsigned i = 0; i < args_.size(); i++) {
|
||||||
if (i != 0) cout << ", " ;
|
if (i != 0) cout << ", " ;
|
||||||
cout << formulas_[i];
|
cout << args_[i];
|
||||||
}
|
}
|
||||||
cout << endl;
|
cout << endl;
|
||||||
vector<string> groups;
|
vector<string> groups;
|
||||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
for (unsigned i = 0; i < args_.size(); i++) {
|
||||||
groups.push_back (string ("g") + Util::toString (formulas_[i].group()));
|
groups.push_back (string ("g") + Util::toString (args_[i].group()));
|
||||||
}
|
}
|
||||||
cout << "Groups: " << groups << endl;
|
cout << "Groups: " << groups << endl;
|
||||||
cout << "LogVars: " << constr_->logVars() << endl;
|
cout << "LogVars: " << constr_->logVarSet() << endl;
|
||||||
cout << "Ranges: " << ranges_ << endl;
|
cout << "Ranges: " << ranges_ << endl;
|
||||||
if (printParams == false) {
|
if (printParams == false) {
|
||||||
cout << "Params: " << params_ << endl;
|
cout << "Params: " << params_ << endl;
|
||||||
}
|
}
|
||||||
cout << "Tuples: " << constr_->tupleSet() << endl;
|
ConstraintTree copy (*constr_);
|
||||||
|
copy.moveToTop (copy.logVarSet().elements());
|
||||||
|
cout << "Tuples: " << copy.tupleSet() << endl;
|
||||||
if (printParams) {
|
if (printParams) {
|
||||||
vector<string> jointStrings;
|
vector<string> jointStrings;
|
||||||
StatesIndexer indexer (ranges_);
|
StatesIndexer indexer (ranges_);
|
||||||
while (indexer.valid()) {
|
while (indexer.valid()) {
|
||||||
stringstream ss;
|
stringstream ss;
|
||||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
for (unsigned i = 0; i < args_.size(); i++) {
|
||||||
if (i != 0) ss << ", " ;
|
if (i != 0) ss << ", " ;
|
||||||
if (formulas_[i].isCounting()) {
|
if (args_[i].isCounting()) {
|
||||||
unsigned N = constr_->getConditionalCount (formulas_[i].countedLogVar());
|
unsigned N = constr_->getConditionalCount (
|
||||||
HistogramSet hs (N, formulas_[i].range());
|
args_[i].countedLogVar());
|
||||||
|
HistogramSet hs (N, args_[i].range());
|
||||||
unsigned c = 0;
|
unsigned c = 0;
|
||||||
while (c < indexer[i]) {
|
while (c < indexer[i]) {
|
||||||
hs.nextHistogram();
|
hs.nextHistogram();
|
||||||
@ -784,17 +575,50 @@ Parfactor::print (bool printParams) const
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
Parfactor::insertDimension (unsigned range)
|
Parfactor::expandPotential (
|
||||||
|
int fIdx,
|
||||||
|
unsigned newRange,
|
||||||
|
const vector<unsigned>& sumIndexes)
|
||||||
{
|
{
|
||||||
|
unsigned size = (params_.size() / ranges_[fIdx]) * newRange;
|
||||||
Params copy = params_;
|
Params copy = params_;
|
||||||
params_.clear();
|
params_.clear();
|
||||||
params_.reserve (copy.size() * range);
|
params_.reserve (size);
|
||||||
for (unsigned i = 0; i < copy.size(); i++) {
|
|
||||||
for (unsigned reps = 0; reps < range; reps++) {
|
unsigned prod = 1;
|
||||||
params_.push_back (copy[i]);
|
vector<unsigned> offsets_ (ranges_.size());
|
||||||
|
for (int i = ranges_.size() - 1; i >= 0; i--) {
|
||||||
|
offsets_[i] = prod;
|
||||||
|
prod *= ranges_[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned index = 0;
|
||||||
|
ranges_[fIdx] = newRange;
|
||||||
|
vector<unsigned> indices (ranges_.size(), 0);
|
||||||
|
for (unsigned k = 0; k < size; k++) {
|
||||||
|
params_.push_back (copy[index]);
|
||||||
|
for (int i = ranges_.size() - 1; i >= 0; i--) {
|
||||||
|
indices[i] ++;
|
||||||
|
if (i == fIdx) {
|
||||||
|
assert (indices[i] - 1 < sumIndexes.size());
|
||||||
|
int diff = sumIndexes[indices[i]] - sumIndexes[indices[i] - 1];
|
||||||
|
index += diff * offsets_[i];
|
||||||
|
} else {
|
||||||
|
index += offsets_[i];
|
||||||
|
}
|
||||||
|
if (indices[i] != ranges_[i]) {
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
if (i == fIdx) {
|
||||||
|
int diff = sumIndexes[0] - sumIndexes[indices[i]];
|
||||||
|
index += diff * offsets_[i];
|
||||||
|
} else {
|
||||||
|
index -= offsets_[i] * ranges_[i];
|
||||||
|
}
|
||||||
|
indices[i] = 0;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ranges_.push_back (range);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -803,29 +627,27 @@ void
|
|||||||
Parfactor::alignAndExponentiate (Parfactor* g1, Parfactor* g2)
|
Parfactor::alignAndExponentiate (Parfactor* g1, Parfactor* g2)
|
||||||
{
|
{
|
||||||
LogVars X_1, X_2;
|
LogVars X_1, X_2;
|
||||||
const ProbFormulas& formulas1 = g1->formulas();
|
const ProbFormulas& formulas1 = g1->arguments();
|
||||||
const ProbFormulas& formulas2 = g2->formulas();
|
const ProbFormulas& formulas2 = g2->arguments();
|
||||||
for (unsigned i = 0; i < formulas1.size(); i++) {
|
for (unsigned i = 0; i < formulas1.size(); i++) {
|
||||||
for (unsigned j = 0; j < formulas2.size(); j++) {
|
for (unsigned j = 0; j < formulas2.size(); j++) {
|
||||||
if (formulas1[i].group() == formulas2[j].group()) {
|
if (formulas1[i].group() == formulas2[j].group()) {
|
||||||
X_1.insert (X_1.end(),
|
Util::addToVector (X_1, formulas1[i].logVars());
|
||||||
formulas1[i].logVars().begin(),
|
Util::addToVector (X_2, formulas2[j].logVars());
|
||||||
formulas1[i].logVars().end());
|
|
||||||
X_2.insert (X_2.end(),
|
|
||||||
formulas2[j].logVars().begin(),
|
|
||||||
formulas2[j].logVars().end());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
align (g1, X_1, g2, X_2);
|
|
||||||
LogVarSet Y_1 = g1->logVarSet() - LogVarSet (X_1);
|
LogVarSet Y_1 = g1->logVarSet() - LogVarSet (X_1);
|
||||||
LogVarSet Y_2 = g2->logVarSet() - LogVarSet (X_2);
|
LogVarSet Y_2 = g2->logVarSet() - LogVarSet (X_2);
|
||||||
assert (g1->constr()->isCountNormalized (Y_1));
|
assert (g1->constr()->isCountNormalized (Y_1));
|
||||||
assert (g2->constr()->isCountNormalized (Y_2));
|
assert (g2->constr()->isCountNormalized (Y_2));
|
||||||
unsigned condCount1 = g1->constr()->getConditionalCount (Y_1);
|
unsigned condCount1 = g1->constr()->getConditionalCount (Y_1);
|
||||||
unsigned condCount2 = g2->constr()->getConditionalCount (Y_2);
|
unsigned condCount2 = g2->constr()->getConditionalCount (Y_2);
|
||||||
Util::pow (g1->params(), 1.0 / condCount2);
|
LogAware::pow (g1->params(), 1.0 / condCount2);
|
||||||
Util::pow (g2->params(), 1.0 / condCount1);
|
LogAware::pow (g2->params(), 1.0 / condCount1);
|
||||||
|
// this must be done in the end or else X_1 and X_2
|
||||||
|
// will refer the old log var names in the code above
|
||||||
|
align (g1, X_1, g2, X_2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -838,7 +660,6 @@ Parfactor::align (
|
|||||||
LogVar freeLogVar = 0;
|
LogVar freeLogVar = 0;
|
||||||
Substitution theta1;
|
Substitution theta1;
|
||||||
Substitution theta2;
|
Substitution theta2;
|
||||||
|
|
||||||
const LogVarSet& allLvs1 = g1->logVarSet();
|
const LogVarSet& allLvs1 = g1->logVarSet();
|
||||||
for (unsigned i = 0; i < allLvs1.size(); i++) {
|
for (unsigned i = 0; i < allLvs1.size(); i++) {
|
||||||
theta1.add (allLvs1[i], freeLogVar);
|
theta1.add (allLvs1[i], freeLogVar);
|
||||||
|
@ -9,8 +9,9 @@
|
|||||||
#include "LiftedUtils.h"
|
#include "LiftedUtils.h"
|
||||||
#include "Horus.h"
|
#include "Horus.h"
|
||||||
|
|
||||||
|
#include "Factor.h"
|
||||||
|
|
||||||
class Parfactor
|
class Parfactor : public TFactor<ProbFormula>
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
Parfactor (
|
Parfactor (
|
||||||
@ -18,27 +19,15 @@ class Parfactor
|
|||||||
const Params&,
|
const Params&,
|
||||||
const Tuples&,
|
const Tuples&,
|
||||||
unsigned);
|
unsigned);
|
||||||
|
|
||||||
Parfactor (const Parfactor*, const Tuple&);
|
Parfactor (const Parfactor*, const Tuple&);
|
||||||
|
|
||||||
Parfactor (const Parfactor*, ConstraintTree*);
|
Parfactor (const Parfactor*, ConstraintTree*);
|
||||||
|
|
||||||
Parfactor (const Parfactor&);
|
Parfactor (const Parfactor&);
|
||||||
|
|
||||||
~Parfactor (void);
|
~Parfactor (void);
|
||||||
|
|
||||||
ProbFormulas& formulas (void) { return formulas_; }
|
|
||||||
|
|
||||||
const ProbFormulas& formulas (void) const { return formulas_; }
|
|
||||||
|
|
||||||
unsigned nrFormulas (void) const { return formulas_.size(); }
|
|
||||||
|
|
||||||
Params& params (void) { return params_; }
|
|
||||||
|
|
||||||
const Params& params (void) const { return params_; }
|
|
||||||
|
|
||||||
unsigned size (void) const { return params_.size(); }
|
|
||||||
|
|
||||||
const Ranges& ranges (void) const { return ranges_; }
|
|
||||||
|
|
||||||
unsigned distId (void) const { return distId_; }
|
|
||||||
|
|
||||||
ConstraintTree* constr (void) { return constr_; }
|
ConstraintTree* constr (void) { return constr_; }
|
||||||
|
|
||||||
const ConstraintTree* constr (void) const { return constr_; }
|
const ConstraintTree* constr (void) const { return constr_; }
|
||||||
@ -57,63 +46,51 @@ class Parfactor
|
|||||||
|
|
||||||
void setConstraintTree (ConstraintTree*);
|
void setConstraintTree (ConstraintTree*);
|
||||||
|
|
||||||
void sumOut (unsigned);
|
void sumOut (unsigned fIdx);
|
||||||
|
|
||||||
void multiply (Parfactor&);
|
void multiply (Parfactor&);
|
||||||
|
|
||||||
void countConvert (LogVar);
|
void countConvert (LogVar);
|
||||||
|
|
||||||
void expandPotential (LogVar, LogVar, LogVar);
|
void expand (LogVar, LogVar, LogVar);
|
||||||
|
|
||||||
void fullExpand (LogVar);
|
void fullExpand (LogVar);
|
||||||
|
|
||||||
void reorderAccordingGrounds (const Grounds&);
|
void reorderAccordingGrounds (const Grounds&);
|
||||||
|
|
||||||
void reorderFormulas (const ProbFormulas&);
|
void absorveEvidence (const ProbFormula&, unsigned);
|
||||||
|
|
||||||
void absorveEvidence (unsigned, unsigned);
|
|
||||||
|
|
||||||
void normalize (void);
|
|
||||||
|
|
||||||
void setFormulaGroup (const ProbFormula&, int);
|
|
||||||
|
|
||||||
void setNewGroups (void);
|
void setNewGroups (void);
|
||||||
|
|
||||||
void applySubstitution (const Substitution&);
|
void applySubstitution (const Substitution&);
|
||||||
|
|
||||||
|
int findGroup (const Ground&) const;
|
||||||
|
|
||||||
bool containsGround (const Ground&) const;
|
bool containsGround (const Ground&) const;
|
||||||
|
|
||||||
bool containsGroup (unsigned) const;
|
bool containsGroup (unsigned) const;
|
||||||
|
|
||||||
const ProbFormula& formula (unsigned) const;
|
|
||||||
|
|
||||||
unsigned range (unsigned) const;
|
|
||||||
|
|
||||||
unsigned nrFormulas (LogVar) const;
|
unsigned nrFormulas (LogVar) const;
|
||||||
|
|
||||||
int indexOf (const ProbFormula&) const;
|
int indexOfLogVar (LogVar) const;
|
||||||
|
|
||||||
int indexOfFormulaWithLogVar (LogVar) const;
|
int indexOfGroup (unsigned) const;
|
||||||
|
|
||||||
int indexOfFormulaWithGroup (unsigned) const;
|
|
||||||
|
|
||||||
vector<unsigned> getAllGroups (void) const;
|
vector<unsigned> getAllGroups (void) const;
|
||||||
|
|
||||||
void print (bool = false) const;
|
void print (bool = false) const;
|
||||||
|
|
||||||
string getHeaderString (void) const;
|
string getLabel (void) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
void expandPotential (int fIdx, unsigned newRange,
|
||||||
|
const vector<unsigned>& sumIndexes);
|
||||||
|
|
||||||
static void alignAndExponentiate (Parfactor*, Parfactor*);
|
static void alignAndExponentiate (Parfactor*, Parfactor*);
|
||||||
|
|
||||||
static void align (
|
static void align (
|
||||||
Parfactor*, const LogVars&, Parfactor*, const LogVars&);
|
Parfactor*, const LogVars&, Parfactor*, const LogVars&);
|
||||||
|
|
||||||
void insertDimension (unsigned);
|
|
||||||
|
|
||||||
ProbFormulas formulas_;
|
|
||||||
Ranges ranges_;
|
|
||||||
Params params_;
|
|
||||||
unsigned distId_;
|
|
||||||
ConstraintTree* constr_;
|
ConstraintTree* constr_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -3,9 +3,32 @@
|
|||||||
#include "ParfactorList.h"
|
#include "ParfactorList.h"
|
||||||
|
|
||||||
|
|
||||||
ParfactorList::ParfactorList (Parfactors& pfs)
|
ParfactorList::ParfactorList (const ParfactorList& pfList)
|
||||||
{
|
{
|
||||||
pfList_.insert (pfList_.end(), pfs.begin(), pfs.end());
|
ParfactorList::const_iterator it = pfList.begin();
|
||||||
|
while (it != pfList.end()) {
|
||||||
|
addShattered (new Parfactor (**it));
|
||||||
|
++ it;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
ParfactorList::ParfactorList (const Parfactors& pfs)
|
||||||
|
{
|
||||||
|
add (pfs);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
ParfactorList::~ParfactorList (void)
|
||||||
|
{
|
||||||
|
ParfactorList::const_iterator it = pfList_.begin();
|
||||||
|
while (it != pfList_.end()) {
|
||||||
|
delete *it;
|
||||||
|
++ it;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -14,17 +37,17 @@ void
|
|||||||
ParfactorList::add (Parfactor* pf)
|
ParfactorList::add (Parfactor* pf)
|
||||||
{
|
{
|
||||||
pf->setNewGroups();
|
pf->setNewGroups();
|
||||||
pfList_.push_back (pf);
|
addToShatteredList (pf);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
ParfactorList::add (Parfactors& pfs)
|
ParfactorList::add (const Parfactors& pfs)
|
||||||
{
|
{
|
||||||
for (unsigned i = 0; i < pfs.size(); i++) {
|
for (unsigned i = 0; i < pfs.size(); i++) {
|
||||||
pfs[i]->setNewGroups();
|
pfs[i]->setNewGroups();
|
||||||
pfList_.push_back (pfs[i]);
|
addToShatteredList (pfs[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -33,7 +56,20 @@ ParfactorList::add (Parfactors& pfs)
|
|||||||
void
|
void
|
||||||
ParfactorList::addShattered (Parfactor* pf)
|
ParfactorList::addShattered (Parfactor* pf)
|
||||||
{
|
{
|
||||||
|
assert (isAllShattered());
|
||||||
pfList_.push_back (pf);
|
pfList_.push_back (pf);
|
||||||
|
assert (isAllShattered());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
list<Parfactor*>::iterator
|
||||||
|
ParfactorList::insertShattered (
|
||||||
|
list<Parfactor*>::iterator it,
|
||||||
|
Parfactor* pf)
|
||||||
|
{
|
||||||
|
return pfList_.insert (it, pf);
|
||||||
|
assert (isAllShattered());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -47,7 +83,7 @@ ParfactorList::remove (list<Parfactor*>::iterator it)
|
|||||||
|
|
||||||
|
|
||||||
list<Parfactor*>::iterator
|
list<Parfactor*>::iterator
|
||||||
ParfactorList::deleteAndRemove (list<Parfactor*>::iterator it)
|
ParfactorList::removeAndDelete (list<Parfactor*>::iterator it)
|
||||||
{
|
{
|
||||||
delete *it;
|
delete *it;
|
||||||
return pfList_.erase (it);
|
return pfList_.erase (it);
|
||||||
@ -55,58 +91,21 @@ ParfactorList::deleteAndRemove (list<Parfactor*>::iterator it)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
bool
|
||||||
ParfactorList::shatter (void)
|
ParfactorList::isAllShattered (void) const
|
||||||
{
|
{
|
||||||
list<Parfactor*> tempList;
|
if (pfList_.size() <= 1) {
|
||||||
Parfactors newPfs;
|
return true;
|
||||||
newPfs.insert (newPfs.end(), pfList_.begin(), pfList_.end());
|
|
||||||
while (newPfs.empty() == false) {
|
|
||||||
tempList.insert (tempList.end(), newPfs.begin(), newPfs.end());
|
|
||||||
newPfs.clear();
|
|
||||||
list<Parfactor*>::iterator iter1 = tempList.begin();
|
|
||||||
while (tempList.size() > 1 && iter1 != -- tempList.end()) {
|
|
||||||
list<Parfactor*>::iterator iter2 = iter1;
|
|
||||||
++ iter2;
|
|
||||||
bool incIter1 = true;
|
|
||||||
while (iter2 != tempList.end()) {
|
|
||||||
assert (iter1 != iter2);
|
|
||||||
std::pair<Parfactors, Parfactors> res = shatter (
|
|
||||||
(*iter1)->formulas(), *iter1, (*iter2)->formulas(), *iter2);
|
|
||||||
bool incIter2 = true;
|
|
||||||
if (res.second.empty() == false) {
|
|
||||||
// cout << "second unshattered" << endl;
|
|
||||||
delete *iter2;
|
|
||||||
iter2 = tempList.erase (iter2);
|
|
||||||
incIter2 = false;
|
|
||||||
newPfs.insert (
|
|
||||||
newPfs.begin(), res.second.begin(), res.second.end());
|
|
||||||
}
|
}
|
||||||
if (res.first.empty() == false) {
|
vector<Parfactor*> pfs (pfList_.begin(), pfList_.end());
|
||||||
// cout << "first unshattered" << endl;
|
for (unsigned i = 0; i < pfs.size() - 1; i++) {
|
||||||
delete *iter1;
|
for (unsigned j = i + 1; j < pfs.size(); j++) {
|
||||||
iter1 = tempList.erase (iter1);
|
if (isShattered (pfs[i], pfs[j]) == false) {
|
||||||
newPfs.insert (
|
return false;
|
||||||
newPfs.begin(), res.first.begin(), res.first.end());
|
|
||||||
incIter1 = false;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if (incIter2) {
|
|
||||||
++ iter2;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (incIter1) {
|
|
||||||
++ iter1;
|
|
||||||
}
|
}
|
||||||
}
|
return true;
|
||||||
// cout << "|||||||||||||||||||||||||||||||||||||||||||||||||" << endl;
|
|
||||||
// cout << "||||||||||||| SHATTERING ITERATION ||||||||||||||" << endl;
|
|
||||||
// cout << "|||||||||||||||||||||||||||||||||||||||||||||||||" << endl;
|
|
||||||
// printParfactors (newPfs);
|
|
||||||
// cout << "|||||||||||||||||||||||||||||||||||||||||||||||||" << endl;
|
|
||||||
}
|
|
||||||
pfList_.clear();
|
|
||||||
pfList_.insert (pfList_.end(), tempList.begin(), tempList.end());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -123,19 +122,83 @@ ParfactorList::print (void) const
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
std::pair<Parfactors, Parfactors>
|
bool
|
||||||
ParfactorList::shatter (
|
ParfactorList::isShattered (
|
||||||
ProbFormulas& formulas1,
|
const Parfactor* g1,
|
||||||
Parfactor* g1,
|
const Parfactor* g2) const
|
||||||
ProbFormulas& formulas2,
|
|
||||||
Parfactor* g2)
|
|
||||||
{
|
{
|
||||||
|
assert (g1 != g2);
|
||||||
|
const ProbFormulas& fms1 = g1->arguments();
|
||||||
|
const ProbFormulas& fms2 = g2->arguments();
|
||||||
|
for (unsigned i = 0; i < fms1.size(); i++) {
|
||||||
|
for (unsigned j = 0; j < fms2.size(); j++) {
|
||||||
|
if (fms1[i].group() == fms2[j].group()) {
|
||||||
|
if (identical (
|
||||||
|
fms1[i], *(g1->constr()),
|
||||||
|
fms2[j], *(g2->constr())) == false) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (disjoint (
|
||||||
|
fms1[i], *(g1->constr()),
|
||||||
|
fms2[j], *(g2->constr())) == false) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
ParfactorList::addToShatteredList (Parfactor* g)
|
||||||
|
{
|
||||||
|
queue<Parfactor*> residuals;
|
||||||
|
residuals.push (g);
|
||||||
|
while (residuals.empty() == false) {
|
||||||
|
Parfactor* pf = residuals.front();
|
||||||
|
bool pfSplitted = false;
|
||||||
|
list<Parfactor*>::iterator pfIter;
|
||||||
|
pfIter = pfList_.begin();
|
||||||
|
while (pfIter != pfList_.end()) {
|
||||||
|
std::pair<Parfactors, Parfactors> shattRes;
|
||||||
|
shattRes = shatter (*pfIter, pf);
|
||||||
|
if (shattRes.first.empty() == false) {
|
||||||
|
pfIter = removeAndDelete (pfIter);
|
||||||
|
Util::addToQueue (residuals, shattRes.first);
|
||||||
|
} else {
|
||||||
|
++ pfIter;
|
||||||
|
}
|
||||||
|
if (shattRes.second.empty() == false) {
|
||||||
|
delete pf;
|
||||||
|
Util::addToQueue (residuals, shattRes.second);
|
||||||
|
pfSplitted = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
residuals.pop();
|
||||||
|
if (pfSplitted == false) {
|
||||||
|
addShattered (pf);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert (isAllShattered());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
std::pair<Parfactors, Parfactors>
|
||||||
|
ParfactorList::shatter (Parfactor* g1, Parfactor* g2)
|
||||||
|
{
|
||||||
|
ProbFormulas& formulas1 = g1->arguments();
|
||||||
|
ProbFormulas& formulas2 = g2->arguments();
|
||||||
assert (g1 != 0 && g2 != 0 && g1 != g2);
|
assert (g1 != 0 && g2 != 0 && g1 != g2);
|
||||||
for (unsigned i = 0; i < formulas1.size(); i++) {
|
for (unsigned i = 0; i < formulas1.size(); i++) {
|
||||||
for (unsigned j = 0; j < formulas2.size(); j++) {
|
for (unsigned j = 0; j < formulas2.size(); j++) {
|
||||||
if (formulas1[i].sameSkeletonAs (formulas2[j])) {
|
if (formulas1[i].sameSkeletonAs (formulas2[j])) {
|
||||||
std::pair<Parfactors, Parfactors> res
|
std::pair<Parfactors, Parfactors> res;
|
||||||
= shatter (formulas1[i], g1, formulas2[j], g2);
|
res = shatter (i, g1, j, g2);
|
||||||
if (res.first.empty() == false ||
|
if (res.first.empty() == false ||
|
||||||
res.second.empty() == false) {
|
res.second.empty() == false) {
|
||||||
return res;
|
return res;
|
||||||
@ -150,21 +213,22 @@ ParfactorList::shatter (
|
|||||||
|
|
||||||
std::pair<Parfactors, Parfactors>
|
std::pair<Parfactors, Parfactors>
|
||||||
ParfactorList::shatter (
|
ParfactorList::shatter (
|
||||||
ProbFormula& f1,
|
unsigned fIdx1, Parfactor* g1,
|
||||||
Parfactor* g1,
|
unsigned fIdx2, Parfactor* g2)
|
||||||
ProbFormula& f2,
|
|
||||||
Parfactor* g2)
|
|
||||||
{
|
{
|
||||||
|
ProbFormula& f1 = g1->argument (fIdx1);
|
||||||
|
ProbFormula& f2 = g2->argument (fIdx2);
|
||||||
// cout << endl;
|
// cout << endl;
|
||||||
// cout << "-------------------------------------------------" << endl;
|
// Util::printDashLine();
|
||||||
// cout << "-> SHATTERING (#" << g1 << ", #" << g2 << ")" << endl;
|
// cout << "-> SHATTERING (#" << g1 << ", #" << g2 << ")" << endl;
|
||||||
// g1->print();
|
// g1->print();
|
||||||
// cout << "-> WITH" << endl;
|
// cout << "-> WITH" << endl;
|
||||||
// g2->print();
|
// g2->print();
|
||||||
// cout << "-> ON: " << f1.toString (g1->constr()) << endl;
|
// cout << "-> ON: " << f1 << "|" ;
|
||||||
// cout << "-> ON: " << f2.toString (g2->constr()) << endl;
|
// cout << g1->constr()->tupleSet (f1.logVars()) << endl;
|
||||||
// cout << "-------------------------------------------------" << endl;
|
// cout << "-> ON: " << f2 << "|" ;
|
||||||
|
// cout << g2->constr()->tupleSet (f2.logVars()) << endl;
|
||||||
|
// Util::printDashLine();
|
||||||
if (f1.isAtom()) {
|
if (f1.isAtom()) {
|
||||||
unsigned group = (f1.group() < f2.group()) ? f1.group() : f2.group();
|
unsigned group = (f1.group() < f2.group()) ? f1.group() : f2.group();
|
||||||
f1.setGroup (group);
|
f1.setGroup (group);
|
||||||
@ -174,7 +238,7 @@ ParfactorList::shatter (
|
|||||||
assert (g1->constr()->empty() == false);
|
assert (g1->constr()->empty() == false);
|
||||||
assert (g2->constr()->empty() == false);
|
assert (g2->constr()->empty() == false);
|
||||||
if (f1.group() == f2.group()) {
|
if (f1.group() == f2.group()) {
|
||||||
// assert (identical (f1, g1->constr(), f2, g2->constr()));
|
assert (identical (f1, *(g1->constr()), f2, *(g2->constr())));
|
||||||
return { };
|
return { };
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -215,7 +279,9 @@ ParfactorList::shatter (
|
|||||||
// exclCt2->exportToGraphViz (ss6.str().c_str(), true);
|
// exclCt2->exportToGraphViz (ss6.str().c_str(), true);
|
||||||
|
|
||||||
if (exclCt1->empty() && exclCt2->empty()) {
|
if (exclCt1->empty() && exclCt2->empty()) {
|
||||||
unsigned group = (f1.group() < f2.group()) ? f1.group() : f2.group();
|
unsigned group = (f1.group() < f2.group())
|
||||||
|
? f1.group()
|
||||||
|
: f2.group();
|
||||||
// identical
|
// identical
|
||||||
f1.setGroup (group);
|
f1.setGroup (group);
|
||||||
f2.setGroup (group);
|
f2.setGroup (group);
|
||||||
@ -235,8 +301,8 @@ ParfactorList::shatter (
|
|||||||
} else {
|
} else {
|
||||||
group = ProbFormula::getNewGroup();
|
group = ProbFormula::getNewGroup();
|
||||||
}
|
}
|
||||||
Parfactors res1 = shatter (g1, f1, commCt1, exclCt1, group);
|
Parfactors res1 = shatter (g1, fIdx1, commCt1, exclCt1, group);
|
||||||
Parfactors res2 = shatter (g2, f2, commCt2, exclCt2, group);
|
Parfactors res2 = shatter (g2, fIdx2, commCt2, exclCt2, group);
|
||||||
return make_pair (res1, res2);
|
return make_pair (res1, res2);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -245,11 +311,19 @@ ParfactorList::shatter (
|
|||||||
Parfactors
|
Parfactors
|
||||||
ParfactorList::shatter (
|
ParfactorList::shatter (
|
||||||
Parfactor* g,
|
Parfactor* g,
|
||||||
const ProbFormula& f,
|
unsigned fIdx,
|
||||||
ConstraintTree* commCt,
|
ConstraintTree* commCt,
|
||||||
ConstraintTree* exclCt,
|
ConstraintTree* exclCt,
|
||||||
unsigned commGroup)
|
unsigned commGroup)
|
||||||
{
|
{
|
||||||
|
ProbFormula& f = g->argument (fIdx);
|
||||||
|
if (exclCt->empty()) {
|
||||||
|
delete commCt;
|
||||||
|
delete exclCt;
|
||||||
|
f.setGroup (commGroup);
|
||||||
|
return { };
|
||||||
|
}
|
||||||
|
|
||||||
Parfactors result;
|
Parfactors result;
|
||||||
if (f.isCounting()) {
|
if (f.isCounting()) {
|
||||||
LogVar X_new1 = g->constr()->logVarSet().back() + 1;
|
LogVar X_new1 = g->constr()->logVarSet().back() + 1;
|
||||||
@ -259,7 +333,7 @@ ParfactorList::shatter (
|
|||||||
for (unsigned i = 0; i < cts.size(); i++) {
|
for (unsigned i = 0; i < cts.size(); i++) {
|
||||||
Parfactor* newPf = new Parfactor (g, cts[i]);
|
Parfactor* newPf = new Parfactor (g, cts[i]);
|
||||||
if (cts[i]->nrLogVars() == g->constr()->nrLogVars() + 1) {
|
if (cts[i]->nrLogVars() == g->constr()->nrLogVars() + 1) {
|
||||||
newPf->expandPotential (f.countedLogVar(), X_new1, X_new2);
|
newPf->expand (f.countedLogVar(), X_new1, X_new2);
|
||||||
assert (g->constr()->getConditionalCount (f.countedLogVar()) ==
|
assert (g->constr()->getConditionalCount (f.countedLogVar()) ==
|
||||||
cts[i]->getConditionalCount (X_new1) +
|
cts[i]->getConditionalCount (X_new1) +
|
||||||
cts[i]->getConditionalCount (X_new2));
|
cts[i]->getConditionalCount (X_new2));
|
||||||
@ -270,21 +344,17 @@ ParfactorList::shatter (
|
|||||||
newPf->setNewGroups();
|
newPf->setNewGroups();
|
||||||
result.push_back (newPf);
|
result.push_back (newPf);
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
if (exclCt->empty()) {
|
|
||||||
delete commCt;
|
delete commCt;
|
||||||
delete exclCt;
|
delete exclCt;
|
||||||
g->setFormulaGroup (f, commGroup);
|
|
||||||
} else {
|
} else {
|
||||||
Parfactor* newPf = new Parfactor (g, commCt);
|
Parfactor* newPf = new Parfactor (g, commCt);
|
||||||
newPf->setNewGroups();
|
newPf->setNewGroups();
|
||||||
newPf->setFormulaGroup (f, commGroup);
|
newPf->argument (fIdx).setGroup (commGroup);
|
||||||
result.push_back (newPf);
|
result.push_back (newPf);
|
||||||
newPf = new Parfactor (g, exclCt);
|
newPf = new Parfactor (g, exclCt);
|
||||||
newPf->setNewGroups();
|
newPf->setNewGroups();
|
||||||
result.push_back (newPf);
|
result.push_back (newPf);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -296,7 +366,7 @@ ParfactorList::unifyGroups (unsigned group1, unsigned group2)
|
|||||||
unsigned newGroup = ProbFormula::getNewGroup();
|
unsigned newGroup = ProbFormula::getNewGroup();
|
||||||
for (ParfactorList::iterator it = pfList_.begin();
|
for (ParfactorList::iterator it = pfList_.begin();
|
||||||
it != pfList_.end(); it++) {
|
it != pfList_.end(); it++) {
|
||||||
ProbFormulas& formulas = (*it)->formulas();
|
ProbFormulas& formulas = (*it)->arguments();
|
||||||
for (unsigned i = 0; i < formulas.size(); i++) {
|
for (unsigned i = 0; i < formulas.size(); i++) {
|
||||||
if (formulas[i].group() == group1 ||
|
if (formulas[i].group() == group1 ||
|
||||||
formulas[i].group() == group2) {
|
formulas[i].group() == group2) {
|
||||||
@ -306,3 +376,52 @@ ParfactorList::unifyGroups (unsigned group1, unsigned group2)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
ParfactorList::proper (
|
||||||
|
const ProbFormula& f1, ConstraintTree c1,
|
||||||
|
const ProbFormula& f2, ConstraintTree c2) const
|
||||||
|
{
|
||||||
|
return disjoint (f1, c1, f2, c2)
|
||||||
|
|| identical (f1, c1, f2, c2);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
ParfactorList::identical (
|
||||||
|
const ProbFormula& f1, ConstraintTree c1,
|
||||||
|
const ProbFormula& f2, ConstraintTree c2) const
|
||||||
|
{
|
||||||
|
if (f1.sameSkeletonAs (f2) == false) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (f1.isAtom()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
c1.moveToTop (f1.logVars());
|
||||||
|
c2.moveToTop (f2.logVars());
|
||||||
|
return ConstraintTree::identical (
|
||||||
|
&c1, &c2, f1.logVars().size());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
ParfactorList::disjoint (
|
||||||
|
const ProbFormula& f1, ConstraintTree c1,
|
||||||
|
const ProbFormula& f2, ConstraintTree c2) const
|
||||||
|
{
|
||||||
|
if (f1.sameSkeletonAs (f2) == false) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (f1.isAtom()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
c1.moveToTop (f1.logVars());
|
||||||
|
c2.moveToTop (f2.logVars());
|
||||||
|
return ConstraintTree::overlap (
|
||||||
|
&c1, &c2, f1.arity()) == false;
|
||||||
|
}
|
||||||
|
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
#define HORUS_PARFACTORLIST_H
|
#define HORUS_PARFACTORLIST_H
|
||||||
|
|
||||||
#include <list>
|
#include <list>
|
||||||
|
#include <queue>
|
||||||
|
|
||||||
#include "Parfactor.h"
|
#include "Parfactor.h"
|
||||||
#include "ProbFormula.h"
|
#include "ProbFormula.h"
|
||||||
@ -14,55 +15,81 @@ class ParfactorList
|
|||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
ParfactorList (void) { }
|
ParfactorList (void) { }
|
||||||
ParfactorList (Parfactors&);
|
|
||||||
list<Parfactor*>& getParfactors (void) { return pfList_; }
|
|
||||||
const list<Parfactor*>& getParfactors (void) const { return pfList_; }
|
|
||||||
|
|
||||||
void add (Parfactor* pf);
|
ParfactorList (const ParfactorList&);
|
||||||
void add (Parfactors& pfs);
|
|
||||||
void addShattered (Parfactor* pf);
|
ParfactorList (const Parfactors&);
|
||||||
list<Parfactor*>::iterator remove (list<Parfactor*>::iterator);
|
|
||||||
list<Parfactor*>::iterator deleteAndRemove (list<Parfactor*>::iterator);
|
~ParfactorList (void);
|
||||||
|
|
||||||
|
const list<Parfactor*>& parfactors (void) const { return pfList_; }
|
||||||
|
|
||||||
void clear (void) { pfList_.clear(); }
|
void clear (void) { pfList_.clear(); }
|
||||||
|
|
||||||
unsigned size (void) const { return pfList_.size(); }
|
unsigned size (void) const { return pfList_.size(); }
|
||||||
|
|
||||||
|
|
||||||
void shatter (void);
|
|
||||||
|
|
||||||
typedef std::list<Parfactor*>::iterator iterator;
|
typedef std::list<Parfactor*>::iterator iterator;
|
||||||
|
|
||||||
iterator begin (void) { return pfList_.begin(); }
|
iterator begin (void) { return pfList_.begin(); }
|
||||||
|
|
||||||
iterator end (void) { return pfList_.end(); }
|
iterator end (void) { return pfList_.end(); }
|
||||||
|
|
||||||
typedef std::list<Parfactor*>::const_iterator const_iterator;
|
typedef std::list<Parfactor*>::const_iterator const_iterator;
|
||||||
|
|
||||||
const_iterator begin (void) const { return pfList_.begin(); }
|
const_iterator begin (void) const { return pfList_.begin(); }
|
||||||
|
|
||||||
const_iterator end (void) const { return pfList_.end(); }
|
const_iterator end (void) const { return pfList_.end(); }
|
||||||
|
|
||||||
|
void add (Parfactor* pf);
|
||||||
|
|
||||||
|
void add (const Parfactors& pfs);
|
||||||
|
|
||||||
|
void addShattered (Parfactor* pf);
|
||||||
|
|
||||||
|
list<Parfactor*>::iterator insertShattered (
|
||||||
|
list<Parfactor*>::iterator, Parfactor*);
|
||||||
|
|
||||||
|
list<Parfactor*>::iterator remove (list<Parfactor*>::iterator);
|
||||||
|
|
||||||
|
list<Parfactor*>::iterator removeAndDelete (list<Parfactor*>::iterator);
|
||||||
|
|
||||||
|
bool isAllShattered (void) const;
|
||||||
|
|
||||||
void print (void) const;
|
void print (void) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
||||||
static std::pair<Parfactors, Parfactors> shatter (
|
bool isShattered (const Parfactor*, const Parfactor*) const;
|
||||||
ProbFormulas&,
|
|
||||||
Parfactor*,
|
|
||||||
ProbFormulas&,
|
|
||||||
Parfactor*);
|
|
||||||
|
|
||||||
static std::pair<Parfactors, Parfactors> shatter (
|
void addToShatteredList (Parfactor*);
|
||||||
ProbFormula&,
|
|
||||||
Parfactor*,
|
|
||||||
ProbFormula&,
|
|
||||||
Parfactor*);
|
|
||||||
|
|
||||||
static Parfactors shatter (
|
std::pair<Parfactors, Parfactors> shatter (
|
||||||
|
Parfactor*, Parfactor*);
|
||||||
|
|
||||||
|
std::pair<Parfactors, Parfactors> shatter (
|
||||||
|
unsigned, Parfactor*, unsigned, Parfactor*);
|
||||||
|
|
||||||
|
Parfactors shatter (
|
||||||
Parfactor*,
|
Parfactor*,
|
||||||
const ProbFormula&,
|
unsigned,
|
||||||
ConstraintTree*,
|
ConstraintTree*,
|
||||||
ConstraintTree*,
|
ConstraintTree*,
|
||||||
unsigned);
|
unsigned);
|
||||||
|
|
||||||
void unifyGroups (unsigned group1, unsigned group2);
|
void unifyGroups (unsigned group1, unsigned group2);
|
||||||
|
|
||||||
|
bool proper (
|
||||||
|
const ProbFormula&, ConstraintTree,
|
||||||
|
const ProbFormula&, ConstraintTree) const;
|
||||||
|
|
||||||
|
bool identical (
|
||||||
|
const ProbFormula&, ConstraintTree,
|
||||||
|
const ProbFormula&, ConstraintTree) const;
|
||||||
|
|
||||||
|
bool disjoint (
|
||||||
|
const ProbFormula&, ConstraintTree,
|
||||||
|
const ProbFormula&, ConstraintTree) const;
|
||||||
|
|
||||||
list<Parfactor*> pfList_;
|
list<Parfactor*> pfList_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -16,8 +16,7 @@ ProbFormula::sameSkeletonAs (const ProbFormula& f) const
|
|||||||
bool
|
bool
|
||||||
ProbFormula::contains (LogVar lv) const
|
ProbFormula::contains (LogVar lv) const
|
||||||
{
|
{
|
||||||
return std::find (logVars_.begin(), logVars_.end(), lv) !=
|
return Util::contains (logVars_, lv);
|
||||||
logVars_.end();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -77,16 +76,15 @@ ProbFormula::rename (LogVar oldName, LogVar newName)
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
bool operator== (const ProbFormula& f1, const ProbFormula& f2)
|
||||||
bool
|
|
||||||
ProbFormula::operator== (const ProbFormula& f) const
|
|
||||||
{
|
{
|
||||||
return functor_ == f.functor_ && logVars_ == f.logVars_ ;
|
return f1.group_ == f2.group_;
|
||||||
|
//return functor_ == f.functor_ && logVars_ == f.logVars_ ;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
ostream& operator<< (ostream &os, const ProbFormula& f)
|
std::ostream& operator<< (ostream &os, const ProbFormula& f)
|
||||||
{
|
{
|
||||||
os << f.functor_;
|
os << f.functor_;
|
||||||
if (f.isAtom() == false) {
|
if (f.isAtom() == false) {
|
||||||
@ -113,3 +111,13 @@ ProbFormula::getNewGroup (void)
|
|||||||
return freeGroup_;
|
return freeGroup_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
ostream& operator<< (ostream &os, const ObservedFormula& of)
|
||||||
|
{
|
||||||
|
os << of.functor_ << "/" << of.arity_;
|
||||||
|
os << "|" << of.constr_.tupleSet();
|
||||||
|
os << " [evidence=" << of.evidence_ << "]";
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
||||||
|
@ -8,14 +8,16 @@
|
|||||||
#include "Horus.h"
|
#include "Horus.h"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ProbFormula
|
class ProbFormula
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
ProbFormula (Symbol f, const LogVars& lvs, unsigned range)
|
ProbFormula (Symbol f, const LogVars& lvs, unsigned range)
|
||||||
: functor_(f), logVars_(lvs), range_(range),
|
: functor_(f), logVars_(lvs), range_(range),
|
||||||
countedLogVar_() { }
|
countedLogVar_(), group_(Util::maxUnsigned()) { }
|
||||||
|
|
||||||
ProbFormula (Symbol f, unsigned r) : functor_(f), range_(r) { }
|
ProbFormula (Symbol f, unsigned r)
|
||||||
|
: functor_(f), range_(r), group_(Util::maxUnsigned()) { }
|
||||||
|
|
||||||
Symbol functor (void) const { return functor_; }
|
Symbol functor (void) const { return functor_; }
|
||||||
|
|
||||||
@ -29,9 +31,9 @@ class ProbFormula
|
|||||||
|
|
||||||
LogVarSet logVarSet (void) const { return LogVarSet (logVars_); }
|
LogVarSet logVarSet (void) const { return LogVarSet (logVars_); }
|
||||||
|
|
||||||
unsigned group (void) const { return groupId_; }
|
unsigned group (void) const { return group_; }
|
||||||
|
|
||||||
void setGroup (unsigned g) { groupId_ = g; }
|
void setGroup (unsigned g) { group_ = g; }
|
||||||
|
|
||||||
bool sameSkeletonAs (const ProbFormula&) const;
|
bool sameSkeletonAs (const ProbFormula&) const;
|
||||||
|
|
||||||
@ -49,23 +51,58 @@ class ProbFormula
|
|||||||
|
|
||||||
void rename (LogVar, LogVar);
|
void rename (LogVar, LogVar);
|
||||||
|
|
||||||
bool operator== (const ProbFormula& f) const;
|
|
||||||
|
|
||||||
friend ostream& operator<< (ostream &out, const ProbFormula& f);
|
|
||||||
|
|
||||||
static unsigned getNewGroup (void);
|
static unsigned getNewGroup (void);
|
||||||
|
|
||||||
|
friend std::ostream& operator<< (ostream &os, const ProbFormula& f);
|
||||||
|
|
||||||
|
friend bool operator== (const ProbFormula& f1, const ProbFormula& f2);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Symbol functor_;
|
Symbol functor_;
|
||||||
LogVars logVars_;
|
LogVars logVars_;
|
||||||
unsigned range_;
|
unsigned range_;
|
||||||
LogVar countedLogVar_;
|
LogVar countedLogVar_;
|
||||||
unsigned groupId_;
|
unsigned group_;
|
||||||
static int freeGroup_;
|
static int freeGroup_;
|
||||||
};
|
};
|
||||||
|
|
||||||
typedef vector<ProbFormula> ProbFormulas;
|
typedef vector<ProbFormula> ProbFormulas;
|
||||||
|
|
||||||
|
|
||||||
|
class ObservedFormula
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
ObservedFormula (Symbol f, unsigned a, unsigned ev)
|
||||||
|
: functor_(f), arity_(a), evidence_(ev), constr_(a) { }
|
||||||
|
|
||||||
|
ObservedFormula (Symbol f, unsigned ev, const Tuple& tuple)
|
||||||
|
: functor_(f), arity_(tuple.size()), evidence_(ev), constr_(arity_)
|
||||||
|
{
|
||||||
|
constr_.addTuple (tuple);
|
||||||
|
}
|
||||||
|
|
||||||
|
Symbol functor (void) const { return functor_; }
|
||||||
|
|
||||||
|
unsigned arity (void) const { return arity_; }
|
||||||
|
|
||||||
|
unsigned evidence (void) const { return evidence_; }
|
||||||
|
|
||||||
|
ConstraintTree& constr (void) { return constr_; }
|
||||||
|
|
||||||
|
bool isAtom (void) const { return arity_ == 0; }
|
||||||
|
|
||||||
|
void addTuple (const Tuple& tuple) { constr_.addTuple (tuple); }
|
||||||
|
|
||||||
|
friend ostream& operator<< (ostream &os, const ObservedFormula& of);
|
||||||
|
|
||||||
|
private:
|
||||||
|
Symbol functor_;
|
||||||
|
unsigned arity_;
|
||||||
|
unsigned evidence_;
|
||||||
|
ConstraintTree constr_;
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef vector<ObservedFormula> ObservedFormulas;
|
||||||
|
|
||||||
#endif // HORUS_PROBFORMULA_H
|
#endif // HORUS_PROBFORMULA_H
|
||||||
|
|
||||||
|
@ -21,7 +21,7 @@ Solver::printPosterioriOf (VarId vid)
|
|||||||
const States& states = var->states();
|
const States& states = var->states();
|
||||||
for (unsigned i = 0; i < states.size(); i++) {
|
for (unsigned i = 0; i < states.size(); i++) {
|
||||||
cout << "P(" << var->label() << "=" << states[i] << ") = " ;
|
cout << "P(" << var->label() << "=" << states[i] << ") = " ;
|
||||||
cout << setprecision (PRECISION) << posterioriDist[i];
|
cout << setprecision (Constants::PRECISION) << posterioriDist[i];
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
cout << endl;
|
cout << endl;
|
||||||
@ -45,7 +45,7 @@ Solver::printJointDistributionOf (const VarIds& vids)
|
|||||||
vector<string> jointStrings = Util::getJointStateStrings (vars);
|
vector<string> jointStrings = Util::getJointStateStrings (vars);
|
||||||
for (unsigned i = 0; i < jointDist.size(); i++) {
|
for (unsigned i = 0; i < jointDist.size(); i++) {
|
||||||
cout << "P(" << jointStrings[i] << ") = " ;
|
cout << "P(" << jointStrings[i] << ") = " ;
|
||||||
cout << setprecision (PRECISION) << jointDist[i];
|
cout << setprecision (Constants::PRECISION) << jointDist[i];
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
cout << endl;
|
cout << endl;
|
||||||
|
@ -11,17 +11,20 @@ using namespace std;
|
|||||||
class Solver
|
class Solver
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
Solver (const GraphicalModel* gm)
|
Solver (const GraphicalModel* gm) : gm_(gm) { }
|
||||||
{
|
|
||||||
gm_ = gm;
|
virtual ~Solver() { } // ensure that subclass destructor is called
|
||||||
}
|
|
||||||
virtual ~Solver() {} // to ensure that subclass destructor is called
|
|
||||||
virtual void runSolver (void) = 0;
|
virtual void runSolver (void) = 0;
|
||||||
|
|
||||||
virtual Params getPosterioriOf (VarId) = 0;
|
virtual Params getPosterioriOf (VarId) = 0;
|
||||||
|
|
||||||
virtual Params getJointDistributionOf (const VarIds&) = 0;
|
virtual Params getJointDistributionOf (const VarIds&) = 0;
|
||||||
|
|
||||||
void printAllPosterioris (void);
|
void printAllPosterioris (void);
|
||||||
|
|
||||||
void printPosterioriOf (VarId vid);
|
void printPosterioriOf (VarId vid);
|
||||||
|
|
||||||
void printJointDistributionOf (const VarIds& vids);
|
void printJointDistributionOf (const VarIds& vids);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
5
packages/CLPBN/clpbn/bp/TODO
Normal file
5
packages/CLPBN/clpbn/bp/TODO
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
TODO
|
||||||
|
|
||||||
|
- add way to calculate combinations and factorials with large numbers
|
||||||
|
- refactor sumOut in parfactor -> is really ugly code
|
||||||
|
- Indexer: start receiving ranges as constant reference
|
@ -1,4 +1,7 @@
|
|||||||
|
#include <limits>
|
||||||
|
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
#include <fstream>
|
||||||
|
|
||||||
#include "Util.h"
|
#include "Util.h"
|
||||||
#include "Indexer.h"
|
#include "Indexer.h"
|
||||||
@ -6,16 +9,15 @@
|
|||||||
|
|
||||||
|
|
||||||
namespace Globals {
|
namespace Globals {
|
||||||
bool logDomain = false;
|
bool logDomain = false;
|
||||||
|
|
||||||
|
//InfAlgs infAlgorithm = InfAlgorithms::VE;
|
||||||
|
//InfAlgs infAlgorithm = InfAlgorithms::BN_BP;
|
||||||
|
//InfAlgs infAlgorithm = InfAlgorithms::FG_BP;
|
||||||
|
InfAlgorithms infAlgorithm = InfAlgorithms::CBP;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
namespace InfAlgorithms {
|
|
||||||
//InfAlgs infAlgorithm = InfAlgorithms::VE;
|
|
||||||
//InfAlgs infAlgorithm = InfAlgorithms::BN_BP;
|
|
||||||
InfAlgs infAlgorithm = InfAlgorithms::FG_BP;
|
|
||||||
//InfAlgs infAlgorithm = InfAlgorithms::CBP;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
namespace BpOptions {
|
namespace BpOptions {
|
||||||
@ -28,8 +30,7 @@ unsigned maxIter = 1000;
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
unordered_map<VarId,VariableInfo> GraphicalModel::varsInfo_;
|
unordered_map<VarId, VarInfo> GraphicalModel::varsInfo_;
|
||||||
unordered_map<unsigned,Distribution*> GraphicalModel::distsInfo_;
|
|
||||||
|
|
||||||
vector<NetInfo> Statistics::netInfo_;
|
vector<NetInfo> Statistics::netInfo_;
|
||||||
vector<CompressInfo> Statistics::compressInfo_;
|
vector<CompressInfo> Statistics::compressInfo_;
|
||||||
@ -58,76 +59,6 @@ fromLog (Params& v)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
normalize (Params& v)
|
|
||||||
{
|
|
||||||
double sum;
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
sum = addIdenty();
|
|
||||||
for (unsigned i = 0; i < v.size(); i++) {
|
|
||||||
logSum (sum, v[i]);
|
|
||||||
}
|
|
||||||
assert (sum != -numeric_limits<double>::infinity());
|
|
||||||
for (unsigned i = 0; i < v.size(); i++) {
|
|
||||||
v[i] -= sum;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
sum = 0.0;
|
|
||||||
for (unsigned i = 0; i < v.size(); i++) {
|
|
||||||
sum += v[i];
|
|
||||||
}
|
|
||||||
assert (sum != 0.0);
|
|
||||||
for (unsigned i = 0; i < v.size(); i++) {
|
|
||||||
v[i] /= sum;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
pow (Params& v, double expoent)
|
|
||||||
{
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
for (unsigned i = 0; i < v.size(); i++) {
|
|
||||||
v[i] *= expoent;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (unsigned i = 0; i < v.size(); i++) {
|
|
||||||
v[i] = std::pow (v[i], expoent);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
pow (Params& v, unsigned expoent)
|
|
||||||
{
|
|
||||||
if (expoent == 1) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
for (unsigned i = 0; i < v.size(); i++) {
|
|
||||||
v[i] *= expoent;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (unsigned i = 0; i < v.size(); i++) {
|
|
||||||
v[i] = std::pow (v[i], expoent);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
double
|
|
||||||
pow (double p, unsigned expoent)
|
|
||||||
{
|
|
||||||
return Globals::logDomain ? p * expoent : std::pow (p, expoent);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
double
|
double
|
||||||
factorial (double num)
|
factorial (double num)
|
||||||
{
|
{
|
||||||
@ -153,52 +84,21 @@ nrCombinations (unsigned n, unsigned r)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
double
|
unsigned
|
||||||
getL1Distance (const Params& v1, const Params& v2)
|
expectedSize (const Ranges& ranges)
|
||||||
{
|
{
|
||||||
assert (v1.size() == v2.size());
|
unsigned prod = 1;
|
||||||
double dist = 0.0;
|
for (unsigned i = 0; i < ranges.size(); i++) {
|
||||||
if (Globals::logDomain) {
|
prod *= ranges[i];
|
||||||
for (unsigned i = 0; i < v1.size(); i++) {
|
|
||||||
dist += abs (exp(v1[i]) - exp(v2[i]));
|
|
||||||
}
|
}
|
||||||
} else {
|
return prod;
|
||||||
for (unsigned i = 0; i < v1.size(); i++) {
|
|
||||||
dist += abs (v1[i] - v2[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return dist;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
double
|
|
||||||
getMaxNorm (const Params& v1, const Params& v2)
|
|
||||||
{
|
|
||||||
assert (v1.size() == v2.size());
|
|
||||||
double max = 0.0;
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
for (unsigned i = 0; i < v1.size(); i++) {
|
|
||||||
double diff = abs (exp(v1[i]) - exp(v2[i]));
|
|
||||||
if (diff > max) {
|
|
||||||
max = diff;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (unsigned i = 0; i < v1.size(); i++) {
|
|
||||||
double diff = abs (v1[i] - v2[i]);
|
|
||||||
if (diff > max) {
|
|
||||||
max = diff;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return max;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
unsigned
|
unsigned
|
||||||
getNumberOfDigits (int number) {
|
getNumberOfDigits (int number)
|
||||||
|
{
|
||||||
unsigned count = 1;
|
unsigned count = 1;
|
||||||
while (number >= 10) {
|
while (number >= 10) {
|
||||||
number /= 10;
|
number /= 10;
|
||||||
@ -257,6 +157,168 @@ getJointStateStrings (const VarNodes& vars)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void printHeader (string header, std::ostream& os)
|
||||||
|
{
|
||||||
|
printAsteriskLine (os);
|
||||||
|
os << header << endl;
|
||||||
|
printAsteriskLine (os);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void printSubHeader (string header, std::ostream& os)
|
||||||
|
{
|
||||||
|
printDashedLine (os);
|
||||||
|
os << header << endl;
|
||||||
|
printDashedLine (os);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void printAsteriskLine (std::ostream& os)
|
||||||
|
{
|
||||||
|
os << "********************************" ;
|
||||||
|
os << "********************************" ;
|
||||||
|
os << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void printDashedLine (std::ostream& os)
|
||||||
|
{
|
||||||
|
os << "--------------------------------" ;
|
||||||
|
os << "--------------------------------" ;
|
||||||
|
os << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
namespace LogAware {
|
||||||
|
|
||||||
|
void
|
||||||
|
normalize (Params& v)
|
||||||
|
{
|
||||||
|
double sum;
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
sum = LogAware::addIdenty();
|
||||||
|
for (unsigned i = 0; i < v.size(); i++) {
|
||||||
|
sum = Util::logSum (sum, v[i]);
|
||||||
|
}
|
||||||
|
assert (sum != -numeric_limits<double>::infinity());
|
||||||
|
for (unsigned i = 0; i < v.size(); i++) {
|
||||||
|
v[i] -= sum;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
sum = 0.0;
|
||||||
|
for (unsigned i = 0; i < v.size(); i++) {
|
||||||
|
sum += v[i];
|
||||||
|
}
|
||||||
|
assert (sum != 0.0);
|
||||||
|
for (unsigned i = 0; i < v.size(); i++) {
|
||||||
|
v[i] /= sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
double
|
||||||
|
getL1Distance (const Params& v1, const Params& v2)
|
||||||
|
{
|
||||||
|
assert (v1.size() == v2.size());
|
||||||
|
double dist = 0.0;
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
for (unsigned i = 0; i < v1.size(); i++) {
|
||||||
|
dist += abs (exp(v1[i]) - exp(v2[i]));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (unsigned i = 0; i < v1.size(); i++) {
|
||||||
|
dist += abs (v1[i] - v2[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return dist;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
double
|
||||||
|
getMaxNorm (const Params& v1, const Params& v2)
|
||||||
|
{
|
||||||
|
assert (v1.size() == v2.size());
|
||||||
|
double max = 0.0;
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
for (unsigned i = 0; i < v1.size(); i++) {
|
||||||
|
double diff = abs (exp(v1[i]) - exp(v2[i]));
|
||||||
|
if (diff > max) {
|
||||||
|
max = diff;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (unsigned i = 0; i < v1.size(); i++) {
|
||||||
|
double diff = abs (v1[i] - v2[i]);
|
||||||
|
if (diff > max) {
|
||||||
|
max = diff;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return max;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
double
|
||||||
|
pow (double p, unsigned expoent)
|
||||||
|
{
|
||||||
|
return Globals::logDomain ? p * expoent : std::pow (p, expoent);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
double
|
||||||
|
pow (double p, double expoent)
|
||||||
|
{
|
||||||
|
// assumes that `expoent' is never in log domain
|
||||||
|
return Globals::logDomain ? p * expoent : std::pow (p, expoent);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
pow (Params& v, unsigned expoent)
|
||||||
|
{
|
||||||
|
if (expoent == 1) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
for (unsigned i = 0; i < v.size(); i++) {
|
||||||
|
v[i] *= expoent;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (unsigned i = 0; i < v.size(); i++) {
|
||||||
|
v[i] = std::pow (v[i], expoent);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
pow (Params& v, double expoent)
|
||||||
|
{
|
||||||
|
// assumes that `expoent' is never in log domain
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
for (unsigned i = 0; i < v.size(); i++) {
|
||||||
|
v[i] *= expoent;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (unsigned i = 0; i < v.size(); i++) {
|
||||||
|
v[i] = std::pow (v[i], expoent);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -286,8 +348,11 @@ Statistics::getPrimaryNetworksCounting (void)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
Statistics::updateStatistics (unsigned size, bool loopy,
|
Statistics::updateStatistics (
|
||||||
unsigned nIters, double time)
|
unsigned size,
|
||||||
|
bool loopy,
|
||||||
|
unsigned nIters,
|
||||||
|
double time)
|
||||||
{
|
{
|
||||||
netInfo_.push_back (NetInfo (size, loopy, nIters, time));
|
netInfo_.push_back (NetInfo (size, loopy, nIters, time));
|
||||||
}
|
}
|
||||||
@ -318,7 +383,8 @@ Statistics::writeStatisticsToFile (const char* fileName)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
Statistics::updateCompressingStatistics (unsigned nGroundVars,
|
Statistics::updateCompressingStatistics (
|
||||||
|
unsigned nGroundVars,
|
||||||
unsigned nGroundFactors,
|
unsigned nGroundFactors,
|
||||||
unsigned nClusterVars,
|
unsigned nClusterVars,
|
||||||
unsigned nClusterFactors,
|
unsigned nClusterFactors,
|
||||||
@ -334,7 +400,7 @@ Statistics::getStatisticString (void)
|
|||||||
{
|
{
|
||||||
stringstream ss2, ss3, ss4, ss1;
|
stringstream ss2, ss3, ss4, ss1;
|
||||||
ss1 << "running mode: " ;
|
ss1 << "running mode: " ;
|
||||||
switch (InfAlgorithms::infAlgorithm) {
|
switch (Globals::infAlgorithm) {
|
||||||
case InfAlgorithms::VE: ss1 << "ve" << endl; break;
|
case InfAlgorithms::VE: ss1 << "ve" << endl; break;
|
||||||
case InfAlgorithms::BN_BP: ss1 << "bn_bp" << endl; break;
|
case InfAlgorithms::BN_BP: ss1 << "bn_bp" << endl; break;
|
||||||
case InfAlgorithms::FG_BP: ss1 << "fg_bp" << endl; break;
|
case InfAlgorithms::FG_BP: ss1 << "fg_bp" << endl; break;
|
||||||
@ -342,18 +408,23 @@ Statistics::getStatisticString (void)
|
|||||||
}
|
}
|
||||||
ss1 << "message schedule: " ;
|
ss1 << "message schedule: " ;
|
||||||
switch (BpOptions::schedule) {
|
switch (BpOptions::schedule) {
|
||||||
case BpOptions::Schedule::SEQ_FIXED: ss1 << "sequential fixed" << endl; break;
|
case BpOptions::Schedule::SEQ_FIXED:
|
||||||
case BpOptions::Schedule::SEQ_RANDOM: ss1 << "sequential random" << endl; break;
|
ss1 << "sequential fixed" << endl;
|
||||||
case BpOptions::Schedule::PARALLEL: ss1 << "parallel" << endl; break;
|
break;
|
||||||
case BpOptions::Schedule::MAX_RESIDUAL: ss1 << "max residual" << endl; break;
|
case BpOptions::Schedule::SEQ_RANDOM:
|
||||||
|
ss1 << "sequential random" << endl;
|
||||||
|
break;
|
||||||
|
case BpOptions::Schedule::PARALLEL:
|
||||||
|
ss1 << "parallel" << endl;
|
||||||
|
break;
|
||||||
|
case BpOptions::Schedule::MAX_RESIDUAL:
|
||||||
|
ss1 << "max residual" << endl;
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
ss1 << "max iterations: " << BpOptions::maxIter << endl;
|
ss1 << "max iterations: " << BpOptions::maxIter << endl;
|
||||||
ss1 << "accuracy " << BpOptions::accuracy << endl;
|
ss1 << "accuracy " << BpOptions::accuracy << endl;
|
||||||
ss1 << endl << endl;
|
ss1 << endl << endl;
|
||||||
|
Util::printSubHeader ("Network information", ss2);
|
||||||
ss2 << "---------------------------------------------------" << endl;
|
|
||||||
ss2 << " Network information" << endl;
|
|
||||||
ss2 << "---------------------------------------------------" << endl;
|
|
||||||
ss2 << left;
|
ss2 << left;
|
||||||
ss2 << setw (15) << "Network Size" ;
|
ss2 << setw (15) << "Network Size" ;
|
||||||
ss2 << setw (9) << "Loopy" ;
|
ss2 << setw (9) << "Loopy" ;
|
||||||
@ -387,9 +458,7 @@ Statistics::getStatisticString (void)
|
|||||||
|
|
||||||
unsigned c1 = 0, c2 = 0, c3 = 0, c4 = 0;
|
unsigned c1 = 0, c2 = 0, c3 = 0, c4 = 0;
|
||||||
if (compressInfo_.size() > 0) {
|
if (compressInfo_.size() > 0) {
|
||||||
ss3 << "---------------------------------------------------" << endl;
|
Util::printSubHeader ("Compress information", ss3);
|
||||||
ss3 << " Compression information" << endl;
|
|
||||||
ss3 << "---------------------------------------------------" << endl;
|
|
||||||
ss3 << left;
|
ss3 << left;
|
||||||
ss3 << "Ground Cluster Ground Cluster Neighborless" << endl;
|
ss3 << "Ground Cluster Ground Cluster Neighborless" << endl;
|
||||||
ss3 << "Vars Vars Factors Factors Vars" << endl;
|
ss3 << "Vars Vars Factors Factors Vars" << endl;
|
||||||
|
@ -1,53 +1,131 @@
|
|||||||
#ifndef HORUS_UTIL_H
|
#ifndef HORUS_UTIL_H
|
||||||
#define HORUS_UTIL_H
|
#define HORUS_UTIL_H
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
#include <cassert>
|
||||||
|
#include <limits>
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <set>
|
||||||
|
#include <queue>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
|
#include <sstream>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
#include "Horus.h"
|
#include "Horus.h"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
|
|
||||||
namespace Util {
|
namespace Util {
|
||||||
|
|
||||||
|
template <typename T> void addToVector (vector<T>&, const vector<T>&);
|
||||||
|
|
||||||
|
template <typename T> void addToQueue (queue<T>&, const vector<T>&);
|
||||||
|
|
||||||
|
template <typename T> bool contains (const vector<T>&, const T&);
|
||||||
|
|
||||||
|
template <typename T> bool contains (const set<T>&, const T&);
|
||||||
|
|
||||||
|
template <typename K, typename V> bool contains (
|
||||||
|
const unordered_map<K, V>&, const K&);
|
||||||
|
|
||||||
|
template <typename T> std::string toString (const T&);
|
||||||
|
|
||||||
void toLog (Params&);
|
void toLog (Params&);
|
||||||
|
|
||||||
void fromLog (Params&);
|
void fromLog (Params&);
|
||||||
void normalize (Params&);
|
|
||||||
void logSum (double&, double);
|
double logSum (double, double);
|
||||||
|
|
||||||
void multiply (Params&, const Params&);
|
void multiply (Params&, const Params&);
|
||||||
|
|
||||||
void multiply (Params&, const Params&, unsigned);
|
void multiply (Params&, const Params&, unsigned);
|
||||||
|
|
||||||
void add (Params&, const Params&);
|
void add (Params&, const Params&);
|
||||||
|
|
||||||
void add (Params&, const Params&, unsigned);
|
void add (Params&, const Params&, unsigned);
|
||||||
void pow (Params&, double);
|
|
||||||
void pow (Params&, unsigned);
|
|
||||||
double pow (double, unsigned);
|
|
||||||
double factorial (double);
|
double factorial (double);
|
||||||
|
|
||||||
unsigned nrCombinations (unsigned, unsigned);
|
unsigned nrCombinations (unsigned, unsigned);
|
||||||
double getL1Distance (const Params&, const Params&);
|
|
||||||
double getMaxNorm (const Params&, const Params&);
|
unsigned expectedSize (const Ranges&);
|
||||||
|
|
||||||
unsigned getNumberOfDigits (int);
|
unsigned getNumberOfDigits (int);
|
||||||
|
|
||||||
bool isInteger (const string&);
|
bool isInteger (const string&);
|
||||||
string parametersToString (const Params&, unsigned = PRECISION);
|
|
||||||
|
string parametersToString (const Params&, unsigned = Constants::PRECISION);
|
||||||
|
|
||||||
vector<string> getJointStateStrings (const VarNodes&);
|
vector<string> getJointStateStrings (const VarNodes&);
|
||||||
double tl (double);
|
|
||||||
double fl (double);
|
void printHeader (string, std::ostream& os = std::cout);
|
||||||
double multIdenty();
|
|
||||||
double addIdenty();
|
void printSubHeader (string, std::ostream& os = std::cout);
|
||||||
double withEvidence();
|
|
||||||
double noEvidence();
|
void printAsteriskLine (std::ostream& os = std::cout);
|
||||||
double one();
|
|
||||||
double zero();
|
void printDashedLine (std::ostream& os = std::cout);
|
||||||
|
|
||||||
|
unsigned maxUnsigned (void);
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
template <class T>
|
|
||||||
std::string toString (const T& t)
|
template <typename T> void
|
||||||
|
Util::addToVector (vector<T>& v, const vector<T>& elements)
|
||||||
|
{
|
||||||
|
v.insert (v.end(), elements.begin(), elements.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T> void
|
||||||
|
Util::addToQueue (queue<T>& q, const vector<T>& elements)
|
||||||
|
{
|
||||||
|
for (unsigned i = 0; i < elements.size(); i++) {
|
||||||
|
q.push (elements[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T> bool
|
||||||
|
Util::contains (const vector<T>& v, const T& e)
|
||||||
|
{
|
||||||
|
return std::find (v.begin(), v.end(), e) != v.end();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T> bool
|
||||||
|
Util::contains (const set<T>& s, const T& e)
|
||||||
|
{
|
||||||
|
return s.find (e) != s.end();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <typename K, typename V> bool
|
||||||
|
Util::contains (
|
||||||
|
const unordered_map<K, V>& m, const K& k)
|
||||||
|
{
|
||||||
|
return m.find (k) != m.end();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T> std::string
|
||||||
|
Util::toString (const T& t)
|
||||||
{
|
{
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << t;
|
ss << t;
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@ -62,28 +140,31 @@ std::ostream& operator << (std::ostream& os, const vector<T>& v)
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
const double INF = -numeric_limits<double>::infinity();
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
inline void
|
inline double
|
||||||
Util::logSum (double& x, double y)
|
Util::logSum (double x, double y)
|
||||||
{
|
{
|
||||||
x = log (exp (x) + exp (y)); return;
|
return log (exp (x) + exp (y));
|
||||||
assert (isfinite (x) && isfinite (y));
|
assert (isfinite (x) && isfinite (y));
|
||||||
// If one value is much smaller than the other, keep the larger value.
|
// If one value is much smaller than the other, keep the larger value.
|
||||||
if (x < (y - log (1e200))) {
|
if (x < (y - log (1e200))) {
|
||||||
x = y;
|
return y;
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
if (y < (x - log (1e200))) {
|
if (y < (x - log (1e200))) {
|
||||||
return;
|
return x;
|
||||||
}
|
}
|
||||||
double diff = x - y;
|
double diff = x - y;
|
||||||
assert (isfinite (diff) && isfinite (x) && isfinite (y));
|
assert (isfinite (diff) && isfinite (x) && isfinite (y));
|
||||||
if (!isfinite (exp (diff))) { // difference is too large
|
if (!isfinite (exp (diff))) {
|
||||||
x = x > y ? x : y;
|
// difference is too large
|
||||||
} else { // otherwise return the sum.
|
return x > y ? x : y;
|
||||||
x = y + log (static_cast<double>(1.0) + exp (diff));
|
|
||||||
}
|
}
|
||||||
|
// otherwise return the sum.
|
||||||
|
return y + log (static_cast<double>(1.0) + exp (diff));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -140,52 +221,87 @@ Util::add (Params& v1, const Params& v2, unsigned repetitions)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline double
|
inline unsigned
|
||||||
Util::tl (double v)
|
Util::maxUnsigned (void)
|
||||||
{
|
{
|
||||||
return Globals::logDomain ? log(v) : v;
|
return numeric_limits<unsigned>::max();
|
||||||
}
|
}
|
||||||
|
|
||||||
inline double
|
|
||||||
Util::fl (double v)
|
|
||||||
{
|
namespace LogAware {
|
||||||
return Globals::logDomain ? exp(v) : v;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline double
|
inline double
|
||||||
Util::multIdenty() {
|
one()
|
||||||
return Globals::logDomain ? 0.0 : 1.0;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline double
|
|
||||||
Util::addIdenty()
|
|
||||||
{
|
|
||||||
return Globals::logDomain ? INF : 0.0;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline double
|
|
||||||
Util::withEvidence()
|
|
||||||
{
|
{
|
||||||
return Globals::logDomain ? 0.0 : 1.0;
|
return Globals::logDomain ? 0.0 : 1.0;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline double
|
|
||||||
Util::noEvidence() {
|
|
||||||
return Globals::logDomain ? INF : 0.0;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline double
|
inline double
|
||||||
Util::one()
|
zero() {
|
||||||
{
|
|
||||||
return Globals::logDomain ? 0.0 : 1.0;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline double
|
|
||||||
Util::zero() {
|
|
||||||
return Globals::logDomain ? INF : 0.0 ;
|
return Globals::logDomain ? INF : 0.0 ;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
inline double
|
||||||
|
addIdenty()
|
||||||
|
{
|
||||||
|
return Globals::logDomain ? INF : 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
inline double
|
||||||
|
multIdenty()
|
||||||
|
{
|
||||||
|
return Globals::logDomain ? 0.0 : 1.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
inline double
|
||||||
|
withEvidence()
|
||||||
|
{
|
||||||
|
return Globals::logDomain ? 0.0 : 1.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
inline double
|
||||||
|
noEvidence() {
|
||||||
|
return Globals::logDomain ? INF : 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
inline double
|
||||||
|
tl (double v)
|
||||||
|
{
|
||||||
|
return Globals::logDomain ? log (v) : v;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
inline double
|
||||||
|
fl (double v)
|
||||||
|
{
|
||||||
|
return Globals::logDomain ? exp (v) : v;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void normalize (Params&);
|
||||||
|
|
||||||
|
double getL1Distance (const Params&, const Params&);
|
||||||
|
|
||||||
|
double getMaxNorm (const Params&, const Params&);
|
||||||
|
|
||||||
|
double pow (double, unsigned);
|
||||||
|
|
||||||
|
double pow (double, double);
|
||||||
|
|
||||||
|
void pow (Params&, unsigned);
|
||||||
|
|
||||||
|
void pow (Params&, double);
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
struct NetInfo
|
struct NetInfo
|
||||||
{
|
{
|
||||||
NetInfo (unsigned size, bool loopy, unsigned nIters, double time)
|
NetInfo (unsigned size, bool loopy, unsigned nIters, double time)
|
||||||
@ -224,11 +340,17 @@ class Statistics
|
|||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
static unsigned getSolvedNetworksCounting (void);
|
static unsigned getSolvedNetworksCounting (void);
|
||||||
|
|
||||||
static void incrementPrimaryNetworksCounting (void);
|
static void incrementPrimaryNetworksCounting (void);
|
||||||
|
|
||||||
static unsigned getPrimaryNetworksCounting (void);
|
static unsigned getPrimaryNetworksCounting (void);
|
||||||
|
|
||||||
static void updateStatistics (unsigned, bool, unsigned, double);
|
static void updateStatistics (unsigned, bool, unsigned, double);
|
||||||
|
|
||||||
static void printStatistics (void);
|
static void printStatistics (void);
|
||||||
|
|
||||||
static void writeStatisticsToFile (const char*);
|
static void writeStatisticsToFile (const char*);
|
||||||
|
|
||||||
static void updateCompressingStatistics (
|
static void updateCompressingStatistics (
|
||||||
unsigned, unsigned, unsigned, unsigned, unsigned);
|
unsigned, unsigned, unsigned, unsigned, unsigned);
|
||||||
|
|
||||||
|
@ -56,7 +56,7 @@ VarElimSolver::getJointDistributionOf (const VarIds& vids)
|
|||||||
introduceEvidence();
|
introduceEvidence();
|
||||||
chooseEliminationOrder (vids);
|
chooseEliminationOrder (vids);
|
||||||
processFactorList (vids);
|
processFactorList (vids);
|
||||||
Params params = factorList_.back()->getParameters();
|
Params params = factorList_.back()->params();
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
Util::fromLog (params);
|
Util::fromLog (params);
|
||||||
}
|
}
|
||||||
@ -98,7 +98,7 @@ VarElimSolver::introduceEvidence (void)
|
|||||||
varFactors_.find (varNodes[i]->varId())->second;
|
varFactors_.find (varNodes[i]->varId())->second;
|
||||||
for (unsigned j = 0; j < idxs.size(); j++) {
|
for (unsigned j = 0; j < idxs.size(); j++) {
|
||||||
Factor* factor = factorList_[idxs[j]];
|
Factor* factor = factorList_[idxs[j]];
|
||||||
if (factor->nrVariables() == 1) {
|
if (factor->nrArguments() == 1) {
|
||||||
factorList_[idxs[j]] = 0;
|
factorList_[idxs[j]] = 0;
|
||||||
} else {
|
} else {
|
||||||
factorList_[idxs[j]]->absorveEvidence (
|
factorList_[idxs[j]]->absorveEvidence (
|
||||||
@ -121,8 +121,8 @@ VarElimSolver::chooseEliminationOrder (const VarIds& vids)
|
|||||||
const FgVarSet& varNodes = factorGraph_->getVarNodes();
|
const FgVarSet& varNodes = factorGraph_->getVarNodes();
|
||||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||||
VarId vid = varNodes[i]->varId();
|
VarId vid = varNodes[i]->varId();
|
||||||
if (std::find (vids.begin(), vids.end(), vid) == vids.end()
|
if (Util::contains (vids, vid) == false &&
|
||||||
&& !varNodes[i]->hasEvidence()) {
|
varNodes[i]->hasEvidence() == false) {
|
||||||
elimOrder_.push_back (vid);
|
elimOrder_.push_back (vid);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -154,7 +154,7 @@ VarElimSolver::processFactorList (const VarIds& vids)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
finalFactor->reorderVariables (unobservedVids);
|
finalFactor->reorderArguments (unobservedVids);
|
||||||
finalFactor->normalize();
|
finalFactor->normalize();
|
||||||
factorList_.push_back (finalFactor);
|
factorList_.push_back (finalFactor);
|
||||||
}
|
}
|
||||||
@ -179,10 +179,10 @@ VarElimSolver::eliminate (VarId elimVar)
|
|||||||
factorList_[idx] = 0;
|
factorList_[idx] = 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (result != 0 && result->nrVariables() != 1) {
|
if (result != 0 && result->nrArguments() != 1) {
|
||||||
result->sumOut (vn->varId());
|
result->sumOut (vn->varId());
|
||||||
factorList_.push_back (result);
|
factorList_.push_back (result);
|
||||||
const VarIds& resultVarIds = result->getVarIds();
|
const VarIds& resultVarIds = result->arguments();
|
||||||
for (unsigned i = 0; i < resultVarIds.size(); i++) {
|
for (unsigned i = 0; i < resultVarIds.size(); i++) {
|
||||||
vector<unsigned>& idxs =
|
vector<unsigned>& idxs =
|
||||||
varFactors_.find (resultVarIds[i])->second;
|
varFactors_.find (resultVarIds[i])->second;
|
||||||
|
@ -16,18 +16,28 @@ class VarElimSolver : public Solver
|
|||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
VarElimSolver (const BayesNet&);
|
VarElimSolver (const BayesNet&);
|
||||||
|
|
||||||
VarElimSolver (const FactorGraph&);
|
VarElimSolver (const FactorGraph&);
|
||||||
|
|
||||||
~VarElimSolver (void);
|
~VarElimSolver (void);
|
||||||
|
|
||||||
void runSolver (void) { }
|
void runSolver (void) { }
|
||||||
|
|
||||||
Params getPosterioriOf (VarId);
|
Params getPosterioriOf (VarId);
|
||||||
|
|
||||||
Params getJointDistributionOf (const VarIds&);
|
Params getJointDistributionOf (const VarIds&);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void createFactorList (void);
|
void createFactorList (void);
|
||||||
|
|
||||||
void introduceEvidence (void);
|
void introduceEvidence (void);
|
||||||
|
|
||||||
void chooseEliminationOrder (const VarIds&);
|
void chooseEliminationOrder (const VarIds&);
|
||||||
|
|
||||||
void processFactorList (const VarIds&);
|
void processFactorList (const VarIds&);
|
||||||
|
|
||||||
void eliminate (VarId);
|
void eliminate (VarId);
|
||||||
|
|
||||||
void printActiveFactors (void);
|
void printActiveFactors (void);
|
||||||
|
|
||||||
const BayesNet* bayesNet_;
|
const BayesNet* bayesNet_;
|
||||||
|
@ -40,8 +40,8 @@ VarNode::isValidState (int stateIndex)
|
|||||||
bool
|
bool
|
||||||
VarNode::isValidState (const string& stateName)
|
VarNode::isValidState (const string& stateName)
|
||||||
{
|
{
|
||||||
States states = GraphicalModel::getVariableInformation (varId_).states;
|
States states = GraphicalModel::getVarInformation (varId_).states;
|
||||||
return find (states.begin(), states.end(), stateName) != states.end();
|
return Util::contains (states, stateName);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -58,7 +58,7 @@ VarNode::setEvidence (int ev)
|
|||||||
void
|
void
|
||||||
VarNode::setEvidence (const string& ev)
|
VarNode::setEvidence (const string& ev)
|
||||||
{
|
{
|
||||||
States states = GraphicalModel::getVariableInformation (varId_).states;
|
States states = GraphicalModel::getVarInformation (varId_).states;
|
||||||
for (unsigned i = 0; i < states.size(); i++) {
|
for (unsigned i = 0; i < states.size(); i++) {
|
||||||
if (states[i] == ev) {
|
if (states[i] == ev) {
|
||||||
evidence_ = i;
|
evidence_ = i;
|
||||||
@ -74,7 +74,7 @@ string
|
|||||||
VarNode::label (void) const
|
VarNode::label (void) const
|
||||||
{
|
{
|
||||||
if (GraphicalModel::variablesHaveInformation()) {
|
if (GraphicalModel::variablesHaveInformation()) {
|
||||||
return GraphicalModel::getVariableInformation (varId_).label;
|
return GraphicalModel::getVarInformation (varId_).label;
|
||||||
}
|
}
|
||||||
stringstream ss;
|
stringstream ss;
|
||||||
ss << "x" << varId_;
|
ss << "x" << varId_;
|
||||||
@ -87,7 +87,7 @@ States
|
|||||||
VarNode::states (void) const
|
VarNode::states (void) const
|
||||||
{
|
{
|
||||||
if (GraphicalModel::variablesHaveInformation()) {
|
if (GraphicalModel::variablesHaveInformation()) {
|
||||||
return GraphicalModel::getVariableInformation (varId_).states;
|
return GraphicalModel::getVarInformation (varId_).states;
|
||||||
}
|
}
|
||||||
States states;
|
States states;
|
||||||
for (unsigned i = 0; i < nrStates_; i++) {
|
for (unsigned i = 0; i < nrStates_; i++) {
|
||||||
|
@ -1,6 +1,10 @@
|
|||||||
#ifndef HORUS_VARNODE_H
|
#ifndef HORUS_VARNODE_H
|
||||||
#define HORUS_VARNODE_H
|
#define HORUS_VARNODE_H
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
#include "Horus.h"
|
#include "Horus.h"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
@ -9,25 +13,28 @@ class VarNode
|
|||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
VarNode (const VarNode*);
|
VarNode (const VarNode*);
|
||||||
VarNode (VarId, unsigned, int = NO_EVIDENCE);
|
|
||||||
virtual ~VarNode (void) {};
|
|
||||||
|
|
||||||
bool isValidState (int);
|
VarNode (VarId, unsigned, int = Constants::NO_EVIDENCE);
|
||||||
bool isValidState (const string&);
|
|
||||||
void setEvidence (int);
|
virtual ~VarNode (void) { };
|
||||||
void setEvidence (const string&);
|
|
||||||
string label (void) const;
|
|
||||||
States states (void) const;
|
|
||||||
|
|
||||||
unsigned varId (void) const { return varId_; }
|
unsigned varId (void) const { return varId_; }
|
||||||
|
|
||||||
unsigned nrStates (void) const { return nrStates_; }
|
unsigned nrStates (void) const { return nrStates_; }
|
||||||
bool hasEvidence (void) const { return evidence_ != NO_EVIDENCE; }
|
|
||||||
int getEvidence (void) const { return evidence_; }
|
int getEvidence (void) const { return evidence_; }
|
||||||
|
|
||||||
unsigned getIndex (void) const { return index_; }
|
unsigned getIndex (void) const { return index_; }
|
||||||
|
|
||||||
void setIndex (unsigned idx) { index_ = idx; }
|
void setIndex (unsigned idx) { index_ = idx; }
|
||||||
|
|
||||||
operator unsigned () const { return index_; }
|
operator unsigned () const { return index_; }
|
||||||
|
|
||||||
|
bool hasEvidence (void) const
|
||||||
|
{
|
||||||
|
return evidence_ != Constants::NO_EVIDENCE;
|
||||||
|
}
|
||||||
|
|
||||||
bool operator== (const VarNode& var) const
|
bool operator== (const VarNode& var) const
|
||||||
{
|
{
|
||||||
cout << "equal operator called" << endl;
|
cout << "equal operator called" << endl;
|
||||||
@ -42,6 +49,18 @@ class VarNode
|
|||||||
return varId_ != var.varId();
|
return varId_ != var.varId();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool isValidState (int);
|
||||||
|
|
||||||
|
bool isValidState (const string&);
|
||||||
|
|
||||||
|
void setEvidence (int);
|
||||||
|
|
||||||
|
void setEvidence (const string&);
|
||||||
|
|
||||||
|
string label (void) const;
|
||||||
|
|
||||||
|
States states (void) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
VarId varId_;
|
VarId varId_;
|
||||||
unsigned nrStates_;
|
unsigned nrStates_;
|
||||||
|
@ -13,8 +13,8 @@ function run_solver
|
|||||||
{
|
{
|
||||||
if [ $2 = bp ]
|
if [ $2 = bp ]
|
||||||
then
|
then
|
||||||
extra_flag1=clpbn_bp:set_horus_flag\(inf_alg,$4\)
|
extra_flag1=clpbn_horus:set_horus_flag\(inf_alg,$4\)
|
||||||
extra_flag2=clpbn_bp:set_horus_flag\(schedule,$5\)
|
extra_flag2=clpbn_horus:set_horus_flag\(schedule,$5\)
|
||||||
else
|
else
|
||||||
extra_flag1=true
|
extra_flag1=true
|
||||||
extra_flag2=true
|
extra_flag2=true
|
||||||
@ -22,7 +22,7 @@ fi
|
|||||||
/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME
|
/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME
|
||||||
[$1].
|
[$1].
|
||||||
clpbn:set_clpbn_flag(solver,$2),
|
clpbn:set_clpbn_flag(solver,$2),
|
||||||
clpbn_bp:set_horus_flag(use_logarithms, true),
|
clpbn_horus:set_horus_flag(use_logarithms, true),
|
||||||
$extra_flag1, $extra_flag2,
|
$extra_flag1, $extra_flag2,
|
||||||
run_query(_R),
|
run_query(_R),
|
||||||
open("$OUT_FILE_NAME", 'append',S),
|
open("$OUT_FILE_NAME", 'append',S),
|
||||||
|
@ -13,8 +13,8 @@ function run_solver
|
|||||||
{
|
{
|
||||||
if [ $2 = bp ]
|
if [ $2 = bp ]
|
||||||
then
|
then
|
||||||
extra_flag1=clpbn_bp:set_horus_flag\(inf_alg,$4\)
|
extra_flag1=clpbn_horus:set_horus_flag\(inf_alg,$4\)
|
||||||
extra_flag2=clpbn_bp:set_horus_flag\(schedule,$5\)
|
extra_flag2=clpbn_horus:set_horus_flag\(schedule,$5\)
|
||||||
else
|
else
|
||||||
extra_flag1=true
|
extra_flag1=true
|
||||||
extra_flag2=true
|
extra_flag2=true
|
||||||
@ -22,7 +22,7 @@ fi
|
|||||||
/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME
|
/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME
|
||||||
[$1].
|
[$1].
|
||||||
clpbn:set_clpbn_flag(solver,$2),
|
clpbn:set_clpbn_flag(solver,$2),
|
||||||
clpbn_bp:set_horus_flag(use_logarithms, true),
|
clpbn_horus:set_horus_flag(use_logarithms, true),
|
||||||
$extra_flag1, $extra_flag2,
|
$extra_flag1, $extra_flag2,
|
||||||
run_query(_R),
|
run_query(_R),
|
||||||
open("$OUT_FILE_NAME", 'append',S),
|
open("$OUT_FILE_NAME", 'append',S),
|
||||||
@ -37,6 +37,8 @@ function run_all_graphs
|
|||||||
echo "*******************************************************************" >> "$OUT_FILE_NAME"
|
echo "*******************************************************************" >> "$OUT_FILE_NAME"
|
||||||
echo "results for solver $2" >> $OUT_FILE_NAME
|
echo "results for solver $2" >> $OUT_FILE_NAME
|
||||||
echo "*******************************************************************" >> "$OUT_FILE_NAME"
|
echo "*******************************************************************" >> "$OUT_FILE_NAME"
|
||||||
|
run_solver town_3 $1 town_3 $3 $4 $5
|
||||||
|
return
|
||||||
run_solver town_1000 $1 town_1000 $3 $4 $5
|
run_solver town_1000 $1 town_1000 $3 $4 $5
|
||||||
run_solver town_5000 $1 town_5000 $3 $4 $5
|
run_solver town_5000 $1 town_5000 $3 $4 $5
|
||||||
run_solver town_10000 $1 town_10000 $3 $4 $5
|
run_solver town_10000 $1 town_10000 $3 $4 $5
|
||||||
|
@ -13,8 +13,8 @@ function run_solver
|
|||||||
{
|
{
|
||||||
if [ $2 = bp ]
|
if [ $2 = bp ]
|
||||||
then
|
then
|
||||||
extra_flag1=clpbn_bp:set_horus_flag\(inf_alg,$4\)
|
extra_flag1=clpbn_horus:set_horus_flag\(inf_alg,$4\)
|
||||||
extra_flag2=clpbn_bp:set_horus_flag\(schedule,$5\)
|
extra_flag2=clpbn_horus:set_horus_flag\(schedule,$5\)
|
||||||
else
|
else
|
||||||
extra_flag1=true
|
extra_flag1=true
|
||||||
extra_flag2=true
|
extra_flag2=true
|
||||||
@ -22,7 +22,7 @@ fi
|
|||||||
/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME
|
/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME
|
||||||
[$1].
|
[$1].
|
||||||
clpbn:set_clpbn_flag(solver,$2),
|
clpbn:set_clpbn_flag(solver,$2),
|
||||||
clpbn_bp:set_horus_flag(use_logarithms, true),
|
clpbn_horus:set_horus_flag(use_logarithms, true),
|
||||||
$extra_flag1, $extra_flag2,
|
$extra_flag1, $extra_flag2,
|
||||||
run_query(_R),
|
run_query(_R),
|
||||||
open("$OUT_FILE_NAME", 'append',S),
|
open("$OUT_FILE_NAME", 'append',S),
|
||||||
|
29
packages/CLPBN/clpbn/bp/benchmarks/city/town_3.yap
Normal file
29
packages/CLPBN/clpbn/bp/benchmarks/city/town_3.yap
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
:- source.
|
||||||
|
:- style_check(all).
|
||||||
|
:- yap_flag(unknown,error).
|
||||||
|
:- yap_flag(write_strings,on).
|
||||||
|
:- use_module(library(clpbn)).
|
||||||
|
:- set_clpbn_flag(solver, bp).
|
||||||
|
:- [-schema].
|
||||||
|
|
||||||
|
lives(_joe, nyc).
|
||||||
|
|
||||||
|
run_query(Guilty) :-
|
||||||
|
guilty(joe, Guilty),
|
||||||
|
witness(nyc, t),
|
||||||
|
runall(X, ev(X)).
|
||||||
|
|
||||||
|
|
||||||
|
runall(G, Wrapper) :-
|
||||||
|
findall(G, Wrapper, L),
|
||||||
|
execute_all(L).
|
||||||
|
|
||||||
|
|
||||||
|
execute_all([]).
|
||||||
|
execute_all(G.L) :-
|
||||||
|
call(G),
|
||||||
|
execute_all(L).
|
||||||
|
|
||||||
|
|
||||||
|
ev(descn(p2, t)).
|
||||||
|
ev(descn(p3, t)).
|
@ -9,7 +9,7 @@ OUT_FILE_NAME=results.log
|
|||||||
rm -f $OUT_FILE_NAME
|
rm -f $OUT_FILE_NAME
|
||||||
rm -f ignore.$OUT_FILE_NAME
|
rm -f ignore.$OUT_FILE_NAME
|
||||||
|
|
||||||
# yap -g "['../../../../examples/School/school_32'], [missing5], use_module(library(clpbn/learning/em)), graph(L), clpbn:set_clpbn_flag(em_solver,bp), clpbn_bp:set_horus_flag(inf_alg,ve), statistics(runtime, _), em(L,0.01,10,_,Lik), statistics(runtime, [T,_])."
|
# yap -g "['../../../../examples/School/sch32'], [missing5], use_module(library(clpbn/learning/em)), graph(L), clpbn:set_clpbn_flag(em_solver,bp), clpbn_horus:set_horus_flag(inf_alg,fg_bp), statistics(runtime, _), em(L,0.01,10,_,Lik), statistics(runtime, [T,_])."
|
||||||
|
|
||||||
function run_solver
|
function run_solver
|
||||||
{
|
{
|
||||||
@ -17,11 +17,11 @@ if [ $2 = bp ]
|
|||||||
then
|
then
|
||||||
if [ $4 = ve ]
|
if [ $4 = ve ]
|
||||||
then
|
then
|
||||||
extra_flag1=clpbn_bp:set_horus_flag\(inf_alg,$4\)
|
extra_flag1=clpbn_horus:set_horus_flag\(inf_alg,$4\)
|
||||||
extra_flag2=clpbn_bp:set_horus_flag\(elim_heuristic,$5\)
|
extra_flag2=clpbn_horus:set_horus_flag\(elim_heuristic,$5\)
|
||||||
else
|
else
|
||||||
extra_flag1=clpbn_bp:set_horus_flag\(inf_alg,$4\)
|
extra_flag1=clpbn_horus:set_horus_flag\(inf_alg,$4\)
|
||||||
extra_flag2=clpbn_bp:set_horus_flag\(schedule,$5\)
|
extra_flag2=clpbn_horus:set_horus_flag\(schedule,$5\)
|
||||||
fi
|
fi
|
||||||
else
|
else
|
||||||
extra_flag1=true
|
extra_flag1=true
|
||||||
@ -29,7 +29,7 @@ else
|
|||||||
fi
|
fi
|
||||||
/usr/bin/time -o "$OUT_FILE_NAME" -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF &>> "ignore.$OUT_FILE_NAME"
|
/usr/bin/time -o "$OUT_FILE_NAME" -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF &>> "ignore.$OUT_FILE_NAME"
|
||||||
:- [pos:train].
|
:- [pos:train].
|
||||||
:- ['../../../../examples/School/school_32'].
|
:- ['../../../../examples/School/sch32'].
|
||||||
:- use_module(library(clpbn/learning/em)).
|
:- use_module(library(clpbn/learning/em)).
|
||||||
:- use_module(library(clpbn/bp)).
|
:- use_module(library(clpbn/bp)).
|
||||||
[$1].
|
[$1].
|
||||||
@ -57,11 +57,10 @@ function run_all_graphs
|
|||||||
#run_solver missing50 $1 missing50 $3 $4 $5
|
#run_solver missing50 $1 missing50 $3 $4 $5
|
||||||
}
|
}
|
||||||
|
|
||||||
run_solver missing5 ve missing5 $3 $4 $5
|
|
||||||
exit
|
#run_all_graphs bp "hve(min_neighbors) " ve min_neighbors
|
||||||
run_all_graphs bp "hve(min_neighbors) " ve min_neighbors
|
|
||||||
#run_all_graphs bp "bn_bp(seq_fixed) " bn_bp seq_fixed
|
#run_all_graphs bp "bn_bp(seq_fixed) " bn_bp seq_fixed
|
||||||
#run_all_graphs bp "fg_bp(seq_fixed) " fg_bp seq_fixed
|
run_all_graphs bp "fg_bp(seq_fixed) " fg_bp seq_fixed
|
||||||
#run_all_graphs bp "cbp(seq_fixed) " cbp seq_fixed
|
#run_all_graphs bp "cbp(seq_fixed) " cbp seq_fixed
|
||||||
exit
|
exit
|
||||||
|
|
||||||
|
18
packages/CLPBN/clpbn/bp/examples/allopstest.yap
Normal file
18
packages/CLPBN/clpbn/bp/examples/allopstest.yap
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
|
||||||
|
:- use_module(library(pfl)).
|
||||||
|
|
||||||
|
:- set_clpbn_flag(solver,fove).
|
||||||
|
|
||||||
|
|
||||||
|
c(x1,y1,z1).
|
||||||
|
c(x1,y1,z2).
|
||||||
|
c(x2,y2,z1).
|
||||||
|
c(x3,y2,z1).
|
||||||
|
|
||||||
|
bayes p(X)::[t,f] ; [0.2, 0.4] ; [c(X,_,_)].
|
||||||
|
|
||||||
|
bayes q(Y)::[t,f] ; [0.5, 0.6] ; [c(_,Y,_)].
|
||||||
|
|
||||||
|
bayes s(Z)::[t,f] , p(X) , q(Y) ; [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8] ; [c(X,Y,Z)].
|
||||||
|
|
||||||
|
% bayes series::[t,f] , attends(X) ; [0.5, 0.6, 0.7, 0.8] ; [c(X,_)].
|
78022
packages/CLPBN/clpbn/bp/examples/city_test.fg
Normal file
78022
packages/CLPBN/clpbn/bp/examples/city_test.fg
Normal file
File diff suppressed because it is too large
Load Diff
@ -25,6 +25,8 @@ markov attends(P)::[t,f] , hot(W)::[t,f] ; [0.1, 0.2, 0.3, 0.4] ; [c(P,W)].
|
|||||||
|
|
||||||
markov attends(P)::[t,f], series::[t,f] ; [0.5, 0.6, 0.7, 0.8] ; [c(P,_)].
|
markov attends(P)::[t,f], series::[t,f] ; [0.5, 0.6, 0.7, 0.8] ; [c(P,_)].
|
||||||
|
|
||||||
|
:- clpbn_horus:set_horus_flag(use_logarithms,true).
|
||||||
|
|
||||||
?- series(X).
|
?- series(X).
|
||||||
|
|
||||||
|
|
||||||
|
20
packages/CLPBN/clpbn/bp/examples/fail.yap
Normal file
20
packages/CLPBN/clpbn/bp/examples/fail.yap
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
|
||||||
|
:- use_module(library(pfl)).
|
||||||
|
|
||||||
|
:- set_clpbn_flag(solver,ve).
|
||||||
|
%:- set_clpbn_flag(solver,fove).
|
||||||
|
|
||||||
|
|
||||||
|
t(ann).
|
||||||
|
t(dave).
|
||||||
|
|
||||||
|
% p(ann,t).
|
||||||
|
|
||||||
|
bayes p(X)::[t,f] ; [0.1, 0.3] ; [t(X)].
|
||||||
|
|
||||||
|
% use standard Prolog queries: provide evidence first.
|
||||||
|
|
||||||
|
?- p(ann,t), p(ann,X).
|
||||||
|
|
||||||
|
% ?- p(ann,X).
|
||||||
|
|
32
packages/CLPBN/clpbn/bp/examples/workshop_attrs.yap
Normal file
32
packages/CLPBN/clpbn/bp/examples/workshop_attrs.yap
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
|
||||||
|
:- use_module(library(pfl)).
|
||||||
|
|
||||||
|
%:- set_clpbn_flag(solver,ve).
|
||||||
|
%:- set_clpbn_flag(solver,bp), clpbn_bp:set_horus_flag(inf_alg,ve).
|
||||||
|
:- set_clpbn_flag(solver,fove).
|
||||||
|
|
||||||
|
c(p1).
|
||||||
|
c(p2).
|
||||||
|
c(p3).
|
||||||
|
c(p4).
|
||||||
|
c(p5).
|
||||||
|
|
||||||
|
|
||||||
|
markov attends(P)::[t,f] , attr1::[t,f] ; [0.1, 0.2, 0.3, 0.4] ; [c(P)].
|
||||||
|
|
||||||
|
markov attends(P)::[t,f] , attr2::[t,f] ; [0.1, 0.2, 0.3, 0.4] ; [c(P)].
|
||||||
|
|
||||||
|
markov attends(P)::[t,f] , attr3::[t,f] ; [0.1, 0.2, 0.3, 0.4] ; [c(P)].
|
||||||
|
|
||||||
|
markov attends(P)::[t,f] , attr4::[t,f] ; [0.1, 0.2, 0.3, 0.4] ; [c(P)].
|
||||||
|
|
||||||
|
markov attends(P)::[t,f] , attr5::[t,f] ; [0.1, 0.2, 0.3, 0.4] ; [c(P)].
|
||||||
|
|
||||||
|
markov attends(P)::[t,f] , attr6::[t,f] ; [0.1, 0.2, 0.3, 0.4] ; [c(P)].
|
||||||
|
|
||||||
|
markov attends(P)::[t,f], series::[t,f] ; [0.5, 0.6, 0.7, 0.8] ; [c(P)].
|
||||||
|
|
||||||
|
%:- clpbn_horus:set_horus_flag(use_logarithms,true).
|
||||||
|
|
||||||
|
?- series(X).
|
||||||
|
|
@ -1,9 +1,9 @@
|
|||||||
|
|
||||||
/************************************************
|
/*******************************************************
|
||||||
|
|
||||||
(GC) First Order Variable Elimination Interface
|
First Order Variable Elimination Interface
|
||||||
|
|
||||||
**************************************************/
|
********************************************************/
|
||||||
|
|
||||||
:- module(clpbn_fove,
|
:- module(clpbn_fove,
|
||||||
[fove/3,
|
[fove/3,
|
||||||
@ -23,37 +23,40 @@
|
|||||||
|
|
||||||
|
|
||||||
:- use_module(library(pfl),
|
:- use_module(library(pfl),
|
||||||
[factor/6,
|
[factor/5,
|
||||||
skolem/2
|
skolem/2,
|
||||||
|
get_pfl_parameters/2
|
||||||
]).
|
]).
|
||||||
|
|
||||||
|
|
||||||
:- use_module(horus).
|
:- use_module(horus,
|
||||||
|
[create_lifted_network/3,
|
||||||
|
set_parfactors_params/2,
|
||||||
|
run_lifted_solver/3,
|
||||||
|
free_parfactors/1
|
||||||
|
]).
|
||||||
|
|
||||||
:- set_horus_flag(use_logarithms, false).
|
|
||||||
%:- set_horus_flag(use_logarithms, true).
|
|
||||||
|
|
||||||
fove([[]], _, _) :- !.
|
fove([[]], _, _) :- !.
|
||||||
fove([QueryVars], AllVars, Output) :-
|
fove([QueryVars], AllVars, Output) :-
|
||||||
writeln(queryVars:QueryVars),
|
init_fove_solver(_, AllVars, _, ParfactorList),
|
||||||
writeln(allVars:AllVars),
|
run_fove_solver([QueryVars], LPs, ParfactorList),
|
||||||
init_fove_solver(_, AllVars, _, ParfactorGraph),
|
finalize_fove_solver(ParfactorList),
|
||||||
run_fove_solver([QueryVars], LPs, ParfactorGraph),
|
|
||||||
finalize_fove_solver(ParfactorGraph),
|
|
||||||
clpbn_bind_vals([QueryVars], LPs, Output).
|
clpbn_bind_vals([QueryVars], LPs, Output).
|
||||||
|
|
||||||
init_fove_solver(_, AllAttVars, _, fove(ParfactorGraph, DistIds)) :-
|
|
||||||
writeln(allattvars:AllAttVars), writeln(''),
|
init_fove_solver(_, AllAttVars, _, fove(ParfactorList, DistIds)) :-
|
||||||
get_parfactors(Parfactors),
|
get_parfactors(Parfactors),
|
||||||
get_dist_ids(Parfactors, DistIds0),
|
get_dist_ids(Parfactors, DistIds0),
|
||||||
sort(DistIds0, DistIds),
|
sort(DistIds0, DistIds),
|
||||||
get_observed_vars(AllAttVars, ObservedVars),
|
get_observed_vars(AllAttVars, ObservedVars),
|
||||||
writeln(factors:Parfactors:'\n'),
|
writeln(factors:Parfactors:'\n'),
|
||||||
writeln(evidence:ObservedVars:'\n'),
|
writeln(evidence:ObservedVars:'\n'),
|
||||||
create_lifted_network(Parfactors,ObservedVars,ParfactorGraph).
|
create_lifted_network(Parfactors,ObservedVars,ParfactorList).
|
||||||
|
|
||||||
|
|
||||||
:- table get_parfactors/1.
|
:- table get_parfactors/1.
|
||||||
|
|
||||||
%
|
%
|
||||||
% enumerate all parfactors and enumerate their domain as tuples.
|
% enumerate all parfactors and enumerate their domain as tuples.
|
||||||
%
|
%
|
||||||
@ -62,17 +65,26 @@ init_fove_solver(_, AllAttVars, _, fove(ParfactorGraph, DistIds)) :-
|
|||||||
% Ks: a list of keys, also known as the pf formula [a(X),b(Y),c(X,Y)]
|
% Ks: a list of keys, also known as the pf formula [a(X),b(Y),c(X,Y)]
|
||||||
% Vs: the list of free variables [X,Y]
|
% Vs: the list of free variables [X,Y]
|
||||||
% Phi: the table following usual CLP(BN) convention
|
% Phi: the table following usual CLP(BN) convention
|
||||||
% Tuples: tuples with all ground bindings for variables in Vs, of the form [fv(x,y)]
|
% Tuples: ground bindings for variables in Vs, of the form [fv(x,y)]
|
||||||
%
|
%
|
||||||
get_parfactors(Factors) :-
|
get_parfactors(Factors) :-
|
||||||
findall(F, is_factor(F), Factors).
|
findall(F, is_factor(F), Factors).
|
||||||
|
|
||||||
|
|
||||||
is_factor(pf(Id, Ks, Rs, Phi, Tuples)) :-
|
is_factor(pf(Id, Ks, Rs, Phi, Tuples)) :-
|
||||||
|
<<<<<<< HEAD
|
||||||
factor(_Type, Id, Ks, Vs, Table, Constraints),
|
factor(_Type, Id, Ks, Vs, Table, Constraints),
|
||||||
get_ranges(Ks,Rs),
|
get_ranges(Ks,Rs),
|
||||||
Table \= avg,
|
Table \= avg,
|
||||||
gen_table(Table, Phi),
|
gen_table(Table, Phi),
|
||||||
all_tuples(Constraints, Vs, Tuples).
|
all_tuples(Constraints, Vs, Tuples).
|
||||||
|
=======
|
||||||
|
factor(Id, Ks, Vs, Table, Constraints),
|
||||||
|
get_ranges(Ks,Rs),
|
||||||
|
Table \= avg,
|
||||||
|
gen_table(Table, Phi),
|
||||||
|
all_tuples(Constraints, Vs, Tuples).
|
||||||
|
>>>>>>> 911b241ad663a911af52babcf5d702c5239b4350
|
||||||
|
|
||||||
|
|
||||||
get_ranges([],[]).
|
get_ranges([],[]).
|
||||||
@ -90,6 +102,7 @@ gen_table(Table, Phi) :-
|
|||||||
call(user:Table, Phi)
|
call(user:Table, Phi)
|
||||||
).
|
).
|
||||||
|
|
||||||
|
|
||||||
all_tuples(Constraints, Tuple, Tuples) :-
|
all_tuples(Constraints, Tuple, Tuples) :-
|
||||||
setof(Tuple, Constraints^run(Constraints), Tuples).
|
setof(Tuple, Constraints^run(Constraints), Tuples).
|
||||||
|
|
||||||
@ -107,14 +120,11 @@ get_dist_ids(pf(Id, _, _, _, _).Parfactors, Id.DistIds) :-
|
|||||||
|
|
||||||
get_observed_vars([], []).
|
get_observed_vars([], []).
|
||||||
get_observed_vars(V.AllAttVars, [K:E|ObservedVars]) :-
|
get_observed_vars(V.AllAttVars, [K:E|ObservedVars]) :-
|
||||||
writeln('checking ev for':V),
|
|
||||||
clpbn:get_atts(V,[key(K)]),
|
clpbn:get_atts(V,[key(K)]),
|
||||||
( clpbn:get_atts(V,[evidence(E)]) ; pfl:evidence(K,E) ), !,
|
( clpbn:get_atts(V,[evidence(E)]) ; pfl:evidence(K,E) ), !,
|
||||||
writeln('evidence!!!':K:E),
|
|
||||||
get_observed_vars(AllAttVars, ObservedVars).
|
get_observed_vars(AllAttVars, ObservedVars).
|
||||||
get_observed_vars(V.AllAttVars, ObservedVars) :-
|
get_observed_vars(V.AllAttVars, ObservedVars) :-
|
||||||
clpbn:get_atts(V,[key(K)]), !,
|
clpbn:get_atts(V,[key(K)]), !,
|
||||||
writeln('no evidence for':V:K),
|
|
||||||
get_observed_vars(AllAttVars, ObservedVars).
|
get_observed_vars(AllAttVars, ObservedVars).
|
||||||
|
|
||||||
|
|
||||||
@ -136,16 +146,15 @@ get_dists_parameters([Id|Ids], [dist(Id, Params)|DistsInfo]) :-
|
|||||||
get_dists_parameters(Ids, DistsInfo).
|
get_dists_parameters(Ids, DistsInfo).
|
||||||
|
|
||||||
|
|
||||||
run_fove_solver(QueryVarsAtts, Solutions, fove(ParfactorGraph, DistIds)) :-
|
run_fove_solver(QueryVarsAtts, Solutions, fove(ParfactorList, DistIds)) :-
|
||||||
% TODO set_parfactor_graph_params
|
get_dists_parameters(DistIds, DistsParams),
|
||||||
writeln(distIds:DistIds),
|
writeln(distParams:DistsParams),
|
||||||
%get_dists_parameters(DistIds, DistParams),
|
set_parfactors_params(ParfactorList, DistsParams),
|
||||||
%writeln(distParams:DistParams),
|
|
||||||
get_query_vars(QueryVarsAtts, QueryVars),
|
get_query_vars(QueryVarsAtts, QueryVars),
|
||||||
writeln(queryVars:QueryVars),
|
writeln(queryVars:QueryVars), writeln(''),
|
||||||
run_lifted_solver(ParfactorGraph, QueryVars, Solutions).
|
run_lifted_solver(ParfactorList, QueryVars, Solutions).
|
||||||
|
|
||||||
|
|
||||||
finalize_fove_solver(fove(ParfactorGraph, _)) :-
|
finalize_fove_solver(fove(ParfactorList, _)) :-
|
||||||
free_parfactor_graph(ParfactorGraph).
|
free_parfactors(ParfactorList).
|
||||||
|
|
||||||
|
@ -1,18 +1,24 @@
|
|||||||
|
|
||||||
|
/*******************************************************
|
||||||
|
|
||||||
|
Interface with C++
|
||||||
|
|
||||||
|
********************************************************/
|
||||||
|
|
||||||
:- module(clpbn_horus,
|
:- module(clpbn_horus,
|
||||||
[
|
[create_lifted_network/3,
|
||||||
create_lifted_network/3,
|
|
||||||
create_ground_network/2,
|
create_ground_network/2,
|
||||||
set_parfactor_graph_params/2,
|
set_parfactors_params/2,
|
||||||
set_bayes_net_params/2,
|
set_bayes_net_params/2,
|
||||||
run_lifted_solver/3,
|
run_lifted_solver/3,
|
||||||
run_other_solvers/3,
|
run_ground_solver/3,
|
||||||
set_extra_vars_info/2,
|
set_extra_vars_info/2,
|
||||||
set_horus_flag/2,
|
set_horus_flag/2,
|
||||||
free_bayesian_network/1,
|
free_parfactors/1,
|
||||||
free_parfactor_graph/1
|
free_bayesian_network/1
|
||||||
]).
|
]).
|
||||||
|
|
||||||
|
|
||||||
patch_things_up :-
|
patch_things_up :-
|
||||||
assert_static(clpbn_horus:set_horus_flag(_,_)).
|
assert_static(clpbn_horus:set_horus_flag(_,_)).
|
||||||
|
|
||||||
@ -23,3 +29,24 @@ warning :-
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
%:- set_horus_flag(inf_alg, ve).
|
||||||
|
:- set_horus_flag(inf_alg, bn_bp).
|
||||||
|
%:- set_horus_flag(inf_alg, fg_bp).
|
||||||
|
%: -set_horus_flag(inf_alg, cbp).
|
||||||
|
|
||||||
|
:- set_horus_flag(schedule, seq_fixed).
|
||||||
|
%:- set_horus_flag(schedule, seq_random).
|
||||||
|
%:- set_horus_flag(schedule, parallel).
|
||||||
|
%:- set_horus_flag(schedule, max_residual).
|
||||||
|
|
||||||
|
:- set_horus_flag(accuracy, 0.0001).
|
||||||
|
|
||||||
|
:- set_horus_flag(max_iter, 1000).
|
||||||
|
|
||||||
|
:- set_horus_flag(order_factor_variables, false).
|
||||||
|
%:- set_horus_flag(order_factor_variables, true).
|
||||||
|
|
||||||
|
|
||||||
|
:- set_horus_flag(use_logarithms, false).
|
||||||
|
% :- set_horus_flag(use_logarithms, true).
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user