Improvements
Factor nodes now contain a factor object instead of a pointer. Refactor the way .fg and .uai formats are readed.
This commit is contained in:
parent
f1d52c0389
commit
6986e8c0d7
@ -40,7 +40,6 @@
|
|||||||
get_pfl_parameters/2
|
get_pfl_parameters/2
|
||||||
]).
|
]).
|
||||||
|
|
||||||
% :- use_module(library(clpbn/horus)).
|
|
||||||
|
|
||||||
:- use_module(library(lists)).
|
:- use_module(library(lists)).
|
||||||
|
|
||||||
@ -53,45 +52,42 @@
|
|||||||
[create_ground_network/4,
|
[create_ground_network/4,
|
||||||
set_bayes_net_params/2,
|
set_bayes_net_params/2,
|
||||||
run_ground_solver/3,
|
run_ground_solver/3,
|
||||||
set_extra_vars_info/2,
|
set_vars_information/2,
|
||||||
free_ground_network/1
|
free_ground_network/1
|
||||||
]).
|
]).
|
||||||
|
|
||||||
|
|
||||||
call_bp_ground(QueryKeys, AllKeys, Factors, Evidence, Solutions) :-
|
call_bp_ground(QueryKeys, AllKeys, Factors, Evidence, Output) :-
|
||||||
b_hash_new(Hash0),
|
b_hash_new(Hash0),
|
||||||
keys_to_ids(AllKeys, 0, Hash0, Hash),
|
keys_to_ids(AllKeys, 0, Hash0, Hash),
|
||||||
%InvMap =.. [view|AllKeys],
|
|
||||||
list_of_keys_to_ids(QueryKeys, Hash, QueryIds),
|
|
||||||
evidence_to_ids(Evidence, Hash, EvIds),
|
|
||||||
factors_to_ids(Factors, Hash, FactorIds),
|
|
||||||
get_factors_type(Factors, Type),
|
get_factors_type(Factors, Type),
|
||||||
|
evidence_to_ids(Evidence, Hash, EvidenceIds),
|
||||||
|
factors_to_ids(Factors, Hash, FactorIds),
|
||||||
writeln(type:Type), writeln(''),
|
writeln(type:Type), writeln(''),
|
||||||
writeln(allKeys:AllKeys), writeln(''),
|
writeln(allKeys:AllKeys), writeln(''),
|
||||||
%writeln(allKeysIds:Hash), writeln(''),
|
|
||||||
writeln(queryKeys:QueryKeys), writeln(''),
|
|
||||||
writeln(queryIds:QueryIds), writeln(''),
|
|
||||||
writeln(factors:Factors), writeln(''),
|
writeln(factors:Factors), writeln(''),
|
||||||
writeln(factorIds:FactorIds), writeln(''),
|
writeln(factorIds:FactorIds), writeln(''),
|
||||||
writeln(evidence:Evidence), writeln(''),
|
writeln(evidence:Evidence), writeln(''),
|
||||||
writeln(evIds:EvIds),
|
writeln(evidenceIds:EvidenceIds), writeln(''),
|
||||||
create_ground_network(Type, FactorIds, EvIds, Network),
|
create_ground_network(Type, FactorIds, EvidenceIds, Network),
|
||||||
run_ground_fixme_solver(ground(Network,Hash), QueryKeys, Solutions).
|
%get_vars_information(AllKeys, StatesNames),
|
||||||
%free_graphical_model(Network).
|
%set_vars_information(AllKeys, StatesNames),
|
||||||
|
run_solver(ground(Network,Hash), QueryKeys, Solutions),
|
||||||
|
writeln(answer:Solutions),
|
||||||
|
%clpbn_bind_vals([QueryKeys], Solutions, Output).
|
||||||
|
free_ground_network(Network).
|
||||||
|
|
||||||
|
|
||||||
run_ground_fixme_solver(ground(Network,Hash), QueryKeys, Solutions) :-
|
run_solver(ground(Network,Hash), QueryKeys, Solutions) :-
|
||||||
%get_dists_parameters(DistIds, DistsParams),
|
%get_dists_parameters(DistIds, DistsParams),
|
||||||
%set_bayes_net_params(Network, DistsParams),
|
%set_bayes_net_params(Network, DistsParams),
|
||||||
%vars_to_ids(QueryVars, QueryVarsIds),
|
list_of_keys_to_ids(QueryKeys, Hash, QueryIds),
|
||||||
|
writeln(queryKeys:QueryKeys), writeln(''),
|
||||||
|
writeln(queryIds:QueryIds), writeln(''),
|
||||||
list_of_keys_to_ids(QueryKeys, Hash, QueryIds),
|
list_of_keys_to_ids(QueryKeys, Hash, QueryIds),
|
||||||
run_ground_solver(Network, [QueryIds], Solutions).
|
run_ground_solver(Network, [QueryIds], Solutions).
|
||||||
|
|
||||||
|
|
||||||
get_factors_type([f(bayes, _, _)|_], bayes) :- ! .
|
|
||||||
get_factors_type([f(markov, _, _)|_], markov) :- ! .
|
|
||||||
|
|
||||||
|
|
||||||
keys_to_ids([], _, Hash, Hash).
|
keys_to_ids([], _, Hash, Hash).
|
||||||
keys_to_ids([Key|AllKeys], I0, Hash0, Hash) :-
|
keys_to_ids([Key|AllKeys], I0, Hash0, Hash) :-
|
||||||
b_hash_insert(Hash0, Key, I0, HashI),
|
b_hash_insert(Hash0, Key, I0, HashI),
|
||||||
@ -99,27 +95,16 @@ keys_to_ids([Key|AllKeys], I0, Hash0, Hash) :-
|
|||||||
keys_to_ids(AllKeys, I, HashI, Hash).
|
keys_to_ids(AllKeys, I, HashI, Hash).
|
||||||
|
|
||||||
|
|
||||||
|
get_factors_type([f(bayes, _, _)|_], bayes) :- ! .
|
||||||
|
get_factors_type([f(markov, _, _)|_], markov) :- ! .
|
||||||
|
|
||||||
|
|
||||||
list_of_keys_to_ids([], _, []).
|
list_of_keys_to_ids([], _, []).
|
||||||
list_of_keys_to_ids([Key|QueryKeys], Hash, [Id|QueryIds]) :-
|
list_of_keys_to_ids([Key|QueryKeys], Hash, [Id|QueryIds]) :-
|
||||||
b_hash_lookup(Key, Id, Hash),
|
b_hash_lookup(Key, Id, Hash),
|
||||||
list_of_keys_to_ids(QueryKeys, Hash, QueryIds).
|
list_of_keys_to_ids(QueryKeys, Hash, QueryIds).
|
||||||
|
|
||||||
|
|
||||||
evidence_to_ids([], _, []).
|
|
||||||
evidence_to_ids([Key=Ev|QueryKeys], Hash, [Id=Ev|QueryIds]) :-
|
|
||||||
b_hash_lookup(Key, Id, Hash),
|
|
||||||
evidence_to_ids(QueryKeys, Hash, QueryIds).
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
%evidence_to_ids([], _, [], []).
|
|
||||||
%evidence_to_ids([Key=V|QueryKeys], Hash, [Id=V|QueryIds], [Id=Name|QueryNames]) :-
|
|
||||||
% b_hash_lookup(Key, Id, Hash),
|
|
||||||
% pfl:skolem(Key,Dom),
|
|
||||||
% nth0(V, Dom, Name),
|
|
||||||
% evidence_to_ids(QueryKeys, Hash, QueryIds, QueryNames).
|
|
||||||
|
|
||||||
|
|
||||||
factors_to_ids([], _, []).
|
factors_to_ids([], _, []).
|
||||||
factors_to_ids([f(_, Keys, CPT)|Fs], Hash, [f(Ids, Ranges, CPT, DistId)|NFs]) :-
|
factors_to_ids([f(_, Keys, CPT)|Fs], Hash, [f(Ids, Ranges, CPT, DistId)|NFs]) :-
|
||||||
list_of_keys_to_ids(Keys, Hash, Ids),
|
list_of_keys_to_ids(Keys, Hash, Ids),
|
||||||
@ -135,6 +120,22 @@ get_ranges(K.Ks, Range.Rs) :- !,
|
|||||||
get_ranges(Ks, Rs).
|
get_ranges(Ks, Rs).
|
||||||
|
|
||||||
|
|
||||||
|
evidence_to_ids([], _, []).
|
||||||
|
evidence_to_ids([Key=Ev|QueryKeys], Hash, [Id=Ev|QueryIds]) :-
|
||||||
|
b_hash_lookup(Key, Id, Hash),
|
||||||
|
evidence_to_ids(QueryKeys, Hash, QueryIds).
|
||||||
|
|
||||||
|
|
||||||
|
get_vars_information([], []).
|
||||||
|
get_vars_information(Key.QueryKeys, Domain.StatesNames) :-
|
||||||
|
pfl:skolem(Key, Domain),
|
||||||
|
get_vars_information(QueryKeys, StatesNames).
|
||||||
|
|
||||||
|
|
||||||
|
finalize_bp_solver(bp(Network, _)) :-
|
||||||
|
free_ground_network(Network).
|
||||||
|
|
||||||
|
|
||||||
bp([[]],_,_) :- !.
|
bp([[]],_,_) :- !.
|
||||||
bp([QueryVars], AllVars, Output) :-
|
bp([QueryVars], AllVars, Output) :-
|
||||||
init_bp_solver(_, AllVars, _, Network),
|
init_bp_solver(_, AllVars, _, Network),
|
||||||
@ -144,53 +145,22 @@ bp([QueryVars], AllVars, Output) :-
|
|||||||
|
|
||||||
|
|
||||||
init_bp_solver(_, AllVars0, _, bp(BayesNet, DistIds)) :-
|
init_bp_solver(_, AllVars0, _, bp(BayesNet, DistIds)) :-
|
||||||
%writeln('init_bp_solver'),
|
|
||||||
%check_for_agg_vars(AllVars0, AllVars),
|
%check_for_agg_vars(AllVars0, AllVars),
|
||||||
%writeln('clpbn_vars:'), print_clpbn_vars(AllVars),
|
|
||||||
get_vars_info(AllVars, VarsInfo, DistIds0),
|
get_vars_info(AllVars, VarsInfo, DistIds0),
|
||||||
sort(DistIds0, DistIds),
|
sort(DistIds0, DistIds),
|
||||||
create_ground_network(VarsInfo, BayesNet),
|
create_ground_network(VarsInfo, BayesNet),
|
||||||
%get_extra_vars_info(AllVars, ExtraVarsInfo),
|
|
||||||
%set_extra_vars_info(BayesNet, ExtraVarsInfo),
|
|
||||||
%writeln(extravarsinfo:ExtraVarsInfo),
|
|
||||||
true.
|
true.
|
||||||
|
|
||||||
|
|
||||||
run_bp_solver(QueryVars, Solutions, bp(Network, DistIds)) :-
|
run_bp_solver(QueryVars, Solutions, bp(Network, DistIds)) :-
|
||||||
%writeln('-> run_bp_solver'),
|
|
||||||
get_dists_parameters(DistIds, DistsParams),
|
get_dists_parameters(DistIds, DistsParams),
|
||||||
set_bayes_net_params(Network, DistsParams),
|
set_bayes_net_params(Network, DistsParams),
|
||||||
vars_to_ids(QueryVars, QueryVarsIds),
|
vars_to_ids(QueryVars, QueryVarsIds),
|
||||||
run_ground_solver(Network, QueryVarsIds, Solutions).
|
run_ground_solver(Network, QueryVarsIds, Solutions).
|
||||||
|
|
||||||
|
|
||||||
finalize_bp_solver(bp(Network, _)) :-
|
|
||||||
free_ground_network(Network).
|
|
||||||
|
|
||||||
|
|
||||||
get_extra_vars_info([], []).
|
|
||||||
get_extra_vars_info([V|Vs], [v(VarId, Label, Domain)|VarsInfo]) :-
|
|
||||||
get_atts(V, [id(VarId)]), !,
|
|
||||||
clpbn:get_atts(V, [key(Key), dist(DistId, _)]),
|
|
||||||
term_to_atom(Key, Label),
|
|
||||||
get_dist_domain(DistId, Domain0),
|
|
||||||
numbers_to_atoms(Domain0, Domain),
|
|
||||||
get_extra_vars_info(Vs, VarsInfo).
|
|
||||||
get_extra_vars_info([_|Vs], VarsInfo) :-
|
|
||||||
get_extra_vars_info(Vs, VarsInfo).
|
|
||||||
|
|
||||||
|
|
||||||
get_dists_parameters([],[]).
|
get_dists_parameters([],[]).
|
||||||
get_dists_parameters([Id|Ids], [dist(Id, Params)|DistsInfo]) :-
|
get_dists_parameters([Id|Ids], [dist(Id, Params)|DistsInfo]) :-
|
||||||
get_dist_params(Id, Params),
|
get_dist_params(Id, Params),
|
||||||
get_dists_parameters(Ids, DistsInfo).
|
get_dists_parameters(Ids, DistsInfo).
|
||||||
|
|
||||||
|
|
||||||
numbers_to_atoms([], []).
|
|
||||||
numbers_to_atoms([Atom|L0], [Atom|L]) :-
|
|
||||||
atom(Atom), !,
|
|
||||||
numbers_to_atoms(L0, L).
|
|
||||||
numbers_to_atoms([Number|L0], [Atom|L]) :-
|
|
||||||
number_atom(Number, Atom),
|
|
||||||
numbers_to_atoms(L0, L).
|
|
||||||
|
|
||||||
|
@ -63,12 +63,12 @@ BayesBall::constructGraph (FactorGraph* fg) const
|
|||||||
const FactorNodes& facNodes = fg_.factorNodes();
|
const FactorNodes& facNodes = fg_.factorNodes();
|
||||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||||
const DAGraphNode* n = dag_.getNode (
|
const DAGraphNode* n = dag_.getNode (
|
||||||
facNodes[i]->factor()->argument (0));
|
facNodes[i]->factor().argument (0));
|
||||||
if (n->isMarkedOnTop()) {
|
if (n->isMarkedOnTop()) {
|
||||||
fg->addFactor (Factor (*(facNodes[i]->factor())));
|
fg->addFactor (Factor (facNodes[i]->factor()));
|
||||||
} else if (n->hasEvidence() && n->isVisited()) {
|
} else if (n->hasEvidence() && n->isVisited()) {
|
||||||
VarIds varIds = { facNodes[i]->factor()->argument (0) };
|
VarIds varIds = { facNodes[i]->factor().argument (0) };
|
||||||
Ranges ranges = { facNodes[i]->factor()->range (0) };
|
Ranges ranges = { facNodes[i]->factor().range (0) };
|
||||||
Params params (ranges[0], LogAware::noEvidence());
|
Params params (ranges[0], LogAware::noEvidence());
|
||||||
params[n->getEvidence()] = LogAware::withEvidence();
|
params[n->getEvidence()] = LogAware::withEvidence();
|
||||||
fg->addFactor (Factor (varIds, ranges, params));
|
fg->addFactor (Factor (varIds, ranges, params));
|
||||||
|
@ -84,5 +84,5 @@ class DAGraph
|
|||||||
unordered_map<VarId, DAGraphNode*> varMap_;
|
unordered_map<VarId, DAGraphNode*> varMap_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
#endif // HORUS_BAYESNET_H
|
#endif // HORUS_BAYESNET_H
|
||||||
|
|
||||||
|
@ -101,7 +101,7 @@ BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
|||||||
VarNode* vn = factorGraph_->getVarNode (jointVarIds[0]);
|
VarNode* vn = factorGraph_->getVarNode (jointVarIds[0]);
|
||||||
const FactorNodes& factorNodes = vn->neighbors();
|
const FactorNodes& factorNodes = vn->neighbors();
|
||||||
for (unsigned i = 0; i < factorNodes.size(); i++) {
|
for (unsigned i = 0; i < factorNodes.size(); i++) {
|
||||||
if (factorNodes[i]->factor()->contains (jointVarIds)) {
|
if (factorNodes[i]->factor().contains (jointVarIds)) {
|
||||||
idx = i;
|
idx = i;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -109,7 +109,7 @@ BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
|||||||
if (idx == -1) {
|
if (idx == -1) {
|
||||||
return getJointByConditioning (jointVarIds);
|
return getJointByConditioning (jointVarIds);
|
||||||
} else {
|
} else {
|
||||||
Factor res (*factorNodes[idx]->factor());
|
Factor res (factorNodes[idx]->factor());
|
||||||
const SpLinkSet& links = ninf(factorNodes[idx])->getLinks();
|
const SpLinkSet& links = ninf(factorNodes[idx])->getLinks();
|
||||||
for (unsigned i = 0; i < links.size(); i++) {
|
for (unsigned i = 0; i < links.size(); i++) {
|
||||||
Factor msg (links[i]->getVariable()->varId(),
|
Factor msg (links[i]->getVariable()->varId(),
|
||||||
@ -316,9 +316,9 @@ BpSolver::maxResidualSchedule (void)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
BpSolver::calculateFactor2VariableMsg (SpLink* link) const
|
BpSolver::calculateFactor2VariableMsg (SpLink* link)
|
||||||
{
|
{
|
||||||
const FactorNode* src = link->getFactor();
|
FactorNode* src = link->getFactor();
|
||||||
const VarNode* dst = link->getVariable();
|
const VarNode* dst = link->getVariable();
|
||||||
const SpLinkSet& links = ninf(src)->getLinks();
|
const SpLinkSet& links = ninf(src)->getLinks();
|
||||||
// calculate the product of messages that were sent
|
// calculate the product of messages that were sent
|
||||||
@ -357,10 +357,9 @@ BpSolver::calculateFactor2VariableMsg (SpLink* link) const
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Factor result (src->factor()->arguments(),
|
Factor result (src->factor().arguments(),
|
||||||
src->factor()->ranges(),
|
src->factor().ranges(), msgProduct);
|
||||||
msgProduct);
|
result.multiply (src->factor());
|
||||||
result.multiply (*(src->factor()));
|
|
||||||
if (Constants::DEBUG >= 5) {
|
if (Constants::DEBUG >= 5) {
|
||||||
cout << " message product: " << msgProduct << endl;
|
cout << " message product: " << msgProduct << endl;
|
||||||
cout << " original factor: " << src->params() << endl;
|
cout << " original factor: " << src->params() << endl;
|
||||||
|
@ -108,7 +108,7 @@ class BpSolver : public Solver
|
|||||||
|
|
||||||
virtual void maxResidualSchedule (void);
|
virtual void maxResidualSchedule (void);
|
||||||
|
|
||||||
virtual void calculateFactor2VariableMsg (SpLink*) const;
|
virtual void calculateFactor2VariableMsg (SpLink*);
|
||||||
|
|
||||||
virtual Params getVar2FactorMsg (const SpLink*) const;
|
virtual Params getVar2FactorMsg (const SpLink*) const;
|
||||||
|
|
||||||
|
@ -74,20 +74,20 @@ CFactorGraph::setInitialColors (void)
|
|||||||
if (checkForIdenticalFactors) {
|
if (checkForIdenticalFactors) {
|
||||||
unsigned groupCount = 1;
|
unsigned groupCount = 1;
|
||||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||||
Factor* f1 = facNodes[i]->factor();
|
Factor& f1 = facNodes[i]->factor();
|
||||||
if (f1->distId() != Util::maxUnsigned()) {
|
if (f1.distId() != Util::maxUnsigned()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
f1->setDistId (groupCount);
|
f1.setDistId (groupCount);
|
||||||
for (unsigned j = i + 1; j < facNodes.size(); j++) {
|
for (unsigned j = i + 1; j < facNodes.size(); j++) {
|
||||||
Factor* f2 = facNodes[j]->factor();
|
Factor& f2 = facNodes[j]->factor();
|
||||||
if (f2->distId() != Util::maxUnsigned()) {
|
if (f2.distId() != Util::maxUnsigned()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (f1->size() == f2->size() &&
|
if (f1.size() == f2.size() &&
|
||||||
f1->ranges() == f2->ranges() &&
|
f1.ranges() == f2.ranges() &&
|
||||||
f1->params() == f2->params()) {
|
f1.params() == f2.params()) {
|
||||||
f2->setDistId (groupCount);
|
f2.setDistId (groupCount);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
groupCount ++;
|
groupCount ++;
|
||||||
@ -96,7 +96,7 @@ CFactorGraph::setInitialColors (void)
|
|||||||
// create the initial factor colors
|
// create the initial factor colors
|
||||||
DistColorMap distColors;
|
DistColorMap distColors;
|
||||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||||
unsigned distId = facNodes[i]->factor()->distId();
|
unsigned distId = facNodes[i]->factor().distId();
|
||||||
DistColorMap::iterator it = distColors.find (distId);
|
DistColorMap::iterator it = distColors.find (distId);
|
||||||
if (it == distColors.end()) {
|
if (it == distColors.end()) {
|
||||||
it = distColors.insert (make_pair (distId, getFreeColor())).first;
|
it = distColors.insert (make_pair (distId, getFreeColor())).first;
|
||||||
@ -211,7 +211,7 @@ CFactorGraph::getSignature (const VarNode* varNode)
|
|||||||
for (unsigned i = 0; i < neighs.size(); i++) {
|
for (unsigned i = 0; i < neighs.size(); i++) {
|
||||||
*it = getColor (neighs[i]);
|
*it = getColor (neighs[i]);
|
||||||
it ++;
|
it ++;
|
||||||
*it = neighs[i]->factor()->indexOf (varNode->varId());
|
*it = neighs[i]->factor().indexOf (varNode->varId());
|
||||||
it ++;
|
it ++;
|
||||||
}
|
}
|
||||||
*it = getColor (varNode);
|
*it = getColor (varNode);
|
||||||
@ -244,7 +244,7 @@ CFactorGraph::getCompressedFactorGraph (void)
|
|||||||
VarNode* var = varClusters_[i]->getGroundVarNodes()[0];
|
VarNode* var = varClusters_[i]->getGroundVarNodes()[0];
|
||||||
VarNode* newVar = new VarNode (var);
|
VarNode* newVar = new VarNode (var);
|
||||||
varClusters_[i]->setRepresentativeVariable (newVar);
|
varClusters_[i]->setRepresentativeVariable (newVar);
|
||||||
fg->addVariable (newVar);
|
fg->addVarNode (newVar);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (unsigned i = 0; i < facClusters_.size(); i++) {
|
for (unsigned i = 0; i < facClusters_.size(); i++) {
|
||||||
@ -255,11 +255,10 @@ CFactorGraph::getCompressedFactorGraph (void)
|
|||||||
VarNode* v = myVarClusters[j]->getRepresentativeVariable();
|
VarNode* v = myVarClusters[j]->getRepresentativeVariable();
|
||||||
myGroundVars.push_back (v);
|
myGroundVars.push_back (v);
|
||||||
}
|
}
|
||||||
Factor* newFactor = new Factor (myGroundVars,
|
FactorNode* fn = new FactorNode (Factor (myGroundVars,
|
||||||
facClusters_[i]->getGroundFactors()[0]->params());
|
facClusters_[i]->getGroundFactors()[0]->params()));
|
||||||
FactorNode* fn = new FactorNode (newFactor);
|
|
||||||
facClusters_[i]->setRepresentativeFactor (fn);
|
facClusters_[i]->setRepresentativeFactor (fn);
|
||||||
fg->addFactor (fn);
|
fg->addFactorNode (fn);
|
||||||
for (unsigned j = 0; j < myGroundVars.size(); j++) {
|
for (unsigned j = 0; j < myGroundVars.size(); j++) {
|
||||||
fg->addEdge (fn, static_cast<VarNode*> (myGroundVars[j]));
|
fg->addEdge (fn, static_cast<VarNode*> (myGroundVars[j]));
|
||||||
}
|
}
|
||||||
@ -279,7 +278,7 @@ CFactorGraph::getGroundEdgeCount (
|
|||||||
VarNode* varNode = vc->getGroundVarNodes()[0];
|
VarNode* varNode = vc->getGroundVarNodes()[0];
|
||||||
unsigned count = 0;
|
unsigned count = 0;
|
||||||
for (unsigned i = 0; i < clusterGroundFactors.size(); i++) {
|
for (unsigned i = 0; i < clusterGroundFactors.size(); i++) {
|
||||||
if (clusterGroundFactors[i]->factor()->indexOf (varNode->varId()) != -1) {
|
if (clusterGroundFactors[i]->factor().indexOf (varNode->varId()) != -1) {
|
||||||
count ++;
|
count ++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -18,21 +18,20 @@ bool FactorGraph::orderFactorVariables = false;
|
|||||||
|
|
||||||
FactorGraph::FactorGraph (const FactorGraph& fg)
|
FactorGraph::FactorGraph (const FactorGraph& fg)
|
||||||
{
|
{
|
||||||
const VarNodes& vars = fg.varNodes();
|
const VarNodes& varNodes = fg.varNodes();
|
||||||
for (unsigned i = 0; i < vars.size(); i++) {
|
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||||
VarNode* varNode = new VarNode (vars[i]);
|
addVarNode (new VarNode (varNodes[i]));
|
||||||
addVariable (varNode);
|
|
||||||
}
|
}
|
||||||
|
const FactorNodes& facNodes = fg.factorNodes();
|
||||||
const FactorNodes& facs = fg.factorNodes();
|
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||||
for (unsigned i = 0; i < facs.size(); i++) {
|
FactorNode* facNode = new FactorNode (facNodes[i]->factor());
|
||||||
FactorNode* facNode = new FactorNode (facs[i]);
|
addFactorNode (facNode);
|
||||||
addFactor (facNode);
|
const VarNodes& neighs = facNodes[i]->neighbors();
|
||||||
const VarNodes& neighs = facs[i]->neighbors();
|
|
||||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||||
addEdge (facNode, varNodes_[neighs[j]->getIndex()]);
|
addEdge (facNode, varNodes_[neighs[j]->getIndex()]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
setIndexes();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -40,82 +39,70 @@ FactorGraph::FactorGraph (const FactorGraph& fg)
|
|||||||
void
|
void
|
||||||
FactorGraph::readFromUaiFormat (const char* fileName)
|
FactorGraph::readFromUaiFormat (const char* fileName)
|
||||||
{
|
{
|
||||||
ifstream is (fileName);
|
ifstream is (fileName);
|
||||||
if (!is.is_open()) {
|
if (!is.is_open()) {
|
||||||
cerr << "error: cannot read from file " + std::string (fileName) << endl;
|
cerr << "error: cannot read from file " + std::string (fileName) << endl;
|
||||||
abort();
|
abort();
|
||||||
}
|
}
|
||||||
|
ignoreLines (is);
|
||||||
string line;
|
string line;
|
||||||
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
|
||||||
getline (is, line);
|
getline (is, line);
|
||||||
if (line != "MARKOV") {
|
if (line != "MARKOV") {
|
||||||
cerr << "error: the network must be a MARKOV network " << endl;
|
cerr << "error: the network must be a MARKOV network " << endl;
|
||||||
abort();
|
abort();
|
||||||
}
|
}
|
||||||
|
// read the number of vars
|
||||||
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
ignoreLines (is);
|
||||||
unsigned nVars;
|
unsigned nrVars;
|
||||||
is >> nVars;
|
is >> nrVars;
|
||||||
|
// read the range of each var
|
||||||
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
ignoreLines (is);
|
||||||
vector<int> domainSizes (nVars);
|
Ranges ranges (nrVars);
|
||||||
for (unsigned i = 0; i < nVars; i++) {
|
for (unsigned i = 0; i < nrVars; i++) {
|
||||||
unsigned ds;
|
is >> ranges[i];
|
||||||
is >> ds;
|
|
||||||
domainSizes[i] = ds;
|
|
||||||
}
|
}
|
||||||
|
unsigned nrFactors;
|
||||||
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
unsigned nrArgs;
|
||||||
for (unsigned i = 0; i < nVars; i++) {
|
unsigned vid;
|
||||||
addVariable (new VarNode (i, domainSizes[i]));
|
is >> nrFactors;
|
||||||
}
|
vector<VarIds> factorVarIds;
|
||||||
|
vector<Ranges> factorRanges;
|
||||||
unsigned nFactors;
|
for (unsigned i = 0; i < nrFactors; i++) {
|
||||||
is >> nFactors;
|
ignoreLines (is);
|
||||||
for (unsigned i = 0; i < nFactors; i++) {
|
is >> nrArgs;
|
||||||
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
factorVarIds.push_back ({ });
|
||||||
unsigned nFactorVars;
|
factorRanges.push_back ({ });
|
||||||
is >> nFactorVars;
|
for (unsigned j = 0; j < nrArgs; j++) {
|
||||||
Vars neighs;
|
|
||||||
for (unsigned j = 0; j < nFactorVars; j++) {
|
|
||||||
unsigned vid;
|
|
||||||
is >> vid;
|
is >> vid;
|
||||||
VarNode* neigh = getVarNode (vid);
|
if (vid >= ranges.size()) {
|
||||||
if (!neigh) {
|
cerr << "error: invalid variable identifier `" << vid << "'" << endl;
|
||||||
cerr << "error: invalid variable identifier (" << vid << ")" << endl;
|
cerr << "identifiers must be between 0 and " << ranges.size() - 1 ;
|
||||||
|
cerr << endl;
|
||||||
abort();
|
abort();
|
||||||
}
|
}
|
||||||
neighs.push_back (neigh);
|
factorVarIds.back().push_back (vid);
|
||||||
}
|
factorRanges.back().push_back (ranges[vid]);
|
||||||
FactorNode* fn = new FactorNode (new Factor (neighs));
|
|
||||||
addFactor (fn);
|
|
||||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
|
||||||
addEdge (fn, static_cast<VarNode*> (neighs[j]));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// read the parameters
|
||||||
for (unsigned i = 0; i < nFactors; i++) {
|
unsigned nrParams;
|
||||||
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
for (unsigned i = 0; i < nrFactors; i++) {
|
||||||
unsigned nParams;
|
ignoreLines (is);
|
||||||
is >> nParams;
|
is >> nrParams;
|
||||||
if (facNodes_[i]->params().size() != nParams) {
|
if (nrParams != Util::expectedSize (factorRanges[i])) {
|
||||||
cerr << "error: invalid number of parameters for factor " ;
|
cerr << "error: invalid number of parameters for factor nº " << i ;
|
||||||
cerr << facNodes_[i]->getLabel() ;
|
cerr << ", expected: " << Util::expectedSize (factorRanges[i]);
|
||||||
cerr << ", expected: " << facNodes_[i]->params().size();
|
cerr << ", given: " << nrParams << endl;
|
||||||
cerr << ", given: " << nParams << endl;
|
|
||||||
abort();
|
abort();
|
||||||
}
|
}
|
||||||
Params params (nParams);
|
Params params (nrParams);
|
||||||
for (unsigned j = 0; j < nParams; j++) {
|
for (unsigned j = 0; j < nrParams; j++) {
|
||||||
double param;
|
is >> params[j];
|
||||||
is >> param;
|
|
||||||
params[j] = param;
|
|
||||||
}
|
}
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
Util::toLog (params);
|
Util::toLog (params);
|
||||||
}
|
}
|
||||||
facNodes_[i]->factor()->setParams (params);
|
addFactor (Factor (factorVarIds[i], factorRanges[i], params));
|
||||||
}
|
}
|
||||||
is.close();
|
is.close();
|
||||||
setIndexes();
|
setIndexes();
|
||||||
@ -131,79 +118,51 @@ FactorGraph::readFromLibDaiFormat (const char* fileName)
|
|||||||
cerr << "error: cannot read from file " + std::string (fileName) << endl;
|
cerr << "error: cannot read from file " + std::string (fileName) << endl;
|
||||||
abort();
|
abort();
|
||||||
}
|
}
|
||||||
|
ignoreLines (is);
|
||||||
string line;
|
unsigned nrFactors;
|
||||||
unsigned nFactors;
|
unsigned nrArgs;
|
||||||
|
VarId vid;
|
||||||
while ((is.peek()) == '#') getline (is, line);
|
is >> nrFactors;
|
||||||
is >> nFactors;
|
for (unsigned i = 0; i < nrFactors; i++) {
|
||||||
|
ignoreLines (is);
|
||||||
if (is.fail()) {
|
// read the factor arguments
|
||||||
cerr << "error: cannot read the number of factors" << endl;
|
is >> nrArgs;
|
||||||
abort();
|
|
||||||
}
|
|
||||||
|
|
||||||
getline (is, line);
|
|
||||||
if (is.fail() || line.size() > 0) {
|
|
||||||
cerr << "error: cannot read the number of factors" << endl;
|
|
||||||
abort();
|
|
||||||
}
|
|
||||||
|
|
||||||
for (unsigned i = 0; i < nFactors; i++) {
|
|
||||||
unsigned nVars;
|
|
||||||
while ((is.peek()) == '#') getline (is, line);
|
|
||||||
|
|
||||||
is >> nVars;
|
|
||||||
VarIds vids;
|
VarIds vids;
|
||||||
for (unsigned j = 0; j < nVars; j++) {
|
for (unsigned j = 0; j < nrArgs; j++) {
|
||||||
VarId vid;
|
ignoreLines (is);
|
||||||
while ((is.peek()) == '#') getline (is, line);
|
|
||||||
is >> vid;
|
is >> vid;
|
||||||
vids.push_back (vid);
|
vids.push_back (vid);
|
||||||
}
|
}
|
||||||
|
// read ranges
|
||||||
Vars neighs;
|
Ranges ranges (nrArgs);
|
||||||
unsigned nParams = 1;
|
for (unsigned j = 0; j < nrArgs; j++) {
|
||||||
for (unsigned j = 0; j < nVars; j++) {
|
ignoreLines (is);
|
||||||
unsigned dsize;
|
is >> ranges[j];
|
||||||
while ((is.peek()) == '#') getline (is, line);
|
|
||||||
is >> dsize;
|
|
||||||
VarNode* var = getVarNode (vids[j]);
|
VarNode* var = getVarNode (vids[j]);
|
||||||
if (var == 0) {
|
if (var != 0 && ranges[j] != var->range()) {
|
||||||
var = new VarNode (vids[j], dsize);
|
cerr << "error: variable `" << vids[j] << "' appears in two or " ;
|
||||||
addVariable (var);
|
cerr << "more factors with a different range" << endl;
|
||||||
} else {
|
|
||||||
if (var->range() != dsize) {
|
|
||||||
cerr << "error: variable `" << vids[j] << "' appears in two or " ;
|
|
||||||
cerr << "more factors with different domain sizes" << endl;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
neighs.push_back (var);
|
|
||||||
nParams *= var->range();
|
|
||||||
}
|
}
|
||||||
Params params (nParams, 0);
|
// read parameters
|
||||||
|
ignoreLines (is);
|
||||||
unsigned nNonzeros;
|
unsigned nNonzeros;
|
||||||
while ((is.peek()) == '#') getline (is, line);
|
|
||||||
is >> nNonzeros;
|
is >> nNonzeros;
|
||||||
|
Params params (Util::expectedSize (ranges), 0);
|
||||||
for (unsigned j = 0; j < nNonzeros; j++) {
|
for (unsigned j = 0; j < nNonzeros; j++) {
|
||||||
|
ignoreLines (is);
|
||||||
unsigned index;
|
unsigned index;
|
||||||
double val;
|
|
||||||
while ((is.peek()) == '#') getline (is, line);
|
|
||||||
is >> index;
|
is >> index;
|
||||||
while ((is.peek()) == '#') getline (is, line);
|
ignoreLines (is);
|
||||||
|
double val;
|
||||||
is >> val;
|
is >> val;
|
||||||
params[index] = val;
|
params[index] = val;
|
||||||
}
|
}
|
||||||
reverse (neighs.begin(), neighs.end());
|
reverse (vids.begin(), vids.end());
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
Util::toLog (params);
|
Util::toLog (params);
|
||||||
}
|
}
|
||||||
FactorNode* fn = new FactorNode (new Factor (neighs, params));
|
addFactor (Factor (vids, ranges, params));
|
||||||
addFactor (fn);
|
|
||||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
|
||||||
addEdge (fn, static_cast<VarNode*> (neighs[j]));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
is.close();
|
is.close();
|
||||||
setIndexes();
|
setIndexes();
|
||||||
@ -223,30 +182,11 @@ FactorGraph::~FactorGraph (void)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
FactorGraph::addVariable (VarNode* vn)
|
|
||||||
{
|
|
||||||
varNodes_.push_back (vn);
|
|
||||||
vn->setIndex (varNodes_.size() - 1);
|
|
||||||
varMap_.insert (make_pair (vn->varId(), varNodes_.size() - 1));
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
FactorGraph::addFactor (FactorNode* fn)
|
|
||||||
{
|
|
||||||
facNodes_.push_back (fn);
|
|
||||||
fn->setIndex (facNodes_.size() - 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
FactorGraph::addFactor (const Factor& factor)
|
FactorGraph::addFactor (const Factor& factor)
|
||||||
{
|
{
|
||||||
FactorNode* fn = new FactorNode (factor);
|
FactorNode* fn = new FactorNode (factor);
|
||||||
addFactor (fn);
|
addFactorNode (fn);
|
||||||
const VarIds& vids = factor.arguments();
|
const VarIds& vids = factor.arguments();
|
||||||
for (unsigned i = 0; i < vids.size(); i++) {
|
for (unsigned i = 0; i < vids.size(); i++) {
|
||||||
bool found = false;
|
bool found = false;
|
||||||
@ -258,7 +198,7 @@ FactorGraph::addFactor (const Factor& factor)
|
|||||||
}
|
}
|
||||||
if (found == false) {
|
if (found == false) {
|
||||||
VarNode* vn = new VarNode (vids[i], factor.range (i));
|
VarNode* vn = new VarNode (vids[i], factor.range (i));
|
||||||
addVariable (vn);
|
addVarNode (vn);
|
||||||
addEdge (vn, fn);
|
addEdge (vn, fn);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -266,6 +206,25 @@ FactorGraph::addFactor (const Factor& factor)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
FactorGraph::addVarNode (VarNode* vn)
|
||||||
|
{
|
||||||
|
varNodes_.push_back (vn);
|
||||||
|
vn->setIndex (varNodes_.size() - 1);
|
||||||
|
varMap_.insert (make_pair (vn->varId(), varNodes_.size() - 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
FactorGraph::addFactorNode (FactorNode* fn)
|
||||||
|
{
|
||||||
|
facNodes_.push_back (fn);
|
||||||
|
fn->setIndex (facNodes_.size() - 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
FactorGraph::addEdge (VarNode* vn, FactorNode* fn)
|
FactorGraph::addEdge (VarNode* vn, FactorNode* fn)
|
||||||
{
|
{
|
||||||
@ -301,7 +260,7 @@ FactorGraph::getStructure (void)
|
|||||||
structure_.addNode (new DAGraphNode (varNodes_[i]));
|
structure_.addNode (new DAGraphNode (varNodes_[i]));
|
||||||
}
|
}
|
||||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||||
const VarIds& vids = facNodes_[i]->factor()->arguments();
|
const VarIds& vids = facNodes_[i]->factor().arguments();
|
||||||
for (unsigned j = 1; j < vids.size(); j++) {
|
for (unsigned j = 1; j < vids.size(); j++) {
|
||||||
structure_.addEdge (vids[j], vids[0]);
|
structure_.addEdge (vids[j], vids[0]);
|
||||||
}
|
}
|
||||||
@ -340,7 +299,7 @@ FactorGraph::print (void) const
|
|||||||
cout << endl << endl;
|
cout << endl << endl;
|
||||||
}
|
}
|
||||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||||
facNodes_[i]->factor()->print();
|
facNodes_[i]->factor().print();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -446,7 +405,7 @@ FactorGraph::exportToLibDaiFormat (const char* fileName) const
|
|||||||
out << factorVars[j]->range() << " " ;
|
out << factorVars[j]->range() << " " ;
|
||||||
}
|
}
|
||||||
out << endl;
|
out << endl;
|
||||||
Params params = facNodes_[i]->factor()->params();
|
Params params = facNodes_[i]->factor().params();
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
Util::fromLog (params);
|
Util::fromLog (params);
|
||||||
}
|
}
|
||||||
@ -461,6 +420,17 @@ FactorGraph::exportToLibDaiFormat (const char* fileName) const
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
FactorGraph::ignoreLines (std::ifstream& is) const
|
||||||
|
{
|
||||||
|
string ignoreStr;
|
||||||
|
while (is.peek() == '#' || is.peek() == '\n') {
|
||||||
|
getline (is, ignoreStr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
bool
|
||||||
FactorGraph::containsCycle (void) const
|
FactorGraph::containsCycle (void) const
|
||||||
{
|
{
|
||||||
|
@ -34,47 +34,28 @@ class VarNode : public Var
|
|||||||
class FactorNode
|
class FactorNode
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
FactorNode (const FactorNode* fn)
|
FactorNode (const Factor& f) : factor_(f), index_(-1) { }
|
||||||
{
|
|
||||||
factor_ = new Factor (*fn->factor());
|
|
||||||
index_ = -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
FactorNode (Factor* f) : factor_(new Factor(*f)), index_(-1) { }
|
const Factor& factor (void) const { return factor_; }
|
||||||
|
|
||||||
FactorNode (const Factor& f) : factor_(new Factor (f)), index_(-1) { }
|
Factor& factor (void) { return factor_; }
|
||||||
|
|
||||||
Factor* factor() const { return factor_; }
|
|
||||||
|
|
||||||
void addNeighbor (VarNode* vn) { neighs_.push_back (vn); }
|
void addNeighbor (VarNode* vn) { neighs_.push_back (vn); }
|
||||||
|
|
||||||
const VarNodes& neighbors (void) const { return neighs_; }
|
const VarNodes& neighbors (void) const { return neighs_; }
|
||||||
|
|
||||||
int getIndex (void) const
|
int getIndex (void) const { return index_; }
|
||||||
{
|
|
||||||
assert (index_ != -1);
|
|
||||||
return index_;
|
|
||||||
}
|
|
||||||
|
|
||||||
void setIndex (int index)
|
void setIndex (int index) { index_ = index; }
|
||||||
{
|
|
||||||
index_ = index;
|
|
||||||
}
|
|
||||||
|
|
||||||
const Params& params (void) const
|
const Params& params (void) const { return factor_.params(); }
|
||||||
{
|
|
||||||
return factor_->params();
|
|
||||||
}
|
|
||||||
|
|
||||||
string getLabel (void)
|
string getLabel (void) { return factor_.getLabel(); }
|
||||||
{
|
|
||||||
return factor_->getLabel();
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
DISALLOW_COPY_AND_ASSIGN (FactorNode);
|
DISALLOW_COPY_AND_ASSIGN (FactorNode);
|
||||||
|
|
||||||
Factor* factor_;
|
Factor factor_;
|
||||||
VarNodes neighs_;
|
VarNodes neighs_;
|
||||||
int index_;
|
int index_;
|
||||||
};
|
};
|
||||||
@ -116,12 +97,12 @@ class FactorGraph
|
|||||||
|
|
||||||
void readFromLibDaiFormat (const char*);
|
void readFromLibDaiFormat (const char*);
|
||||||
|
|
||||||
void addVariable (VarNode*);
|
|
||||||
|
|
||||||
void addFactor (FactorNode*);
|
|
||||||
|
|
||||||
void addFactor (const Factor& factor);
|
void addFactor (const Factor& factor);
|
||||||
|
|
||||||
|
void addVarNode (VarNode*);
|
||||||
|
|
||||||
|
void addFactorNode (FactorNode*);
|
||||||
|
|
||||||
void addEdge (VarNode*, FactorNode*);
|
void addEdge (VarNode*, FactorNode*);
|
||||||
|
|
||||||
void addEdge (FactorNode*, VarNode*);
|
void addEdge (FactorNode*, VarNode*);
|
||||||
@ -145,6 +126,8 @@ class FactorGraph
|
|||||||
private:
|
private:
|
||||||
// DISALLOW_COPY_AND_ASSIGN (FactorGraph);
|
// DISALLOW_COPY_AND_ASSIGN (FactorGraph);
|
||||||
|
|
||||||
|
void ignoreLines (std::ifstream&) const;
|
||||||
|
|
||||||
bool containsCycle (void) const;
|
bool containsCycle (void) const;
|
||||||
|
|
||||||
bool containsCycle (const VarNode*, const FactorNode*,
|
bool containsCycle (const VarNode*, const FactorNode*,
|
||||||
@ -153,7 +136,7 @@ class FactorGraph
|
|||||||
bool containsCycle (const FactorNode*, const VarNode*,
|
bool containsCycle (const FactorNode*, const VarNode*,
|
||||||
vector<bool>&, vector<bool>&) const;
|
vector<bool>&, vector<bool>&) const;
|
||||||
|
|
||||||
VarNodes varNodes_;
|
VarNodes varNodes_;
|
||||||
FactorNodes facNodes_;
|
FactorNodes facNodes_;
|
||||||
|
|
||||||
bool fromBayesNet_;
|
bool fromBayesNet_;
|
||||||
|
@ -31,15 +31,16 @@ main (int argc, const char* argv[])
|
|||||||
FactorGraph fg;
|
FactorGraph fg;
|
||||||
if (extension == "uai") {
|
if (extension == "uai") {
|
||||||
fg.readFromUaiFormat (argv[1]);
|
fg.readFromUaiFormat (argv[1]);
|
||||||
processArguments (fg, argc, argv);
|
|
||||||
} else if (extension == "fg") {
|
} else if (extension == "fg") {
|
||||||
fg.readFromLibDaiFormat (argv[1]);
|
fg.readFromLibDaiFormat (argv[1]);
|
||||||
processArguments (fg, argc, argv);
|
|
||||||
} else {
|
} else {
|
||||||
cerr << "error: the graphical model must be defined either " ;
|
cerr << "error: the graphical model must be defined either " ;
|
||||||
cerr << "in a UAI or libDAI file" << endl;
|
cerr << "in a UAI or libDAI file" << endl;
|
||||||
exit (0);
|
exit (0);
|
||||||
}
|
}
|
||||||
|
fg.print();
|
||||||
|
assert (false);
|
||||||
|
processArguments (fg, argc, argv);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -214,8 +214,6 @@ createGroundNetwork (void)
|
|||||||
{
|
{
|
||||||
FactorGraph* fg = new FactorGraph();;
|
FactorGraph* fg = new FactorGraph();;
|
||||||
string factorsType ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG1)));
|
string factorsType ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG1)));
|
||||||
cout << "factors type: '" << factorsType << "'" << endl;
|
|
||||||
|
|
||||||
YAP_Term factorList = YAP_ARG2;
|
YAP_Term factorList = YAP_ARG2;
|
||||||
while (factorList != YAP_TermNil()) {
|
while (factorList != YAP_TermNil()) {
|
||||||
YAP_Term factor = YAP_HeadOfTerm (factorList);
|
YAP_Term factor = YAP_HeadOfTerm (factorList);
|
||||||
@ -240,7 +238,6 @@ createGroundNetwork (void)
|
|||||||
YAP_Term evTerm = YAP_HeadOfTerm (evidenceList);
|
YAP_Term evTerm = YAP_HeadOfTerm (evidenceList);
|
||||||
unsigned vid = (unsigned) YAP_IntOfTerm ((YAP_ArgOfTerm (1, evTerm)));
|
unsigned vid = (unsigned) YAP_IntOfTerm ((YAP_ArgOfTerm (1, evTerm)));
|
||||||
unsigned ev = (unsigned) YAP_IntOfTerm ((YAP_ArgOfTerm (2, evTerm)));
|
unsigned ev = (unsigned) YAP_IntOfTerm ((YAP_ArgOfTerm (2, evTerm)));
|
||||||
cout << vid << " == " << ev << endl;
|
|
||||||
assert (fg->getVarNode (vid));
|
assert (fg->getVarNode (vid));
|
||||||
fg->getVarNode (vid)->setEvidence (ev);
|
fg->getVarNode (vid)->setEvidence (ev);
|
||||||
evidenceList = YAP_TailOfTerm (evidenceList);
|
evidenceList = YAP_TailOfTerm (evidenceList);
|
||||||
@ -354,7 +351,7 @@ runGroundSolver (void)
|
|||||||
} else {
|
} else {
|
||||||
runBpSolver (fg, tasks, results);
|
runBpSolver (fg, tasks, results);
|
||||||
}
|
}
|
||||||
cout << "results: " << results << endl;
|
|
||||||
YAP_Term list = YAP_TermNil();
|
YAP_Term list = YAP_TermNil();
|
||||||
for (int i = results.size() - 1; i >= 0; i--) {
|
for (int i = results.size() - 1; i >= 0; i--) {
|
||||||
const Params& beliefs = results[i];
|
const Params& beliefs = results[i];
|
||||||
@ -491,24 +488,29 @@ setBayesNetParams (void)
|
|||||||
|
|
||||||
|
|
||||||
int
|
int
|
||||||
setExtraVarsInfo (void)
|
setVarsInformation (void)
|
||||||
{
|
{
|
||||||
Var::clearVariablesInformation();
|
Var::clearVarsInfo();
|
||||||
YAP_Term varsInfoL = YAP_ARG2;
|
YAP_Term labelsL = YAP_ARG1;
|
||||||
while (varsInfoL != YAP_TermNil()) {
|
vector<string> labels;
|
||||||
YAP_Term head = YAP_HeadOfTerm (varsInfoL);
|
while (labelsL != YAP_TermNil()) {
|
||||||
VarId vid = YAP_IntOfTerm (YAP_ArgOfTerm (1, head));
|
YAP_Atom atom = YAP_AtomOfTerm (YAP_HeadOfTerm (labelsL));
|
||||||
YAP_Atom label = YAP_AtomOfTerm (YAP_ArgOfTerm (2, head));
|
labels.push_back ((char*) YAP_AtomName (atom));
|
||||||
YAP_Term statesL = YAP_ArgOfTerm (3, head);
|
labelsL = YAP_TailOfTerm (labelsL);
|
||||||
|
}
|
||||||
|
unsigned count = 0;
|
||||||
|
YAP_Term stateNamesL = YAP_ARG2;
|
||||||
|
while (stateNamesL != YAP_TermNil()) {
|
||||||
States states;
|
States states;
|
||||||
while (statesL != YAP_TermNil()) {
|
YAP_Term namesL = YAP_HeadOfTerm (stateNamesL);
|
||||||
YAP_Atom atom = YAP_AtomOfTerm (YAP_HeadOfTerm (statesL));
|
while (namesL != YAP_TermNil()) {
|
||||||
|
YAP_Atom atom = YAP_AtomOfTerm (YAP_HeadOfTerm (namesL));
|
||||||
states.push_back ((char*) YAP_AtomName (atom));
|
states.push_back ((char*) YAP_AtomName (atom));
|
||||||
statesL = YAP_TailOfTerm (statesL);
|
namesL = YAP_TailOfTerm (namesL);
|
||||||
}
|
}
|
||||||
Var::addVariableInformation (vid,
|
Var::addVarInfo (count, labels[count], states);
|
||||||
(char*) YAP_AtomName (label), states);
|
count ++;
|
||||||
varsInfoL = YAP_TailOfTerm (varsInfoL);
|
stateNamesL = YAP_TailOfTerm (stateNamesL);
|
||||||
}
|
}
|
||||||
return TRUE;
|
return TRUE;
|
||||||
}
|
}
|
||||||
@ -627,7 +629,7 @@ init_predicates (void)
|
|||||||
YAP_UserCPredicate ("run_ground_solver", runGroundSolver, 3);
|
YAP_UserCPredicate ("run_ground_solver", runGroundSolver, 3);
|
||||||
YAP_UserCPredicate ("set_parfactors_params", setParfactorsParams, 2);
|
YAP_UserCPredicate ("set_parfactors_params", setParfactorsParams, 2);
|
||||||
YAP_UserCPredicate ("set_bayes_net_params", setBayesNetParams, 2);
|
YAP_UserCPredicate ("set_bayes_net_params", setBayesNetParams, 2);
|
||||||
YAP_UserCPredicate ("set_extra_vars_info", setExtraVarsInfo, 2);
|
YAP_UserCPredicate ("set_vars_information", setVarsInformation, 2);
|
||||||
YAP_UserCPredicate ("set_horus_flag", setHorusFlag, 2);
|
YAP_UserCPredicate ("set_horus_flag", setHorusFlag, 2);
|
||||||
YAP_UserCPredicate ("free_parfactors", freeParfactors, 1);
|
YAP_UserCPredicate ("free_parfactors", freeParfactors, 1);
|
||||||
YAP_UserCPredicate ("free_ground_network", freeGroundNetwork, 1);
|
YAP_UserCPredicate ("free_ground_network", freeGroundNetwork, 1);
|
||||||
|
@ -366,12 +366,12 @@ Statistics::printStatistics (void)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
Statistics::writeStatisticsToFile (const char* fileName)
|
Statistics::writeStatistics (const char* fileName)
|
||||||
{
|
{
|
||||||
ofstream out (fileName);
|
ofstream out (fileName);
|
||||||
if (!out.is_open()) {
|
if (!out.is_open()) {
|
||||||
cerr << "error: cannot open file to write at " ;
|
cerr << "error: cannot open file to write at " ;
|
||||||
cerr << "Statistics::writeStatisticsToFile()" << endl;
|
cerr << "Statistics::writeStats()" << endl;
|
||||||
abort();
|
abort();
|
||||||
}
|
}
|
||||||
out << getStatisticString();
|
out << getStatisticString();
|
||||||
|
@ -359,7 +359,7 @@ class Statistics
|
|||||||
|
|
||||||
static void printStatistics (void);
|
static void printStatistics (void);
|
||||||
|
|
||||||
static void writeStatisticsToFile (const char*);
|
static void writeStatistics (const char*);
|
||||||
|
|
||||||
static void updateCompressingStatistics (
|
static void updateCompressingStatistics (
|
||||||
unsigned, unsigned, unsigned, unsigned, unsigned);
|
unsigned, unsigned, unsigned, unsigned, unsigned);
|
||||||
|
@ -42,7 +42,7 @@ Var::isValidState (int stateIndex)
|
|||||||
bool
|
bool
|
||||||
Var::isValidState (const string& stateName)
|
Var::isValidState (const string& stateName)
|
||||||
{
|
{
|
||||||
States states = Var::getVarInformation (varId_).states;
|
States states = Var::getVarInfo (varId_).states;
|
||||||
return Util::contains (states, stateName);
|
return Util::contains (states, stateName);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -60,7 +60,7 @@ Var::setEvidence (int ev)
|
|||||||
void
|
void
|
||||||
Var::setEvidence (const string& ev)
|
Var::setEvidence (const string& ev)
|
||||||
{
|
{
|
||||||
States states = Var::getVarInformation (varId_).states;
|
States states = Var::getVarInfo (varId_).states;
|
||||||
for (unsigned i = 0; i < states.size(); i++) {
|
for (unsigned i = 0; i < states.size(); i++) {
|
||||||
if (states[i] == ev) {
|
if (states[i] == ev) {
|
||||||
evidence_ = i;
|
evidence_ = i;
|
||||||
@ -75,8 +75,8 @@ Var::setEvidence (const string& ev)
|
|||||||
string
|
string
|
||||||
Var::label (void) const
|
Var::label (void) const
|
||||||
{
|
{
|
||||||
if (Var::variablesHaveInformation()) {
|
if (Var::varsHaveInfo()) {
|
||||||
return Var::getVarInformation (varId_).label;
|
return Var::getVarInfo (varId_).label;
|
||||||
}
|
}
|
||||||
stringstream ss;
|
stringstream ss;
|
||||||
ss << "x" << varId_;
|
ss << "x" << varId_;
|
||||||
@ -88,8 +88,8 @@ Var::label (void) const
|
|||||||
States
|
States
|
||||||
Var::states (void) const
|
Var::states (void) const
|
||||||
{
|
{
|
||||||
if (Var::variablesHaveInformation()) {
|
if (Var::varsHaveInfo()) {
|
||||||
return Var::getVarInformation (varId_).states;
|
return Var::getVarInfo (varId_).states;
|
||||||
}
|
}
|
||||||
States states;
|
States states;
|
||||||
for (unsigned i = 0; i < range_; i++) {
|
for (unsigned i = 0; i < range_; i++) {
|
||||||
|
@ -71,25 +71,25 @@ class Var
|
|||||||
|
|
||||||
States states (void) const;
|
States states (void) const;
|
||||||
|
|
||||||
static void addVariableInformation (
|
static void addVarInfo (
|
||||||
VarId vid, string label, const States& states)
|
VarId vid, string label, const States& states)
|
||||||
{
|
{
|
||||||
assert (Util::contains (varsInfo_, vid) == false);
|
assert (Util::contains (varsInfo_, vid) == false);
|
||||||
varsInfo_.insert (make_pair (vid, VarInfo (label, states)));
|
varsInfo_.insert (make_pair (vid, VarInfo (label, states)));
|
||||||
}
|
}
|
||||||
|
|
||||||
static VarInfo getVarInformation (VarId vid)
|
static VarInfo getVarInfo (VarId vid)
|
||||||
{
|
{
|
||||||
assert (Util::contains (varsInfo_, vid));
|
assert (Util::contains (varsInfo_, vid));
|
||||||
return varsInfo_.find (vid)->second;
|
return varsInfo_.find (vid)->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool variablesHaveInformation (void)
|
static bool varsHaveInfo (void)
|
||||||
{
|
{
|
||||||
return varsInfo_.size() != 0;
|
return varsInfo_.size() != 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void clearVariablesInformation (void)
|
static void clearVarsInfo (void)
|
||||||
{
|
{
|
||||||
varsInfo_.clear();
|
varsInfo_.clear();
|
||||||
}
|
}
|
||||||
|
@ -60,7 +60,7 @@ VarElimSolver::createFactorList (void)
|
|||||||
const FactorNodes& factorNodes = factorGraph_->factorNodes();
|
const FactorNodes& factorNodes = factorGraph_->factorNodes();
|
||||||
factorList_.reserve (factorNodes.size() * 2);
|
factorList_.reserve (factorNodes.size() * 2);
|
||||||
for (unsigned i = 0; i < factorNodes.size(); i++) {
|
for (unsigned i = 0; i < factorNodes.size(); i++) {
|
||||||
factorList_.push_back (new Factor (*factorNodes[i]->factor()));
|
factorList_.push_back (new Factor (factorNodes[i]->factor())); // FIXME
|
||||||
const VarNodes& neighs = factorNodes[i]->neighbors();
|
const VarNodes& neighs = factorNodes[i]->neighbors();
|
||||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||||
unordered_map<VarId,vector<unsigned> >::iterator it
|
unordered_map<VarId,vector<unsigned> >::iterator it
|
||||||
@ -95,7 +95,6 @@ VarElimSolver::absorveEvidence (void)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
printActiveFactors();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@
|
|||||||
set_bayes_net_params/2,
|
set_bayes_net_params/2,
|
||||||
run_lifted_solver/3,
|
run_lifted_solver/3,
|
||||||
run_ground_solver/3,
|
run_ground_solver/3,
|
||||||
set_extra_vars_info/2,
|
set_vars_information/2,
|
||||||
set_horus_flag/2,
|
set_horus_flag/2,
|
||||||
free_parfactors/1,
|
free_parfactors/1,
|
||||||
free_ground_network/1
|
free_ground_network/1
|
||||||
|
Reference in New Issue
Block a user