Merge branch 'master' of https://github.com/tacgomes/yap6.3
This commit is contained in:
commit
28ce2da3dc
@ -10,7 +10,7 @@
|
|||||||
check_if_bp_done/1,
|
check_if_bp_done/1,
|
||||||
init_bp_solver/4,
|
init_bp_solver/4,
|
||||||
run_bp_solver/3,
|
run_bp_solver/3,
|
||||||
call_bp_ground/5,
|
call_bp_ground/5,
|
||||||
finalize_bp_solver/1
|
finalize_bp_solver/1
|
||||||
]).
|
]).
|
||||||
|
|
||||||
@ -31,80 +31,109 @@
|
|||||||
[check_for_agg_vars/2]).
|
[check_for_agg_vars/2]).
|
||||||
|
|
||||||
|
|
||||||
:- use_module(library(clpbn/horus)).
|
:- use_module(library(charsio),
|
||||||
|
[term_to_atom/2]).
|
||||||
|
|
||||||
|
|
||||||
|
:- use_module(library(pfl),
|
||||||
|
[skolem/2,
|
||||||
|
get_pfl_parameters/2
|
||||||
|
]).
|
||||||
|
|
||||||
|
|
||||||
:- use_module(library(lists)).
|
:- use_module(library(lists)).
|
||||||
|
|
||||||
:- use_module(library(atts)).
|
:- use_module(library(atts)).
|
||||||
|
|
||||||
:- attribute id/1.
|
|
||||||
|
|
||||||
|
|
||||||
%:- set_horus_flag(inf_alg, ve).
|
|
||||||
:- set_horus_flag(inf_alg, bn_bp).
|
|
||||||
%:- set_horus_flag(inf_alg, fg_bp).
|
|
||||||
%: -set_horus_flag(inf_alg, cbp).
|
|
||||||
|
|
||||||
:- set_horus_flag(schedule, seq_fixed).
|
|
||||||
%:- set_horus_flag(schedule, seq_random).
|
|
||||||
%:- set_horus_flag(schedule, parallel).
|
|
||||||
%:- set_horus_flag(schedule, max_residual).
|
|
||||||
|
|
||||||
:- set_horus_flag(accuracy, 0.0001).
|
|
||||||
|
|
||||||
:- use_module(library(charsio),
|
|
||||||
[term_to_atom/2]).
|
|
||||||
|
|
||||||
:- use_module(library(bhash)).
|
:- use_module(library(bhash)).
|
||||||
|
|
||||||
|
|
||||||
:- use_module(horus,
|
:- use_module(horus,
|
||||||
[create_ground_network/2,
|
[create_ground_network/4,
|
||||||
set_bayes_net_params/2,
|
set_factors_params/2,
|
||||||
run_ground_solver/3,
|
run_ground_solver/3,
|
||||||
set_extra_vars_info/2,
|
set_vars_information/2,
|
||||||
free_bayesian_network/1
|
free_ground_network/1
|
||||||
]).
|
]).
|
||||||
|
|
||||||
|
|
||||||
:- attribute id/1.
|
call_bp_ground(QueryKeys, AllKeys, Factors, Evidence, Output) :-
|
||||||
|
b_hash_new(Hash0),
|
||||||
|
keys_to_ids(AllKeys, 0, Hash0, Hash),
|
||||||
|
get_factors_type(Factors, Type),
|
||||||
|
evidence_to_ids(Evidence, Hash, EvidenceIds),
|
||||||
|
factors_to_ids(Factors, Hash, FactorIds),
|
||||||
|
writeln(type:Type), writeln(''),
|
||||||
|
writeln(allKeys:AllKeys), writeln(''),
|
||||||
|
writeln(factors:Factors), writeln(''),
|
||||||
|
writeln(factorIds:FactorIds), writeln(''),
|
||||||
|
writeln(evidence:Evidence), writeln(''),
|
||||||
|
writeln(evidenceIds:EvidenceIds), writeln(''),
|
||||||
|
create_ground_network(Type, FactorIds, EvidenceIds, Network),
|
||||||
|
%get_vars_information(AllKeys, StatesNames),
|
||||||
|
%set_vars_information(AllKeys, StatesNames),
|
||||||
|
run_solver(ground(Network,Hash), QueryKeys, Solutions),
|
||||||
|
writeln(answer:Solutions),
|
||||||
|
%clpbn_bind_vals([QueryKeys], Solutions, Output).
|
||||||
|
free_ground_network(Network).
|
||||||
|
|
||||||
|
|
||||||
|
run_solver(ground(Network,Hash), QueryKeys, Solutions) :-
|
||||||
|
%get_dists_parameters(DistIds, DistsParams),
|
||||||
|
%set_factors_params(Network, DistsParams),
|
||||||
|
list_of_keys_to_ids(QueryKeys, Hash, QueryIds),
|
||||||
|
writeln(queryKeys:QueryKeys), writeln(''),
|
||||||
|
writeln(queryIds:QueryIds), writeln(''),
|
||||||
|
list_of_keys_to_ids(QueryKeys, Hash, QueryIds),
|
||||||
|
run_ground_solver(Network, [QueryIds], Solutions).
|
||||||
|
|
||||||
call_bp_ground(QueryKeys, AllKeys, Factors, Evidence, Solutions) :-
|
|
||||||
b_hash_new(Hash0),
|
|
||||||
keys_to_ids(AllKeys, 0, Hash0, Hash),
|
|
||||||
InvMap =.. [view|AllKeys],
|
|
||||||
list_of_keys_to_ids(QueryKeys, Hash, QueryVarsIds),
|
|
||||||
evidence_to_ids(Evidence, Hash, EvIds, EvIdNames),
|
|
||||||
factors_to_ids(Factors, Hash, FactorIds),
|
|
||||||
init_graphical_model(FactorIds, Network, InvMap, EvIdNames),
|
|
||||||
run_ground_solver(Network, QueryVarsIds, EvIds, Solutions),
|
|
||||||
free_graphical_model(Network).
|
|
||||||
|
|
||||||
keys_to_ids([], _, Hash, Hash).
|
keys_to_ids([], _, Hash, Hash).
|
||||||
keys_to_ids([Key|AllKeys], I0, Hash0, Hash) :-
|
keys_to_ids([Key|AllKeys], I0, Hash0, Hash) :-
|
||||||
b_hash_insert(Hash0, Key, I0, HashI),
|
b_hash_insert(Hash0, Key, I0, HashI),
|
||||||
I is I0+1,
|
I is I0+1,
|
||||||
keys_to_ids(AllKeys, I, HashI, Hash).
|
keys_to_ids(AllKeys, I, HashI, Hash).
|
||||||
|
|
||||||
|
|
||||||
|
get_factors_type([f(bayes, _, _)|_], bayes) :- ! .
|
||||||
|
get_factors_type([f(markov, _, _)|_], markov) :- ! .
|
||||||
|
|
||||||
|
|
||||||
list_of_keys_to_ids([], _, []).
|
list_of_keys_to_ids([], _, []).
|
||||||
list_of_keys_to_ids([Key|QueryKeys], Hash, [Id|QueryIds]) :-
|
list_of_keys_to_ids([Key|QueryKeys], Hash, [Id|QueryIds]) :-
|
||||||
b_hash_lookup(Key, Id, Hash),
|
b_hash_lookup(Key, Id, Hash),
|
||||||
list_of_keys_to_ids(QueryKeys, Hash, QueryIds).
|
list_of_keys_to_ids(QueryKeys, Hash, QueryIds).
|
||||||
|
|
||||||
evidence_to_ids([], _, [], []).
|
|
||||||
evidence_to_ids([Key=V|QueryKeys], Hash, [Id=V|QueryIds], [Id=Name|QueryNames]) :-
|
|
||||||
b_hash_lookup(Key, Id, Hash),
|
|
||||||
pfl:skolem(Key,Dom),
|
|
||||||
nth0(V, Dom, Name),
|
|
||||||
evidence_to_ids(QueryKeys, Hash, QueryIds, QueryNames).
|
|
||||||
|
|
||||||
factors_to_ids([], _, []).
|
factors_to_ids([], _, []).
|
||||||
factors_to_ids([f(markov, Keys, CPT)|Fs], Hash, [markov(Ids, CPT)|NFs]) :-
|
factors_to_ids([f(_, Keys, CPT)|Fs], Hash, [f(Ids, Ranges, CPT, DistId)|NFs]) :-
|
||||||
list_of_keys_to_ids(Keys, Hash, Ids),
|
list_of_keys_to_ids(Keys, Hash, Ids),
|
||||||
factors_to_ids(Fs, Hash, NFs).
|
DistId = 0,
|
||||||
factors_to_ids([f(bayes, Keys, CPT)|Fs], Hash, [bayes(Ids, CPT)|NFs]) :-
|
get_ranges(Keys, Ranges),
|
||||||
list_of_keys_to_ids(Keys, Hash, Ids),
|
factors_to_ids(Fs, Hash, NFs).
|
||||||
factors_to_ids(Fs, Hash, NFs).
|
|
||||||
|
|
||||||
|
get_ranges([],[]).
|
||||||
|
get_ranges(K.Ks, Range.Rs) :- !,
|
||||||
|
skolem(K,Domain),
|
||||||
|
length(Domain,Range),
|
||||||
|
get_ranges(Ks, Rs).
|
||||||
|
|
||||||
|
|
||||||
|
evidence_to_ids([], _, []).
|
||||||
|
evidence_to_ids([Key=Ev|QueryKeys], Hash, [Id=Ev|QueryIds]) :-
|
||||||
|
b_hash_lookup(Key, Id, Hash),
|
||||||
|
evidence_to_ids(QueryKeys, Hash, QueryIds).
|
||||||
|
|
||||||
|
|
||||||
|
get_vars_information([], []).
|
||||||
|
get_vars_information(Key.QueryKeys, Domain.StatesNames) :-
|
||||||
|
pfl:skolem(Key, Domain),
|
||||||
|
get_vars_information(QueryKeys, StatesNames).
|
||||||
|
|
||||||
|
|
||||||
|
finalize_bp_solver(bp(Network, _)) :-
|
||||||
|
free_ground_network(Network).
|
||||||
|
|
||||||
|
|
||||||
bp([[]],_,_) :- !.
|
bp([[]],_,_) :- !.
|
||||||
@ -116,102 +145,22 @@ bp([QueryVars], AllVars, Output) :-
|
|||||||
|
|
||||||
|
|
||||||
init_bp_solver(_, AllVars0, _, bp(BayesNet, DistIds)) :-
|
init_bp_solver(_, AllVars0, _, bp(BayesNet, DistIds)) :-
|
||||||
%writeln('init_bp_solver'),
|
%check_for_agg_vars(AllVars0, AllVars),
|
||||||
check_for_agg_vars(AllVars0, AllVars),
|
|
||||||
%writeln('clpbn_vars:'), print_clpbn_vars(AllVars),
|
|
||||||
assign_ids(AllVars, 0),
|
|
||||||
get_vars_info(AllVars, VarsInfo, DistIds0),
|
get_vars_info(AllVars, VarsInfo, DistIds0),
|
||||||
sort(DistIds0, DistIds),
|
sort(DistIds0, DistIds),
|
||||||
create_ground_network(VarsInfo, BayesNet),
|
create_ground_network(VarsInfo, BayesNet),
|
||||||
%get_extra_vars_info(AllVars, ExtraVarsInfo),
|
|
||||||
%set_extra_vars_info(BayesNet, ExtraVarsInfo),
|
|
||||||
%writeln(extravarsinfo:ExtraVarsInfo),
|
|
||||||
true.
|
true.
|
||||||
|
|
||||||
|
|
||||||
run_bp_solver(QueryVars, Solutions, bp(Network, DistIds)) :-
|
run_bp_solver(QueryVars, Solutions, bp(Network, DistIds)) :-
|
||||||
%writeln('-> run_bp_solver'),
|
|
||||||
get_dists_parameters(DistIds, DistsParams),
|
get_dists_parameters(DistIds, DistsParams),
|
||||||
set_bayes_net_params(Network, DistsParams),
|
set_factors_params(Network, DistsParams),
|
||||||
vars_to_ids(QueryVars, QueryVarsIds),
|
vars_to_ids(QueryVars, QueryVarsIds),
|
||||||
run_ground_solver(Network, QueryVarsIds, Solutions).
|
run_ground_solver(Network, QueryVarsIds, Solutions).
|
||||||
|
|
||||||
|
|
||||||
finalize_bp_solver(bp(Network, _)) :-
|
|
||||||
free_bayesian_network(Network).
|
|
||||||
|
|
||||||
|
|
||||||
assign_ids([], _).
|
|
||||||
assign_ids([V|Vs], Count) :-
|
|
||||||
put_atts(V, [id(Count)]),
|
|
||||||
Count1 is Count + 1,
|
|
||||||
assign_ids(Vs, Count1).
|
|
||||||
|
|
||||||
|
|
||||||
get_vars_info([], [], []).
|
|
||||||
get_vars_info(V.Vs,
|
|
||||||
var(VarId,DS,Ev,PIds,DistId).VarsInfo,
|
|
||||||
DistId.DistIds) :-
|
|
||||||
clpbn:get_atts(V, [dist(DistId, Parents)]), !,
|
|
||||||
get_atts(V, [id(VarId)]),
|
|
||||||
get_dist_domain_size(DistId, DS),
|
|
||||||
get_evidence(V, Ev),
|
|
||||||
vars_to_ids(Parents, PIds),
|
|
||||||
get_vars_info(Vs, VarsInfo, DistIds).
|
|
||||||
|
|
||||||
|
|
||||||
get_evidence(V, Ev) :-
|
|
||||||
clpbn:get_atts(V, [evidence(Ev)]), !.
|
|
||||||
get_evidence(_V, -1). % no evidence !!!
|
|
||||||
|
|
||||||
|
|
||||||
vars_to_ids([], []).
|
|
||||||
vars_to_ids([L|Vars], [LIds|Ids]) :-
|
|
||||||
is_list(L), !,
|
|
||||||
vars_to_ids(L, LIds),
|
|
||||||
vars_to_ids(Vars, Ids).
|
|
||||||
vars_to_ids([V|Vars], [VarId|Ids]) :-
|
|
||||||
get_atts(V, [id(VarId)]),
|
|
||||||
vars_to_ids(Vars, Ids).
|
|
||||||
|
|
||||||
|
|
||||||
get_extra_vars_info([], []).
|
|
||||||
get_extra_vars_info([V|Vs], [v(VarId, Label, Domain)|VarsInfo]) :-
|
|
||||||
get_atts(V, [id(VarId)]), !,
|
|
||||||
clpbn:get_atts(V, [key(Key), dist(DistId, _)]),
|
|
||||||
term_to_atom(Key, Label),
|
|
||||||
get_dist_domain(DistId, Domain0),
|
|
||||||
numbers_to_atoms(Domain0, Domain),
|
|
||||||
get_extra_vars_info(Vs, VarsInfo).
|
|
||||||
get_extra_vars_info([_|Vs], VarsInfo) :-
|
|
||||||
get_extra_vars_info(Vs, VarsInfo).
|
|
||||||
|
|
||||||
|
|
||||||
get_dists_parameters([],[]).
|
get_dists_parameters([],[]).
|
||||||
get_dists_parameters([Id|Ids], [dist(Id, Params)|DistsInfo]) :-
|
get_dists_parameters([Id|Ids], [dist(Id, Params)|DistsInfo]) :-
|
||||||
get_dist_params(Id, Params),
|
get_dist_params(Id, Params),
|
||||||
get_dists_parameters(Ids, DistsInfo).
|
get_dists_parameters(Ids, DistsInfo).
|
||||||
|
|
||||||
|
|
||||||
numbers_to_atoms([], []).
|
|
||||||
numbers_to_atoms([Atom|L0], [Atom|L]) :-
|
|
||||||
atom(Atom), !,
|
|
||||||
numbers_to_atoms(L0, L).
|
|
||||||
numbers_to_atoms([Number|L0], [Atom|L]) :-
|
|
||||||
number_atom(Number, Atom),
|
|
||||||
numbers_to_atoms(L0, L).
|
|
||||||
|
|
||||||
|
|
||||||
print_clpbn_vars(Var.AllVars) :-
|
|
||||||
clpbn:get_atts(Var, [key(Key),dist(DistId,Parents)]),
|
|
||||||
parents_to_keys(Parents, ParentKeys),
|
|
||||||
writeln(Var:Key:ParentKeys:DistId),
|
|
||||||
print_clpbn_vars(AllVars).
|
|
||||||
print_clpbn_vars([]).
|
|
||||||
|
|
||||||
|
|
||||||
parents_to_keys([], []).
|
|
||||||
parents_to_keys(Var.Parents, Key.Keys) :-
|
|
||||||
clpbn:get_atts(Var, [key(Key)]),
|
|
||||||
parents_to_keys(Parents, Keys).
|
|
||||||
|
|
||||||
|
77
packages/CLPBN/clpbn/bp/BayesBall.cpp
Normal file
77
packages/CLPBN/clpbn/bp/BayesBall.cpp
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
#include <cstdlib>
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <fstream>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "BayesBall.h"
|
||||||
|
#include "Util.h"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
FactorGraph*
|
||||||
|
BayesBall::getMinimalFactorGraph (const VarIds& queryIds)
|
||||||
|
{
|
||||||
|
assert (fg_.isFromBayesNetwork());
|
||||||
|
|
||||||
|
Scheduling scheduling;
|
||||||
|
for (unsigned i = 0; i < queryIds.size(); i++) {
|
||||||
|
assert (dag_.getNode (queryIds[i]));
|
||||||
|
DAGraphNode* n = dag_.getNode (queryIds[i]);
|
||||||
|
scheduling.push (ScheduleInfo (n, false, true));
|
||||||
|
}
|
||||||
|
|
||||||
|
while (!scheduling.empty()) {
|
||||||
|
ScheduleInfo& sch = scheduling.front();
|
||||||
|
DAGraphNode* n = sch.node;
|
||||||
|
n->setAsVisited();
|
||||||
|
if (n->hasEvidence() == false && sch.visitedFromChild) {
|
||||||
|
if (n->isMarkedOnTop() == false) {
|
||||||
|
n->markOnTop();
|
||||||
|
scheduleParents (n, scheduling);
|
||||||
|
}
|
||||||
|
if (n->isMarkedOnBottom() == false) {
|
||||||
|
n->markOnBottom();
|
||||||
|
scheduleChilds (n, scheduling);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (sch.visitedFromParent) {
|
||||||
|
if (n->hasEvidence() && n->isMarkedOnTop() == false) {
|
||||||
|
n->markOnTop();
|
||||||
|
scheduleParents (n, scheduling);
|
||||||
|
}
|
||||||
|
if (n->hasEvidence() == false && n->isMarkedOnBottom() == false) {
|
||||||
|
n->markOnBottom();
|
||||||
|
scheduleChilds (n, scheduling);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
scheduling.pop();
|
||||||
|
}
|
||||||
|
|
||||||
|
FactorGraph* fg = new FactorGraph();
|
||||||
|
constructGraph (fg);
|
||||||
|
return fg;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
BayesBall::constructGraph (FactorGraph* fg) const
|
||||||
|
{
|
||||||
|
const FacNodes& facNodes = fg_.facNodes();
|
||||||
|
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||||
|
const DAGraphNode* n = dag_.getNode (
|
||||||
|
facNodes[i]->factor().argument (0));
|
||||||
|
if (n->isMarkedOnTop()) {
|
||||||
|
fg->addFactor (Factor (facNodes[i]->factor()));
|
||||||
|
} else if (n->hasEvidence() && n->isVisited()) {
|
||||||
|
VarIds varIds = { facNodes[i]->factor().argument (0) };
|
||||||
|
Ranges ranges = { facNodes[i]->factor().range (0) };
|
||||||
|
Params params (ranges[0], LogAware::noEvidence());
|
||||||
|
params[n->getEvidence()] = LogAware::withEvidence();
|
||||||
|
fg->addFactor (Factor (varIds, ranges, params));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
85
packages/CLPBN/clpbn/bp/BayesBall.h
Normal file
85
packages/CLPBN/clpbn/bp/BayesBall.h
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
#ifndef HORUS_BAYESBALL_H
|
||||||
|
#define HORUS_BAYESBALL_H
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <queue>
|
||||||
|
#include <list>
|
||||||
|
#include <map>
|
||||||
|
|
||||||
|
#include "FactorGraph.h"
|
||||||
|
#include "BayesNet.h"
|
||||||
|
#include "Horus.h"
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
|
||||||
|
struct ScheduleInfo
|
||||||
|
{
|
||||||
|
ScheduleInfo (DAGraphNode* n, bool vfp, bool vfc) :
|
||||||
|
node(n), visitedFromParent(vfp), visitedFromChild(vfc) { }
|
||||||
|
|
||||||
|
DAGraphNode* node;
|
||||||
|
bool visitedFromParent;
|
||||||
|
bool visitedFromChild;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
typedef queue<ScheduleInfo, list<ScheduleInfo>> Scheduling;
|
||||||
|
|
||||||
|
|
||||||
|
class BayesBall
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
BayesBall (FactorGraph& fg)
|
||||||
|
: fg_(fg) , dag_(fg.getStructure())
|
||||||
|
{
|
||||||
|
dag_.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
FactorGraph* getMinimalFactorGraph (const VarIds&);
|
||||||
|
|
||||||
|
static FactorGraph* getMinimalFactorGraph (FactorGraph& fg, VarIds vids)
|
||||||
|
{
|
||||||
|
BayesBall bb (fg);
|
||||||
|
return bb.getMinimalFactorGraph (vids);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
|
||||||
|
void constructGraph (FactorGraph* fg) const;
|
||||||
|
|
||||||
|
void scheduleParents (const DAGraphNode* n, Scheduling& sch) const;
|
||||||
|
|
||||||
|
void scheduleChilds (const DAGraphNode* n, Scheduling& sch) const;
|
||||||
|
|
||||||
|
FactorGraph& fg_;
|
||||||
|
|
||||||
|
DAGraph& dag_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
inline void
|
||||||
|
BayesBall::scheduleParents (const DAGraphNode* n, Scheduling& sch) const
|
||||||
|
{
|
||||||
|
const vector<DAGraphNode*>& ps = n->parents();
|
||||||
|
for (vector<DAGraphNode*>::const_iterator it = ps.begin();
|
||||||
|
it != ps.end(); it++) {
|
||||||
|
sch.push (ScheduleInfo (*it, false, true));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
inline void
|
||||||
|
BayesBall::scheduleChilds (const DAGraphNode* n, Scheduling& sch) const
|
||||||
|
{
|
||||||
|
const vector<DAGraphNode*>& cs = n->childs();
|
||||||
|
for (vector<DAGraphNode*>::const_iterator it = cs.begin();
|
||||||
|
it != cs.end(); it++) {
|
||||||
|
sch.push (ScheduleInfo (*it, true, false));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // HORUS_BAYESBALL_H
|
||||||
|
|
@ -5,354 +5,57 @@
|
|||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
#include "xmlParser/xmlParser.h"
|
|
||||||
|
|
||||||
#include "BayesNet.h"
|
#include "BayesNet.h"
|
||||||
#include "Util.h"
|
#include "Util.h"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
BayesNet::~BayesNet (void)
|
|
||||||
{
|
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
|
||||||
delete nodes_[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
BayesNet::readFromBifFormat (const char* fileName)
|
DAGraph::addNode (DAGraphNode* n)
|
||||||
{
|
{
|
||||||
XMLNode xMainNode = XMLNode::openFileHelper (fileName, "BIF");
|
assert (Util::contains (varMap_, n->varId()) == false);
|
||||||
// only the first network is parsed, others are ignored
|
|
||||||
XMLNode xNode = xMainNode.getChildNode ("NETWORK");
|
|
||||||
unsigned nVars = xNode.nChildNode ("VARIABLE");
|
|
||||||
for (unsigned i = 0; i < nVars; i++) {
|
|
||||||
XMLNode var = xNode.getChildNode ("VARIABLE", i);
|
|
||||||
if (string (var.getAttribute ("TYPE")) != "nature") {
|
|
||||||
cerr << "error: only \"nature\" variables are supported" << endl;
|
|
||||||
abort();
|
|
||||||
}
|
|
||||||
States states;
|
|
||||||
string label = var.getChildNode("NAME").getText();
|
|
||||||
unsigned nrStates = var.nChildNode ("OUTCOME");
|
|
||||||
for (unsigned j = 0; j < nrStates; j++) {
|
|
||||||
if (var.getChildNode("OUTCOME", j).getText() == 0) {
|
|
||||||
stringstream ss;
|
|
||||||
ss << j + 1;
|
|
||||||
states.push_back (ss.str());
|
|
||||||
} else {
|
|
||||||
states.push_back (var.getChildNode("OUTCOME", j).getText());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
addNode (label, states);
|
|
||||||
}
|
|
||||||
|
|
||||||
unsigned nDefs = xNode.nChildNode ("DEFINITION");
|
|
||||||
if (nVars != nDefs) {
|
|
||||||
cerr << "error: different number of variables and definitions" << endl;
|
|
||||||
abort();
|
|
||||||
}
|
|
||||||
for (unsigned i = 0; i < nDefs; i++) {
|
|
||||||
XMLNode def = xNode.getChildNode ("DEFINITION", i);
|
|
||||||
string label = def.getChildNode("FOR").getText();
|
|
||||||
BayesNode* node = getBayesNode (label);
|
|
||||||
if (!node) {
|
|
||||||
cerr << "error: unknow variable `" << label << "'" << endl;
|
|
||||||
abort();
|
|
||||||
}
|
|
||||||
BnNodeSet parents;
|
|
||||||
unsigned nParams = node->nrStates();
|
|
||||||
for (int j = 0; j < def.nChildNode ("GIVEN"); j++) {
|
|
||||||
string parentLabel = def.getChildNode("GIVEN", j).getText();
|
|
||||||
BayesNode* parentNode = getBayesNode (parentLabel);
|
|
||||||
if (!parentNode) {
|
|
||||||
cerr << "error: unknow variable `" << parentLabel << "'" << endl;
|
|
||||||
abort();
|
|
||||||
}
|
|
||||||
nParams *= parentNode->nrStates();
|
|
||||||
parents.push_back (parentNode);
|
|
||||||
}
|
|
||||||
node->setParents (parents);
|
|
||||||
unsigned count = 0;
|
|
||||||
Params params (nParams);
|
|
||||||
stringstream s (def.getChildNode("TABLE").getText());
|
|
||||||
while (!s.eof() && count < nParams) {
|
|
||||||
s >> params[count];
|
|
||||||
count ++;
|
|
||||||
}
|
|
||||||
if (count != nParams) {
|
|
||||||
cerr << "error: invalid number of parameters " ;
|
|
||||||
cerr << "for variable `" << label << "'" << endl;
|
|
||||||
abort();
|
|
||||||
}
|
|
||||||
params = reorderParameters (params, node->nrStates());
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
Util::toLog (params);
|
|
||||||
}
|
|
||||||
node->setParams (params);
|
|
||||||
}
|
|
||||||
setIndexes();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
BayesNode*
|
|
||||||
BayesNet::addNode (BayesNode* n)
|
|
||||||
{
|
|
||||||
varMap_.insert (make_pair (n->varId(), nodes_.size()));
|
|
||||||
nodes_.push_back (n);
|
nodes_.push_back (n);
|
||||||
return nodes_.back();
|
varMap_[n->varId()] = n;
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
BayesNode*
|
|
||||||
BayesNet::addNode (string label, const States& states)
|
|
||||||
{
|
|
||||||
VarId vid = nodes_.size();
|
|
||||||
varMap_.insert (make_pair (vid, nodes_.size()));
|
|
||||||
GraphicalModel::addVariableInformation (vid, label, states);
|
|
||||||
BayesNode* node = new BayesNode (VarNode (vid, states.size()));
|
|
||||||
nodes_.push_back (node);
|
|
||||||
return node;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
BayesNode*
|
|
||||||
BayesNet::getBayesNode (VarId vid) const
|
|
||||||
{
|
|
||||||
IndexMap::const_iterator it = varMap_.find (vid);
|
|
||||||
if (it == varMap_.end()) {
|
|
||||||
return 0;
|
|
||||||
} else {
|
|
||||||
return nodes_[it->second];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
BayesNode*
|
|
||||||
BayesNet::getBayesNode (string label) const
|
|
||||||
{
|
|
||||||
BayesNode* node = 0;
|
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
|
||||||
if (nodes_[i]->label() == label) {
|
|
||||||
node = nodes_[i];
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return node;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
VarNode*
|
|
||||||
BayesNet::getVariableNode (VarId vid) const
|
|
||||||
{
|
|
||||||
BayesNode* node = getBayesNode (vid);
|
|
||||||
assert (node);
|
|
||||||
return node;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
VarNodes
|
|
||||||
BayesNet::getVariableNodes (void) const
|
|
||||||
{
|
|
||||||
VarNodes vars;
|
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
|
||||||
vars.push_back (nodes_[i]);
|
|
||||||
}
|
|
||||||
return vars;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
const BnNodeSet&
|
|
||||||
BayesNet::getBayesNodes (void) const
|
|
||||||
{
|
|
||||||
return nodes_;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
unsigned
|
|
||||||
BayesNet::nrNodes (void) const
|
|
||||||
{
|
|
||||||
return nodes_.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
BnNodeSet
|
|
||||||
BayesNet::getRootNodes (void) const
|
|
||||||
{
|
|
||||||
BnNodeSet roots;
|
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
|
||||||
if (nodes_[i]->isRoot()) {
|
|
||||||
roots.push_back (nodes_[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return roots;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
BnNodeSet
|
|
||||||
BayesNet::getLeafNodes (void) const
|
|
||||||
{
|
|
||||||
BnNodeSet leafs;
|
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
|
||||||
if (nodes_[i]->isLeaf()) {
|
|
||||||
leafs.push_back (nodes_[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return leafs;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
BayesNet*
|
|
||||||
BayesNet::getMinimalRequesiteNetwork (VarId vid) const
|
|
||||||
{
|
|
||||||
return getMinimalRequesiteNetwork (VarIds() = {vid});
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
BayesNet*
|
|
||||||
BayesNet::getMinimalRequesiteNetwork (const VarIds& queryVarIds) const
|
|
||||||
{
|
|
||||||
BnNodeSet queryVars;
|
|
||||||
Scheduling scheduling;
|
|
||||||
for (unsigned i = 0; i < queryVarIds.size(); i++) {
|
|
||||||
BayesNode* n = getBayesNode (queryVarIds[i]);
|
|
||||||
assert (n);
|
|
||||||
queryVars.push_back (n);
|
|
||||||
scheduling.push (ScheduleInfo (n, false, true));
|
|
||||||
}
|
|
||||||
|
|
||||||
vector<StateInfo*> states (nodes_.size(), 0);
|
|
||||||
|
|
||||||
while (!scheduling.empty()) {
|
|
||||||
ScheduleInfo& sch = scheduling.front();
|
|
||||||
StateInfo* state = states[sch.node->getIndex()];
|
|
||||||
if (!state) {
|
|
||||||
state = new StateInfo();
|
|
||||||
states[sch.node->getIndex()] = state;
|
|
||||||
} else {
|
|
||||||
state->visited = true;
|
|
||||||
}
|
|
||||||
if (!sch.node->hasEvidence() && sch.visitedFromChild) {
|
|
||||||
if (!state->markedOnTop) {
|
|
||||||
state->markedOnTop = true;
|
|
||||||
scheduleParents (sch.node, scheduling);
|
|
||||||
}
|
|
||||||
if (!state->markedOnBottom) {
|
|
||||||
state->markedOnBottom = true;
|
|
||||||
scheduleChilds (sch.node, scheduling);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (sch.visitedFromParent) {
|
|
||||||
if (sch.node->hasEvidence() && !state->markedOnTop) {
|
|
||||||
state->markedOnTop = true;
|
|
||||||
scheduleParents (sch.node, scheduling);
|
|
||||||
}
|
|
||||||
if (!sch.node->hasEvidence() && !state->markedOnBottom) {
|
|
||||||
state->markedOnBottom = true;
|
|
||||||
scheduleChilds (sch.node, scheduling);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
scheduling.pop();
|
|
||||||
}
|
|
||||||
/*
|
|
||||||
cout << "\t\ttop\tbottom" << endl;
|
|
||||||
cout << "variable\t\tmarked\tmarked\tvisited\tobserved" << endl;
|
|
||||||
Util::printDashedLine();
|
|
||||||
cout << endl;
|
|
||||||
for (unsigned i = 0; i < states.size(); i++) {
|
|
||||||
cout << nodes_[i]->label() << ":\t\t" ;
|
|
||||||
if (states[i]) {
|
|
||||||
states[i]->markedOnTop ? cout << "yes\t" : cout << "no\t" ;
|
|
||||||
states[i]->markedOnBottom ? cout << "yes\t" : cout << "no\t" ;
|
|
||||||
states[i]->visited ? cout << "yes\t" : cout << "no\t" ;
|
|
||||||
nodes_[i]->hasEvidence() ? cout << "yes" : cout << "no" ;
|
|
||||||
cout << endl;
|
|
||||||
} else {
|
|
||||||
cout << "no\tno\tno\t" ;
|
|
||||||
nodes_[i]->hasEvidence() ? cout << "yes" : cout << "no" ;
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
cout << endl;
|
|
||||||
*/
|
|
||||||
BayesNet* bn = new BayesNet();
|
|
||||||
constructGraph (bn, states);
|
|
||||||
|
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
|
||||||
delete states[i];
|
|
||||||
}
|
|
||||||
return bn;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
BayesNet::constructGraph (BayesNet* bn,
|
DAGraph::addEdge (VarId vid1, VarId vid2)
|
||||||
const vector<StateInfo*>& states) const
|
|
||||||
{
|
{
|
||||||
BnNodeSet mrnNodes;
|
unordered_map<VarId, DAGraphNode*>::iterator it1;
|
||||||
vector<VarIds> parents;
|
unordered_map<VarId, DAGraphNode*>::iterator it2;
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
it1 = varMap_.find (vid1);
|
||||||
bool isRequired = false;
|
it2 = varMap_.find (vid2);
|
||||||
if (states[i]) {
|
assert (it1 != varMap_.end());
|
||||||
isRequired = (nodes_[i]->hasEvidence() && states[i]->visited)
|
assert (it2 != varMap_.end());
|
||||||
||
|
it1->second->addChild (it2->second);
|
||||||
states[i]->markedOnTop;
|
it2->second->addParent (it1->second);
|
||||||
}
|
|
||||||
if (isRequired) {
|
|
||||||
parents.push_back (VarIds());
|
|
||||||
if (states[i]->markedOnTop) {
|
|
||||||
const BnNodeSet& ps = nodes_[i]->getParents();
|
|
||||||
for (unsigned j = 0; j < ps.size(); j++) {
|
|
||||||
parents.back().push_back (ps[j]->varId());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
assert (bn->getBayesNode (nodes_[i]->varId()) == 0);
|
|
||||||
BayesNode* mrnNode = new BayesNode (nodes_[i]);
|
|
||||||
bn->addNode (mrnNode);
|
|
||||||
mrnNodes.push_back (mrnNode);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (unsigned i = 0; i < mrnNodes.size(); i++) {
|
|
||||||
BnNodeSet ps;
|
|
||||||
for (unsigned j = 0; j < parents[i].size(); j++) {
|
|
||||||
assert (bn->getBayesNode (parents[i][j]) != 0);
|
|
||||||
ps.push_back (bn->getBayesNode (parents[i][j]));
|
|
||||||
}
|
|
||||||
mrnNodes[i]->setParents (ps);
|
|
||||||
}
|
|
||||||
bn->setIndexes();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
const DAGraphNode*
|
||||||
BayesNet::isPolyTree (void) const
|
DAGraph::getNode (VarId vid) const
|
||||||
{
|
{
|
||||||
return !containsUndirectedCycle();
|
unordered_map<VarId, DAGraphNode*>::const_iterator it;
|
||||||
|
it = varMap_.find (vid);
|
||||||
|
return it != varMap_.end() ? it->second : 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
DAGraphNode*
|
||||||
|
DAGraph::getNode (VarId vid)
|
||||||
|
{
|
||||||
|
unordered_map<VarId, DAGraphNode*>::const_iterator it;
|
||||||
|
it = varMap_.find (vid);
|
||||||
|
return it != varMap_.end() ? it->second : 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
BayesNet::setIndexes (void)
|
DAGraph::setIndexes (void)
|
||||||
{
|
{
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||||
nodes_[i]->setIndex (i);
|
nodes_[i]->setIndex (i);
|
||||||
@ -362,213 +65,43 @@ BayesNet::setIndexes (void)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
BayesNet::printGraphicalModel (void) const
|
DAGraph::clear (void)
|
||||||
{
|
{
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||||
cout << *nodes_[i];
|
nodes_[i]->clear();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
BayesNet::exportToGraphViz (const char* fileName,
|
DAGraph::exportToGraphViz (const char* fileName)
|
||||||
bool showNeighborless,
|
|
||||||
const VarIds& highlightVarIds) const
|
|
||||||
{
|
{
|
||||||
ofstream out (fileName);
|
ofstream out (fileName);
|
||||||
if (!out.is_open()) {
|
if (!out.is_open()) {
|
||||||
cerr << "error: cannot open file to write at " ;
|
cerr << "error: cannot open file to write at " ;
|
||||||
cerr << "BayesNet::exportToDotFile()" << endl;
|
cerr << "DAGraph::exportToDotFile()" << endl;
|
||||||
abort();
|
abort();
|
||||||
}
|
}
|
||||||
|
|
||||||
out << "digraph {" << endl;
|
out << "digraph {" << endl;
|
||||||
out << "ranksep=1" << endl;
|
out << "ranksep=1" << endl;
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||||
if (showNeighborless || nodes_[i]->hasNeighbors()) {
|
out << nodes_[i]->varId() ;
|
||||||
out << nodes_[i]->varId() ;
|
out << " [" ;
|
||||||
if (nodes_[i]->hasEvidence()) {
|
out << "label=\"" << nodes_[i]->label() << "\"" ;
|
||||||
out << " [" ;
|
if (nodes_[i]->hasEvidence()) {
|
||||||
out << "label=\"" << nodes_[i]->label() << "\"," ;
|
out << ",style=filled, fillcolor=yellow" ;
|
||||||
out << "style=filled, fillcolor=yellow" ;
|
|
||||||
out << "]" ;
|
|
||||||
} else {
|
|
||||||
out << " [" ;
|
|
||||||
out << "label=\"" << nodes_[i]->label() << "\"" ;
|
|
||||||
out << "]" ;
|
|
||||||
}
|
|
||||||
out << endl;
|
|
||||||
}
|
}
|
||||||
|
out << "]" << endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (unsigned i = 0; i < highlightVarIds.size(); i++) {
|
|
||||||
BayesNode* node = getBayesNode (highlightVarIds[i]);
|
|
||||||
if (node) {
|
|
||||||
out << node->varId() ;
|
|
||||||
out << " [shape=box3d]" << endl;
|
|
||||||
} else {
|
|
||||||
cout << "error: invalid variable id: " << highlightVarIds[i] << endl;
|
|
||||||
abort();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||||
const BnNodeSet& childs = nodes_[i]->getChilds();
|
const vector<DAGraphNode*>& childs = nodes_[i]->childs();
|
||||||
for (unsigned j = 0; j < childs.size(); j++) {
|
for (unsigned j = 0; j < childs.size(); j++) {
|
||||||
out << nodes_[i]->varId() << " -> " << childs[j]->varId() << " [style=bold]" << endl ;
|
out << nodes_[i]->varId() << " -> " << childs[j]->varId();
|
||||||
|
out << " [style=bold]" << endl ;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
out << "}" << endl;
|
out << "}" << endl;
|
||||||
out.close();
|
out.close();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BayesNet::exportToBifFormat (const char* fileName) const
|
|
||||||
{
|
|
||||||
ofstream out (fileName);
|
|
||||||
if(!out.is_open()) {
|
|
||||||
cerr << "error: cannot open file to write at " ;
|
|
||||||
cerr << "BayesNet::exportToBifFile()" << endl;
|
|
||||||
abort();
|
|
||||||
}
|
|
||||||
out << "<?xml version=\"1.0\" encoding=\"US-ASCII\"?>" << endl;
|
|
||||||
out << "<BIF VERSION=\"0.3\">" << endl;
|
|
||||||
out << "<NETWORK>" << endl;
|
|
||||||
out << "<NAME>" << fileName << "</NAME>" << endl << endl;
|
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
|
||||||
out << "<VARIABLE TYPE=\"nature\">" << endl;
|
|
||||||
out << "\t<NAME>" << nodes_[i]->label() << "</NAME>" << endl;
|
|
||||||
const States& states = nodes_[i]->states();
|
|
||||||
for (unsigned j = 0; j < states.size(); j++) {
|
|
||||||
out << "\t<OUTCOME>" << states[j] << "</OUTCOME>" << endl;
|
|
||||||
}
|
|
||||||
out << "</VARIABLE>" << endl << endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
|
||||||
out << "<DEFINITION>" << endl;
|
|
||||||
out << "\t<FOR>" << nodes_[i]->label() << "</FOR>" << endl;
|
|
||||||
const BnNodeSet& parents = nodes_[i]->getParents();
|
|
||||||
for (unsigned j = 0; j < parents.size(); j++) {
|
|
||||||
out << "\t<GIVEN>" << parents[j]->label();
|
|
||||||
out << "</GIVEN>" << endl;
|
|
||||||
}
|
|
||||||
Params params = revertParameterReorder (
|
|
||||||
nodes_[i]->params(), nodes_[i]->nrStates());
|
|
||||||
out << "\t<TABLE>" ;
|
|
||||||
for (unsigned j = 0; j < params.size(); j++) {
|
|
||||||
out << " " << params[j];
|
|
||||||
}
|
|
||||||
out << " </TABLE>" << endl;
|
|
||||||
out << "</DEFINITION>" << endl << endl;
|
|
||||||
}
|
|
||||||
out << "</NETWORK>" << endl;
|
|
||||||
out << "</BIF>" << endl << endl;
|
|
||||||
out.close();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
|
||||||
BayesNet::containsUndirectedCycle (void) const
|
|
||||||
{
|
|
||||||
vector<bool> visited (nodes_.size(), false);
|
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
|
||||||
int v = nodes_[i]->getIndex();
|
|
||||||
if (!visited[v]) {
|
|
||||||
if (containsUndirectedCycle (v, -1, visited)) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
|
||||||
BayesNet::containsUndirectedCycle (int v, int p, vector<bool>& visited) const
|
|
||||||
{
|
|
||||||
visited[v] = true;
|
|
||||||
vector<int> adjacencies = getAdjacentNodes (v);
|
|
||||||
for (unsigned i = 0; i < adjacencies.size(); i++) {
|
|
||||||
int w = adjacencies[i];
|
|
||||||
if (!visited[w]) {
|
|
||||||
if (containsUndirectedCycle (w, v, visited)) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else if (visited[w] && w != p) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false; // no cycle detected in this component
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
vector<int>
|
|
||||||
BayesNet::getAdjacentNodes (int v) const
|
|
||||||
{
|
|
||||||
vector<int> adjacencies;
|
|
||||||
const BnNodeSet& parents = nodes_[v]->getParents();
|
|
||||||
const BnNodeSet& childs = nodes_[v]->getChilds();
|
|
||||||
for (unsigned i = 0; i < parents.size(); i++) {
|
|
||||||
adjacencies.push_back (parents[i]->getIndex());
|
|
||||||
}
|
|
||||||
for (unsigned i = 0; i < childs.size(); i++) {
|
|
||||||
adjacencies.push_back (childs[i]->getIndex());
|
|
||||||
}
|
|
||||||
return adjacencies;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
|
||||||
BayesNet::reorderParameters (const Params& params, unsigned dsize) const
|
|
||||||
{
|
|
||||||
// the interchange format for bayesian networks keeps the probabilities
|
|
||||||
// in the following order:
|
|
||||||
// p(a1|b1,c1) p(a2|b1,c1) p(a1|b1,c2) p(a2|b1,c2) p(a1|b2,c1) p(a2|b2,c1)
|
|
||||||
// p(a1|b2,c2) p(a2|b2,c2).
|
|
||||||
//
|
|
||||||
// however, in clpbn we keep the probabilities in this order:
|
|
||||||
// p(a1|b1,c1) p(a1|b1,c2) p(a1|b2,c1) p(a1|b2,c2) p(a2|b1,c1) p(a2|b1,c2)
|
|
||||||
// p(a2|b2,c1) p(a2|b2,c2).
|
|
||||||
unsigned count = 0;
|
|
||||||
unsigned rowSize = params.size() / dsize;
|
|
||||||
Params reordered;
|
|
||||||
while (reordered.size() < params.size()) {
|
|
||||||
unsigned idx = count;
|
|
||||||
for (unsigned i = 0; i < rowSize; i++) {
|
|
||||||
reordered.push_back (params[idx]);
|
|
||||||
idx += dsize ;
|
|
||||||
}
|
|
||||||
count++;
|
|
||||||
}
|
|
||||||
return reordered;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
|
||||||
BayesNet::revertParameterReorder (const Params& params, unsigned dsize) const
|
|
||||||
{
|
|
||||||
unsigned count = 0;
|
|
||||||
unsigned rowSize = params.size() / dsize;
|
|
||||||
Params reordered;
|
|
||||||
while (reordered.size() < params.size()) {
|
|
||||||
unsigned idx = count;
|
|
||||||
for (unsigned i = 0; i < dsize; i++) {
|
|
||||||
reordered.push_back (params[idx]);
|
|
||||||
idx += rowSize;
|
|
||||||
}
|
|
||||||
count ++;
|
|
||||||
}
|
|
||||||
return reordered;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
@ -6,127 +6,83 @@
|
|||||||
#include <list>
|
#include <list>
|
||||||
#include <map>
|
#include <map>
|
||||||
|
|
||||||
#include "GraphicalModel.h"
|
#include "Var.h"
|
||||||
#include "BayesNode.h"
|
|
||||||
#include "Horus.h"
|
#include "Horus.h"
|
||||||
|
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
|
|
||||||
struct ScheduleInfo
|
class Var;
|
||||||
{
|
|
||||||
ScheduleInfo (BayesNode* n, bool vfp, bool vfc) :
|
|
||||||
node(n), visitedFromParent(vfp), visitedFromChild(vfc) { }
|
|
||||||
BayesNode* node;
|
|
||||||
bool visitedFromParent;
|
|
||||||
bool visitedFromChild;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
class DAGraphNode : public Var
|
||||||
struct StateInfo
|
|
||||||
{
|
|
||||||
StateInfo (void) : visited(false), markedOnTop(false),
|
|
||||||
markedOnBottom(false) { }
|
|
||||||
bool visited;
|
|
||||||
bool markedOnTop;
|
|
||||||
bool markedOnBottom;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
typedef queue<ScheduleInfo, list<ScheduleInfo> > Scheduling;
|
|
||||||
|
|
||||||
|
|
||||||
class BayesNet : public GraphicalModel
|
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
BayesNet (void) { };
|
DAGraphNode (Var* v) : Var (v) , visited_(false),
|
||||||
|
markedOnTop_(false), markedOnBottom_(false) { }
|
||||||
|
|
||||||
~BayesNet (void);
|
const vector<DAGraphNode*>& childs (void) const { return childs_; }
|
||||||
|
|
||||||
void readFromBifFormat (const char*);
|
vector<DAGraphNode*>& childs (void) { return childs_; }
|
||||||
|
|
||||||
BayesNode* addNode (BayesNode*);
|
const vector<DAGraphNode*>& parents (void) const { return parents_; }
|
||||||
|
|
||||||
BayesNode* addNode (string, const States&);
|
vector<DAGraphNode*>& parents (void) { return parents_; }
|
||||||
|
|
||||||
BayesNode* getBayesNode (VarId) const;
|
void addParent (DAGraphNode* p) { parents_.push_back (p); }
|
||||||
|
|
||||||
BayesNode* getBayesNode (string) const;
|
void addChild (DAGraphNode* c) { childs_.push_back (c); }
|
||||||
|
|
||||||
VarNode* getVariableNode (VarId) const;
|
bool isVisited (void) const { return visited_; }
|
||||||
|
|
||||||
VarNodes getVariableNodes (void) const;
|
void setAsVisited (void) { visited_ = true; }
|
||||||
|
|
||||||
const BnNodeSet& getBayesNodes (void) const;
|
bool isMarkedOnTop (void) const { return markedOnTop_; }
|
||||||
|
|
||||||
unsigned nrNodes (void) const;
|
void markOnTop (void) { markedOnTop_ = true; }
|
||||||
|
|
||||||
BnNodeSet getRootNodes (void) const;
|
bool isMarkedOnBottom (void) const { return markedOnBottom_; }
|
||||||
|
|
||||||
BnNodeSet getLeafNodes (void) const;
|
void markOnBottom (void) { markedOnBottom_ = true; }
|
||||||
|
|
||||||
BayesNet* getMinimalRequesiteNetwork (VarId) const;
|
void clear (void) { visited_ = markedOnTop_ = markedOnBottom_ = false; }
|
||||||
|
|
||||||
BayesNet* getMinimalRequesiteNetwork (const VarIds&) const;
|
private:
|
||||||
|
bool visited_;
|
||||||
|
bool markedOnTop_;
|
||||||
|
bool markedOnBottom_;
|
||||||
|
|
||||||
void constructGraph (BayesNet*, const vector<StateInfo*>&) const;
|
vector<DAGraphNode*> childs_;
|
||||||
|
vector<DAGraphNode*> parents_;
|
||||||
|
};
|
||||||
|
|
||||||
bool isPolyTree (void) const;
|
|
||||||
|
class DAGraph
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
DAGraph (void) { }
|
||||||
|
|
||||||
|
void addNode (DAGraphNode* n);
|
||||||
|
|
||||||
|
void addEdge (VarId vid1, VarId vid2);
|
||||||
|
|
||||||
|
const DAGraphNode* getNode (VarId vid) const;
|
||||||
|
|
||||||
|
DAGraphNode* getNode (VarId vid);
|
||||||
|
|
||||||
|
bool empty (void) const { return nodes_.empty(); }
|
||||||
|
|
||||||
void setIndexes (void);
|
void setIndexes (void);
|
||||||
|
|
||||||
void printGraphicalModel (void) const;
|
void clear (void);
|
||||||
|
|
||||||
void exportToGraphViz (const char*, bool = true,
|
void exportToGraphViz (const char*);
|
||||||
const VarIds& = VarIds()) const;
|
|
||||||
|
|
||||||
void exportToBifFormat (const char*) const;
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
DISALLOW_COPY_AND_ASSIGN (BayesNet);
|
vector<DAGraphNode*> nodes_;
|
||||||
|
|
||||||
bool containsUndirectedCycle (void) const;
|
unordered_map<VarId, DAGraphNode*> varMap_;
|
||||||
|
|
||||||
bool containsUndirectedCycle (int, int, vector<bool>&)const;
|
|
||||||
|
|
||||||
vector<int> getAdjacentNodes (int) const;
|
|
||||||
|
|
||||||
Params reorderParameters (const Params&, unsigned) const;
|
|
||||||
|
|
||||||
Params revertParameterReorder (const Params&, unsigned) const;
|
|
||||||
|
|
||||||
void scheduleParents (const BayesNode*, Scheduling&) const;
|
|
||||||
|
|
||||||
void scheduleChilds (const BayesNode*, Scheduling&) const;
|
|
||||||
|
|
||||||
BnNodeSet nodes_;
|
|
||||||
|
|
||||||
typedef unordered_map<unsigned, unsigned> IndexMap;
|
|
||||||
IndexMap varMap_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline void
|
|
||||||
BayesNet::scheduleParents (const BayesNode* n, Scheduling& sch) const
|
|
||||||
{
|
|
||||||
const BnNodeSet& ps = n->getParents();
|
|
||||||
for (BnNodeSet::const_iterator it = ps.begin(); it != ps.end(); it++) {
|
|
||||||
sch.push (ScheduleInfo (*it, false, true));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline void
|
|
||||||
BayesNet::scheduleChilds (const BayesNode* n, Scheduling& sch) const
|
|
||||||
{
|
|
||||||
const BnNodeSet& cs = n->getChilds();
|
|
||||||
for (BnNodeSet::const_iterator it = cs.begin(); it != cs.end(); it++) {
|
|
||||||
sch.push (ScheduleInfo (*it, true, false));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif // HORUS_BAYESNET_H
|
#endif // HORUS_BAYESNET_H
|
||||||
|
|
||||||
|
@ -1,247 +0,0 @@
|
|||||||
#include <cstdlib>
|
|
||||||
#include <cassert>
|
|
||||||
|
|
||||||
#include <iomanip>
|
|
||||||
#include <iostream>
|
|
||||||
#include <sstream>
|
|
||||||
|
|
||||||
#include "BayesNode.h"
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BayesNode::setParams (const Params& params)
|
|
||||||
{
|
|
||||||
params_ = params;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BayesNode::setParents (const BnNodeSet& parents)
|
|
||||||
{
|
|
||||||
parents_ = parents;
|
|
||||||
for (unsigned int i = 0; i < parents.size(); i++) {
|
|
||||||
parents[i]->addChild (this);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BayesNode::addChild (BayesNode* node)
|
|
||||||
{
|
|
||||||
childs_.push_back (node);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
|
||||||
BayesNode::getRow (int rowIndex) const
|
|
||||||
{
|
|
||||||
int rowSize = getRowSize();
|
|
||||||
int offset = rowSize * rowIndex;
|
|
||||||
Params row (rowSize);
|
|
||||||
for (int i = 0; i < rowSize; i++) {
|
|
||||||
row[i] = params_[offset + i] ;
|
|
||||||
}
|
|
||||||
return row;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
|
||||||
BayesNode::isRoot (void)
|
|
||||||
{
|
|
||||||
return getParents().empty();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
|
||||||
BayesNode::isLeaf (void)
|
|
||||||
{
|
|
||||||
return getChilds().empty();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
|
||||||
BayesNode::hasNeighbors (void) const
|
|
||||||
{
|
|
||||||
return childs_.size() != 0 || parents_.size() != 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
int
|
|
||||||
BayesNode::getCptSize (void)
|
|
||||||
{
|
|
||||||
return params_.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
int
|
|
||||||
BayesNode::indexOfParent (const BayesNode* parent) const
|
|
||||||
{
|
|
||||||
for (unsigned int i = 0; i < parents_.size(); i++) {
|
|
||||||
if (parents_[i] == parent) {
|
|
||||||
return i;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
string
|
|
||||||
BayesNode::cptEntryToString (
|
|
||||||
int row,
|
|
||||||
const vector<unsigned>& stateConf) const
|
|
||||||
{
|
|
||||||
stringstream ss;
|
|
||||||
ss << "p(" ;
|
|
||||||
ss << states()[row];
|
|
||||||
if (parents_.size() > 0) {
|
|
||||||
ss << "|" ;
|
|
||||||
for (unsigned int i = 0; i < stateConf.size(); i++) {
|
|
||||||
if (i != 0) {
|
|
||||||
ss << ",";
|
|
||||||
}
|
|
||||||
ss << parents_[i]->states()[stateConf[i]];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ss << ")" ;
|
|
||||||
return ss.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
vector<string>
|
|
||||||
BayesNode::getDomainHeaders (void) const
|
|
||||||
{
|
|
||||||
unsigned nParents = parents_.size();
|
|
||||||
unsigned rowSize = getRowSize();
|
|
||||||
unsigned nReps = 1;
|
|
||||||
vector<string> headers (rowSize);
|
|
||||||
for (int i = nParents - 1; i >= 0; i--) {
|
|
||||||
States states = parents_[i]->states();
|
|
||||||
unsigned index = 0;
|
|
||||||
while (index < rowSize) {
|
|
||||||
for (unsigned j = 0; j < parents_[i]->nrStates(); j++) {
|
|
||||||
for (unsigned r = 0; r < nReps; r++) {
|
|
||||||
if (headers[index] != "") {
|
|
||||||
headers[index] = states[j] + "," + headers[index];
|
|
||||||
} else {
|
|
||||||
headers[index] = states[j];
|
|
||||||
}
|
|
||||||
index++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
nReps *= parents_[i]->nrStates();
|
|
||||||
}
|
|
||||||
return headers;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
ostream&
|
|
||||||
operator << (ostream& o, const BayesNode& node)
|
|
||||||
{
|
|
||||||
o << "variable " << node.getIndex() << endl;
|
|
||||||
o << "Var Id: " << node.varId() << endl;
|
|
||||||
o << "Label: " << node.label() << endl;
|
|
||||||
|
|
||||||
o << "Evidence: " ;
|
|
||||||
if (node.hasEvidence()) {
|
|
||||||
o << node.getEvidence();
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
o << "no" ;
|
|
||||||
}
|
|
||||||
o << endl;
|
|
||||||
|
|
||||||
o << "Parents: " ;
|
|
||||||
const BnNodeSet& parents = node.getParents();
|
|
||||||
if (parents.size() != 0) {
|
|
||||||
for (unsigned int i = 0; i < parents.size() - 1; i++) {
|
|
||||||
o << parents[i]->label() << ", " ;
|
|
||||||
}
|
|
||||||
o << parents[parents.size() - 1]->label();
|
|
||||||
}
|
|
||||||
o << endl;
|
|
||||||
|
|
||||||
o << "Childs: " ;
|
|
||||||
const BnNodeSet& childs = node.getChilds();
|
|
||||||
if (childs.size() != 0) {
|
|
||||||
for (unsigned int i = 0; i < childs.size() - 1; i++) {
|
|
||||||
o << childs[i]->label() << ", " ;
|
|
||||||
}
|
|
||||||
o << childs[childs.size() - 1]->label();
|
|
||||||
}
|
|
||||||
o << endl;
|
|
||||||
|
|
||||||
o << "Domain: " ;
|
|
||||||
States states = node.states();
|
|
||||||
for (unsigned int i = 0; i < states.size() - 1; i++) {
|
|
||||||
o << states[i] << ", " ;
|
|
||||||
}
|
|
||||||
if (states.size() != 0) {
|
|
||||||
o << states[states.size() - 1];
|
|
||||||
}
|
|
||||||
o << endl;
|
|
||||||
|
|
||||||
// min width of first column
|
|
||||||
const unsigned int MIN_DOMAIN_WIDTH = 4;
|
|
||||||
// min width of following columns
|
|
||||||
const unsigned int MIN_COMBO_WIDTH = 12;
|
|
||||||
|
|
||||||
unsigned int domainWidth = states[0].length();
|
|
||||||
for (unsigned int i = 1; i < states.size(); i++) {
|
|
||||||
if (states[i].length() > domainWidth) {
|
|
||||||
domainWidth = states[i].length();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
domainWidth = (domainWidth < MIN_DOMAIN_WIDTH)
|
|
||||||
? MIN_DOMAIN_WIDTH
|
|
||||||
: domainWidth;
|
|
||||||
|
|
||||||
o << left << setw (domainWidth) << "cpt" << right;
|
|
||||||
|
|
||||||
vector<int> widths;
|
|
||||||
int lineWidth = domainWidth;
|
|
||||||
vector<string> headers = node.getDomainHeaders();
|
|
||||||
|
|
||||||
if (!headers.empty()) {
|
|
||||||
for (unsigned int i = 0; i < headers.size(); i++) {
|
|
||||||
unsigned int len = headers[i].length();
|
|
||||||
int w = (len < MIN_COMBO_WIDTH) ? MIN_COMBO_WIDTH : len;
|
|
||||||
widths.push_back (w);
|
|
||||||
o << setw (w) << headers[i];
|
|
||||||
lineWidth += w;
|
|
||||||
}
|
|
||||||
o << endl;
|
|
||||||
} else {
|
|
||||||
cout << endl;
|
|
||||||
widths.push_back (domainWidth);
|
|
||||||
lineWidth += MIN_COMBO_WIDTH;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = 0; i < lineWidth; i++) {
|
|
||||||
o << "-" ;
|
|
||||||
}
|
|
||||||
o << endl;
|
|
||||||
|
|
||||||
for (unsigned int i = 0; i < states.size(); i++) {
|
|
||||||
Params row = node.getRow (i);
|
|
||||||
o << left << setw (domainWidth) << states[i] << right;
|
|
||||||
for (unsigned j = 0; j < node.getRowSize(); j++) {
|
|
||||||
o << setw (widths[j]) << row[j];
|
|
||||||
}
|
|
||||||
o << endl;
|
|
||||||
}
|
|
||||||
o << endl;
|
|
||||||
|
|
||||||
return o;
|
|
||||||
}
|
|
||||||
|
|
@ -1,81 +0,0 @@
|
|||||||
#ifndef HORUS_BAYESNODE_H
|
|
||||||
#define HORUS_BAYESNODE_H
|
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "VarNode.h"
|
|
||||||
#include "Horus.h"
|
|
||||||
|
|
||||||
using namespace std;
|
|
||||||
|
|
||||||
|
|
||||||
class BayesNode : public VarNode
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
|
|
||||||
BayesNode (const VarNode& v) : VarNode (v) { }
|
|
||||||
|
|
||||||
BayesNode (const BayesNode* n) :
|
|
||||||
VarNode (n->varId(), n->nrStates(), n->getEvidence()),
|
|
||||||
params_(n->params()), distId_(n->distId()) { }
|
|
||||||
|
|
||||||
BayesNode (VarId vid, unsigned nrStates, int ev,
|
|
||||||
const Params& ps, unsigned id)
|
|
||||||
: VarNode (vid, nrStates, ev) , params_(ps), distId_(id) { }
|
|
||||||
|
|
||||||
const BnNodeSet& getParents (void) const { return parents_; }
|
|
||||||
|
|
||||||
const BnNodeSet& getChilds (void) const { return childs_; }
|
|
||||||
|
|
||||||
const Params& params (void) const { return params_; }
|
|
||||||
|
|
||||||
unsigned distId (void) const { return distId_; }
|
|
||||||
|
|
||||||
unsigned getRowSize (void) const
|
|
||||||
{
|
|
||||||
return params_.size() / nrStates();
|
|
||||||
}
|
|
||||||
|
|
||||||
double getProbability (int row, unsigned col)
|
|
||||||
{
|
|
||||||
int idx = (row * getRowSize()) + col;
|
|
||||||
return params_[idx];
|
|
||||||
}
|
|
||||||
|
|
||||||
void setParams (const Params& params);
|
|
||||||
|
|
||||||
void setParents (const BnNodeSet&);
|
|
||||||
|
|
||||||
void addChild (BayesNode*);
|
|
||||||
|
|
||||||
const Params& getParameters (void);
|
|
||||||
|
|
||||||
Params getRow (int) const;
|
|
||||||
|
|
||||||
bool isRoot (void);
|
|
||||||
|
|
||||||
bool isLeaf (void);
|
|
||||||
|
|
||||||
bool hasNeighbors (void) const;
|
|
||||||
|
|
||||||
int getCptSize (void);
|
|
||||||
|
|
||||||
int indexOfParent (const BayesNode*) const;
|
|
||||||
|
|
||||||
string cptEntryToString (int, const vector<unsigned>&) const;
|
|
||||||
|
|
||||||
friend ostream& operator << (ostream&, const BayesNode&);
|
|
||||||
|
|
||||||
private:
|
|
||||||
DISALLOW_COPY_AND_ASSIGN (BayesNode);
|
|
||||||
|
|
||||||
States getDomainHeaders (void) const;
|
|
||||||
|
|
||||||
BnNodeSet parents_;
|
|
||||||
BnNodeSet childs_;
|
|
||||||
Params params_;
|
|
||||||
unsigned distId_;
|
|
||||||
};
|
|
||||||
|
|
||||||
#endif // HORUS_BAYESNODE_H
|
|
||||||
|
|
@ -1,790 +0,0 @@
|
|||||||
#include <cstdlib>
|
|
||||||
#include <limits>
|
|
||||||
#include <time.h>
|
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
|
|
||||||
#include <iostream>
|
|
||||||
#include <sstream>
|
|
||||||
#include <iomanip>
|
|
||||||
|
|
||||||
#include "BnBpSolver.h"
|
|
||||||
#include "Indexer.h"
|
|
||||||
|
|
||||||
BnBpSolver::BnBpSolver (const BayesNet& bn) : Solver (&bn)
|
|
||||||
{
|
|
||||||
bayesNet_ = &bn;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
BnBpSolver::~BnBpSolver (void)
|
|
||||||
{
|
|
||||||
for (unsigned i = 0; i < nodesI_.size(); i++) {
|
|
||||||
delete nodesI_[i];
|
|
||||||
}
|
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
|
||||||
delete links_[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BnBpSolver::runSolver (void)
|
|
||||||
{
|
|
||||||
clock_t start;
|
|
||||||
if (Constants::COLLECT_STATS) {
|
|
||||||
start = clock();
|
|
||||||
}
|
|
||||||
initializeSolver();
|
|
||||||
runLoopySolver();
|
|
||||||
if (Constants::DEBUG >= 2) {
|
|
||||||
cout << endl;
|
|
||||||
if (nIters_ < BpOptions::maxIter) {
|
|
||||||
cout << "Belief propagation converged in " ;
|
|
||||||
cout << nIters_ << " iterations" << endl;
|
|
||||||
} else {
|
|
||||||
cout << "The maximum number of iterations was hit, terminating..." ;
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
unsigned size = bayesNet_->nrNodes();
|
|
||||||
if (Constants::COLLECT_STATS) {
|
|
||||||
unsigned nIters = 0;
|
|
||||||
bool loopy = bayesNet_->isPolyTree() == false;
|
|
||||||
if (loopy) nIters = nIters_;
|
|
||||||
double time = (double (clock() - start)) / CLOCKS_PER_SEC;
|
|
||||||
Statistics::updateStatistics (size, loopy, nIters, time);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
|
||||||
BnBpSolver::getPosterioriOf (VarId vid)
|
|
||||||
{
|
|
||||||
BayesNode* node = bayesNet_->getBayesNode (vid);
|
|
||||||
assert (node);
|
|
||||||
return nodesI_[node->getIndex()]->getBeliefs();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
|
||||||
BnBpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
|
||||||
{
|
|
||||||
if (Constants::DEBUG >= 2) {
|
|
||||||
cout << "calculating joint distribution on: " ;
|
|
||||||
for (unsigned i = 0; i < jointVarIds.size(); i++) {
|
|
||||||
VarNode* var = bayesNet_->getBayesNode (jointVarIds[i]);
|
|
||||||
cout << var->label() << " " ;
|
|
||||||
}
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
return getJointByConditioning (jointVarIds);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BnBpSolver::initializeSolver (void)
|
|
||||||
{
|
|
||||||
const BnNodeSet& nodes = bayesNet_->getBayesNodes();
|
|
||||||
for (unsigned i = 0; i < nodesI_.size(); i++) {
|
|
||||||
delete nodesI_[i];
|
|
||||||
}
|
|
||||||
nodesI_.clear();
|
|
||||||
nodesI_.reserve (nodes.size());
|
|
||||||
links_.clear();
|
|
||||||
sortedOrder_.clear();
|
|
||||||
linkMap_.clear();
|
|
||||||
|
|
||||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
|
||||||
nodesI_.push_back (new BpNodeInfo (nodes[i]));
|
|
||||||
}
|
|
||||||
|
|
||||||
BnNodeSet roots = bayesNet_->getRootNodes();
|
|
||||||
for (unsigned i = 0; i < roots.size(); i++) {
|
|
||||||
const Params& params = roots[i]->params();
|
|
||||||
Params& piVals = ninf(roots[i])->getPiValues();
|
|
||||||
for (unsigned ri = 0; ri < roots[i]->nrStates(); ri++) {
|
|
||||||
piVals[ri] = params[ri];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
|
||||||
const BnNodeSet& parents = nodes[i]->getParents();
|
|
||||||
for (unsigned j = 0; j < parents.size(); j++) {
|
|
||||||
BpLink* newLink = new BpLink (
|
|
||||||
parents[j], nodes[i], LinkOrientation::DOWN);
|
|
||||||
links_.push_back (newLink);
|
|
||||||
ninf(nodes[i])->addIncomingParentLink (newLink);
|
|
||||||
ninf(parents[j])->addOutcomingChildLink (newLink);
|
|
||||||
}
|
|
||||||
const BnNodeSet& childs = nodes[i]->getChilds();
|
|
||||||
for (unsigned j = 0; j < childs.size(); j++) {
|
|
||||||
BpLink* newLink = new BpLink (
|
|
||||||
childs[j], nodes[i], LinkOrientation::UP);
|
|
||||||
links_.push_back (newLink);
|
|
||||||
ninf(nodes[i])->addIncomingChildLink (newLink);
|
|
||||||
ninf(childs[j])->addOutcomingParentLink (newLink);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
|
||||||
if (nodes[i]->hasEvidence()) {
|
|
||||||
Params& piVals = ninf(nodes[i])->getPiValues();
|
|
||||||
Params& ldVals = ninf(nodes[i])->getLambdaValues();
|
|
||||||
for (unsigned xi = 0; xi < nodes[i]->nrStates(); xi++) {
|
|
||||||
piVals[xi] = LogAware::noEvidence();
|
|
||||||
ldVals[xi] = LogAware::noEvidence();
|
|
||||||
}
|
|
||||||
piVals[nodes[i]->getEvidence()] = LogAware::withEvidence();
|
|
||||||
ldVals[nodes[i]->getEvidence()] = LogAware::withEvidence();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BnBpSolver::runLoopySolver()
|
|
||||||
{
|
|
||||||
nIters_ = 0;
|
|
||||||
while (!converged() && nIters_ < BpOptions::maxIter) {
|
|
||||||
|
|
||||||
nIters_++;
|
|
||||||
if (Constants::DEBUG >= 2) {
|
|
||||||
Util::printHeader ("Iteration " + nIters_);
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
switch (BpOptions::schedule) {
|
|
||||||
|
|
||||||
case BpOptions::Schedule::SEQ_RANDOM:
|
|
||||||
random_shuffle (links_.begin(), links_.end());
|
|
||||||
// no break
|
|
||||||
|
|
||||||
case BpOptions::Schedule::SEQ_FIXED:
|
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
|
||||||
calculateAndUpdateMessage (links_[i]);
|
|
||||||
updateValues (links_[i]);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
|
|
||||||
case BpOptions::Schedule::PARALLEL:
|
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
|
||||||
calculateMessage (links_[i]);
|
|
||||||
}
|
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
|
||||||
updateMessage (links_[i]);
|
|
||||||
updateValues (links_[i]);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
|
|
||||||
case BpOptions::Schedule::MAX_RESIDUAL:
|
|
||||||
maxResidualSchedule();
|
|
||||||
break;
|
|
||||||
|
|
||||||
}
|
|
||||||
if (Constants::DEBUG >= 2) {
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
|
||||||
BnBpSolver::converged (void) const
|
|
||||||
{
|
|
||||||
// this can happen if the graph is fully disconnected
|
|
||||||
if (links_.size() == 0) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
if (nIters_ == 0 || nIters_ == 1) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
bool converged = true;
|
|
||||||
if (BpOptions::schedule == BpOptions::Schedule::MAX_RESIDUAL) {
|
|
||||||
double maxResidual = (*(sortedOrder_.begin()))->getResidual();
|
|
||||||
if (maxResidual < BpOptions::accuracy) {
|
|
||||||
converged = true;
|
|
||||||
} else {
|
|
||||||
converged = false;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
|
||||||
double residual = links_[i]->getResidual();
|
|
||||||
if (Constants::DEBUG >= 2) {
|
|
||||||
cout << links_[i]->toString() + " residual change = " ;
|
|
||||||
cout << residual << endl;
|
|
||||||
}
|
|
||||||
if (residual > BpOptions::accuracy) {
|
|
||||||
converged = false;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return converged;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BnBpSolver::maxResidualSchedule (void)
|
|
||||||
{
|
|
||||||
if (nIters_ == 1) {
|
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
|
||||||
calculateMessage (links_[i]);
|
|
||||||
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
|
|
||||||
linkMap_.insert (make_pair (links_[i], it));
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (unsigned c = 0; c < sortedOrder_.size(); c++) {
|
|
||||||
if (Constants::DEBUG >= 2) {
|
|
||||||
cout << "current residuals:" << endl;
|
|
||||||
for (SortedOrder::iterator it = sortedOrder_.begin();
|
|
||||||
it != sortedOrder_.end(); it ++) {
|
|
||||||
cout << " " << setw (30) << left << (*it)->toString();
|
|
||||||
cout << "residual = " << (*it)->getResidual() << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
SortedOrder::iterator it = sortedOrder_.begin();
|
|
||||||
BpLink* link = *it;
|
|
||||||
if (link->getResidual() < BpOptions::accuracy) {
|
|
||||||
sortedOrder_.erase (it);
|
|
||||||
it = sortedOrder_.begin();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
updateMessage (link);
|
|
||||||
updateValues (link);
|
|
||||||
link->clearResidual();
|
|
||||||
sortedOrder_.erase (it);
|
|
||||||
linkMap_.find (link)->second = sortedOrder_.insert (link);
|
|
||||||
|
|
||||||
const BpLinkSet& outParentLinks =
|
|
||||||
ninf(link->getDestination())->getOutcomingParentLinks();
|
|
||||||
for (unsigned i = 0; i < outParentLinks.size(); i++) {
|
|
||||||
if (outParentLinks[i]->getDestination() != link->getSource()
|
|
||||||
&& outParentLinks[i]->getDestination()->hasEvidence() == false) {
|
|
||||||
calculateMessage (outParentLinks[i]);
|
|
||||||
BpLinkMap::iterator iter = linkMap_.find (outParentLinks[i]);
|
|
||||||
sortedOrder_.erase (iter->second);
|
|
||||||
iter->second = sortedOrder_.insert (outParentLinks[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
const BpLinkSet& outChildLinks =
|
|
||||||
ninf(link->getDestination())->getOutcomingChildLinks();
|
|
||||||
for (unsigned i = 0; i < outChildLinks.size(); i++) {
|
|
||||||
if (outChildLinks[i]->getDestination() != link->getSource()) {
|
|
||||||
calculateMessage (outChildLinks[i]);
|
|
||||||
BpLinkMap::iterator iter = linkMap_.find (outChildLinks[i]);
|
|
||||||
sortedOrder_.erase (iter->second);
|
|
||||||
iter->second = sortedOrder_.insert (outChildLinks[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (Constants::DEBUG >= 2) {
|
|
||||||
Util::printDashedLine();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BnBpSolver::updatePiValues (BayesNode* x)
|
|
||||||
{
|
|
||||||
// π(Xi)
|
|
||||||
if (Constants::DEBUG >= 3) {
|
|
||||||
cout << "updating " << PI_SYMBOL << " values for " << x->label() << endl;
|
|
||||||
}
|
|
||||||
Params& piValues = ninf(x)->getPiValues();
|
|
||||||
const BpLinkSet& parentLinks = ninf(x)->getIncomingParentLinks();
|
|
||||||
const BnNodeSet& ps = x->getParents();
|
|
||||||
Ranges ranges;
|
|
||||||
for (unsigned i = 0; i < ps.size(); i++) {
|
|
||||||
ranges.push_back (ps[i]->nrStates());
|
|
||||||
}
|
|
||||||
StatesIndexer indexer (ranges, false);
|
|
||||||
stringstream* calcs1 = 0;
|
|
||||||
stringstream* calcs2 = 0;
|
|
||||||
|
|
||||||
Params messageProducts (indexer.size());
|
|
||||||
for (unsigned k = 0; k < indexer.size(); k++) {
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
calcs1 = new stringstream;
|
|
||||||
calcs2 = new stringstream;
|
|
||||||
}
|
|
||||||
double messageProduct = LogAware::multIdenty();
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
for (unsigned i = 0; i < parentLinks.size(); i++) {
|
|
||||||
messageProduct += parentLinks[i]->getMessage()[indexer[i]];
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (unsigned i = 0; i < parentLinks.size(); i++) {
|
|
||||||
messageProduct *= parentLinks[i]->getMessage()[indexer[i]];
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
if (i != 0) *calcs1 << " + " ;
|
|
||||||
if (i != 0) *calcs2 << " + " ;
|
|
||||||
*calcs1 << parentLinks[i]->toString (indexer[i]);
|
|
||||||
*calcs2 << parentLinks[i]->getMessage()[indexer[i]];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
messageProducts[k] = messageProduct;
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
cout << " mp" << k;
|
|
||||||
cout << " = " << (*calcs1).str();
|
|
||||||
if (parentLinks.size() == 1) {
|
|
||||||
cout << " = " << messageProduct << endl;
|
|
||||||
} else {
|
|
||||||
cout << " = " << (*calcs2).str();
|
|
||||||
cout << " = " << messageProduct << endl;
|
|
||||||
}
|
|
||||||
delete calcs1;
|
|
||||||
delete calcs2;
|
|
||||||
}
|
|
||||||
++ indexer;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (unsigned xi = 0; xi < x->nrStates(); xi++) {
|
|
||||||
double sum = LogAware::addIdenty();
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
calcs1 = new stringstream;
|
|
||||||
calcs2 = new stringstream;
|
|
||||||
}
|
|
||||||
indexer.reset();
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
for (unsigned k = 0; k < indexer.size(); k++) {
|
|
||||||
sum = Util::logSum (sum,
|
|
||||||
x->getProbability(xi, indexer) + messageProducts[k]);
|
|
||||||
++ indexer;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (unsigned k = 0; k < indexer.size(); k++) {
|
|
||||||
sum += x->getProbability (xi, indexer) * messageProducts[k];
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
if (k != 0) *calcs1 << " + " ;
|
|
||||||
if (k != 0) *calcs2 << " + " ;
|
|
||||||
*calcs1 << x->cptEntryToString (xi, indexer.indices());
|
|
||||||
*calcs1 << ".mp" << k;
|
|
||||||
*calcs2 << LogAware::fl (x->getProbability (xi, indexer));
|
|
||||||
*calcs2 << "*" << messageProducts[k];
|
|
||||||
}
|
|
||||||
++ indexer;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
piValues[xi] = sum;
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
cout << " " << PI_SYMBOL << "(" << x->label() << ")" ;
|
|
||||||
cout << "[" << x->states()[xi] << "]" ;
|
|
||||||
cout << " = " << (*calcs1).str();
|
|
||||||
cout << " = " << (*calcs2).str();
|
|
||||||
cout << " = " << piValues[xi] << endl;
|
|
||||||
delete calcs1;
|
|
||||||
delete calcs2;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BnBpSolver::updateLambdaValues (BayesNode* x)
|
|
||||||
{
|
|
||||||
// λ(Xi)
|
|
||||||
if (Constants::DEBUG >= 3) {
|
|
||||||
cout << "updating " << LD_SYMBOL << " values for " << x->label() << endl;
|
|
||||||
}
|
|
||||||
Params& lambdaValues = ninf(x)->getLambdaValues();
|
|
||||||
const BpLinkSet& childLinks = ninf(x)->getIncomingChildLinks();
|
|
||||||
stringstream* calcs1 = 0;
|
|
||||||
stringstream* calcs2 = 0;
|
|
||||||
|
|
||||||
for (unsigned xi = 0; xi < x->nrStates(); xi++) {
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
calcs1 = new stringstream;
|
|
||||||
calcs2 = new stringstream;
|
|
||||||
}
|
|
||||||
double product = LogAware::multIdenty();
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
for (unsigned i = 0; i < childLinks.size(); i++) {
|
|
||||||
product += childLinks[i]->getMessage()[xi];
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (unsigned i = 0; i < childLinks.size(); i++) {
|
|
||||||
product *= childLinks[i]->getMessage()[xi];
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
if (i != 0) *calcs1 << "." ;
|
|
||||||
if (i != 0) *calcs2 << "*" ;
|
|
||||||
*calcs1 << childLinks[i]->toString (xi);
|
|
||||||
*calcs2 << childLinks[i]->getMessage()[xi];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
lambdaValues[xi] = product;
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
cout << " " << LD_SYMBOL << "(" << x->label() << ")" ;
|
|
||||||
cout << "[" << x->states()[xi] << "]" ;
|
|
||||||
cout << " = " << (*calcs1).str();
|
|
||||||
if (childLinks.size() == 1) {
|
|
||||||
cout << " = " << product << endl;
|
|
||||||
} else {
|
|
||||||
cout << " = " << (*calcs2).str();
|
|
||||||
cout << " = " << lambdaValues[xi] << endl;
|
|
||||||
}
|
|
||||||
delete calcs1;
|
|
||||||
delete calcs2;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BnBpSolver::calculatePiMessage (BpLink* link)
|
|
||||||
{
|
|
||||||
// πX(Zi)
|
|
||||||
BayesNode* z = link->getSource();
|
|
||||||
BayesNode* x = link->getDestination();
|
|
||||||
Params& zxPiNextMessage = link->getNextMessage();
|
|
||||||
const BpLinkSet& zChildLinks = ninf(z)->getIncomingChildLinks();
|
|
||||||
stringstream* calcs1 = 0;
|
|
||||||
stringstream* calcs2 = 0;
|
|
||||||
|
|
||||||
const Params& zPiValues = ninf(z)->getPiValues();
|
|
||||||
for (unsigned zi = 0; zi < z->nrStates(); zi++) {
|
|
||||||
double product = zPiValues[zi];
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
calcs1 = new stringstream;
|
|
||||||
calcs2 = new stringstream;
|
|
||||||
*calcs1 << PI_SYMBOL << "(" << z->label() << ")";
|
|
||||||
*calcs1 << "[" << z->states()[zi] << "]" ;
|
|
||||||
*calcs2 << product;
|
|
||||||
}
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
for (unsigned i = 0; i < zChildLinks.size(); i++) {
|
|
||||||
if (zChildLinks[i]->getSource() != x) {
|
|
||||||
product += zChildLinks[i]->getMessage()[zi];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (unsigned i = 0; i < zChildLinks.size(); i++) {
|
|
||||||
if (zChildLinks[i]->getSource() != x) {
|
|
||||||
product *= zChildLinks[i]->getMessage()[zi];
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
*calcs1 << "." << zChildLinks[i]->toString (zi);
|
|
||||||
*calcs2 << " * " << zChildLinks[i]->getMessage()[zi];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
zxPiNextMessage[zi] = product;
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
cout << " " << link->toString();
|
|
||||||
cout << "[" << z->states()[zi] << "]" ;
|
|
||||||
cout << " = " << (*calcs1).str();
|
|
||||||
if (zChildLinks.size() == 1) {
|
|
||||||
cout << " = " << product << endl;
|
|
||||||
} else {
|
|
||||||
cout << " = " << (*calcs2).str();
|
|
||||||
cout << " = " << product << endl;
|
|
||||||
}
|
|
||||||
delete calcs1;
|
|
||||||
delete calcs2;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
LogAware::normalize (zxPiNextMessage);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BnBpSolver::calculateLambdaMessage (BpLink* link)
|
|
||||||
{
|
|
||||||
// λY(Xi)
|
|
||||||
BayesNode* y = link->getSource();
|
|
||||||
BayesNode* x = link->getDestination();
|
|
||||||
if (x->hasEvidence()) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
Params& yxLambdaNextMessage = link->getNextMessage();
|
|
||||||
const BpLinkSet& yParentLinks = ninf(y)->getIncomingParentLinks();
|
|
||||||
const Params& yLambdaValues = ninf(y)->getLambdaValues();
|
|
||||||
int parentIndex = y->indexOfParent (x);
|
|
||||||
stringstream* calcs1 = 0;
|
|
||||||
stringstream* calcs2 = 0;
|
|
||||||
|
|
||||||
const BnNodeSet& ps = y->getParents();
|
|
||||||
Ranges ranges;
|
|
||||||
for (unsigned i = 0; i < ps.size(); i++) {
|
|
||||||
ranges.push_back (ps[i]->nrStates());
|
|
||||||
}
|
|
||||||
StatesIndexer indexer (ranges, false);
|
|
||||||
|
|
||||||
|
|
||||||
unsigned N = indexer.size() / x->nrStates();
|
|
||||||
Params messageProducts (N);
|
|
||||||
for (unsigned k = 0; k < N; k++) {
|
|
||||||
while (indexer[parentIndex] != 0) {
|
|
||||||
++ indexer;
|
|
||||||
}
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
calcs1 = new stringstream;
|
|
||||||
calcs2 = new stringstream;
|
|
||||||
}
|
|
||||||
double messageProduct = LogAware::multIdenty();
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
for (unsigned i = 0; i < yParentLinks.size(); i++) {
|
|
||||||
if (yParentLinks[i]->getSource() != x) {
|
|
||||||
messageProduct += yParentLinks[i]->getMessage()[indexer[i]];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (unsigned i = 0; i < yParentLinks.size(); i++) {
|
|
||||||
if (yParentLinks[i]->getSource() != x) {
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
if (messageProduct != LogAware::multIdenty()) *calcs1 << "*" ;
|
|
||||||
if (messageProduct != LogAware::multIdenty()) *calcs2 << "*" ;
|
|
||||||
*calcs1 << yParentLinks[i]->toString (indexer[i]);
|
|
||||||
*calcs2 << yParentLinks[i]->getMessage()[indexer[i]];
|
|
||||||
}
|
|
||||||
messageProduct *= yParentLinks[i]->getMessage()[indexer[i]];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
messageProducts[k] = messageProduct;
|
|
||||||
++ indexer;
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
cout << " mp" << k;
|
|
||||||
cout << " = " << (*calcs1).str();
|
|
||||||
if (yParentLinks.size() == 1) {
|
|
||||||
cout << 1 << endl;
|
|
||||||
} else if (yParentLinks.size() == 2) {
|
|
||||||
cout << " = " << messageProduct << endl;
|
|
||||||
} else {
|
|
||||||
cout << " = " << (*calcs2).str();
|
|
||||||
cout << " = " << messageProduct << endl;
|
|
||||||
}
|
|
||||||
delete calcs1;
|
|
||||||
delete calcs2;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (unsigned xi = 0; xi < x->nrStates(); xi++) {
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
calcs1 = new stringstream;
|
|
||||||
calcs2 = new stringstream;
|
|
||||||
}
|
|
||||||
double outerSum = LogAware::addIdenty();
|
|
||||||
for (unsigned yi = 0; yi < y->nrStates(); yi++) {
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
(yi != 0) ? *calcs1 << " + {" : *calcs1 << "{" ;
|
|
||||||
(yi != 0) ? *calcs2 << " + {" : *calcs2 << "{" ;
|
|
||||||
}
|
|
||||||
double innerSum = LogAware::addIdenty();
|
|
||||||
indexer.reset();
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
for (unsigned k = 0; k < N; k++) {
|
|
||||||
while (indexer[parentIndex] != xi) {
|
|
||||||
++ indexer;
|
|
||||||
}
|
|
||||||
innerSum = Util::logSum (innerSum,
|
|
||||||
y->getProbability (yi, indexer) + messageProducts[k]);
|
|
||||||
++ indexer;
|
|
||||||
}
|
|
||||||
outerSum = Util::logSum (outerSum, innerSum + yLambdaValues[yi]);
|
|
||||||
} else {
|
|
||||||
for (unsigned k = 0; k < N; k++) {
|
|
||||||
while (indexer[parentIndex] != xi) {
|
|
||||||
++ indexer;
|
|
||||||
}
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
if (k != 0) *calcs1 << " + " ;
|
|
||||||
if (k != 0) *calcs2 << " + " ;
|
|
||||||
*calcs1 << y->cptEntryToString (yi, indexer.indices());
|
|
||||||
*calcs1 << ".mp" << k;
|
|
||||||
*calcs2 << y->getProbability (yi, indexer);
|
|
||||||
*calcs2 << "*" << messageProducts[k];
|
|
||||||
}
|
|
||||||
innerSum += y->getProbability (yi, indexer) * messageProducts[k];
|
|
||||||
++ indexer;
|
|
||||||
}
|
|
||||||
outerSum += innerSum * yLambdaValues[yi];
|
|
||||||
}
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
*calcs1 << "}." << LD_SYMBOL << "(" << y->label() << ")" ;
|
|
||||||
*calcs1 << "[" << y->states()[yi] << "]";
|
|
||||||
*calcs2 << "}*" << yLambdaValues[yi];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
yxLambdaNextMessage[xi] = outerSum;
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
cout << " " << link->toString();
|
|
||||||
cout << "[" << x->states()[xi] << "]" ;
|
|
||||||
cout << " = " << (*calcs1).str();
|
|
||||||
cout << " = " << (*calcs2).str();
|
|
||||||
cout << " = " << yxLambdaNextMessage[xi] << endl;
|
|
||||||
delete calcs1;
|
|
||||||
delete calcs2;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
LogAware::normalize (yxLambdaNextMessage);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
|
||||||
BnBpSolver::getJointByConditioning (const VarIds& jointVarIds) const
|
|
||||||
{
|
|
||||||
BnNodeSet jointVars;
|
|
||||||
for (unsigned i = 0; i < jointVarIds.size(); i++) {
|
|
||||||
assert (bayesNet_->getBayesNode (jointVarIds[i]));
|
|
||||||
jointVars.push_back (bayesNet_->getBayesNode (jointVarIds[i]));
|
|
||||||
}
|
|
||||||
|
|
||||||
BayesNet* mrn = bayesNet_->getMinimalRequesiteNetwork (jointVarIds[0]);
|
|
||||||
BnBpSolver solver (*mrn);
|
|
||||||
solver.runSolver();
|
|
||||||
Params prevBeliefs = solver.getPosterioriOf (jointVarIds[0]);
|
|
||||||
delete mrn;
|
|
||||||
|
|
||||||
VarIds observedVids = {jointVars[0]->varId()};
|
|
||||||
|
|
||||||
for (unsigned i = 1; i < jointVarIds.size(); i++) {
|
|
||||||
assert (jointVars[i]->hasEvidence() == false);
|
|
||||||
VarIds reqVars = {jointVarIds[i]};
|
|
||||||
Util::addToVector (reqVars, observedVids);
|
|
||||||
mrn = bayesNet_->getMinimalRequesiteNetwork (reqVars);
|
|
||||||
Params newBeliefs;
|
|
||||||
VarNodes observedVars;
|
|
||||||
for (unsigned j = 0; j < observedVids.size(); j++) {
|
|
||||||
observedVars.push_back (mrn->getBayesNode (observedVids[j]));
|
|
||||||
}
|
|
||||||
StatesIndexer idx (observedVars, false);
|
|
||||||
while (idx.valid()) {
|
|
||||||
for (unsigned j = 0; j < observedVars.size(); j++) {
|
|
||||||
observedVars[j]->setEvidence (idx[j]);
|
|
||||||
}
|
|
||||||
BnBpSolver solver (*mrn);
|
|
||||||
solver.runSolver();
|
|
||||||
Params beliefs = solver.getPosterioriOf (jointVarIds[i]);
|
|
||||||
for (unsigned k = 0; k < beliefs.size(); k++) {
|
|
||||||
newBeliefs.push_back (beliefs[k]);
|
|
||||||
}
|
|
||||||
++ idx;
|
|
||||||
}
|
|
||||||
|
|
||||||
int count = -1;
|
|
||||||
for (unsigned j = 0; j < newBeliefs.size(); j++) {
|
|
||||||
if (j % jointVars[i]->nrStates() == 0) {
|
|
||||||
count ++;
|
|
||||||
}
|
|
||||||
newBeliefs[j] *= prevBeliefs[count];
|
|
||||||
}
|
|
||||||
prevBeliefs = newBeliefs;
|
|
||||||
observedVids.push_back (jointVars[i]->varId());
|
|
||||||
delete mrn;
|
|
||||||
}
|
|
||||||
return prevBeliefs;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BnBpSolver::printPiLambdaValues (const BayesNode* var) const
|
|
||||||
{
|
|
||||||
cout << left;
|
|
||||||
cout << setw (10) << "states" ;
|
|
||||||
cout << setw (20) << PI_SYMBOL << "(" + var->label() + ")" ;
|
|
||||||
cout << setw (20) << LD_SYMBOL << "(" + var->label() + ")" ;
|
|
||||||
cout << setw (16) << "belief" ;
|
|
||||||
cout << endl;
|
|
||||||
Util::printDashedLine();
|
|
||||||
cout << endl;
|
|
||||||
const States& states = var->states();
|
|
||||||
const Params& piVals = ninf(var)->getPiValues();
|
|
||||||
const Params& ldVals = ninf(var)->getLambdaValues();
|
|
||||||
const Params& beliefs = ninf(var)->getBeliefs();
|
|
||||||
for (unsigned xi = 0; xi < var->nrStates(); xi++) {
|
|
||||||
cout << setw (10) << states[xi];
|
|
||||||
cout << setw (19) << piVals[xi];
|
|
||||||
cout << setw (19) << ldVals[xi];
|
|
||||||
cout.precision (Constants::PRECISION);
|
|
||||||
cout << setw (16) << beliefs[xi];
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BnBpSolver::printAllMessageStatus (void) const
|
|
||||||
{
|
|
||||||
const BnNodeSet& nodes = bayesNet_->getBayesNodes();
|
|
||||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
|
||||||
printPiLambdaValues (nodes[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
BpNodeInfo::BpNodeInfo (BayesNode* node)
|
|
||||||
{
|
|
||||||
node_ = node;
|
|
||||||
piVals_.resize (node->nrStates(), LogAware::one());
|
|
||||||
ldVals_.resize (node->nrStates(), LogAware::one());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
|
||||||
BpNodeInfo::getBeliefs (void) const
|
|
||||||
{
|
|
||||||
double sum = 0.0;
|
|
||||||
Params beliefs (node_->nrStates());
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
for (unsigned xi = 0; xi < node_->nrStates(); xi++) {
|
|
||||||
beliefs[xi] = exp (piVals_[xi] + ldVals_[xi]);
|
|
||||||
sum += beliefs[xi];
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (unsigned xi = 0; xi < node_->nrStates(); xi++) {
|
|
||||||
beliefs[xi] = piVals_[xi] * ldVals_[xi];
|
|
||||||
sum += beliefs[xi];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
assert (sum);
|
|
||||||
for (unsigned xi = 0; xi < node_->nrStates(); xi++) {
|
|
||||||
beliefs[xi] /= sum;
|
|
||||||
}
|
|
||||||
return beliefs;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
|
||||||
BpNodeInfo::receivedBottomInfluence (void) const
|
|
||||||
{
|
|
||||||
// if all lambda values are equal, then neither
|
|
||||||
// this node neither its descendents have evidence,
|
|
||||||
// we can use this to don't send lambda messages his parents
|
|
||||||
bool childInfluenced = false;
|
|
||||||
for (unsigned xi = 1; xi < node_->nrStates(); xi++) {
|
|
||||||
if (ldVals_[xi] != ldVals_[0]) {
|
|
||||||
childInfluenced = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return childInfluenced;
|
|
||||||
}
|
|
||||||
|
|
@ -1,271 +0,0 @@
|
|||||||
#ifndef HORUS_BNBPSOLVER_H
|
|
||||||
#define HORUS_BNBPSOLVER_H
|
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
#include <set>
|
|
||||||
|
|
||||||
#include "Solver.h"
|
|
||||||
#include "BayesNet.h"
|
|
||||||
#include "Horus.h"
|
|
||||||
#include "Util.h"
|
|
||||||
|
|
||||||
using namespace std;
|
|
||||||
|
|
||||||
class BpNodeInfo;
|
|
||||||
|
|
||||||
static const string PI_SYMBOL = "pi" ;
|
|
||||||
static const string LD_SYMBOL = "ld" ;
|
|
||||||
|
|
||||||
enum LinkOrientation {UP, DOWN};
|
|
||||||
|
|
||||||
class BpLink
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
BpLink (BayesNode* s, BayesNode* d, LinkOrientation o)
|
|
||||||
{
|
|
||||||
source_ = s;
|
|
||||||
destin_ = d;
|
|
||||||
orientation_ = o;
|
|
||||||
if (orientation_ == LinkOrientation::DOWN) {
|
|
||||||
v1_.resize (s->nrStates(), LogAware::tl (1.0 / s->nrStates()));
|
|
||||||
v2_.resize (s->nrStates(), LogAware::tl (1.0 / s->nrStates()));
|
|
||||||
} else {
|
|
||||||
v1_.resize (d->nrStates(), LogAware::tl (1.0 / d->nrStates()));
|
|
||||||
v2_.resize (d->nrStates(), LogAware::tl (1.0 / d->nrStates()));
|
|
||||||
}
|
|
||||||
currMsg_ = &v1_;
|
|
||||||
nextMsg_ = &v2_;
|
|
||||||
residual_ = 0;
|
|
||||||
msgSended_ = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
BayesNode* getSource (void) const { return source_; }
|
|
||||||
|
|
||||||
BayesNode* getDestination (void) const { return destin_; }
|
|
||||||
|
|
||||||
LinkOrientation getOrientation (void) const { return orientation_; }
|
|
||||||
|
|
||||||
const Params& getMessage (void) const { return *currMsg_; }
|
|
||||||
|
|
||||||
Params& getNextMessage (void) { return *nextMsg_;}
|
|
||||||
|
|
||||||
bool messageWasSended (void) const { return msgSended_; }
|
|
||||||
|
|
||||||
double getResidual (void) const { return residual_; }
|
|
||||||
|
|
||||||
void clearResidual (void) { residual_ = 0;}
|
|
||||||
|
|
||||||
void updateMessage (void)
|
|
||||||
{
|
|
||||||
swap (currMsg_, nextMsg_);
|
|
||||||
msgSended_ = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
void updateResidual (void)
|
|
||||||
{
|
|
||||||
residual_ = LogAware::getMaxNorm (v1_, v2_);
|
|
||||||
}
|
|
||||||
|
|
||||||
string toString (void) const
|
|
||||||
{
|
|
||||||
stringstream ss;
|
|
||||||
if (orientation_ == LinkOrientation::DOWN) {
|
|
||||||
ss << PI_SYMBOL;
|
|
||||||
} else {
|
|
||||||
ss << LD_SYMBOL;
|
|
||||||
}
|
|
||||||
ss << "(" << source_->label();
|
|
||||||
ss << " --> " << destin_->label() << ")" ;
|
|
||||||
return ss.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
string toString (unsigned stateIndex) const
|
|
||||||
{
|
|
||||||
stringstream ss;
|
|
||||||
ss << toString() << "[" ;
|
|
||||||
if (orientation_ == LinkOrientation::DOWN) {
|
|
||||||
ss << source_->states()[stateIndex] << "]" ;
|
|
||||||
} else {
|
|
||||||
ss << destin_->states()[stateIndex] << "]" ;
|
|
||||||
}
|
|
||||||
return ss.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
BayesNode* source_;
|
|
||||||
BayesNode* destin_;
|
|
||||||
LinkOrientation orientation_;
|
|
||||||
Params v1_;
|
|
||||||
Params v2_;
|
|
||||||
Params* currMsg_;
|
|
||||||
Params* nextMsg_;
|
|
||||||
bool msgSended_;
|
|
||||||
double residual_;
|
|
||||||
};
|
|
||||||
|
|
||||||
typedef vector<BpLink*> BpLinkSet;
|
|
||||||
|
|
||||||
|
|
||||||
class BpNodeInfo
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
BpNodeInfo (BayesNode*);
|
|
||||||
|
|
||||||
Params& getPiValues (void) { return piVals_; }
|
|
||||||
|
|
||||||
Params& getLambdaValues (void) { return ldVals_; }
|
|
||||||
|
|
||||||
const BpLinkSet& getIncomingParentLinks (void) { return inParentLinks_; }
|
|
||||||
|
|
||||||
const BpLinkSet& getIncomingChildLinks (void) { return inChildLinks_; }
|
|
||||||
|
|
||||||
const BpLinkSet& getOutcomingParentLinks (void) { return outParentLinks_; }
|
|
||||||
|
|
||||||
const BpLinkSet& getOutcomingChildLinks (void) { return outChildLinks_; }
|
|
||||||
|
|
||||||
void addIncomingParentLink (BpLink* l) { inParentLinks_.push_back (l); }
|
|
||||||
|
|
||||||
void addIncomingChildLink (BpLink* l) { inChildLinks_.push_back (l); }
|
|
||||||
|
|
||||||
void addOutcomingParentLink (BpLink* l) { outParentLinks_.push_back (l); }
|
|
||||||
|
|
||||||
void addOutcomingChildLink (BpLink* l) { outChildLinks_.push_back (l); }
|
|
||||||
|
|
||||||
Params getBeliefs (void) const;
|
|
||||||
|
|
||||||
bool receivedBottomInfluence (void) const;
|
|
||||||
|
|
||||||
|
|
||||||
private:
|
|
||||||
DISALLOW_COPY_AND_ASSIGN (BpNodeInfo);
|
|
||||||
|
|
||||||
const BayesNode* node_;
|
|
||||||
Params piVals_;
|
|
||||||
Params ldVals_;
|
|
||||||
BpLinkSet inParentLinks_;
|
|
||||||
BpLinkSet inChildLinks_;
|
|
||||||
BpLinkSet outParentLinks_;
|
|
||||||
BpLinkSet outChildLinks_;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class BnBpSolver : public Solver
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
BnBpSolver (const BayesNet&);
|
|
||||||
|
|
||||||
~BnBpSolver (void);
|
|
||||||
|
|
||||||
void runSolver (void);
|
|
||||||
Params getPosterioriOf (VarId);
|
|
||||||
Params getJointDistributionOf (const VarIds&);
|
|
||||||
|
|
||||||
private:
|
|
||||||
DISALLOW_COPY_AND_ASSIGN (BnBpSolver);
|
|
||||||
|
|
||||||
void initializeSolver (void);
|
|
||||||
|
|
||||||
void runLoopySolver (void);
|
|
||||||
|
|
||||||
void maxResidualSchedule (void);
|
|
||||||
|
|
||||||
bool converged (void) const;
|
|
||||||
|
|
||||||
void updatePiValues (BayesNode*);
|
|
||||||
|
|
||||||
void updateLambdaValues (BayesNode*);
|
|
||||||
|
|
||||||
void calculateLambdaMessage (BpLink*);
|
|
||||||
|
|
||||||
void calculatePiMessage (BpLink*);
|
|
||||||
|
|
||||||
Params getJointByJunctionNode (const VarIds&);
|
|
||||||
|
|
||||||
Params getJointByConditioning (const VarIds&) const;
|
|
||||||
|
|
||||||
void printPiLambdaValues (const BayesNode*) const;
|
|
||||||
|
|
||||||
void printAllMessageStatus (void) const;
|
|
||||||
|
|
||||||
void calculateAndUpdateMessage (BpLink* link, bool calcResidual = true)
|
|
||||||
{
|
|
||||||
if (Constants::DEBUG >= 3) {
|
|
||||||
cout << "calculating & updating " << link->toString() << endl;
|
|
||||||
}
|
|
||||||
if (link->getOrientation() == LinkOrientation::DOWN) {
|
|
||||||
calculatePiMessage (link);
|
|
||||||
} else if (link->getOrientation() == LinkOrientation::UP) {
|
|
||||||
calculateLambdaMessage (link);
|
|
||||||
}
|
|
||||||
if (calcResidual) {
|
|
||||||
link->updateResidual();
|
|
||||||
}
|
|
||||||
link->updateMessage();
|
|
||||||
}
|
|
||||||
|
|
||||||
void calculateMessage (BpLink* link, bool calcResidual = true)
|
|
||||||
{
|
|
||||||
if (Constants::DEBUG >= 3) {
|
|
||||||
cout << "calculating " << link->toString() << endl;
|
|
||||||
}
|
|
||||||
if (link->getOrientation() == LinkOrientation::DOWN) {
|
|
||||||
calculatePiMessage (link);
|
|
||||||
} else if (link->getOrientation() == LinkOrientation::UP) {
|
|
||||||
calculateLambdaMessage (link);
|
|
||||||
}
|
|
||||||
if (calcResidual) {
|
|
||||||
link->updateResidual();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void updateMessage (BpLink* link)
|
|
||||||
{
|
|
||||||
if (Constants::DEBUG >= 3) {
|
|
||||||
cout << "updating " << link->toString() << endl;
|
|
||||||
}
|
|
||||||
link->updateMessage();
|
|
||||||
}
|
|
||||||
|
|
||||||
void updateValues (BpLink* link)
|
|
||||||
{
|
|
||||||
if (!link->getDestination()->hasEvidence()) {
|
|
||||||
if (link->getOrientation() == LinkOrientation::DOWN) {
|
|
||||||
updatePiValues (link->getDestination());
|
|
||||||
} else if (link->getOrientation() == LinkOrientation::UP) {
|
|
||||||
updateLambdaValues (link->getDestination());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
BpNodeInfo* ninf (const BayesNode* node) const
|
|
||||||
{
|
|
||||||
assert (node);
|
|
||||||
assert (node == bayesNet_->getBayesNode (node->varId()));
|
|
||||||
assert (node->getIndex() < nodesI_.size());
|
|
||||||
return nodesI_[node->getIndex()];
|
|
||||||
}
|
|
||||||
|
|
||||||
const BayesNet* bayesNet_;
|
|
||||||
vector<BpLink*> links_;
|
|
||||||
vector<BpNodeInfo*> nodesI_;
|
|
||||||
unsigned nIters_;
|
|
||||||
|
|
||||||
struct compare
|
|
||||||
{
|
|
||||||
inline bool operator() (const BpLink* e1, const BpLink* e2)
|
|
||||||
{
|
|
||||||
return e1->getResidual() > e2->getResidual();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
typedef multiset<BpLink*, compare> SortedOrder;
|
|
||||||
SortedOrder sortedOrder_;
|
|
||||||
|
|
||||||
typedef unordered_map<BpLink*, SortedOrder::iterator> BpLinkMap;
|
|
||||||
BpLinkMap linkMap_;
|
|
||||||
|
|
||||||
};
|
|
||||||
|
|
||||||
#endif // HORUS_BNBPSOLVER_H
|
|
||||||
|
|
@ -5,21 +5,22 @@
|
|||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
||||||
#include "FgBpSolver.h"
|
#include "BpSolver.h"
|
||||||
#include "FactorGraph.h"
|
#include "FactorGraph.h"
|
||||||
#include "Factor.h"
|
#include "Factor.h"
|
||||||
#include "Indexer.h"
|
#include "Indexer.h"
|
||||||
#include "Horus.h"
|
#include "Horus.h"
|
||||||
|
|
||||||
|
|
||||||
FgBpSolver::FgBpSolver (const FactorGraph& fg) : Solver (&fg)
|
BpSolver::BpSolver (const FactorGraph& fg) : Solver (fg)
|
||||||
{
|
{
|
||||||
factorGraph_ = &fg;
|
fg_ = &fg;
|
||||||
|
runned_ = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
FgBpSolver::~FgBpSolver (void)
|
BpSolver::~BpSolver (void)
|
||||||
{
|
{
|
||||||
for (unsigned i = 0; i < varsI_.size(); i++) {
|
for (unsigned i = 0; i < varsI_.size(); i++) {
|
||||||
delete varsI_[i];
|
delete varsI_[i];
|
||||||
@ -34,47 +35,33 @@ FgBpSolver::~FgBpSolver (void)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
Params
|
||||||
FgBpSolver::runSolver (void)
|
BpSolver::solveQuery (VarIds queryVids)
|
||||||
{
|
{
|
||||||
clock_t start;
|
assert (queryVids.empty() == false);
|
||||||
if (Constants::COLLECT_STATS) {
|
if (queryVids.size() == 1) {
|
||||||
start = clock();
|
return getPosterioriOf (queryVids[0]);
|
||||||
}
|
} else {
|
||||||
runLoopySolver();
|
return getJointDistributionOf (queryVids);
|
||||||
if (Constants::DEBUG >= 2) {
|
|
||||||
cout << endl;
|
|
||||||
if (nIters_ < BpOptions::maxIter) {
|
|
||||||
cout << "Sum-Product converged in " ;
|
|
||||||
cout << nIters_ << " iterations" << endl;
|
|
||||||
} else {
|
|
||||||
cout << "The maximum number of iterations was hit, terminating..." ;
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
unsigned size = factorGraph_->getVarNodes().size();
|
|
||||||
if (Constants::COLLECT_STATS) {
|
|
||||||
unsigned nIters = 0;
|
|
||||||
bool loopy = factorGraph_->isTree() == false;
|
|
||||||
if (loopy) nIters = nIters_;
|
|
||||||
double time = (double (clock() - start)) / CLOCKS_PER_SEC;
|
|
||||||
Statistics::updateStatistics (size, loopy, nIters, time);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
Params
|
||||||
FgBpSolver::getPosterioriOf (VarId vid)
|
BpSolver::getPosterioriOf (VarId vid)
|
||||||
{
|
{
|
||||||
assert (factorGraph_->getFgVarNode (vid));
|
if (runned_ == false) {
|
||||||
FgVarNode* var = factorGraph_->getFgVarNode (vid);
|
runSolver();
|
||||||
|
}
|
||||||
|
assert (fg_->getVarNode (vid));
|
||||||
|
VarNode* var = fg_->getVarNode (vid);
|
||||||
Params probs;
|
Params probs;
|
||||||
if (var->hasEvidence()) {
|
if (var->hasEvidence()) {
|
||||||
probs.resize (var->nrStates(), LogAware::noEvidence());
|
probs.resize (var->range(), LogAware::noEvidence());
|
||||||
probs[var->getEvidence()] = LogAware::withEvidence();
|
probs[var->getEvidence()] = LogAware::withEvidence();
|
||||||
} else {
|
} else {
|
||||||
probs.resize (var->nrStates(), LogAware::multIdenty());
|
probs.resize (var->range(), LogAware::multIdenty());
|
||||||
const SpLinkSet& links = ninf(var)->getLinks();
|
const SpLinkSet& links = ninf(var)->getLinks();
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
for (unsigned i = 0; i < links.size(); i++) {
|
for (unsigned i = 0; i < links.size(); i++) {
|
||||||
@ -95,13 +82,16 @@ FgBpSolver::getPosterioriOf (VarId vid)
|
|||||||
|
|
||||||
|
|
||||||
Params
|
Params
|
||||||
FgBpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
||||||
{
|
{
|
||||||
|
if (runned_ == false) {
|
||||||
|
runSolver();
|
||||||
|
}
|
||||||
int idx = -1;
|
int idx = -1;
|
||||||
FgVarNode* vn = factorGraph_->getFgVarNode (jointVarIds[0]);
|
VarNode* vn = fg_->getVarNode (jointVarIds[0]);
|
||||||
const FgFacSet& factorNodes = vn->neighbors();
|
const FacNodes& facNodes = vn->neighbors();
|
||||||
for (unsigned i = 0; i < factorNodes.size(); i++) {
|
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||||
if (factorNodes[i]->factor()->contains (jointVarIds)) {
|
if (facNodes[i]->factor().contains (jointVarIds)) {
|
||||||
idx = i;
|
idx = i;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -109,11 +99,11 @@ FgBpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
|||||||
if (idx == -1) {
|
if (idx == -1) {
|
||||||
return getJointByConditioning (jointVarIds);
|
return getJointByConditioning (jointVarIds);
|
||||||
} else {
|
} else {
|
||||||
Factor res (*factorNodes[idx]->factor());
|
Factor res (facNodes[idx]->factor());
|
||||||
const SpLinkSet& links = ninf(factorNodes[idx])->getLinks();
|
const SpLinkSet& links = ninf(facNodes[idx])->getLinks();
|
||||||
for (unsigned i = 0; i < links.size(); i++) {
|
for (unsigned i = 0; i < links.size(); i++) {
|
||||||
Factor msg (links[i]->getVariable()->varId(),
|
Factor msg ({links[i]->getVariable()->varId()},
|
||||||
links[i]->getVariable()->nrStates(),
|
{links[i]->getVariable()->range()},
|
||||||
getVar2FactorMsg (links[i]));
|
getVar2FactorMsg (links[i]));
|
||||||
res.multiply (msg);
|
res.multiply (msg);
|
||||||
}
|
}
|
||||||
@ -131,30 +121,29 @@ FgBpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
FgBpSolver::runLoopySolver (void)
|
BpSolver::runSolver (void)
|
||||||
{
|
{
|
||||||
|
clock_t start;
|
||||||
|
if (Constants::COLLECT_STATS) {
|
||||||
|
start = clock();
|
||||||
|
}
|
||||||
initializeSolver();
|
initializeSolver();
|
||||||
nIters_ = 0;
|
nIters_ = 0;
|
||||||
|
|
||||||
while (!converged() && nIters_ < BpOptions::maxIter) {
|
while (!converged() && nIters_ < BpOptions::maxIter) {
|
||||||
|
|
||||||
nIters_ ++;
|
nIters_ ++;
|
||||||
if (Constants::DEBUG >= 2) {
|
if (Constants::DEBUG >= 2) {
|
||||||
Util::printHeader (" Iteration " + nIters_);
|
Util::printHeader (string ("Iteration ") + Util::toString (nIters_));
|
||||||
cout << endl;
|
// cout << endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
switch (BpOptions::schedule) {
|
switch (BpOptions::schedule) {
|
||||||
case BpOptions::Schedule::SEQ_RANDOM:
|
case BpOptions::Schedule::SEQ_RANDOM:
|
||||||
random_shuffle (links_.begin(), links_.end());
|
random_shuffle (links_.begin(), links_.end());
|
||||||
// no break
|
// no break
|
||||||
|
|
||||||
case BpOptions::Schedule::SEQ_FIXED:
|
case BpOptions::Schedule::SEQ_FIXED:
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
for (unsigned i = 0; i < links_.size(); i++) {
|
||||||
calculateAndUpdateMessage (links_[i]);
|
calculateAndUpdateMessage (links_[i]);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case BpOptions::Schedule::PARALLEL:
|
case BpOptions::Schedule::PARALLEL:
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
for (unsigned i = 0; i < links_.size(); i++) {
|
||||||
calculateMessage (links_[i]);
|
calculateMessage (links_[i]);
|
||||||
@ -163,7 +152,6 @@ FgBpSolver::runLoopySolver (void)
|
|||||||
updateMessage(links_[i]);
|
updateMessage(links_[i]);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case BpOptions::Schedule::MAX_RESIDUAL:
|
case BpOptions::Schedule::MAX_RESIDUAL:
|
||||||
maxResidualSchedule();
|
maxResidualSchedule();
|
||||||
break;
|
break;
|
||||||
@ -172,52 +160,35 @@ FgBpSolver::runLoopySolver (void)
|
|||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (Constants::DEBUG >= 2) {
|
||||||
|
cout << endl;
|
||||||
|
if (nIters_ < BpOptions::maxIter) {
|
||||||
|
cout << "Sum-Product converged in " ;
|
||||||
|
cout << nIters_ << " iterations" << endl;
|
||||||
|
} else {
|
||||||
|
cout << "The maximum number of iterations was hit, terminating..." ;
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
unsigned size = fg_->varNodes().size();
|
||||||
|
if (Constants::COLLECT_STATS) {
|
||||||
|
unsigned nIters = 0;
|
||||||
|
bool loopy = fg_->isTree() == false;
|
||||||
|
if (loopy) nIters = nIters_;
|
||||||
|
double time = (double (clock() - start)) / CLOCKS_PER_SEC;
|
||||||
|
Statistics::updateStatistics (size, loopy, nIters, time);
|
||||||
|
}
|
||||||
|
runned_ = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
FgBpSolver::initializeSolver (void)
|
BpSolver::createLinks (void)
|
||||||
{
|
{
|
||||||
const FgVarSet& varNodes = factorGraph_->getVarNodes();
|
const FacNodes& facNodes = fg_->facNodes();
|
||||||
for (unsigned i = 0; i < varsI_.size(); i++) {
|
|
||||||
delete varsI_[i];
|
|
||||||
}
|
|
||||||
varsI_.reserve (varNodes.size());
|
|
||||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
|
||||||
varsI_.push_back (new SPNodeInfo());
|
|
||||||
}
|
|
||||||
|
|
||||||
const FgFacSet& facNodes = factorGraph_->getFactorNodes();
|
|
||||||
for (unsigned i = 0; i < facsI_.size(); i++) {
|
|
||||||
delete facsI_[i];
|
|
||||||
}
|
|
||||||
facsI_.reserve (facNodes.size());
|
|
||||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||||
facsI_.push_back (new SPNodeInfo());
|
const VarNodes& neighbors = facNodes[i]->neighbors();
|
||||||
}
|
|
||||||
|
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
|
||||||
delete links_[i];
|
|
||||||
}
|
|
||||||
createLinks();
|
|
||||||
|
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
|
||||||
FgFacNode* src = links_[i]->getFactor();
|
|
||||||
FgVarNode* dst = links_[i]->getVariable();
|
|
||||||
ninf (dst)->addSpLink (links_[i]);
|
|
||||||
ninf (src)->addSpLink (links_[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
FgBpSolver::createLinks (void)
|
|
||||||
{
|
|
||||||
const FgFacSet& facNodes = factorGraph_->getFactorNodes();
|
|
||||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
|
||||||
const FgVarSet& neighbors = facNodes[i]->neighbors();
|
|
||||||
for (unsigned j = 0; j < neighbors.size(); j++) {
|
for (unsigned j = 0; j < neighbors.size(); j++) {
|
||||||
links_.push_back (new SpLink (facNodes[i], neighbors[j]));
|
links_.push_back (new SpLink (facNodes[i], neighbors[j]));
|
||||||
}
|
}
|
||||||
@ -226,42 +197,8 @@ FgBpSolver::createLinks (void)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
|
||||||
FgBpSolver::converged (void)
|
|
||||||
{
|
|
||||||
if (links_.size() == 0) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
if (nIters_ == 0 || nIters_ == 1) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
bool converged = true;
|
|
||||||
if (BpOptions::schedule == BpOptions::Schedule::MAX_RESIDUAL) {
|
|
||||||
double maxResidual = (*(sortedOrder_.begin()))->getResidual();
|
|
||||||
if (maxResidual > BpOptions::accuracy) {
|
|
||||||
converged = false;
|
|
||||||
} else {
|
|
||||||
converged = true;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
|
||||||
double residual = links_[i]->getResidual();
|
|
||||||
if (Constants::DEBUG >= 2) {
|
|
||||||
cout << links_[i]->toString() + " residual = " << residual << endl;
|
|
||||||
}
|
|
||||||
if (residual > BpOptions::accuracy) {
|
|
||||||
converged = false;
|
|
||||||
if (Constants::DEBUG == 0) break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return converged;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
FgBpSolver::maxResidualSchedule (void)
|
BpSolver::maxResidualSchedule (void)
|
||||||
{
|
{
|
||||||
if (nIters_ == 1) {
|
if (nIters_ == 1) {
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
for (unsigned i = 0; i < links_.size(); i++) {
|
||||||
@ -293,7 +230,7 @@ FgBpSolver::maxResidualSchedule (void)
|
|||||||
linkMap_.find (link)->second = sortedOrder_.insert (link);
|
linkMap_.find (link)->second = sortedOrder_.insert (link);
|
||||||
|
|
||||||
// update the messages that depend on message source --> destin
|
// update the messages that depend on message source --> destin
|
||||||
const FgFacSet& factorNeighbors = link->getVariable()->neighbors();
|
const FacNodes& factorNeighbors = link->getVariable()->neighbors();
|
||||||
for (unsigned i = 0; i < factorNeighbors.size(); i++) {
|
for (unsigned i = 0; i < factorNeighbors.size(); i++) {
|
||||||
if (factorNeighbors[i] != link->getFactor()) {
|
if (factorNeighbors[i] != link->getFactor()) {
|
||||||
const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks();
|
const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks();
|
||||||
@ -316,16 +253,16 @@ FgBpSolver::maxResidualSchedule (void)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
FgBpSolver::calculateFactor2VariableMsg (SpLink* link) const
|
BpSolver::calculateFactor2VariableMsg (SpLink* link)
|
||||||
{
|
{
|
||||||
const FgFacNode* src = link->getFactor();
|
FacNode* src = link->getFactor();
|
||||||
const FgVarNode* dst = link->getVariable();
|
const VarNode* dst = link->getVariable();
|
||||||
const SpLinkSet& links = ninf(src)->getLinks();
|
const SpLinkSet& links = ninf(src)->getLinks();
|
||||||
// calculate the product of messages that were sent
|
// calculate the product of messages that were sent
|
||||||
// to factor `src', except from var `dst'
|
// to factor `src', except from var `dst'
|
||||||
unsigned msgSize = 1;
|
unsigned msgSize = 1;
|
||||||
for (unsigned i = 0; i < links.size(); i++) {
|
for (unsigned i = 0; i < links.size(); i++) {
|
||||||
msgSize *= links[i]->getVariable()->nrStates();
|
msgSize *= links[i]->getVariable()->range();
|
||||||
}
|
}
|
||||||
unsigned repetitions = 1;
|
unsigned repetitions = 1;
|
||||||
Params msgProduct (msgSize, LogAware::multIdenty());
|
Params msgProduct (msgSize, LogAware::multIdenty());
|
||||||
@ -333,9 +270,9 @@ FgBpSolver::calculateFactor2VariableMsg (SpLink* link) const
|
|||||||
for (int i = links.size() - 1; i >= 0; i--) {
|
for (int i = links.size() - 1; i >= 0; i--) {
|
||||||
if (links[i]->getVariable() != dst) {
|
if (links[i]->getVariable() != dst) {
|
||||||
Util::add (msgProduct, getVar2FactorMsg (links[i]), repetitions);
|
Util::add (msgProduct, getVar2FactorMsg (links[i]), repetitions);
|
||||||
repetitions *= links[i]->getVariable()->nrStates();
|
repetitions *= links[i]->getVariable()->range();
|
||||||
} else {
|
} else {
|
||||||
unsigned ds = links[i]->getVariable()->nrStates();
|
unsigned ds = links[i]->getVariable()->range();
|
||||||
Util::add (msgProduct, Params (ds, 1.0), repetitions);
|
Util::add (msgProduct, Params (ds, 1.0), repetitions);
|
||||||
repetitions *= ds;
|
repetitions *= ds;
|
||||||
}
|
}
|
||||||
@ -348,22 +285,21 @@ FgBpSolver::calculateFactor2VariableMsg (SpLink* link) const
|
|||||||
cout << ": " << endl;
|
cout << ": " << endl;
|
||||||
}
|
}
|
||||||
Util::multiply (msgProduct, getVar2FactorMsg (links[i]), repetitions);
|
Util::multiply (msgProduct, getVar2FactorMsg (links[i]), repetitions);
|
||||||
repetitions *= links[i]->getVariable()->nrStates();
|
repetitions *= links[i]->getVariable()->range();
|
||||||
} else {
|
} else {
|
||||||
unsigned ds = links[i]->getVariable()->nrStates();
|
unsigned ds = links[i]->getVariable()->range();
|
||||||
Util::multiply (msgProduct, Params (ds, 1.0), repetitions);
|
Util::multiply (msgProduct, Params (ds, 1.0), repetitions);
|
||||||
repetitions *= ds;
|
repetitions *= ds;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Factor result (src->factor()->arguments(),
|
Factor result (src->factor().arguments(),
|
||||||
src->factor()->ranges(),
|
src->factor().ranges(), msgProduct);
|
||||||
msgProduct);
|
result.multiply (src->factor());
|
||||||
result.multiply (*(src->factor()));
|
|
||||||
if (Constants::DEBUG >= 5) {
|
if (Constants::DEBUG >= 5) {
|
||||||
cout << " message product: " << msgProduct << endl;
|
cout << " message product: " << msgProduct << endl;
|
||||||
cout << " original factor: " << src->params() << endl;
|
cout << " original factor: " << src->factor().params() << endl;
|
||||||
cout << " factor product: " << result.params() << endl;
|
cout << " factor product: " << result.params() << endl;
|
||||||
}
|
}
|
||||||
result.sumOutAllExcept (dst->varId());
|
result.sumOutAllExcept (dst->varId());
|
||||||
@ -386,19 +322,19 @@ FgBpSolver::calculateFactor2VariableMsg (SpLink* link) const
|
|||||||
|
|
||||||
|
|
||||||
Params
|
Params
|
||||||
FgBpSolver::getVar2FactorMsg (const SpLink* link) const
|
BpSolver::getVar2FactorMsg (const SpLink* link) const
|
||||||
{
|
{
|
||||||
const FgVarNode* src = link->getVariable();
|
const VarNode* src = link->getVariable();
|
||||||
const FgFacNode* dst = link->getFactor();
|
const FacNode* dst = link->getFactor();
|
||||||
Params msg;
|
Params msg;
|
||||||
if (src->hasEvidence()) {
|
if (src->hasEvidence()) {
|
||||||
msg.resize (src->nrStates(), LogAware::noEvidence());
|
msg.resize (src->range(), LogAware::noEvidence());
|
||||||
msg[src->getEvidence()] = LogAware::withEvidence();
|
msg[src->getEvidence()] = LogAware::withEvidence();
|
||||||
if (Constants::DEBUG >= 5) {
|
if (Constants::DEBUG >= 5) {
|
||||||
cout << msg;
|
cout << msg;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
msg.resize (src->nrStates(), LogAware::one());
|
msg.resize (src->range(), LogAware::one());
|
||||||
}
|
}
|
||||||
if (Constants::DEBUG >= 5) {
|
if (Constants::DEBUG >= 5) {
|
||||||
cout << msg;
|
cout << msg;
|
||||||
@ -429,16 +365,16 @@ FgBpSolver::getVar2FactorMsg (const SpLink* link) const
|
|||||||
|
|
||||||
|
|
||||||
Params
|
Params
|
||||||
FgBpSolver::getJointByConditioning (const VarIds& jointVarIds) const
|
BpSolver::getJointByConditioning (const VarIds& jointVarIds) const
|
||||||
{
|
{
|
||||||
FgVarSet jointVars;
|
VarNodes jointVars;
|
||||||
for (unsigned i = 0; i < jointVarIds.size(); i++) {
|
for (unsigned i = 0; i < jointVarIds.size(); i++) {
|
||||||
assert (factorGraph_->getFgVarNode (jointVarIds[i]));
|
assert (fg_->getVarNode (jointVarIds[i]));
|
||||||
jointVars.push_back (factorGraph_->getFgVarNode (jointVarIds[i]));
|
jointVars.push_back (fg_->getVarNode (jointVarIds[i]));
|
||||||
}
|
}
|
||||||
|
|
||||||
FactorGraph* fg = new FactorGraph (*factorGraph_);
|
FactorGraph* fg = new FactorGraph (*fg_);
|
||||||
FgBpSolver solver (*fg);
|
BpSolver solver (*fg);
|
||||||
solver.runSolver();
|
solver.runSolver();
|
||||||
Params prevBeliefs = solver.getPosterioriOf (jointVarIds[0]);
|
Params prevBeliefs = solver.getPosterioriOf (jointVarIds[0]);
|
||||||
|
|
||||||
@ -447,9 +383,9 @@ FgBpSolver::getJointByConditioning (const VarIds& jointVarIds) const
|
|||||||
for (unsigned i = 1; i < jointVarIds.size(); i++) {
|
for (unsigned i = 1; i < jointVarIds.size(); i++) {
|
||||||
assert (jointVars[i]->hasEvidence() == false);
|
assert (jointVars[i]->hasEvidence() == false);
|
||||||
Params newBeliefs;
|
Params newBeliefs;
|
||||||
VarNodes observedVars;
|
Vars observedVars;
|
||||||
for (unsigned j = 0; j < observedVids.size(); j++) {
|
for (unsigned j = 0; j < observedVids.size(); j++) {
|
||||||
observedVars.push_back (fg->getFgVarNode (observedVids[j]));
|
observedVars.push_back (fg->getVarNode (observedVids[j]));
|
||||||
}
|
}
|
||||||
StatesIndexer idx (observedVars, false);
|
StatesIndexer idx (observedVars, false);
|
||||||
while (idx.valid()) {
|
while (idx.valid()) {
|
||||||
@ -457,7 +393,7 @@ FgBpSolver::getJointByConditioning (const VarIds& jointVarIds) const
|
|||||||
observedVars[j]->setEvidence (idx[j]);
|
observedVars[j]->setEvidence (idx[j]);
|
||||||
}
|
}
|
||||||
++ idx;
|
++ idx;
|
||||||
FgBpSolver solver (*fg);
|
BpSolver solver (*fg);
|
||||||
solver.runSolver();
|
solver.runSolver();
|
||||||
Params beliefs = solver.getPosterioriOf (jointVarIds[i]);
|
Params beliefs = solver.getPosterioriOf (jointVarIds[i]);
|
||||||
for (unsigned k = 0; k < beliefs.size(); k++) {
|
for (unsigned k = 0; k < beliefs.size(); k++) {
|
||||||
@ -467,7 +403,7 @@ FgBpSolver::getJointByConditioning (const VarIds& jointVarIds) const
|
|||||||
|
|
||||||
int count = -1;
|
int count = -1;
|
||||||
for (unsigned j = 0; j < newBeliefs.size(); j++) {
|
for (unsigned j = 0; j < newBeliefs.size(); j++) {
|
||||||
if (j % jointVars[i]->nrStates() == 0) {
|
if (j % jointVars[i]->range() == 0) {
|
||||||
count ++;
|
count ++;
|
||||||
}
|
}
|
||||||
newBeliefs[j] *= prevBeliefs[count];
|
newBeliefs[j] *= prevBeliefs[count];
|
||||||
@ -481,7 +417,68 @@ FgBpSolver::getJointByConditioning (const VarIds& jointVarIds) const
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
FgBpSolver::printLinkInformation (void) const
|
BpSolver::initializeSolver (void)
|
||||||
|
{
|
||||||
|
const VarNodes& varNodes = fg_->varNodes();
|
||||||
|
varsI_.reserve (varNodes.size());
|
||||||
|
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||||
|
varsI_.push_back (new SPNodeInfo());
|
||||||
|
}
|
||||||
|
const FacNodes& facNodes = fg_->facNodes();
|
||||||
|
facsI_.reserve (facNodes.size());
|
||||||
|
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||||
|
facsI_.push_back (new SPNodeInfo());
|
||||||
|
}
|
||||||
|
createLinks();
|
||||||
|
for (unsigned i = 0; i < links_.size(); i++) {
|
||||||
|
FacNode* src = links_[i]->getFactor();
|
||||||
|
VarNode* dst = links_[i]->getVariable();
|
||||||
|
ninf (dst)->addSpLink (links_[i]);
|
||||||
|
ninf (src)->addSpLink (links_[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
BpSolver::converged (void)
|
||||||
|
{
|
||||||
|
if (links_.size() == 0) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (nIters_ <= 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
bool converged = true;
|
||||||
|
if (BpOptions::schedule == BpOptions::Schedule::MAX_RESIDUAL) {
|
||||||
|
double maxResidual = (*(sortedOrder_.begin()))->getResidual();
|
||||||
|
if (maxResidual > BpOptions::accuracy) {
|
||||||
|
converged = false;
|
||||||
|
} else {
|
||||||
|
converged = true;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (unsigned i = 0; i < links_.size(); i++) {
|
||||||
|
double residual = links_[i]->getResidual();
|
||||||
|
if (Constants::DEBUG >= 2) {
|
||||||
|
cout << links_[i]->toString() + " residual = " << residual << endl;
|
||||||
|
}
|
||||||
|
if (residual > BpOptions::accuracy) {
|
||||||
|
converged = false;
|
||||||
|
if (Constants::DEBUG == 0) break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (Constants::DEBUG >= 2) {
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return converged;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
BpSolver::printLinkInformation (void) const
|
||||||
{
|
{
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
for (unsigned i = 0; i < links_.size(); i++) {
|
||||||
SpLink* l = links_[i];
|
SpLink* l = links_[i];
|
@ -1,5 +1,5 @@
|
|||||||
#ifndef HORUS_FGBPSOLVER_H
|
#ifndef HORUS_BPSOLVER_H
|
||||||
#define HORUS_FGBPSOLVER_H
|
#define HORUS_BPSOLVER_H
|
||||||
|
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
@ -16,12 +16,12 @@ using namespace std;
|
|||||||
class SpLink
|
class SpLink
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
SpLink (FgFacNode* fn, FgVarNode* vn)
|
SpLink (FacNode* fn, VarNode* vn)
|
||||||
{
|
{
|
||||||
fac_ = fn;
|
fac_ = fn;
|
||||||
var_ = vn;
|
var_ = vn;
|
||||||
v1_.resize (vn->nrStates(), LogAware::tl (1.0 / vn->nrStates()));
|
v1_.resize (vn->range(), LogAware::tl (1.0 / vn->range()));
|
||||||
v2_.resize (vn->nrStates(), LogAware::tl (1.0 / vn->nrStates()));
|
v2_.resize (vn->range(), LogAware::tl (1.0 / vn->range()));
|
||||||
currMsg_ = &v1_;
|
currMsg_ = &v1_;
|
||||||
nextMsg_ = &v2_;
|
nextMsg_ = &v2_;
|
||||||
msgSended_ = false;
|
msgSended_ = false;
|
||||||
@ -30,9 +30,9 @@ class SpLink
|
|||||||
|
|
||||||
virtual ~SpLink (void) { };
|
virtual ~SpLink (void) { };
|
||||||
|
|
||||||
FgFacNode* getFactor (void) const { return fac_; }
|
FacNode* getFactor (void) const { return fac_; }
|
||||||
|
|
||||||
FgVarNode* getVariable (void) const { return var_; }
|
VarNode* getVariable (void) const { return var_; }
|
||||||
|
|
||||||
const Params& getMessage (void) const { return *currMsg_; }
|
const Params& getMessage (void) const { return *currMsg_; }
|
||||||
|
|
||||||
@ -65,14 +65,14 @@ class SpLink
|
|||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
FgFacNode* fac_;
|
FacNode* fac_;
|
||||||
FgVarNode* var_;
|
VarNode* var_;
|
||||||
Params v1_;
|
Params v1_;
|
||||||
Params v2_;
|
Params v2_;
|
||||||
Params* currMsg_;
|
Params* currMsg_;
|
||||||
Params* nextMsg_;
|
Params* nextMsg_;
|
||||||
bool msgSended_;
|
bool msgSended_;
|
||||||
double residual_;
|
double residual_;
|
||||||
};
|
};
|
||||||
|
|
||||||
typedef vector<SpLink*> SpLinkSet;
|
typedef vector<SpLink*> SpLinkSet;
|
||||||
@ -88,40 +88,38 @@ class SPNodeInfo
|
|||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
class FgBpSolver : public Solver
|
class BpSolver : public Solver
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
FgBpSolver (const FactorGraph&);
|
BpSolver (const FactorGraph&);
|
||||||
|
|
||||||
virtual ~FgBpSolver (void);
|
virtual ~BpSolver (void);
|
||||||
|
|
||||||
void runSolver (void);
|
Params solveQuery (VarIds);
|
||||||
|
|
||||||
virtual Params getPosterioriOf (VarId);
|
virtual Params getPosterioriOf (VarId);
|
||||||
|
|
||||||
virtual Params getJointDistributionOf (const VarIds&);
|
virtual Params getJointDistributionOf (const VarIds&);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
virtual void initializeSolver (void);
|
void runSolver (void);
|
||||||
|
|
||||||
virtual void createLinks (void);
|
virtual void createLinks (void);
|
||||||
|
|
||||||
virtual void maxResidualSchedule (void);
|
virtual void maxResidualSchedule (void);
|
||||||
|
|
||||||
virtual void calculateFactor2VariableMsg (SpLink*) const;
|
virtual void calculateFactor2VariableMsg (SpLink*);
|
||||||
|
|
||||||
virtual Params getVar2FactorMsg (const SpLink*) const;
|
virtual Params getVar2FactorMsg (const SpLink*) const;
|
||||||
|
|
||||||
virtual Params getJointByConditioning (const VarIds&) const;
|
virtual Params getJointByConditioning (const VarIds&) const;
|
||||||
|
|
||||||
virtual void printLinkInformation (void) const;
|
SPNodeInfo* ninf (const VarNode* var) const
|
||||||
|
|
||||||
SPNodeInfo* ninf (const FgVarNode* var) const
|
|
||||||
{
|
{
|
||||||
return varsI_[var->getIndex()];
|
return varsI_[var->getIndex()];
|
||||||
}
|
}
|
||||||
|
|
||||||
SPNodeInfo* ninf (const FgFacNode* fac) const
|
SPNodeInfo* ninf (const FacNode* fac) const
|
||||||
{
|
{
|
||||||
return facsI_[fac->getIndex()];
|
return facsI_[fac->getIndex()];
|
||||||
}
|
}
|
||||||
@ -169,7 +167,8 @@ class FgBpSolver : public Solver
|
|||||||
unsigned nIters_;
|
unsigned nIters_;
|
||||||
vector<SPNodeInfo*> varsI_;
|
vector<SPNodeInfo*> varsI_;
|
||||||
vector<SPNodeInfo*> facsI_;
|
vector<SPNodeInfo*> facsI_;
|
||||||
const FactorGraph* factorGraph_;
|
bool runned_;
|
||||||
|
const FactorGraph* fg_;
|
||||||
|
|
||||||
typedef multiset<SpLink*, CompareResidual> SortedOrder;
|
typedef multiset<SpLink*, CompareResidual> SortedOrder;
|
||||||
SortedOrder sortedOrder_;
|
SortedOrder sortedOrder_;
|
||||||
@ -178,9 +177,12 @@ class FgBpSolver : public Solver
|
|||||||
SpLinkMap linkMap_;
|
SpLinkMap linkMap_;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void runLoopySolver (void);
|
void initializeSolver (void);
|
||||||
|
|
||||||
bool converged (void);
|
bool converged (void);
|
||||||
|
|
||||||
|
void printLinkInformation (void) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // HORUS_FGBPSOLVER_H
|
#endif // HORUS_BPSOLVER_H
|
||||||
|
|
@ -10,22 +10,22 @@ CFactorGraph::CFactorGraph (const FactorGraph& fg)
|
|||||||
groundFg_ = &fg;
|
groundFg_ = &fg;
|
||||||
freeColor_ = 0;
|
freeColor_ = 0;
|
||||||
|
|
||||||
const FgVarSet& varNodes = fg.getVarNodes();
|
const VarNodes& varNodes = fg.varNodes();
|
||||||
varSignatures_.reserve (varNodes.size());
|
varSignatures_.reserve (varNodes.size());
|
||||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||||
unsigned c = (varNodes[i]->neighbors().size() * 2) + 1;
|
unsigned c = (varNodes[i]->neighbors().size() * 2) + 1;
|
||||||
varSignatures_.push_back (Signature (c));
|
varSignatures_.push_back (Signature (c));
|
||||||
}
|
}
|
||||||
|
|
||||||
const FgFacSet& facNodes = fg.getFactorNodes();
|
const FacNodes& facNodes = fg.facNodes();
|
||||||
factorSignatures_.reserve (facNodes.size());
|
facSignatures_.reserve (facNodes.size());
|
||||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||||
unsigned c = facNodes[i]->neighbors().size() + 1;
|
unsigned c = facNodes[i]->neighbors().size() + 1;
|
||||||
factorSignatures_.push_back (Signature (c));
|
facSignatures_.push_back (Signature (c));
|
||||||
}
|
}
|
||||||
|
|
||||||
varColors_.resize (varNodes.size());
|
varColors_.resize (varNodes.size());
|
||||||
factorColors_.resize (facNodes.size());
|
facColors_.resize (facNodes.size());
|
||||||
setInitialColors();
|
setInitialColors();
|
||||||
createGroups();
|
createGroups();
|
||||||
}
|
}
|
||||||
@ -49,9 +49,9 @@ CFactorGraph::setInitialColors (void)
|
|||||||
{
|
{
|
||||||
// create the initial variable colors
|
// create the initial variable colors
|
||||||
VarColorMap colorMap;
|
VarColorMap colorMap;
|
||||||
const FgVarSet& varNodes = groundFg_->getVarNodes();
|
const VarNodes& varNodes = groundFg_->varNodes();
|
||||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||||
unsigned dsize = varNodes[i]->nrStates();
|
unsigned dsize = varNodes[i]->range();
|
||||||
VarColorMap::iterator it = colorMap.find (dsize);
|
VarColorMap::iterator it = colorMap.find (dsize);
|
||||||
if (it == colorMap.end()) {
|
if (it == colorMap.end()) {
|
||||||
it = colorMap.insert (make_pair (
|
it = colorMap.insert (make_pair (
|
||||||
@ -70,24 +70,28 @@ CFactorGraph::setInitialColors (void)
|
|||||||
setColor (varNodes[i], stateColors[idx]);
|
setColor (varNodes[i], stateColors[idx]);
|
||||||
}
|
}
|
||||||
|
|
||||||
const FgFacSet& facNodes = groundFg_->getFactorNodes();
|
const FacNodes& facNodes = groundFg_->facNodes();
|
||||||
if (checkForIdenticalFactors) {
|
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||||
|
facNodes[i]->factor().setDistId (Util::maxUnsigned());
|
||||||
|
}
|
||||||
|
// FIXME FIXME FIXME : pfl should give correct dist ids.
|
||||||
|
if (checkForIdenticalFactors || true) {
|
||||||
unsigned groupCount = 1;
|
unsigned groupCount = 1;
|
||||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||||
Factor* f1 = facNodes[i]->factor();
|
Factor& f1 = facNodes[i]->factor();
|
||||||
if (f1->distId() != Util::maxUnsigned()) {
|
if (f1.distId() != Util::maxUnsigned()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
f1->setDistId (groupCount);
|
f1.setDistId (groupCount);
|
||||||
for (unsigned j = i + 1; j < facNodes.size(); j++) {
|
for (unsigned j = i + 1; j < facNodes.size(); j++) {
|
||||||
Factor* f2 = facNodes[j]->factor();
|
Factor& f2 = facNodes[j]->factor();
|
||||||
if (f2->distId() != Util::maxUnsigned()) {
|
if (f2.distId() != Util::maxUnsigned()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (f1->size() == f2->size() &&
|
if (f1.size() == f2.size() &&
|
||||||
f1->ranges() == f2->ranges() &&
|
f1.ranges() == f2.ranges() &&
|
||||||
f1->params() == f2->params()) {
|
f1.params() == f2.params()) {
|
||||||
f2->setDistId (groupCount);
|
f2.setDistId (groupCount);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
groupCount ++;
|
groupCount ++;
|
||||||
@ -96,7 +100,7 @@ CFactorGraph::setInitialColors (void)
|
|||||||
// create the initial factor colors
|
// create the initial factor colors
|
||||||
DistColorMap distColors;
|
DistColorMap distColors;
|
||||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||||
unsigned distId = facNodes[i]->factor()->distId();
|
unsigned distId = facNodes[i]->factor().distId();
|
||||||
DistColorMap::iterator it = distColors.find (distId);
|
DistColorMap::iterator it = distColors.find (distId);
|
||||||
if (it == distColors.end()) {
|
if (it == distColors.end()) {
|
||||||
it = distColors.insert (make_pair (distId, getFreeColor())).first;
|
it = distColors.insert (make_pair (distId, getFreeColor())).first;
|
||||||
@ -111,30 +115,30 @@ void
|
|||||||
CFactorGraph::createGroups (void)
|
CFactorGraph::createGroups (void)
|
||||||
{
|
{
|
||||||
VarSignMap varGroups;
|
VarSignMap varGroups;
|
||||||
FacSignMap factorGroups;
|
FacSignMap facGroups;
|
||||||
unsigned nIters = 0;
|
unsigned nIters = 0;
|
||||||
bool groupsHaveChanged = true;
|
bool groupsHaveChanged = true;
|
||||||
const FgVarSet& varNodes = groundFg_->getVarNodes();
|
const VarNodes& varNodes = groundFg_->varNodes();
|
||||||
const FgFacSet& facNodes = groundFg_->getFactorNodes();
|
const FacNodes& facNodes = groundFg_->facNodes();
|
||||||
|
|
||||||
while (groupsHaveChanged || nIters == 1) {
|
while (groupsHaveChanged || nIters == 1) {
|
||||||
nIters ++;
|
nIters ++;
|
||||||
|
|
||||||
unsigned prevFactorGroupsSize = factorGroups.size();
|
unsigned prevFactorGroupsSize = facGroups.size();
|
||||||
factorGroups.clear();
|
facGroups.clear();
|
||||||
// set a new color to the factors with the same signature
|
// set a new color to the factors with the same signature
|
||||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||||
const Signature& signature = getSignature (facNodes[i]);
|
const Signature& signature = getSignature (facNodes[i]);
|
||||||
FacSignMap::iterator it = factorGroups.find (signature);
|
FacSignMap::iterator it = facGroups.find (signature);
|
||||||
if (it == factorGroups.end()) {
|
if (it == facGroups.end()) {
|
||||||
it = factorGroups.insert (make_pair (signature, FgFacSet())).first;
|
it = facGroups.insert (make_pair (signature, FacNodes())).first;
|
||||||
}
|
}
|
||||||
it->second.push_back (facNodes[i]);
|
it->second.push_back (facNodes[i]);
|
||||||
}
|
}
|
||||||
for (FacSignMap::iterator it = factorGroups.begin();
|
for (FacSignMap::iterator it = facGroups.begin();
|
||||||
it != factorGroups.end(); it++) {
|
it != facGroups.end(); it++) {
|
||||||
Color newColor = getFreeColor();
|
Color newColor = getFreeColor();
|
||||||
FgFacSet& groupMembers = it->second;
|
FacNodes& groupMembers = it->second;
|
||||||
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
||||||
setColor (groupMembers[i], newColor);
|
setColor (groupMembers[i], newColor);
|
||||||
}
|
}
|
||||||
@ -147,24 +151,24 @@ CFactorGraph::createGroups (void)
|
|||||||
const Signature& signature = getSignature (varNodes[i]);
|
const Signature& signature = getSignature (varNodes[i]);
|
||||||
VarSignMap::iterator it = varGroups.find (signature);
|
VarSignMap::iterator it = varGroups.find (signature);
|
||||||
if (it == varGroups.end()) {
|
if (it == varGroups.end()) {
|
||||||
it = varGroups.insert (make_pair (signature, FgVarSet())).first;
|
it = varGroups.insert (make_pair (signature, VarNodes())).first;
|
||||||
}
|
}
|
||||||
it->second.push_back (varNodes[i]);
|
it->second.push_back (varNodes[i]);
|
||||||
}
|
}
|
||||||
for (VarSignMap::iterator it = varGroups.begin();
|
for (VarSignMap::iterator it = varGroups.begin();
|
||||||
it != varGroups.end(); it++) {
|
it != varGroups.end(); it++) {
|
||||||
Color newColor = getFreeColor();
|
Color newColor = getFreeColor();
|
||||||
FgVarSet& groupMembers = it->second;
|
VarNodes& groupMembers = it->second;
|
||||||
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
||||||
setColor (groupMembers[i], newColor);
|
setColor (groupMembers[i], newColor);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
groupsHaveChanged = prevVarGroupsSize != varGroups.size()
|
groupsHaveChanged = prevVarGroupsSize != varGroups.size()
|
||||||
|| prevFactorGroupsSize != factorGroups.size();
|
|| prevFactorGroupsSize != facGroups.size();
|
||||||
}
|
}
|
||||||
//printGroups (varGroups, factorGroups);
|
printGroups (varGroups, facGroups);
|
||||||
createClusters (varGroups, factorGroups);
|
createClusters (varGroups, facGroups);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -172,12 +176,12 @@ CFactorGraph::createGroups (void)
|
|||||||
void
|
void
|
||||||
CFactorGraph::createClusters (
|
CFactorGraph::createClusters (
|
||||||
const VarSignMap& varGroups,
|
const VarSignMap& varGroups,
|
||||||
const FacSignMap& factorGroups)
|
const FacSignMap& facGroups)
|
||||||
{
|
{
|
||||||
varClusters_.reserve (varGroups.size());
|
varClusters_.reserve (varGroups.size());
|
||||||
for (VarSignMap::const_iterator it = varGroups.begin();
|
for (VarSignMap::const_iterator it = varGroups.begin();
|
||||||
it != varGroups.end(); it++) {
|
it != varGroups.end(); it++) {
|
||||||
const FgVarSet& groupVars = it->second;
|
const VarNodes& groupVars = it->second;
|
||||||
VarCluster* vc = new VarCluster (groupVars);
|
VarCluster* vc = new VarCluster (groupVars);
|
||||||
for (unsigned i = 0; i < groupVars.size(); i++) {
|
for (unsigned i = 0; i < groupVars.size(); i++) {
|
||||||
vid2VarCluster_.insert (make_pair (groupVars[i]->varId(), vc));
|
vid2VarCluster_.insert (make_pair (groupVars[i]->varId(), vc));
|
||||||
@ -185,12 +189,12 @@ CFactorGraph::createClusters (
|
|||||||
varClusters_.push_back (vc);
|
varClusters_.push_back (vc);
|
||||||
}
|
}
|
||||||
|
|
||||||
facClusters_.reserve (factorGroups.size());
|
facClusters_.reserve (facGroups.size());
|
||||||
for (FacSignMap::const_iterator it = factorGroups.begin();
|
for (FacSignMap::const_iterator it = facGroups.begin();
|
||||||
it != factorGroups.end(); it++) {
|
it != facGroups.end(); it++) {
|
||||||
FgFacNode* groupFactor = it->second[0];
|
FacNode* groupFactor = it->second[0];
|
||||||
const FgVarSet& neighs = groupFactor->neighbors();
|
const VarNodes& neighs = groupFactor->neighbors();
|
||||||
VarClusterSet varClusters;
|
VarClusters varClusters;
|
||||||
varClusters.reserve (neighs.size());
|
varClusters.reserve (neighs.size());
|
||||||
for (unsigned i = 0; i < neighs.size(); i++) {
|
for (unsigned i = 0; i < neighs.size(); i++) {
|
||||||
VarId vid = neighs[i]->varId();
|
VarId vid = neighs[i]->varId();
|
||||||
@ -203,15 +207,15 @@ CFactorGraph::createClusters (
|
|||||||
|
|
||||||
|
|
||||||
const Signature&
|
const Signature&
|
||||||
CFactorGraph::getSignature (const FgVarNode* varNode)
|
CFactorGraph::getSignature (const VarNode* varNode)
|
||||||
{
|
{
|
||||||
Signature& sign = varSignatures_[varNode->getIndex()];
|
Signature& sign = varSignatures_[varNode->getIndex()];
|
||||||
vector<Color>::iterator it = sign.colors.begin();
|
vector<Color>::iterator it = sign.colors.begin();
|
||||||
const FgFacSet& neighs = varNode->neighbors();
|
const FacNodes& neighs = varNode->neighbors();
|
||||||
for (unsigned i = 0; i < neighs.size(); i++) {
|
for (unsigned i = 0; i < neighs.size(); i++) {
|
||||||
*it = getColor (neighs[i]);
|
*it = getColor (neighs[i]);
|
||||||
it ++;
|
it ++;
|
||||||
*it = neighs[i]->factor()->indexOf (varNode->varId());
|
*it = neighs[i]->factor().indexOf (varNode->varId());
|
||||||
it ++;
|
it ++;
|
||||||
}
|
}
|
||||||
*it = getColor (varNode);
|
*it = getColor (varNode);
|
||||||
@ -221,11 +225,11 @@ CFactorGraph::getSignature (const FgVarNode* varNode)
|
|||||||
|
|
||||||
|
|
||||||
const Signature&
|
const Signature&
|
||||||
CFactorGraph::getSignature (const FgFacNode* facNode)
|
CFactorGraph::getSignature (const FacNode* facNode)
|
||||||
{
|
{
|
||||||
Signature& sign = factorSignatures_[facNode->getIndex()];
|
Signature& sign = facSignatures_[facNode->getIndex()];
|
||||||
vector<Color>::iterator it = sign.colors.begin();
|
vector<Color>::iterator it = sign.colors.begin();
|
||||||
const FgVarSet& neighs = facNode->neighbors();
|
const VarNodes& neighs = facNode->neighbors();
|
||||||
for (unsigned i = 0; i < neighs.size(); i++) {
|
for (unsigned i = 0; i < neighs.size(); i++) {
|
||||||
*it = getColor (neighs[i]);
|
*it = getColor (neighs[i]);
|
||||||
it ++;
|
it ++;
|
||||||
@ -237,55 +241,53 @@ CFactorGraph::getSignature (const FgFacNode* facNode)
|
|||||||
|
|
||||||
|
|
||||||
FactorGraph*
|
FactorGraph*
|
||||||
CFactorGraph::getCompressedFactorGraph (void)
|
CFactorGraph::getGroundFactorGraph (void) const
|
||||||
{
|
{
|
||||||
FactorGraph* fg = new FactorGraph();
|
FactorGraph* fg = new FactorGraph();
|
||||||
for (unsigned i = 0; i < varClusters_.size(); i++) {
|
for (unsigned i = 0; i < varClusters_.size(); i++) {
|
||||||
FgVarNode* var = varClusters_[i]->getGroundFgVarNodes()[0];
|
VarNode* var = varClusters_[i]->getGroundVarNodes()[0];
|
||||||
FgVarNode* newVar = new FgVarNode (var);
|
VarNode* newVar = new VarNode (var);
|
||||||
varClusters_[i]->setRepresentativeVariable (newVar);
|
varClusters_[i]->setRepresentativeVariable (newVar);
|
||||||
fg->addVariable (newVar);
|
fg->addVarNode (newVar);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (unsigned i = 0; i < facClusters_.size(); i++) {
|
for (unsigned i = 0; i < facClusters_.size(); i++) {
|
||||||
const VarClusterSet& myVarClusters = facClusters_[i]->getVarClusters();
|
const VarClusters& myVarClusters = facClusters_[i]->getVarClusters();
|
||||||
VarNodes myGroundVars;
|
Vars myGroundVars;
|
||||||
myGroundVars.reserve (myVarClusters.size());
|
myGroundVars.reserve (myVarClusters.size());
|
||||||
for (unsigned j = 0; j < myVarClusters.size(); j++) {
|
for (unsigned j = 0; j < myVarClusters.size(); j++) {
|
||||||
FgVarNode* v = myVarClusters[j]->getRepresentativeVariable();
|
VarNode* v = myVarClusters[j]->getRepresentativeVariable();
|
||||||
myGroundVars.push_back (v);
|
myGroundVars.push_back (v);
|
||||||
}
|
}
|
||||||
Factor* newFactor = new Factor (myGroundVars,
|
FacNode* fn = new FacNode (Factor (myGroundVars,
|
||||||
facClusters_[i]->getGroundFactors()[0]->params());
|
facClusters_[i]->getGroundFactors()[0]->factor().params()));
|
||||||
FgFacNode* fn = new FgFacNode (newFactor);
|
|
||||||
facClusters_[i]->setRepresentativeFactor (fn);
|
facClusters_[i]->setRepresentativeFactor (fn);
|
||||||
fg->addFactor (fn);
|
fg->addFacNode (fn);
|
||||||
for (unsigned j = 0; j < myGroundVars.size(); j++) {
|
for (unsigned j = 0; j < myGroundVars.size(); j++) {
|
||||||
fg->addEdge (fn, static_cast<FgVarNode*> (myGroundVars[j]));
|
fg->addEdge (static_cast<VarNode*> (myGroundVars[j]), fn);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
fg->setIndexes();
|
|
||||||
return fg;
|
return fg;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
unsigned
|
unsigned
|
||||||
CFactorGraph::getGroundEdgeCount (
|
CFactorGraph::getEdgeCount (
|
||||||
const FacCluster* fc,
|
const FacCluster* fc,
|
||||||
const VarCluster* vc) const
|
const VarCluster* vc) const
|
||||||
{
|
{
|
||||||
const FgFacSet& clusterGroundFactors = fc->getGroundFactors();
|
|
||||||
FgVarNode* varNode = vc->getGroundFgVarNodes()[0];
|
|
||||||
unsigned count = 0;
|
unsigned count = 0;
|
||||||
|
VarId vid = vc->getGroundVarNodes().front()->varId();
|
||||||
|
const FacNodes& clusterGroundFactors = fc->getGroundFactors();
|
||||||
for (unsigned i = 0; i < clusterGroundFactors.size(); i++) {
|
for (unsigned i = 0; i < clusterGroundFactors.size(); i++) {
|
||||||
if (clusterGroundFactors[i]->factor()->indexOf (varNode->varId()) != -1) {
|
if (clusterGroundFactors[i]->factor().contains (vid)) {
|
||||||
count ++;
|
count ++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// CFgVarSet vars = vc->getGroundFgVarNodes();
|
// CVarNodes vars = vc->getGroundVarNodes();
|
||||||
// for (unsigned i = 1; i < vars.size(); i++) {
|
// for (unsigned i = 1; i < vars.size(); i++) {
|
||||||
// FgVarNode* var = vc->getGroundFgVarNodes()[i];
|
// VarNode* var = vc->getGroundVarNodes()[i];
|
||||||
// unsigned count2 = 0;
|
// unsigned count2 = 0;
|
||||||
// for (unsigned i = 0; i < clusterGroundFactors.size(); i++) {
|
// for (unsigned i = 0; i < clusterGroundFactors.size(); i++) {
|
||||||
// if (clusterGroundFactors[i]->getPosition (var) != -1) {
|
// if (clusterGroundFactors[i]->getPosition (var) != -1) {
|
||||||
@ -302,13 +304,13 @@ CFactorGraph::getGroundEdgeCount (
|
|||||||
void
|
void
|
||||||
CFactorGraph::printGroups (
|
CFactorGraph::printGroups (
|
||||||
const VarSignMap& varGroups,
|
const VarSignMap& varGroups,
|
||||||
const FacSignMap& factorGroups) const
|
const FacSignMap& facGroups) const
|
||||||
{
|
{
|
||||||
unsigned count = 1;
|
unsigned count = 1;
|
||||||
cout << "variable groups:" << endl;
|
cout << "variable groups:" << endl;
|
||||||
for (VarSignMap::const_iterator it = varGroups.begin();
|
for (VarSignMap::const_iterator it = varGroups.begin();
|
||||||
it != varGroups.end(); it++) {
|
it != varGroups.end(); it++) {
|
||||||
const FgVarSet& groupMembers = it->second;
|
const VarNodes& groupMembers = it->second;
|
||||||
if (groupMembers.size() > 0) {
|
if (groupMembers.size() > 0) {
|
||||||
cout << count << ": " ;
|
cout << count << ": " ;
|
||||||
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
||||||
@ -321,9 +323,9 @@ CFactorGraph::printGroups (
|
|||||||
|
|
||||||
count = 1;
|
count = 1;
|
||||||
cout << endl << "factor groups:" << endl;
|
cout << endl << "factor groups:" << endl;
|
||||||
for (FacSignMap::const_iterator it = factorGroups.begin();
|
for (FacSignMap::const_iterator it = facGroups.begin();
|
||||||
it != factorGroups.end(); it++) {
|
it != facGroups.end(); it++) {
|
||||||
const FgFacSet& groupMembers = it->second;
|
const FacNodes& groupMembers = it->second;
|
||||||
if (groupMembers.size() > 0) {
|
if (groupMembers.size() > 0) {
|
||||||
cout << ++count << ": " ;
|
cout << ++count << ": " ;
|
||||||
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
||||||
|
@ -22,11 +22,11 @@ typedef unordered_map<unsigned, vector<Color>> VarColorMap;
|
|||||||
typedef unordered_map<unsigned, Color> DistColorMap;
|
typedef unordered_map<unsigned, Color> DistColorMap;
|
||||||
typedef unordered_map<VarId, VarCluster*> VarId2VarCluster;
|
typedef unordered_map<VarId, VarCluster*> VarId2VarCluster;
|
||||||
|
|
||||||
typedef vector<VarCluster*> VarClusterSet;
|
typedef vector<VarCluster*> VarClusters;
|
||||||
typedef vector<FacCluster*> FacClusterSet;
|
typedef vector<FacCluster*> FacClusters;
|
||||||
|
|
||||||
typedef unordered_map<Signature, FgVarSet, SignatureHash> VarSignMap;
|
typedef unordered_map<Signature, VarNodes, SignatureHash> VarSignMap;
|
||||||
typedef unordered_map<Signature, FgFacSet, SignatureHash> FacSignMap;
|
typedef unordered_map<Signature, FacNodes, SignatureHash> FacSignMap;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -87,7 +87,7 @@ struct SignatureHash
|
|||||||
class VarCluster
|
class VarCluster
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
VarCluster (const FgVarSet& vs)
|
VarCluster (const VarNodes& vs)
|
||||||
{
|
{
|
||||||
for (unsigned i = 0; i < vs.size(); i++) {
|
for (unsigned i = 0; i < vs.size(); i++) {
|
||||||
groundVars_.push_back (vs[i]);
|
groundVars_.push_back (vs[i]);
|
||||||
@ -99,26 +99,28 @@ class VarCluster
|
|||||||
facClusters_.push_back (fc);
|
facClusters_.push_back (fc);
|
||||||
}
|
}
|
||||||
|
|
||||||
const FacClusterSet& getFacClusters (void) const
|
const FacClusters& getFacClusters (void) const
|
||||||
{
|
{
|
||||||
return facClusters_;
|
return facClusters_;
|
||||||
}
|
}
|
||||||
|
|
||||||
FgVarNode* getRepresentativeVariable (void) const { return representVar_; }
|
VarNode* getRepresentativeVariable (void) const { return representVar_; }
|
||||||
void setRepresentativeVariable (FgVarNode* v) { representVar_ = v; }
|
|
||||||
const FgVarSet& getGroundFgVarNodes (void) const { return groundVars_; }
|
void setRepresentativeVariable (VarNode* v) { representVar_ = v; }
|
||||||
|
|
||||||
|
const VarNodes& getGroundVarNodes (void) const { return groundVars_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
FgVarSet groundVars_;
|
VarNodes groundVars_;
|
||||||
FacClusterSet facClusters_;
|
FacClusters facClusters_;
|
||||||
FgVarNode* representVar_;
|
VarNode* representVar_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
class FacCluster
|
class FacCluster
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
FacCluster (const FgFacSet& groundFactors, const VarClusterSet& vcs)
|
FacCluster (const FacNodes& groundFactors, const VarClusters& vcs)
|
||||||
{
|
{
|
||||||
groundFactors_ = groundFactors;
|
groundFactors_ = groundFactors;
|
||||||
varClusters_ = vcs;
|
varClusters_ = vcs;
|
||||||
@ -127,12 +129,12 @@ class FacCluster
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const VarClusterSet& getVarClusters (void) const
|
const VarClusters& getVarClusters (void) const
|
||||||
{
|
{
|
||||||
return varClusters_;
|
return varClusters_;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool containsGround (const FgFacNode* fn)
|
bool containsGround (const FacNode* fn)
|
||||||
{
|
{
|
||||||
for (unsigned i = 0; i < groundFactors_.size(); i++) {
|
for (unsigned i = 0; i < groundFactors_.size(); i++) {
|
||||||
if (groundFactors_[i] == fn) {
|
if (groundFactors_[i] == fn) {
|
||||||
@ -142,26 +144,26 @@ class FacCluster
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
FgFacNode* getRepresentativeFactor (void) const
|
FacNode* getRepresentativeFactor (void) const
|
||||||
{
|
{
|
||||||
return representFactor_;
|
return representFactor_;
|
||||||
}
|
}
|
||||||
|
|
||||||
void setRepresentativeFactor (FgFacNode* fn)
|
void setRepresentativeFactor (FacNode* fn)
|
||||||
{
|
{
|
||||||
representFactor_ = fn;
|
representFactor_ = fn;
|
||||||
}
|
}
|
||||||
|
|
||||||
const FgFacSet& getGroundFactors (void) const
|
const FacNodes& getGroundFactors (void) const
|
||||||
{
|
{
|
||||||
return groundFactors_;
|
return groundFactors_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
FgFacSet groundFactors_;
|
FacNodes groundFactors_;
|
||||||
VarClusterSet varClusters_;
|
VarClusters varClusters_;
|
||||||
FgFacNode* representFactor_;
|
FacNode* representFactor_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
@ -172,19 +174,19 @@ class CFactorGraph
|
|||||||
|
|
||||||
~CFactorGraph (void);
|
~CFactorGraph (void);
|
||||||
|
|
||||||
const VarClusterSet& getVarClusters (void) { return varClusters_; }
|
const VarClusters& getVarClusters (void) { return varClusters_; }
|
||||||
|
|
||||||
const FacClusterSet& getFacClusters (void) { return facClusters_; }
|
const FacClusters& getFacClusters (void) { return facClusters_; }
|
||||||
|
|
||||||
FgVarNode* getEquivalentVariable (VarId vid)
|
VarNode* getEquivalentVariable (VarId vid)
|
||||||
{
|
{
|
||||||
VarCluster* vc = vid2VarCluster_.find (vid)->second;
|
VarCluster* vc = vid2VarCluster_.find (vid)->second;
|
||||||
return vc->getRepresentativeVariable();
|
return vc->getRepresentativeVariable();
|
||||||
}
|
}
|
||||||
|
|
||||||
FactorGraph* getCompressedFactorGraph (void);
|
FactorGraph* getGroundFactorGraph (void) const;
|
||||||
|
|
||||||
unsigned getGroundEdgeCount (const FacCluster*, const VarCluster*) const;
|
unsigned getEdgeCount (const FacCluster*, const VarCluster*) const;
|
||||||
|
|
||||||
static bool checkForIdenticalFactors;
|
static bool checkForIdenticalFactors;
|
||||||
|
|
||||||
@ -195,22 +197,22 @@ class CFactorGraph
|
|||||||
return freeColor_ - 1;
|
return freeColor_ - 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
Color getColor (const FgVarNode* vn) const
|
Color getColor (const VarNode* vn) const
|
||||||
{
|
{
|
||||||
return varColors_[vn->getIndex()];
|
return varColors_[vn->getIndex()];
|
||||||
}
|
}
|
||||||
Color getColor (const FgFacNode* fn) const {
|
Color getColor (const FacNode* fn) const {
|
||||||
return factorColors_[fn->getIndex()];
|
return facColors_[fn->getIndex()];
|
||||||
}
|
}
|
||||||
|
|
||||||
void setColor (const FgVarNode* vn, Color c)
|
void setColor (const VarNode* vn, Color c)
|
||||||
{
|
{
|
||||||
varColors_[vn->getIndex()] = c;
|
varColors_[vn->getIndex()] = c;
|
||||||
}
|
}
|
||||||
|
|
||||||
void setColor (const FgFacNode* fn, Color c)
|
void setColor (const FacNode* fn, Color c)
|
||||||
{
|
{
|
||||||
factorColors_[fn->getIndex()] = c;
|
facColors_[fn->getIndex()] = c;
|
||||||
}
|
}
|
||||||
|
|
||||||
VarCluster* getVariableCluster (VarId vid) const
|
VarCluster* getVariableCluster (VarId vid) const
|
||||||
@ -218,25 +220,25 @@ class CFactorGraph
|
|||||||
return vid2VarCluster_.find (vid)->second;
|
return vid2VarCluster_.find (vid)->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
void setInitialColors (void);
|
void setInitialColors (void);
|
||||||
|
|
||||||
void createGroups (void);
|
void createGroups (void);
|
||||||
|
|
||||||
void createClusters (const VarSignMap&, const FacSignMap&);
|
void createClusters (const VarSignMap&, const FacSignMap&);
|
||||||
|
|
||||||
const Signature& getSignature (const FgVarNode*);
|
const Signature& getSignature (const VarNode*);
|
||||||
|
|
||||||
const Signature& getSignature (const FgFacNode*);
|
const Signature& getSignature (const FacNode*);
|
||||||
|
|
||||||
void printGroups (const VarSignMap&, const FacSignMap&) const;
|
void printGroups (const VarSignMap&, const FacSignMap&) const;
|
||||||
|
|
||||||
Color freeColor_;
|
Color freeColor_;
|
||||||
vector<Color> varColors_;
|
vector<Color> varColors_;
|
||||||
vector<Color> factorColors_;
|
vector<Color> facColors_;
|
||||||
vector<Signature> varSignatures_;
|
vector<Signature> varSignatures_;
|
||||||
vector<Signature> factorSignatures_;
|
vector<Signature> facSignatures_;
|
||||||
VarClusterSet varClusters_;
|
VarClusters varClusters_;
|
||||||
FacClusterSet facClusters_;
|
FacClusters facClusters_;
|
||||||
VarId2VarCluster vid2VarCluster_;
|
VarId2VarCluster vid2VarCluster_;
|
||||||
const FactorGraph* groundFg_;
|
const FactorGraph* groundFg_;
|
||||||
};
|
};
|
||||||
|
@ -1,10 +1,41 @@
|
|||||||
#include "CbpSolver.h"
|
#include "CbpSolver.h"
|
||||||
|
|
||||||
|
|
||||||
|
CbpSolver::CbpSolver (const FactorGraph& fg) : BpSolver (fg)
|
||||||
|
{
|
||||||
|
unsigned nGroundVars, nGroundFacs, nWithoutNeighs;
|
||||||
|
if (Constants::COLLECT_STATS) {
|
||||||
|
nGroundVars = fg_->varNodes().size();
|
||||||
|
nGroundFacs = fg_->facNodes().size();
|
||||||
|
const VarNodes& vars = fg_->varNodes();
|
||||||
|
nWithoutNeighs = 0;
|
||||||
|
for (unsigned i = 0; i < vars.size(); i++) {
|
||||||
|
const FacNodes& factors = vars[i]->neighbors();
|
||||||
|
if (factors.size() == 1 && factors[0]->neighbors().size() == 1) {
|
||||||
|
nWithoutNeighs ++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cfg_ = new CFactorGraph (fg);
|
||||||
|
fg_ = cfg_->getGroundFactorGraph();
|
||||||
|
if (Constants::COLLECT_STATS) {
|
||||||
|
unsigned nClusterVars = fg_->varNodes().size();
|
||||||
|
unsigned nClusterFacs = fg_->facNodes().size();
|
||||||
|
Statistics::updateCompressingStatistics (nGroundVars,
|
||||||
|
nGroundFacs, nClusterVars, nClusterFacs, nWithoutNeighs);
|
||||||
|
}
|
||||||
|
Util::printHeader ("Uncompressed Factor Graph");
|
||||||
|
fg.print();
|
||||||
|
Util::printHeader ("Compressed Factor Graph");
|
||||||
|
fg_->print();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
CbpSolver::~CbpSolver (void)
|
CbpSolver::~CbpSolver (void)
|
||||||
{
|
{
|
||||||
delete lfg_;
|
delete cfg_;
|
||||||
delete factorGraph_;
|
delete fg_;
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
for (unsigned i = 0; i < links_.size(); i++) {
|
||||||
delete links_[i];
|
delete links_[i];
|
||||||
}
|
}
|
||||||
@ -16,26 +47,29 @@ CbpSolver::~CbpSolver (void)
|
|||||||
Params
|
Params
|
||||||
CbpSolver::getPosterioriOf (VarId vid)
|
CbpSolver::getPosterioriOf (VarId vid)
|
||||||
{
|
{
|
||||||
assert (lfg_->getEquivalentVariable (vid));
|
if (runned_ == false) {
|
||||||
FgVarNode* var = lfg_->getEquivalentVariable (vid);
|
runSolver();
|
||||||
|
}
|
||||||
|
assert (cfg_->getEquivalentVariable (vid));
|
||||||
|
VarNode* var = cfg_->getEquivalentVariable (vid);
|
||||||
Params probs;
|
Params probs;
|
||||||
if (var->hasEvidence()) {
|
if (var->hasEvidence()) {
|
||||||
probs.resize (var->nrStates(), LogAware::noEvidence());
|
probs.resize (var->range(), LogAware::noEvidence());
|
||||||
probs[var->getEvidence()] = LogAware::withEvidence();
|
probs[var->getEvidence()] = LogAware::withEvidence();
|
||||||
} else {
|
} else {
|
||||||
probs.resize (var->nrStates(), LogAware::multIdenty());
|
probs.resize (var->range(), LogAware::multIdenty());
|
||||||
const SpLinkSet& links = ninf(var)->getLinks();
|
const SpLinkSet& links = ninf(var)->getLinks();
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
for (unsigned i = 0; i < links.size(); i++) {
|
for (unsigned i = 0; i < links.size(); i++) {
|
||||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
||||||
Util::add (probs, l->getPoweredMessage());
|
Util::add (probs, l->poweredMessage());
|
||||||
}
|
}
|
||||||
LogAware::normalize (probs);
|
LogAware::normalize (probs);
|
||||||
Util::fromLog (probs);
|
Util::fromLog (probs);
|
||||||
} else {
|
} else {
|
||||||
for (unsigned i = 0; i < links.size(); i++) {
|
for (unsigned i = 0; i < links.size(); i++) {
|
||||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
||||||
Util::multiply (probs, l->getPoweredMessage());
|
Util::multiply (probs, l->poweredMessage());
|
||||||
}
|
}
|
||||||
LogAware::normalize (probs);
|
LogAware::normalize (probs);
|
||||||
}
|
}
|
||||||
@ -46,55 +80,14 @@ CbpSolver::getPosterioriOf (VarId vid)
|
|||||||
|
|
||||||
|
|
||||||
Params
|
Params
|
||||||
CbpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
CbpSolver::getJointDistributionOf (const VarIds& jointVids)
|
||||||
{
|
{
|
||||||
VarIds eqVarIds;
|
VarIds eqVarIds;
|
||||||
for (unsigned i = 0; i < jointVarIds.size(); i++) {
|
for (unsigned i = 0; i < jointVids.size(); i++) {
|
||||||
eqVarIds.push_back (lfg_->getEquivalentVariable (jointVarIds[i])->varId());
|
VarNode* vn = cfg_->getEquivalentVariable (jointVids[i]);
|
||||||
|
eqVarIds.push_back (vn->varId());
|
||||||
}
|
}
|
||||||
return FgBpSolver::getJointDistributionOf (eqVarIds);
|
return BpSolver::getJointDistributionOf (eqVarIds);
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
CbpSolver::initializeSolver (void)
|
|
||||||
{
|
|
||||||
unsigned nGroundVars, nGroundFacs, nWithoutNeighs;
|
|
||||||
if (Constants::COLLECT_STATS) {
|
|
||||||
nGroundVars = factorGraph_->getVarNodes().size();
|
|
||||||
nGroundFacs = factorGraph_->getFactorNodes().size();
|
|
||||||
const FgVarSet& vars = factorGraph_->getVarNodes();
|
|
||||||
nWithoutNeighs = 0;
|
|
||||||
for (unsigned i = 0; i < vars.size(); i++) {
|
|
||||||
const FgFacSet& factors = vars[i]->neighbors();
|
|
||||||
if (factors.size() == 1 && factors[0]->neighbors().size() == 1) {
|
|
||||||
nWithoutNeighs ++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
lfg_ = new CFactorGraph (*factorGraph_);
|
|
||||||
|
|
||||||
// cout << "Uncompressed Factor Graph" << endl;
|
|
||||||
// factorGraph_->printGraphicalModel();
|
|
||||||
// factorGraph_->exportToGraphViz ("uncompressed_fg.dot");
|
|
||||||
factorGraph_ = lfg_->getCompressedFactorGraph();
|
|
||||||
|
|
||||||
if (Constants::COLLECT_STATS) {
|
|
||||||
unsigned nClusterVars = factorGraph_->getVarNodes().size();
|
|
||||||
unsigned nClusterFacs = factorGraph_->getFactorNodes().size();
|
|
||||||
Statistics::updateCompressingStatistics (nGroundVars, nGroundFacs,
|
|
||||||
nClusterVars, nClusterFacs,
|
|
||||||
nWithoutNeighs);
|
|
||||||
}
|
|
||||||
|
|
||||||
// cout << "Compressed Factor Graph" << endl;
|
|
||||||
// factorGraph_->printGraphicalModel();
|
|
||||||
// factorGraph_->exportToGraphViz ("compressed_fg.dot");
|
|
||||||
// abort();
|
|
||||||
FgBpSolver::initializeSolver();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -102,12 +95,13 @@ CbpSolver::initializeSolver (void)
|
|||||||
void
|
void
|
||||||
CbpSolver::createLinks (void)
|
CbpSolver::createLinks (void)
|
||||||
{
|
{
|
||||||
const FacClusterSet fcs = lfg_->getFacClusters();
|
const FacClusters& fcs = cfg_->getFacClusters();
|
||||||
for (unsigned i = 0; i < fcs.size(); i++) {
|
for (unsigned i = 0; i < fcs.size(); i++) {
|
||||||
const VarClusterSet vcs = fcs[i]->getVarClusters();
|
const VarClusters& vcs = fcs[i]->getVarClusters();
|
||||||
for (unsigned j = 0; j < vcs.size(); j++) {
|
for (unsigned j = 0; j < vcs.size(); j++) {
|
||||||
unsigned c = lfg_->getGroundEdgeCount (fcs[i], vcs[j]);
|
unsigned c = cfg_->getEdgeCount (fcs[i], vcs[j]);
|
||||||
links_.push_back (new CbpSolverLink (fcs[i]->getRepresentativeFactor(),
|
links_.push_back (new CbpSolverLink (
|
||||||
|
fcs[i]->getRepresentativeFactor(),
|
||||||
vcs[j]->getRepresentativeVariable(), c));
|
vcs[j]->getRepresentativeVariable(), c));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -154,7 +148,7 @@ CbpSolver::maxResidualSchedule (void)
|
|||||||
linkMap_.find (link)->second = sortedOrder_.insert (link);
|
linkMap_.find (link)->second = sortedOrder_.insert (link);
|
||||||
|
|
||||||
// update the messages that depend on message source --> destin
|
// update the messages that depend on message source --> destin
|
||||||
const FgFacSet& factorNeighbors = link->getVariable()->neighbors();
|
const FacNodes& factorNeighbors = link->getVariable()->neighbors();
|
||||||
for (unsigned i = 0; i < factorNeighbors.size(); i++) {
|
for (unsigned i = 0; i < factorNeighbors.size(); i++) {
|
||||||
const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks();
|
const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks();
|
||||||
for (unsigned j = 0; j < links.size(); j++) {
|
for (unsigned j = 0; j < links.size(); j++) {
|
||||||
@ -192,16 +186,16 @@ Params
|
|||||||
CbpSolver::getVar2FactorMsg (const SpLink* link) const
|
CbpSolver::getVar2FactorMsg (const SpLink* link) const
|
||||||
{
|
{
|
||||||
Params msg;
|
Params msg;
|
||||||
const FgVarNode* src = link->getVariable();
|
const VarNode* src = link->getVariable();
|
||||||
const FgFacNode* dst = link->getFactor();
|
const FacNode* dst = link->getFactor();
|
||||||
const CbpSolverLink* l = static_cast<const CbpSolverLink*> (link);
|
const CbpSolverLink* l = static_cast<const CbpSolverLink*> (link);
|
||||||
if (src->hasEvidence()) {
|
if (src->hasEvidence()) {
|
||||||
msg.resize (src->nrStates(), LogAware::noEvidence());
|
msg.resize (src->range(), LogAware::noEvidence());
|
||||||
double value = link->getMessage()[src->getEvidence()];
|
double value = link->getMessage()[src->getEvidence()];
|
||||||
msg[src->getEvidence()] = LogAware::pow (value, l->getNumberOfEdges() - 1);
|
msg[src->getEvidence()] = LogAware::pow (value, l->nrEdges() - 1);
|
||||||
} else {
|
} else {
|
||||||
msg = link->getMessage();
|
msg = link->getMessage();
|
||||||
LogAware::pow (msg, l->getNumberOfEdges() - 1);
|
LogAware::pow (msg, l->nrEdges() - 1);
|
||||||
}
|
}
|
||||||
if (Constants::DEBUG >= 5) {
|
if (Constants::DEBUG >= 5) {
|
||||||
cout << " " << "init: " << msg << endl;
|
cout << " " << "init: " << msg << endl;
|
||||||
@ -211,17 +205,17 @@ CbpSolver::getVar2FactorMsg (const SpLink* link) const
|
|||||||
for (unsigned i = 0; i < links.size(); i++) {
|
for (unsigned i = 0; i < links.size(); i++) {
|
||||||
if (links[i]->getFactor() != dst) {
|
if (links[i]->getFactor() != dst) {
|
||||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
||||||
Util::add (msg, l->getPoweredMessage());
|
Util::add (msg, l->poweredMessage());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (unsigned i = 0; i < links.size(); i++) {
|
for (unsigned i = 0; i < links.size(); i++) {
|
||||||
if (links[i]->getFactor() != dst) {
|
if (links[i]->getFactor() != dst) {
|
||||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
||||||
Util::multiply (msg, l->getPoweredMessage());
|
Util::multiply (msg, l->poweredMessage());
|
||||||
if (Constants::DEBUG >= 5) {
|
if (Constants::DEBUG >= 5) {
|
||||||
cout << " msg from " << l->getFactor()->getLabel() << ": " ;
|
cout << " msg from " << l->getFactor()->getLabel() << ": " ;
|
||||||
cout << l->getPoweredMessage() << endl;
|
cout << l->poweredMessage() << endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -243,7 +237,7 @@ CbpSolver::printLinkInformation (void) const
|
|||||||
cout << l->toString() << ":" << endl;
|
cout << l->toString() << ":" << endl;
|
||||||
cout << " curr msg = " << l->getMessage() << endl;
|
cout << " curr msg = " << l->getMessage() << endl;
|
||||||
cout << " next msg = " << l->getNextMessage() << endl;
|
cout << " next msg = " << l->getNextMessage() << endl;
|
||||||
cout << " powered = " << l->getPoweredMessage() << endl;
|
cout << " powered = " << l->poweredMessage() << endl;
|
||||||
cout << " residual = " << l->getResidual() << endl;
|
cout << " residual = " << l->getResidual() << endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
#ifndef HORUS_CBP_H
|
#ifndef HORUS_CBP_H
|
||||||
#define HORUS_CBP_H
|
#define HORUS_CBP_H
|
||||||
|
|
||||||
#include "FgBpSolver.h"
|
#include "BpSolver.h"
|
||||||
#include "CFactorGraph.h"
|
#include "CFactorGraph.h"
|
||||||
|
|
||||||
class Factor;
|
class Factor;
|
||||||
@ -9,35 +9,33 @@ class Factor;
|
|||||||
class CbpSolverLink : public SpLink
|
class CbpSolverLink : public SpLink
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
CbpSolverLink (FgFacNode* fn, FgVarNode* vn, unsigned c) : SpLink (fn, vn)
|
CbpSolverLink (FacNode* fn, VarNode* vn, unsigned c)
|
||||||
{
|
: SpLink (fn, vn), nrEdges_(c),
|
||||||
edgeCount_ = c;
|
pwdMsg_(vn->range(), LogAware::one()) { }
|
||||||
poweredMsg_.resize (vn->nrStates(), LogAware::one());
|
|
||||||
}
|
|
||||||
|
|
||||||
unsigned getNumberOfEdges (void) const { return edgeCount_; }
|
unsigned nrEdges (void) const { return nrEdges_; }
|
||||||
|
|
||||||
const Params& getPoweredMessage (void) const { return poweredMsg_; }
|
const Params& poweredMessage (void) const { return pwdMsg_; }
|
||||||
|
|
||||||
void updateMessage (void)
|
void updateMessage (void)
|
||||||
{
|
{
|
||||||
poweredMsg_ = *nextMsg_;
|
pwdMsg_ = *nextMsg_;
|
||||||
swap (currMsg_, nextMsg_);
|
swap (currMsg_, nextMsg_);
|
||||||
msgSended_ = true;
|
msgSended_ = true;
|
||||||
LogAware::pow (poweredMsg_, edgeCount_);
|
LogAware::pow (pwdMsg_, nrEdges_);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Params poweredMsg_;
|
unsigned nrEdges_;
|
||||||
unsigned edgeCount_;
|
Params pwdMsg_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CbpSolver : public FgBpSolver
|
class CbpSolver : public BpSolver
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
CbpSolver (FactorGraph& fg) : FgBpSolver (fg) { }
|
CbpSolver (const FactorGraph& fg);
|
||||||
|
|
||||||
~CbpSolver (void);
|
~CbpSolver (void);
|
||||||
|
|
||||||
@ -46,14 +44,16 @@ class CbpSolver : public FgBpSolver
|
|||||||
Params getJointDistributionOf (const VarIds&);
|
Params getJointDistributionOf (const VarIds&);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void initializeSolver (void);
|
|
||||||
void createLinks (void);
|
void createLinks (void);
|
||||||
|
|
||||||
void maxResidualSchedule (void);
|
void maxResidualSchedule (void);
|
||||||
|
|
||||||
Params getVar2FactorMsg (const SpLink*) const;
|
Params getVar2FactorMsg (const SpLink*) const;
|
||||||
|
|
||||||
void printLinkInformation (void) const;
|
void printLinkInformation (void) const;
|
||||||
|
|
||||||
CFactorGraph* lfg_;
|
CFactorGraph* cfg_;
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // HORUS_CBP_H
|
#endif // HORUS_CBP_H
|
||||||
|
@ -43,6 +43,14 @@ CTNode::removeChild (CTNode* child)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
CTNode::removeChilds (void)
|
||||||
|
{
|
||||||
|
childs_.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
CTNode::removeAndDeleteChild (CTNode* child)
|
CTNode::removeAndDeleteChild (CTNode* child)
|
||||||
{
|
{
|
||||||
@ -897,19 +905,19 @@ ConstraintTree::getNodesAtLevel (unsigned level) const
|
|||||||
void
|
void
|
||||||
ConstraintTree::swapLogVar (LogVar X)
|
ConstraintTree::swapLogVar (LogVar X)
|
||||||
{
|
{
|
||||||
|
TupleSet before = tupleSet();
|
||||||
LogVars::iterator it =
|
LogVars::iterator it =
|
||||||
std::find (logVars_.begin(),logVars_.end(), X);
|
std::find (logVars_.begin(),logVars_.end(), X);
|
||||||
assert (it != logVars_.end());
|
assert (it != logVars_.end());
|
||||||
unsigned pos = std::distance (logVars_.begin(), it);
|
unsigned pos = std::distance (logVars_.begin(), it);
|
||||||
|
|
||||||
const CTNodes& nodes = getNodesAtLevel (pos);
|
const CTNodes& nodes = getNodesAtLevel (pos);
|
||||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||||
const CTNodes childs = nodes[i]->childs();
|
CTNodes childsCopy = nodes[i]->childs();
|
||||||
for (unsigned j = 0; j < childs.size(); j++) {
|
nodes[i]->removeChilds();
|
||||||
nodes[i]->removeChild (childs[j]);
|
for (unsigned j = 0; j < childsCopy.size(); j++) {
|
||||||
const CTNodes grandsons = childs[j]->childs();
|
const CTNodes grandsons = childsCopy[j]->childs();
|
||||||
for (unsigned k = 0; k < grandsons.size(); k++) {
|
for (unsigned k = 0; k < grandsons.size(); k++) {
|
||||||
CTNode* childCopy = new CTNode (*childs[j]);
|
CTNode* childCopy = new CTNode (*childsCopy[j]);
|
||||||
const CTNodes greatGrandsons = grandsons[k]->childs();
|
const CTNodes greatGrandsons = grandsons[k]->childs();
|
||||||
for (unsigned t = 0; t < greatGrandsons.size(); t++) {
|
for (unsigned t = 0; t < greatGrandsons.size(); t++) {
|
||||||
grandsons[k]->removeChild (greatGrandsons[t]);
|
grandsons[k]->removeChild (greatGrandsons[t]);
|
||||||
@ -920,10 +928,9 @@ ConstraintTree::swapLogVar (LogVar X)
|
|||||||
grandsons[k]->setLevel (grandsons[k]->level() - 1);
|
grandsons[k]->setLevel (grandsons[k]->level() - 1);
|
||||||
nodes[i]->addChild (grandsons[k], false);
|
nodes[i]->addChild (grandsons[k], false);
|
||||||
}
|
}
|
||||||
delete childs[j];
|
delete childsCopy[j];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::swap (logVars_[pos], logVars_[pos + 1]);
|
std::swap (logVars_[pos], logVars_[pos + 1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -50,6 +50,8 @@ class CTNode
|
|||||||
|
|
||||||
void removeChild (CTNode*);
|
void removeChild (CTNode*);
|
||||||
|
|
||||||
|
void removeChilds (void);
|
||||||
|
|
||||||
void removeAndDeleteChild (CTNode*);
|
void removeAndDeleteChild (CTNode*);
|
||||||
|
|
||||||
void removeAndDeleteAllChilds (void);
|
void removeAndDeleteAllChilds (void);
|
||||||
|
@ -3,53 +3,37 @@
|
|||||||
#include <fstream>
|
#include <fstream>
|
||||||
|
|
||||||
#include "ElimGraph.h"
|
#include "ElimGraph.h"
|
||||||
#include "BayesNet.h"
|
|
||||||
|
|
||||||
|
|
||||||
ElimHeuristic ElimGraph::elimHeuristic_ = MIN_NEIGHBORS;
|
ElimHeuristic ElimGraph::elimHeuristic_ = MIN_NEIGHBORS;
|
||||||
|
|
||||||
|
|
||||||
ElimGraph::ElimGraph (const BayesNet& bayesNet)
|
ElimGraph::ElimGraph (const vector<Factor*>& factors)
|
||||||
{
|
{
|
||||||
const BnNodeSet& bnNodes = bayesNet.getBayesNodes();
|
for (unsigned i = 0; i < factors.size(); i++) {
|
||||||
for (unsigned i = 0; i < bnNodes.size(); i++) {
|
const VarIds& vids = factors[i]->arguments();
|
||||||
if (bnNodes[i]->hasEvidence() == false) {
|
for (unsigned j = 0; j < vids.size() - 1; j++) {
|
||||||
addNode (new EgNode (bnNodes[i]));
|
EgNode* n1 = getEgNode (vids[j]);
|
||||||
}
|
if (n1 == 0) {
|
||||||
}
|
n1 = new EgNode (vids[j], factors[i]->range (j));
|
||||||
|
addNode (n1);
|
||||||
for (unsigned i = 0; i < bnNodes.size(); i++) {
|
}
|
||||||
if (bnNodes[i]->hasEvidence() == false) {
|
for (unsigned k = j + 1; k < vids.size(); k++) {
|
||||||
EgNode* n = getEgNode (bnNodes[i]->varId());
|
EgNode* n2 = getEgNode (vids[k]);
|
||||||
const BnNodeSet& childs = bnNodes[i]->getChilds();
|
if (n2 == 0) {
|
||||||
for (unsigned j = 0; j < childs.size(); j++) {
|
n2 = new EgNode (vids[k], factors[i]->range (k));
|
||||||
if (childs[j]->hasEvidence() == false) {
|
addNode (n2);
|
||||||
addEdge (n, getEgNode (childs[j]->varId()));
|
}
|
||||||
|
if (neighbors (n1, n2) == false) {
|
||||||
|
addEdge (n1, n2);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
if (vids.size() == 1) {
|
||||||
|
if (getEgNode (vids[0]) == 0) {
|
||||||
for (unsigned i = 0; i < bnNodes.size(); i++) {
|
addNode (new EgNode (vids[0], factors[i]->range (0)));
|
||||||
vector<EgNode*> neighs;
|
|
||||||
const vector<BayesNode*>& parents = bnNodes[i]->getParents();
|
|
||||||
for (unsigned i = 0; i < parents.size(); i++) {
|
|
||||||
if (parents[i]->hasEvidence() == false) {
|
|
||||||
neighs.push_back (getEgNode (parents[i]->varId()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (neighs.size() > 0) {
|
|
||||||
for (unsigned i = 0; i < neighs.size() - 1; i++) {
|
|
||||||
for (unsigned j = i+1; j < neighs.size(); j++) {
|
|
||||||
if (!neighbors (neighs[i], neighs[j])) {
|
|
||||||
addEdge (neighs[i], neighs[j]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
setIndexes();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -63,40 +47,16 @@ ElimGraph::~ElimGraph (void)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
ElimGraph::addNode (EgNode* n)
|
|
||||||
{
|
|
||||||
nodes_.push_back (n);
|
|
||||||
varMap_.insert (make_pair (n->varId(), n));
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
EgNode*
|
|
||||||
ElimGraph::getEgNode (VarId vid) const
|
|
||||||
{
|
|
||||||
unordered_map<VarId,EgNode*>::const_iterator it =varMap_.find (vid);
|
|
||||||
if (it ==varMap_.end()) {
|
|
||||||
return 0;
|
|
||||||
} else {
|
|
||||||
return it->second;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
VarIds
|
VarIds
|
||||||
ElimGraph::getEliminatingOrder (const VarIds& exclude)
|
ElimGraph::getEliminatingOrder (const VarIds& exclude)
|
||||||
{
|
{
|
||||||
VarIds elimOrder;
|
VarIds elimOrder;
|
||||||
marked_.resize (nodes_.size(), false);
|
marked_.resize (nodes_.size(), false);
|
||||||
|
|
||||||
for (unsigned i = 0; i < exclude.size(); i++) {
|
for (unsigned i = 0; i < exclude.size(); i++) {
|
||||||
|
assert (getEgNode (exclude[i]));
|
||||||
EgNode* node = getEgNode (exclude[i]);
|
EgNode* node = getEgNode (exclude[i]);
|
||||||
assert (node);
|
|
||||||
marked_[*node] = true;
|
marked_[*node] = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned nVarsToEliminate = nodes_.size() - exclude.size();
|
unsigned nVarsToEliminate = nodes_.size() - exclude.size();
|
||||||
for (unsigned i = 0; i < nVarsToEliminate; i++) {
|
for (unsigned i = 0; i < nVarsToEliminate; i++) {
|
||||||
EgNode* node = getLowestCostNode();
|
EgNode* node = getLowestCostNode();
|
||||||
@ -109,6 +69,100 @@ ElimGraph::getEliminatingOrder (const VarIds& exclude)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
ElimGraph::print (void) const
|
||||||
|
{
|
||||||
|
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||||
|
cout << "node " << nodes_[i]->label() << " neighs:" ;
|
||||||
|
vector<EgNode*> neighs = nodes_[i]->neighbors();
|
||||||
|
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||||
|
cout << " " << neighs[j]->label();
|
||||||
|
}
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
ElimGraph::exportToGraphViz (
|
||||||
|
const char* fileName,
|
||||||
|
bool showNeighborless,
|
||||||
|
const VarIds& highlightVarIds) const
|
||||||
|
{
|
||||||
|
ofstream out (fileName);
|
||||||
|
if (!out.is_open()) {
|
||||||
|
cerr << "error: cannot open file to write at " ;
|
||||||
|
cerr << "Markov::exportToDotFile()" << endl;
|
||||||
|
abort();
|
||||||
|
}
|
||||||
|
|
||||||
|
out << "strict graph {" << endl;
|
||||||
|
|
||||||
|
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||||
|
if (showNeighborless || nodes_[i]->neighbors().size() != 0) {
|
||||||
|
out << '"' << nodes_[i]->label() << '"' << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (unsigned i = 0; i < highlightVarIds.size(); i++) {
|
||||||
|
EgNode* node =getEgNode (highlightVarIds[i]);
|
||||||
|
if (node) {
|
||||||
|
out << '"' << node->label() << '"' ;
|
||||||
|
out << " [shape=box3d]" << endl;
|
||||||
|
} else {
|
||||||
|
cout << "error: invalid variable id: " << highlightVarIds[i] << endl;
|
||||||
|
abort();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||||
|
vector<EgNode*> neighs = nodes_[i]->neighbors();
|
||||||
|
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||||
|
out << '"' << nodes_[i]->label() << '"' << " -- " ;
|
||||||
|
out << '"' << neighs[j]->label() << '"' << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
out << "}" << endl;
|
||||||
|
out.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
VarIds
|
||||||
|
ElimGraph::getEliminationOrder (
|
||||||
|
const vector<Factor*> factors,
|
||||||
|
VarIds excludedVids)
|
||||||
|
{
|
||||||
|
ElimGraph graph (factors);
|
||||||
|
// graph.print();
|
||||||
|
// graph.exportToGraphViz ("_egg.dot");
|
||||||
|
return graph.getEliminatingOrder (excludedVids);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
ElimGraph::addNode (EgNode* n)
|
||||||
|
{
|
||||||
|
nodes_.push_back (n);
|
||||||
|
n->setIndex (nodes_.size() - 1);
|
||||||
|
varMap_.insert (make_pair (n->varId(), n));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
EgNode*
|
||||||
|
ElimGraph::getEgNode (VarId vid) const
|
||||||
|
{
|
||||||
|
unordered_map<VarId, EgNode*>::const_iterator it;
|
||||||
|
it = varMap_.find (vid);
|
||||||
|
return (it != varMap_.end()) ? it->second : 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
EgNode*
|
EgNode*
|
||||||
ElimGraph::getLowestCostNode (void) const
|
ElimGraph::getLowestCostNode (void) const
|
||||||
{
|
{
|
||||||
@ -166,7 +220,7 @@ ElimGraph::getWeightCost (const EgNode* n) const
|
|||||||
const vector<EgNode*>& neighs = n->neighbors();
|
const vector<EgNode*>& neighs = n->neighbors();
|
||||||
for (unsigned i = 0; i < neighs.size(); i++) {
|
for (unsigned i = 0; i < neighs.size(); i++) {
|
||||||
if (marked_[*neighs[i]] == false) {
|
if (marked_[*neighs[i]] == false) {
|
||||||
cost *= neighs[i]->nrStates();
|
cost *= neighs[i]->range();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return cost;
|
return cost;
|
||||||
@ -206,7 +260,7 @@ ElimGraph::getWeightedFillCost (const EgNode* n) const
|
|||||||
for (unsigned j = i+1; j < neighs.size(); j++) {
|
for (unsigned j = i+1; j < neighs.size(); j++) {
|
||||||
if (marked_[*neighs[j]] == true) continue;
|
if (marked_[*neighs[j]] == true) continue;
|
||||||
if (!neighbors (neighs[i], neighs[j])) {
|
if (!neighbors (neighs[i], neighs[j])) {
|
||||||
cost += neighs[i]->nrStates() * neighs[j]->nrStates();
|
cost += neighs[i]->range() * neighs[j]->range();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -247,78 +301,3 @@ ElimGraph::neighbors (const EgNode* n1, const EgNode* n2) const
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
ElimGraph::setIndexes (void)
|
|
||||||
{
|
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
|
||||||
nodes_[i]->setIndex (i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
ElimGraph::printGraphicalModel (void) const
|
|
||||||
{
|
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
|
||||||
cout << "node " << nodes_[i]->label() << " neighs:" ;
|
|
||||||
vector<EgNode*> neighs = nodes_[i]->neighbors();
|
|
||||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
|
||||||
cout << " " << neighs[j]->label();
|
|
||||||
}
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
ElimGraph::exportToGraphViz (const char* fileName,
|
|
||||||
bool showNeighborless,
|
|
||||||
const VarIds& highlightVarIds) const
|
|
||||||
{
|
|
||||||
ofstream out (fileName);
|
|
||||||
if (!out.is_open()) {
|
|
||||||
cerr << "error: cannot open file to write at " ;
|
|
||||||
cerr << "Markov::exportToDotFile()" << endl;
|
|
||||||
abort();
|
|
||||||
}
|
|
||||||
|
|
||||||
out << "strict graph {" << endl;
|
|
||||||
|
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
|
||||||
if (showNeighborless || nodes_[i]->neighbors().size() != 0) {
|
|
||||||
out << '"' << nodes_[i]->label() << '"' ;
|
|
||||||
if (nodes_[i]->hasEvidence()) {
|
|
||||||
out << " [style=filled, fillcolor=yellow]" << endl;
|
|
||||||
} else {
|
|
||||||
out << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (unsigned i = 0; i < highlightVarIds.size(); i++) {
|
|
||||||
EgNode* node =getEgNode (highlightVarIds[i]);
|
|
||||||
if (node) {
|
|
||||||
out << '"' << node->label() << '"' ;
|
|
||||||
out << " [shape=box3d]" << endl;
|
|
||||||
} else {
|
|
||||||
cout << "error: invalid variable id: " << highlightVarIds[i] << endl;
|
|
||||||
abort();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
|
||||||
vector<EgNode*> neighs = nodes_[i]->neighbors();
|
|
||||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
|
||||||
out << '"' << nodes_[i]->label() << '"' << " -- " ;
|
|
||||||
out << '"' << neighs[j]->label() << '"' << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
out << "}" << endl;
|
|
||||||
out.close();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
@ -17,10 +17,10 @@ enum ElimHeuristic
|
|||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
class EgNode : public VarNode
|
class EgNode : public Var
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
EgNode (VarNode* var) : VarNode (var) { }
|
EgNode (VarId vid, unsigned range) : Var (vid, range) { }
|
||||||
|
|
||||||
void addNeighbor (EgNode* n) { neighs_.push_back (n); }
|
void addNeighbor (EgNode* n) { neighs_.push_back (n); }
|
||||||
|
|
||||||
@ -34,10 +34,26 @@ class EgNode : public VarNode
|
|||||||
class ElimGraph
|
class ElimGraph
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
ElimGraph (const BayesNet&);
|
ElimGraph (const vector<Factor*>&); // TODO
|
||||||
|
|
||||||
~ElimGraph (void);
|
~ElimGraph (void);
|
||||||
|
|
||||||
|
VarIds getEliminatingOrder (const VarIds&);
|
||||||
|
|
||||||
|
void print (void) const;
|
||||||
|
|
||||||
|
void exportToGraphViz (const char*, bool = true,
|
||||||
|
const VarIds& = VarIds()) const;
|
||||||
|
|
||||||
|
static VarIds getEliminationOrder (const vector<Factor*>, VarIds);
|
||||||
|
|
||||||
|
static void setEliminationHeuristic (ElimHeuristic h)
|
||||||
|
{
|
||||||
|
elimHeuristic_ = h;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
|
||||||
void addEdge (EgNode* n1, EgNode* n2)
|
void addEdge (EgNode* n1, EgNode* n2)
|
||||||
{
|
{
|
||||||
assert (n1 != n2);
|
assert (n1 != n2);
|
||||||
@ -48,22 +64,6 @@ class ElimGraph
|
|||||||
void addNode (EgNode*);
|
void addNode (EgNode*);
|
||||||
|
|
||||||
EgNode* getEgNode (VarId) const;
|
EgNode* getEgNode (VarId) const;
|
||||||
|
|
||||||
VarIds getEliminatingOrder (const VarIds&);
|
|
||||||
|
|
||||||
void printGraphicalModel (void) const;
|
|
||||||
|
|
||||||
void exportToGraphViz (const char*, bool = true,
|
|
||||||
const VarIds& = VarIds()) const;
|
|
||||||
|
|
||||||
void setIndexes();
|
|
||||||
|
|
||||||
static void setEliminationHeuristic (ElimHeuristic h)
|
|
||||||
{
|
|
||||||
elimHeuristic_ = h;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
EgNode* getLowestCostNode (void) const;
|
EgNode* getLowestCostNode (void) const;
|
||||||
|
|
||||||
unsigned getNeighborsCost (const EgNode*) const;
|
unsigned getNeighborsCost (const EgNode*) const;
|
||||||
@ -80,7 +80,7 @@ class ElimGraph
|
|||||||
|
|
||||||
vector<EgNode*> nodes_;
|
vector<EgNode*> nodes_;
|
||||||
vector<bool> marked_;
|
vector<bool> marked_;
|
||||||
unordered_map<VarId,EgNode*> varMap_;
|
unordered_map<VarId, EgNode*> varMap_;
|
||||||
static ElimHeuristic elimHeuristic_;
|
static ElimHeuristic elimHeuristic_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -18,56 +18,14 @@ Factor::Factor (const Factor& g)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
Factor::Factor (VarId vid, unsigned nrStates)
|
|
||||||
{
|
|
||||||
args_.push_back (vid);
|
|
||||||
ranges_.push_back (nrStates);
|
|
||||||
params_.resize (nrStates, 1.0);
|
|
||||||
distId_ = Util::maxUnsigned();
|
|
||||||
assert (params_.size() == Util::expectedSize (ranges_));
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Factor::Factor (const VarNodes& vars)
|
|
||||||
{
|
|
||||||
int nrParams = 1;
|
|
||||||
for (unsigned i = 0; i < vars.size(); i++) {
|
|
||||||
args_.push_back (vars[i]->varId());
|
|
||||||
ranges_.push_back (vars[i]->nrStates());
|
|
||||||
nrParams *= vars[i]->nrStates();
|
|
||||||
}
|
|
||||||
double val = 1.0 / nrParams;
|
|
||||||
params_.resize (nrParams, val);
|
|
||||||
distId_ = Util::maxUnsigned();
|
|
||||||
assert (params_.size() == Util::expectedSize (ranges_));
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Factor::Factor (
|
Factor::Factor (
|
||||||
VarId vid,
|
const VarIds& vids,
|
||||||
unsigned nrStates,
|
const Ranges& ranges,
|
||||||
const Params& params)
|
|
||||||
{
|
|
||||||
args_.push_back (vid);
|
|
||||||
ranges_.push_back (nrStates);
|
|
||||||
params_ = params;
|
|
||||||
distId_ = Util::maxUnsigned();
|
|
||||||
assert (params_.size() == Util::expectedSize (ranges_));
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Factor::Factor (
|
|
||||||
const VarNodes& vars,
|
|
||||||
const Params& params,
|
const Params& params,
|
||||||
unsigned distId)
|
unsigned distId)
|
||||||
{
|
{
|
||||||
for (unsigned i = 0; i < vars.size(); i++) {
|
args_ = vids;
|
||||||
args_.push_back (vars[i]->varId());
|
ranges_ = ranges;
|
||||||
ranges_.push_back (vars[i]->nrStates());
|
|
||||||
}
|
|
||||||
params_ = params;
|
params_ = params;
|
||||||
distId_ = distId;
|
distId_ = distId;
|
||||||
assert (params_.size() == Util::expectedSize (ranges_));
|
assert (params_.size() == Util::expectedSize (ranges_));
|
||||||
@ -76,14 +34,16 @@ Factor::Factor (
|
|||||||
|
|
||||||
|
|
||||||
Factor::Factor (
|
Factor::Factor (
|
||||||
const VarIds& vids,
|
const Vars& vars,
|
||||||
const Ranges& ranges,
|
const Params& params,
|
||||||
const Params& params)
|
unsigned distId)
|
||||||
{
|
{
|
||||||
args_ = vids;
|
for (unsigned i = 0; i < vars.size(); i++) {
|
||||||
ranges_ = ranges;
|
args_.push_back (vars[i]->varId());
|
||||||
|
ranges_.push_back (vars[i]->range());
|
||||||
|
}
|
||||||
params_ = params;
|
params_ = params;
|
||||||
distId_ = Util::maxUnsigned();
|
distId_ = distId;
|
||||||
assert (params_.size() == Util::expectedSize (ranges_));
|
assert (params_.size() == Util::expectedSize (ranges_));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -185,8 +145,8 @@ Factor::sumOut (VarId vid)
|
|||||||
void
|
void
|
||||||
Factor::sumOutFirstVariable (void)
|
Factor::sumOutFirstVariable (void)
|
||||||
{
|
{
|
||||||
unsigned nStates = ranges_.front();
|
unsigned range = ranges_.front();
|
||||||
unsigned sep = params_.size() / nStates;
|
unsigned sep = params_.size() / range;
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
for (unsigned i = sep; i < params_.size(); i++) {
|
for (unsigned i = sep; i < params_.size(); i++) {
|
||||||
params_[i % sep] = Util::logSum (params_[i % sep], params_[i]);
|
params_[i % sep] = Util::logSum (params_[i % sep], params_[i]);
|
||||||
@ -206,14 +166,14 @@ Factor::sumOutFirstVariable (void)
|
|||||||
void
|
void
|
||||||
Factor::sumOutLastVariable (void)
|
Factor::sumOutLastVariable (void)
|
||||||
{
|
{
|
||||||
unsigned nStates = ranges_.back();
|
unsigned range = ranges_.back();
|
||||||
unsigned idx1 = 0;
|
unsigned idx1 = 0;
|
||||||
unsigned idx2 = 0;
|
unsigned idx2 = 0;
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
while (idx1 < params_.size()) {
|
while (idx1 < params_.size()) {
|
||||||
params_[idx2] = params_[idx1];
|
params_[idx2] = params_[idx1];
|
||||||
idx1 ++;
|
idx1 ++;
|
||||||
for (unsigned j = 1; j < nStates; j++) {
|
for (unsigned j = 1; j < range; j++) {
|
||||||
params_[idx2] = Util::logSum (params_[idx2], params_[idx1]);
|
params_[idx2] = Util::logSum (params_[idx2], params_[idx1]);
|
||||||
idx1 ++;
|
idx1 ++;
|
||||||
}
|
}
|
||||||
@ -223,7 +183,7 @@ Factor::sumOutLastVariable (void)
|
|||||||
while (idx1 < params_.size()) {
|
while (idx1 < params_.size()) {
|
||||||
params_[idx2] = params_[idx1];
|
params_[idx2] = params_[idx1];
|
||||||
idx1 ++;
|
idx1 ++;
|
||||||
for (unsigned j = 1; j < nStates; j++) {
|
for (unsigned j = 1; j < range; j++) {
|
||||||
params_[idx2] += params_[idx1];
|
params_[idx2] += params_[idx1];
|
||||||
idx1 ++;
|
idx1 ++;
|
||||||
}
|
}
|
||||||
@ -266,7 +226,7 @@ Factor::getLabel (void) const
|
|||||||
ss << "f(" ;
|
ss << "f(" ;
|
||||||
for (unsigned i = 0; i < args_.size(); i++) {
|
for (unsigned i = 0; i < args_.size(); i++) {
|
||||||
if (i != 0) ss << "," ;
|
if (i != 0) ss << "," ;
|
||||||
ss << VarNode (args_[i], ranges_[i]).label();
|
ss << Var (args_[i], ranges_[i]).label();
|
||||||
}
|
}
|
||||||
ss << ")" ;
|
ss << ")" ;
|
||||||
return ss.str();
|
return ss.str();
|
||||||
@ -277,13 +237,13 @@ Factor::getLabel (void) const
|
|||||||
void
|
void
|
||||||
Factor::print (void) const
|
Factor::print (void) const
|
||||||
{
|
{
|
||||||
VarNodes vars;
|
Vars vars;
|
||||||
for (unsigned i = 0; i < args_.size(); i++) {
|
for (unsigned i = 0; i < args_.size(); i++) {
|
||||||
vars.push_back (new VarNode (args_[i], ranges_[i]));
|
vars.push_back (new Var (args_[i], ranges_[i]));
|
||||||
}
|
}
|
||||||
vector<string> jointStrings = Util::getJointStateStrings (vars);
|
vector<string> jointStrings = Util::getStateLines (vars);
|
||||||
for (unsigned i = 0; i < params_.size(); i++) {
|
for (unsigned i = 0; i < params_.size(); i++) {
|
||||||
cout << "f(" << jointStrings[i] << ")" ;
|
cout << "[" << distId_ << "] f(" << jointStrings[i] << ")" ;
|
||||||
cout << " = " << params_[i] << endl;
|
cout << " = " << params_[i] << endl;
|
||||||
}
|
}
|
||||||
cout << endl;
|
cout << endl;
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "VarNode.h"
|
#include "Var.h"
|
||||||
#include "Indexer.h"
|
#include "Indexer.h"
|
||||||
#include "Util.h"
|
#include "Util.h"
|
||||||
|
|
||||||
@ -33,17 +33,14 @@ class TFactor
|
|||||||
|
|
||||||
void setDistId (unsigned id) { distId_ = id; }
|
void setDistId (unsigned id) { distId_ = id; }
|
||||||
|
|
||||||
|
void normalize (void) { LogAware::normalize (params_); }
|
||||||
|
|
||||||
void setParams (const Params& newParams)
|
void setParams (const Params& newParams)
|
||||||
{
|
{
|
||||||
params_ = newParams;
|
params_ = newParams;
|
||||||
assert (params_.size() == Util::expectedSize (ranges_));
|
assert (params_.size() == Util::expectedSize (ranges_));
|
||||||
}
|
}
|
||||||
|
|
||||||
void normalize (void)
|
|
||||||
{
|
|
||||||
LogAware::normalize (params_);
|
|
||||||
}
|
|
||||||
|
|
||||||
int indexOf (const T& t) const
|
int indexOf (const T& t) const
|
||||||
{
|
{
|
||||||
int idx = -1;
|
int idx = -1;
|
||||||
@ -258,16 +255,11 @@ class Factor : public TFactor<VarId>
|
|||||||
|
|
||||||
Factor (const Factor&);
|
Factor (const Factor&);
|
||||||
|
|
||||||
Factor (VarId, unsigned);
|
Factor (const VarIds&, const Ranges&, const Params&,
|
||||||
|
|
||||||
Factor (const VarNodes&);
|
|
||||||
|
|
||||||
Factor (VarId, unsigned, const Params&);
|
|
||||||
|
|
||||||
Factor (const VarNodes&, const Params&,
|
|
||||||
unsigned = Util::maxUnsigned());
|
unsigned = Util::maxUnsigned());
|
||||||
|
|
||||||
Factor (const VarIds&, const Ranges&, const Params&);
|
Factor (const Vars&, const Params&,
|
||||||
|
unsigned = Util::maxUnsigned());
|
||||||
|
|
||||||
void sumOutAllExcept (VarId);
|
void sumOutAllExcept (VarId);
|
||||||
|
|
||||||
|
@ -9,6 +9,7 @@
|
|||||||
#include "FactorGraph.h"
|
#include "FactorGraph.h"
|
||||||
#include "Factor.h"
|
#include "Factor.h"
|
||||||
#include "BayesNet.h"
|
#include "BayesNet.h"
|
||||||
|
#include "BayesBall.h"
|
||||||
#include "Util.h"
|
#include "Util.h"
|
||||||
|
|
||||||
|
|
||||||
@ -17,140 +18,92 @@ bool FactorGraph::orderFactorVariables = false;
|
|||||||
|
|
||||||
FactorGraph::FactorGraph (const FactorGraph& fg)
|
FactorGraph::FactorGraph (const FactorGraph& fg)
|
||||||
{
|
{
|
||||||
const FgVarSet& vars = fg.getVarNodes();
|
const VarNodes& varNodes = fg.varNodes();
|
||||||
for (unsigned i = 0; i < vars.size(); i++) {
|
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||||
FgVarNode* varNode = new FgVarNode (vars[i]);
|
addVarNode (new VarNode (varNodes[i]));
|
||||||
addVariable (varNode);
|
|
||||||
}
|
}
|
||||||
|
const FacNodes& facNodes = fg.facNodes();
|
||||||
const FgFacSet& facs = fg.getFactorNodes();
|
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||||
for (unsigned i = 0; i < facs.size(); i++) {
|
FacNode* facNode = new FacNode (facNodes[i]->factor());
|
||||||
FgFacNode* facNode = new FgFacNode (facs[i]);
|
addFacNode (facNode);
|
||||||
addFactor (facNode);
|
const VarNodes& neighs = facNodes[i]->neighbors();
|
||||||
const FgVarSet& neighs = facs[i]->neighbors();
|
|
||||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||||
addEdge (facNode, varNodes_[neighs[j]->getIndex()]);
|
addEdge (varNodes_[neighs[j]->getIndex()], facNode);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
FactorGraph::FactorGraph (const BayesNet& bn)
|
|
||||||
{
|
|
||||||
const BnNodeSet& nodes = bn.getBayesNodes();
|
|
||||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
|
||||||
FgVarNode* varNode = new FgVarNode (nodes[i]);
|
|
||||||
addVariable (varNode);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
|
||||||
const BnNodeSet& parents = nodes[i]->getParents();
|
|
||||||
if (!(nodes[i]->hasEvidence() && parents.size() == 0)) {
|
|
||||||
VarNodes neighs;
|
|
||||||
neighs.push_back (varNodes_[nodes[i]->getIndex()]);
|
|
||||||
for (unsigned j = 0; j < parents.size(); j++) {
|
|
||||||
neighs.push_back (varNodes_[parents[j]->getIndex()]);
|
|
||||||
}
|
|
||||||
FgFacNode* fn = new FgFacNode (
|
|
||||||
new Factor (neighs, nodes[i]->params(), nodes[i]->distId()));
|
|
||||||
if (orderFactorVariables) {
|
|
||||||
sort (neighs.begin(), neighs.end(), CompVarId());
|
|
||||||
fn->factor()->reorderAccordingVarIds();
|
|
||||||
}
|
|
||||||
addFactor (fn);
|
|
||||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
|
||||||
addEdge (fn, static_cast<FgVarNode*> (neighs[j]));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
setIndexes();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
FactorGraph::readFromUaiFormat (const char* fileName)
|
FactorGraph::readFromUaiFormat (const char* fileName)
|
||||||
{
|
{
|
||||||
ifstream is (fileName);
|
std::ifstream is (fileName);
|
||||||
if (!is.is_open()) {
|
if (!is.is_open()) {
|
||||||
cerr << "error: cannot read from file " + std::string (fileName) << endl;
|
cerr << "error: cannot read from file " << fileName << endl;
|
||||||
abort();
|
abort();
|
||||||
}
|
}
|
||||||
|
ignoreLines (is);
|
||||||
string line;
|
string line;
|
||||||
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
|
||||||
getline (is, line);
|
getline (is, line);
|
||||||
if (line != "MARKOV") {
|
if (line != "MARKOV") {
|
||||||
cerr << "error: the network must be a MARKOV network " << endl;
|
cerr << "error: the network must be a MARKOV network " << endl;
|
||||||
abort();
|
abort();
|
||||||
}
|
}
|
||||||
|
// read the number of vars
|
||||||
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
ignoreLines (is);
|
||||||
unsigned nVars;
|
unsigned nrVars;
|
||||||
is >> nVars;
|
is >> nrVars;
|
||||||
|
// read the range of each var
|
||||||
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
ignoreLines (is);
|
||||||
vector<int> domainSizes (nVars);
|
Ranges ranges (nrVars);
|
||||||
for (unsigned i = 0; i < nVars; i++) {
|
for (unsigned i = 0; i < nrVars; i++) {
|
||||||
unsigned ds;
|
is >> ranges[i];
|
||||||
is >> ds;
|
|
||||||
domainSizes[i] = ds;
|
|
||||||
}
|
}
|
||||||
|
unsigned nrFactors;
|
||||||
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
unsigned nrArgs;
|
||||||
for (unsigned i = 0; i < nVars; i++) {
|
unsigned vid;
|
||||||
addVariable (new FgVarNode (i, domainSizes[i]));
|
is >> nrFactors;
|
||||||
}
|
vector<VarIds> factorVarIds;
|
||||||
|
vector<Ranges> factorRanges;
|
||||||
unsigned nFactors;
|
for (unsigned i = 0; i < nrFactors; i++) {
|
||||||
is >> nFactors;
|
ignoreLines (is);
|
||||||
for (unsigned i = 0; i < nFactors; i++) {
|
is >> nrArgs;
|
||||||
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
factorVarIds.push_back ({ });
|
||||||
unsigned nFactorVars;
|
factorRanges.push_back ({ });
|
||||||
is >> nFactorVars;
|
for (unsigned j = 0; j < nrArgs; j++) {
|
||||||
VarNodes neighs;
|
|
||||||
for (unsigned j = 0; j < nFactorVars; j++) {
|
|
||||||
unsigned vid;
|
|
||||||
is >> vid;
|
is >> vid;
|
||||||
FgVarNode* neigh = getFgVarNode (vid);
|
if (vid >= ranges.size()) {
|
||||||
if (!neigh) {
|
cerr << "error: invalid variable identifier `" << vid << "'" << endl;
|
||||||
cerr << "error: invalid variable identifier (" << vid << ")" << endl;
|
cerr << "identifiers must be between 0 and " << ranges.size() - 1 ;
|
||||||
|
cerr << endl;
|
||||||
abort();
|
abort();
|
||||||
}
|
}
|
||||||
neighs.push_back (neigh);
|
factorVarIds.back().push_back (vid);
|
||||||
}
|
factorRanges.back().push_back (ranges[vid]);
|
||||||
FgFacNode* fn = new FgFacNode (new Factor (neighs));
|
|
||||||
addFactor (fn);
|
|
||||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
|
||||||
addEdge (fn, static_cast<FgVarNode*> (neighs[j]));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// read the parameters
|
||||||
for (unsigned i = 0; i < nFactors; i++) {
|
unsigned nrParams;
|
||||||
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
for (unsigned i = 0; i < nrFactors; i++) {
|
||||||
unsigned nParams;
|
ignoreLines (is);
|
||||||
is >> nParams;
|
is >> nrParams;
|
||||||
if (facNodes_[i]->params().size() != nParams) {
|
if (nrParams != Util::expectedSize (factorRanges[i])) {
|
||||||
cerr << "error: invalid number of parameters for factor " ;
|
cerr << "error: invalid number of parameters for factor nº " << i ;
|
||||||
cerr << facNodes_[i]->getLabel() ;
|
cerr << ", expected: " << Util::expectedSize (factorRanges[i]);
|
||||||
cerr << ", expected: " << facNodes_[i]->params().size();
|
cerr << ", given: " << nrParams << endl;
|
||||||
cerr << ", given: " << nParams << endl;
|
|
||||||
abort();
|
abort();
|
||||||
}
|
}
|
||||||
Params params (nParams);
|
Params params (nrParams);
|
||||||
for (unsigned j = 0; j < nParams; j++) {
|
for (unsigned j = 0; j < nrParams; j++) {
|
||||||
double param;
|
is >> params[j];
|
||||||
is >> param;
|
|
||||||
params[j] = param;
|
|
||||||
}
|
}
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
Util::toLog (params);
|
Util::toLog (params);
|
||||||
}
|
}
|
||||||
facNodes_[i]->factor()->setParams (params);
|
addFactor (Factor (factorVarIds[i], factorRanges[i], params));
|
||||||
}
|
}
|
||||||
is.close();
|
is.close();
|
||||||
setIndexes();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -158,87 +111,58 @@ FactorGraph::readFromUaiFormat (const char* fileName)
|
|||||||
void
|
void
|
||||||
FactorGraph::readFromLibDaiFormat (const char* fileName)
|
FactorGraph::readFromLibDaiFormat (const char* fileName)
|
||||||
{
|
{
|
||||||
ifstream is (fileName);
|
std::ifstream is (fileName);
|
||||||
if (!is.is_open()) {
|
if (!is.is_open()) {
|
||||||
cerr << "error: cannot read from file " + std::string (fileName) << endl;
|
cerr << "error: cannot read from file " << fileName << endl;
|
||||||
abort();
|
abort();
|
||||||
}
|
}
|
||||||
|
ignoreLines (is);
|
||||||
string line;
|
unsigned nrFactors;
|
||||||
unsigned nFactors;
|
unsigned nrArgs;
|
||||||
|
VarId vid;
|
||||||
while ((is.peek()) == '#') getline (is, line);
|
is >> nrFactors;
|
||||||
is >> nFactors;
|
for (unsigned i = 0; i < nrFactors; i++) {
|
||||||
|
ignoreLines (is);
|
||||||
if (is.fail()) {
|
// read the factor arguments
|
||||||
cerr << "error: cannot read the number of factors" << endl;
|
is >> nrArgs;
|
||||||
abort();
|
|
||||||
}
|
|
||||||
|
|
||||||
getline (is, line);
|
|
||||||
if (is.fail() || line.size() > 0) {
|
|
||||||
cerr << "error: cannot read the number of factors" << endl;
|
|
||||||
abort();
|
|
||||||
}
|
|
||||||
|
|
||||||
for (unsigned i = 0; i < nFactors; i++) {
|
|
||||||
unsigned nVars;
|
|
||||||
while ((is.peek()) == '#') getline (is, line);
|
|
||||||
|
|
||||||
is >> nVars;
|
|
||||||
VarIds vids;
|
VarIds vids;
|
||||||
for (unsigned j = 0; j < nVars; j++) {
|
for (unsigned j = 0; j < nrArgs; j++) {
|
||||||
VarId vid;
|
ignoreLines (is);
|
||||||
while ((is.peek()) == '#') getline (is, line);
|
|
||||||
is >> vid;
|
is >> vid;
|
||||||
vids.push_back (vid);
|
vids.push_back (vid);
|
||||||
}
|
}
|
||||||
|
// read ranges
|
||||||
VarNodes neighs;
|
Ranges ranges (nrArgs);
|
||||||
unsigned nParams = 1;
|
for (unsigned j = 0; j < nrArgs; j++) {
|
||||||
for (unsigned j = 0; j < nVars; j++) {
|
ignoreLines (is);
|
||||||
unsigned dsize;
|
is >> ranges[j];
|
||||||
while ((is.peek()) == '#') getline (is, line);
|
VarNode* var = getVarNode (vids[j]);
|
||||||
is >> dsize;
|
if (var != 0 && ranges[j] != var->range()) {
|
||||||
FgVarNode* var = getFgVarNode (vids[j]);
|
cerr << "error: variable `" << vids[j] << "' appears in two or " ;
|
||||||
if (var == 0) {
|
cerr << "more factors with a different range" << endl;
|
||||||
var = new FgVarNode (vids[j], dsize);
|
|
||||||
addVariable (var);
|
|
||||||
} else {
|
|
||||||
if (var->nrStates() != dsize) {
|
|
||||||
cerr << "error: variable `" << vids[j] << "' appears in two or " ;
|
|
||||||
cerr << "more factors with different domain sizes" << endl;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
neighs.push_back (var);
|
|
||||||
nParams *= var->nrStates();
|
|
||||||
}
|
}
|
||||||
Params params (nParams, 0);
|
// read parameters
|
||||||
|
ignoreLines (is);
|
||||||
unsigned nNonzeros;
|
unsigned nNonzeros;
|
||||||
while ((is.peek()) == '#') getline (is, line);
|
|
||||||
is >> nNonzeros;
|
is >> nNonzeros;
|
||||||
|
Params params (Util::expectedSize (ranges), 0);
|
||||||
for (unsigned j = 0; j < nNonzeros; j++) {
|
for (unsigned j = 0; j < nNonzeros; j++) {
|
||||||
|
ignoreLines (is);
|
||||||
unsigned index;
|
unsigned index;
|
||||||
double val;
|
|
||||||
while ((is.peek()) == '#') getline (is, line);
|
|
||||||
is >> index;
|
is >> index;
|
||||||
while ((is.peek()) == '#') getline (is, line);
|
ignoreLines (is);
|
||||||
|
double val;
|
||||||
is >> val;
|
is >> val;
|
||||||
params[index] = val;
|
params[index] = val;
|
||||||
}
|
}
|
||||||
reverse (neighs.begin(), neighs.end());
|
reverse (vids.begin(), vids.end());
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
Util::toLog (params);
|
Util::toLog (params);
|
||||||
}
|
}
|
||||||
FgFacNode* fn = new FgFacNode (new Factor (neighs, params));
|
addFactor (Factor (vids, ranges, params));
|
||||||
addFactor (fn);
|
|
||||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
|
||||||
addEdge (fn, static_cast<FgVarNode*> (neighs[j]));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
is.close();
|
is.close();
|
||||||
setIndexes();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -256,17 +180,41 @@ FactorGraph::~FactorGraph (void)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
FactorGraph::addVariable (FgVarNode* vn)
|
FactorGraph::addFactor (const Factor& factor)
|
||||||
{
|
{
|
||||||
varNodes_.push_back (vn);
|
FacNode* fn = new FacNode (factor);
|
||||||
vn->setIndex (varNodes_.size() - 1);
|
addFacNode (fn);
|
||||||
varMap_.insert (make_pair (vn->varId(), varNodes_.size() - 1));
|
const VarIds& vids = factor.arguments();
|
||||||
|
for (unsigned i = 0; i < vids.size(); i++) {
|
||||||
|
bool found = false;
|
||||||
|
for (unsigned j = 0; j < varNodes_.size(); j++) {
|
||||||
|
if (varNodes_[j]->varId() == vids[i]) {
|
||||||
|
addEdge (varNodes_[j], fn);
|
||||||
|
found = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (found == false) {
|
||||||
|
VarNode* vn = new VarNode (vids[i], factor.range (i));
|
||||||
|
addVarNode (vn);
|
||||||
|
addEdge (vn, fn);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
FactorGraph::addFactor (FgFacNode* fn)
|
FactorGraph::addVarNode (VarNode* vn)
|
||||||
|
{
|
||||||
|
varNodes_.push_back (vn);
|
||||||
|
vn->setIndex (varNodes_.size() - 1);
|
||||||
|
varMap_.insert (make_pair (vn->varId(), vn));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
FactorGraph::addFacNode (FacNode* fn)
|
||||||
{
|
{
|
||||||
facNodes_.push_back (fn);
|
facNodes_.push_back (fn);
|
||||||
fn->setIndex (facNodes_.size() - 1);
|
fn->setIndex (facNodes_.size() - 1);
|
||||||
@ -275,7 +223,7 @@ FactorGraph::addFactor (FgFacNode* fn)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
FactorGraph::addEdge (FgVarNode* vn, FgFacNode* fn)
|
FactorGraph::addEdge (VarNode* vn, FacNode* fn)
|
||||||
{
|
{
|
||||||
vn->addNeighbor (fn);
|
vn->addNeighbor (fn);
|
||||||
fn->addNeighbor (vn);
|
fn->addNeighbor (vn);
|
||||||
@ -283,37 +231,6 @@ FactorGraph::addEdge (FgVarNode* vn, FgFacNode* fn)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
FactorGraph::addEdge (FgFacNode* fn, FgVarNode* vn)
|
|
||||||
{
|
|
||||||
fn->addNeighbor (vn);
|
|
||||||
vn->addNeighbor (fn);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
VarNode*
|
|
||||||
FactorGraph::getVariableNode (VarId vid) const
|
|
||||||
{
|
|
||||||
FgVarNode* vn = getFgVarNode (vid);
|
|
||||||
assert (vn);
|
|
||||||
return vn;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
VarNodes
|
|
||||||
FactorGraph::getVariableNodes (void) const
|
|
||||||
{
|
|
||||||
VarNodes vars;
|
|
||||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
|
||||||
vars.push_back (varNodes_[i]);
|
|
||||||
}
|
|
||||||
return vars;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
bool
|
||||||
FactorGraph::isTree (void) const
|
FactorGraph::isTree (void) const
|
||||||
{
|
{
|
||||||
@ -322,36 +239,42 @@ FactorGraph::isTree (void) const
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
DAGraph&
|
||||||
FactorGraph::setIndexes (void)
|
FactorGraph::getStructure (void)
|
||||||
{
|
{
|
||||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
assert (fromBayesNet_);
|
||||||
varNodes_[i]->setIndex (i);
|
if (structure_.empty()) {
|
||||||
}
|
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
structure_.addNode (new DAGraphNode (varNodes_[i]));
|
||||||
facNodes_[i]->setIndex (i);
|
}
|
||||||
|
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||||
|
const VarIds& vids = facNodes_[i]->factor().arguments();
|
||||||
|
for (unsigned j = 1; j < vids.size(); j++) {
|
||||||
|
structure_.addEdge (vids[j], vids[0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
return structure_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
FactorGraph::printGraphicalModel (void) const
|
FactorGraph::print (void) const
|
||||||
{
|
{
|
||||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||||
cout << "VarId = " << varNodes_[i]->varId() << endl;
|
cout << "var id = " << varNodes_[i]->varId() << endl;
|
||||||
cout << "Label = " << varNodes_[i]->label() << endl;
|
cout << "label = " << varNodes_[i]->label() << endl;
|
||||||
cout << "Nr States = " << varNodes_[i]->nrStates() << endl;
|
cout << "range = " << varNodes_[i]->range() << endl;
|
||||||
cout << "Evidence = " << varNodes_[i]->getEvidence() << endl;
|
cout << "evidence = " << varNodes_[i]->getEvidence() << endl;
|
||||||
cout << "Factors = " ;
|
cout << "factors = " ;
|
||||||
for (unsigned j = 0; j < varNodes_[i]->neighbors().size(); j++) {
|
for (unsigned j = 0; j < varNodes_[i]->neighbors().size(); j++) {
|
||||||
cout << varNodes_[i]->neighbors()[j]->getLabel() << " " ;
|
cout << varNodes_[i]->neighbors()[j]->getLabel() << " " ;
|
||||||
}
|
}
|
||||||
cout << endl << endl;
|
cout << endl << endl;
|
||||||
}
|
}
|
||||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||||
facNodes_[i]->factor()->print();
|
facNodes_[i]->factor().print();
|
||||||
cout << endl;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -366,31 +289,26 @@ FactorGraph::exportToGraphViz (const char* fileName) const
|
|||||||
cerr << "FactorGraph::exportToDotFile()" << endl;
|
cerr << "FactorGraph::exportToDotFile()" << endl;
|
||||||
abort();
|
abort();
|
||||||
}
|
}
|
||||||
|
|
||||||
out << "graph \"" << fileName << "\" {" << endl;
|
out << "graph \"" << fileName << "\" {" << endl;
|
||||||
|
|
||||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||||
if (varNodes_[i]->hasEvidence()) {
|
if (varNodes_[i]->hasEvidence()) {
|
||||||
out << '"' << varNodes_[i]->label() << '"' ;
|
out << '"' << varNodes_[i]->label() << '"' ;
|
||||||
out << " [style=filled, fillcolor=yellow]" << endl;
|
out << " [style=filled, fillcolor=yellow]" << endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||||
out << '"' << facNodes_[i]->getLabel() << '"' ;
|
out << '"' << facNodes_[i]->getLabel() << '"' ;
|
||||||
out << " [label=\"" << facNodes_[i]->getLabel();
|
out << " [label=\"" << facNodes_[i]->getLabel();
|
||||||
out << "\"" << ", shape=box]" << endl;
|
out << "\"" << ", shape=box]" << endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||||
const FgVarSet& myVars = facNodes_[i]->neighbors();
|
const VarNodes& myVars = facNodes_[i]->neighbors();
|
||||||
for (unsigned j = 0; j < myVars.size(); j++) {
|
for (unsigned j = 0; j < myVars.size(); j++) {
|
||||||
out << '"' << facNodes_[i]->getLabel() << '"' ;
|
out << '"' << facNodes_[i]->getLabel() << '"' ;
|
||||||
out << " -- " ;
|
out << " -- " ;
|
||||||
out << '"' << myVars[j]->label() << '"' << endl;
|
out << '"' << myVars[j]->label() << '"' << endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
out << "}" << endl;
|
out << "}" << endl;
|
||||||
out.close();
|
out.close();
|
||||||
}
|
}
|
||||||
@ -402,30 +320,26 @@ FactorGraph::exportToUaiFormat (const char* fileName) const
|
|||||||
{
|
{
|
||||||
ofstream out (fileName);
|
ofstream out (fileName);
|
||||||
if (!out.is_open()) {
|
if (!out.is_open()) {
|
||||||
cerr << "error: cannot open file to write at " ;
|
cerr << "error: cannot open file " << fileName << endl;
|
||||||
cerr << "FactorGraph::exportToUaiFormat()" << endl;
|
|
||||||
abort();
|
abort();
|
||||||
}
|
}
|
||||||
|
|
||||||
out << "MARKOV" << endl;
|
out << "MARKOV" << endl;
|
||||||
out << varNodes_.size() << endl;
|
out << varNodes_.size() << endl;
|
||||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||||
out << varNodes_[i]->nrStates() << " " ;
|
out << varNodes_[i]->range() << " " ;
|
||||||
}
|
}
|
||||||
out << endl;
|
out << endl;
|
||||||
|
|
||||||
out << facNodes_.size() << endl;
|
out << facNodes_.size() << endl;
|
||||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||||
const FgVarSet& factorVars = facNodes_[i]->neighbors();
|
const VarNodes& factorVars = facNodes_[i]->neighbors();
|
||||||
out << factorVars.size();
|
out << factorVars.size();
|
||||||
for (unsigned j = 0; j < factorVars.size(); j++) {
|
for (unsigned j = 0; j < factorVars.size(); j++) {
|
||||||
out << " " << factorVars[j]->getIndex();
|
out << " " << factorVars[j]->getIndex();
|
||||||
}
|
}
|
||||||
out << endl;
|
out << endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||||
Params params = facNodes_[i]->params();
|
Params params = facNodes_[i]->factor().params();
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
Util::fromLog (params);
|
Util::fromLog (params);
|
||||||
}
|
}
|
||||||
@ -435,7 +349,6 @@ FactorGraph::exportToUaiFormat (const char* fileName) const
|
|||||||
}
|
}
|
||||||
out << endl;
|
out << endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
out.close();
|
out.close();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -446,23 +359,22 @@ FactorGraph::exportToLibDaiFormat (const char* fileName) const
|
|||||||
{
|
{
|
||||||
ofstream out (fileName);
|
ofstream out (fileName);
|
||||||
if (!out.is_open()) {
|
if (!out.is_open()) {
|
||||||
cerr << "error: cannot open file to write at " ;
|
cerr << "error: cannot open file " << fileName << endl;
|
||||||
cerr << "FactorGraph::exportToLibDaiFormat()" << endl;
|
|
||||||
abort();
|
abort();
|
||||||
}
|
}
|
||||||
out << facNodes_.size() << endl << endl;
|
out << facNodes_.size() << endl << endl;
|
||||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||||
const FgVarSet& factorVars = facNodes_[i]->neighbors();
|
const VarNodes& factorVars = facNodes_[i]->neighbors();
|
||||||
out << factorVars.size() << endl;
|
out << factorVars.size() << endl;
|
||||||
for (int j = factorVars.size() - 1; j >= 0; j--) {
|
for (int j = factorVars.size() - 1; j >= 0; j--) {
|
||||||
out << factorVars[j]->varId() << " " ;
|
out << factorVars[j]->varId() << " " ;
|
||||||
}
|
}
|
||||||
out << endl;
|
out << endl;
|
||||||
for (unsigned j = 0; j < factorVars.size(); j++) {
|
for (unsigned j = 0; j < factorVars.size(); j++) {
|
||||||
out << factorVars[j]->nrStates() << " " ;
|
out << factorVars[j]->range() << " " ;
|
||||||
}
|
}
|
||||||
out << endl;
|
out << endl;
|
||||||
Params params = facNodes_[i]->factor()->params();
|
Params params = facNodes_[i]->factor().params();
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
Util::fromLog (params);
|
Util::fromLog (params);
|
||||||
}
|
}
|
||||||
@ -477,6 +389,17 @@ FactorGraph::exportToLibDaiFormat (const char* fileName) const
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
FactorGraph::ignoreLines (std::ifstream& is) const
|
||||||
|
{
|
||||||
|
string ignoreStr;
|
||||||
|
while (is.peek() == '#' || is.peek() == '\n') {
|
||||||
|
getline (is, ignoreStr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
bool
|
||||||
FactorGraph::containsCycle (void) const
|
FactorGraph::containsCycle (void) const
|
||||||
{
|
{
|
||||||
@ -496,13 +419,14 @@ FactorGraph::containsCycle (void) const
|
|||||||
|
|
||||||
|
|
||||||
bool
|
bool
|
||||||
FactorGraph::containsCycle (const FgVarNode* v,
|
FactorGraph::containsCycle (
|
||||||
const FgFacNode* p,
|
const VarNode* v,
|
||||||
vector<bool>& visitedVars,
|
const FacNode* p,
|
||||||
vector<bool>& visitedFactors) const
|
vector<bool>& visitedVars,
|
||||||
|
vector<bool>& visitedFactors) const
|
||||||
{
|
{
|
||||||
visitedVars[v->getIndex()] = true;
|
visitedVars[v->getIndex()] = true;
|
||||||
const FgFacSet& adjacencies = v->neighbors();
|
const FacNodes& adjacencies = v->neighbors();
|
||||||
for (unsigned i = 0; i < adjacencies.size(); i++) {
|
for (unsigned i = 0; i < adjacencies.size(); i++) {
|
||||||
int w = adjacencies[i]->getIndex();
|
int w = adjacencies[i]->getIndex();
|
||||||
if (!visitedFactors[w]) {
|
if (!visitedFactors[w]) {
|
||||||
@ -520,13 +444,14 @@ FactorGraph::containsCycle (const FgVarNode* v,
|
|||||||
|
|
||||||
|
|
||||||
bool
|
bool
|
||||||
FactorGraph::containsCycle (const FgFacNode* v,
|
FactorGraph::containsCycle (
|
||||||
const FgVarNode* p,
|
const FacNode* v,
|
||||||
vector<bool>& visitedVars,
|
const VarNode* p,
|
||||||
vector<bool>& visitedFactors) const
|
vector<bool>& visitedVars,
|
||||||
|
vector<bool>& visitedFactors) const
|
||||||
{
|
{
|
||||||
visitedFactors[v->getIndex()] = true;
|
visitedFactors[v->getIndex()] = true;
|
||||||
const FgVarSet& adjacencies = v->neighbors();
|
const VarNodes& adjacencies = v->neighbors();
|
||||||
for (unsigned i = 0; i < adjacencies.size(); i++) {
|
for (unsigned i = 0; i < adjacencies.size(); i++) {
|
||||||
int w = adjacencies[i]->getIndex();
|
int w = adjacencies[i]->getIndex();
|
||||||
if (!visitedVars[w]) {
|
if (!visitedVars[w]) {
|
||||||
|
@ -3,136 +3,109 @@
|
|||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "GraphicalModel.h"
|
|
||||||
#include "Factor.h"
|
#include "Factor.h"
|
||||||
|
#include "BayesNet.h"
|
||||||
#include "Horus.h"
|
#include "Horus.h"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
class BayesNet;
|
|
||||||
class FgFacNode;
|
|
||||||
|
|
||||||
|
class FacNode;
|
||||||
|
|
||||||
class FgVarNode : public VarNode
|
class VarNode : public Var
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
FgVarNode (VarId varId, unsigned nrStates) : VarNode (varId, nrStates) { }
|
VarNode (VarId varId, unsigned nrStates)
|
||||||
|
: Var (varId, nrStates) { }
|
||||||
|
|
||||||
FgVarNode (const VarNode* v) : VarNode (v) { }
|
VarNode (const Var* v) : Var (v) { }
|
||||||
|
|
||||||
void addNeighbor (FgFacNode* fn) { neighs_.push_back (fn); }
|
void addNeighbor (FacNode* fn) { neighs_.push_back (fn); }
|
||||||
|
|
||||||
const FgFacSet& neighbors (void) const { return neighs_; }
|
const FacNodes& neighbors (void) const { return neighs_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
DISALLOW_COPY_AND_ASSIGN (FgVarNode);
|
DISALLOW_COPY_AND_ASSIGN (VarNode);
|
||||||
|
|
||||||
FgFacSet neighs_;
|
FacNodes neighs_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
class FgFacNode
|
class FacNode
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
FgFacNode (const FgFacNode* fn)
|
FacNode (const Factor& f) : factor_(f), index_(-1) { }
|
||||||
{
|
|
||||||
factor_ = new Factor (*fn->factor());
|
|
||||||
index_ = -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
FgFacNode (Factor* f) : factor_(new Factor(*f)), index_(-1) { }
|
const Factor& factor (void) const { return factor_; }
|
||||||
|
|
||||||
Factor* factor() const { return factor_; }
|
Factor& factor (void) { return factor_; }
|
||||||
|
|
||||||
void addNeighbor (FgVarNode* vn) { neighs_.push_back (vn); }
|
void addNeighbor (VarNode* vn) { neighs_.push_back (vn); }
|
||||||
|
|
||||||
const FgVarSet& neighbors (void) const { return neighs_; }
|
const VarNodes& neighbors (void) const { return neighs_; }
|
||||||
|
|
||||||
int getIndex (void) const
|
int getIndex (void) const { return index_; }
|
||||||
{
|
|
||||||
assert (index_ != -1);
|
|
||||||
return index_;
|
|
||||||
}
|
|
||||||
|
|
||||||
void setIndex (int index)
|
void setIndex (int index) { index_ = index; }
|
||||||
{
|
|
||||||
index_ = index;
|
|
||||||
}
|
|
||||||
|
|
||||||
const Params& params (void) const
|
string getLabel (void) { return factor_.getLabel(); }
|
||||||
{
|
|
||||||
return factor_->params();
|
|
||||||
}
|
|
||||||
|
|
||||||
string getLabel (void)
|
|
||||||
{
|
|
||||||
return factor_->getLabel();
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
DISALLOW_COPY_AND_ASSIGN (FgFacNode);
|
DISALLOW_COPY_AND_ASSIGN (FacNode);
|
||||||
|
|
||||||
Factor* factor_;
|
VarNodes neighs_;
|
||||||
FgVarSet neighs_;
|
Factor factor_;
|
||||||
int index_;
|
int index_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
struct CompVarId
|
struct CompVarId
|
||||||
{
|
{
|
||||||
bool operator() (const VarNode* vn1, const VarNode* vn2) const
|
bool operator() (const Var* v1, const Var* v2) const
|
||||||
{
|
{
|
||||||
return vn1->varId() < vn2->varId();
|
return v1->varId() < v2->varId();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
class FactorGraph : public GraphicalModel
|
class FactorGraph
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
FactorGraph (void) { };
|
FactorGraph (bool fbn = false) : fromBayesNet_(fbn) { }
|
||||||
|
|
||||||
FactorGraph (const FactorGraph&);
|
FactorGraph (const FactorGraph&);
|
||||||
|
|
||||||
FactorGraph (const BayesNet&);
|
|
||||||
|
|
||||||
~FactorGraph (void);
|
~FactorGraph (void);
|
||||||
|
|
||||||
const FgVarSet& getVarNodes (void) const { return varNodes_; }
|
const VarNodes& varNodes (void) const { return varNodes_; }
|
||||||
|
|
||||||
const FgFacSet& getFactorNodes (void) const { return facNodes_; }
|
const FacNodes& facNodes (void) const { return facNodes_; }
|
||||||
|
|
||||||
FgVarNode* getFgVarNode (VarId vid) const
|
bool isFromBayesNetwork (void) const { return fromBayesNet_ ; }
|
||||||
|
|
||||||
|
VarNode* getVarNode (VarId vid) const
|
||||||
{
|
{
|
||||||
IndexMap::const_iterator it = varMap_.find (vid);
|
VarMap::const_iterator it = varMap_.find (vid);
|
||||||
if (it == varMap_.end()) {
|
return it != varMap_.end() ? it->second : 0;
|
||||||
return 0;
|
|
||||||
} else {
|
|
||||||
return varNodes_[it->second];
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void readFromUaiFormat (const char*);
|
void readFromUaiFormat (const char*);
|
||||||
|
|
||||||
void readFromLibDaiFormat (const char*);
|
void readFromLibDaiFormat (const char*);
|
||||||
|
|
||||||
void addVariable (FgVarNode*);
|
void addFactor (const Factor& factor);
|
||||||
|
|
||||||
void addFactor (FgFacNode*);
|
void addVarNode (VarNode*);
|
||||||
|
|
||||||
void addEdge (FgVarNode*, FgFacNode*);
|
void addFacNode (FacNode*);
|
||||||
|
|
||||||
void addEdge (FgFacNode*, FgVarNode*);
|
void addEdge (VarNode*, FacNode*);
|
||||||
|
|
||||||
VarNode* getVariableNode (unsigned) const;
|
|
||||||
|
|
||||||
VarNodes getVariableNodes (void) const;
|
|
||||||
|
|
||||||
bool isTree (void) const;
|
bool isTree (void) const;
|
||||||
|
|
||||||
void setIndexes (void);
|
DAGraph& getStructure (void);
|
||||||
|
|
||||||
void printGraphicalModel (void) const;
|
void print (void) const;
|
||||||
|
|
||||||
void exportToGraphViz (const char*) const;
|
void exportToGraphViz (const char*) const;
|
||||||
|
|
||||||
@ -145,19 +118,24 @@ class FactorGraph : public GraphicalModel
|
|||||||
private:
|
private:
|
||||||
// DISALLOW_COPY_AND_ASSIGN (FactorGraph);
|
// DISALLOW_COPY_AND_ASSIGN (FactorGraph);
|
||||||
|
|
||||||
|
void ignoreLines (std::ifstream&) const;
|
||||||
|
|
||||||
bool containsCycle (void) const;
|
bool containsCycle (void) const;
|
||||||
|
|
||||||
bool containsCycle (const FgVarNode*, const FgFacNode*,
|
bool containsCycle (const VarNode*, const FacNode*,
|
||||||
vector<bool>&, vector<bool>&) const;
|
vector<bool>&, vector<bool>&) const;
|
||||||
|
|
||||||
bool containsCycle (const FgFacNode*, const FgVarNode*,
|
bool containsCycle (const FacNode*, const VarNode*,
|
||||||
vector<bool>&, vector<bool>&) const;
|
vector<bool>&, vector<bool>&) const;
|
||||||
|
|
||||||
FgVarSet varNodes_;
|
VarNodes varNodes_;
|
||||||
FgFacSet facNodes_;
|
FacNodes facNodes_;
|
||||||
|
|
||||||
typedef unordered_map<unsigned, unsigned> IndexMap;
|
DAGraph structure_;
|
||||||
IndexMap varMap_;
|
bool fromBayesNet_;
|
||||||
|
|
||||||
|
typedef unordered_map<unsigned, VarNode*> VarMap;
|
||||||
|
VarMap varMap_;
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // HORUS_FACTORGRAPH_H
|
#endif // HORUS_FACTORGRAPH_H
|
||||||
|
@ -455,7 +455,7 @@ FoveSolver::absorveEvidence (
|
|||||||
}
|
}
|
||||||
pfList.add (newPfs);
|
pfList.add (newPfs);
|
||||||
}
|
}
|
||||||
if (Constants::DEBUG > 1 && obsFormulas.empty() == false) {
|
if (Constants::DEBUG >= 2 && obsFormulas.empty() == false) {
|
||||||
Util::printAsteriskLine();
|
Util::printAsteriskLine();
|
||||||
cout << "AFTER EVIDENCE ABSORVED" << endl;
|
cout << "AFTER EVIDENCE ABSORVED" << endl;
|
||||||
for (unsigned i = 0; i < obsFormulas.size(); i++) {
|
for (unsigned i = 0; i < obsFormulas.size(); i++) {
|
||||||
@ -493,7 +493,7 @@ FoveSolver::runSolver (const Grounds& query)
|
|||||||
shatterAgainstQuery (query);
|
shatterAgainstQuery (query);
|
||||||
runWeakBayesBall (query);
|
runWeakBayesBall (query);
|
||||||
while (true) {
|
while (true) {
|
||||||
if (Constants::DEBUG > 1) {
|
if (Constants::DEBUG >= 2) {
|
||||||
Util::printDashedLine();
|
Util::printDashedLine();
|
||||||
pfList_.print();
|
pfList_.print();
|
||||||
LiftedOperator::printValidOps (pfList_, query);
|
LiftedOperator::printValidOps (pfList_, query);
|
||||||
@ -502,7 +502,7 @@ FoveSolver::runSolver (const Grounds& query)
|
|||||||
if (op == 0) {
|
if (op == 0) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
if (Constants::DEBUG > 1) {
|
if (Constants::DEBUG >= 2) {
|
||||||
cout << "best operation: " << op->toString() << endl;
|
cout << "best operation: " << op->toString() << endl;
|
||||||
}
|
}
|
||||||
op->apply();
|
op->apply();
|
||||||
@ -594,7 +594,7 @@ FoveSolver::runWeakBayesBall (const Grounds& query)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Constants::DEBUG > 1) {
|
if (Constants::DEBUG >= 2) {
|
||||||
Util::printHeader ("REQUIRED PARFACTORS");
|
Util::printHeader ("REQUIRED PARFACTORS");
|
||||||
pfList_.print();
|
pfList_.print();
|
||||||
}
|
}
|
||||||
@ -605,15 +605,16 @@ FoveSolver::runWeakBayesBall (const Grounds& query)
|
|||||||
void
|
void
|
||||||
FoveSolver::shatterAgainstQuery (const Grounds& query)
|
FoveSolver::shatterAgainstQuery (const Grounds& query)
|
||||||
{
|
{
|
||||||
return ;
|
|
||||||
for (unsigned i = 0; i < query.size(); i++) {
|
for (unsigned i = 0; i < query.size(); i++) {
|
||||||
if (query[i].isAtom()) {
|
if (query[i].isAtom()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
bool found = false;
|
||||||
Parfactors newPfs;
|
Parfactors newPfs;
|
||||||
ParfactorList::iterator it = pfList_.begin();
|
ParfactorList::iterator it = pfList_.begin();
|
||||||
while (it != pfList_.end()) {
|
while (it != pfList_.end()) {
|
||||||
if ((*it)->containsGround (query[i])) {
|
if ((*it)->containsGround (query[i])) {
|
||||||
|
found = true;
|
||||||
std::pair<ConstraintTree*, ConstraintTree*> split =
|
std::pair<ConstraintTree*, ConstraintTree*> split =
|
||||||
(*it)->constr()->split (query[i].args(), query[i].arity());
|
(*it)->constr()->split (query[i].args(), query[i].arity());
|
||||||
ConstraintTree* commCt = split.first;
|
ConstraintTree* commCt = split.first;
|
||||||
@ -629,9 +630,14 @@ FoveSolver::shatterAgainstQuery (const Grounds& query)
|
|||||||
++ it;
|
++ it;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (found == false) {
|
||||||
|
cerr << "error: could not find a parfactor with ground " ;
|
||||||
|
cerr << "`" << query[i] << "'" << endl;
|
||||||
|
exit (0);
|
||||||
|
}
|
||||||
pfList_.add (newPfs);
|
pfList_.add (newPfs);
|
||||||
}
|
}
|
||||||
if (Constants::DEBUG > 1) {
|
if (Constants::DEBUG >= 2) {
|
||||||
cout << endl;
|
cout << endl;
|
||||||
Util::printAsteriskLine();
|
Util::printAsteriskLine();
|
||||||
cout << "SHATTERED AGAINST THE QUERY" << endl;
|
cout << "SHATTERED AGAINST THE QUERY" << endl;
|
||||||
|
@ -1,64 +0,0 @@
|
|||||||
#ifndef HORUS_GRAPHICALMODEL_H
|
|
||||||
#define HORUS_GRAPHICALMODEL_H
|
|
||||||
|
|
||||||
#include <cassert>
|
|
||||||
|
|
||||||
#include <unordered_map>
|
|
||||||
|
|
||||||
#include <sstream>
|
|
||||||
|
|
||||||
#include "VarNode.h"
|
|
||||||
#include "Util.h"
|
|
||||||
#include "Horus.h"
|
|
||||||
|
|
||||||
using namespace std;
|
|
||||||
|
|
||||||
|
|
||||||
struct VarInfo
|
|
||||||
{
|
|
||||||
VarInfo (string l, const States& sts) : label(l), states(sts) { }
|
|
||||||
string label;
|
|
||||||
States states;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
class GraphicalModel
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
virtual ~GraphicalModel (void) { };
|
|
||||||
|
|
||||||
virtual VarNode* getVariableNode (VarId) const = 0;
|
|
||||||
|
|
||||||
virtual VarNodes getVariableNodes (void) const = 0;
|
|
||||||
|
|
||||||
virtual void printGraphicalModel (void) const = 0;
|
|
||||||
|
|
||||||
static void addVariableInformation (
|
|
||||||
VarId vid, string label, const States& states)
|
|
||||||
{
|
|
||||||
assert (Util::contains (varsInfo_, vid) == false);
|
|
||||||
varsInfo_.insert (make_pair (vid, VarInfo (label, states)));
|
|
||||||
}
|
|
||||||
|
|
||||||
static VarInfo getVarInformation (VarId vid)
|
|
||||||
{
|
|
||||||
assert (Util::contains (varsInfo_, vid));
|
|
||||||
return varsInfo_.find (vid)->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool variablesHaveInformation (void)
|
|
||||||
{
|
|
||||||
return varsInfo_.size() != 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void clearVariablesInformation (void)
|
|
||||||
{
|
|
||||||
varsInfo_.clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
static unordered_map<VarId,VarInfo> varsInfo_;
|
|
||||||
};
|
|
||||||
|
|
||||||
#endif // HORUS_GRAPHICALMODEL_H
|
|
||||||
|
|
@ -11,30 +11,27 @@
|
|||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
class VarNode;
|
class Var;
|
||||||
class BayesNode;
|
|
||||||
class FgVarNode;
|
|
||||||
class FgFacNode;
|
|
||||||
class Factor;
|
class Factor;
|
||||||
|
class VarNode;
|
||||||
|
class FacNode;
|
||||||
|
|
||||||
typedef vector<double> Params;
|
typedef vector<double> Params;
|
||||||
typedef unsigned VarId;
|
typedef unsigned VarId;
|
||||||
typedef vector<VarId> VarIds;
|
typedef vector<VarId> VarIds;
|
||||||
typedef vector<VarNode*> VarNodes;
|
typedef vector<Var*> Vars;
|
||||||
typedef vector<BayesNode*> BnNodeSet;
|
typedef vector<VarNode*> VarNodes;
|
||||||
typedef vector<FgVarNode*> FgVarSet;
|
typedef vector<FacNode*> FacNodes;
|
||||||
typedef vector<FgFacNode*> FgFacSet;
|
typedef vector<Factor*> Factors;
|
||||||
typedef vector<Factor*> FactorSet;
|
typedef vector<string> States;
|
||||||
typedef vector<string> States;
|
typedef vector<unsigned> Ranges;
|
||||||
typedef vector<unsigned> Ranges;
|
|
||||||
|
|
||||||
|
|
||||||
enum InfAlgorithms
|
enum InfAlgorithms
|
||||||
{
|
{
|
||||||
VE, // variable elimination
|
VE, // variable elimination
|
||||||
BN_BP, // bayesian network belief propagation
|
BP, // belief propagation
|
||||||
FG_BP, // factor graph belief propagation
|
CBP // counting belief propagation
|
||||||
CBP // counting bp solver
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
@ -50,7 +47,7 @@ extern InfAlgorithms infAlgorithm;
|
|||||||
namespace Constants {
|
namespace Constants {
|
||||||
|
|
||||||
// level of debug information
|
// level of debug information
|
||||||
const unsigned DEBUG = 2;
|
const unsigned DEBUG = 0;
|
||||||
|
|
||||||
const int NO_EVIDENCE = -1;
|
const int NO_EVIDENCE = -1;
|
||||||
|
|
||||||
|
@ -3,137 +3,70 @@
|
|||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
#include "BayesNet.h"
|
|
||||||
#include "FactorGraph.h"
|
#include "FactorGraph.h"
|
||||||
#include "VarElimSolver.h"
|
#include "VarElimSolver.h"
|
||||||
#include "BnBpSolver.h"
|
#include "BpSolver.h"
|
||||||
#include "FgBpSolver.h"
|
|
||||||
#include "CbpSolver.h"
|
#include "CbpSolver.h"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
void processArguments (BayesNet&, int, const char* []);
|
|
||||||
void processArguments (FactorGraph&, int, const char* []);
|
void processArguments (FactorGraph&, int, const char* []);
|
||||||
void runSolver (Solver*, const VarNodes&);
|
void runSolver (const FactorGraph&, const VarIds&);
|
||||||
|
|
||||||
const string USAGE = "usage: \
|
const string USAGE = "usage: \
|
||||||
./hcli FILE [VARIABLE | OBSERVED_VARIABLE=EVIDENCE]..." ;
|
./hcli ve|bp|cbp NETWORK_FILE [VARIABLE | OBSERVED_VARIABLE=EVIDENCE]..." ;
|
||||||
|
|
||||||
|
|
||||||
int
|
int
|
||||||
main (int argc, const char* argv[])
|
main (int argc, const char* argv[])
|
||||||
{
|
{
|
||||||
if (!argv[1]) {
|
if (argc <= 1) {
|
||||||
|
cerr << "error: no solver specified" << endl;
|
||||||
cerr << "error: no graphical model specified" << endl;
|
cerr << "error: no graphical model specified" << endl;
|
||||||
cerr << USAGE << endl;
|
cerr << USAGE << endl;
|
||||||
exit (0);
|
exit (0);
|
||||||
}
|
}
|
||||||
const string& fileName = argv[1];
|
if (argc <= 2) {
|
||||||
const string& extension = fileName.substr (fileName.find_last_of ('.') + 1);
|
cerr << "error: no graphical model specified" << endl;
|
||||||
if (extension == "xml") {
|
cerr << USAGE << endl;
|
||||||
BayesNet bn;
|
|
||||||
bn.readFromBifFormat (argv[1]);
|
|
||||||
processArguments (bn, argc, argv);
|
|
||||||
} else if (extension == "uai") {
|
|
||||||
FactorGraph fg;
|
|
||||||
fg.readFromUaiFormat (argv[1]);
|
|
||||||
processArguments (fg, argc, argv);
|
|
||||||
} else if (extension == "fg") {
|
|
||||||
FactorGraph fg;
|
|
||||||
fg.readFromLibDaiFormat (argv[1]);
|
|
||||||
processArguments (fg, argc, argv);
|
|
||||||
} else {
|
|
||||||
cerr << "error: the graphical model must be defined either " ;
|
|
||||||
cerr << "in a xml, uai or libDAI file" << endl;
|
|
||||||
exit (0);
|
exit (0);
|
||||||
}
|
}
|
||||||
|
string solver (argv[1]);
|
||||||
|
if (solver == "ve") {
|
||||||
|
Globals::infAlgorithm = InfAlgorithms::VE;
|
||||||
|
} else if (solver == "bp") {
|
||||||
|
Globals::infAlgorithm = InfAlgorithms::BP;
|
||||||
|
} else if (solver == "cbp") {
|
||||||
|
Globals::infAlgorithm = InfAlgorithms::CBP;
|
||||||
|
} else {
|
||||||
|
cerr << "error: unknow solver `" << solver << "'" << endl ;
|
||||||
|
cerr << USAGE << endl;
|
||||||
|
exit(0);
|
||||||
|
}
|
||||||
|
string fileName (argv[2]);
|
||||||
|
string extension = fileName.substr (
|
||||||
|
fileName.find_last_of ('.') + 1);
|
||||||
|
FactorGraph fg;
|
||||||
|
if (extension == "uai") {
|
||||||
|
fg.readFromUaiFormat (fileName.c_str());
|
||||||
|
} else if (extension == "fg") {
|
||||||
|
fg.readFromLibDaiFormat (fileName.c_str());
|
||||||
|
} else {
|
||||||
|
cerr << "error: the graphical model must be defined either " ;
|
||||||
|
cerr << "in a UAI or libDAI file" << endl;
|
||||||
|
exit (0);
|
||||||
|
}
|
||||||
|
processArguments (fg, argc, argv);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
processArguments (BayesNet& bn, int argc, const char* argv[])
|
|
||||||
{
|
|
||||||
VarNodes queryVars;
|
|
||||||
for (int i = 2; i < argc; i++) {
|
|
||||||
const string& arg = argv[i];
|
|
||||||
if (arg.find ('=') == std::string::npos) {
|
|
||||||
BayesNode* queryVar = bn.getBayesNode (arg);
|
|
||||||
if (queryVar) {
|
|
||||||
queryVars.push_back (queryVar);
|
|
||||||
} else {
|
|
||||||
cerr << "error: there isn't a variable labeled of " ;
|
|
||||||
cerr << "`" << arg << "'" ;
|
|
||||||
cerr << endl;
|
|
||||||
exit (0);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
size_t pos = arg.find ('=');
|
|
||||||
const string& label = arg.substr (0, pos);
|
|
||||||
const string& state = arg.substr (pos + 1);
|
|
||||||
if (label.empty()) {
|
|
||||||
cerr << "error: missing left argument" << endl;
|
|
||||||
cerr << USAGE << endl;
|
|
||||||
exit (0);
|
|
||||||
}
|
|
||||||
if (state.empty()) {
|
|
||||||
cerr << "error: missing right argument" << endl;
|
|
||||||
cerr << USAGE << endl;
|
|
||||||
exit (0);
|
|
||||||
}
|
|
||||||
BayesNode* node = bn.getBayesNode (label);
|
|
||||||
if (node) {
|
|
||||||
if (node->isValidState (state)) {
|
|
||||||
node->setEvidence (state);
|
|
||||||
} else {
|
|
||||||
cerr << "error: `" << state << "' " ;
|
|
||||||
cerr << "is not a valid state for " ;
|
|
||||||
cerr << "`" << node->label() << "'" ;
|
|
||||||
cerr << endl;
|
|
||||||
exit (0);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
cerr << "error: there isn't a variable labeled of " ;
|
|
||||||
cerr << "`" << label << "'" ;
|
|
||||||
cerr << endl;
|
|
||||||
exit (0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Solver* solver = 0;
|
|
||||||
FactorGraph* fg = 0;
|
|
||||||
switch (Globals::infAlgorithm) {
|
|
||||||
case InfAlgorithms::VE:
|
|
||||||
fg = new FactorGraph (bn);
|
|
||||||
solver = new VarElimSolver (*fg);
|
|
||||||
break;
|
|
||||||
case InfAlgorithms::BN_BP:
|
|
||||||
solver = new BnBpSolver (bn);
|
|
||||||
break;
|
|
||||||
case InfAlgorithms::FG_BP:
|
|
||||||
fg = new FactorGraph (bn);
|
|
||||||
solver = new FgBpSolver (*fg);
|
|
||||||
break;
|
|
||||||
case InfAlgorithms::CBP:
|
|
||||||
fg = new FactorGraph (bn);
|
|
||||||
solver = new CbpSolver (*fg);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
assert (false);
|
|
||||||
}
|
|
||||||
runSolver (solver, queryVars);
|
|
||||||
delete fg;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
processArguments (FactorGraph& fg, int argc, const char* argv[])
|
processArguments (FactorGraph& fg, int argc, const char* argv[])
|
||||||
{
|
{
|
||||||
VarNodes queryVars;
|
VarIds queryIds;
|
||||||
for (int i = 2; i < argc; i++) {
|
for (int i = 3; i < argc; i++) {
|
||||||
const string& arg = argv[i];
|
const string& arg = argv[i];
|
||||||
if (arg.find ('=') == std::string::npos) {
|
if (arg.find ('=') == std::string::npos) {
|
||||||
if (!Util::isInteger (arg)) {
|
if (!Util::isInteger (arg)) {
|
||||||
@ -146,9 +79,9 @@ processArguments (FactorGraph& fg, int argc, const char* argv[])
|
|||||||
stringstream ss;
|
stringstream ss;
|
||||||
ss << arg;
|
ss << arg;
|
||||||
ss >> vid;
|
ss >> vid;
|
||||||
VarNode* queryVar = fg.getFgVarNode (vid);
|
VarNode* queryVar = fg.getVarNode (vid);
|
||||||
if (queryVar) {
|
if (queryVar) {
|
||||||
queryVars.push_back (queryVar);
|
queryIds.push_back (vid);
|
||||||
} else {
|
} else {
|
||||||
cerr << "error: there isn't a variable with " ;
|
cerr << "error: there isn't a variable with " ;
|
||||||
cerr << "`" << vid << "' as id" ;
|
cerr << "`" << vid << "' as id" ;
|
||||||
@ -177,7 +110,7 @@ processArguments (FactorGraph& fg, int argc, const char* argv[])
|
|||||||
stringstream ss;
|
stringstream ss;
|
||||||
ss << arg.substr (0, pos);
|
ss << arg.substr (0, pos);
|
||||||
ss >> vid;
|
ss >> vid;
|
||||||
VarNode* var = fg.getFgVarNode (vid);
|
VarNode* var = fg.getVarNode (vid);
|
||||||
if (var) {
|
if (var) {
|
||||||
if (!Util::isInteger (arg.substr (pos + 1))) {
|
if (!Util::isInteger (arg.substr (pos + 1))) {
|
||||||
cerr << "error: `" << arg.substr (pos + 1) << "' " ;
|
cerr << "error: `" << arg.substr (pos + 1) << "' " ;
|
||||||
@ -206,14 +139,21 @@ processArguments (FactorGraph& fg, int argc, const char* argv[])
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
runSolver (fg, queryIds);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
runSolver (const FactorGraph& fg, const VarIds& queryIds)
|
||||||
|
{
|
||||||
Solver* solver = 0;
|
Solver* solver = 0;
|
||||||
switch (Globals::infAlgorithm) {
|
switch (Globals::infAlgorithm) {
|
||||||
case InfAlgorithms::VE:
|
case InfAlgorithms::VE:
|
||||||
solver = new VarElimSolver (fg);
|
solver = new VarElimSolver (fg);
|
||||||
break;
|
break;
|
||||||
case InfAlgorithms::BN_BP:
|
case InfAlgorithms::BP:
|
||||||
case InfAlgorithms::FG_BP:
|
solver = new BpSolver (fg);
|
||||||
solver = new FgBpSolver (fg);
|
|
||||||
break;
|
break;
|
||||||
case InfAlgorithms::CBP:
|
case InfAlgorithms::CBP:
|
||||||
solver = new CbpSolver (fg);
|
solver = new CbpSolver (fg);
|
||||||
@ -221,27 +161,10 @@ processArguments (FactorGraph& fg, int argc, const char* argv[])
|
|||||||
default:
|
default:
|
||||||
assert (false);
|
assert (false);
|
||||||
}
|
}
|
||||||
runSolver (solver, queryVars);
|
if (queryIds.size() == 0) {
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
runSolver (Solver* solver, const VarNodes& queryVars)
|
|
||||||
{
|
|
||||||
VarIds vids;
|
|
||||||
for (unsigned i = 0; i < queryVars.size(); i++) {
|
|
||||||
vids.push_back (queryVars[i]->varId());
|
|
||||||
}
|
|
||||||
if (queryVars.size() == 0) {
|
|
||||||
solver->runSolver();
|
|
||||||
solver->printAllPosterioris();
|
solver->printAllPosterioris();
|
||||||
} else if (queryVars.size() == 1) {
|
|
||||||
solver->runSolver();
|
|
||||||
solver->printPosterioriOf (vids[0]);
|
|
||||||
} else {
|
} else {
|
||||||
solver->runSolver();
|
solver->printAnswer (queryIds);
|
||||||
solver->printJointDistributionOf (vids);
|
|
||||||
}
|
}
|
||||||
delete solver;
|
delete solver;
|
||||||
}
|
}
|
||||||
|
@ -8,14 +8,13 @@
|
|||||||
#include <YapInterface.h>
|
#include <YapInterface.h>
|
||||||
|
|
||||||
#include "ParfactorList.h"
|
#include "ParfactorList.h"
|
||||||
#include "BayesNet.h"
|
|
||||||
#include "FactorGraph.h"
|
#include "FactorGraph.h"
|
||||||
#include "FoveSolver.h"
|
#include "FoveSolver.h"
|
||||||
#include "VarElimSolver.h"
|
#include "VarElimSolver.h"
|
||||||
#include "BnBpSolver.h"
|
#include "BpSolver.h"
|
||||||
#include "FgBpSolver.h"
|
|
||||||
#include "CbpSolver.h"
|
#include "CbpSolver.h"
|
||||||
#include "ElimGraph.h"
|
#include "ElimGraph.h"
|
||||||
|
#include "BayesBall.h"
|
||||||
|
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
@ -24,10 +23,35 @@ using namespace std;
|
|||||||
typedef std::pair<ParfactorList*, ObservedFormulas*> LiftedNetwork;
|
typedef std::pair<ParfactorList*, ObservedFormulas*> LiftedNetwork;
|
||||||
|
|
||||||
|
|
||||||
Params readParams (YAP_Term);
|
Params readParameters (YAP_Term);
|
||||||
|
|
||||||
|
vector<unsigned> readUnsignedList (YAP_Term);
|
||||||
|
|
||||||
void readLiftedEvidence (YAP_Term, ObservedFormulas&);
|
void readLiftedEvidence (YAP_Term, ObservedFormulas&);
|
||||||
|
|
||||||
Parfactor* readParfactor (YAP_Term);
|
Parfactor* readParfactor (YAP_Term);
|
||||||
|
|
||||||
|
void runVeSolver (FactorGraph* fg, const vector<VarIds>& tasks,
|
||||||
|
vector<Params>& results);
|
||||||
|
|
||||||
|
void runBpSolver (FactorGraph* fg, const vector<VarIds>& tasks,
|
||||||
|
vector<Params>& results);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
vector<unsigned>
|
||||||
|
readUnsignedList (YAP_Term list)
|
||||||
|
{
|
||||||
|
vector<unsigned> vec;
|
||||||
|
while (list != YAP_TermNil()) {
|
||||||
|
vec.push_back ((unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (list)));
|
||||||
|
list = YAP_TailOfTerm (list);
|
||||||
|
}
|
||||||
|
return vec;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
int createLiftedNetwork (void)
|
int createLiftedNetwork (void)
|
||||||
{
|
{
|
||||||
@ -40,20 +64,17 @@ int createLiftedNetwork (void)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// LiftedUtils::printSymbolDictionary();
|
// LiftedUtils::printSymbolDictionary();
|
||||||
if (Constants::DEBUG > 1) {
|
if (Constants::DEBUG > 2) {
|
||||||
// Util::printHeader ("INITIAL PARFACTORS");
|
// Util::printHeader ("INITIAL PARFACTORS");
|
||||||
// for (unsigned i = 0; i < parfactors.size(); i++) {
|
// for (unsigned i = 0; i < parfactors.size(); i++) {
|
||||||
// parfactors[i]->print();
|
// parfactors[i]->print();
|
||||||
// cout << endl;
|
|
||||||
// }
|
// }
|
||||||
// parfactors[0]->countConvert (LogVar (0));
|
|
||||||
//parfactors[1]->fullExpand (LogVar (1));
|
|
||||||
Util::printHeader ("SHATTERED PARFACTORS");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ParfactorList* pfList = new ParfactorList (parfactors);
|
ParfactorList* pfList = new ParfactorList (parfactors);
|
||||||
|
|
||||||
if (Constants::DEBUG > 1) {
|
if (Constants::DEBUG >= 2) {
|
||||||
|
Util::printHeader ("SHATTERED PARFACTORS");
|
||||||
pfList->print();
|
pfList->print();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -117,7 +138,7 @@ Parfactor* readParfactor (YAP_Term pfTerm)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// read the parameters
|
// read the parameters
|
||||||
const Params& params = readParams (YAP_ArgOfTerm (4, pfTerm));
|
const Params& params = readParameters (YAP_ArgOfTerm (4, pfTerm));
|
||||||
|
|
||||||
// read the constraint
|
// read the constraint
|
||||||
Tuples tuples;
|
Tuples tuples;
|
||||||
@ -195,55 +216,46 @@ void readLiftedEvidence (
|
|||||||
int
|
int
|
||||||
createGroundNetwork (void)
|
createGroundNetwork (void)
|
||||||
{
|
{
|
||||||
Statistics::incrementPrimaryNetworksCounting();
|
string factorsType ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG1)));
|
||||||
// cout << "creating network number " ;
|
bool fromBayesNet = factorsType == "bayes";
|
||||||
// cout << Statistics::getPrimaryNetworksCounting() << endl;
|
FactorGraph* fg = new FactorGraph (fromBayesNet);
|
||||||
// if (Statistics::getPrimaryNetworksCounting() > 98) {
|
YAP_Term factorList = YAP_ARG2;
|
||||||
// Statistics::writeStatisticsToFile ("../../compressing.stats");
|
while (factorList != YAP_TermNil()) {
|
||||||
// }
|
YAP_Term factor = YAP_HeadOfTerm (factorList);
|
||||||
BayesNet* bn = new BayesNet();
|
// read the var ids
|
||||||
YAP_Term varList = YAP_ARG1;
|
VarIds varIds = readUnsignedList (YAP_ArgOfTerm (1, factor));
|
||||||
vector<VarIds> parents;
|
// read the ranges
|
||||||
while (varList != YAP_TermNil()) {
|
Ranges ranges = readUnsignedList (YAP_ArgOfTerm (2, factor));
|
||||||
YAP_Term var = YAP_HeadOfTerm (varList);
|
// read the parameters
|
||||||
VarId vid = (VarId) YAP_IntOfTerm (YAP_ArgOfTerm (1, var));
|
Params params = readParameters (YAP_ArgOfTerm (3, factor));
|
||||||
unsigned dsize = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (2, var));
|
// read dist id
|
||||||
int evidence = (int) YAP_IntOfTerm (YAP_ArgOfTerm (3, var));
|
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (4, factor));
|
||||||
YAP_Term parentL = YAP_ArgOfTerm (4, var);
|
fg->addFactor (Factor (varIds, ranges, params, distId));
|
||||||
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (5, var));
|
factorList = YAP_TailOfTerm (factorList);
|
||||||
parents.push_back (VarIds());
|
|
||||||
while (parentL != YAP_TermNil()) {
|
|
||||||
unsigned parentId = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (parentL));
|
|
||||||
parents.back().push_back (parentId);
|
|
||||||
parentL = YAP_TailOfTerm (parentL);
|
|
||||||
}
|
|
||||||
assert (bn->getBayesNode (vid) == 0);
|
|
||||||
BayesNode* newNode = new BayesNode (
|
|
||||||
vid, dsize, evidence, Params(), distId);
|
|
||||||
bn->addNode (newNode);
|
|
||||||
varList = YAP_TailOfTerm (varList);
|
|
||||||
}
|
}
|
||||||
const BnNodeSet& nodes = bn->getBayesNodes();
|
|
||||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
YAP_Term evidenceList = YAP_ARG3;
|
||||||
BnNodeSet ps;
|
while (evidenceList != YAP_TermNil()) {
|
||||||
for (unsigned j = 0; j < parents[i].size(); j++) {
|
YAP_Term evTerm = YAP_HeadOfTerm (evidenceList);
|
||||||
assert (bn->getBayesNode (parents[i][j]) != 0);
|
unsigned vid = (unsigned) YAP_IntOfTerm ((YAP_ArgOfTerm (1, evTerm)));
|
||||||
ps.push_back (bn->getBayesNode (parents[i][j]));
|
unsigned ev = (unsigned) YAP_IntOfTerm ((YAP_ArgOfTerm (2, evTerm)));
|
||||||
}
|
assert (fg->getVarNode (vid));
|
||||||
nodes[i]->setParents (ps);
|
fg->getVarNode (vid)->setEvidence (ev);
|
||||||
|
evidenceList = YAP_TailOfTerm (evidenceList);
|
||||||
}
|
}
|
||||||
bn->setIndexes();
|
|
||||||
YAP_Int p = (YAP_Int) (bn);
|
YAP_Int p = (YAP_Int) (fg);
|
||||||
return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG2);
|
return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG4);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
Params
|
||||||
readParams (YAP_Term paramL)
|
readParameters (YAP_Term paramL)
|
||||||
{
|
{
|
||||||
Params params;
|
Params params;
|
||||||
while (paramL!= YAP_TermNil()) {
|
assert (YAP_IsPairTerm (paramL));
|
||||||
|
while (paramL != YAP_TermNil()) {
|
||||||
params.push_back ((double) YAP_FloatOfTerm (YAP_HeadOfTerm (paramL)));
|
params.push_back ((double) YAP_FloatOfTerm (YAP_HeadOfTerm (paramL)));
|
||||||
paramL = YAP_TailOfTerm (paramL);
|
paramL = YAP_TailOfTerm (paramL);
|
||||||
}
|
}
|
||||||
@ -319,68 +331,21 @@ runLiftedSolver (void)
|
|||||||
int
|
int
|
||||||
runGroundSolver (void)
|
runGroundSolver (void)
|
||||||
{
|
{
|
||||||
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
|
FactorGraph* fg = (FactorGraph*) YAP_IntOfTerm (YAP_ARG1);
|
||||||
YAP_Term taskList = YAP_ARG2;
|
|
||||||
vector<VarIds> tasks;
|
vector<VarIds> tasks;
|
||||||
std::set<VarId> vids;
|
YAP_Term taskList = YAP_ARG2;
|
||||||
while (taskList != YAP_TermNil()) {
|
while (taskList != YAP_TermNil()) {
|
||||||
VarIds queryVars;
|
tasks.push_back (readUnsignedList (YAP_HeadOfTerm (taskList)));
|
||||||
YAP_Term jointList = YAP_HeadOfTerm (taskList);
|
|
||||||
while (jointList != YAP_TermNil()) {
|
|
||||||
VarId vid = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (jointList));
|
|
||||||
assert (bn->getBayesNode (vid));
|
|
||||||
queryVars.push_back (vid);
|
|
||||||
vids.insert (vid);
|
|
||||||
jointList = YAP_TailOfTerm (jointList);
|
|
||||||
}
|
|
||||||
tasks.push_back (queryVars);
|
|
||||||
taskList = YAP_TailOfTerm (taskList);
|
taskList = YAP_TailOfTerm (taskList);
|
||||||
}
|
}
|
||||||
|
|
||||||
Solver* bpSolver = 0;
|
|
||||||
GraphicalModel* graphicalModel = 0;
|
|
||||||
CFactorGraph::checkForIdenticalFactors = false;
|
|
||||||
if (Globals::infAlgorithm != InfAlgorithms::VE) {
|
|
||||||
BayesNet* mrn = bn->getMinimalRequesiteNetwork (
|
|
||||||
VarIds (vids.begin(), vids.end()));
|
|
||||||
if (Globals::infAlgorithm == InfAlgorithms::BN_BP) {
|
|
||||||
graphicalModel = mrn;
|
|
||||||
bpSolver = new BnBpSolver (*static_cast<BayesNet*> (graphicalModel));
|
|
||||||
} else if (Globals::infAlgorithm == InfAlgorithms::FG_BP) {
|
|
||||||
graphicalModel = new FactorGraph (*mrn);
|
|
||||||
bpSolver = new FgBpSolver (*static_cast<FactorGraph*> (graphicalModel));
|
|
||||||
delete mrn;
|
|
||||||
} else if (Globals::infAlgorithm == InfAlgorithms::CBP) {
|
|
||||||
graphicalModel = new FactorGraph (*mrn);
|
|
||||||
bpSolver = new CbpSolver (*static_cast<FactorGraph*> (graphicalModel));
|
|
||||||
delete mrn;
|
|
||||||
}
|
|
||||||
bpSolver->runSolver();
|
|
||||||
}
|
|
||||||
|
|
||||||
vector<Params> results;
|
vector<Params> results;
|
||||||
results.reserve (tasks.size());
|
if (Globals::infAlgorithm == InfAlgorithms::VE) {
|
||||||
for (unsigned i = 0; i < tasks.size(); i++) {
|
runVeSolver (fg, tasks, results);
|
||||||
if (Globals::infAlgorithm == InfAlgorithms::VE) {
|
} else {
|
||||||
BayesNet* mrn = bn->getMinimalRequesiteNetwork (tasks[i]);
|
runBpSolver (fg, tasks, results);
|
||||||
VarElimSolver* veSolver = new VarElimSolver (*mrn);
|
|
||||||
if (tasks[i].size() == 1) {
|
|
||||||
results.push_back (veSolver->getPosterioriOf (tasks[i][0]));
|
|
||||||
} else {
|
|
||||||
results.push_back (veSolver->getJointDistributionOf (tasks[i]));
|
|
||||||
}
|
|
||||||
delete mrn;
|
|
||||||
delete veSolver;
|
|
||||||
} else {
|
|
||||||
if (tasks[i].size() == 1) {
|
|
||||||
results.push_back (bpSolver->getPosterioriOf (tasks[i][0]));
|
|
||||||
} else {
|
|
||||||
results.push_back (bpSolver->getJointDistributionOf (tasks[i]));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
delete bpSolver;
|
|
||||||
delete graphicalModel;
|
|
||||||
|
|
||||||
YAP_Term list = YAP_TermNil();
|
YAP_Term list = YAP_TermNil();
|
||||||
for (int i = results.size() - 1; i >= 0; i--) {
|
for (int i = results.size() - 1; i >= 0; i--) {
|
||||||
@ -395,12 +360,68 @@ runGroundSolver (void)
|
|||||||
}
|
}
|
||||||
list = YAP_MkPairTerm (queryBeliefsL, list);
|
list = YAP_MkPairTerm (queryBeliefsL, list);
|
||||||
}
|
}
|
||||||
|
|
||||||
return YAP_Unify (list, YAP_ARG3);
|
return YAP_Unify (list, YAP_ARG3);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void runVeSolver (
|
||||||
|
FactorGraph* fg,
|
||||||
|
const vector<VarIds>& tasks,
|
||||||
|
vector<Params>& results)
|
||||||
|
{
|
||||||
|
results.reserve (tasks.size());
|
||||||
|
for (unsigned i = 0; i < tasks.size(); i++) {
|
||||||
|
FactorGraph* mfg = fg;
|
||||||
|
if (fg->isFromBayesNetwork()) {
|
||||||
|
mfg = BayesBall::getMinimalFactorGraph (*fg, tasks[i]);
|
||||||
|
}
|
||||||
|
VarElimSolver solver (*mfg);
|
||||||
|
results.push_back (solver.solveQuery (tasks[i]));
|
||||||
|
if (fg->isFromBayesNetwork()) {
|
||||||
|
delete mfg;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void runBpSolver (
|
||||||
|
FactorGraph* fg,
|
||||||
|
const vector<VarIds>& tasks,
|
||||||
|
vector<Params>& results)
|
||||||
|
{
|
||||||
|
std::set<VarId> vids;
|
||||||
|
for (unsigned i = 0; i < tasks.size(); i++) {
|
||||||
|
Util::addToSet (vids, tasks[i]);
|
||||||
|
}
|
||||||
|
Solver* solver = 0;
|
||||||
|
FactorGraph* mfg = fg;
|
||||||
|
if (fg->isFromBayesNetwork()) {
|
||||||
|
mfg = BayesBall::getMinimalFactorGraph (
|
||||||
|
*fg, VarIds (vids.begin(),vids.end()));
|
||||||
|
}
|
||||||
|
if (Globals::infAlgorithm == InfAlgorithms::BP) {
|
||||||
|
solver = new BpSolver (*mfg);
|
||||||
|
} else if (Globals::infAlgorithm == InfAlgorithms::CBP) {
|
||||||
|
CFactorGraph::checkForIdenticalFactors = false;
|
||||||
|
solver = new CbpSolver (*mfg);
|
||||||
|
} else {
|
||||||
|
cerr << "error: unknow solver" << endl;
|
||||||
|
abort();
|
||||||
|
}
|
||||||
|
results.reserve (tasks.size());
|
||||||
|
for (unsigned i = 0; i < tasks.size(); i++) {
|
||||||
|
results.push_back (solver->solveQuery (tasks[i]));
|
||||||
|
}
|
||||||
|
if (fg->isFromBayesNetwork()) {
|
||||||
|
delete mfg;
|
||||||
|
}
|
||||||
|
delete solver;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
int
|
int
|
||||||
setParfactorsParams (void)
|
setParfactorsParams (void)
|
||||||
{
|
{
|
||||||
@ -409,10 +430,10 @@ setParfactorsParams (void)
|
|||||||
YAP_Term distList = YAP_ARG2;
|
YAP_Term distList = YAP_ARG2;
|
||||||
unordered_map<unsigned, Params> paramsMap;
|
unordered_map<unsigned, Params> paramsMap;
|
||||||
while (distList != YAP_TermNil()) {
|
while (distList != YAP_TermNil()) {
|
||||||
YAP_Term dist = YAP_HeadOfTerm (distList);
|
YAP_Term dist = YAP_HeadOfTerm (distList);
|
||||||
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, dist));
|
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, dist));
|
||||||
assert (Util::contains (paramsMap, distId) == false);
|
assert (Util::contains (paramsMap, distId) == false);
|
||||||
paramsMap[distId] = readParams (YAP_ArgOfTerm (2, dist));
|
paramsMap[distId] = readParameters (YAP_ArgOfTerm (2, dist));
|
||||||
distList = YAP_TailOfTerm (distList);
|
distList = YAP_TailOfTerm (distList);
|
||||||
}
|
}
|
||||||
ParfactorList::iterator it = pfList->begin();
|
ParfactorList::iterator it = pfList->begin();
|
||||||
@ -427,22 +448,24 @@ setParfactorsParams (void)
|
|||||||
|
|
||||||
|
|
||||||
int
|
int
|
||||||
setBayesNetParams (void)
|
setFactorsParams (void)
|
||||||
{
|
{
|
||||||
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
|
return TRUE; // TODO
|
||||||
|
FactorGraph* fg = (FactorGraph*) YAP_IntOfTerm (YAP_ARG1);
|
||||||
YAP_Term distList = YAP_ARG2;
|
YAP_Term distList = YAP_ARG2;
|
||||||
unordered_map<unsigned, Params> paramsMap;
|
unordered_map<unsigned, Params> paramsMap;
|
||||||
while (distList != YAP_TermNil()) {
|
while (distList != YAP_TermNil()) {
|
||||||
YAP_Term dist = YAP_HeadOfTerm (distList);
|
YAP_Term dist = YAP_HeadOfTerm (distList);
|
||||||
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, dist));
|
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, dist));
|
||||||
assert (Util::contains (paramsMap, distId) == false);
|
assert (Util::contains (paramsMap, distId) == false);
|
||||||
paramsMap[distId] = readParams (YAP_ArgOfTerm (2, dist));
|
paramsMap[distId] = readParameters (YAP_ArgOfTerm (2, dist));
|
||||||
distList = YAP_TailOfTerm (distList);
|
distList = YAP_TailOfTerm (distList);
|
||||||
}
|
}
|
||||||
const BnNodeSet& nodes = bn->getBayesNodes();
|
const FacNodes& facNodes = fg->facNodes();
|
||||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||||
assert (Util::contains (paramsMap, nodes[i]->distId()));
|
unsigned distId = facNodes[i]->factor().distId();
|
||||||
nodes[i]->setParams (paramsMap[nodes[i]->distId()]);
|
assert (Util::contains (paramsMap, distId));
|
||||||
|
facNodes[i]->factor().setParams (paramsMap[distId]);
|
||||||
}
|
}
|
||||||
return TRUE;
|
return TRUE;
|
||||||
}
|
}
|
||||||
@ -450,24 +473,29 @@ setBayesNetParams (void)
|
|||||||
|
|
||||||
|
|
||||||
int
|
int
|
||||||
setExtraVarsInfo (void)
|
setVarsInformation (void)
|
||||||
{
|
{
|
||||||
GraphicalModel::clearVariablesInformation();
|
Var::clearVarsInfo();
|
||||||
YAP_Term varsInfoL = YAP_ARG2;
|
YAP_Term labelsL = YAP_ARG1;
|
||||||
while (varsInfoL != YAP_TermNil()) {
|
vector<string> labels;
|
||||||
YAP_Term head = YAP_HeadOfTerm (varsInfoL);
|
while (labelsL != YAP_TermNil()) {
|
||||||
VarId vid = YAP_IntOfTerm (YAP_ArgOfTerm (1, head));
|
YAP_Atom atom = YAP_AtomOfTerm (YAP_HeadOfTerm (labelsL));
|
||||||
YAP_Atom label = YAP_AtomOfTerm (YAP_ArgOfTerm (2, head));
|
labels.push_back ((char*) YAP_AtomName (atom));
|
||||||
YAP_Term statesL = YAP_ArgOfTerm (3, head);
|
labelsL = YAP_TailOfTerm (labelsL);
|
||||||
|
}
|
||||||
|
unsigned count = 0;
|
||||||
|
YAP_Term stateNamesL = YAP_ARG2;
|
||||||
|
while (stateNamesL != YAP_TermNil()) {
|
||||||
States states;
|
States states;
|
||||||
while (statesL != YAP_TermNil()) {
|
YAP_Term namesL = YAP_HeadOfTerm (stateNamesL);
|
||||||
YAP_Atom atom = YAP_AtomOfTerm (YAP_HeadOfTerm (statesL));
|
while (namesL != YAP_TermNil()) {
|
||||||
|
YAP_Atom atom = YAP_AtomOfTerm (YAP_HeadOfTerm (namesL));
|
||||||
states.push_back ((char*) YAP_AtomName (atom));
|
states.push_back ((char*) YAP_AtomName (atom));
|
||||||
statesL = YAP_TailOfTerm (statesL);
|
namesL = YAP_TailOfTerm (namesL);
|
||||||
}
|
}
|
||||||
GraphicalModel::addVariableInformation (vid,
|
Var::addVarInfo (count, labels[count], states);
|
||||||
(char*) YAP_AtomName (label), states);
|
count ++;
|
||||||
varsInfoL = YAP_TailOfTerm (varsInfoL);
|
stateNamesL = YAP_TailOfTerm (stateNamesL);
|
||||||
}
|
}
|
||||||
return TRUE;
|
return TRUE;
|
||||||
}
|
}
|
||||||
@ -482,10 +510,8 @@ setHorusFlag (void)
|
|||||||
string value ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG2)));
|
string value ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG2)));
|
||||||
if ( value == "ve") {
|
if ( value == "ve") {
|
||||||
Globals::infAlgorithm = InfAlgorithms::VE;
|
Globals::infAlgorithm = InfAlgorithms::VE;
|
||||||
} else if (value == "bn_bp") {
|
} else if (value == "bp") {
|
||||||
Globals::infAlgorithm = InfAlgorithms::BN_BP;
|
Globals::infAlgorithm = InfAlgorithms::BP;
|
||||||
} else if (value == "fg_bp") {
|
|
||||||
Globals::infAlgorithm = InfAlgorithms::FG_BP;
|
|
||||||
} else if (value == "cbp") {
|
} else if (value == "cbp") {
|
||||||
Globals::infAlgorithm = InfAlgorithms::CBP;
|
Globals::infAlgorithm = InfAlgorithms::CBP;
|
||||||
} else {
|
} else {
|
||||||
@ -559,9 +585,9 @@ setHorusFlag (void)
|
|||||||
|
|
||||||
|
|
||||||
int
|
int
|
||||||
freeBayesNetwork (void)
|
freeGroundNetwork (void)
|
||||||
{
|
{
|
||||||
delete (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
|
delete (FactorGraph*) YAP_IntOfTerm (YAP_ARG1);
|
||||||
return TRUE;
|
return TRUE;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -583,14 +609,14 @@ extern "C" void
|
|||||||
init_predicates (void)
|
init_predicates (void)
|
||||||
{
|
{
|
||||||
YAP_UserCPredicate ("create_lifted_network", createLiftedNetwork, 3);
|
YAP_UserCPredicate ("create_lifted_network", createLiftedNetwork, 3);
|
||||||
YAP_UserCPredicate ("create_ground_network", createGroundNetwork, 2);
|
YAP_UserCPredicate ("create_ground_network", createGroundNetwork, 4);
|
||||||
YAP_UserCPredicate ("run_lifted_solver", runLiftedSolver, 3);
|
YAP_UserCPredicate ("run_lifted_solver", runLiftedSolver, 3);
|
||||||
YAP_UserCPredicate ("run_ground_solver", runGroundSolver, 3);
|
YAP_UserCPredicate ("run_ground_solver", runGroundSolver, 3);
|
||||||
YAP_UserCPredicate ("set_parfactors_params", setParfactorsParams, 2);
|
YAP_UserCPredicate ("set_parfactors_params", setParfactorsParams, 2);
|
||||||
YAP_UserCPredicate ("set_bayes_net_params", setBayesNetParams, 2);
|
YAP_UserCPredicate ("set_factors_params", setFactorsParams, 2);
|
||||||
YAP_UserCPredicate ("set_extra_vars_info", setExtraVarsInfo, 2);
|
YAP_UserCPredicate ("set_vars_information", setVarsInformation, 2);
|
||||||
YAP_UserCPredicate ("set_horus_flag", setHorusFlag, 2);
|
YAP_UserCPredicate ("set_horus_flag", setHorusFlag, 2);
|
||||||
YAP_UserCPredicate ("free_parfactors", freeParfactors, 1);
|
YAP_UserCPredicate ("free_parfactors", freeParfactors, 1);
|
||||||
YAP_UserCPredicate ("free_bayesian_network", freeBayesNetwork, 1);
|
YAP_UserCPredicate ("free_ground_network", freeGroundNetwork, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -8,7 +8,7 @@
|
|||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <iomanip>
|
#include <iomanip>
|
||||||
|
|
||||||
#include "VarNode.h"
|
#include "Var.h"
|
||||||
#include "Util.h"
|
#include "Util.h"
|
||||||
|
|
||||||
|
|
||||||
@ -31,14 +31,14 @@ class StatesIndexer
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
StatesIndexer (const VarNodes& vars, bool calcOffsets = true)
|
StatesIndexer (const Vars& vars, bool calcOffsets = true)
|
||||||
{
|
{
|
||||||
size_ = 1;
|
size_ = 1;
|
||||||
indices_.resize (vars.size(), 0);
|
indices_.resize (vars.size(), 0);
|
||||||
ranges_.reserve (vars.size());
|
ranges_.reserve (vars.size());
|
||||||
for (unsigned i = 0; i < vars.size(); i++) {
|
for (unsigned i = 0; i < vars.size(); i++) {
|
||||||
ranges_.push_back (vars[i]->nrStates());
|
ranges_.push_back (vars[i]->range());
|
||||||
size_ *= vars[i]->nrStates();
|
size_ *= vars[i]->range();
|
||||||
}
|
}
|
||||||
li_ = 0;
|
li_ = 0;
|
||||||
if (calcOffsets) {
|
if (calcOffsets) {
|
||||||
|
@ -45,9 +45,8 @@ CWD=$(PWD)
|
|||||||
|
|
||||||
|
|
||||||
HEADERS = \
|
HEADERS = \
|
||||||
$(srcdir)/GraphicalModel.h \
|
|
||||||
$(srcdir)/BayesNet.h \
|
$(srcdir)/BayesNet.h \
|
||||||
$(srcdir)/BayesNode.h \
|
$(srcdir)/BayesBall.h \
|
||||||
$(srcdir)/ElimGraph.h \
|
$(srcdir)/ElimGraph.h \
|
||||||
$(srcdir)/FactorGraph.h \
|
$(srcdir)/FactorGraph.h \
|
||||||
$(srcdir)/Factor.h \
|
$(srcdir)/Factor.h \
|
||||||
@ -55,11 +54,10 @@ HEADERS = \
|
|||||||
$(srcdir)/ConstraintTree.h \
|
$(srcdir)/ConstraintTree.h \
|
||||||
$(srcdir)/Solver.h \
|
$(srcdir)/Solver.h \
|
||||||
$(srcdir)/VarElimSolver.h \
|
$(srcdir)/VarElimSolver.h \
|
||||||
$(srcdir)/BnBpSolver.h \
|
$(srcdir)/BpSolver.h \
|
||||||
$(srcdir)/FgBpSolver.h \
|
|
||||||
$(srcdir)/CbpSolver.h \
|
$(srcdir)/CbpSolver.h \
|
||||||
$(srcdir)/FoveSolver.h \
|
$(srcdir)/FoveSolver.h \
|
||||||
$(srcdir)/VarNode.h \
|
$(srcdir)/Var.h \
|
||||||
$(srcdir)/Indexer.h \
|
$(srcdir)/Indexer.h \
|
||||||
$(srcdir)/Parfactor.h \
|
$(srcdir)/Parfactor.h \
|
||||||
$(srcdir)/ProbFormula.h \
|
$(srcdir)/ProbFormula.h \
|
||||||
@ -68,22 +66,20 @@ HEADERS = \
|
|||||||
$(srcdir)/LiftedUtils.h \
|
$(srcdir)/LiftedUtils.h \
|
||||||
$(srcdir)/TinySet.h \
|
$(srcdir)/TinySet.h \
|
||||||
$(srcdir)/Util.h \
|
$(srcdir)/Util.h \
|
||||||
$(srcdir)/Horus.h \
|
$(srcdir)/Horus.h
|
||||||
$(srcdir)/xmlParser/xmlParser.h
|
|
||||||
|
|
||||||
CPP_SOURCES = \
|
CPP_SOURCES = \
|
||||||
$(srcdir)/BayesNet.cpp \
|
$(srcdir)/BayesNet.cpp \
|
||||||
$(srcdir)/BayesNode.cpp \
|
$(srcdir)/BayesBall.cpp \
|
||||||
$(srcdir)/ElimGraph.cpp \
|
$(srcdir)/ElimGraph.cpp \
|
||||||
$(srcdir)/FactorGraph.cpp \
|
$(srcdir)/FactorGraph.cpp \
|
||||||
$(srcdir)/Factor.cpp \
|
$(srcdir)/Factor.cpp \
|
||||||
$(srcdir)/CFactorGraph.cpp \
|
$(srcdir)/CFactorGraph.cpp \
|
||||||
$(srcdir)/ConstraintTree.cpp \
|
$(srcdir)/ConstraintTree.cpp \
|
||||||
$(srcdir)/VarNode.cpp \
|
$(srcdir)/Var.cpp \
|
||||||
$(srcdir)/Solver.cpp \
|
$(srcdir)/Solver.cpp \
|
||||||
$(srcdir)/VarElimSolver.cpp \
|
$(srcdir)/VarElimSolver.cpp \
|
||||||
$(srcdir)/BnBpSolver.cpp \
|
$(srcdir)/BpSolver.cpp \
|
||||||
$(srcdir)/FgBpSolver.cpp \
|
|
||||||
$(srcdir)/CbpSolver.cpp \
|
$(srcdir)/CbpSolver.cpp \
|
||||||
$(srcdir)/FoveSolver.cpp \
|
$(srcdir)/FoveSolver.cpp \
|
||||||
$(srcdir)/Parfactor.cpp \
|
$(srcdir)/Parfactor.cpp \
|
||||||
@ -93,22 +89,20 @@ CPP_SOURCES = \
|
|||||||
$(srcdir)/LiftedUtils.cpp \
|
$(srcdir)/LiftedUtils.cpp \
|
||||||
$(srcdir)/Util.cpp \
|
$(srcdir)/Util.cpp \
|
||||||
$(srcdir)/HorusYap.cpp \
|
$(srcdir)/HorusYap.cpp \
|
||||||
$(srcdir)/HorusCli.cpp \
|
$(srcdir)/HorusCli.cpp
|
||||||
$(srcdir)/xmlParser/xmlParser.cpp
|
|
||||||
|
|
||||||
OBJS = \
|
OBJS = \
|
||||||
BayesNet.o \
|
BayesNet.o \
|
||||||
BayesNode.o \
|
BayesBall.o \
|
||||||
ElimGraph.o \
|
ElimGraph.o \
|
||||||
FactorGraph.o \
|
FactorGraph.o \
|
||||||
Factor.o \
|
Factor.o \
|
||||||
CFactorGraph.o \
|
CFactorGraph.o \
|
||||||
ConstraintTree.o \
|
ConstraintTree.o \
|
||||||
VarNode.o \
|
Var.o \
|
||||||
Solver.o \
|
Solver.o \
|
||||||
VarElimSolver.o \
|
VarElimSolver.o \
|
||||||
BnBpSolver.o \
|
BpSolver.o \
|
||||||
FgBpSolver.o \
|
|
||||||
CbpSolver.o \
|
CbpSolver.o \
|
||||||
FoveSolver.o \
|
FoveSolver.o \
|
||||||
Parfactor.o \
|
Parfactor.o \
|
||||||
@ -121,17 +115,16 @@ OBJS = \
|
|||||||
|
|
||||||
HCLI_OBJS = \
|
HCLI_OBJS = \
|
||||||
BayesNet.o \
|
BayesNet.o \
|
||||||
BayesNode.o \
|
BayesBall.o \
|
||||||
ElimGraph.o \
|
ElimGraph.o \
|
||||||
FactorGraph.o \
|
FactorGraph.o \
|
||||||
Factor.o \
|
Factor.o \
|
||||||
CFactorGraph.o \
|
CFactorGraph.o \
|
||||||
ConstraintTree.o \
|
ConstraintTree.o \
|
||||||
VarNode.o \
|
Var.o \
|
||||||
Solver.o \
|
Solver.o \
|
||||||
VarElimSolver.o \
|
VarElimSolver.o \
|
||||||
BnBpSolver.o \
|
BpSolver.o \
|
||||||
FgBpSolver.o \
|
|
||||||
CbpSolver.o \
|
CbpSolver.o \
|
||||||
FoveSolver.o \
|
FoveSolver.o \
|
||||||
Parfactor.o \
|
Parfactor.o \
|
||||||
@ -140,7 +133,6 @@ HCLI_OBJS = \
|
|||||||
ParfactorList.o \
|
ParfactorList.o \
|
||||||
LiftedUtils.o \
|
LiftedUtils.o \
|
||||||
Util.o \
|
Util.o \
|
||||||
xmlParser/xmlParser.o \
|
|
||||||
HorusCli.o
|
HorusCli.o
|
||||||
|
|
||||||
SOBJS=horus.@SO@
|
SOBJS=horus.@SO@
|
||||||
@ -153,10 +145,6 @@ all: $(SOBJS) hcli
|
|||||||
$(CXX) -c $(CXXFLAGS) $< -o $@
|
$(CXX) -c $(CXXFLAGS) $< -o $@
|
||||||
|
|
||||||
|
|
||||||
xmlParser/xmlParser.o : $(srcdir)/xmlParser/xmlParser.cpp
|
|
||||||
$(CXX) -c $(CXXFLAGS) $< -o $@
|
|
||||||
|
|
||||||
|
|
||||||
@DO_SECOND_LD@horus.@SO@: $(OBJS)
|
@DO_SECOND_LD@horus.@SO@: $(OBJS)
|
||||||
@DO_SECOND_LD@ @SHLIB_CXX_LD@ -o horus.@SO@ $(OBJS) @EXTRA_LIBS_FOR_SWIDLLS@
|
@DO_SECOND_LD@ @SHLIB_CXX_LD@ -o horus.@SO@ $(OBJS) @EXTRA_LIBS_FOR_SWIDLLS@
|
||||||
|
|
||||||
@ -170,7 +158,7 @@ install: all
|
|||||||
|
|
||||||
|
|
||||||
clean:
|
clean:
|
||||||
rm -f *.o *~ $(OBJS) $(SOBJS) *.BAK hcli xmlParser/*.o
|
rm -f *.o *~ $(OBJS) $(SOBJS) *.BAK hcli
|
||||||
|
|
||||||
|
|
||||||
erase_dots:
|
erase_dots:
|
||||||
|
@ -528,11 +528,13 @@ Parfactor::print (bool printParams) const
|
|||||||
cout << args_[i];
|
cout << args_[i];
|
||||||
}
|
}
|
||||||
cout << endl;
|
cout << endl;
|
||||||
vector<string> groups;
|
if (args_[0].group() != Util::maxUnsigned()) {
|
||||||
for (unsigned i = 0; i < args_.size(); i++) {
|
vector<string> groups;
|
||||||
groups.push_back (string ("g") + Util::toString (args_[i].group()));
|
for (unsigned i = 0; i < args_.size(); i++) {
|
||||||
|
groups.push_back (string ("g") + Util::toString (args_[i].group()));
|
||||||
|
}
|
||||||
|
cout << "Groups: " << groups << endl;
|
||||||
}
|
}
|
||||||
cout << "Groups: " << groups << endl;
|
|
||||||
cout << "LogVars: " << constr_->logVarSet() << endl;
|
cout << "LogVars: " << constr_->logVarSet() << endl;
|
||||||
cout << "Ranges: " << ranges_ << endl;
|
cout << "Ranges: " << ranges_ << endl;
|
||||||
if (printParams == false) {
|
if (printParams == false) {
|
||||||
@ -570,6 +572,7 @@ Parfactor::print (bool printParams) const
|
|||||||
cout << " = " << params_[i] << endl;
|
cout << " = " << params_[i] << endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
cout << endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -116,7 +116,6 @@ ParfactorList::print (void) const
|
|||||||
list<Parfactor*>::const_iterator it;
|
list<Parfactor*>::const_iterator it;
|
||||||
for (it = pfList_.begin(); it != pfList_.end(); ++it) {
|
for (it = pfList_.begin(); it != pfList_.end(); ++it) {
|
||||||
(*it)->print();
|
(*it)->print();
|
||||||
cout << endl;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -219,7 +218,7 @@ ParfactorList::shatter (
|
|||||||
ProbFormula& f1 = g1->argument (fIdx1);
|
ProbFormula& f1 = g1->argument (fIdx1);
|
||||||
ProbFormula& f2 = g2->argument (fIdx2);
|
ProbFormula& f2 = g2->argument (fIdx2);
|
||||||
// cout << endl;
|
// cout << endl;
|
||||||
// Util::printDashLine();
|
// Util::printDashedLine();
|
||||||
// cout << "-> SHATTERING (#" << g1 << ", #" << g2 << ")" << endl;
|
// cout << "-> SHATTERING (#" << g1 << ", #" << g2 << ")" << endl;
|
||||||
// g1->print();
|
// g1->print();
|
||||||
// cout << "-> WITH" << endl;
|
// cout << "-> WITH" << endl;
|
||||||
@ -228,7 +227,7 @@ ParfactorList::shatter (
|
|||||||
// cout << g1->constr()->tupleSet (f1.logVars()) << endl;
|
// cout << g1->constr()->tupleSet (f1.logVars()) << endl;
|
||||||
// cout << "-> ON: " << f2 << "|" ;
|
// cout << "-> ON: " << f2 << "|" ;
|
||||||
// cout << g2->constr()->tupleSet (f2.logVars()) << endl;
|
// cout << g2->constr()->tupleSet (f2.logVars()) << endl;
|
||||||
// Util::printDashLine();
|
// Util::printDashedLine();
|
||||||
if (f1.isAtom()) {
|
if (f1.isAtom()) {
|
||||||
unsigned group = (f1.group() < f2.group()) ? f1.group() : f2.group();
|
unsigned group = (f1.group() < f2.group()) ? f1.group() : f2.group();
|
||||||
f1.setGroup (group);
|
f1.setGroup (group);
|
||||||
@ -265,18 +264,19 @@ ParfactorList::shatter (
|
|||||||
assert (commCt1->tupleSet (f1.arity()) ==
|
assert (commCt1->tupleSet (f1.arity()) ==
|
||||||
commCt2->tupleSet (f2.arity()));
|
commCt2->tupleSet (f2.arity()));
|
||||||
|
|
||||||
// stringstream ss1; ss1 << "" << count << "_A.dot" ;
|
// unsigned static count = 0; count ++;
|
||||||
// stringstream ss2; ss2 << "" << count << "_B.dot" ;
|
// stringstream ss1; ss1 << "" << count << "_A.dot" ;
|
||||||
// stringstream ss3; ss3 << "" << count << "_A_comm.dot" ;
|
// stringstream ss2; ss2 << "" << count << "_B.dot" ;
|
||||||
// stringstream ss4; ss4 << "" << count << "_A_excl.dot" ;
|
// stringstream ss3; ss3 << "" << count << "_A_comm.dot" ;
|
||||||
// stringstream ss5; ss5 << "" << count << "_B_comm.dot" ;
|
// stringstream ss4; ss4 << "" << count << "_A_excl.dot" ;
|
||||||
// stringstream ss6; ss6 << "" << count << "_B_excl.dot" ;
|
// stringstream ss5; ss5 << "" << count << "_B_comm.dot" ;
|
||||||
// ct1->exportToGraphViz (ss1.str().c_str(), true);
|
// stringstream ss6; ss6 << "" << count << "_B_excl.dot" ;
|
||||||
// ct2->exportToGraphViz (ss2.str().c_str(), true);
|
// g1->constr()->exportToGraphViz (ss1.str().c_str(), true);
|
||||||
// commCt1->exportToGraphViz (ss3.str().c_str(), true);
|
// g2->constr()->exportToGraphViz (ss2.str().c_str(), true);
|
||||||
// exclCt1->exportToGraphViz (ss4.str().c_str(), true);
|
// commCt1->exportToGraphViz (ss3.str().c_str(), true);
|
||||||
// commCt2->exportToGraphViz (ss5.str().c_str(), true);
|
// exclCt1->exportToGraphViz (ss4.str().c_str(), true);
|
||||||
// exclCt2->exportToGraphViz (ss6.str().c_str(), true);
|
// commCt2->exportToGraphViz (ss5.str().c_str(), true);
|
||||||
|
// exclCt2->exportToGraphViz (ss6.str().c_str(), true);
|
||||||
|
|
||||||
if (exclCt1->empty() && exclCt2->empty()) {
|
if (exclCt1->empty() && exclCt2->empty()) {
|
||||||
unsigned group = (f1.group() < f2.group())
|
unsigned group = (f1.group() < f2.group())
|
||||||
|
@ -3,51 +3,35 @@
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
Solver::printAllPosterioris (void)
|
Solver::printAnswer (const VarIds& vids)
|
||||||
{
|
{
|
||||||
const VarNodes& vars = gm_->getVariableNodes();
|
Vars unobservedVars;
|
||||||
for (unsigned i = 0; i < vars.size(); i++) {
|
VarIds unobservedVids;
|
||||||
printPosterioriOf (vars[i]->varId());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Solver::printPosterioriOf (VarId vid)
|
|
||||||
{
|
|
||||||
VarNode* var = gm_->getVariableNode (vid);
|
|
||||||
const Params& posterioriDist = getPosterioriOf (vid);
|
|
||||||
const States& states = var->states();
|
|
||||||
for (unsigned i = 0; i < states.size(); i++) {
|
|
||||||
cout << "P(" << var->label() << "=" << states[i] << ") = " ;
|
|
||||||
cout << setprecision (Constants::PRECISION) << posterioriDist[i];
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Solver::printJointDistributionOf (const VarIds& vids)
|
|
||||||
{
|
|
||||||
VarNodes vars;
|
|
||||||
VarIds vidsWithoutEvidence;
|
|
||||||
for (unsigned i = 0; i < vids.size(); i++) {
|
for (unsigned i = 0; i < vids.size(); i++) {
|
||||||
VarNode* var = gm_->getVariableNode (vids[i]);
|
VarNode* vn = fg.getVarNode (vids[i]);
|
||||||
if (var->hasEvidence() == false) {
|
if (vn->hasEvidence() == false) {
|
||||||
vars.push_back (var);
|
unobservedVars.push_back (vn);
|
||||||
vidsWithoutEvidence.push_back (vids[i]);
|
unobservedVids.push_back (vids[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
const Params& jointDist = getJointDistributionOf (vidsWithoutEvidence);
|
Params res = solveQuery (unobservedVids);
|
||||||
vector<string> jointStrings = Util::getJointStateStrings (vars);
|
vector<string> stateLines = Util::getStateLines (unobservedVars);
|
||||||
for (unsigned i = 0; i < jointDist.size(); i++) {
|
for (unsigned i = 0; i < res.size(); i++) {
|
||||||
cout << "P(" << jointStrings[i] << ") = " ;
|
cout << "P(" << stateLines[i] << ") = " ;
|
||||||
cout << setprecision (Constants::PRECISION) << jointDist[i];
|
cout << std::setprecision (Constants::PRECISION) << res[i];
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
Solver::printAllPosterioris (void)
|
||||||
|
{
|
||||||
|
const VarNodes& vars = fg.varNodes();
|
||||||
|
for (unsigned i = 0; i < vars.size(); i++) {
|
||||||
|
printAnswer ({vars[i]->varId()});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@ -3,32 +3,27 @@
|
|||||||
|
|
||||||
#include <iomanip>
|
#include <iomanip>
|
||||||
|
|
||||||
#include "GraphicalModel.h"
|
#include "Var.h"
|
||||||
#include "VarNode.h"
|
#include "FactorGraph.h"
|
||||||
|
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
class Solver
|
class Solver
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
Solver (const GraphicalModel* gm) : gm_(gm) { }
|
Solver (const FactorGraph& factorGraph) : fg(factorGraph) { }
|
||||||
|
|
||||||
virtual ~Solver() { } // ensure that subclass destructor is called
|
virtual ~Solver() { } // ensure that subclass destructor is called
|
||||||
|
|
||||||
virtual void runSolver (void) = 0;
|
virtual Params solveQuery (VarIds queryVids) = 0;
|
||||||
|
|
||||||
virtual Params getPosterioriOf (VarId) = 0;
|
void printAnswer (const VarIds& vids);
|
||||||
|
|
||||||
virtual Params getJointDistributionOf (const VarIds&) = 0;
|
|
||||||
|
|
||||||
void printAllPosterioris (void);
|
void printAllPosterioris (void);
|
||||||
|
|
||||||
void printPosterioriOf (VarId vid);
|
protected:
|
||||||
|
const FactorGraph& fg;
|
||||||
void printJointDistributionOf (const VarIds& vids);
|
|
||||||
|
|
||||||
private:
|
|
||||||
const GraphicalModel* gm_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // HORUS_SOLVER_H
|
#endif // HORUS_SOLVER_H
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
TODO
|
TODO
|
||||||
|
- add a way to calculate combinations and factorials with large numbers
|
||||||
- add way to calculate combinations and factorials with large numbers
|
|
||||||
- refactor sumOut in parfactor -> is really ugly code
|
- refactor sumOut in parfactor -> is really ugly code
|
||||||
- Indexer: start receiving ranges as constant reference
|
- Indexer: start receiving ranges as constant reference
|
||||||
|
@ -5,7 +5,6 @@
|
|||||||
|
|
||||||
#include "Util.h"
|
#include "Util.h"
|
||||||
#include "Indexer.h"
|
#include "Indexer.h"
|
||||||
#include "GraphicalModel.h"
|
|
||||||
|
|
||||||
|
|
||||||
namespace Globals {
|
namespace Globals {
|
||||||
@ -25,12 +24,11 @@ Schedule schedule = BpOptions::Schedule::SEQ_FIXED;
|
|||||||
//Schedule schedule = BpOptions::Schedule::SEQ_RANDOM;
|
//Schedule schedule = BpOptions::Schedule::SEQ_RANDOM;
|
||||||
//Schedule schedule = BpOptions::Schedule::PARALLEL;
|
//Schedule schedule = BpOptions::Schedule::PARALLEL;
|
||||||
//Schedule schedule = BpOptions::Schedule::MAX_RESIDUAL;
|
//Schedule schedule = BpOptions::Schedule::MAX_RESIDUAL;
|
||||||
double accuracy = 0.0001;
|
double accuracy = 0.0001;
|
||||||
unsigned maxIter = 1000;
|
unsigned maxIter = 1000;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
unordered_map<VarId, VarInfo> GraphicalModel::varsInfo_;
|
|
||||||
|
|
||||||
vector<NetInfo> Statistics::netInfo_;
|
vector<NetInfo> Statistics::netInfo_;
|
||||||
vector<CompressInfo> Statistics::compressInfo_;
|
vector<CompressInfo> Statistics::compressInfo_;
|
||||||
@ -139,7 +137,7 @@ parametersToString (const Params& v, unsigned precision)
|
|||||||
|
|
||||||
|
|
||||||
vector<string>
|
vector<string>
|
||||||
getJointStateStrings (const VarNodes& vars)
|
getStateLines (const Vars& vars)
|
||||||
{
|
{
|
||||||
StatesIndexer idx (vars);
|
StatesIndexer idx (vars);
|
||||||
vector<string> jointStrings;
|
vector<string> jointStrings;
|
||||||
@ -157,7 +155,8 @@ getJointStateStrings (const VarNodes& vars)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
void printHeader (string header, std::ostream& os)
|
void
|
||||||
|
printHeader (string header, std::ostream& os)
|
||||||
{
|
{
|
||||||
printAsteriskLine (os);
|
printAsteriskLine (os);
|
||||||
os << header << endl;
|
os << header << endl;
|
||||||
@ -166,7 +165,8 @@ void printHeader (string header, std::ostream& os)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
void printSubHeader (string header, std::ostream& os)
|
void
|
||||||
|
printSubHeader (string header, std::ostream& os)
|
||||||
{
|
{
|
||||||
printDashedLine (os);
|
printDashedLine (os);
|
||||||
os << header << endl;
|
os << header << endl;
|
||||||
@ -175,7 +175,8 @@ void printSubHeader (string header, std::ostream& os)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
void printAsteriskLine (std::ostream& os)
|
void
|
||||||
|
printAsteriskLine (std::ostream& os)
|
||||||
{
|
{
|
||||||
os << "********************************" ;
|
os << "********************************" ;
|
||||||
os << "********************************" ;
|
os << "********************************" ;
|
||||||
@ -184,7 +185,8 @@ void printAsteriskLine (std::ostream& os)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
void printDashedLine (std::ostream& os)
|
void
|
||||||
|
printDashedLine (std::ostream& os)
|
||||||
{
|
{
|
||||||
os << "--------------------------------" ;
|
os << "--------------------------------" ;
|
||||||
os << "--------------------------------" ;
|
os << "--------------------------------" ;
|
||||||
@ -368,12 +370,12 @@ Statistics::printStatistics (void)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
Statistics::writeStatisticsToFile (const char* fileName)
|
Statistics::writeStatistics (const char* fileName)
|
||||||
{
|
{
|
||||||
ofstream out (fileName);
|
ofstream out (fileName);
|
||||||
if (!out.is_open()) {
|
if (!out.is_open()) {
|
||||||
cerr << "error: cannot open file to write at " ;
|
cerr << "error: cannot open file to write at " ;
|
||||||
cerr << "Statistics::writeStatisticsToFile()" << endl;
|
cerr << "Statistics::writeStats()" << endl;
|
||||||
abort();
|
abort();
|
||||||
}
|
}
|
||||||
out << getStatisticString();
|
out << getStatisticString();
|
||||||
@ -384,13 +386,13 @@ Statistics::writeStatisticsToFile (const char* fileName)
|
|||||||
|
|
||||||
void
|
void
|
||||||
Statistics::updateCompressingStatistics (
|
Statistics::updateCompressingStatistics (
|
||||||
unsigned nGroundVars,
|
unsigned nrGroundVars,
|
||||||
unsigned nGroundFactors,
|
unsigned nrGroundFactors,
|
||||||
unsigned nClusterVars,
|
unsigned nrClusterVars,
|
||||||
unsigned nClusterFactors,
|
unsigned nrClusterFactors,
|
||||||
unsigned nWithoutNeighs) {
|
unsigned nrNeighborless) {
|
||||||
compressInfo_.push_back (CompressInfo (nGroundVars, nGroundFactors,
|
compressInfo_.push_back (CompressInfo (nrGroundVars, nrGroundFactors,
|
||||||
nClusterVars, nClusterFactors, nWithoutNeighs));
|
nrClusterVars, nrClusterFactors, nrNeighborless));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -401,10 +403,9 @@ Statistics::getStatisticString (void)
|
|||||||
stringstream ss2, ss3, ss4, ss1;
|
stringstream ss2, ss3, ss4, ss1;
|
||||||
ss1 << "running mode: " ;
|
ss1 << "running mode: " ;
|
||||||
switch (Globals::infAlgorithm) {
|
switch (Globals::infAlgorithm) {
|
||||||
case InfAlgorithms::VE: ss1 << "ve" << endl; break;
|
case InfAlgorithms::VE: ss1 << "ve" << endl; break;
|
||||||
case InfAlgorithms::BN_BP: ss1 << "bn_bp" << endl; break;
|
case InfAlgorithms::BP: ss1 << "bp" << endl; break;
|
||||||
case InfAlgorithms::FG_BP: ss1 << "fg_bp" << endl; break;
|
case InfAlgorithms::CBP: ss1 << "cbp" << endl; break;
|
||||||
case InfAlgorithms::CBP: ss1 << "cbp" << endl; break;
|
|
||||||
}
|
}
|
||||||
ss1 << "message schedule: " ;
|
ss1 << "message schedule: " ;
|
||||||
switch (BpOptions::schedule) {
|
switch (BpOptions::schedule) {
|
||||||
@ -463,17 +464,17 @@ Statistics::getStatisticString (void)
|
|||||||
ss3 << "Ground Cluster Ground Cluster Neighborless" << endl;
|
ss3 << "Ground Cluster Ground Cluster Neighborless" << endl;
|
||||||
ss3 << "Vars Vars Factors Factors Vars" << endl;
|
ss3 << "Vars Vars Factors Factors Vars" << endl;
|
||||||
for (unsigned i = 0; i < compressInfo_.size(); i++) {
|
for (unsigned i = 0; i < compressInfo_.size(); i++) {
|
||||||
ss3 << setw (9) << compressInfo_[i].nGroundVars;
|
ss3 << setw (9) << compressInfo_[i].nrGroundVars;
|
||||||
ss3 << setw (10) << compressInfo_[i].nClusterVars;
|
ss3 << setw (10) << compressInfo_[i].nrClusterVars;
|
||||||
ss3 << setw (10) << compressInfo_[i].nGroundFactors;
|
ss3 << setw (10) << compressInfo_[i].nrGroundFactors;
|
||||||
ss3 << setw (10) << compressInfo_[i].nClusterFactors;
|
ss3 << setw (10) << compressInfo_[i].nrClusterFactors;
|
||||||
ss3 << setw (10) << compressInfo_[i].nWithoutNeighs;
|
ss3 << setw (10) << compressInfo_[i].nrNeighborless;
|
||||||
ss3 << endl;
|
ss3 << endl;
|
||||||
c1 += compressInfo_[i].nGroundVars - compressInfo_[i].nWithoutNeighs;
|
c1 += compressInfo_[i].nrGroundVars - compressInfo_[i].nrNeighborless;
|
||||||
c2 += compressInfo_[i].nClusterVars;
|
c2 += compressInfo_[i].nrClusterVars;
|
||||||
c3 += compressInfo_[i].nGroundFactors - compressInfo_[i].nWithoutNeighs;
|
c3 += compressInfo_[i].nrGroundFactors - compressInfo_[i].nrNeighborless;
|
||||||
c4 += compressInfo_[i].nClusterFactors;
|
c4 += compressInfo_[i].nrClusterFactors;
|
||||||
if (compressInfo_[i].nWithoutNeighs != 0) {
|
if (compressInfo_[i].nrNeighborless != 0) {
|
||||||
c2 --;
|
c2 --;
|
||||||
c4 --;
|
c4 --;
|
||||||
}
|
}
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <queue>
|
#include <queue>
|
||||||
@ -22,7 +23,9 @@ namespace Util {
|
|||||||
|
|
||||||
template <typename T> void addToVector (vector<T>&, const vector<T>&);
|
template <typename T> void addToVector (vector<T>&, const vector<T>&);
|
||||||
|
|
||||||
template <typename T> void addToQueue (queue<T>&, const vector<T>&);
|
template <typename T> void addToSet (set<T>&, const vector<T>&);
|
||||||
|
|
||||||
|
template <typename T> void addToQueue (queue<T>&, const vector<T>&);
|
||||||
|
|
||||||
template <typename T> bool contains (const vector<T>&, const T&);
|
template <typename T> bool contains (const vector<T>&, const T&);
|
||||||
|
|
||||||
@ -59,7 +62,7 @@ bool isInteger (const string&);
|
|||||||
|
|
||||||
string parametersToString (const Params&, unsigned = Constants::PRECISION);
|
string parametersToString (const Params&, unsigned = Constants::PRECISION);
|
||||||
|
|
||||||
vector<string> getJointStateStrings (const VarNodes&);
|
vector<string> getStateLines (const Vars&);
|
||||||
|
|
||||||
void printHeader (string, std::ostream& os = std::cout);
|
void printHeader (string, std::ostream& os = std::cout);
|
||||||
|
|
||||||
@ -83,6 +86,14 @@ Util::addToVector (vector<T>& v, const vector<T>& elements)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T> void
|
||||||
|
Util::addToSet (set<T>& s, const vector<T>& elements)
|
||||||
|
{
|
||||||
|
s.insert (elements.begin(), elements.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
template <typename T> void
|
template <typename T> void
|
||||||
Util::addToQueue (queue<T>& q, const vector<T>& elements)
|
Util::addToQueue (queue<T>& q, const vector<T>& elements)
|
||||||
{
|
{
|
||||||
@ -110,8 +121,7 @@ Util::contains (const set<T>& s, const T& e)
|
|||||||
|
|
||||||
|
|
||||||
template <typename K, typename V> bool
|
template <typename K, typename V> bool
|
||||||
Util::contains (
|
Util::contains (const unordered_map<K, V>& m, const K& k)
|
||||||
const unordered_map<K, V>& m, const K& k)
|
|
||||||
{
|
{
|
||||||
return m.find (k) != m.end();
|
return m.find (k) != m.end();
|
||||||
}
|
}
|
||||||
@ -322,17 +332,17 @@ struct CompressInfo
|
|||||||
{
|
{
|
||||||
CompressInfo (unsigned a, unsigned b, unsigned c, unsigned d, unsigned e)
|
CompressInfo (unsigned a, unsigned b, unsigned c, unsigned d, unsigned e)
|
||||||
{
|
{
|
||||||
nGroundVars = a;
|
nrGroundVars = a;
|
||||||
nGroundFactors = b;
|
nrGroundFactors = b;
|
||||||
nClusterVars = c;
|
nrClusterVars = c;
|
||||||
nClusterFactors = d;
|
nrClusterFactors = d;
|
||||||
nWithoutNeighs = e;
|
nrNeighborless = e;
|
||||||
}
|
}
|
||||||
unsigned nGroundVars;
|
unsigned nrGroundVars;
|
||||||
unsigned nGroundFactors;
|
unsigned nrGroundFactors;
|
||||||
unsigned nClusterVars;
|
unsigned nrClusterVars;
|
||||||
unsigned nClusterFactors;
|
unsigned nrClusterFactors;
|
||||||
unsigned nWithoutNeighs;
|
unsigned nrNeighborless;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
@ -349,7 +359,7 @@ class Statistics
|
|||||||
|
|
||||||
static void printStatistics (void);
|
static void printStatistics (void);
|
||||||
|
|
||||||
static void writeStatisticsToFile (const char*);
|
static void writeStatistics (const char*);
|
||||||
|
|
||||||
static void updateCompressingStatistics (
|
static void updateCompressingStatistics (
|
||||||
unsigned, unsigned, unsigned, unsigned, unsigned);
|
unsigned, unsigned, unsigned, unsigned, unsigned);
|
||||||
|
102
packages/CLPBN/clpbn/bp/Var.cpp
Normal file
102
packages/CLPBN/clpbn/bp/Var.cpp
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
#include <algorithm>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "Var.h"
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
|
||||||
|
unordered_map<VarId, VarInfo> Var::varsInfo_;
|
||||||
|
|
||||||
|
|
||||||
|
Var::Var (const Var* v)
|
||||||
|
{
|
||||||
|
varId_ = v->varId();
|
||||||
|
range_ = v->range();
|
||||||
|
evidence_ = v->getEvidence();
|
||||||
|
index_ = std::numeric_limits<unsigned>::max();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Var::Var (VarId varId, unsigned range, int evidence)
|
||||||
|
{
|
||||||
|
assert (range != 0);
|
||||||
|
assert (evidence < (int) range);
|
||||||
|
varId_ = varId;
|
||||||
|
range_ = range;
|
||||||
|
evidence_ = evidence;
|
||||||
|
index_ = std::numeric_limits<unsigned>::max();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
Var::isValidState (int stateIndex)
|
||||||
|
{
|
||||||
|
return stateIndex >= 0 && stateIndex < (int) range_;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
Var::isValidState (const string& stateName)
|
||||||
|
{
|
||||||
|
States states = Var::getVarInfo (varId_).states;
|
||||||
|
return Util::contains (states, stateName);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
Var::setEvidence (int ev)
|
||||||
|
{
|
||||||
|
assert (ev < (int) range_);
|
||||||
|
evidence_ = ev;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
Var::setEvidence (const string& ev)
|
||||||
|
{
|
||||||
|
States states = Var::getVarInfo (varId_).states;
|
||||||
|
for (unsigned i = 0; i < states.size(); i++) {
|
||||||
|
if (states[i] == ev) {
|
||||||
|
evidence_ = i;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert (false);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
string
|
||||||
|
Var::label (void) const
|
||||||
|
{
|
||||||
|
if (Var::varsHaveInfo()) {
|
||||||
|
return Var::getVarInfo (varId_).label;
|
||||||
|
}
|
||||||
|
stringstream ss;
|
||||||
|
ss << "x" << varId_;
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
States
|
||||||
|
Var::states (void) const
|
||||||
|
{
|
||||||
|
if (Var::varsHaveInfo()) {
|
||||||
|
return Var::getVarInfo (varId_).states;
|
||||||
|
}
|
||||||
|
States states;
|
||||||
|
for (unsigned i = 0; i < range_; i++) {
|
||||||
|
stringstream ss;
|
||||||
|
ss << i ;
|
||||||
|
states.push_back (ss.str());
|
||||||
|
}
|
||||||
|
return states;
|
||||||
|
}
|
||||||
|
|
108
packages/CLPBN/clpbn/bp/Var.h
Normal file
108
packages/CLPBN/clpbn/bp/Var.h
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
#ifndef HORUS_Var_H
|
||||||
|
#define HORUS_Var_H
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
#include "Util.h"
|
||||||
|
#include "Horus.h"
|
||||||
|
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
|
||||||
|
struct VarInfo
|
||||||
|
{
|
||||||
|
VarInfo (string l, const States& sts) : label(l), states(sts) { }
|
||||||
|
string label;
|
||||||
|
States states;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Var
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
Var (const Var*);
|
||||||
|
|
||||||
|
Var (VarId, unsigned, int = Constants::NO_EVIDENCE);
|
||||||
|
|
||||||
|
virtual ~Var (void) { };
|
||||||
|
|
||||||
|
unsigned varId (void) const { return varId_; }
|
||||||
|
|
||||||
|
unsigned range (void) const { return range_; }
|
||||||
|
|
||||||
|
int getEvidence (void) const { return evidence_; }
|
||||||
|
|
||||||
|
unsigned getIndex (void) const { return index_; }
|
||||||
|
|
||||||
|
void setIndex (unsigned idx) { index_ = idx; }
|
||||||
|
|
||||||
|
operator unsigned () const { return index_; }
|
||||||
|
|
||||||
|
bool hasEvidence (void) const
|
||||||
|
{
|
||||||
|
return evidence_ != Constants::NO_EVIDENCE;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool operator== (const Var& var) const
|
||||||
|
{
|
||||||
|
assert (!(varId_ == var.varId() && range_ != var.range()));
|
||||||
|
return varId_ == var.varId();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool operator!= (const Var& var) const
|
||||||
|
{
|
||||||
|
assert (!(varId_ == var.varId() && range_ != var.range()));
|
||||||
|
return varId_ != var.varId();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isValidState (int);
|
||||||
|
|
||||||
|
bool isValidState (const string&);
|
||||||
|
|
||||||
|
void setEvidence (int);
|
||||||
|
|
||||||
|
void setEvidence (const string&);
|
||||||
|
|
||||||
|
string label (void) const;
|
||||||
|
|
||||||
|
States states (void) const;
|
||||||
|
|
||||||
|
static void addVarInfo (
|
||||||
|
VarId vid, string label, const States& states)
|
||||||
|
{
|
||||||
|
assert (Util::contains (varsInfo_, vid) == false);
|
||||||
|
varsInfo_.insert (make_pair (vid, VarInfo (label, states)));
|
||||||
|
}
|
||||||
|
|
||||||
|
static VarInfo getVarInfo (VarId vid)
|
||||||
|
{
|
||||||
|
assert (Util::contains (varsInfo_, vid));
|
||||||
|
return varsInfo_.find (vid)->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool varsHaveInfo (void)
|
||||||
|
{
|
||||||
|
return varsInfo_.size() != 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void clearVarsInfo (void)
|
||||||
|
{
|
||||||
|
varsInfo_.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
VarId varId_;
|
||||||
|
unsigned range_;
|
||||||
|
int evidence_;
|
||||||
|
unsigned index_;
|
||||||
|
|
||||||
|
static unordered_map<VarId, VarInfo> varsInfo_;
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // BP_Var_H
|
||||||
|
|
@ -6,61 +6,27 @@
|
|||||||
#include "Util.h"
|
#include "Util.h"
|
||||||
|
|
||||||
|
|
||||||
VarElimSolver::VarElimSolver (const BayesNet& bn) : Solver (&bn)
|
|
||||||
{
|
|
||||||
bayesNet_ = &bn;
|
|
||||||
factorGraph_ = new FactorGraph (bn);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
VarElimSolver::VarElimSolver (const FactorGraph& fg) : Solver (&fg)
|
|
||||||
{
|
|
||||||
bayesNet_ = 0;
|
|
||||||
factorGraph_ = &fg;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
VarElimSolver::~VarElimSolver (void)
|
VarElimSolver::~VarElimSolver (void)
|
||||||
{
|
{
|
||||||
if (bayesNet_) {
|
delete factorList_.back();
|
||||||
delete factorGraph_;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
Params
|
||||||
VarElimSolver::getPosterioriOf (VarId vid)
|
VarElimSolver::solveQuery (VarIds queryVids)
|
||||||
{
|
|
||||||
assert (factorGraph_->getFgVarNode (vid));
|
|
||||||
FgVarNode* vn = factorGraph_->getFgVarNode (vid);
|
|
||||||
if (vn->hasEvidence()) {
|
|
||||||
Params params (vn->nrStates(), 0.0);
|
|
||||||
params[vn->getEvidence()] = 1.0;
|
|
||||||
return params;
|
|
||||||
}
|
|
||||||
return getJointDistributionOf (VarIds() = {vid});
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
|
||||||
VarElimSolver::getJointDistributionOf (const VarIds& vids)
|
|
||||||
{
|
{
|
||||||
factorList_.clear();
|
factorList_.clear();
|
||||||
varFactors_.clear();
|
varFactors_.clear();
|
||||||
elimOrder_.clear();
|
elimOrder_.clear();
|
||||||
createFactorList();
|
createFactorList();
|
||||||
introduceEvidence();
|
absorveEvidence();
|
||||||
chooseEliminationOrder (vids);
|
findEliminationOrder (queryVids);
|
||||||
processFactorList (vids);
|
processFactorList (queryVids);
|
||||||
Params params = factorList_.back()->params();
|
Params params = factorList_.back()->params();
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
Util::fromLog (params);
|
Util::fromLog (params);
|
||||||
}
|
}
|
||||||
delete factorList_.back();
|
|
||||||
return params;
|
return params;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -69,11 +35,11 @@ VarElimSolver::getJointDistributionOf (const VarIds& vids)
|
|||||||
void
|
void
|
||||||
VarElimSolver::createFactorList (void)
|
VarElimSolver::createFactorList (void)
|
||||||
{
|
{
|
||||||
const FgFacSet& factorNodes = factorGraph_->getFactorNodes();
|
const FacNodes& facNodes = fg.facNodes();
|
||||||
factorList_.reserve (factorNodes.size() * 2);
|
factorList_.reserve (facNodes.size() * 2);
|
||||||
for (unsigned i = 0; i < factorNodes.size(); i++) {
|
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||||
factorList_.push_back (new Factor (*factorNodes[i]->factor()));
|
factorList_.push_back (new Factor (facNodes[i]->factor()));
|
||||||
const FgVarSet& neighs = factorNodes[i]->neighbors();
|
const VarNodes& neighs = facNodes[i]->neighbors();
|
||||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||||
unordered_map<VarId,vector<unsigned> >::iterator it
|
unordered_map<VarId,vector<unsigned> >::iterator it
|
||||||
= varFactors_.find (neighs[j]->varId());
|
= varFactors_.find (neighs[j]->varId());
|
||||||
@ -89,9 +55,9 @@ VarElimSolver::createFactorList (void)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
VarElimSolver::introduceEvidence (void)
|
VarElimSolver::absorveEvidence (void)
|
||||||
{
|
{
|
||||||
const FgVarSet& varNodes = factorGraph_->getVarNodes();
|
const VarNodes& varNodes = fg.varNodes();
|
||||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||||
if (varNodes[i]->hasEvidence()) {
|
if (varNodes[i]->hasEvidence()) {
|
||||||
const vector<unsigned>& idxs =
|
const vector<unsigned>& idxs =
|
||||||
@ -112,21 +78,9 @@ VarElimSolver::introduceEvidence (void)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
VarElimSolver::chooseEliminationOrder (const VarIds& vids)
|
VarElimSolver::findEliminationOrder (const VarIds& vids)
|
||||||
{
|
{
|
||||||
if (bayesNet_) {
|
elimOrder_ = ElimGraph::getEliminationOrder (factorList_, vids);
|
||||||
ElimGraph graph (*bayesNet_);
|
|
||||||
elimOrder_ = graph.getEliminatingOrder (vids);
|
|
||||||
} else {
|
|
||||||
const FgVarSet& varNodes = factorGraph_->getVarNodes();
|
|
||||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
|
||||||
VarId vid = varNodes[i]->varId();
|
|
||||||
if (Util::contains (vids, vid) == false &&
|
|
||||||
varNodes[i]->hasEvidence() == false) {
|
|
||||||
elimOrder_.push_back (vid);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -149,7 +103,7 @@ VarElimSolver::processFactorList (const VarIds& vids)
|
|||||||
|
|
||||||
VarIds unobservedVids;
|
VarIds unobservedVids;
|
||||||
for (unsigned i = 0; i < vids.size(); i++) {
|
for (unsigned i = 0; i < vids.size(); i++) {
|
||||||
if (factorGraph_->getFgVarNode (vids[i])->hasEvidence() == false) {
|
if (fg.getVarNode (vids[i])->hasEvidence() == false) {
|
||||||
unobservedVids.push_back (vids[i]);
|
unobservedVids.push_back (vids[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -165,13 +119,12 @@ void
|
|||||||
VarElimSolver::eliminate (VarId elimVar)
|
VarElimSolver::eliminate (VarId elimVar)
|
||||||
{
|
{
|
||||||
Factor* result = 0;
|
Factor* result = 0;
|
||||||
FgVarNode* vn = factorGraph_->getFgVarNode (elimVar);
|
|
||||||
vector<unsigned>& idxs = varFactors_.find (elimVar)->second;
|
vector<unsigned>& idxs = varFactors_.find (elimVar)->second;
|
||||||
for (unsigned i = 0; i < idxs.size(); i++) {
|
for (unsigned i = 0; i < idxs.size(); i++) {
|
||||||
unsigned idx = idxs[i];
|
unsigned idx = idxs[i];
|
||||||
if (factorList_[idx]) {
|
if (factorList_[idx]) {
|
||||||
if (result == 0) {
|
if (result == 0) {
|
||||||
result = new Factor(*factorList_[idx]);
|
result = new Factor (*factorList_[idx]);
|
||||||
} else {
|
} else {
|
||||||
result->multiply (*factorList_[idx]);
|
result->multiply (*factorList_[idx]);
|
||||||
}
|
}
|
||||||
@ -180,7 +133,7 @@ VarElimSolver::eliminate (VarId elimVar)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (result != 0 && result->nrArguments() != 1) {
|
if (result != 0 && result->nrArguments() != 1) {
|
||||||
result->sumOut (vn->varId());
|
result->sumOut (elimVar);
|
||||||
factorList_.push_back (result);
|
factorList_.push_back (result);
|
||||||
const VarIds& resultVarIds = result->arguments();
|
const VarIds& resultVarIds = result->arguments();
|
||||||
for (unsigned i = 0; i < resultVarIds.size(); i++) {
|
for (unsigned i = 0; i < resultVarIds.size(); i++) {
|
||||||
@ -199,7 +152,6 @@ VarElimSolver::printActiveFactors (void)
|
|||||||
for (unsigned i = 0; i < factorList_.size(); i++) {
|
for (unsigned i = 0; i < factorList_.size(); i++) {
|
||||||
if (factorList_[i] != 0) {
|
if (factorList_[i] != 0) {
|
||||||
factorList_[i]->print();
|
factorList_[i]->print();
|
||||||
cout << endl;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -5,7 +5,6 @@
|
|||||||
|
|
||||||
#include "Solver.h"
|
#include "Solver.h"
|
||||||
#include "FactorGraph.h"
|
#include "FactorGraph.h"
|
||||||
#include "BayesNet.h"
|
|
||||||
#include "Horus.h"
|
#include "Horus.h"
|
||||||
|
|
||||||
|
|
||||||
@ -15,24 +14,18 @@ using namespace std;
|
|||||||
class VarElimSolver : public Solver
|
class VarElimSolver : public Solver
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
VarElimSolver (const BayesNet&);
|
VarElimSolver (const FactorGraph& fg) : Solver (fg) { }
|
||||||
|
|
||||||
VarElimSolver (const FactorGraph&);
|
|
||||||
|
|
||||||
~VarElimSolver (void);
|
~VarElimSolver (void);
|
||||||
|
|
||||||
void runSolver (void) { }
|
Params solveQuery (VarIds);
|
||||||
|
|
||||||
Params getPosterioriOf (VarId);
|
|
||||||
|
|
||||||
Params getJointDistributionOf (const VarIds&);
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void createFactorList (void);
|
void createFactorList (void);
|
||||||
|
|
||||||
void introduceEvidence (void);
|
void absorveEvidence (void);
|
||||||
|
|
||||||
void chooseEliminationOrder (const VarIds&);
|
void findEliminationOrder (const VarIds&);
|
||||||
|
|
||||||
void processFactorList (const VarIds&);
|
void processFactorList (const VarIds&);
|
||||||
|
|
||||||
@ -40,8 +33,6 @@ class VarElimSolver : public Solver
|
|||||||
|
|
||||||
void printActiveFactors (void);
|
void printActiveFactors (void);
|
||||||
|
|
||||||
const BayesNet* bayesNet_;
|
|
||||||
const FactorGraph* factorGraph_;
|
|
||||||
vector<Factor*> factorList_;
|
vector<Factor*> factorList_;
|
||||||
VarIds elimOrder_;
|
VarIds elimOrder_;
|
||||||
unordered_map<VarId, vector<unsigned>> varFactors_;
|
unordered_map<VarId, vector<unsigned>> varFactors_;
|
||||||
|
@ -1,100 +0,0 @@
|
|||||||
#include <algorithm>
|
|
||||||
#include <sstream>
|
|
||||||
|
|
||||||
#include "VarNode.h"
|
|
||||||
#include "GraphicalModel.h"
|
|
||||||
|
|
||||||
using namespace std;
|
|
||||||
|
|
||||||
|
|
||||||
VarNode::VarNode (const VarNode* v)
|
|
||||||
{
|
|
||||||
varId_ = v->varId();
|
|
||||||
nrStates_ = v->nrStates();
|
|
||||||
evidence_ = v->getEvidence();
|
|
||||||
index_ = std::numeric_limits<unsigned>::max();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
VarNode::VarNode (VarId varId, unsigned nrStates, int evidence)
|
|
||||||
{
|
|
||||||
assert (nrStates != 0);
|
|
||||||
assert (evidence < (int) nrStates);
|
|
||||||
varId_ = varId;
|
|
||||||
nrStates_ = nrStates;
|
|
||||||
evidence_ = evidence;
|
|
||||||
index_ = std::numeric_limits<unsigned>::max();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
|
||||||
VarNode::isValidState (int stateIndex)
|
|
||||||
{
|
|
||||||
return stateIndex >= 0 && stateIndex < (int) nrStates_;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
|
||||||
VarNode::isValidState (const string& stateName)
|
|
||||||
{
|
|
||||||
States states = GraphicalModel::getVarInformation (varId_).states;
|
|
||||||
return Util::contains (states, stateName);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
VarNode::setEvidence (int ev)
|
|
||||||
{
|
|
||||||
assert (ev < (int) nrStates_);
|
|
||||||
evidence_ = ev;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
VarNode::setEvidence (const string& ev)
|
|
||||||
{
|
|
||||||
States states = GraphicalModel::getVarInformation (varId_).states;
|
|
||||||
for (unsigned i = 0; i < states.size(); i++) {
|
|
||||||
if (states[i] == ev) {
|
|
||||||
evidence_ = i;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
assert (false);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
string
|
|
||||||
VarNode::label (void) const
|
|
||||||
{
|
|
||||||
if (GraphicalModel::variablesHaveInformation()) {
|
|
||||||
return GraphicalModel::getVarInformation (varId_).label;
|
|
||||||
}
|
|
||||||
stringstream ss;
|
|
||||||
ss << "x" << varId_;
|
|
||||||
return ss.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
States
|
|
||||||
VarNode::states (void) const
|
|
||||||
{
|
|
||||||
if (GraphicalModel::variablesHaveInformation()) {
|
|
||||||
return GraphicalModel::getVarInformation (varId_).states;
|
|
||||||
}
|
|
||||||
States states;
|
|
||||||
for (unsigned i = 0; i < nrStates_; i++) {
|
|
||||||
stringstream ss;
|
|
||||||
ss << i ;
|
|
||||||
states.push_back (ss.str());
|
|
||||||
}
|
|
||||||
return states;
|
|
||||||
}
|
|
||||||
|
|
@ -1,73 +0,0 @@
|
|||||||
#ifndef HORUS_VARNODE_H
|
|
||||||
#define HORUS_VARNODE_H
|
|
||||||
|
|
||||||
#include <cassert>
|
|
||||||
|
|
||||||
#include <iostream>
|
|
||||||
|
|
||||||
#include "Horus.h"
|
|
||||||
|
|
||||||
using namespace std;
|
|
||||||
|
|
||||||
class VarNode
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
VarNode (const VarNode*);
|
|
||||||
|
|
||||||
VarNode (VarId, unsigned, int = Constants::NO_EVIDENCE);
|
|
||||||
|
|
||||||
virtual ~VarNode (void) { };
|
|
||||||
|
|
||||||
unsigned varId (void) const { return varId_; }
|
|
||||||
|
|
||||||
unsigned nrStates (void) const { return nrStates_; }
|
|
||||||
|
|
||||||
int getEvidence (void) const { return evidence_; }
|
|
||||||
|
|
||||||
unsigned getIndex (void) const { return index_; }
|
|
||||||
|
|
||||||
void setIndex (unsigned idx) { index_ = idx; }
|
|
||||||
|
|
||||||
operator unsigned () const { return index_; }
|
|
||||||
|
|
||||||
bool hasEvidence (void) const
|
|
||||||
{
|
|
||||||
return evidence_ != Constants::NO_EVIDENCE;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool operator== (const VarNode& var) const
|
|
||||||
{
|
|
||||||
cout << "equal operator called" << endl;
|
|
||||||
assert (!(varId_ == var.varId() && nrStates_ != var.nrStates()));
|
|
||||||
return varId_ == var.varId();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool operator!= (const VarNode& var) const
|
|
||||||
{
|
|
||||||
cout << "diff operator called" << endl;
|
|
||||||
assert (!(varId_ == var.varId() && nrStates_ != var.nrStates()));
|
|
||||||
return varId_ != var.varId();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool isValidState (int);
|
|
||||||
|
|
||||||
bool isValidState (const string&);
|
|
||||||
|
|
||||||
void setEvidence (int);
|
|
||||||
|
|
||||||
void setEvidence (const string&);
|
|
||||||
|
|
||||||
string label (void) const;
|
|
||||||
|
|
||||||
States states (void) const;
|
|
||||||
|
|
||||||
private:
|
|
||||||
VarId varId_;
|
|
||||||
unsigned nrStates_;
|
|
||||||
int evidence_;
|
|
||||||
unsigned index_;
|
|
||||||
|
|
||||||
};
|
|
||||||
|
|
||||||
#endif // BP_VARNODE_H
|
|
||||||
|
|
35
packages/CLPBN/clpbn/bp/benchmarks/benchs.sh
Executable file
35
packages/CLPBN/clpbn/bp/benchmarks/benchs.sh
Executable file
@ -0,0 +1,35 @@
|
|||||||
|
|
||||||
|
if [ $1 ] && [ $1 == "clear" ]; then
|
||||||
|
rm *~
|
||||||
|
rm -f school/*.log school/*~
|
||||||
|
rm -f city/*.log city/*~
|
||||||
|
rm -f workshop_attrs/*.log workshop_attrs/*~
|
||||||
|
fi
|
||||||
|
|
||||||
|
function run_solver
|
||||||
|
{
|
||||||
|
constraint=$1
|
||||||
|
solver_flag=true
|
||||||
|
if [ -n "$2" ]; then
|
||||||
|
if [ $SOLVER = hve ]; then
|
||||||
|
extra_flag=clpbn_horus:set_horus_flag\(elim_heuristic,$2\)
|
||||||
|
elif [ $SOLVER = bp ]; then
|
||||||
|
extra_flag=clpbn_horus:set_horus_flag\(schedule,$2\)
|
||||||
|
elif [ $SOLVER = cbp ]; then
|
||||||
|
extra_flag=clpbn_horus:set_horus_flag\(schedule,$2\)
|
||||||
|
else
|
||||||
|
echo "unknow flag $2"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
/usr/bin/time -o $LOG_FILE -a -f "real:%E\tuser:%U\tsys:%S" \
|
||||||
|
$YAP << EOF >> $LOG_FILE 2>> ignore.$LOG_FILE
|
||||||
|
[$NETWORK].
|
||||||
|
[$constraint].
|
||||||
|
clpbn_horus:set_solver($SOLVER).
|
||||||
|
clpbn_horus:set_horus_flag(use_logarithms, true).
|
||||||
|
$solver_flag.
|
||||||
|
$QUERY.
|
||||||
|
open("$LOG_FILE", 'append', S), format(S, '$constraint: ~15+ ', []), close(S).
|
||||||
|
EOF
|
||||||
|
}
|
||||||
|
|
@ -1,50 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
cp ~/bin/yap ~/bin/town_bnbp
|
|
||||||
YAP=~/bin/town_bnbp
|
|
||||||
|
|
||||||
#OUT_FILE_NAME=results`date "+ %H:%M:%S %d-%m-%Y"`.log
|
|
||||||
OUT_FILE_NAME=bnbp.log
|
|
||||||
rm -f $OUT_FILE_NAME
|
|
||||||
rm -f ignore.$OUT_FILE_NAME
|
|
||||||
|
|
||||||
|
|
||||||
function run_solver
|
|
||||||
{
|
|
||||||
if [ $2 = bp ]
|
|
||||||
then
|
|
||||||
extra_flag1=clpbn_horus:set_horus_flag\(inf_alg,$4\)
|
|
||||||
extra_flag2=clpbn_horus:set_horus_flag\(schedule,$5\)
|
|
||||||
else
|
|
||||||
extra_flag1=true
|
|
||||||
extra_flag2=true
|
|
||||||
fi
|
|
||||||
/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME
|
|
||||||
[$1].
|
|
||||||
clpbn:set_clpbn_flag(solver,$2),
|
|
||||||
clpbn_horus:set_horus_flag(use_logarithms, true),
|
|
||||||
$extra_flag1, $extra_flag2,
|
|
||||||
run_query(_R),
|
|
||||||
open("$OUT_FILE_NAME", 'append',S),
|
|
||||||
format(S, '$3: ~15+ ',[]),
|
|
||||||
close(S).
|
|
||||||
EOF
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
function run_all_graphs
|
|
||||||
{
|
|
||||||
echo "*******************************************************************" >> "$OUT_FILE_NAME"
|
|
||||||
echo "results for solver $2" >> $OUT_FILE_NAME
|
|
||||||
echo "*******************************************************************" >> "$OUT_FILE_NAME"
|
|
||||||
run_solver town_1000 $1 town_1000 $3 $4 $5
|
|
||||||
run_solver town_5000 $1 town_5000 $3 $4 $5
|
|
||||||
run_solver town_10000 $1 town_10000 $3 $4 $5
|
|
||||||
run_solver town_50000 $1 town_50000 $3 $4 $5
|
|
||||||
run_solver town_100000 $1 town_100000 $3 $4 $5
|
|
||||||
run_solver town_500000 $1 town_500000 $3 $4 $5
|
|
||||||
run_solver town_1000000 $1 town_1000000 $3 $4 $5
|
|
||||||
}
|
|
||||||
|
|
||||||
run_all_graphs bp "bn_bp(seq_fixed) " bn_bp seq_fixed
|
|
||||||
|
|
17
packages/CLPBN/clpbn/bp/benchmarks/city/bp_tests.sh
Executable file
17
packages/CLPBN/clpbn/bp/benchmarks/city/bp_tests.sh
Executable file
@ -0,0 +1,17 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
source city.sh
|
||||||
|
source ../benchs.sh
|
||||||
|
|
||||||
|
SOLVER="bp"
|
||||||
|
|
||||||
|
YAP=~/bin/$SHORTNAME-$SOLVER
|
||||||
|
|
||||||
|
LOG_FILE=$SOLVER.log
|
||||||
|
#LOG_FILE=results`date "+ %H:%M:%S %d-%m-%Y"`.
|
||||||
|
|
||||||
|
rm -f $LOG_FILE
|
||||||
|
rm -f ignore.$LOG_FILE
|
||||||
|
|
||||||
|
run_all_graphs "bp(shedule=seq_fixed) " seq_fixed
|
||||||
|
|
@ -1,56 +1,17 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
cp ~/bin/yap ~/bin/town_cbp
|
source city.sh
|
||||||
YAP=~/bin/town_cbp
|
source ../benchs.sh
|
||||||
|
|
||||||
#OUT_FILE_NAME=results`date "+ %H:%M:%S %d-%m-%Y"`.log
|
SOLVER="cbp"
|
||||||
OUT_FILE_NAME=cbp.log
|
|
||||||
rm -f $OUT_FILE_NAME
|
|
||||||
rm -f ignore.$OUT_FILE_NAME
|
|
||||||
|
|
||||||
|
YAP=~/bin/$SHORTNAME-$SOLVER
|
||||||
|
|
||||||
function run_solver
|
LOG_FILE=$SOLVER.log
|
||||||
{
|
#LOG_FILE=results`date "+ %H:%M:%S %d-%m-%Y"`.
|
||||||
if [ $2 = bp ]
|
|
||||||
then
|
|
||||||
extra_flag1=clpbn_horus:set_horus_flag\(inf_alg,$4\)
|
|
||||||
extra_flag2=clpbn_horus:set_horus_flag\(schedule,$5\)
|
|
||||||
else
|
|
||||||
extra_flag1=true
|
|
||||||
extra_flag2=true
|
|
||||||
fi
|
|
||||||
/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME
|
|
||||||
[$1].
|
|
||||||
clpbn:set_clpbn_flag(solver,$2),
|
|
||||||
clpbn_horus:set_horus_flag(use_logarithms, true),
|
|
||||||
$extra_flag1, $extra_flag2,
|
|
||||||
run_query(_R),
|
|
||||||
open("$OUT_FILE_NAME", 'append',S),
|
|
||||||
format(S, '$3: ~15+ ',[]),
|
|
||||||
close(S).
|
|
||||||
EOF
|
|
||||||
}
|
|
||||||
|
|
||||||
|
rm -f $LOG_FILE
|
||||||
|
rm -f ignore.$LOG_FILE
|
||||||
|
|
||||||
function run_all_graphs
|
run_all_graphs "cbp(shedule=seq_fixed) " seq_fixed
|
||||||
{
|
|
||||||
echo "*******************************************************************" >> "$OUT_FILE_NAME"
|
|
||||||
echo "results for solver $2" >> $OUT_FILE_NAME
|
|
||||||
echo "*******************************************************************" >> "$OUT_FILE_NAME"
|
|
||||||
run_solver town_3 $1 town_3 $3 $4 $5
|
|
||||||
return
|
|
||||||
run_solver town_1000 $1 town_1000 $3 $4 $5
|
|
||||||
run_solver town_5000 $1 town_5000 $3 $4 $5
|
|
||||||
run_solver town_10000 $1 town_10000 $3 $4 $5
|
|
||||||
run_solver town_50000 $1 town_50000 $3 $4 $5
|
|
||||||
run_solver town_100000 $1 town_100000 $3 $4 $5
|
|
||||||
run_solver town_500000 $1 town_500000 $3 $4 $5
|
|
||||||
run_solver town_1000000 $1 town_1000000 $3 $4 $5
|
|
||||||
run_solver town_2500000 $1 town_2500000 $3 $4 $5
|
|
||||||
run_solver town_5000000 $1 town_5000000 $3 $4 $5
|
|
||||||
run_solver town_7500000 $1 town_7500000 $3 $4 $5
|
|
||||||
run_solver town_10000000 $1 town_10000000 $3 $4 $5
|
|
||||||
}
|
|
||||||
|
|
||||||
run_all_graphs bp "cbp(seq_fixed) " cbp seq_fixed
|
|
||||||
|
|
||||||
|
25
packages/CLPBN/clpbn/bp/benchmarks/city/city.sh
Executable file
25
packages/CLPBN/clpbn/bp/benchmarks/city/city.sh
Executable file
@ -0,0 +1,25 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
NETWORK="'../../examples/city'"
|
||||||
|
SHORTNAME="city"
|
||||||
|
QUERY="is_joe_guilty(X)"
|
||||||
|
|
||||||
|
|
||||||
|
function run_all_graphs
|
||||||
|
{
|
||||||
|
cp ~/bin/yap $YAP
|
||||||
|
echo -n "**********************************" >> $LOG_FILE
|
||||||
|
echo "**********************************" >> $LOG_FILE
|
||||||
|
echo "results for solver $1" >> $LOG_FILE
|
||||||
|
echo -n "**********************************" >> $LOG_FILE
|
||||||
|
echo "**********************************" >> $LOG_FILE
|
||||||
|
run_solver city_5 $2
|
||||||
|
#run_solver city_1000 $2
|
||||||
|
#run_solver city_5000 $2
|
||||||
|
#run_solver city_10000 $2
|
||||||
|
#run_solver city_50000 $2
|
||||||
|
#run_solver city_100000 $2
|
||||||
|
#run_solver city_500000 $2
|
||||||
|
#run_solver city_1000000 $2
|
||||||
|
}
|
||||||
|
|
37
packages/CLPBN/clpbn/bp/benchmarks/city/city_generator.sh
Executable file
37
packages/CLPBN/clpbn/bp/benchmarks/city/city_generator.sh
Executable file
@ -0,0 +1,37 @@
|
|||||||
|
#!/home/tiago/bin/yap -L --
|
||||||
|
|
||||||
|
|
||||||
|
:- initialization(main).
|
||||||
|
|
||||||
|
|
||||||
|
main :-
|
||||||
|
unix(argv([H])),
|
||||||
|
generate_town(H).
|
||||||
|
|
||||||
|
|
||||||
|
generate_town(N) :-
|
||||||
|
atomic_concat(['city_', N, '.yap'], FileName),
|
||||||
|
open(FileName, 'write', S),
|
||||||
|
atom_number(N, N2),
|
||||||
|
generate_people(S, N2, 4),
|
||||||
|
write(S, '\n'),
|
||||||
|
generate_query(S, N2, 4),
|
||||||
|
write(S, '\n'),
|
||||||
|
close(S).
|
||||||
|
|
||||||
|
|
||||||
|
generate_people(S, N, Counting) :-
|
||||||
|
Counting > N, !.
|
||||||
|
generate_people(S, N, Counting) :-
|
||||||
|
format(S, 'people(p~w, nyc).~n', [Counting]),
|
||||||
|
Counting1 is Counting + 1,
|
||||||
|
generate_people(S, N, Counting1).
|
||||||
|
|
||||||
|
|
||||||
|
generate_query(S, N, Counting) :-
|
||||||
|
Counting > N, !.
|
||||||
|
generate_query(S, N, Counting) :- !,
|
||||||
|
format(S, 'ev(descn(p~w, t)).~n', [Counting]),
|
||||||
|
Counting1 is Counting + 1,
|
||||||
|
generate_query(S, N, Counting1).
|
||||||
|
|
@ -1,50 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
cp ~/bin/yap ~/bin/town_fgbp
|
|
||||||
YAP=~/bin/town_fgbp
|
|
||||||
|
|
||||||
#OUT_FILE_NAME=results`date "+ %H:%M:%S %d-%m-%Y"`.log
|
|
||||||
OUT_FILE_NAME=fb_bp.log
|
|
||||||
rm -f $OUT_FILE_NAME
|
|
||||||
rm -f ignore.$OUT_FILE_NAME
|
|
||||||
|
|
||||||
|
|
||||||
function run_solver
|
|
||||||
{
|
|
||||||
if [ $2 = bp ]
|
|
||||||
then
|
|
||||||
extra_flag1=clpbn_horus:set_horus_flag\(inf_alg,$4\)
|
|
||||||
extra_flag2=clpbn_horus:set_horus_flag\(schedule,$5\)
|
|
||||||
else
|
|
||||||
extra_flag1=true
|
|
||||||
extra_flag2=true
|
|
||||||
fi
|
|
||||||
/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME
|
|
||||||
[$1].
|
|
||||||
clpbn:set_clpbn_flag(solver,$2),
|
|
||||||
clpbn_horus:set_horus_flag(use_logarithms, true),
|
|
||||||
$extra_flag1, $extra_flag2,
|
|
||||||
run_query(_R),
|
|
||||||
open("$OUT_FILE_NAME", 'append',S),
|
|
||||||
format(S, '$3: ~15+ ',[]),
|
|
||||||
close(S).
|
|
||||||
EOF
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
function run_all_graphs
|
|
||||||
{
|
|
||||||
echo "*******************************************************************" >> "$OUT_FILE_NAME"
|
|
||||||
echo "results for solver $2" >> $OUT_FILE_NAME
|
|
||||||
echo "*******************************************************************" >> "$OUT_FILE_NAME"
|
|
||||||
run_solver town_1000 $1 town_1000 $3 $4 $5
|
|
||||||
#run_solver town_5000 $1 town_5000 $3 $4 $5
|
|
||||||
#run_solver town_10000 $1 town_10000 $3 $4 $5
|
|
||||||
#run_solver town_50000 $1 town_50000 $3 $4 $5
|
|
||||||
#run_solver town_100000 $1 town_100000 $3 $4 $5
|
|
||||||
#run_solver town_500000 $1 town_500000 $3 $4 $5
|
|
||||||
#run_solver town_1000000 $1 town_1000000 $3 $4 $5
|
|
||||||
}
|
|
||||||
|
|
||||||
run_all_graphs bp "fg_bp(seq_fixed) " fg_bp seq_fixed
|
|
||||||
|
|
17
packages/CLPBN/clpbn/bp/benchmarks/city/fove_tests.sh
Executable file
17
packages/CLPBN/clpbn/bp/benchmarks/city/fove_tests.sh
Executable file
@ -0,0 +1,17 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
source city.sh
|
||||||
|
source ../benchs.sh
|
||||||
|
|
||||||
|
SOLVER="fove"
|
||||||
|
|
||||||
|
YAP=~/bin/$SHORTNAME-$SOLVER
|
||||||
|
|
||||||
|
LOG_FILE=$SOLVER.log
|
||||||
|
#LOG_FILE=results`date "+ %H:%M:%S %d-%m-%Y"`.
|
||||||
|
|
||||||
|
rm -f $LOG_FILE
|
||||||
|
rm -f ignore.$LOG_FILEE
|
||||||
|
|
||||||
|
run_all_graphs "fove "
|
||||||
|
|
@ -1,50 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
cp ~/bin/yap ~/bin/town_gibbs
|
|
||||||
YAP=~/bin/town_gibbs
|
|
||||||
|
|
||||||
#OUT_FILE_NAME=results`date "+ %H:%M:%S %d-%m-%Y"`.log
|
|
||||||
OUT_FILE_NAME=gibbs.log
|
|
||||||
rm -f $OUT_FILE_NAME
|
|
||||||
rm -f ignore.$OUT_FILE_NAME
|
|
||||||
|
|
||||||
|
|
||||||
function run_solver
|
|
||||||
{
|
|
||||||
if [ $2 = bp ]
|
|
||||||
then
|
|
||||||
extra_flag1=clpbn_bp:set_horus_flag\(inf_alg,$4\)
|
|
||||||
extra_flag2=clpbn_bp:set_horus_flag\(schedule,$5\)
|
|
||||||
else
|
|
||||||
extra_flag1=true
|
|
||||||
extra_flag2=true
|
|
||||||
fi
|
|
||||||
/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME
|
|
||||||
[$1].
|
|
||||||
clpbn:set_clpbn_flag(solver,$2),
|
|
||||||
clpbn_bp:set_horus_flag(use_logarithms, true),
|
|
||||||
$extra_flag1, $extra_flag2,
|
|
||||||
run_query(_R),
|
|
||||||
open("$OUT_FILE_NAME", 'append',S),
|
|
||||||
format(S, '$3: ~15+ ',[]),
|
|
||||||
close(S).
|
|
||||||
EOF
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
function run_all_graphs
|
|
||||||
{
|
|
||||||
echo "*******************************************************************" >> "$OUT_FILE_NAME"
|
|
||||||
echo "results for solver $2" >> $OUT_FILE_NAME
|
|
||||||
echo "*******************************************************************" >> "$OUT_FILE_NAME"
|
|
||||||
run_solver town_1000 $1 town_1000 $3 $4 $5
|
|
||||||
run_solver town_5000 $1 town_5000 $3 $4 $5
|
|
||||||
run_solver town_10000 $1 town_10000 $3 $4 $5
|
|
||||||
run_solver town_50000 $1 town_50000 $3 $4 $5
|
|
||||||
run_solver town_100000 $1 town_100000 $3 $4 $5
|
|
||||||
run_solver town_500000 $1 town_500000 $3 $4 $5
|
|
||||||
run_solver town_1000000 $1 town_1000000 $3 $4 $5
|
|
||||||
}
|
|
||||||
|
|
||||||
run_all_graphs gibbs "gibbs "
|
|
||||||
|
|
17
packages/CLPBN/clpbn/bp/benchmarks/city/hve_tests.sh
Executable file
17
packages/CLPBN/clpbn/bp/benchmarks/city/hve_tests.sh
Executable file
@ -0,0 +1,17 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
source city.sh
|
||||||
|
source ../benchs.sh
|
||||||
|
|
||||||
|
SOLVER="hve"
|
||||||
|
|
||||||
|
YAP=~/bin/$SHORTNAME-$SOLVER
|
||||||
|
|
||||||
|
LOG_FILE=$SOLVER.log
|
||||||
|
#LOG_FILE=results`date "+ %H:%M:%S %d-%m-%Y"`.
|
||||||
|
|
||||||
|
rm -f $LOG_FILE
|
||||||
|
rm -f ignore.$LOG_FILE
|
||||||
|
|
||||||
|
run_all_graphs "hve(elim_heuristic=min_neighbors) " min_neighbors
|
||||||
|
|
@ -1,50 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
cp ~/bin/yap ~/bin/town_jt
|
|
||||||
YAP=~/bin/town_jt
|
|
||||||
|
|
||||||
#OUT_FILE_NAME=results`date "+ %H:%M:%S %d-%m-%Y"`.log
|
|
||||||
OUT_FILE_NAME=jt.log
|
|
||||||
rm -f $OUT_FILE_NAME
|
|
||||||
rm -f ignore.$OUT_FILE_NAME
|
|
||||||
|
|
||||||
|
|
||||||
function run_solver
|
|
||||||
{
|
|
||||||
if [ $2 = bp ]
|
|
||||||
then
|
|
||||||
extra_flag1=clpbn_bp:set_horus_flag\(inf_alg,$4\)
|
|
||||||
extra_flag2=clpbn_bp:set_horus_flag\(schedule,$5\)
|
|
||||||
else
|
|
||||||
extra_flag1=true
|
|
||||||
extra_flag2=true
|
|
||||||
fi
|
|
||||||
/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME
|
|
||||||
[$1].
|
|
||||||
clpbn:set_clpbn_flag(solver,$2),
|
|
||||||
clpbn_bp:set_horus_flag(use_logarithms, true),
|
|
||||||
$extra_flag1, $extra_flag2,
|
|
||||||
run_query(_R),
|
|
||||||
open("$OUT_FILE_NAME", 'append',S),
|
|
||||||
format(S, '$3: ~15+ ',[]),
|
|
||||||
close(S).
|
|
||||||
EOF
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
function run_all_graphs
|
|
||||||
{
|
|
||||||
echo "*******************************************************************" >> "$OUT_FILE_NAME"
|
|
||||||
echo "results for solver $2" >> $OUT_FILE_NAME
|
|
||||||
echo "*******************************************************************" >> "$OUT_FILE_NAME"
|
|
||||||
run_solver town_1000 $1 town_1000 $3 $4 $5
|
|
||||||
run_solver town_5000 $1 town_5000 $3 $4 $5
|
|
||||||
run_solver town_10000 $1 town_10000 $3 $4 $5
|
|
||||||
run_solver town_50000 $1 town_50000 $3 $4 $5
|
|
||||||
run_solver town_100000 $1 town_100000 $3 $4 $5
|
|
||||||
run_solver town_500000 $1 town_500000 $3 $4 $5
|
|
||||||
run_solver town_1000000 $1 town_1000000 $3 $4 $5
|
|
||||||
}
|
|
||||||
|
|
||||||
run_all_graphs jt "jt "
|
|
||||||
|
|
@ -1,65 +0,0 @@
|
|||||||
|
|
||||||
conservative_city(City, Cons) :-
|
|
||||||
cons_table(City, ConsDist),
|
|
||||||
{ Cons = conservative_city(City) with p([y,n], ConsDist) }.
|
|
||||||
|
|
||||||
|
|
||||||
gender(X, Gender) :-
|
|
||||||
gender_table(X, GenderDist),
|
|
||||||
{ Gender = gender(X) with p([m,f], GenderDist) }.
|
|
||||||
|
|
||||||
|
|
||||||
hair_color(X, Color) :-
|
|
||||||
lives(X, City),
|
|
||||||
conservative_city(City, Cons),
|
|
||||||
hair_color_table(X,ColorTable),
|
|
||||||
{ Color = hair_color(X) with
|
|
||||||
p([t,f], ColorTable,[Cons]) }.
|
|
||||||
|
|
||||||
|
|
||||||
car_color(X, Color) :-
|
|
||||||
hair_color(X, HColor),
|
|
||||||
car_color_table(X,CColorTable),
|
|
||||||
{ Color = car_color(X) with
|
|
||||||
p([t,f], CColorTable,[HColor]) }.
|
|
||||||
|
|
||||||
|
|
||||||
height(X, Height) :-
|
|
||||||
gender(X, Gender),
|
|
||||||
height_table(X,HeightTable),
|
|
||||||
{ Height = height(X) with
|
|
||||||
p([t,f], HeightTable,[Gender]) }.
|
|
||||||
|
|
||||||
|
|
||||||
shoe_size(X, Shoesize) :-
|
|
||||||
height(X, Height),
|
|
||||||
shoe_size_table(X,ShoesizeTable),
|
|
||||||
{ Shoesize = shoe_size(X) with
|
|
||||||
p([t,f], ShoesizeTable,[Height]) }.
|
|
||||||
|
|
||||||
|
|
||||||
guilty(X, Guilt) :-
|
|
||||||
guilty_table(X, GuiltDist),
|
|
||||||
{ Guilt = guilty(X) with p([y,n], GuiltDist) }.
|
|
||||||
|
|
||||||
|
|
||||||
descn(X, Descn) :-
|
|
||||||
car_color(X, Car),
|
|
||||||
hair_color(X, Hair),
|
|
||||||
height(X, Height),
|
|
||||||
guilty(X, Guilt),
|
|
||||||
descn_table(X, DescTable),
|
|
||||||
{ Descn = descn(X) with
|
|
||||||
p([t,f], DescTable,[Car,Hair,Height,Guilt]) }.
|
|
||||||
|
|
||||||
|
|
||||||
witness(City, Witness) :-
|
|
||||||
descn(joe, DescnJ),
|
|
||||||
descn(p2, Descn2),
|
|
||||||
wit_table(WitTable),
|
|
||||||
{ Witness = witness(City) with
|
|
||||||
p([t,f], WitTable,[DescnJ, Descn2]) }.
|
|
||||||
|
|
||||||
|
|
||||||
:- ensure_loaded(tables).
|
|
||||||
|
|
@ -1,46 +0,0 @@
|
|||||||
|
|
||||||
cons_table(amsterdam, [0.2, 0.8]) :- !.
|
|
||||||
cons_table(_, [0.8, 0.2]).
|
|
||||||
|
|
||||||
|
|
||||||
gender_table(_, [0.55, 0.44]).
|
|
||||||
|
|
||||||
|
|
||||||
hair_color_table(_,
|
|
||||||
/* conservative_city */
|
|
||||||
/* y n */
|
|
||||||
[ 0.05, 0.1,
|
|
||||||
0.95, 0.9 ]).
|
|
||||||
|
|
||||||
|
|
||||||
car_color_table(_,
|
|
||||||
/* t f */
|
|
||||||
[ 0.9, 0.2,
|
|
||||||
0.1, 0.8 ]).
|
|
||||||
|
|
||||||
|
|
||||||
height_table(_,
|
|
||||||
/* m f */
|
|
||||||
[ 0.6, 0.4,
|
|
||||||
0.4, 0.6 ]).
|
|
||||||
|
|
||||||
|
|
||||||
shoe_size_table(_,
|
|
||||||
/* t f */
|
|
||||||
[ 0.9, 0.1,
|
|
||||||
0.1, 0.9 ]).
|
|
||||||
|
|
||||||
|
|
||||||
guilty_table(_, [0.23, 0.77]).
|
|
||||||
|
|
||||||
|
|
||||||
descn_table(_,
|
|
||||||
/* color, hair, height, guilt */
|
|
||||||
/* ttttt tttf ttft ttff tfttt tftf tfft tfff ttttt fttf ftft ftff ffttt fftf ffft ffff */
|
|
||||||
[ 0.99, 0.5, 0.23, 0.88, 0.41, 0.3, 0.76, 0.87, 0.44, 0.43, 0.29, 0.72, 0.33, 0.91, 0.95, 0.92,
|
|
||||||
0.01, 0.5, 0.77, 0.12, 0.59, 0.7, 0.24, 0.13, 0.56, 0.57, 0.61, 0.28, 0.77, 0.09, 0.05, 0.08]).
|
|
||||||
|
|
||||||
|
|
||||||
wit_table([0.2, 0.45, 0.24, 0.34,
|
|
||||||
0.8, 0.55, 0.76, 0.66]).
|
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,29 +0,0 @@
|
|||||||
:- source.
|
|
||||||
:- style_check(all).
|
|
||||||
:- yap_flag(unknown,error).
|
|
||||||
:- yap_flag(write_strings,on).
|
|
||||||
:- use_module(library(clpbn)).
|
|
||||||
:- set_clpbn_flag(solver, bp).
|
|
||||||
:- [-schema].
|
|
||||||
|
|
||||||
lives(_joe, nyc).
|
|
||||||
|
|
||||||
run_query(Guilty) :-
|
|
||||||
guilty(joe, Guilty),
|
|
||||||
witness(nyc, t),
|
|
||||||
runall(X, ev(X)).
|
|
||||||
|
|
||||||
|
|
||||||
runall(G, Wrapper) :-
|
|
||||||
findall(G, Wrapper, L),
|
|
||||||
execute_all(L).
|
|
||||||
|
|
||||||
|
|
||||||
execute_all([]).
|
|
||||||
execute_all(G.L) :-
|
|
||||||
call(G),
|
|
||||||
execute_all(L).
|
|
||||||
|
|
||||||
|
|
||||||
ev(descn(p2, t)).
|
|
||||||
ev(descn(p3, t)).
|
|
File diff suppressed because it is too large
Load Diff
@ -1,59 +0,0 @@
|
|||||||
#!/home/tiago/bin/yap -L --
|
|
||||||
|
|
||||||
/*
|
|
||||||
Steps:
|
|
||||||
1. generate N facts lives(I, nyc), 0 <= I < N.
|
|
||||||
2. generate evidence on descn for N people, *** except for 1 ***
|
|
||||||
3. Run query ?- guilty(joe, Guilty), witness(joe, t), descn(2,t), descn(3, f), descn(4, f) ...
|
|
||||||
*/
|
|
||||||
|
|
||||||
:- initialization(main).
|
|
||||||
|
|
||||||
|
|
||||||
main :-
|
|
||||||
unix(argv([H])),
|
|
||||||
generate_town(H).
|
|
||||||
|
|
||||||
|
|
||||||
generate_town(N) :-
|
|
||||||
atomic_concat(['town_', N, '.yap'], FileName),
|
|
||||||
open(FileName, 'write', S),
|
|
||||||
write(S, ':- source.\n'),
|
|
||||||
write(S, ':- style_check(all).\n'),
|
|
||||||
write(S, ':- yap_flag(unknown,error).\n'),
|
|
||||||
write(S, ':- yap_flag(write_strings,on).\n'),
|
|
||||||
write(S, ':- use_module(library(clpbn)).\n'),
|
|
||||||
write(S, ':- set_clpbn_flag(solver, bp).\n'),
|
|
||||||
write(S, ':- [-schema].\n\n'),
|
|
||||||
write(S, 'lives(_joe, nyc).\n'),
|
|
||||||
atom_number(N, N2),
|
|
||||||
generate_people(S, N2, 2),
|
|
||||||
write(S, '\nrun_query(Guilty) :- \n'),
|
|
||||||
write(S, '\tguilty(joe, Guilty),\n'),
|
|
||||||
write(S, '\twitness(nyc, t),\n'),
|
|
||||||
write(S, '\trunall(X, ev(X)).\n\n\n'),
|
|
||||||
write(S, 'runall(G, Wrapper) :-\n'),
|
|
||||||
write(S, '\tfindall(G, Wrapper, L),\n'),
|
|
||||||
write(S, '\texecute_all(L).\n\n\n'),
|
|
||||||
write(S, 'execute_all([]).\n'),
|
|
||||||
write(S, 'execute_all(G.L) :-\n'),
|
|
||||||
write(S, '\tcall(G),\n'),
|
|
||||||
write(S, '\texecute_all(L).\n\n\n'),
|
|
||||||
generate_query(S, N2, 2),
|
|
||||||
close(S).
|
|
||||||
|
|
||||||
|
|
||||||
generate_people(_, N, Counting1) :- !.
|
|
||||||
generate_people(S, N, Counting) :-
|
|
||||||
format(S, 'lives(p~w, nyc).~n', [Counting]),
|
|
||||||
Counting1 is Counting + 1,
|
|
||||||
generate_people(S, N, Counting1).
|
|
||||||
|
|
||||||
|
|
||||||
generate_query(S, N, Counting) :-
|
|
||||||
Counting > N, !.
|
|
||||||
generate_query(S, N, Counting) :- !,
|
|
||||||
format(S, 'ev(descn(p~w, t)).~n', [Counting]),
|
|
||||||
Counting1 is Counting + 1,
|
|
||||||
generate_query(S, N, Counting1).
|
|
||||||
|
|
@ -1,50 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
cp ~/bin/yap ~/bin/town_ve
|
|
||||||
YAP=~/bin/town_ve
|
|
||||||
|
|
||||||
#OUT_FILE_NAME=results`date "+ %H:%M:%S %d-%m-%Y"`.log
|
|
||||||
OUT_FILE_NAME=ve.log
|
|
||||||
rm -f $OUT_FILE_NAME
|
|
||||||
rm -f ignore.$OUT_FILE_NAME
|
|
||||||
|
|
||||||
|
|
||||||
function run_solver
|
|
||||||
{
|
|
||||||
if [ $2 = bp ]
|
|
||||||
then
|
|
||||||
extra_flag1=clpbn_bp:set_horus_flag\(inf_alg,$4\)
|
|
||||||
extra_flag2=clpbn_bp:set_horus_flag\(schedule,$5\)
|
|
||||||
else
|
|
||||||
extra_flag1=true
|
|
||||||
extra_flag2=true
|
|
||||||
fi
|
|
||||||
/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME
|
|
||||||
[$1].
|
|
||||||
clpbn:set_clpbn_flag(solver,$2),
|
|
||||||
clpbn_bp:set_horus_flag(use_logarithms, true),
|
|
||||||
$extra_flag1, $extra_flag2,
|
|
||||||
run_query(_R),
|
|
||||||
open("$OUT_FILE_NAME", 'append',S),
|
|
||||||
format(S, '$3: ~15+ ',[]),
|
|
||||||
close(S).
|
|
||||||
EOF
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
function run_all_graphs
|
|
||||||
{
|
|
||||||
echo "*******************************************************************" >> "$OUT_FILE_NAME"
|
|
||||||
echo "results for solver $2" >> $OUT_FILE_NAME
|
|
||||||
echo "*******************************************************************" >> "$OUT_FILE_NAME"
|
|
||||||
run_solver town_1000 $1 town_1000 $3 $4 $5
|
|
||||||
#run_solver town_5000 $1 town_5000 $3 $4 $5
|
|
||||||
#run_solver town_10000 $1 town_10000 $3 $4 $5
|
|
||||||
#run_solver town_50000 $1 town_50000 $3 $4 $5
|
|
||||||
#run_solver town_100000 $1 town_100000 $3 $4 $5
|
|
||||||
#run_solver town_500000 $1 town_500000 $3 $4 $5
|
|
||||||
#run_solver town_1000000 $1 town_1000000 $3 $4 $5
|
|
||||||
}
|
|
||||||
|
|
||||||
run_all_graphs ve "ve "
|
|
||||||
|
|
@ -9,7 +9,7 @@ OUT_FILE_NAME=results.log
|
|||||||
rm -f $OUT_FILE_NAME
|
rm -f $OUT_FILE_NAME
|
||||||
rm -f ignore.$OUT_FILE_NAME
|
rm -f ignore.$OUT_FILE_NAME
|
||||||
|
|
||||||
# yap -g "['../../../../examples/School/sch32'], [missing5], use_module(library(clpbn/learning/em)), graph(L), clpbn:set_clpbn_flag(em_solver,bp), clpbn_horus:set_horus_flag(inf_alg,fg_bp), statistics(runtime, _), em(L,0.01,10,_,Lik), statistics(runtime, [T,_])."
|
# yap -g "['../../../../examples/School/sch32'], [missing5], use_module(library(clpbn/learning/em)), graph(L), clpbn:set_clpbn_flag(em_solver,bp), clpbn_horus:set_horus_flag(inf_alg, bp), statistics(runtime, _), em(L,0.01,10,_,Lik), statistics(runtime, [T,_])."
|
||||||
|
|
||||||
function run_solver
|
function run_solver
|
||||||
{
|
{
|
||||||
@ -58,24 +58,21 @@ function run_all_graphs
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#run_all_graphs bp "hve(min_neighbors) " ve min_neighbors
|
#run_all_graphs bp "hve(min_neighbors) " ve min_neighbors
|
||||||
#run_all_graphs bp "bn_bp(seq_fixed) " bn_bp seq_fixed
|
#run_all_graphs bp "bp(seq_fixed) " bp seq_fixed
|
||||||
run_all_graphs bp "fg_bp(seq_fixed) " fg_bp seq_fixed
|
#run_all_graphs bp "cbp(seq_fixed) " cbp seq_fixed
|
||||||
#run_all_graphs bp "cbp(seq_fixed) " cbp seq_fixed
|
|
||||||
exit
|
exit
|
||||||
|
|
||||||
|
|
||||||
run_all_graphs bp "hve(min_neighbors) " ve min_neighbors
|
run_all_graphs bp "hve(min_neighbors) " ve min_neighbors
|
||||||
run_all_graphs bp "hve(min_weight) " ve min_weight
|
run_all_graphs bp "hve(min_weight) " ve min_weight
|
||||||
run_all_graphs bp "hve(min_fill) " ve min_fill
|
run_all_graphs bp "hve(min_fill) " ve min_fill
|
||||||
run_all_graphs bp "hve(w_min_fill) " ve weighted_min_fill
|
run_all_graphs bp "hve(w_min_fill) " ve weighted_min_fill
|
||||||
run_all_graphs bp "bn_bp(seq_fixed) " bn_bp seq_fixed
|
run_all_graphs bp "bp(seq_fixed) " bp seq_fixed
|
||||||
run_all_graphs bp "bn_bp(max_residual) " bn_bp max_residual
|
run_all_graphs bp "bp(max_residual) " bp max_residual
|
||||||
run_all_graphs bp "fg_bp(seq_fixed) " fg_bp seq_fixed
|
run_all_graphs bp "cbp(seq_fixed) " cbp seq_fixed
|
||||||
run_all_graphs bp "fg_bp(max_residual) " fg_bp max_residual
|
run_all_graphs bp "cbp(max_residual) " cbp max_residual
|
||||||
run_all_graphs bp "cbp(seq_fixed) " cbp seq_fixed
|
run_all_graphs gibbs "gibbs "
|
||||||
run_all_graphs bp "cbp(max_residual) " cbp max_residual
|
|
||||||
run_all_graphs gibbs "gibbs "
|
|
||||||
echo "************************************************************************" >> "$OUT_FILE_NAME"
|
echo "************************************************************************" >> "$OUT_FILE_NAME"
|
||||||
echo "results for solver ve" >> "$OUT_FILE_NAME"
|
echo "results for solver ve" >> "$OUT_FILE_NAME"
|
||||||
echo "************************************************************************" >> "$OUT_FILE_NAME"
|
echo "************************************************************************" >> "$OUT_FILE_NAME"
|
||||||
|
11
packages/CLPBN/clpbn/bp/benchmarks/workshop_attrs/bp_tests.sh
Executable file
11
packages/CLPBN/clpbn/bp/benchmarks/workshop_attrs/bp_tests.sh
Executable file
@ -0,0 +1,11 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
source wa.sh
|
||||||
|
source ../benchs.sh
|
||||||
|
|
||||||
|
SOLVER="bp"
|
||||||
|
|
||||||
|
YAP=~/bin/$SHORTNAME-$SOLVER
|
||||||
|
|
||||||
|
run_all_graphs "bp(shedule=seq_fixed) " seq_fixed
|
||||||
|
|
11
packages/CLPBN/clpbn/bp/benchmarks/workshop_attrs/cbp_tests.sh
Executable file
11
packages/CLPBN/clpbn/bp/benchmarks/workshop_attrs/cbp_tests.sh
Executable file
@ -0,0 +1,11 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
source wa.sh
|
||||||
|
source ../benchs.sh
|
||||||
|
|
||||||
|
SOLVER="cbp"
|
||||||
|
|
||||||
|
YAP=~/bin/$SHORTNAME-$SOLVER
|
||||||
|
|
||||||
|
run_all_graphs "cbp(shedule=seq_fixed) " seq_fixed
|
||||||
|
|
12
packages/CLPBN/clpbn/bp/benchmarks/workshop_attrs/fove_tests.sh
Executable file
12
packages/CLPBN/clpbn/bp/benchmarks/workshop_attrs/fove_tests.sh
Executable file
@ -0,0 +1,12 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
source wa.sh
|
||||||
|
source ../benchs.sh
|
||||||
|
|
||||||
|
SOLVER="fove"
|
||||||
|
|
||||||
|
YAP=~/bin/$SHORTNAME-$SOLVER
|
||||||
|
|
||||||
|
run_all_graphs "fove "
|
||||||
|
|
||||||
|
|
11
packages/CLPBN/clpbn/bp/benchmarks/workshop_attrs/hve_tests.sh
Executable file
11
packages/CLPBN/clpbn/bp/benchmarks/workshop_attrs/hve_tests.sh
Executable file
@ -0,0 +1,11 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
source wa.sh
|
||||||
|
source ../benchs.sh
|
||||||
|
|
||||||
|
SOLVER="hve"
|
||||||
|
|
||||||
|
YAP=~/bin/$SHORTNAME-$SOLVER
|
||||||
|
|
||||||
|
run_all_graphs "hve(elim_heuristic=min_neighbors) " min_neighbors
|
||||||
|
|
27
packages/CLPBN/clpbn/bp/benchmarks/workshop_attrs/people_generator.sh
Executable file
27
packages/CLPBN/clpbn/bp/benchmarks/workshop_attrs/people_generator.sh
Executable file
@ -0,0 +1,27 @@
|
|||||||
|
#!/home/tiago/bin/yap -L --
|
||||||
|
|
||||||
|
|
||||||
|
:- initialization(main).
|
||||||
|
|
||||||
|
|
||||||
|
main :-
|
||||||
|
unix(argv([H])),
|
||||||
|
generate_town(H).
|
||||||
|
|
||||||
|
|
||||||
|
generate_town(N) :-
|
||||||
|
atomic_concat(['pop_', N, '.yap'], FileName),
|
||||||
|
open(FileName, 'write', S),
|
||||||
|
atom_number(N, N2),
|
||||||
|
generate_people(S, N2, 4),
|
||||||
|
write(S, '\n'),
|
||||||
|
close(S).
|
||||||
|
|
||||||
|
|
||||||
|
generate_people(S, N, Counting) :-
|
||||||
|
Counting > N, !.
|
||||||
|
generate_people(S, N, Counting) :-
|
||||||
|
format(S, 'people(p~w).~n', [Counting]),
|
||||||
|
Counting1 is Counting + 1,
|
||||||
|
generate_people(S, N, Counting1).
|
||||||
|
|
33
packages/CLPBN/clpbn/bp/benchmarks/workshop_attrs/wa.sh
Executable file
33
packages/CLPBN/clpbn/bp/benchmarks/workshop_attrs/wa.sh
Executable file
@ -0,0 +1,33 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
NETWORK="'../../examples/workshop_attrs'"
|
||||||
|
SHORTNAME="wa"
|
||||||
|
QUERY="series(X)"
|
||||||
|
|
||||||
|
|
||||||
|
function run_all_graphs
|
||||||
|
{
|
||||||
|
LOG_FILE=$SOLVER.log
|
||||||
|
#LOG_FILE=results`date "+ %H:%M:%S %d-%m-%Y"`.
|
||||||
|
|
||||||
|
rm -f $LOG_FILE
|
||||||
|
rm -f ignore.$LOG_FILE
|
||||||
|
|
||||||
|
cp ~/bin/yap $YAP
|
||||||
|
|
||||||
|
echo -n "**********************************" >> $LOG_FILE
|
||||||
|
echo "**********************************" >> $LOG_FILE
|
||||||
|
echo "results for solver $1" >> $LOG_FILE
|
||||||
|
echo -n "**********************************" >> $LOG_FILE
|
||||||
|
echo "**********************************" >> $LOG_FILE
|
||||||
|
run_solver pop_10 $2
|
||||||
|
#run_solver pop_1000 $2
|
||||||
|
#run_solver pop_5000 $2
|
||||||
|
#run_solver pop_10000 $2
|
||||||
|
#run_solver pop_50000 $2
|
||||||
|
#run_solver pop_100000 $2
|
||||||
|
#run_solver pop_500000 $2
|
||||||
|
#run_solver pop_1000000 $2
|
||||||
|
}
|
||||||
|
|
||||||
|
|
@ -1,18 +0,0 @@
|
|||||||
|
|
||||||
:- use_module(library(pfl)).
|
|
||||||
|
|
||||||
:- set_clpbn_flag(solver,fove).
|
|
||||||
|
|
||||||
|
|
||||||
c(x1,y1,z1).
|
|
||||||
c(x1,y1,z2).
|
|
||||||
c(x2,y2,z1).
|
|
||||||
c(x3,y2,z1).
|
|
||||||
|
|
||||||
bayes p(X)::[t,f] ; [0.2, 0.4] ; [c(X,_,_)].
|
|
||||||
|
|
||||||
bayes q(Y)::[t,f] ; [0.5, 0.6] ; [c(_,Y,_)].
|
|
||||||
|
|
||||||
bayes s(Z)::[t,f] , p(X) , q(Y) ; [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8] ; [c(X,Y,Z)].
|
|
||||||
|
|
||||||
% bayes series::[t,f] , attends(X) ; [0.5, 0.6, 0.7, 0.8] ; [c(X,_)].
|
|
@ -1,81 +0,0 @@
|
|||||||
<?xml version="1.0" encoding="US-ASCII"?>
|
|
||||||
|
|
||||||
<!--
|
|
||||||
|
|
||||||
B E
|
|
||||||
\ /
|
|
||||||
\ /
|
|
||||||
A
|
|
||||||
/ \
|
|
||||||
/ \
|
|
||||||
J M
|
|
||||||
|
|
||||||
-->
|
|
||||||
|
|
||||||
|
|
||||||
<BIF VERSION="0.3">
|
|
||||||
<NETWORK>
|
|
||||||
<NAME>Simple Loop</NAME>
|
|
||||||
|
|
||||||
<VARIABLE TYPE="nature">
|
|
||||||
<NAME>B</NAME>
|
|
||||||
<OUTCOME>b1</OUTCOME>
|
|
||||||
<OUTCOME>b2</OUTCOME>
|
|
||||||
</VARIABLE>
|
|
||||||
|
|
||||||
<VARIABLE TYPE="nature">
|
|
||||||
<NAME>E</NAME>
|
|
||||||
<OUTCOME>e1</OUTCOME>
|
|
||||||
<OUTCOME>e2</OUTCOME>
|
|
||||||
</VARIABLE>
|
|
||||||
|
|
||||||
<VARIABLE TYPE="nature">
|
|
||||||
<NAME>A</NAME>
|
|
||||||
<OUTCOME>a1</OUTCOME>
|
|
||||||
<OUTCOME>a2</OUTCOME>
|
|
||||||
</VARIABLE>
|
|
||||||
|
|
||||||
<VARIABLE TYPE="nature">
|
|
||||||
<NAME>J</NAME>
|
|
||||||
<OUTCOME>j1</OUTCOME>
|
|
||||||
<OUTCOME>j2</OUTCOME>
|
|
||||||
</VARIABLE>
|
|
||||||
|
|
||||||
<VARIABLE TYPE="nature">
|
|
||||||
<NAME>M</NAME>
|
|
||||||
<OUTCOME>m1</OUTCOME>
|
|
||||||
<OUTCOME>m2</OUTCOME>
|
|
||||||
</VARIABLE>
|
|
||||||
|
|
||||||
<DEFINITION>
|
|
||||||
<FOR>B</FOR>
|
|
||||||
<TABLE> .001 .999 </TABLE>
|
|
||||||
</DEFINITION>
|
|
||||||
|
|
||||||
<DEFINITION>
|
|
||||||
<FOR>E</FOR>
|
|
||||||
<TABLE> .002 .998 </TABLE>
|
|
||||||
</DEFINITION>
|
|
||||||
|
|
||||||
<DEFINITION>
|
|
||||||
<FOR>A</FOR>
|
|
||||||
<GIVEN>B</GIVEN>
|
|
||||||
<GIVEN>E</GIVEN>
|
|
||||||
<TABLE> .95 .05 .94 .06 .29 .71 .001 .999 </TABLE>
|
|
||||||
</DEFINITION>
|
|
||||||
|
|
||||||
<DEFINITION>
|
|
||||||
<FOR>J</FOR>
|
|
||||||
<GIVEN>A</GIVEN>
|
|
||||||
<TABLE> .9 .1 .05 .95 </TABLE>
|
|
||||||
</DEFINITION>
|
|
||||||
|
|
||||||
<DEFINITION>
|
|
||||||
<FOR>M</FOR>
|
|
||||||
<GIVEN>A</GIVEN>
|
|
||||||
<TABLE> .7 .3 .01 .99 </TABLE>
|
|
||||||
</DEFINITION>
|
|
||||||
|
|
||||||
</NETWORK>
|
|
||||||
</BIF>
|
|
||||||
|
|
@ -1,29 +1,41 @@
|
|||||||
|
|
||||||
:- use_module(library(clpbn)).
|
:- use_module(library(pfl)).
|
||||||
|
|
||||||
:- set_clpbn_flag(solver, bp).
|
%:- set_pfl_flag(solver,ve).
|
||||||
|
:- set_pfl_flag(solver,bp), clpbn_horus:set_horus_flag(inf_alg,ve).
|
||||||
|
%:- set_pfl_flag(solver,bp), clpbn_horus:set_horus_flag(inf_alg,bp).
|
||||||
|
%:- set_pfl_flag(solver,fove).
|
||||||
|
|
||||||
r(R) :- r_cpt(RCpt),
|
% :- yap_flag(write_strings, off).
|
||||||
{ R = r with p([r1, r2], RCpt) }.
|
|
||||||
|
|
||||||
t(T) :- t_cpt(TCpt),
|
|
||||||
{ T = t with p([t1, t2], TCpt) }.
|
|
||||||
|
|
||||||
a(A) :- r(R), t(T), a_cpt(ACpt),
|
bayes burglary::[b1,b3] ; [0.001, 0.999] ; [].
|
||||||
{ A = a with p([a1, a2], ACpt, [R, T]) }.
|
|
||||||
|
|
||||||
j(J) :- a(A), j_cpt(JCpt),
|
bayes earthquake::[e1,e2] ; [0.002, 0.998]; [].
|
||||||
{ J = j with p([j1, j2], JCpt, [A]) }.
|
|
||||||
|
|
||||||
m(M) :- a(A), m_cpt(MCpt),
|
bayes alarm::[a1,a2] , burglary, earthquake ; [0.95, 0.94, 0.29, 0.001, 0.05, 0.06, 0.71, 0.999] ; [].
|
||||||
{ M = m with p([m1, m2], MCpt, [A]) }.
|
|
||||||
|
bayes john_calls::[j1,j2] , alarm ; [0.9, 0.05, 0.1, 0.95] ; [].
|
||||||
|
|
||||||
|
bayes mary_calls::[m1,m2] , alarm ; [0.7, 0.01, 0.3, 0.99] ; [].
|
||||||
|
|
||||||
|
|
||||||
|
b_cpt([0.001, 0.999]).
|
||||||
|
|
||||||
|
e_cpt([0.002, 0.998]).
|
||||||
|
|
||||||
r_cpt([0.001, 0.999]).
|
|
||||||
t_cpt([0.002, 0.998]).
|
|
||||||
a_cpt([0.95, 0.94, 0.29, 0.001,
|
a_cpt([0.95, 0.94, 0.29, 0.001,
|
||||||
0.05, 0.06, 0.71, 0.999]).
|
0.05, 0.06, 0.71, 0.999]).
|
||||||
j_cpt([0.9, 0.05,
|
|
||||||
|
jc_cpt([0.9, 0.05,
|
||||||
0.1, 0.95]).
|
0.1, 0.95]).
|
||||||
m_cpt([0.7, 0.01,
|
|
||||||
|
mc_cpt([0.7, 0.01,
|
||||||
0.3, 0.99]).
|
0.3, 0.99]).
|
||||||
|
|
||||||
|
% ?- alarm(A).
|
||||||
|
?- john_calls(J), mary_calls(m1).
|
||||||
|
%?- john_calls(J), mary_calls(m1), alarm(a1).
|
||||||
|
%?- john_calls(J), alarm(a1).
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
# example in counting belief propagation paper
|
||||||
|
|
||||||
MARKOV
|
MARKOV
|
||||||
3
|
3
|
||||||
2 2 2
|
2 2 2
|
||||||
@ -5,7 +7,6 @@ MARKOV
|
|||||||
2 0 1
|
2 0 1
|
||||||
2 2 1
|
2 2 1
|
||||||
|
|
||||||
|
|
||||||
4
|
4
|
||||||
1.2 1.4 2.0 0.4
|
1.2 1.4 2.0 0.4
|
||||||
|
|
||||||
|
102
packages/CLPBN/clpbn/bp/examples/city.yap
Normal file
102
packages/CLPBN/clpbn/bp/examples/city.yap
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
:- use_module(library(pfl)).
|
||||||
|
|
||||||
|
:- clpbn_horus:set_solver(fove).
|
||||||
|
%:- clpbn_horus:set_solver(hve).
|
||||||
|
%:- clpbn_horus:set_solver(bp).
|
||||||
|
%:- clpbn_horus:set_solver(cbp).
|
||||||
|
|
||||||
|
|
||||||
|
people(joe,nyc).
|
||||||
|
people(p2, nyc).
|
||||||
|
people(p3, nyc).
|
||||||
|
|
||||||
|
|
||||||
|
ev(descn(p2, t)).
|
||||||
|
ev(descn(p3, t)).
|
||||||
|
|
||||||
|
% :- [city_7].
|
||||||
|
|
||||||
|
bayes city_conservativeness(C)::[y,n] ; cons_table(C) ; [people(_,C)].
|
||||||
|
|
||||||
|
bayes gender(P)::[m,f] ; gender_table(P) ; [people(P,_)].
|
||||||
|
|
||||||
|
bayes hair_color(P)::[t,f], city_conservativeness(C) ; hair_color_table(P) ; [people(P,C)].
|
||||||
|
|
||||||
|
bayes car_color(P)::[t,f], hair_color(P) ; car_color_table(P); [people(P,_)].
|
||||||
|
|
||||||
|
bayes height(P)::[t,f], gender(P) ; height_table(P) ; [people(P,_)].
|
||||||
|
|
||||||
|
bayes shoe_size(P):[t,f], height(P) ; shoe_size_table(P); [people(P,_)].
|
||||||
|
|
||||||
|
bayes guilty(P)::[y,n] ; guilty_table(P) ; [people(P,_)].
|
||||||
|
|
||||||
|
bayes descn(P)::[t,f], car_color(P), hair_color(P), height(P), guilty(P) ; descn_table(P) ; [people(P,_)].
|
||||||
|
|
||||||
|
bayes witness(C)::[t,f], descn(Joe), descn(P2) ; wit_table ; [people(_,C), Joe=joe, P2=p2].
|
||||||
|
|
||||||
|
|
||||||
|
cons_table(amsterdam, [0.2, 0.8]) :- !.
|
||||||
|
cons_table(_, [0.8, 0.2]).
|
||||||
|
|
||||||
|
|
||||||
|
gender_table(_, [0.55, 0.44]).
|
||||||
|
|
||||||
|
|
||||||
|
hair_color_table(_,
|
||||||
|
/* conservative_city */
|
||||||
|
/* y n */
|
||||||
|
[ 0.05, 0.1,
|
||||||
|
0.95, 0.9 ]).
|
||||||
|
|
||||||
|
|
||||||
|
car_color_table(_,
|
||||||
|
/* t f */
|
||||||
|
[ 0.9, 0.2,
|
||||||
|
0.1, 0.8 ]).
|
||||||
|
|
||||||
|
|
||||||
|
height_table(_,
|
||||||
|
/* m f */
|
||||||
|
[ 0.6, 0.4,
|
||||||
|
0.4, 0.6 ]).
|
||||||
|
|
||||||
|
|
||||||
|
shoe_size_table(_,
|
||||||
|
/* t f */
|
||||||
|
[ 0.9, 0.1,
|
||||||
|
0.1, 0.9 ]).
|
||||||
|
|
||||||
|
|
||||||
|
guilty_table(_, [0.23, 0.77]).
|
||||||
|
|
||||||
|
|
||||||
|
descn_table(_,
|
||||||
|
/* color, hair, height, guilt */
|
||||||
|
/* ttttt tttf ttft ttff tfttt tftf tfft tfff ttttt fttf ftft ftff ffttt fftf ffft ffff */
|
||||||
|
[ 0.99, 0.5, 0.23, 0.88, 0.41, 0.3, 0.76, 0.87, 0.44, 0.43, 0.29, 0.72, 0.33, 0.91, 0.95, 0.92,
|
||||||
|
0.01, 0.5, 0.77, 0.12, 0.59, 0.7, 0.24, 0.13, 0.56, 0.57, 0.61, 0.28, 0.77, 0.09, 0.05, 0.08]).
|
||||||
|
|
||||||
|
|
||||||
|
wit_table([0.2, 0.45, 0.24, 0.34,
|
||||||
|
0.8, 0.55, 0.76, 0.66]).
|
||||||
|
|
||||||
|
|
||||||
|
runall(G, Wrapper) :-
|
||||||
|
findall(G, Wrapper, L),
|
||||||
|
execute_all(L).
|
||||||
|
|
||||||
|
|
||||||
|
execute_all([]).
|
||||||
|
execute_all(G.L) :-
|
||||||
|
call(G),
|
||||||
|
execute_all(L).
|
||||||
|
|
||||||
|
|
||||||
|
is_joe_guilty(Guilty) :-
|
||||||
|
witness(nyc, t),
|
||||||
|
runall(X, ev(X)),
|
||||||
|
guilty(joe, Guilty).
|
||||||
|
|
||||||
|
|
||||||
|
% ?- is_joe_guilty(Guilty)
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,9 +1,11 @@
|
|||||||
|
|
||||||
:- use_module(library(pfl)).
|
:- use_module(library(pfl)).
|
||||||
|
|
||||||
%:- set_clpbn_flag(solver,ve).
|
:- clpbn_horus:set_solver(fove).
|
||||||
%:- set_clpbn_flag(solver,bp), clpbn_bp:set_horus_flag(inf_alg,ve).
|
%:- clpbn_horus:set_solver(hve).
|
||||||
:- set_clpbn_flag(solver,fove).
|
%:- clpbn_horus:set_solver(bp).
|
||||||
|
%:- clpbn_horus:set_solver(cbp).
|
||||||
|
|
||||||
|
:- yap_flag(write_strings, off).
|
||||||
|
|
||||||
c(p1,w1).
|
c(p1,w1).
|
||||||
c(p1,w2).
|
c(p1,w2).
|
||||||
@ -25,8 +27,5 @@ markov attends(P)::[t,f] , hot(W)::[t,f] ; [0.1, 0.2, 0.3, 0.4] ; [c(P,W)].
|
|||||||
|
|
||||||
markov attends(P)::[t,f], series::[t,f] ; [0.5, 0.6, 0.7, 0.8] ; [c(P,_)].
|
markov attends(P)::[t,f], series::[t,f] ; [0.5, 0.6, 0.7, 0.8] ; [c(P,_)].
|
||||||
|
|
||||||
:- clpbn_horus:set_horus_flag(use_logarithms,true).
|
% ?- series(X).
|
||||||
|
|
||||||
?- series(X).
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,7 +2,9 @@
|
|||||||
:- use_module(library(pfl)).
|
:- use_module(library(pfl)).
|
||||||
|
|
||||||
:- set_pfl_flag(solver,fove).
|
:- set_pfl_flag(solver,fove).
|
||||||
%:- set_pfl_flag(solver,fove).
|
%:- set_pfl_flag(solver,bp), clpbn_horus:set_horus_flag(inf_alg,ve).
|
||||||
|
%:- set_pfl_flag(solver,bp), clpbn_horus:set_horus_flag(inf_alg,bp).
|
||||||
|
%:- set_pfl_flag(solver,bp), clpbn_horus:set_horus_flag(inf_alg,cbp).
|
||||||
|
|
||||||
|
|
||||||
t(ann).
|
t(ann).
|
||||||
@ -10,10 +12,10 @@ t(dave).
|
|||||||
|
|
||||||
% p(ann,t).
|
% p(ann,t).
|
||||||
|
|
||||||
bayes p(X)::[t,f] ; [0.1, 0.3] ; [t(X)].
|
markov p(X)::[t,f] ; [0.1, 0.3] ; [t(X)].
|
||||||
|
|
||||||
% use standard Prolog queries: provide evidence first.
|
% use standard Prolog queries: provide evidence first.
|
||||||
|
|
||||||
?- p(dave,t), p(ann,X).
|
?- p(ann,t), p(ann,X).
|
||||||
% ?- p(ann,X).
|
% ?- p(ann,X).
|
||||||
|
|
||||||
|
@ -1,32 +1,24 @@
|
|||||||
|
|
||||||
:- use_module(library(pfl)).
|
:- use_module(library(pfl)).
|
||||||
|
|
||||||
%:- set_pfl_flag(solver,ve).
|
:- clpbn_horus:set_solver(fove).
|
||||||
:- set_pfl_flag(solver,bp), clpbn_bp:set_horus_flag(inf_alg,ve).
|
%:- clpbn_horus:set_solver(hve).
|
||||||
% :- set_pfl_flag(solver,fove).
|
%:- clpbn_horus:set_solver(bp).
|
||||||
|
%:- clpbn_horus:set_solver(cbp).
|
||||||
|
|
||||||
:- yap_flag(write_strings, off).
|
:- yap_flag(write_strings, off).
|
||||||
|
|
||||||
friendly(P1, P2) :-
|
|
||||||
person(P1),
|
|
||||||
person(P2),
|
|
||||||
P1 @> P2.
|
|
||||||
|
|
||||||
person(john).
|
friends(P1, P2) :-
|
||||||
person(maggie).
|
people(P1),
|
||||||
person(harry).
|
people(P2),
|
||||||
person(bill).
|
P1 \= P2.
|
||||||
person(matt).
|
|
||||||
person(diana).
|
|
||||||
person(bob).
|
|
||||||
person(dick).
|
|
||||||
person(burr).
|
|
||||||
person(ann).
|
|
||||||
|
|
||||||
person @ 2.
|
people @ 3.
|
||||||
|
|
||||||
markov smokes(P)::[t,f] , cancer(P)::[t,f] ; [0.1, 0.2, 0.3, 0.4] ; [person(P)].
|
markov smokes(P)::[t,f], cancer(P)::[t,f] ; [0.1, 0.2, 0.3, 0.4] ; [people(P)].
|
||||||
|
|
||||||
markov friend(P1,P2)::[t,f], smokes(P1)::[t,f], smokes(P2)::[t,f] ; [0.5, 0.6, 0.7, 0.8] ; [friendly(P1, P2)].
|
markov friend(P1,P2)::[t,f], smokes(P1)::[t,f], smokes(P2)::[t,f] ;
|
||||||
|
[0.5, 0.6, 0.7, 0.8, 0.5, 0.6, 0.7, 0.8] ; [friends(P1, P2)].
|
||||||
|
|
||||||
|
% ?- smokes(p1, t), smokes(p2, f), friend(p1, p2, X).
|
||||||
|
|
||||||
?- smokes(person_0, t), smokes(person_1, t), friend(person_0, person_1, F).
|
|
||||||
|
@ -1,120 +0,0 @@
|
|||||||
14
|
|
||||||
|
|
||||||
1
|
|
||||||
6
|
|
||||||
2
|
|
||||||
2
|
|
||||||
0 9.974182
|
|
||||||
1 1.000000
|
|
||||||
|
|
||||||
1
|
|
||||||
7
|
|
||||||
2
|
|
||||||
2
|
|
||||||
0 9.974182
|
|
||||||
1 1.000000
|
|
||||||
|
|
||||||
1
|
|
||||||
4
|
|
||||||
2
|
|
||||||
2
|
|
||||||
0 4.055200
|
|
||||||
1 1.000000
|
|
||||||
|
|
||||||
1
|
|
||||||
5
|
|
||||||
2
|
|
||||||
2
|
|
||||||
0 4.055200
|
|
||||||
1 1.000000
|
|
||||||
|
|
||||||
1
|
|
||||||
0
|
|
||||||
2
|
|
||||||
2
|
|
||||||
0 7.389056
|
|
||||||
1 1.000000
|
|
||||||
|
|
||||||
1
|
|
||||||
2
|
|
||||||
2
|
|
||||||
2
|
|
||||||
0 7.389056
|
|
||||||
1 1.000000
|
|
||||||
|
|
||||||
1
|
|
||||||
1
|
|
||||||
2
|
|
||||||
2
|
|
||||||
0 7.389056
|
|
||||||
1 1.000000
|
|
||||||
|
|
||||||
1
|
|
||||||
3
|
|
||||||
2
|
|
||||||
2
|
|
||||||
0 7.389056
|
|
||||||
1 1.000000
|
|
||||||
|
|
||||||
2
|
|
||||||
4 6
|
|
||||||
2 2
|
|
||||||
4
|
|
||||||
0 4.481689
|
|
||||||
1 1.000000
|
|
||||||
2 4.481689
|
|
||||||
3 4.481689
|
|
||||||
|
|
||||||
2
|
|
||||||
5 7
|
|
||||||
2 2
|
|
||||||
4
|
|
||||||
0 4.481689
|
|
||||||
1 1.000000
|
|
||||||
2 4.481689
|
|
||||||
3 4.481689
|
|
||||||
|
|
||||||
2
|
|
||||||
0 4
|
|
||||||
2 2
|
|
||||||
4
|
|
||||||
0 3.004166
|
|
||||||
1 3.004166
|
|
||||||
2 3.004166
|
|
||||||
3 3.004166
|
|
||||||
|
|
||||||
3
|
|
||||||
2 5 4
|
|
||||||
2 2 2
|
|
||||||
8
|
|
||||||
0 3.004166
|
|
||||||
1 3.004166
|
|
||||||
2 3.004166
|
|
||||||
3 1.000000
|
|
||||||
4 3.004166
|
|
||||||
5 1.000000
|
|
||||||
6 3.004166
|
|
||||||
7 3.004166
|
|
||||||
|
|
||||||
3
|
|
||||||
1 4 5
|
|
||||||
2 2 2
|
|
||||||
8
|
|
||||||
0 3.004166
|
|
||||||
1 3.004166
|
|
||||||
2 3.004166
|
|
||||||
3 1.000000
|
|
||||||
4 3.004166
|
|
||||||
5 1.000000
|
|
||||||
6 3.004166
|
|
||||||
7 3.004166
|
|
||||||
|
|
||||||
2
|
|
||||||
3 5
|
|
||||||
2 2
|
|
||||||
4
|
|
||||||
0 3.004166
|
|
||||||
1 3.004166
|
|
||||||
2 3.004166
|
|
||||||
3 3.004166
|
|
||||||
|
|
@ -1,239 +0,0 @@
|
|||||||
27
|
|
||||||
|
|
||||||
1
|
|
||||||
12
|
|
||||||
2
|
|
||||||
2
|
|
||||||
0 9.974182
|
|
||||||
1 1.000000
|
|
||||||
|
|
||||||
1
|
|
||||||
13
|
|
||||||
2
|
|
||||||
2
|
|
||||||
0 9.974182
|
|
||||||
1 1.000000
|
|
||||||
|
|
||||||
1
|
|
||||||
14
|
|
||||||
2
|
|
||||||
2
|
|
||||||
0 9.974182
|
|
||||||
1 1.000000
|
|
||||||
|
|
||||||
1
|
|
||||||
9
|
|
||||||
2
|
|
||||||
2
|
|
||||||
0 4.055200
|
|
||||||
1 1.000000
|
|
||||||
|
|
||||||
1
|
|
||||||
10
|
|
||||||
2
|
|
||||||
2
|
|
||||||
0 4.055200
|
|
||||||
1 1.000000
|
|
||||||
|
|
||||||
1
|
|
||||||
11
|
|
||||||
2
|
|
||||||
2
|
|
||||||
0 4.055200
|
|
||||||
1 1.000000
|
|
||||||
|
|
||||||
1
|
|
||||||
0
|
|
||||||
2
|
|
||||||
2
|
|
||||||
0 7.389056
|
|
||||||
1 1.000000
|
|
||||||
|
|
||||||
1
|
|
||||||
3
|
|
||||||
2
|
|
||||||
2
|
|
||||||
0 7.389056
|
|
||||||
1 1.000000
|
|
||||||
|
|
||||||
1
|
|
||||||
6
|
|
||||||
2
|
|
||||||
2
|
|
||||||
0 7.389056
|
|
||||||
1 1.000000
|
|
||||||
|
|
||||||
1
|
|
||||||
1
|
|
||||||
2
|
|
||||||
2
|
|
||||||
0 7.389056
|
|
||||||
1 1.000000
|
|
||||||
|
|
||||||
1
|
|
||||||
4
|
|
||||||
2
|
|
||||||
2
|
|
||||||
0 7.389056
|
|
||||||
1 1.000000
|
|
||||||
|
|
||||||
1
|
|
||||||
7
|
|
||||||
2
|
|
||||||
2
|
|
||||||
0 7.389056
|
|
||||||
1 1.000000
|
|
||||||
|
|
||||||
1
|
|
||||||
2
|
|
||||||
2
|
|
||||||
2
|
|
||||||
0 7.389056
|
|
||||||
1 1.000000
|
|
||||||
|
|
||||||
1
|
|
||||||
5
|
|
||||||
2
|
|
||||||
2
|
|
||||||
0 7.389056
|
|
||||||
1 1.000000
|
|
||||||
|
|
||||||
1
|
|
||||||
8
|
|
||||||
2
|
|
||||||
2
|
|
||||||
0 7.389056
|
|
||||||
1 1.000000
|
|
||||||
|
|
||||||
2
|
|
||||||
9 12
|
|
||||||
2 2
|
|
||||||
4
|
|
||||||
0 4.481689
|
|
||||||
1 1.000000
|
|
||||||
2 4.481689
|
|
||||||
3 4.481689
|
|
||||||
|
|
||||||
2
|
|
||||||
10 13
|
|
||||||
2 2
|
|
||||||
4
|
|
||||||
0 4.481689
|
|
||||||
1 1.000000
|
|
||||||
2 4.481689
|
|
||||||
3 4.481689
|
|
||||||
|
|
||||||
2
|
|
||||||
11 14
|
|
||||||
2 2
|
|
||||||
4
|
|
||||||
0 4.481689
|
|
||||||
1 1.000000
|
|
||||||
2 4.481689
|
|
||||||
3 4.481689
|
|
||||||
|
|
||||||
2
|
|
||||||
0 9
|
|
||||||
2 2
|
|
||||||
4
|
|
||||||
0 3.004166
|
|
||||||
1 3.004166
|
|
||||||
2 3.004166
|
|
||||||
3 3.004166
|
|
||||||
|
|
||||||
3
|
|
||||||
3 10 9
|
|
||||||
2 2 2
|
|
||||||
8
|
|
||||||
0 3.004166
|
|
||||||
1 3.004166
|
|
||||||
2 3.004166
|
|
||||||
3 1.000000
|
|
||||||
4 3.004166
|
|
||||||
5 1.000000
|
|
||||||
6 3.004166
|
|
||||||
7 3.004166
|
|
||||||
|
|
||||||
3
|
|
||||||
6 11 9
|
|
||||||
2 2 2
|
|
||||||
8
|
|
||||||
0 3.004166
|
|
||||||
1 3.004166
|
|
||||||
2 3.004166
|
|
||||||
3 1.000000
|
|
||||||
4 3.004166
|
|
||||||
5 1.000000
|
|
||||||
6 3.004166
|
|
||||||
7 3.004166
|
|
||||||
|
|
||||||
3
|
|
||||||
1 9 10
|
|
||||||
2 2 2
|
|
||||||
8
|
|
||||||
0 3.004166
|
|
||||||
1 3.004166
|
|
||||||
2 3.004166
|
|
||||||
3 1.000000
|
|
||||||
4 3.004166
|
|
||||||
5 1.000000
|
|
||||||
6 3.004166
|
|
||||||
7 3.004166
|
|
||||||
|
|
||||||
2
|
|
||||||
4 10
|
|
||||||
2 2
|
|
||||||
4
|
|
||||||
0 3.004166
|
|
||||||
1 3.004166
|
|
||||||
2 3.004166
|
|
||||||
3 3.004166
|
|
||||||
|
|
||||||
3
|
|
||||||
7 11 10
|
|
||||||
2 2 2
|
|
||||||
8
|
|
||||||
0 3.004166
|
|
||||||
1 3.004166
|
|
||||||
2 3.004166
|
|
||||||
3 1.000000
|
|
||||||
4 3.004166
|
|
||||||
5 1.000000
|
|
||||||
6 3.004166
|
|
||||||
7 3.004166
|
|
||||||
|
|
||||||
3
|
|
||||||
2 9 11
|
|
||||||
2 2 2
|
|
||||||
8
|
|
||||||
0 3.004166
|
|
||||||
1 3.004166
|
|
||||||
2 3.004166
|
|
||||||
3 1.000000
|
|
||||||
4 3.004166
|
|
||||||
5 1.000000
|
|
||||||
6 3.004166
|
|
||||||
7 3.004166
|
|
||||||
|
|
||||||
3
|
|
||||||
5 10 11
|
|
||||||
2 2 2
|
|
||||||
8
|
|
||||||
0 3.004166
|
|
||||||
1 3.004166
|
|
||||||
2 3.004166
|
|
||||||
3 1.000000
|
|
||||||
4 3.004166
|
|
||||||
5 1.000000
|
|
||||||
6 3.004166
|
|
||||||
7 3.004166
|
|
||||||
|
|
||||||
2
|
|
||||||
8 11
|
|
||||||
2 2
|
|
||||||
4
|
|
||||||
0 3.004166
|
|
||||||
1 3.004166
|
|
||||||
2 3.004166
|
|
||||||
3 3.004166
|
|
||||||
|
|
@ -1,398 +0,0 @@
|
|||||||
44
|
|
||||||
|
|
||||||
1
|
|
||||||
20
|
|
||||||
2
|
|
||||||
2
|
|
||||||
0 9.974182
|
|
||||||
1 1.000000
|
|
||||||
|
|
||||||
1
|
|
||||||
21
|
|
||||||
2
|
|
||||||
2
|
|
||||||
0 9.974182
|
|
||||||
1 1.000000
|
|
||||||
|
|
||||||
1
|
|
||||||
22
|
|
||||||
2
|
|
||||||
2
|
|
||||||
0 9.974182
|
|
||||||
1 1.000000
|
|
||||||
|
|
||||||
1
|
|
||||||
23
|
|
||||||
2
|
|
||||||
2
|
|
||||||
0 9.974182
|
|
||||||
1 1.000000
|
|
||||||
|
|
||||||
1
|
|
||||||
16
|
|
||||||
2
|
|
||||||
2
|
|
||||||
0 4.055200
|
|
||||||
1 1.000000
|
|
||||||
|
|
||||||
1
|
|
||||||
17
|
|
||||||
2
|
|
||||||
2
|
|
||||||
0 4.055200
|
|
||||||
1 1.000000
|
|
||||||
|
|
||||||
1
|
|
||||||
18
|
|
||||||
2
|
|
||||||
2
|
|
||||||
0 4.055200
|
|
||||||
1 1.000000
|
|
||||||
|
|
||||||
1
|
|
||||||
19
|
|
||||||
2
|
|
||||||
2
|
|
||||||
0 4.055200
|
|
||||||
1 1.000000
|
|
||||||
|
|
||||||
1
|
|
||||||
0
|
|
||||||
2
|
|
||||||
2
|
|
||||||
0 7.389056
|
|
||||||
1 1.000000
|
|
||||||
|
|
||||||
1
|
|
||||||
4
|
|
||||||
2
|
|
||||||
2
|
|
||||||
0 7.389056
|
|
||||||
1 1.000000
|
|
||||||
|
|
||||||
1
|
|
||||||
8
|
|
||||||
2
|
|
||||||
2
|
|
||||||
0 7.389056
|
|
||||||
1 1.000000
|
|
||||||
|
|
||||||
1
|
|
||||||
12
|
|
||||||
2
|
|
||||||
2
|
|
||||||
0 7.389056
|
|
||||||
1 1.000000
|
|
||||||
|
|
||||||
1
|
|
||||||
1
|
|
||||||
2
|
|
||||||
2
|
|
||||||
0 7.389056
|
|
||||||
1 1.000000
|
|
||||||
|
|
||||||
1
|
|
||||||
5
|
|
||||||
2
|
|
||||||
2
|
|
||||||
0 7.389056
|
|
||||||
1 1.000000
|
|
||||||
|
|
||||||
1
|
|
||||||
9
|
|
||||||
2
|
|
||||||
2
|
|
||||||
0 7.389056
|
|
||||||
1 1.000000
|
|
||||||
|
|
||||||
1
|
|
||||||
13
|
|
||||||
2
|
|
||||||
2
|
|
||||||
0 7.389056
|
|
||||||
1 1.000000
|
|
||||||
|
|
||||||
1
|
|
||||||
2
|
|
||||||
2
|
|
||||||
2
|
|
||||||
0 7.389056
|
|
||||||
1 1.000000
|
|
||||||
|
|
||||||
1
|
|
||||||
6
|
|
||||||
2
|
|
||||||
2
|
|
||||||
0 7.389056
|
|
||||||
1 1.000000
|
|
||||||
|
|
||||||
1
|
|
||||||
10
|
|
||||||
2
|
|
||||||
2
|
|
||||||
0 7.389056
|
|
||||||
1 1.000000
|
|
||||||
|
|
||||||
1
|
|
||||||
14
|
|
||||||
2
|
|
||||||
2
|
|
||||||
0 7.389056
|
|
||||||
1 1.000000
|
|
||||||
|
|
||||||
1
|
|
||||||
3
|
|
||||||
2
|
|
||||||
2
|
|
||||||
0 7.389056
|
|
||||||
1 1.000000
|
|
||||||
|
|
||||||
1
|
|
||||||
7
|
|
||||||
2
|
|
||||||
2
|
|
||||||
0 7.389056
|
|
||||||
1 1.000000
|
|
||||||
|
|
||||||
1
|
|
||||||
11
|
|
||||||
2
|
|
||||||
2
|
|
||||||
0 7.389056
|
|
||||||
1 1.000000
|
|
||||||
|
|
||||||
1
|
|
||||||
15
|
|
||||||
2
|
|
||||||
2
|
|
||||||
0 7.389056
|
|
||||||
1 1.000000
|
|
||||||
|
|
||||||
2
|
|
||||||
16 20
|
|
||||||
2 2
|
|
||||||
4
|
|
||||||
0 4.481689
|
|
||||||
1 1.000000
|
|
||||||
2 4.481689
|
|
||||||
3 4.481689
|
|
||||||
|
|
||||||
2
|
|
||||||
17 21
|
|
||||||
2 2
|
|
||||||
4
|
|
||||||
0 4.481689
|
|
||||||
1 1.000000
|
|
||||||
2 4.481689
|
|
||||||
3 4.481689
|
|
||||||
|
|
||||||
2
|
|
||||||
18 22
|
|
||||||
2 2
|
|
||||||
4
|
|
||||||
0 4.481689
|
|
||||||
1 1.000000
|
|
||||||
2 4.481689
|
|
||||||
3 4.481689
|
|
||||||
|
|
||||||
2
|
|
||||||
19 23
|
|
||||||
2 2
|
|
||||||
4
|
|
||||||
0 4.481689
|
|
||||||
1 1.000000
|
|
||||||
2 4.481689
|
|
||||||
3 4.481689
|
|
||||||
|
|
||||||
2
|
|
||||||
0 16
|
|
||||||
2 2
|
|
||||||
4
|
|
||||||
0 3.004166
|
|
||||||
1 3.004166
|
|
||||||
2 3.004166
|
|
||||||
3 3.004166
|
|
||||||
|
|
||||||
3
|
|
||||||
4 17 16
|
|
||||||
2 2 2
|
|
||||||
8
|
|
||||||
0 3.004166
|
|
||||||
1 3.004166
|
|
||||||
2 3.004166
|
|
||||||
3 1.000000
|
|
||||||
4 3.004166
|
|
||||||
5 1.000000
|
|
||||||
6 3.004166
|
|
||||||
7 3.004166
|
|
||||||
|
|
||||||
3
|
|
||||||
8 18 16
|
|
||||||
2 2 2
|
|
||||||
8
|
|
||||||
0 3.004166
|
|
||||||
1 3.004166
|
|
||||||
2 3.004166
|
|
||||||
3 1.000000
|
|
||||||
4 3.004166
|
|
||||||
5 1.000000
|
|
||||||
6 3.004166
|
|
||||||
7 3.004166
|
|
||||||
|
|
||||||
3
|
|
||||||
12 19 16
|
|
||||||
2 2 2
|
|
||||||
8
|
|
||||||
0 3.004166
|
|
||||||
1 3.004166
|
|
||||||
2 3.004166
|
|
||||||
3 1.000000
|
|
||||||
4 3.004166
|
|
||||||
5 1.000000
|
|
||||||
6 3.004166
|
|
||||||
7 3.004166
|
|
||||||
|
|
||||||
3
|
|
||||||
1 16 17
|
|
||||||
2 2 2
|
|
||||||
8
|
|
||||||
0 3.004166
|
|
||||||
1 3.004166
|
|
||||||
2 3.004166
|
|
||||||
3 1.000000
|
|
||||||
4 3.004166
|
|
||||||
5 1.000000
|
|
||||||
6 3.004166
|
|
||||||
7 3.004166
|
|
||||||
|
|
||||||
2
|
|
||||||
5 17
|
|
||||||
2 2
|
|
||||||
4
|
|
||||||
0 3.004166
|
|
||||||
1 3.004166
|
|
||||||
2 3.004166
|
|
||||||
3 3.004166
|
|
||||||
|
|
||||||
3
|
|
||||||
9 18 17
|
|
||||||
2 2 2
|
|
||||||
8
|
|
||||||
0 3.004166
|
|
||||||
1 3.004166
|
|
||||||
2 3.004166
|
|
||||||
3 1.000000
|
|
||||||
4 3.004166
|
|
||||||
5 1.000000
|
|
||||||
6 3.004166
|
|
||||||
7 3.004166
|
|
||||||
|
|
||||||
3
|
|
||||||
13 19 17
|
|
||||||
2 2 2
|
|
||||||
8
|
|
||||||
0 3.004166
|
|
||||||
1 3.004166
|
|
||||||
2 3.004166
|
|
||||||
3 1.000000
|
|
||||||
4 3.004166
|
|
||||||
5 1.000000
|
|
||||||
6 3.004166
|
|
||||||
7 3.004166
|
|
||||||
|
|
||||||
3
|
|
||||||
2 16 18
|
|
||||||
2 2 2
|
|
||||||
8
|
|
||||||
0 3.004166
|
|
||||||
1 3.004166
|
|
||||||
2 3.004166
|
|
||||||
3 1.000000
|
|
||||||
4 3.004166
|
|
||||||
5 1.000000
|
|
||||||
6 3.004166
|
|
||||||
7 3.004166
|
|
||||||
|
|
||||||
3
|
|
||||||
6 17 18
|
|
||||||
2 2 2
|
|
||||||
8
|
|
||||||
0 3.004166
|
|
||||||
1 3.004166
|
|
||||||
2 3.004166
|
|
||||||
3 1.000000
|
|
||||||
4 3.004166
|
|
||||||
5 1.000000
|
|
||||||
6 3.004166
|
|
||||||
7 3.004166
|
|
||||||
|
|
||||||
2
|
|
||||||
10 18
|
|
||||||
2 2
|
|
||||||
4
|
|
||||||
0 3.004166
|
|
||||||
1 3.004166
|
|
||||||
2 3.004166
|
|
||||||
3 3.004166
|
|
||||||
|
|
||||||
3
|
|
||||||
14 19 18
|
|
||||||
2 2 2
|
|
||||||
8
|
|
||||||
0 3.004166
|
|
||||||
1 3.004166
|
|
||||||
2 3.004166
|
|
||||||
3 1.000000
|
|
||||||
4 3.004166
|
|
||||||
5 1.000000
|
|
||||||
6 3.004166
|
|
||||||
7 3.004166
|
|
||||||
|
|
||||||
3
|
|
||||||
3 16 19
|
|
||||||
2 2 2
|
|
||||||
8
|
|
||||||
0 3.004166
|
|
||||||
1 3.004166
|
|
||||||
2 3.004166
|
|
||||||
3 1.000000
|
|
||||||
4 3.004166
|
|
||||||
5 1.000000
|
|
||||||
6 3.004166
|
|
||||||
7 3.004166
|
|
||||||
|
|
||||||
3
|
|
||||||
7 17 19
|
|
||||||
2 2 2
|
|
||||||
8
|
|
||||||
0 3.004166
|
|
||||||
1 3.004166
|
|
||||||
2 3.004166
|
|
||||||
3 1.000000
|
|
||||||
4 3.004166
|
|
||||||
5 1.000000
|
|
||||||
6 3.004166
|
|
||||||
7 3.004166
|
|
||||||
|
|
||||||
3
|
|
||||||
11 18 19
|
|
||||||
2 2 2
|
|
||||||
8
|
|
||||||
0 3.004166
|
|
||||||
1 3.004166
|
|
||||||
2 3.004166
|
|
||||||
3 1.000000
|
|
||||||
4 3.004166
|
|
||||||
5 1.000000
|
|
||||||
6 3.004166
|
|
||||||
7 3.004166
|
|
||||||
|
|
||||||
2
|
|
||||||
15 19
|
|
||||||
2 2
|
|
||||||
4
|
|
||||||
0 3.004166
|
|
||||||
1 3.004166
|
|
||||||
2 3.004166
|
|
||||||
3 3.004166
|
|
||||||
|
|
@ -1,32 +1,27 @@
|
|||||||
|
|
||||||
:- use_module(library(pfl)).
|
:- use_module(library(pfl)).
|
||||||
|
|
||||||
%:- set_clpbn_flag(solver,ve).
|
%:- clpbn_horus:set_solver(fove).
|
||||||
%:- set_clpbn_flag(solver,bp), clpbn_bp:set_horus_flag(inf_alg,ve).
|
%:- clpbn_horus:set_solver(hve).
|
||||||
:- set_clpbn_flag(solver,fove).
|
:- clpbn_horus:set_solver(bp).
|
||||||
|
%:- clpbn_horus:set_solver(cbp).
|
||||||
|
|
||||||
c(p1).
|
:- yap_flag(write_strings, off).
|
||||||
c(p2).
|
|
||||||
c(p3).
|
|
||||||
c(p4).
|
|
||||||
c(p5).
|
|
||||||
|
|
||||||
|
people @ 3.
|
||||||
|
|
||||||
markov attends(P)::[t,f] , attr1::[t,f] ; [0.1, 0.2, 0.3, 0.4] ; [c(P)].
|
markov attends(P)::[t,f], attr1::[t,f] ; [0.11, 0.2, 0.3, 0.4] ; [people(P)].
|
||||||
|
|
||||||
markov attends(P)::[t,f] , attr2::[t,f] ; [0.1, 0.2, 0.3, 0.4] ; [c(P)].
|
markov attends(P)::[t,f], attr2::[t,f] ; [0.1, 0.22, 0.3, 0.4] ; [people(P)].
|
||||||
|
|
||||||
markov attends(P)::[t,f] , attr3::[t,f] ; [0.1, 0.2, 0.3, 0.4] ; [c(P)].
|
markov attends(P)::[t,f], attr3::[t,f] ; [0.1, 0.2, 0.33, 0.4] ; [people(P)].
|
||||||
|
|
||||||
markov attends(P)::[t,f] , attr4::[t,f] ; [0.1, 0.2, 0.3, 0.4] ; [c(P)].
|
markov attends(P)::[t,f], attr4::[t,f] ; [0.1, 0.2, 0.3, 0.44] ; [people(P)].
|
||||||
|
|
||||||
markov attends(P)::[t,f] , attr5::[t,f] ; [0.1, 0.2, 0.3, 0.4] ; [c(P)].
|
markov attends(P)::[t,f], attr5::[t,f] ; [0.1, 0.2, 0.3, 0.45] ; [people(P)].
|
||||||
|
|
||||||
markov attends(P)::[t,f] , attr6::[t,f] ; [0.1, 0.2, 0.3, 0.4] ; [c(P)].
|
markov attends(P)::[t,f], attr6::[t,f] ; [0.1, 0.2, 0.3, 0.46] ; [people(P)].
|
||||||
|
|
||||||
markov attends(P)::[t,f], series::[t,f] ; [0.5, 0.6, 0.7, 0.8] ; [c(P)].
|
markov attends(P)::[t,f], series::[t,f] ; [0.5, 0.6, 0.7, 0.87] ; [people(P)].
|
||||||
|
|
||||||
%:- clpbn_horus:set_horus_flag(use_logarithms,true).
|
% ?- series(X).
|
||||||
|
|
||||||
?- series(X).
|
|
||||||
|
|
||||||
|
@ -1,241 +0,0 @@
|
|||||||
Aladdin Free Public License
|
|
||||||
(Version 8, November 18, 1999)
|
|
||||||
|
|
||||||
Copyright (C) 1994, 1995, 1997, 1998, 1999 Aladdin Enterprises,
|
|
||||||
Menlo Park, California, U.S.A. All rights reserved.
|
|
||||||
|
|
||||||
*NOTE:* This License is not the same as any of the GNU Licenses
|
|
||||||
<http://www.gnu.org/copyleft/gpl.html> published by the Free
|
|
||||||
Software Foundation <http://www.gnu.org/>. Its terms are
|
|
||||||
substantially different from those of the GNU Licenses. If you are
|
|
||||||
familiar with the GNU Licenses, please read this license with extra
|
|
||||||
care.
|
|
||||||
|
|
||||||
Aladdin Enterprises hereby grants to anyone the permission to apply this
|
|
||||||
License to their own work, as long as the entire License (including the
|
|
||||||
above notices and this paragraph) is copied with no changes, additions,
|
|
||||||
or deletions except for changing the first paragraph of Section 0 to
|
|
||||||
include a suitable description of the work to which the license is being
|
|
||||||
applied and of the person or entity that holds the copyright in the
|
|
||||||
work, and, if the License is being applied to a work created in a
|
|
||||||
country other than the United States, replacing the first paragraph of
|
|
||||||
Section 6 with an appropriate reference to the laws of the appropriate
|
|
||||||
country.
|
|
||||||
|
|
||||||
|
|
||||||
0. Subject Matter
|
|
||||||
|
|
||||||
This License applies to the computer program known as "XMLParser library".
|
|
||||||
The "Program", below, refers to such program. The Program
|
|
||||||
is a copyrighted work whose copyright is held by Frank Vanden Berghen
|
|
||||||
(the "Licensor").
|
|
||||||
|
|
||||||
A "work based on the Program" means either the Program or any derivative
|
|
||||||
work of the Program, as defined in the United States Copyright Act of
|
|
||||||
1976, such as a translation or a modification.
|
|
||||||
|
|
||||||
* BY MODIFYING OR DISTRIBUTING THE PROGRAM (OR ANY WORK BASED ON THE
|
|
||||||
PROGRAM), YOU INDICATE YOUR ACCEPTANCE OF THIS LICENSE TO DO SO, AND ALL
|
|
||||||
ITS TERMS AND CONDITIONS FOR COPYING, DISTRIBUTING OR MODIFYING THE
|
|
||||||
PROGRAM OR WORKS BASED ON IT. NOTHING OTHER THAN THIS LICENSE GRANTS YOU
|
|
||||||
PERMISSION TO MODIFY OR DISTRIBUTE THE PROGRAM OR ITS DERIVATIVE WORKS.
|
|
||||||
THESE ACTIONS ARE PROHIBITED BY LAW. IF YOU DO NOT ACCEPT THESE TERMS
|
|
||||||
AND CONDITIONS, DO NOT MODIFY OR DISTRIBUTE THE PROGRAM. *
|
|
||||||
|
|
||||||
|
|
||||||
1. Licenses.
|
|
||||||
|
|
||||||
Licensor hereby grants you the following rights, provided that you
|
|
||||||
comply with all of the restrictions set forth in this License and
|
|
||||||
provided, further, that you distribute an unmodified copy of this
|
|
||||||
License with the Program:
|
|
||||||
|
|
||||||
(a)
|
|
||||||
You may copy and distribute literal (i.e., verbatim) copies of the
|
|
||||||
Program's source code as you receive it throughout the world, in any
|
|
||||||
medium.
|
|
||||||
(b)
|
|
||||||
You may modify the Program, create works based on the Program and
|
|
||||||
distribute copies of such throughout the world, in any medium.
|
|
||||||
|
|
||||||
|
|
||||||
2. Restrictions.
|
|
||||||
|
|
||||||
This license is subject to the following restrictions:
|
|
||||||
|
|
||||||
(a)
|
|
||||||
Distribution of the Program or any work based on the Program by a
|
|
||||||
commercial organization to any third party is prohibited if any
|
|
||||||
payment is made in connection with such distribution, whether
|
|
||||||
directly (as in payment for a copy of the Program) or indirectly (as
|
|
||||||
in payment for some service related to the Program, or payment for
|
|
||||||
some product or service that includes a copy of the Program "without
|
|
||||||
charge"; these are only examples, and not an exhaustive enumeration
|
|
||||||
of prohibited activities). The following methods of distribution
|
|
||||||
involving payment shall not in and of themselves be a violation of
|
|
||||||
this restriction:
|
|
||||||
|
|
||||||
(i)
|
|
||||||
Posting the Program on a public access information storage and
|
|
||||||
retrieval service for which a fee is received for retrieving
|
|
||||||
information (such as an on-line service), provided that the fee
|
|
||||||
is not content-dependent (i.e., the fee would be the same for
|
|
||||||
retrieving the same volume of information consisting of random
|
|
||||||
data) and that access to the service and to the Program is
|
|
||||||
available independent of any other product or service. An
|
|
||||||
example of a service that does not fall under this section is an
|
|
||||||
on-line service that is operated by a company and that is only
|
|
||||||
available to customers of that company. (This is not an
|
|
||||||
exhaustive enumeration.)
|
|
||||||
(ii)
|
|
||||||
Distributing the Program on removable computer-readable media,
|
|
||||||
provided that the files containing the Program are reproduced
|
|
||||||
entirely and verbatim on such media, that all information on
|
|
||||||
such media be redistributable for non-commercial purposes
|
|
||||||
without charge, and that such media are distributed by
|
|
||||||
themselves (except for accompanying documentation) independent
|
|
||||||
of any other product or service. Examples of such media include
|
|
||||||
CD-ROM, magnetic tape, and optical storage media. (This is not
|
|
||||||
intended to be an exhaustive list.) An example of a distribution
|
|
||||||
that does not fall under this section is a CD-ROM included in a
|
|
||||||
book or magazine. (This is not an exhaustive enumeration.)
|
|
||||||
|
|
||||||
(b)
|
|
||||||
Activities other than copying, distribution and modification of the
|
|
||||||
Program are not subject to this License and they are outside its
|
|
||||||
scope. Functional use (running) of the Program is not restricted,
|
|
||||||
and any output produced through the use of the Program is subject to
|
|
||||||
this license only if its contents constitute a work based on the
|
|
||||||
Program (independent of having been made by running the Program).
|
|
||||||
(c)
|
|
||||||
You must meet all of the following conditions with respect to any
|
|
||||||
work that you distribute or publish that in whole or in part
|
|
||||||
contains or is derived from the Program or any part thereof ("the
|
|
||||||
Work"):
|
|
||||||
|
|
||||||
(i)
|
|
||||||
If you have modified the Program, you must cause the Work to
|
|
||||||
carry prominent notices stating that you have modified the
|
|
||||||
Program's files and the date of any change. In each source file
|
|
||||||
that you have modified, you must include a prominent notice that
|
|
||||||
you have modified the file, including your name, your e-mail
|
|
||||||
address (if any), and the date and purpose of the change;
|
|
||||||
(ii)
|
|
||||||
You must cause the Work to be licensed as a whole and at no
|
|
||||||
charge to all third parties under the terms of this License;
|
|
||||||
(iii)
|
|
||||||
If the Work normally reads commands interactively when run, you
|
|
||||||
must cause it, at each time the Work commences operation, to
|
|
||||||
print or display an announcement including an appropriate
|
|
||||||
copyright notice and a notice that there is no warranty (or
|
|
||||||
else, saying that you provide a warranty). Such notice must also
|
|
||||||
state that users may redistribute the Work only under the
|
|
||||||
conditions of this License and tell the user how to view the
|
|
||||||
copy of this License included with the Work. (Exceptions: if the
|
|
||||||
Program is interactive but normally prints or displays such an
|
|
||||||
announcement only at the request of a user, such as in an "About
|
|
||||||
box", the Work is required to print or display the notice only
|
|
||||||
under the same circumstances; if the Program itself is
|
|
||||||
interactive but does not normally print such an announcement,
|
|
||||||
the Work is not required to print an announcement.);
|
|
||||||
(iv)
|
|
||||||
You must accompany the Work with the complete corresponding
|
|
||||||
machine-readable source code, delivered on a medium customarily
|
|
||||||
used for software interchange. The source code for a work means
|
|
||||||
the preferred form of the work for making modifications to it.
|
|
||||||
For an executable work, complete source code means all the
|
|
||||||
source code for all modules it contains, plus any associated
|
|
||||||
interface definition files, plus the scripts used to control
|
|
||||||
compilation and installation of the executable code. If you
|
|
||||||
distribute with the Work any component that is normally
|
|
||||||
distributed (in either source or binary form) with the major
|
|
||||||
components (compiler, kernel, and so on) of the operating system
|
|
||||||
on which the executable runs, you must also distribute the
|
|
||||||
source code of that component if you have it and are allowed to
|
|
||||||
do so;
|
|
||||||
(v)
|
|
||||||
If you distribute any written or printed material at all with
|
|
||||||
the Work, such material must include either a written copy of
|
|
||||||
this License, or a prominent written indication that the Work is
|
|
||||||
covered by this License and written instructions for printing
|
|
||||||
and/or displaying the copy of the License on the distribution
|
|
||||||
medium;
|
|
||||||
(vi)
|
|
||||||
You may not impose any further restrictions on the recipient's
|
|
||||||
exercise of the rights granted herein.
|
|
||||||
|
|
||||||
If distribution of executable or object code is made by offering the
|
|
||||||
equivalent ability to copy from a designated place, then offering
|
|
||||||
equivalent ability to copy the source code from the same place counts as
|
|
||||||
distribution of the source code, even though third parties are not
|
|
||||||
compelled to copy the source code along with the object code.
|
|
||||||
|
|
||||||
|
|
||||||
3. Reservation of Rights.
|
|
||||||
|
|
||||||
No rights are granted to the Program except as expressly set forth
|
|
||||||
herein. You may not copy, modify, sublicense, or distribute the Program
|
|
||||||
except as expressly provided under this License. Any attempt otherwise
|
|
||||||
to copy, modify, sublicense or distribute the Program is void, and will
|
|
||||||
automatically terminate your rights under this License. However, parties
|
|
||||||
who have received copies, or rights, from you under this License will
|
|
||||||
not have their licenses terminated so long as such parties remain in
|
|
||||||
full compliance.
|
|
||||||
|
|
||||||
|
|
||||||
4. Other Restrictions.
|
|
||||||
|
|
||||||
If the distribution and/or use of the Program is restricted in certain
|
|
||||||
countries for any reason, Licensor may add an explicit geographical
|
|
||||||
distribution limitation excluding those countries, so that distribution
|
|
||||||
is permitted only in or among countries not thus excluded. In such case,
|
|
||||||
this License incorporates the limitation as if written in the body of
|
|
||||||
this License.
|
|
||||||
|
|
||||||
|
|
||||||
5. Limitations.
|
|
||||||
|
|
||||||
* THE PROGRAM IS PROVIDED TO YOU "AS IS," WITHOUT WARRANTY. THERE IS NO
|
|
||||||
WARRANTY FOR THE PROGRAM, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT
|
|
||||||
NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
|
|
||||||
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF THIRD PARTY RIGHTS. THE
|
|
||||||
ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH
|
|
||||||
YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL
|
|
||||||
NECESSARY SERVICING, REPAIR OR CORRECTION. *
|
|
||||||
|
|
||||||
* IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
|
|
||||||
WILL LICENSOR, OR ANY OTHER PARTY WHO MAY MODIFY AND/OR REDISTRIBUTE THE
|
|
||||||
PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
|
|
||||||
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
|
|
||||||
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS
|
|
||||||
OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR
|
|
||||||
THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER
|
|
||||||
PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE
|
|
||||||
POSSIBILITY OF SUCH DAMAGES. *
|
|
||||||
|
|
||||||
|
|
||||||
6. General.
|
|
||||||
|
|
||||||
This License is governed by the laws of Belgium., excluding choice of
|
|
||||||
law rules.
|
|
||||||
|
|
||||||
If any part of this License is found to be in conflict with the law,
|
|
||||||
that part shall be interpreted in its broadest meaning consistent with
|
|
||||||
the law, and no other parts of the License shall be affected.
|
|
||||||
|
|
||||||
For United States Government users, the Program is provided with
|
|
||||||
*RESTRICTED RIGHTS*. If you are a unit or agency of the United States
|
|
||||||
Government or are acquiring the Program for any such unit or agency, the
|
|
||||||
following apply:
|
|
||||||
|
|
||||||
If the unit or agency is the Department of Defense ("DOD"), the
|
|
||||||
Program and its documentation are classified as "commercial computer
|
|
||||||
software" and "commercial computer software documentation"
|
|
||||||
respectively and, pursuant to DFAR Section 227.7202, the Government
|
|
||||||
is acquiring the Program and its documentation in accordance with
|
|
||||||
the terms of this License. If the unit or agency is other than DOD,
|
|
||||||
the Program and its documentation are classified as "commercial
|
|
||||||
computer software" and "commercial computer software documentation"
|
|
||||||
respectively and, pursuant to FAR Section 12.212, the Government is
|
|
||||||
acquiring the Program and its documentation in accordance with the
|
|
||||||
terms of this License.
|
|
File diff suppressed because it is too large
Load Diff
@ -1,734 +0,0 @@
|
|||||||
|
|
||||||
/****************************************************************************/
|
|
||||||
/*! \mainpage XMLParser library
|
|
||||||
* \section intro_sec Introduction
|
|
||||||
*
|
|
||||||
* This is a basic XML parser written in ANSI C++ for portability.
|
|
||||||
* It works by using recursion and a node tree for breaking
|
|
||||||
* down the elements of an XML document.
|
|
||||||
*
|
|
||||||
* @version V2.42
|
|
||||||
* @author Frank Vanden Berghen
|
|
||||||
*
|
|
||||||
* Copyright (c) 2002, Business-Insight
|
|
||||||
* <a href="http://www.Business-Insight.com">Business-Insight</a>
|
|
||||||
* All rights reserved.
|
|
||||||
* See the file <a href="../../AFPL-license.txt">AFPL-license.txt</a> about the licensing terms
|
|
||||||
*
|
|
||||||
* \section tutorial First Tutorial
|
|
||||||
* You can follow a simple <a href="../../xmlParser.html">Tutorial</a> to know the basics...
|
|
||||||
*
|
|
||||||
* \section usage General usage: How to include the XMLParser library inside your project.
|
|
||||||
*
|
|
||||||
* The library is composed of two files: <a href="../../xmlParser.cpp">xmlParser.cpp</a> and
|
|
||||||
* <a href="../../xmlParser.h">xmlParser.h</a>. These are the ONLY 2 files that you need when
|
|
||||||
* using the library inside your own projects.
|
|
||||||
*
|
|
||||||
* All the functions of the library are documented inside the comments of the file
|
|
||||||
* <a href="../../xmlParser.h">xmlParser.h</a>. These comments can be transformed in
|
|
||||||
* full-fledged HTML documentation using the DOXYGEN software: simply type: "doxygen doxy.cfg"
|
|
||||||
*
|
|
||||||
* By default, the XMLParser library uses (char*) for string representation.To use the (wchar_t*)
|
|
||||||
* version of the library, you need to define the "_UNICODE" preprocessor definition variable
|
|
||||||
* (this is usually done inside your project definition file) (This is done automatically for you
|
|
||||||
* when using Visual Studio).
|
|
||||||
*
|
|
||||||
* \section example Advanced Tutorial and Many Examples of usage.
|
|
||||||
*
|
|
||||||
* Some very small introductory examples are described inside the Tutorial file
|
|
||||||
* <a href="../../xmlParser.html">xmlParser.html</a>
|
|
||||||
*
|
|
||||||
* Some additional small examples are also inside the file <a href="../../xmlTest.cpp">xmlTest.cpp</a>
|
|
||||||
* (for the "char*" version of the library) and inside the file
|
|
||||||
* <a href="../../xmlTestUnicode.cpp">xmlTestUnicode.cpp</a> (for the "wchar_t*"
|
|
||||||
* version of the library). If you have a question, please review these additionnal examples
|
|
||||||
* before sending an e-mail to the author.
|
|
||||||
*
|
|
||||||
* To build the examples:
|
|
||||||
* - linux/unix: type "make"
|
|
||||||
* - solaris: type "make -f makefile.solaris"
|
|
||||||
* - windows: Visual Studio: double-click on xmlParser.dsw
|
|
||||||
* (under Visual Studio .NET, the .dsp and .dsw files will be automatically converted to .vcproj and .sln files)
|
|
||||||
*
|
|
||||||
* In order to build the examples you need some additional files:
|
|
||||||
* - linux/unix: makefile
|
|
||||||
* - solaris: makefile.solaris
|
|
||||||
* - windows: Visual Studio: *.dsp, xmlParser.dsw and also xmlParser.lib and xmlParser.dll
|
|
||||||
*
|
|
||||||
* \section debugging Debugging with the XMLParser library
|
|
||||||
*
|
|
||||||
* \subsection debugwin Debugging under WINDOWS
|
|
||||||
*
|
|
||||||
* Inside Visual C++, the "debug versions" of the memory allocation functions are
|
|
||||||
* very slow: Do not forget to compile in "release mode" to get maximum speed.
|
|
||||||
* When I had to debug a software that was using the XMLParser Library, it was usually
|
|
||||||
* a nightmare because the library was sooOOOoooo slow in debug mode (because of the
|
|
||||||
* slow memory allocations in Debug mode). To solve this
|
|
||||||
* problem, during all the debugging session, I am now using a very fast DLL version of the
|
|
||||||
* XMLParser Library (the DLL is compiled in release mode). Using the DLL version of
|
|
||||||
* the XMLParser Library allows me to have lightening XML parsing speed even in debug!
|
|
||||||
* Other than that, the DLL version is useless: In the release version of my tool,
|
|
||||||
* I always use the normal, ".cpp"-based, XMLParser Library (I simply include the
|
|
||||||
* <a href="../../xmlParser.cpp">xmlParser.cpp</a> and
|
|
||||||
* <a href="../../xmlParser.h">xmlParser.h</a> files into the project).
|
|
||||||
*
|
|
||||||
* The file <a href="../../XMLNodeAutoexp.txt">XMLNodeAutoexp.txt</a> contains some
|
|
||||||
* "tweaks" that improve substancially the display of the content of the XMLNode objects
|
|
||||||
* inside the Visual Studio Debugger. Believe me, once you have seen inside the debugger
|
|
||||||
* the "smooth" display of the XMLNode objects, you cannot live without it anymore!
|
|
||||||
*
|
|
||||||
* \subsection debuglinux Debugging under LINUX/UNIX
|
|
||||||
*
|
|
||||||
* The speed of the debug version of the XMLParser library is tolerable so no extra
|
|
||||||
* work.has been done.
|
|
||||||
*
|
|
||||||
****************************************************************************/
|
|
||||||
|
|
||||||
#ifndef __INCLUDE_XML_NODE__
|
|
||||||
#define __INCLUDE_XML_NODE__
|
|
||||||
|
|
||||||
#include <stdlib.h>
|
|
||||||
|
|
||||||
#ifdef _UNICODE
|
|
||||||
// If you comment the next "define" line then the library will never "switch to" _UNICODE (wchar_t*) mode (16/32 bits per characters).
|
|
||||||
// This is useful when you get error messages like:
|
|
||||||
// 'XMLNode::openFileHelper' : cannot convert parameter 2 from 'const char [5]' to 'const wchar_t *'
|
|
||||||
// The _XMLWIDECHAR preprocessor variable force the XMLParser library into either utf16/32-mode (the proprocessor variable
|
|
||||||
// must be defined) or utf8-mode(the pre-processor variable must be undefined).
|
|
||||||
#define _XMLWIDECHAR
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#if defined(WIN32) || defined(UNDER_CE) || defined(_WIN32) || defined(WIN64) || defined(__BORLANDC__)
|
|
||||||
// comment the next line if you are under windows and the compiler is not Microsoft Visual Studio (6.0 or .NET) or Borland
|
|
||||||
#define _XMLWINDOWS
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#ifdef XMLDLLENTRY
|
|
||||||
#undef XMLDLLENTRY
|
|
||||||
#endif
|
|
||||||
#ifdef _USE_XMLPARSER_DLL
|
|
||||||
#ifdef _DLL_EXPORTS_
|
|
||||||
#define XMLDLLENTRY __declspec(dllexport)
|
|
||||||
#else
|
|
||||||
#define XMLDLLENTRY __declspec(dllimport)
|
|
||||||
#endif
|
|
||||||
#else
|
|
||||||
#define XMLDLLENTRY
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// uncomment the next line if you want no support for wchar_t* (no need for the <wchar.h> or <tchar.h> libraries anymore to compile)
|
|
||||||
//#define XML_NO_WIDE_CHAR
|
|
||||||
|
|
||||||
#ifdef XML_NO_WIDE_CHAR
|
|
||||||
#undef _XMLWINDOWS
|
|
||||||
#undef _XMLWIDECHAR
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#ifdef _XMLWINDOWS
|
|
||||||
#include <tchar.h>
|
|
||||||
#else
|
|
||||||
#define XMLDLLENTRY
|
|
||||||
#ifndef XML_NO_WIDE_CHAR
|
|
||||||
#include <wchar.h> // to have 'wcsrtombs' for ANSI version
|
|
||||||
// to have 'mbsrtowcs' for WIDECHAR version
|
|
||||||
#endif
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// Some common types for char set portable code
|
|
||||||
#ifdef _XMLWIDECHAR
|
|
||||||
#define _CXML(c) L ## c
|
|
||||||
#define XMLCSTR const wchar_t *
|
|
||||||
#define XMLSTR wchar_t *
|
|
||||||
#define XMLCHAR wchar_t
|
|
||||||
#else
|
|
||||||
#define _CXML(c) c
|
|
||||||
#define XMLCSTR const char *
|
|
||||||
#define XMLSTR char *
|
|
||||||
#define XMLCHAR char
|
|
||||||
#endif
|
|
||||||
#ifndef FALSE
|
|
||||||
#define FALSE 0
|
|
||||||
#endif /* FALSE */
|
|
||||||
#ifndef TRUE
|
|
||||||
#define TRUE 1
|
|
||||||
#endif /* TRUE */
|
|
||||||
|
|
||||||
|
|
||||||
/// Enumeration for XML parse errors.
|
|
||||||
typedef enum XMLError
|
|
||||||
{
|
|
||||||
eXMLErrorNone = 0,
|
|
||||||
eXMLErrorMissingEndTag,
|
|
||||||
eXMLErrorNoXMLTagFound,
|
|
||||||
eXMLErrorEmpty,
|
|
||||||
eXMLErrorMissingTagName,
|
|
||||||
eXMLErrorMissingEndTagName,
|
|
||||||
eXMLErrorUnmatchedEndTag,
|
|
||||||
eXMLErrorUnmatchedEndClearTag,
|
|
||||||
eXMLErrorUnexpectedToken,
|
|
||||||
eXMLErrorNoElements,
|
|
||||||
eXMLErrorFileNotFound,
|
|
||||||
eXMLErrorFirstTagNotFound,
|
|
||||||
eXMLErrorUnknownCharacterEntity,
|
|
||||||
eXMLErrorCharacterCodeAbove255,
|
|
||||||
eXMLErrorCharConversionError,
|
|
||||||
eXMLErrorCannotOpenWriteFile,
|
|
||||||
eXMLErrorCannotWriteFile,
|
|
||||||
|
|
||||||
eXMLErrorBase64DataSizeIsNotMultipleOf4,
|
|
||||||
eXMLErrorBase64DecodeIllegalCharacter,
|
|
||||||
eXMLErrorBase64DecodeTruncatedData,
|
|
||||||
eXMLErrorBase64DecodeBufferTooSmall
|
|
||||||
} XMLError;
|
|
||||||
|
|
||||||
|
|
||||||
/// Enumeration used to manage type of data. Use in conjunction with structure XMLNodeContents
|
|
||||||
typedef enum XMLElementType
|
|
||||||
{
|
|
||||||
eNodeChild=0,
|
|
||||||
eNodeAttribute=1,
|
|
||||||
eNodeText=2,
|
|
||||||
eNodeClear=3,
|
|
||||||
eNodeNULL=4
|
|
||||||
} XMLElementType;
|
|
||||||
|
|
||||||
/// Structure used to obtain error details if the parse fails.
|
|
||||||
typedef struct XMLResults
|
|
||||||
{
|
|
||||||
enum XMLError error;
|
|
||||||
int nLine,nColumn;
|
|
||||||
} XMLResults;
|
|
||||||
|
|
||||||
/// Structure for XML clear (unformatted) node (usually comments)
|
|
||||||
typedef struct XMLClear {
|
|
||||||
XMLCSTR lpszValue; XMLCSTR lpszOpenTag; XMLCSTR lpszCloseTag;
|
|
||||||
} XMLClear;
|
|
||||||
|
|
||||||
/// Structure for XML attribute.
|
|
||||||
typedef struct XMLAttribute {
|
|
||||||
XMLCSTR lpszName; XMLCSTR lpszValue;
|
|
||||||
} XMLAttribute;
|
|
||||||
|
|
||||||
/// XMLElementPosition are not interchangeable with simple indexes
|
|
||||||
typedef int XMLElementPosition;
|
|
||||||
|
|
||||||
struct XMLNodeContents;
|
|
||||||
|
|
||||||
/** @defgroup XMLParserGeneral The XML parser */
|
|
||||||
|
|
||||||
/// Main Class representing a XML node
|
|
||||||
/**
|
|
||||||
* All operations are performed using this class.
|
|
||||||
* \note The constructors of the XMLNode class are protected, so use instead one of these four methods to get your first instance of XMLNode:
|
|
||||||
* <ul>
|
|
||||||
* <li> XMLNode::parseString </li>
|
|
||||||
* <li> XMLNode::parseFile </li>
|
|
||||||
* <li> XMLNode::openFileHelper </li>
|
|
||||||
* <li> XMLNode::createXMLTopNode (or XMLNode::createXMLTopNode_WOSD)</li>
|
|
||||||
* </ul> */
|
|
||||||
typedef struct XMLDLLENTRY XMLNode
|
|
||||||
{
|
|
||||||
private:
|
|
||||||
|
|
||||||
struct XMLNodeDataTag;
|
|
||||||
|
|
||||||
/// Constructors are protected, so use instead one of: XMLNode::parseString, XMLNode::parseFile, XMLNode::openFileHelper, XMLNode::createXMLTopNode
|
|
||||||
XMLNode(struct XMLNodeDataTag *pParent, XMLSTR lpszName, char isDeclaration);
|
|
||||||
/// Constructors are protected, so use instead one of: XMLNode::parseString, XMLNode::parseFile, XMLNode::openFileHelper, XMLNode::createXMLTopNode
|
|
||||||
XMLNode(struct XMLNodeDataTag *p);
|
|
||||||
|
|
||||||
public:
|
|
||||||
static XMLCSTR getVersion();///< Return the XMLParser library version number
|
|
||||||
|
|
||||||
/** @defgroup conversions Parsing XML files/strings to an XMLNode structure and Rendering XMLNode's to files/string.
|
|
||||||
* @ingroup XMLParserGeneral
|
|
||||||
* @{ */
|
|
||||||
|
|
||||||
/// Parse an XML string and return the root of a XMLNode tree representing the string.
|
|
||||||
static XMLNode parseString (XMLCSTR lpXMLString, XMLCSTR tag=NULL, XMLResults *pResults=NULL);
|
|
||||||
/**< The "parseString" function parse an XML string and return the root of a XMLNode tree. The "opposite" of this function is
|
|
||||||
* the function "createXMLString" that re-creates an XML string from an XMLNode tree. If the XML document is corrupted, the
|
|
||||||
* "parseString" method will initialize the "pResults" variable with some information that can be used to trace the error.
|
|
||||||
* If you still want to parse the file, you can use the APPROXIMATE_PARSING option as explained inside the note at the
|
|
||||||
* beginning of the "xmlParser.cpp" file.
|
|
||||||
*
|
|
||||||
* @param lpXMLString the XML string to parse
|
|
||||||
* @param tag the name of the first tag inside the XML file. If the tag parameter is omitted, this function returns a node that represents the head of the xml document including the declaration term (<? ... ?>).
|
|
||||||
* @param pResults a pointer to a XMLResults variable that will contain some information that can be used to trace the XML parsing error. You can have a user-friendly explanation of the parsing error with the "getError" function.
|
|
||||||
*/
|
|
||||||
|
|
||||||
/// Parse an XML file and return the root of a XMLNode tree representing the file.
|
|
||||||
static XMLNode parseFile (XMLCSTR filename, XMLCSTR tag=NULL, XMLResults *pResults=NULL);
|
|
||||||
/**< The "parseFile" function parse an XML file and return the root of a XMLNode tree. The "opposite" of this function is
|
|
||||||
* the function "writeToFile" that re-creates an XML file from an XMLNode tree. If the XML document is corrupted, the
|
|
||||||
* "parseFile" method will initialize the "pResults" variable with some information that can be used to trace the error.
|
|
||||||
* If you still want to parse the file, you can use the APPROXIMATE_PARSING option as explained inside the note at the
|
|
||||||
* beginning of the "xmlParser.cpp" file.
|
|
||||||
*
|
|
||||||
* @param filename the path to the XML file to parse
|
|
||||||
* @param tag the name of the first tag inside the XML file. If the tag parameter is omitted, this function returns a node that represents the head of the xml document including the declaration term (<? ... ?>).
|
|
||||||
* @param pResults a pointer to a XMLResults variable that will contain some information that can be used to trace the XML parsing error. You can have a user-friendly explanation of the parsing error with the "getError" function.
|
|
||||||
*/
|
|
||||||
|
|
||||||
/// Parse an XML file and return the root of a XMLNode tree representing the file. A very crude error checking is made. An attempt to guess the Char Encoding used in the file is made.
|
|
||||||
static XMLNode openFileHelper(XMLCSTR filename, XMLCSTR tag=NULL);
|
|
||||||
/**< The "openFileHelper" function reports to the screen all the warnings and errors that occurred during parsing of the XML file.
|
|
||||||
* This function also tries to guess char Encoding (UTF-8, ASCII or SHIT-JIS) based on the first 200 bytes of the file. Since each
|
|
||||||
* application has its own way to report and deal with errors, you should rather use the "parseFile" function to parse XML files
|
|
||||||
* and program yourself thereafter an "error reporting" tailored for your needs (instead of using the very crude "error reporting"
|
|
||||||
* mechanism included inside the "openFileHelper" function).
|
|
||||||
*
|
|
||||||
* If the XML document is corrupted, the "openFileHelper" method will:
|
|
||||||
* - display an error message on the console (or inside a messageBox for windows).
|
|
||||||
* - stop execution (exit).
|
|
||||||
*
|
|
||||||
* I strongly suggest that you write your own "openFileHelper" method tailored to your needs. If you still want to parse
|
|
||||||
* the file, you can use the APPROXIMATE_PARSING option as explained inside the note at the beginning of the "xmlParser.cpp" file.
|
|
||||||
*
|
|
||||||
* @param filename the path of the XML file to parse.
|
|
||||||
* @param tag the name of the first tag inside the XML file. If the tag parameter is omitted, this function returns a node that represents the head of the xml document including the declaration term (<? ... ?>).
|
|
||||||
*/
|
|
||||||
|
|
||||||
static XMLCSTR getError(XMLError error); ///< this gives you a user-friendly explanation of the parsing error
|
|
||||||
|
|
||||||
/// Create an XML string starting from the current XMLNode.
|
|
||||||
XMLSTR createXMLString(int nFormat=1, int *pnSize=NULL) const;
|
|
||||||
/**< The returned string should be free'd using the "freeXMLString" function.
|
|
||||||
*
|
|
||||||
* If nFormat==0, no formatting is required otherwise this returns an user friendly XML string from a given element
|
|
||||||
* with appropriate white spaces and carriage returns. if pnSize is given it returns the size in character of the string. */
|
|
||||||
|
|
||||||
/// Save the content of an xmlNode inside a file
|
|
||||||
XMLError writeToFile(XMLCSTR filename,
|
|
||||||
const char *encoding=NULL,
|
|
||||||
char nFormat=1) const;
|
|
||||||
/**< If nFormat==0, no formatting is required otherwise this returns an user friendly XML string from a given element with appropriate white spaces and carriage returns.
|
|
||||||
* If the global parameter "characterEncoding==encoding_UTF8", then the "encoding" parameter is ignored and always set to "utf-8".
|
|
||||||
* If the global parameter "characterEncoding==encoding_ShiftJIS", then the "encoding" parameter is ignored and always set to "SHIFT-JIS".
|
|
||||||
* If "_XMLWIDECHAR=1", then the "encoding" parameter is ignored and always set to "utf-16".
|
|
||||||
* If no "encoding" parameter is given the "ISO-8859-1" encoding is used. */
|
|
||||||
/** @} */
|
|
||||||
|
|
||||||
/** @defgroup navigate Navigate the XMLNode structure
|
|
||||||
* @ingroup XMLParserGeneral
|
|
||||||
* @{ */
|
|
||||||
XMLCSTR getName() const; ///< name of the node
|
|
||||||
XMLCSTR getText(int i=0) const; ///< return ith text field
|
|
||||||
int nText() const; ///< nbr of text field
|
|
||||||
XMLNode getParentNode() const; ///< return the parent node
|
|
||||||
XMLNode getChildNode(int i=0) const; ///< return ith child node
|
|
||||||
XMLNode getChildNode(XMLCSTR name, int i) const; ///< return ith child node with specific name (return an empty node if failing). If i==-1, this returns the last XMLNode with the given name.
|
|
||||||
XMLNode getChildNode(XMLCSTR name, int *i=NULL) const; ///< return next child node with specific name (return an empty node if failing)
|
|
||||||
XMLNode getChildNodeWithAttribute(XMLCSTR tagName,
|
|
||||||
XMLCSTR attributeName,
|
|
||||||
XMLCSTR attributeValue=NULL,
|
|
||||||
int *i=NULL) const; ///< return child node with specific name/attribute (return an empty node if failing)
|
|
||||||
XMLNode getChildNodeByPath(XMLCSTR path, char createNodeIfMissing=0, XMLCHAR sep='/');
|
|
||||||
///< return the first child node with specific path
|
|
||||||
XMLNode getChildNodeByPathNonConst(XMLSTR path, char createNodeIfMissing=0, XMLCHAR sep='/');
|
|
||||||
///< return the first child node with specific path.
|
|
||||||
|
|
||||||
int nChildNode(XMLCSTR name) const; ///< return the number of child node with specific name
|
|
||||||
int nChildNode() const; ///< nbr of child node
|
|
||||||
XMLAttribute getAttribute(int i=0) const; ///< return ith attribute
|
|
||||||
XMLCSTR getAttributeName(int i=0) const; ///< return ith attribute name
|
|
||||||
XMLCSTR getAttributeValue(int i=0) const; ///< return ith attribute value
|
|
||||||
char isAttributeSet(XMLCSTR name) const; ///< test if an attribute with a specific name is given
|
|
||||||
XMLCSTR getAttribute(XMLCSTR name, int i) const; ///< return ith attribute content with specific name (return a NULL if failing)
|
|
||||||
XMLCSTR getAttribute(XMLCSTR name, int *i=NULL) const; ///< return next attribute content with specific name (return a NULL if failing)
|
|
||||||
int nAttribute() const; ///< nbr of attribute
|
|
||||||
XMLClear getClear(int i=0) const; ///< return ith clear field (comments)
|
|
||||||
int nClear() const; ///< nbr of clear field
|
|
||||||
XMLNodeContents enumContents(XMLElementPosition i) const; ///< enumerate all the different contents (attribute,child,text, clear) of the current XMLNode. The order is reflecting the order of the original file/string. NOTE: 0 <= i < nElement();
|
|
||||||
int nElement() const; ///< nbr of different contents for current node
|
|
||||||
char isEmpty() const; ///< is this node Empty?
|
|
||||||
char isDeclaration() const; ///< is this node a declaration <? .... ?>
|
|
||||||
XMLNode deepCopy() const; ///< deep copy (duplicate/clone) a XMLNode
|
|
||||||
static XMLNode emptyNode(); ///< return XMLNode::emptyXMLNode;
|
|
||||||
/** @} */
|
|
||||||
|
|
||||||
~XMLNode();
|
|
||||||
XMLNode(const XMLNode &A); ///< to allow shallow/fast copy:
|
|
||||||
XMLNode& operator=( const XMLNode& A ); ///< to allow shallow/fast copy:
|
|
||||||
|
|
||||||
XMLNode(): d(NULL){};
|
|
||||||
static XMLNode emptyXMLNode;
|
|
||||||
static XMLClear emptyXMLClear;
|
|
||||||
static XMLAttribute emptyXMLAttribute;
|
|
||||||
|
|
||||||
/** @defgroup xmlModify Create or Update the XMLNode structure
|
|
||||||
* @ingroup XMLParserGeneral
|
|
||||||
* The functions in this group allows you to create from scratch (or update) a XMLNode structure. Start by creating your top
|
|
||||||
* node with the "createXMLTopNode" function and then add new nodes with the "addChild" function. The parameter 'pos' gives
|
|
||||||
* the position where the childNode, the text or the XMLClearTag will be inserted. The default value (pos=-1) inserts at the
|
|
||||||
* end. The value (pos=0) insert at the beginning (Insertion at the beginning is slower than at the end). <br>
|
|
||||||
*
|
|
||||||
* REMARK: 0 <= pos < nChild()+nText()+nClear() <br>
|
|
||||||
*/
|
|
||||||
|
|
||||||
/** @defgroup creation Creating from scratch a XMLNode structure
|
|
||||||
* @ingroup xmlModify
|
|
||||||
* @{ */
|
|
||||||
static XMLNode createXMLTopNode(XMLCSTR lpszName, char isDeclaration=FALSE); ///< Create the top node of an XMLNode structure
|
|
||||||
XMLNode addChild(XMLCSTR lpszName, char isDeclaration=FALSE, XMLElementPosition pos=-1); ///< Add a new child node
|
|
||||||
XMLNode addChild(XMLNode nodeToAdd, XMLElementPosition pos=-1); ///< If the "nodeToAdd" has some parents, it will be detached from it's parents before being attached to the current XMLNode
|
|
||||||
XMLAttribute *addAttribute(XMLCSTR lpszName, XMLCSTR lpszValuev); ///< Add a new attribute
|
|
||||||
XMLCSTR addText(XMLCSTR lpszValue, XMLElementPosition pos=-1); ///< Add a new text content
|
|
||||||
XMLClear *addClear(XMLCSTR lpszValue, XMLCSTR lpszOpen=NULL, XMLCSTR lpszClose=NULL, XMLElementPosition pos=-1);
|
|
||||||
/**< Add a new clear tag
|
|
||||||
* @param lpszOpen default value "<![CDATA["
|
|
||||||
* @param lpszClose default value "]]>"
|
|
||||||
*/
|
|
||||||
/** @} */
|
|
||||||
|
|
||||||
/** @defgroup xmlUpdate Updating Nodes
|
|
||||||
* @ingroup xmlModify
|
|
||||||
* Some update functions:
|
|
||||||
* @{
|
|
||||||
*/
|
|
||||||
XMLCSTR updateName(XMLCSTR lpszName); ///< change node's name
|
|
||||||
XMLAttribute *updateAttribute(XMLAttribute *newAttribute, XMLAttribute *oldAttribute); ///< if the attribute to update is missing, a new one will be added
|
|
||||||
XMLAttribute *updateAttribute(XMLCSTR lpszNewValue, XMLCSTR lpszNewName=NULL,int i=0); ///< if the attribute to update is missing, a new one will be added
|
|
||||||
XMLAttribute *updateAttribute(XMLCSTR lpszNewValue, XMLCSTR lpszNewName,XMLCSTR lpszOldName);///< set lpszNewName=NULL if you don't want to change the name of the attribute if the attribute to update is missing, a new one will be added
|
|
||||||
XMLCSTR updateText(XMLCSTR lpszNewValue, int i=0); ///< if the text to update is missing, a new one will be added
|
|
||||||
XMLCSTR updateText(XMLCSTR lpszNewValue, XMLCSTR lpszOldValue); ///< if the text to update is missing, a new one will be added
|
|
||||||
XMLClear *updateClear(XMLCSTR lpszNewContent, int i=0); ///< if the clearTag to update is missing, a new one will be added
|
|
||||||
XMLClear *updateClear(XMLClear *newP,XMLClear *oldP); ///< if the clearTag to update is missing, a new one will be added
|
|
||||||
XMLClear *updateClear(XMLCSTR lpszNewValue, XMLCSTR lpszOldValue); ///< if the clearTag to update is missing, a new one will be added
|
|
||||||
/** @} */
|
|
||||||
|
|
||||||
/** @defgroup xmlDelete Deleting Nodes or Attributes
|
|
||||||
* @ingroup xmlModify
|
|
||||||
* Some deletion functions:
|
|
||||||
* @{
|
|
||||||
*/
|
|
||||||
/// The "deleteNodeContent" function forces the deletion of the content of this XMLNode and the subtree.
|
|
||||||
void deleteNodeContent();
|
|
||||||
/**< \note The XMLNode instances that are referring to the part of the subtree that has been deleted CANNOT be used anymore!!. Unexpected results will occur if you continue using them. */
|
|
||||||
void deleteAttribute(int i=0); ///< Delete the ith attribute of the current XMLNode
|
|
||||||
void deleteAttribute(XMLCSTR lpszName); ///< Delete the attribute with the given name (the "strcmp" function is used to find the right attribute)
|
|
||||||
void deleteAttribute(XMLAttribute *anAttribute); ///< Delete the attribute with the name "anAttribute->lpszName" (the "strcmp" function is used to find the right attribute)
|
|
||||||
void deleteText(int i=0); ///< Delete the Ith text content of the current XMLNode
|
|
||||||
void deleteText(XMLCSTR lpszValue); ///< Delete the text content "lpszValue" inside the current XMLNode (direct "pointer-to-pointer" comparison is used to find the right text)
|
|
||||||
void deleteClear(int i=0); ///< Delete the Ith clear tag inside the current XMLNode
|
|
||||||
void deleteClear(XMLCSTR lpszValue); ///< Delete the clear tag "lpszValue" inside the current XMLNode (direct "pointer-to-pointer" comparison is used to find the clear tag)
|
|
||||||
void deleteClear(XMLClear *p); ///< Delete the clear tag "p" inside the current XMLNode (direct "pointer-to-pointer" comparison on the lpszName of the clear tag is used to find the clear tag)
|
|
||||||
/** @} */
|
|
||||||
|
|
||||||
/** @defgroup xmlWOSD ???_WOSD functions.
|
|
||||||
* @ingroup xmlModify
|
|
||||||
* The strings given as parameters for the "add" and "update" methods that have a name with
|
|
||||||
* the postfix "_WOSD" (that means "WithOut String Duplication")(for example "addText_WOSD")
|
|
||||||
* will be free'd by the XMLNode class. For example, it means that this is incorrect:
|
|
||||||
* \code
|
|
||||||
* xNode.addText_WOSD("foo");
|
|
||||||
* xNode.updateAttribute_WOSD("#newcolor" ,NULL,"color");
|
|
||||||
* \endcode
|
|
||||||
* In opposition, this is correct:
|
|
||||||
* \code
|
|
||||||
* xNode.addText("foo");
|
|
||||||
* xNode.addText_WOSD(stringDup("foo"));
|
|
||||||
* xNode.updateAttribute("#newcolor" ,NULL,"color");
|
|
||||||
* xNode.updateAttribute_WOSD(stringDup("#newcolor"),NULL,"color");
|
|
||||||
* \endcode
|
|
||||||
* Typically, you will never do:
|
|
||||||
* \code
|
|
||||||
* char *b=(char*)malloc(...);
|
|
||||||
* xNode.addText(b);
|
|
||||||
* free(b);
|
|
||||||
* \endcode
|
|
||||||
* ... but rather:
|
|
||||||
* \code
|
|
||||||
* char *b=(char*)malloc(...);
|
|
||||||
* xNode.addText_WOSD(b);
|
|
||||||
* \endcode
|
|
||||||
* ('free(b)' is performed by the XMLNode class)
|
|
||||||
* @{ */
|
|
||||||
static XMLNode createXMLTopNode_WOSD(XMLSTR lpszName, char isDeclaration=FALSE); ///< Create the top node of an XMLNode structure
|
|
||||||
XMLNode addChild_WOSD(XMLSTR lpszName, char isDeclaration=FALSE, XMLElementPosition pos=-1); ///< Add a new child node
|
|
||||||
XMLAttribute *addAttribute_WOSD(XMLSTR lpszName, XMLSTR lpszValue); ///< Add a new attribute
|
|
||||||
XMLCSTR addText_WOSD(XMLSTR lpszValue, XMLElementPosition pos=-1); ///< Add a new text content
|
|
||||||
XMLClear *addClear_WOSD(XMLSTR lpszValue, XMLCSTR lpszOpen=NULL, XMLCSTR lpszClose=NULL, XMLElementPosition pos=-1); ///< Add a new clear Tag
|
|
||||||
|
|
||||||
XMLCSTR updateName_WOSD(XMLSTR lpszName); ///< change node's name
|
|
||||||
XMLAttribute *updateAttribute_WOSD(XMLAttribute *newAttribute, XMLAttribute *oldAttribute); ///< if the attribute to update is missing, a new one will be added
|
|
||||||
XMLAttribute *updateAttribute_WOSD(XMLSTR lpszNewValue, XMLSTR lpszNewName=NULL,int i=0); ///< if the attribute to update is missing, a new one will be added
|
|
||||||
XMLAttribute *updateAttribute_WOSD(XMLSTR lpszNewValue, XMLSTR lpszNewName,XMLCSTR lpszOldName); ///< set lpszNewName=NULL if you don't want to change the name of the attribute if the attribute to update is missing, a new one will be added
|
|
||||||
XMLCSTR updateText_WOSD(XMLSTR lpszNewValue, int i=0); ///< if the text to update is missing, a new one will be added
|
|
||||||
XMLCSTR updateText_WOSD(XMLSTR lpszNewValue, XMLCSTR lpszOldValue); ///< if the text to update is missing, a new one will be added
|
|
||||||
XMLClear *updateClear_WOSD(XMLSTR lpszNewContent, int i=0); ///< if the clearTag to update is missing, a new one will be added
|
|
||||||
XMLClear *updateClear_WOSD(XMLClear *newP,XMLClear *oldP); ///< if the clearTag to update is missing, a new one will be added
|
|
||||||
XMLClear *updateClear_WOSD(XMLSTR lpszNewValue, XMLCSTR lpszOldValue); ///< if the clearTag to update is missing, a new one will be added
|
|
||||||
/** @} */
|
|
||||||
|
|
||||||
/** @defgroup xmlPosition Position helper functions (use in conjunction with the update&add functions
|
|
||||||
* @ingroup xmlModify
|
|
||||||
* These are some useful functions when you want to insert a childNode, a text or a XMLClearTag in the
|
|
||||||
* middle (at a specified position) of a XMLNode tree already constructed. The value returned by these
|
|
||||||
* methods is to be used as last parameter (parameter 'pos') of addChild, addText or addClear.
|
|
||||||
* @{ */
|
|
||||||
XMLElementPosition positionOfText(int i=0) const;
|
|
||||||
XMLElementPosition positionOfText(XMLCSTR lpszValue) const;
|
|
||||||
XMLElementPosition positionOfClear(int i=0) const;
|
|
||||||
XMLElementPosition positionOfClear(XMLCSTR lpszValue) const;
|
|
||||||
XMLElementPosition positionOfClear(XMLClear *a) const;
|
|
||||||
XMLElementPosition positionOfChildNode(int i=0) const;
|
|
||||||
XMLElementPosition positionOfChildNode(XMLNode x) const;
|
|
||||||
XMLElementPosition positionOfChildNode(XMLCSTR name, int i=0) const; ///< return the position of the ith childNode with the specified name if (name==NULL) return the position of the ith childNode
|
|
||||||
/** @} */
|
|
||||||
|
|
||||||
/// Enumeration for XML character encoding.
|
|
||||||
typedef enum XMLCharEncoding
|
|
||||||
{
|
|
||||||
char_encoding_error=0,
|
|
||||||
char_encoding_UTF8=1,
|
|
||||||
char_encoding_legacy=2,
|
|
||||||
char_encoding_ShiftJIS=3,
|
|
||||||
char_encoding_GB2312=4,
|
|
||||||
char_encoding_Big5=5,
|
|
||||||
char_encoding_GBK=6 // this is actually the same as Big5
|
|
||||||
} XMLCharEncoding;
|
|
||||||
|
|
||||||
/** \addtogroup conversions
|
|
||||||
* @{ */
|
|
||||||
|
|
||||||
/// Sets the global options for the conversions
|
|
||||||
static char setGlobalOptions(XMLCharEncoding characterEncoding=XMLNode::char_encoding_UTF8, char guessWideCharChars=1,
|
|
||||||
char dropWhiteSpace=1, char removeCommentsInMiddleOfText=1);
|
|
||||||
/**< The "setGlobalOptions" function allows you to change four global parameters that affect string & file
|
|
||||||
* parsing. First of all, you most-probably will never have to change these 3 global parameters.
|
|
||||||
*
|
|
||||||
* @param guessWideCharChars If "guessWideCharChars"=1 and if this library is compiled in WideChar mode, then the
|
|
||||||
* XMLNode::parseFile and XMLNode::openFileHelper functions will test if the file contains ASCII
|
|
||||||
* characters. If this is the case, then the file will be loaded and converted in memory to
|
|
||||||
* WideChar before being parsed. If 0, no conversion will be performed.
|
|
||||||
*
|
|
||||||
* @param guessWideCharChars If "guessWideCharChars"=1 and if this library is compiled in ASCII/UTF8/char* mode, then the
|
|
||||||
* XMLNode::parseFile and XMLNode::openFileHelper functions will test if the file contains WideChar
|
|
||||||
* characters. If this is the case, then the file will be loaded and converted in memory to
|
|
||||||
* ASCII/UTF8/char* before being parsed. If 0, no conversion will be performed.
|
|
||||||
*
|
|
||||||
* @param characterEncoding This parameter is only meaningful when compiling in char* mode (multibyte character mode).
|
|
||||||
* In wchar_t* (wide char mode), this parameter is ignored. This parameter should be one of the
|
|
||||||
* three currently recognized encodings: XMLNode::encoding_UTF8, XMLNode::encoding_ascii,
|
|
||||||
* XMLNode::encoding_ShiftJIS.
|
|
||||||
*
|
|
||||||
* @param dropWhiteSpace In most situations, text fields containing only white spaces (and carriage returns)
|
|
||||||
* are useless. Even more, these "empty" text fields are annoying because they increase the
|
|
||||||
* complexity of the user's code for parsing. So, 99% of the time, it's better to drop
|
|
||||||
* the "empty" text fields. However The XML specification indicates that no white spaces
|
|
||||||
* should be lost when parsing the file. So to be perfectly XML-compliant, you should set
|
|
||||||
* dropWhiteSpace=0. A note of caution: if you set "dropWhiteSpace=0", the parser will be
|
|
||||||
* slower and your code will be more complex.
|
|
||||||
*
|
|
||||||
* @param removeCommentsInMiddleOfText To explain this parameter, let's consider this code:
|
|
||||||
* \code
|
|
||||||
* XMLNode x=XMLNode::parseString("<a>foo<!-- hello -->bar<!DOCTYPE world >chu</a>","a");
|
|
||||||
* \endcode
|
|
||||||
* If removeCommentsInMiddleOfText=0, then we will have:
|
|
||||||
* \code
|
|
||||||
* x.getText(0) -> "foo"
|
|
||||||
* x.getText(1) -> "bar"
|
|
||||||
* x.getText(2) -> "chu"
|
|
||||||
* x.getClear(0) --> "<!-- hello -->"
|
|
||||||
* x.getClear(1) --> "<!DOCTYPE world >"
|
|
||||||
* \endcode
|
|
||||||
* If removeCommentsInMiddleOfText=1, then we will have:
|
|
||||||
* \code
|
|
||||||
* x.getText(0) -> "foobar"
|
|
||||||
* x.getText(1) -> "chu"
|
|
||||||
* x.getClear(0) --> "<!DOCTYPE world >"
|
|
||||||
* \endcode
|
|
||||||
*
|
|
||||||
* \return "0" when there are no errors. If you try to set an unrecognized encoding then the return value will be "1" to signal an error.
|
|
||||||
*
|
|
||||||
* \note Sometime, it's useful to set "guessWideCharChars=0" to disable any conversion
|
|
||||||
* because the test to detect the file-type (ASCII/UTF8/char* or WideChar) may fail (rarely). */
|
|
||||||
|
|
||||||
/// Guess the character encoding of the string (ascii, utf8 or shift-JIS)
|
|
||||||
static XMLCharEncoding guessCharEncoding(void *buffer, int bufLen, char useXMLEncodingAttribute=1);
|
|
||||||
/**< The "guessCharEncoding" function try to guess the character encoding. You most-probably will never
|
|
||||||
* have to use this function. It then returns the appropriate value of the global parameter
|
|
||||||
* "characterEncoding" described in the XMLNode::setGlobalOptions. The guess is based on the content of a buffer of length
|
|
||||||
* "bufLen" bytes that contains the first bytes (minimum 25 bytes; 200 bytes is a good value) of the
|
|
||||||
* file to be parsed. The XMLNode::openFileHelper function is using this function to automatically compute
|
|
||||||
* the value of the "characterEncoding" global parameter. There are several heuristics used to do the
|
|
||||||
* guess. One of the heuristic is based on the "encoding" attribute. The original XML specifications
|
|
||||||
* forbids to use this attribute to do the guess but you can still use it if you set
|
|
||||||
* "useXMLEncodingAttribute" to 1 (this is the default behavior and the behavior of most parsers).
|
|
||||||
* If an inconsistency in the encoding is detected, then the return value is "0". */
|
|
||||||
/** @} */
|
|
||||||
|
|
||||||
private:
|
|
||||||
// these are functions and structures used internally by the XMLNode class (don't bother about them):
|
|
||||||
|
|
||||||
typedef struct XMLNodeDataTag // to allow shallow copy and "intelligent/smart" pointers (automatic delete):
|
|
||||||
{
|
|
||||||
XMLCSTR lpszName; // Element name (=NULL if root)
|
|
||||||
int nChild, // Number of child nodes
|
|
||||||
nText, // Number of text fields
|
|
||||||
nClear, // Number of Clear fields (comments)
|
|
||||||
nAttribute; // Number of attributes
|
|
||||||
char isDeclaration; // Whether node is an XML declaration - '<?xml ?>'
|
|
||||||
struct XMLNodeDataTag *pParent; // Pointer to parent element (=NULL if root)
|
|
||||||
XMLNode *pChild; // Array of child nodes
|
|
||||||
XMLCSTR *pText; // Array of text fields
|
|
||||||
XMLClear *pClear; // Array of clear fields
|
|
||||||
XMLAttribute *pAttribute; // Array of attributes
|
|
||||||
int *pOrder; // order of the child_nodes,text_fields,clear_fields
|
|
||||||
int ref_count; // for garbage collection (smart pointers)
|
|
||||||
} XMLNodeData;
|
|
||||||
XMLNodeData *d;
|
|
||||||
|
|
||||||
char parseClearTag(void *px, void *pa);
|
|
||||||
char maybeAddTxT(void *pa, XMLCSTR tokenPStr);
|
|
||||||
int ParseXMLElement(void *pXML);
|
|
||||||
void *addToOrder(int memInc, int *_pos, int nc, void *p, int size, XMLElementType xtype);
|
|
||||||
int indexText(XMLCSTR lpszValue) const;
|
|
||||||
int indexClear(XMLCSTR lpszValue) const;
|
|
||||||
XMLNode addChild_priv(int,XMLSTR,char,int);
|
|
||||||
XMLAttribute *addAttribute_priv(int,XMLSTR,XMLSTR);
|
|
||||||
XMLCSTR addText_priv(int,XMLSTR,int);
|
|
||||||
XMLClear *addClear_priv(int,XMLSTR,XMLCSTR,XMLCSTR,int);
|
|
||||||
void emptyTheNode(char force);
|
|
||||||
static inline XMLElementPosition findPosition(XMLNodeData *d, int index, XMLElementType xtype);
|
|
||||||
static int CreateXMLStringR(XMLNodeData *pEntry, XMLSTR lpszMarker, int nFormat);
|
|
||||||
static int removeOrderElement(XMLNodeData *d, XMLElementType t, int index);
|
|
||||||
static void exactMemory(XMLNodeData *d);
|
|
||||||
static int detachFromParent(XMLNodeData *d);
|
|
||||||
} XMLNode;
|
|
||||||
|
|
||||||
/// This structure is given by the function XMLNode::enumContents.
|
|
||||||
typedef struct XMLNodeContents
|
|
||||||
{
|
|
||||||
/// This dictates what's the content of the XMLNodeContent
|
|
||||||
enum XMLElementType etype;
|
|
||||||
/**< should be an union to access the appropriate data. Compiler does not allow union of object with constructor... too bad. */
|
|
||||||
XMLNode child;
|
|
||||||
XMLAttribute attrib;
|
|
||||||
XMLCSTR text;
|
|
||||||
XMLClear clear;
|
|
||||||
|
|
||||||
} XMLNodeContents;
|
|
||||||
|
|
||||||
/** @defgroup StringAlloc String Allocation/Free functions
|
|
||||||
* @ingroup xmlModify
|
|
||||||
* @{ */
|
|
||||||
/// Duplicate (copy in a new allocated buffer) the source string.
|
|
||||||
XMLDLLENTRY XMLSTR stringDup(XMLCSTR source, int cbData=-1);
|
|
||||||
/**< This is
|
|
||||||
* a very handy function when used with all the "XMLNode::*_WOSD" functions (\link xmlWOSD \endlink).
|
|
||||||
* @param cbData If !=0 then cbData is the number of chars to duplicate. New strings allocated with
|
|
||||||
* this function should be free'd using the "freeXMLString" function. */
|
|
||||||
|
|
||||||
/// to free the string allocated inside the "stringDup" function or the "createXMLString" function.
|
|
||||||
XMLDLLENTRY void freeXMLString(XMLSTR t); // {free(t);}
|
|
||||||
/** @} */
|
|
||||||
|
|
||||||
/** @defgroup atoX ato? like functions
|
|
||||||
* @ingroup XMLParserGeneral
|
|
||||||
* The "xmlto?" functions are equivalents to the atoi, atol, atof functions.
|
|
||||||
* The only difference is: If the variable "xmlString" is NULL, than the return value
|
|
||||||
* is "defautValue". These 6 functions are only here as "convenience" functions for the
|
|
||||||
* user (they are not used inside the XMLparser). If you don't need them, you can
|
|
||||||
* delete them without any trouble.
|
|
||||||
*
|
|
||||||
* @{ */
|
|
||||||
XMLDLLENTRY char xmltob(XMLCSTR xmlString,char defautValue=0);
|
|
||||||
XMLDLLENTRY int xmltoi(XMLCSTR xmlString,int defautValue=0);
|
|
||||||
XMLDLLENTRY long xmltol(XMLCSTR xmlString,long defautValue=0);
|
|
||||||
XMLDLLENTRY double xmltof(XMLCSTR xmlString,double defautValue=.0);
|
|
||||||
XMLDLLENTRY XMLCSTR xmltoa(XMLCSTR xmlString,XMLCSTR defautValue=_CXML(""));
|
|
||||||
XMLDLLENTRY XMLCHAR xmltoc(XMLCSTR xmlString,const XMLCHAR defautValue=_CXML('\0'));
|
|
||||||
/** @} */
|
|
||||||
|
|
||||||
/** @defgroup ToXMLStringTool Helper class to create XML files using "printf", "fprintf", "cout",... functions.
|
|
||||||
* @ingroup XMLParserGeneral
|
|
||||||
* @{ */
|
|
||||||
/// Helper class to create XML files using "printf", "fprintf", "cout",... functions.
|
|
||||||
/** The ToXMLStringTool class helps you creating XML files using "printf", "fprintf", "cout",... functions.
|
|
||||||
* The "ToXMLStringTool" class is processing strings so that all the characters
|
|
||||||
* &,",',<,> are replaced by their XML equivalent:
|
|
||||||
* \verbatim &, ", ', <, > \endverbatim
|
|
||||||
* Using the "ToXMLStringTool class" and the "fprintf function" is THE most efficient
|
|
||||||
* way to produce VERY large XML documents VERY fast.
|
|
||||||
* \note If you are creating from scratch an XML file using the provided XMLNode class
|
|
||||||
* you must not use the "ToXMLStringTool" class (because the "XMLNode" class does the
|
|
||||||
* processing job for you during rendering).*/
|
|
||||||
typedef struct XMLDLLENTRY ToXMLStringTool
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
ToXMLStringTool(): buf(NULL),buflen(0){}
|
|
||||||
~ToXMLStringTool();
|
|
||||||
void freeBuffer();///<call this function when you have finished using this object to release memory used by the internal buffer.
|
|
||||||
|
|
||||||
XMLSTR toXML(XMLCSTR source);///< returns a pointer to an internal buffer that contains a XML-encoded string based on the "source" parameter.
|
|
||||||
|
|
||||||
/** The "toXMLUnSafe" function is deprecated because there is a possibility of
|
|
||||||
* "destination-buffer-overflow". It converts the string
|
|
||||||
* "source" to the string "dest". */
|
|
||||||
static XMLSTR toXMLUnSafe(XMLSTR dest,XMLCSTR source); ///< deprecated: use "toXML" instead
|
|
||||||
static int lengthXMLString(XMLCSTR source); ///< deprecated: use "toXML" instead
|
|
||||||
|
|
||||||
private:
|
|
||||||
XMLSTR buf;
|
|
||||||
int buflen;
|
|
||||||
} ToXMLStringTool;
|
|
||||||
/** @} */
|
|
||||||
|
|
||||||
/** @defgroup XMLParserBase64Tool Helper class to include binary data inside XML strings using "Base64 encoding".
|
|
||||||
* @ingroup XMLParserGeneral
|
|
||||||
* @{ */
|
|
||||||
/// Helper class to include binary data inside XML strings using "Base64 encoding".
|
|
||||||
/** The "XMLParserBase64Tool" class allows you to include any binary data (images, sounds,...)
|
|
||||||
* into an XML document using "Base64 encoding". This class is completely
|
|
||||||
* separated from the rest of the xmlParser library and can be removed without any problem.
|
|
||||||
* To include some binary data into an XML file, you must convert the binary data into
|
|
||||||
* standard text (using "encode"). To retrieve the original binary data from the
|
|
||||||
* b64-encoded text included inside the XML file, use "decode". Alternatively, these
|
|
||||||
* functions can also be used to "encrypt/decrypt" some critical data contained inside
|
|
||||||
* the XML (it's not a strong encryption at all, but sometimes it can be useful). */
|
|
||||||
typedef struct XMLDLLENTRY XMLParserBase64Tool
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
XMLParserBase64Tool(): buf(NULL),buflen(0){}
|
|
||||||
~XMLParserBase64Tool();
|
|
||||||
void freeBuffer();///< Call this function when you have finished using this object to release memory used by the internal buffer.
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param formatted If "formatted"=true, some space will be reserved for a carriage-return every 72 chars. */
|
|
||||||
static int encodeLength(int inBufLen, char formatted=0); ///< return the length of the base64 string that encodes a data buffer of size inBufLen bytes.
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The "base64Encode" function returns a string containing the base64 encoding of "inByteLen" bytes
|
|
||||||
* from "inByteBuf". If "formatted" parameter is true, then there will be a carriage-return every 72 chars.
|
|
||||||
* The string will be free'd when the XMLParserBase64Tool object is deleted.
|
|
||||||
* All returned strings are sharing the same memory space. */
|
|
||||||
XMLSTR encode(unsigned char *inByteBuf, unsigned int inByteLen, char formatted=0); ///< returns a pointer to an internal buffer containing the base64 string containing the binary data encoded from "inByteBuf"
|
|
||||||
|
|
||||||
/// returns the number of bytes which will be decoded from "inString".
|
|
||||||
static unsigned int decodeSize(XMLCSTR inString, XMLError *xe=NULL);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The "decode" function returns a pointer to a buffer containing the binary data decoded from "inString"
|
|
||||||
* The output buffer will be free'd when the XMLParserBase64Tool object is deleted.
|
|
||||||
* All output buffer are sharing the same memory space.
|
|
||||||
* @param inString If "instring" is malformed, NULL will be returned */
|
|
||||||
unsigned char* decode(XMLCSTR inString, int *outByteLen=NULL, XMLError *xe=NULL); ///< returns a pointer to an internal buffer containing the binary data decoded from "inString"
|
|
||||||
|
|
||||||
/**
|
|
||||||
* decodes data from "inString" to "outByteBuf". You need to provide the size (in byte) of "outByteBuf"
|
|
||||||
* in "inMaxByteOutBuflen". If "outByteBuf" is not large enough or if data is malformed, then "FALSE"
|
|
||||||
* will be returned; otherwise "TRUE". */
|
|
||||||
static unsigned char decode(XMLCSTR inString, unsigned char *outByteBuf, int inMaxByteOutBuflen, XMLError *xe=NULL); ///< deprecated.
|
|
||||||
|
|
||||||
private:
|
|
||||||
void *buf;
|
|
||||||
int buflen;
|
|
||||||
void alloc(int newsize);
|
|
||||||
}XMLParserBase64Tool;
|
|
||||||
/** @} */
|
|
||||||
|
|
||||||
#undef XMLDLLENTRY
|
|
||||||
|
|
||||||
#endif
|
|
@ -50,7 +50,7 @@ init_fove_solver(_, AllAttVars, _, fove(ParfactorList, DistIds)) :-
|
|||||||
get_dist_ids(Parfactors, DistIds0),
|
get_dist_ids(Parfactors, DistIds0),
|
||||||
sort(DistIds0, DistIds),
|
sort(DistIds0, DistIds),
|
||||||
get_observed_vars(AllAttVars, ObservedVars),
|
get_observed_vars(AllAttVars, ObservedVars),
|
||||||
writeln(factors:Parfactors:'\n'),
|
writeln(parfactors:Parfactors:'\n'),
|
||||||
writeln(evidence:ObservedVars:'\n'),
|
writeln(evidence:ObservedVars:'\n'),
|
||||||
create_lifted_network(Parfactors,ObservedVars,ParfactorList).
|
create_lifted_network(Parfactors,ObservedVars,ParfactorList).
|
||||||
|
|
||||||
@ -139,11 +139,11 @@ get_dists_parameters([Id|Ids], [dist(Id, Params)|DistsInfo]) :-
|
|||||||
|
|
||||||
|
|
||||||
run_fove_solver(QueryVarsAtts, Solutions, fove(ParfactorList, DistIds)) :-
|
run_fove_solver(QueryVarsAtts, Solutions, fove(ParfactorList, DistIds)) :-
|
||||||
get_dists_parameters(DistIds, DistsParams),
|
|
||||||
writeln(distParams:DistsParams),
|
|
||||||
set_parfactors_params(ParfactorList, DistsParams),
|
|
||||||
get_query_vars(QueryVarsAtts, QueryVars),
|
get_query_vars(QueryVarsAtts, QueryVars),
|
||||||
writeln(queryVars:QueryVars), writeln(''),
|
writeln(queryVars:QueryVars), writeln(''),
|
||||||
|
get_dists_parameters(DistIds, DistsParams),
|
||||||
|
writeln(dists:DistsParams), writeln(''),
|
||||||
|
set_parfactors_params(ParfactorList, DistsParams),
|
||||||
run_lifted_solver(ParfactorList, QueryVars, Solutions).
|
run_lifted_solver(ParfactorList, QueryVars, Solutions).
|
||||||
|
|
||||||
|
|
||||||
|
@ -6,32 +6,45 @@
|
|||||||
********************************************************/
|
********************************************************/
|
||||||
|
|
||||||
:- module(clpbn_horus,
|
:- module(clpbn_horus,
|
||||||
[create_lifted_network/3,
|
[set_solver/1,
|
||||||
create_ground_network/2,
|
create_lifted_network/3,
|
||||||
|
create_ground_network/4,
|
||||||
set_parfactors_params/2,
|
set_parfactors_params/2,
|
||||||
set_bayes_net_params/2,
|
set_factors_params/2,
|
||||||
run_lifted_solver/3,
|
run_lifted_solver/3,
|
||||||
run_ground_solver/3,
|
run_ground_solver/3,
|
||||||
set_extra_vars_info/2,
|
set_vars_information/2,
|
||||||
set_horus_flag/2,
|
set_horus_flag/2,
|
||||||
free_parfactors/1,
|
free_parfactors/1,
|
||||||
free_bayesian_network/1
|
free_ground_network/1
|
||||||
]).
|
]).
|
||||||
|
|
||||||
|
|
||||||
|
:- use_module(library(pfl),
|
||||||
|
[set_pfl_flag/2]).
|
||||||
|
|
||||||
|
|
||||||
patch_things_up :-
|
patch_things_up :-
|
||||||
assert_static(clpbn_horus:set_horus_flag(_,_)).
|
assert_static(clpbn_horus:set_horus_flag(_,_)).
|
||||||
|
|
||||||
warning :-
|
warning :-
|
||||||
format(user_error,"Horus library not installed: cannot use bp, fove~n.",[]).
|
format(user_error,"Horus library not installed: cannot use bp, fove~n.",[]).
|
||||||
|
|
||||||
:- catch(load_foreign_files([horus], [], init_predicates), _, patch_things_up) -> true ; warning.
|
:- catch(load_foreign_files([horus], [], init_predicates), _, patch_things_up) -> true ; warning.
|
||||||
|
|
||||||
|
|
||||||
|
set_solver(ve) :- set_pfl_flag(solver,ve).
|
||||||
|
set_solver(jt) :- set_pfl_flag(solver,jt).
|
||||||
|
set_solver(gibbs) :- set_pfl_flag(solver,gibbs).
|
||||||
|
set_solver(fove) :- set_pfl_flag(solver,fove).
|
||||||
|
set_solver(hve) :- set_pfl_flag(solver,bp), set_horus_flag(inf_alg, ve).
|
||||||
|
set_solver(bp) :- set_pfl_flag(solver,bp), set_horus_flag(inf_alg, bp).
|
||||||
|
set_solver(cbp) :- set_pfl_flag(solver,bp), set_horus_flag(inf_alg, cbp).
|
||||||
|
set_solver(S) :- throw(error('unknow solver ', S)).
|
||||||
|
|
||||||
|
|
||||||
%:- set_horus_flag(inf_alg, ve).
|
%:- set_horus_flag(inf_alg, ve).
|
||||||
:- set_horus_flag(inf_alg, bn_bp).
|
%:- set_horus_flag(inf_alg, bp).
|
||||||
%:- set_horus_flag(inf_alg, fg_bp).
|
|
||||||
%: -set_horus_flag(inf_alg, cbp).
|
%: -set_horus_flag(inf_alg, cbp).
|
||||||
|
|
||||||
:- set_horus_flag(schedule, seq_fixed).
|
:- set_horus_flag(schedule, seq_fixed).
|
||||||
@ -46,7 +59,6 @@ warning :-
|
|||||||
:- set_horus_flag(order_factor_variables, false).
|
:- set_horus_flag(order_factor_variables, false).
|
||||||
%:- set_horus_flag(order_factor_variables, true).
|
%:- set_horus_flag(order_factor_variables, true).
|
||||||
|
|
||||||
|
|
||||||
:- set_horus_flag(use_logarithms, false).
|
:- set_horus_flag(use_logarithms, false).
|
||||||
% :- set_horus_flag(use_logarithms, true).
|
% :- set_horus_flag(use_logarithms, true).
|
||||||
|
|
||||||
|
@ -41,18 +41,19 @@
|
|||||||
user:term_expansion( bayes((Formula ; Phi ; Constraints)), pfl:factor(bayes,Id,FList,FV,Phi,Constraints)) :-
|
user:term_expansion( bayes((Formula ; Phi ; Constraints)), pfl:factor(bayes,Id,FList,FV,Phi,Constraints)) :-
|
||||||
!,
|
!,
|
||||||
term_variables(Formula, FreeVars),
|
term_variables(Formula, FreeVars),
|
||||||
FV =.. [fv|FreeVars],
|
FV =.. [''|FreeVars],
|
||||||
new_id(Id),
|
new_id(Id),
|
||||||
process_args(Formula, Id, 0, _, FList, []).
|
process_args(Formula, Id, 0, _, FList, []).
|
||||||
user:term_expansion( markov((Formula ; Phi ; Constraints)), pfl:factor(markov,Id,FList,FV,Phi,Constraints)) :-
|
user:term_expansion( markov((Formula ; Phi ; Constraints)), pfl:factor(markov,Id,FList,FV,Phi,Constraints)) :-
|
||||||
!,
|
!,
|
||||||
term_variables(Formula, FreeVars),
|
term_variables(Formula, FreeVars),
|
||||||
FV =.. [fv|FreeVars],
|
FV =.. [''|FreeVars],
|
||||||
new_id(Id),
|
new_id(Id),
|
||||||
process_args(Formula, Id, 0, _, FList, []).
|
process_args(Formula, Id, 0, _, FList, []).
|
||||||
user:term_expansion( Id@N, L ) :-
|
user:term_expansion( Id@N, L ) :-
|
||||||
atom(Id), number(N), !,
|
atom(Id), number(N), !,
|
||||||
findall(G,generate_entity(0, N, Id, G), L).
|
N1 is N + 1,
|
||||||
|
findall(G,generate_entity(1, N1, Id, G), L).
|
||||||
user:term_expansion( Goal, [] ) :-
|
user:term_expansion( Goal, [] ) :-
|
||||||
preprocess(Goal, Sk,Var), !,
|
preprocess(Goal, Sk,Var), !,
|
||||||
(ground(Goal) -> true ; throw(error('non ground evidence',Goal))),
|
(ground(Goal) -> true ; throw(error('non ground evidence',Goal))),
|
||||||
@ -78,7 +79,7 @@ defined_in_factor(Key, Factor) :-
|
|||||||
|
|
||||||
generate_entity(N, N, _, _) :- !.
|
generate_entity(N, N, _, _) :- !.
|
||||||
generate_entity(I0, _N, Id, T) :-
|
generate_entity(I0, _N, Id, T) :-
|
||||||
atomic_concat(person_, I0, P),
|
atomic_concat(p, I0, P),
|
||||||
T =.. [Id, P].
|
T =.. [Id, P].
|
||||||
generate_entity(I0, N, Id, T) :-
|
generate_entity(I0, N, Id, T) :-
|
||||||
I is I0+1,
|
I is I0+1,
|
||||||
@ -145,7 +146,7 @@ add_evidence(Sk,Var) :-
|
|||||||
|
|
||||||
get_pfl_parameters(Id,Out) :-
|
get_pfl_parameters(Id,Out) :-
|
||||||
factor(_Type,Id,_FList,_FV,Phi,_Constraints),
|
factor(_Type,Id,_FList,_FV,Phi,_Constraints),
|
||||||
writeln(factor(_Type,Id,_FList,_FV,_Phi,_Constraints)),
|
%writeln(factor(_Type,Id,_FList,_FV,_Phi,_Constraints)),
|
||||||
( is_list(Phi) -> Out = Phi ; call(user:Phi, Out) ).
|
( is_list(Phi) -> Out = Phi ; call(user:Phi, Out) ).
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user