Improvements

Factor nodes now contain a factor object instead of a pointer. Refactor the way .fg and .uai formats are readed.
This commit is contained in:
Tiago Gomes 2012-04-09 15:40:51 +01:00
parent f1d52c0389
commit 6986e8c0d7
16 changed files with 233 additions and 310 deletions

View File

@ -40,7 +40,6 @@
get_pfl_parameters/2 get_pfl_parameters/2
]). ]).
% :- use_module(library(clpbn/horus)).
:- use_module(library(lists)). :- use_module(library(lists)).
@ -53,45 +52,42 @@
[create_ground_network/4, [create_ground_network/4,
set_bayes_net_params/2, set_bayes_net_params/2,
run_ground_solver/3, run_ground_solver/3,
set_extra_vars_info/2, set_vars_information/2,
free_ground_network/1 free_ground_network/1
]). ]).
call_bp_ground(QueryKeys, AllKeys, Factors, Evidence, Solutions) :- call_bp_ground(QueryKeys, AllKeys, Factors, Evidence, Output) :-
b_hash_new(Hash0), b_hash_new(Hash0),
keys_to_ids(AllKeys, 0, Hash0, Hash), keys_to_ids(AllKeys, 0, Hash0, Hash),
%InvMap =.. [view|AllKeys],
list_of_keys_to_ids(QueryKeys, Hash, QueryIds),
evidence_to_ids(Evidence, Hash, EvIds),
factors_to_ids(Factors, Hash, FactorIds),
get_factors_type(Factors, Type), get_factors_type(Factors, Type),
evidence_to_ids(Evidence, Hash, EvidenceIds),
factors_to_ids(Factors, Hash, FactorIds),
writeln(type:Type), writeln(''), writeln(type:Type), writeln(''),
writeln(allKeys:AllKeys), writeln(''), writeln(allKeys:AllKeys), writeln(''),
%writeln(allKeysIds:Hash), writeln(''),
writeln(queryKeys:QueryKeys), writeln(''),
writeln(queryIds:QueryIds), writeln(''),
writeln(factors:Factors), writeln(''), writeln(factors:Factors), writeln(''),
writeln(factorIds:FactorIds), writeln(''), writeln(factorIds:FactorIds), writeln(''),
writeln(evidence:Evidence), writeln(''), writeln(evidence:Evidence), writeln(''),
writeln(evIds:EvIds), writeln(evidenceIds:EvidenceIds), writeln(''),
create_ground_network(Type, FactorIds, EvIds, Network), create_ground_network(Type, FactorIds, EvidenceIds, Network),
run_ground_fixme_solver(ground(Network,Hash), QueryKeys, Solutions). %get_vars_information(AllKeys, StatesNames),
%free_graphical_model(Network). %set_vars_information(AllKeys, StatesNames),
run_solver(ground(Network,Hash), QueryKeys, Solutions),
writeln(answer:Solutions),
%clpbn_bind_vals([QueryKeys], Solutions, Output).
free_ground_network(Network).
run_ground_fixme_solver(ground(Network,Hash), QueryKeys, Solutions) :- run_solver(ground(Network,Hash), QueryKeys, Solutions) :-
%get_dists_parameters(DistIds, DistsParams), %get_dists_parameters(DistIds, DistsParams),
%set_bayes_net_params(Network, DistsParams), %set_bayes_net_params(Network, DistsParams),
%vars_to_ids(QueryVars, QueryVarsIds), list_of_keys_to_ids(QueryKeys, Hash, QueryIds),
writeln(queryKeys:QueryKeys), writeln(''),
writeln(queryIds:QueryIds), writeln(''),
list_of_keys_to_ids(QueryKeys, Hash, QueryIds), list_of_keys_to_ids(QueryKeys, Hash, QueryIds),
run_ground_solver(Network, [QueryIds], Solutions). run_ground_solver(Network, [QueryIds], Solutions).
get_factors_type([f(bayes, _, _)|_], bayes) :- ! .
get_factors_type([f(markov, _, _)|_], markov) :- ! .
keys_to_ids([], _, Hash, Hash). keys_to_ids([], _, Hash, Hash).
keys_to_ids([Key|AllKeys], I0, Hash0, Hash) :- keys_to_ids([Key|AllKeys], I0, Hash0, Hash) :-
b_hash_insert(Hash0, Key, I0, HashI), b_hash_insert(Hash0, Key, I0, HashI),
@ -99,27 +95,16 @@ keys_to_ids([Key|AllKeys], I0, Hash0, Hash) :-
keys_to_ids(AllKeys, I, HashI, Hash). keys_to_ids(AllKeys, I, HashI, Hash).
get_factors_type([f(bayes, _, _)|_], bayes) :- ! .
get_factors_type([f(markov, _, _)|_], markov) :- ! .
list_of_keys_to_ids([], _, []). list_of_keys_to_ids([], _, []).
list_of_keys_to_ids([Key|QueryKeys], Hash, [Id|QueryIds]) :- list_of_keys_to_ids([Key|QueryKeys], Hash, [Id|QueryIds]) :-
b_hash_lookup(Key, Id, Hash), b_hash_lookup(Key, Id, Hash),
list_of_keys_to_ids(QueryKeys, Hash, QueryIds). list_of_keys_to_ids(QueryKeys, Hash, QueryIds).
evidence_to_ids([], _, []).
evidence_to_ids([Key=Ev|QueryKeys], Hash, [Id=Ev|QueryIds]) :-
b_hash_lookup(Key, Id, Hash),
evidence_to_ids(QueryKeys, Hash, QueryIds).
%evidence_to_ids([], _, [], []).
%evidence_to_ids([Key=V|QueryKeys], Hash, [Id=V|QueryIds], [Id=Name|QueryNames]) :-
% b_hash_lookup(Key, Id, Hash),
% pfl:skolem(Key,Dom),
% nth0(V, Dom, Name),
% evidence_to_ids(QueryKeys, Hash, QueryIds, QueryNames).
factors_to_ids([], _, []). factors_to_ids([], _, []).
factors_to_ids([f(_, Keys, CPT)|Fs], Hash, [f(Ids, Ranges, CPT, DistId)|NFs]) :- factors_to_ids([f(_, Keys, CPT)|Fs], Hash, [f(Ids, Ranges, CPT, DistId)|NFs]) :-
list_of_keys_to_ids(Keys, Hash, Ids), list_of_keys_to_ids(Keys, Hash, Ids),
@ -135,6 +120,22 @@ get_ranges(K.Ks, Range.Rs) :- !,
get_ranges(Ks, Rs). get_ranges(Ks, Rs).
evidence_to_ids([], _, []).
evidence_to_ids([Key=Ev|QueryKeys], Hash, [Id=Ev|QueryIds]) :-
b_hash_lookup(Key, Id, Hash),
evidence_to_ids(QueryKeys, Hash, QueryIds).
get_vars_information([], []).
get_vars_information(Key.QueryKeys, Domain.StatesNames) :-
pfl:skolem(Key, Domain),
get_vars_information(QueryKeys, StatesNames).
finalize_bp_solver(bp(Network, _)) :-
free_ground_network(Network).
bp([[]],_,_) :- !. bp([[]],_,_) :- !.
bp([QueryVars], AllVars, Output) :- bp([QueryVars], AllVars, Output) :-
init_bp_solver(_, AllVars, _, Network), init_bp_solver(_, AllVars, _, Network),
@ -144,53 +145,22 @@ bp([QueryVars], AllVars, Output) :-
init_bp_solver(_, AllVars0, _, bp(BayesNet, DistIds)) :- init_bp_solver(_, AllVars0, _, bp(BayesNet, DistIds)) :-
%writeln('init_bp_solver'),
%check_for_agg_vars(AllVars0, AllVars), %check_for_agg_vars(AllVars0, AllVars),
%writeln('clpbn_vars:'), print_clpbn_vars(AllVars),
get_vars_info(AllVars, VarsInfo, DistIds0), get_vars_info(AllVars, VarsInfo, DistIds0),
sort(DistIds0, DistIds), sort(DistIds0, DistIds),
create_ground_network(VarsInfo, BayesNet), create_ground_network(VarsInfo, BayesNet),
%get_extra_vars_info(AllVars, ExtraVarsInfo),
%set_extra_vars_info(BayesNet, ExtraVarsInfo),
%writeln(extravarsinfo:ExtraVarsInfo),
true. true.
run_bp_solver(QueryVars, Solutions, bp(Network, DistIds)) :- run_bp_solver(QueryVars, Solutions, bp(Network, DistIds)) :-
%writeln('-> run_bp_solver'),
get_dists_parameters(DistIds, DistsParams), get_dists_parameters(DistIds, DistsParams),
set_bayes_net_params(Network, DistsParams), set_bayes_net_params(Network, DistsParams),
vars_to_ids(QueryVars, QueryVarsIds), vars_to_ids(QueryVars, QueryVarsIds),
run_ground_solver(Network, QueryVarsIds, Solutions). run_ground_solver(Network, QueryVarsIds, Solutions).
finalize_bp_solver(bp(Network, _)) :-
free_ground_network(Network).
get_extra_vars_info([], []).
get_extra_vars_info([V|Vs], [v(VarId, Label, Domain)|VarsInfo]) :-
get_atts(V, [id(VarId)]), !,
clpbn:get_atts(V, [key(Key), dist(DistId, _)]),
term_to_atom(Key, Label),
get_dist_domain(DistId, Domain0),
numbers_to_atoms(Domain0, Domain),
get_extra_vars_info(Vs, VarsInfo).
get_extra_vars_info([_|Vs], VarsInfo) :-
get_extra_vars_info(Vs, VarsInfo).
get_dists_parameters([],[]). get_dists_parameters([],[]).
get_dists_parameters([Id|Ids], [dist(Id, Params)|DistsInfo]) :- get_dists_parameters([Id|Ids], [dist(Id, Params)|DistsInfo]) :-
get_dist_params(Id, Params), get_dist_params(Id, Params),
get_dists_parameters(Ids, DistsInfo). get_dists_parameters(Ids, DistsInfo).
numbers_to_atoms([], []).
numbers_to_atoms([Atom|L0], [Atom|L]) :-
atom(Atom), !,
numbers_to_atoms(L0, L).
numbers_to_atoms([Number|L0], [Atom|L]) :-
number_atom(Number, Atom),
numbers_to_atoms(L0, L).

View File

@ -63,12 +63,12 @@ BayesBall::constructGraph (FactorGraph* fg) const
const FactorNodes& facNodes = fg_.factorNodes(); const FactorNodes& facNodes = fg_.factorNodes();
for (unsigned i = 0; i < facNodes.size(); i++) { for (unsigned i = 0; i < facNodes.size(); i++) {
const DAGraphNode* n = dag_.getNode ( const DAGraphNode* n = dag_.getNode (
facNodes[i]->factor()->argument (0)); facNodes[i]->factor().argument (0));
if (n->isMarkedOnTop()) { if (n->isMarkedOnTop()) {
fg->addFactor (Factor (*(facNodes[i]->factor()))); fg->addFactor (Factor (facNodes[i]->factor()));
} else if (n->hasEvidence() && n->isVisited()) { } else if (n->hasEvidence() && n->isVisited()) {
VarIds varIds = { facNodes[i]->factor()->argument (0) }; VarIds varIds = { facNodes[i]->factor().argument (0) };
Ranges ranges = { facNodes[i]->factor()->range (0) }; Ranges ranges = { facNodes[i]->factor().range (0) };
Params params (ranges[0], LogAware::noEvidence()); Params params (ranges[0], LogAware::noEvidence());
params[n->getEvidence()] = LogAware::withEvidence(); params[n->getEvidence()] = LogAware::withEvidence();
fg->addFactor (Factor (varIds, ranges, params)); fg->addFactor (Factor (varIds, ranges, params));

View File

@ -84,5 +84,5 @@ class DAGraph
unordered_map<VarId, DAGraphNode*> varMap_; unordered_map<VarId, DAGraphNode*> varMap_;
}; };
#endif // HORUS_BAYESNET_H #endif // HORUS_BAYESNET_H

View File

@ -101,7 +101,7 @@ BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
VarNode* vn = factorGraph_->getVarNode (jointVarIds[0]); VarNode* vn = factorGraph_->getVarNode (jointVarIds[0]);
const FactorNodes& factorNodes = vn->neighbors(); const FactorNodes& factorNodes = vn->neighbors();
for (unsigned i = 0; i < factorNodes.size(); i++) { for (unsigned i = 0; i < factorNodes.size(); i++) {
if (factorNodes[i]->factor()->contains (jointVarIds)) { if (factorNodes[i]->factor().contains (jointVarIds)) {
idx = i; idx = i;
break; break;
} }
@ -109,7 +109,7 @@ BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
if (idx == -1) { if (idx == -1) {
return getJointByConditioning (jointVarIds); return getJointByConditioning (jointVarIds);
} else { } else {
Factor res (*factorNodes[idx]->factor()); Factor res (factorNodes[idx]->factor());
const SpLinkSet& links = ninf(factorNodes[idx])->getLinks(); const SpLinkSet& links = ninf(factorNodes[idx])->getLinks();
for (unsigned i = 0; i < links.size(); i++) { for (unsigned i = 0; i < links.size(); i++) {
Factor msg (links[i]->getVariable()->varId(), Factor msg (links[i]->getVariable()->varId(),
@ -316,9 +316,9 @@ BpSolver::maxResidualSchedule (void)
void void
BpSolver::calculateFactor2VariableMsg (SpLink* link) const BpSolver::calculateFactor2VariableMsg (SpLink* link)
{ {
const FactorNode* src = link->getFactor(); FactorNode* src = link->getFactor();
const VarNode* dst = link->getVariable(); const VarNode* dst = link->getVariable();
const SpLinkSet& links = ninf(src)->getLinks(); const SpLinkSet& links = ninf(src)->getLinks();
// calculate the product of messages that were sent // calculate the product of messages that were sent
@ -357,10 +357,9 @@ BpSolver::calculateFactor2VariableMsg (SpLink* link) const
} }
} }
Factor result (src->factor()->arguments(), Factor result (src->factor().arguments(),
src->factor()->ranges(), src->factor().ranges(), msgProduct);
msgProduct); result.multiply (src->factor());
result.multiply (*(src->factor()));
if (Constants::DEBUG >= 5) { if (Constants::DEBUG >= 5) {
cout << " message product: " << msgProduct << endl; cout << " message product: " << msgProduct << endl;
cout << " original factor: " << src->params() << endl; cout << " original factor: " << src->params() << endl;

View File

@ -108,7 +108,7 @@ class BpSolver : public Solver
virtual void maxResidualSchedule (void); virtual void maxResidualSchedule (void);
virtual void calculateFactor2VariableMsg (SpLink*) const; virtual void calculateFactor2VariableMsg (SpLink*);
virtual Params getVar2FactorMsg (const SpLink*) const; virtual Params getVar2FactorMsg (const SpLink*) const;

View File

@ -74,20 +74,20 @@ CFactorGraph::setInitialColors (void)
if (checkForIdenticalFactors) { if (checkForIdenticalFactors) {
unsigned groupCount = 1; unsigned groupCount = 1;
for (unsigned i = 0; i < facNodes.size(); i++) { for (unsigned i = 0; i < facNodes.size(); i++) {
Factor* f1 = facNodes[i]->factor(); Factor& f1 = facNodes[i]->factor();
if (f1->distId() != Util::maxUnsigned()) { if (f1.distId() != Util::maxUnsigned()) {
continue; continue;
} }
f1->setDistId (groupCount); f1.setDistId (groupCount);
for (unsigned j = i + 1; j < facNodes.size(); j++) { for (unsigned j = i + 1; j < facNodes.size(); j++) {
Factor* f2 = facNodes[j]->factor(); Factor& f2 = facNodes[j]->factor();
if (f2->distId() != Util::maxUnsigned()) { if (f2.distId() != Util::maxUnsigned()) {
continue; continue;
} }
if (f1->size() == f2->size() && if (f1.size() == f2.size() &&
f1->ranges() == f2->ranges() && f1.ranges() == f2.ranges() &&
f1->params() == f2->params()) { f1.params() == f2.params()) {
f2->setDistId (groupCount); f2.setDistId (groupCount);
} }
} }
groupCount ++; groupCount ++;
@ -96,7 +96,7 @@ CFactorGraph::setInitialColors (void)
// create the initial factor colors // create the initial factor colors
DistColorMap distColors; DistColorMap distColors;
for (unsigned i = 0; i < facNodes.size(); i++) { for (unsigned i = 0; i < facNodes.size(); i++) {
unsigned distId = facNodes[i]->factor()->distId(); unsigned distId = facNodes[i]->factor().distId();
DistColorMap::iterator it = distColors.find (distId); DistColorMap::iterator it = distColors.find (distId);
if (it == distColors.end()) { if (it == distColors.end()) {
it = distColors.insert (make_pair (distId, getFreeColor())).first; it = distColors.insert (make_pair (distId, getFreeColor())).first;
@ -211,7 +211,7 @@ CFactorGraph::getSignature (const VarNode* varNode)
for (unsigned i = 0; i < neighs.size(); i++) { for (unsigned i = 0; i < neighs.size(); i++) {
*it = getColor (neighs[i]); *it = getColor (neighs[i]);
it ++; it ++;
*it = neighs[i]->factor()->indexOf (varNode->varId()); *it = neighs[i]->factor().indexOf (varNode->varId());
it ++; it ++;
} }
*it = getColor (varNode); *it = getColor (varNode);
@ -244,7 +244,7 @@ CFactorGraph::getCompressedFactorGraph (void)
VarNode* var = varClusters_[i]->getGroundVarNodes()[0]; VarNode* var = varClusters_[i]->getGroundVarNodes()[0];
VarNode* newVar = new VarNode (var); VarNode* newVar = new VarNode (var);
varClusters_[i]->setRepresentativeVariable (newVar); varClusters_[i]->setRepresentativeVariable (newVar);
fg->addVariable (newVar); fg->addVarNode (newVar);
} }
for (unsigned i = 0; i < facClusters_.size(); i++) { for (unsigned i = 0; i < facClusters_.size(); i++) {
@ -255,11 +255,10 @@ CFactorGraph::getCompressedFactorGraph (void)
VarNode* v = myVarClusters[j]->getRepresentativeVariable(); VarNode* v = myVarClusters[j]->getRepresentativeVariable();
myGroundVars.push_back (v); myGroundVars.push_back (v);
} }
Factor* newFactor = new Factor (myGroundVars, FactorNode* fn = new FactorNode (Factor (myGroundVars,
facClusters_[i]->getGroundFactors()[0]->params()); facClusters_[i]->getGroundFactors()[0]->params()));
FactorNode* fn = new FactorNode (newFactor);
facClusters_[i]->setRepresentativeFactor (fn); facClusters_[i]->setRepresentativeFactor (fn);
fg->addFactor (fn); fg->addFactorNode (fn);
for (unsigned j = 0; j < myGroundVars.size(); j++) { for (unsigned j = 0; j < myGroundVars.size(); j++) {
fg->addEdge (fn, static_cast<VarNode*> (myGroundVars[j])); fg->addEdge (fn, static_cast<VarNode*> (myGroundVars[j]));
} }
@ -279,7 +278,7 @@ CFactorGraph::getGroundEdgeCount (
VarNode* varNode = vc->getGroundVarNodes()[0]; VarNode* varNode = vc->getGroundVarNodes()[0];
unsigned count = 0; unsigned count = 0;
for (unsigned i = 0; i < clusterGroundFactors.size(); i++) { for (unsigned i = 0; i < clusterGroundFactors.size(); i++) {
if (clusterGroundFactors[i]->factor()->indexOf (varNode->varId()) != -1) { if (clusterGroundFactors[i]->factor().indexOf (varNode->varId()) != -1) {
count ++; count ++;
} }
} }

View File

@ -18,21 +18,20 @@ bool FactorGraph::orderFactorVariables = false;
FactorGraph::FactorGraph (const FactorGraph& fg) FactorGraph::FactorGraph (const FactorGraph& fg)
{ {
const VarNodes& vars = fg.varNodes(); const VarNodes& varNodes = fg.varNodes();
for (unsigned i = 0; i < vars.size(); i++) { for (unsigned i = 0; i < varNodes.size(); i++) {
VarNode* varNode = new VarNode (vars[i]); addVarNode (new VarNode (varNodes[i]));
addVariable (varNode);
} }
const FactorNodes& facNodes = fg.factorNodes();
const FactorNodes& facs = fg.factorNodes(); for (unsigned i = 0; i < facNodes.size(); i++) {
for (unsigned i = 0; i < facs.size(); i++) { FactorNode* facNode = new FactorNode (facNodes[i]->factor());
FactorNode* facNode = new FactorNode (facs[i]); addFactorNode (facNode);
addFactor (facNode); const VarNodes& neighs = facNodes[i]->neighbors();
const VarNodes& neighs = facs[i]->neighbors();
for (unsigned j = 0; j < neighs.size(); j++) { for (unsigned j = 0; j < neighs.size(); j++) {
addEdge (facNode, varNodes_[neighs[j]->getIndex()]); addEdge (facNode, varNodes_[neighs[j]->getIndex()]);
} }
} }
setIndexes();
} }
@ -40,82 +39,70 @@ FactorGraph::FactorGraph (const FactorGraph& fg)
void void
FactorGraph::readFromUaiFormat (const char* fileName) FactorGraph::readFromUaiFormat (const char* fileName)
{ {
ifstream is (fileName); ifstream is (fileName);
if (!is.is_open()) { if (!is.is_open()) {
cerr << "error: cannot read from file " + std::string (fileName) << endl; cerr << "error: cannot read from file " + std::string (fileName) << endl;
abort(); abort();
} }
ignoreLines (is);
string line; string line;
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
getline (is, line); getline (is, line);
if (line != "MARKOV") { if (line != "MARKOV") {
cerr << "error: the network must be a MARKOV network " << endl; cerr << "error: the network must be a MARKOV network " << endl;
abort(); abort();
} }
// read the number of vars
while (is.peek() == '#' || is.peek() == '\n') getline (is, line); ignoreLines (is);
unsigned nVars; unsigned nrVars;
is >> nVars; is >> nrVars;
// read the range of each var
while (is.peek() == '#' || is.peek() == '\n') getline (is, line); ignoreLines (is);
vector<int> domainSizes (nVars); Ranges ranges (nrVars);
for (unsigned i = 0; i < nVars; i++) { for (unsigned i = 0; i < nrVars; i++) {
unsigned ds; is >> ranges[i];
is >> ds;
domainSizes[i] = ds;
} }
unsigned nrFactors;
while (is.peek() == '#' || is.peek() == '\n') getline (is, line); unsigned nrArgs;
for (unsigned i = 0; i < nVars; i++) { unsigned vid;
addVariable (new VarNode (i, domainSizes[i])); is >> nrFactors;
} vector<VarIds> factorVarIds;
vector<Ranges> factorRanges;
unsigned nFactors; for (unsigned i = 0; i < nrFactors; i++) {
is >> nFactors; ignoreLines (is);
for (unsigned i = 0; i < nFactors; i++) { is >> nrArgs;
while (is.peek() == '#' || is.peek() == '\n') getline (is, line); factorVarIds.push_back ({ });
unsigned nFactorVars; factorRanges.push_back ({ });
is >> nFactorVars; for (unsigned j = 0; j < nrArgs; j++) {
Vars neighs;
for (unsigned j = 0; j < nFactorVars; j++) {
unsigned vid;
is >> vid; is >> vid;
VarNode* neigh = getVarNode (vid); if (vid >= ranges.size()) {
if (!neigh) { cerr << "error: invalid variable identifier `" << vid << "'" << endl;
cerr << "error: invalid variable identifier (" << vid << ")" << endl; cerr << "identifiers must be between 0 and " << ranges.size() - 1 ;
cerr << endl;
abort(); abort();
} }
neighs.push_back (neigh); factorVarIds.back().push_back (vid);
} factorRanges.back().push_back (ranges[vid]);
FactorNode* fn = new FactorNode (new Factor (neighs));
addFactor (fn);
for (unsigned j = 0; j < neighs.size(); j++) {
addEdge (fn, static_cast<VarNode*> (neighs[j]));
} }
} }
// read the parameters
for (unsigned i = 0; i < nFactors; i++) { unsigned nrParams;
while (is.peek() == '#' || is.peek() == '\n') getline (is, line); for (unsigned i = 0; i < nrFactors; i++) {
unsigned nParams; ignoreLines (is);
is >> nParams; is >> nrParams;
if (facNodes_[i]->params().size() != nParams) { if (nrParams != Util::expectedSize (factorRanges[i])) {
cerr << "error: invalid number of parameters for factor " ; cerr << "error: invalid number of parameters for factor nº " << i ;
cerr << facNodes_[i]->getLabel() ; cerr << ", expected: " << Util::expectedSize (factorRanges[i]);
cerr << ", expected: " << facNodes_[i]->params().size(); cerr << ", given: " << nrParams << endl;
cerr << ", given: " << nParams << endl;
abort(); abort();
} }
Params params (nParams); Params params (nrParams);
for (unsigned j = 0; j < nParams; j++) { for (unsigned j = 0; j < nrParams; j++) {
double param; is >> params[j];
is >> param;
params[j] = param;
} }
if (Globals::logDomain) { if (Globals::logDomain) {
Util::toLog (params); Util::toLog (params);
} }
facNodes_[i]->factor()->setParams (params); addFactor (Factor (factorVarIds[i], factorRanges[i], params));
} }
is.close(); is.close();
setIndexes(); setIndexes();
@ -131,79 +118,51 @@ FactorGraph::readFromLibDaiFormat (const char* fileName)
cerr << "error: cannot read from file " + std::string (fileName) << endl; cerr << "error: cannot read from file " + std::string (fileName) << endl;
abort(); abort();
} }
ignoreLines (is);
string line; unsigned nrFactors;
unsigned nFactors; unsigned nrArgs;
VarId vid;
while ((is.peek()) == '#') getline (is, line); is >> nrFactors;
is >> nFactors; for (unsigned i = 0; i < nrFactors; i++) {
ignoreLines (is);
if (is.fail()) { // read the factor arguments
cerr << "error: cannot read the number of factors" << endl; is >> nrArgs;
abort();
}
getline (is, line);
if (is.fail() || line.size() > 0) {
cerr << "error: cannot read the number of factors" << endl;
abort();
}
for (unsigned i = 0; i < nFactors; i++) {
unsigned nVars;
while ((is.peek()) == '#') getline (is, line);
is >> nVars;
VarIds vids; VarIds vids;
for (unsigned j = 0; j < nVars; j++) { for (unsigned j = 0; j < nrArgs; j++) {
VarId vid; ignoreLines (is);
while ((is.peek()) == '#') getline (is, line);
is >> vid; is >> vid;
vids.push_back (vid); vids.push_back (vid);
} }
// read ranges
Vars neighs; Ranges ranges (nrArgs);
unsigned nParams = 1; for (unsigned j = 0; j < nrArgs; j++) {
for (unsigned j = 0; j < nVars; j++) { ignoreLines (is);
unsigned dsize; is >> ranges[j];
while ((is.peek()) == '#') getline (is, line);
is >> dsize;
VarNode* var = getVarNode (vids[j]); VarNode* var = getVarNode (vids[j]);
if (var == 0) { if (var != 0 && ranges[j] != var->range()) {
var = new VarNode (vids[j], dsize); cerr << "error: variable `" << vids[j] << "' appears in two or " ;
addVariable (var); cerr << "more factors with a different range" << endl;
} else {
if (var->range() != dsize) {
cerr << "error: variable `" << vids[j] << "' appears in two or " ;
cerr << "more factors with different domain sizes" << endl;
}
} }
neighs.push_back (var);
nParams *= var->range();
} }
Params params (nParams, 0); // read parameters
ignoreLines (is);
unsigned nNonzeros; unsigned nNonzeros;
while ((is.peek()) == '#') getline (is, line);
is >> nNonzeros; is >> nNonzeros;
Params params (Util::expectedSize (ranges), 0);
for (unsigned j = 0; j < nNonzeros; j++) { for (unsigned j = 0; j < nNonzeros; j++) {
ignoreLines (is);
unsigned index; unsigned index;
double val;
while ((is.peek()) == '#') getline (is, line);
is >> index; is >> index;
while ((is.peek()) == '#') getline (is, line); ignoreLines (is);
double val;
is >> val; is >> val;
params[index] = val; params[index] = val;
} }
reverse (neighs.begin(), neighs.end()); reverse (vids.begin(), vids.end());
if (Globals::logDomain) { if (Globals::logDomain) {
Util::toLog (params); Util::toLog (params);
} }
FactorNode* fn = new FactorNode (new Factor (neighs, params)); addFactor (Factor (vids, ranges, params));
addFactor (fn);
for (unsigned j = 0; j < neighs.size(); j++) {
addEdge (fn, static_cast<VarNode*> (neighs[j]));
}
} }
is.close(); is.close();
setIndexes(); setIndexes();
@ -223,30 +182,11 @@ FactorGraph::~FactorGraph (void)
void
FactorGraph::addVariable (VarNode* vn)
{
varNodes_.push_back (vn);
vn->setIndex (varNodes_.size() - 1);
varMap_.insert (make_pair (vn->varId(), varNodes_.size() - 1));
}
void
FactorGraph::addFactor (FactorNode* fn)
{
facNodes_.push_back (fn);
fn->setIndex (facNodes_.size() - 1);
}
void void
FactorGraph::addFactor (const Factor& factor) FactorGraph::addFactor (const Factor& factor)
{ {
FactorNode* fn = new FactorNode (factor); FactorNode* fn = new FactorNode (factor);
addFactor (fn); addFactorNode (fn);
const VarIds& vids = factor.arguments(); const VarIds& vids = factor.arguments();
for (unsigned i = 0; i < vids.size(); i++) { for (unsigned i = 0; i < vids.size(); i++) {
bool found = false; bool found = false;
@ -258,7 +198,7 @@ FactorGraph::addFactor (const Factor& factor)
} }
if (found == false) { if (found == false) {
VarNode* vn = new VarNode (vids[i], factor.range (i)); VarNode* vn = new VarNode (vids[i], factor.range (i));
addVariable (vn); addVarNode (vn);
addEdge (vn, fn); addEdge (vn, fn);
} }
} }
@ -266,6 +206,25 @@ FactorGraph::addFactor (const Factor& factor)
void
FactorGraph::addVarNode (VarNode* vn)
{
varNodes_.push_back (vn);
vn->setIndex (varNodes_.size() - 1);
varMap_.insert (make_pair (vn->varId(), varNodes_.size() - 1));
}
void
FactorGraph::addFactorNode (FactorNode* fn)
{
facNodes_.push_back (fn);
fn->setIndex (facNodes_.size() - 1);
}
void void
FactorGraph::addEdge (VarNode* vn, FactorNode* fn) FactorGraph::addEdge (VarNode* vn, FactorNode* fn)
{ {
@ -301,7 +260,7 @@ FactorGraph::getStructure (void)
structure_.addNode (new DAGraphNode (varNodes_[i])); structure_.addNode (new DAGraphNode (varNodes_[i]));
} }
for (unsigned i = 0; i < facNodes_.size(); i++) { for (unsigned i = 0; i < facNodes_.size(); i++) {
const VarIds& vids = facNodes_[i]->factor()->arguments(); const VarIds& vids = facNodes_[i]->factor().arguments();
for (unsigned j = 1; j < vids.size(); j++) { for (unsigned j = 1; j < vids.size(); j++) {
structure_.addEdge (vids[j], vids[0]); structure_.addEdge (vids[j], vids[0]);
} }
@ -340,7 +299,7 @@ FactorGraph::print (void) const
cout << endl << endl; cout << endl << endl;
} }
for (unsigned i = 0; i < facNodes_.size(); i++) { for (unsigned i = 0; i < facNodes_.size(); i++) {
facNodes_[i]->factor()->print(); facNodes_[i]->factor().print();
} }
} }
@ -446,7 +405,7 @@ FactorGraph::exportToLibDaiFormat (const char* fileName) const
out << factorVars[j]->range() << " " ; out << factorVars[j]->range() << " " ;
} }
out << endl; out << endl;
Params params = facNodes_[i]->factor()->params(); Params params = facNodes_[i]->factor().params();
if (Globals::logDomain) { if (Globals::logDomain) {
Util::fromLog (params); Util::fromLog (params);
} }
@ -461,6 +420,17 @@ FactorGraph::exportToLibDaiFormat (const char* fileName) const
void
FactorGraph::ignoreLines (std::ifstream& is) const
{
string ignoreStr;
while (is.peek() == '#' || is.peek() == '\n') {
getline (is, ignoreStr);
}
}
bool bool
FactorGraph::containsCycle (void) const FactorGraph::containsCycle (void) const
{ {

View File

@ -34,47 +34,28 @@ class VarNode : public Var
class FactorNode class FactorNode
{ {
public: public:
FactorNode (const FactorNode* fn) FactorNode (const Factor& f) : factor_(f), index_(-1) { }
{
factor_ = new Factor (*fn->factor());
index_ = -1;
}
FactorNode (Factor* f) : factor_(new Factor(*f)), index_(-1) { } const Factor& factor (void) const { return factor_; }
FactorNode (const Factor& f) : factor_(new Factor (f)), index_(-1) { } Factor& factor (void) { return factor_; }
Factor* factor() const { return factor_; }
void addNeighbor (VarNode* vn) { neighs_.push_back (vn); } void addNeighbor (VarNode* vn) { neighs_.push_back (vn); }
const VarNodes& neighbors (void) const { return neighs_; } const VarNodes& neighbors (void) const { return neighs_; }
int getIndex (void) const int getIndex (void) const { return index_; }
{
assert (index_ != -1);
return index_;
}
void setIndex (int index) void setIndex (int index) { index_ = index; }
{
index_ = index;
}
const Params& params (void) const const Params& params (void) const { return factor_.params(); }
{
return factor_->params();
}
string getLabel (void) string getLabel (void) { return factor_.getLabel(); }
{
return factor_->getLabel();
}
private: private:
DISALLOW_COPY_AND_ASSIGN (FactorNode); DISALLOW_COPY_AND_ASSIGN (FactorNode);
Factor* factor_; Factor factor_;
VarNodes neighs_; VarNodes neighs_;
int index_; int index_;
}; };
@ -116,12 +97,12 @@ class FactorGraph
void readFromLibDaiFormat (const char*); void readFromLibDaiFormat (const char*);
void addVariable (VarNode*);
void addFactor (FactorNode*);
void addFactor (const Factor& factor); void addFactor (const Factor& factor);
void addVarNode (VarNode*);
void addFactorNode (FactorNode*);
void addEdge (VarNode*, FactorNode*); void addEdge (VarNode*, FactorNode*);
void addEdge (FactorNode*, VarNode*); void addEdge (FactorNode*, VarNode*);
@ -145,6 +126,8 @@ class FactorGraph
private: private:
// DISALLOW_COPY_AND_ASSIGN (FactorGraph); // DISALLOW_COPY_AND_ASSIGN (FactorGraph);
void ignoreLines (std::ifstream&) const;
bool containsCycle (void) const; bool containsCycle (void) const;
bool containsCycle (const VarNode*, const FactorNode*, bool containsCycle (const VarNode*, const FactorNode*,
@ -153,7 +136,7 @@ class FactorGraph
bool containsCycle (const FactorNode*, const VarNode*, bool containsCycle (const FactorNode*, const VarNode*,
vector<bool>&, vector<bool>&) const; vector<bool>&, vector<bool>&) const;
VarNodes varNodes_; VarNodes varNodes_;
FactorNodes facNodes_; FactorNodes facNodes_;
bool fromBayesNet_; bool fromBayesNet_;

View File

@ -31,15 +31,16 @@ main (int argc, const char* argv[])
FactorGraph fg; FactorGraph fg;
if (extension == "uai") { if (extension == "uai") {
fg.readFromUaiFormat (argv[1]); fg.readFromUaiFormat (argv[1]);
processArguments (fg, argc, argv);
} else if (extension == "fg") { } else if (extension == "fg") {
fg.readFromLibDaiFormat (argv[1]); fg.readFromLibDaiFormat (argv[1]);
processArguments (fg, argc, argv);
} else { } else {
cerr << "error: the graphical model must be defined either " ; cerr << "error: the graphical model must be defined either " ;
cerr << "in a UAI or libDAI file" << endl; cerr << "in a UAI or libDAI file" << endl;
exit (0); exit (0);
} }
fg.print();
assert (false);
processArguments (fg, argc, argv);
return 0; return 0;
} }

View File

@ -214,8 +214,6 @@ createGroundNetwork (void)
{ {
FactorGraph* fg = new FactorGraph();; FactorGraph* fg = new FactorGraph();;
string factorsType ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG1))); string factorsType ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG1)));
cout << "factors type: '" << factorsType << "'" << endl;
YAP_Term factorList = YAP_ARG2; YAP_Term factorList = YAP_ARG2;
while (factorList != YAP_TermNil()) { while (factorList != YAP_TermNil()) {
YAP_Term factor = YAP_HeadOfTerm (factorList); YAP_Term factor = YAP_HeadOfTerm (factorList);
@ -240,7 +238,6 @@ createGroundNetwork (void)
YAP_Term evTerm = YAP_HeadOfTerm (evidenceList); YAP_Term evTerm = YAP_HeadOfTerm (evidenceList);
unsigned vid = (unsigned) YAP_IntOfTerm ((YAP_ArgOfTerm (1, evTerm))); unsigned vid = (unsigned) YAP_IntOfTerm ((YAP_ArgOfTerm (1, evTerm)));
unsigned ev = (unsigned) YAP_IntOfTerm ((YAP_ArgOfTerm (2, evTerm))); unsigned ev = (unsigned) YAP_IntOfTerm ((YAP_ArgOfTerm (2, evTerm)));
cout << vid << " == " << ev << endl;
assert (fg->getVarNode (vid)); assert (fg->getVarNode (vid));
fg->getVarNode (vid)->setEvidence (ev); fg->getVarNode (vid)->setEvidence (ev);
evidenceList = YAP_TailOfTerm (evidenceList); evidenceList = YAP_TailOfTerm (evidenceList);
@ -354,7 +351,7 @@ runGroundSolver (void)
} else { } else {
runBpSolver (fg, tasks, results); runBpSolver (fg, tasks, results);
} }
cout << "results: " << results << endl;
YAP_Term list = YAP_TermNil(); YAP_Term list = YAP_TermNil();
for (int i = results.size() - 1; i >= 0; i--) { for (int i = results.size() - 1; i >= 0; i--) {
const Params& beliefs = results[i]; const Params& beliefs = results[i];
@ -491,24 +488,29 @@ setBayesNetParams (void)
int int
setExtraVarsInfo (void) setVarsInformation (void)
{ {
Var::clearVariablesInformation(); Var::clearVarsInfo();
YAP_Term varsInfoL = YAP_ARG2; YAP_Term labelsL = YAP_ARG1;
while (varsInfoL != YAP_TermNil()) { vector<string> labels;
YAP_Term head = YAP_HeadOfTerm (varsInfoL); while (labelsL != YAP_TermNil()) {
VarId vid = YAP_IntOfTerm (YAP_ArgOfTerm (1, head)); YAP_Atom atom = YAP_AtomOfTerm (YAP_HeadOfTerm (labelsL));
YAP_Atom label = YAP_AtomOfTerm (YAP_ArgOfTerm (2, head)); labels.push_back ((char*) YAP_AtomName (atom));
YAP_Term statesL = YAP_ArgOfTerm (3, head); labelsL = YAP_TailOfTerm (labelsL);
}
unsigned count = 0;
YAP_Term stateNamesL = YAP_ARG2;
while (stateNamesL != YAP_TermNil()) {
States states; States states;
while (statesL != YAP_TermNil()) { YAP_Term namesL = YAP_HeadOfTerm (stateNamesL);
YAP_Atom atom = YAP_AtomOfTerm (YAP_HeadOfTerm (statesL)); while (namesL != YAP_TermNil()) {
YAP_Atom atom = YAP_AtomOfTerm (YAP_HeadOfTerm (namesL));
states.push_back ((char*) YAP_AtomName (atom)); states.push_back ((char*) YAP_AtomName (atom));
statesL = YAP_TailOfTerm (statesL); namesL = YAP_TailOfTerm (namesL);
} }
Var::addVariableInformation (vid, Var::addVarInfo (count, labels[count], states);
(char*) YAP_AtomName (label), states); count ++;
varsInfoL = YAP_TailOfTerm (varsInfoL); stateNamesL = YAP_TailOfTerm (stateNamesL);
} }
return TRUE; return TRUE;
} }
@ -627,7 +629,7 @@ init_predicates (void)
YAP_UserCPredicate ("run_ground_solver", runGroundSolver, 3); YAP_UserCPredicate ("run_ground_solver", runGroundSolver, 3);
YAP_UserCPredicate ("set_parfactors_params", setParfactorsParams, 2); YAP_UserCPredicate ("set_parfactors_params", setParfactorsParams, 2);
YAP_UserCPredicate ("set_bayes_net_params", setBayesNetParams, 2); YAP_UserCPredicate ("set_bayes_net_params", setBayesNetParams, 2);
YAP_UserCPredicate ("set_extra_vars_info", setExtraVarsInfo, 2); YAP_UserCPredicate ("set_vars_information", setVarsInformation, 2);
YAP_UserCPredicate ("set_horus_flag", setHorusFlag, 2); YAP_UserCPredicate ("set_horus_flag", setHorusFlag, 2);
YAP_UserCPredicate ("free_parfactors", freeParfactors, 1); YAP_UserCPredicate ("free_parfactors", freeParfactors, 1);
YAP_UserCPredicate ("free_ground_network", freeGroundNetwork, 1); YAP_UserCPredicate ("free_ground_network", freeGroundNetwork, 1);

View File

@ -366,12 +366,12 @@ Statistics::printStatistics (void)
void void
Statistics::writeStatisticsToFile (const char* fileName) Statistics::writeStatistics (const char* fileName)
{ {
ofstream out (fileName); ofstream out (fileName);
if (!out.is_open()) { if (!out.is_open()) {
cerr << "error: cannot open file to write at " ; cerr << "error: cannot open file to write at " ;
cerr << "Statistics::writeStatisticsToFile()" << endl; cerr << "Statistics::writeStats()" << endl;
abort(); abort();
} }
out << getStatisticString(); out << getStatisticString();

View File

@ -359,7 +359,7 @@ class Statistics
static void printStatistics (void); static void printStatistics (void);
static void writeStatisticsToFile (const char*); static void writeStatistics (const char*);
static void updateCompressingStatistics ( static void updateCompressingStatistics (
unsigned, unsigned, unsigned, unsigned, unsigned); unsigned, unsigned, unsigned, unsigned, unsigned);

View File

@ -42,7 +42,7 @@ Var::isValidState (int stateIndex)
bool bool
Var::isValidState (const string& stateName) Var::isValidState (const string& stateName)
{ {
States states = Var::getVarInformation (varId_).states; States states = Var::getVarInfo (varId_).states;
return Util::contains (states, stateName); return Util::contains (states, stateName);
} }
@ -60,7 +60,7 @@ Var::setEvidence (int ev)
void void
Var::setEvidence (const string& ev) Var::setEvidence (const string& ev)
{ {
States states = Var::getVarInformation (varId_).states; States states = Var::getVarInfo (varId_).states;
for (unsigned i = 0; i < states.size(); i++) { for (unsigned i = 0; i < states.size(); i++) {
if (states[i] == ev) { if (states[i] == ev) {
evidence_ = i; evidence_ = i;
@ -75,8 +75,8 @@ Var::setEvidence (const string& ev)
string string
Var::label (void) const Var::label (void) const
{ {
if (Var::variablesHaveInformation()) { if (Var::varsHaveInfo()) {
return Var::getVarInformation (varId_).label; return Var::getVarInfo (varId_).label;
} }
stringstream ss; stringstream ss;
ss << "x" << varId_; ss << "x" << varId_;
@ -88,8 +88,8 @@ Var::label (void) const
States States
Var::states (void) const Var::states (void) const
{ {
if (Var::variablesHaveInformation()) { if (Var::varsHaveInfo()) {
return Var::getVarInformation (varId_).states; return Var::getVarInfo (varId_).states;
} }
States states; States states;
for (unsigned i = 0; i < range_; i++) { for (unsigned i = 0; i < range_; i++) {

View File

@ -71,25 +71,25 @@ class Var
States states (void) const; States states (void) const;
static void addVariableInformation ( static void addVarInfo (
VarId vid, string label, const States& states) VarId vid, string label, const States& states)
{ {
assert (Util::contains (varsInfo_, vid) == false); assert (Util::contains (varsInfo_, vid) == false);
varsInfo_.insert (make_pair (vid, VarInfo (label, states))); varsInfo_.insert (make_pair (vid, VarInfo (label, states)));
} }
static VarInfo getVarInformation (VarId vid) static VarInfo getVarInfo (VarId vid)
{ {
assert (Util::contains (varsInfo_, vid)); assert (Util::contains (varsInfo_, vid));
return varsInfo_.find (vid)->second; return varsInfo_.find (vid)->second;
} }
static bool variablesHaveInformation (void) static bool varsHaveInfo (void)
{ {
return varsInfo_.size() != 0; return varsInfo_.size() != 0;
} }
static void clearVariablesInformation (void) static void clearVarsInfo (void)
{ {
varsInfo_.clear(); varsInfo_.clear();
} }

View File

@ -60,7 +60,7 @@ VarElimSolver::createFactorList (void)
const FactorNodes& factorNodes = factorGraph_->factorNodes(); const FactorNodes& factorNodes = factorGraph_->factorNodes();
factorList_.reserve (factorNodes.size() * 2); factorList_.reserve (factorNodes.size() * 2);
for (unsigned i = 0; i < factorNodes.size(); i++) { for (unsigned i = 0; i < factorNodes.size(); i++) {
factorList_.push_back (new Factor (*factorNodes[i]->factor())); factorList_.push_back (new Factor (factorNodes[i]->factor())); // FIXME
const VarNodes& neighs = factorNodes[i]->neighbors(); const VarNodes& neighs = factorNodes[i]->neighbors();
for (unsigned j = 0; j < neighs.size(); j++) { for (unsigned j = 0; j < neighs.size(); j++) {
unordered_map<VarId,vector<unsigned> >::iterator it unordered_map<VarId,vector<unsigned> >::iterator it
@ -95,7 +95,6 @@ VarElimSolver::absorveEvidence (void)
} }
} }
} }
printActiveFactors();
} }

View File

@ -12,7 +12,7 @@
set_bayes_net_params/2, set_bayes_net_params/2,
run_lifted_solver/3, run_lifted_solver/3,
run_ground_solver/3, run_ground_solver/3,
set_extra_vars_info/2, set_vars_information/2,
set_horus_flag/2, set_horus_flag/2,
free_parfactors/1, free_parfactors/1,
free_ground_network/1 free_ground_network/1