Improvements

Factor nodes now contain a factor object instead of a pointer. Refactor the way .fg and .uai formats are readed.
This commit is contained in:
Tiago Gomes 2012-04-09 15:40:51 +01:00
parent f1d52c0389
commit 6986e8c0d7
16 changed files with 233 additions and 310 deletions

View File

@ -40,7 +40,6 @@
get_pfl_parameters/2
]).
% :- use_module(library(clpbn/horus)).
:- use_module(library(lists)).
@ -53,45 +52,42 @@
[create_ground_network/4,
set_bayes_net_params/2,
run_ground_solver/3,
set_extra_vars_info/2,
set_vars_information/2,
free_ground_network/1
]).
call_bp_ground(QueryKeys, AllKeys, Factors, Evidence, Solutions) :-
call_bp_ground(QueryKeys, AllKeys, Factors, Evidence, Output) :-
b_hash_new(Hash0),
keys_to_ids(AllKeys, 0, Hash0, Hash),
%InvMap =.. [view|AllKeys],
list_of_keys_to_ids(QueryKeys, Hash, QueryIds),
evidence_to_ids(Evidence, Hash, EvIds),
factors_to_ids(Factors, Hash, FactorIds),
get_factors_type(Factors, Type),
evidence_to_ids(Evidence, Hash, EvidenceIds),
factors_to_ids(Factors, Hash, FactorIds),
writeln(type:Type), writeln(''),
writeln(allKeys:AllKeys), writeln(''),
%writeln(allKeysIds:Hash), writeln(''),
writeln(queryKeys:QueryKeys), writeln(''),
writeln(queryIds:QueryIds), writeln(''),
writeln(factors:Factors), writeln(''),
writeln(factorIds:FactorIds), writeln(''),
writeln(evidence:Evidence), writeln(''),
writeln(evIds:EvIds),
create_ground_network(Type, FactorIds, EvIds, Network),
run_ground_fixme_solver(ground(Network,Hash), QueryKeys, Solutions).
%free_graphical_model(Network).
writeln(evidenceIds:EvidenceIds), writeln(''),
create_ground_network(Type, FactorIds, EvidenceIds, Network),
%get_vars_information(AllKeys, StatesNames),
%set_vars_information(AllKeys, StatesNames),
run_solver(ground(Network,Hash), QueryKeys, Solutions),
writeln(answer:Solutions),
%clpbn_bind_vals([QueryKeys], Solutions, Output).
free_ground_network(Network).
run_ground_fixme_solver(ground(Network,Hash), QueryKeys, Solutions) :-
run_solver(ground(Network,Hash), QueryKeys, Solutions) :-
%get_dists_parameters(DistIds, DistsParams),
%set_bayes_net_params(Network, DistsParams),
%vars_to_ids(QueryVars, QueryVarsIds),
list_of_keys_to_ids(QueryKeys, Hash, QueryIds),
writeln(queryKeys:QueryKeys), writeln(''),
writeln(queryIds:QueryIds), writeln(''),
list_of_keys_to_ids(QueryKeys, Hash, QueryIds),
run_ground_solver(Network, [QueryIds], Solutions).
get_factors_type([f(bayes, _, _)|_], bayes) :- ! .
get_factors_type([f(markov, _, _)|_], markov) :- ! .
keys_to_ids([], _, Hash, Hash).
keys_to_ids([Key|AllKeys], I0, Hash0, Hash) :-
b_hash_insert(Hash0, Key, I0, HashI),
@ -99,27 +95,16 @@ keys_to_ids([Key|AllKeys], I0, Hash0, Hash) :-
keys_to_ids(AllKeys, I, HashI, Hash).
get_factors_type([f(bayes, _, _)|_], bayes) :- ! .
get_factors_type([f(markov, _, _)|_], markov) :- ! .
list_of_keys_to_ids([], _, []).
list_of_keys_to_ids([Key|QueryKeys], Hash, [Id|QueryIds]) :-
b_hash_lookup(Key, Id, Hash),
list_of_keys_to_ids(QueryKeys, Hash, QueryIds).
evidence_to_ids([], _, []).
evidence_to_ids([Key=Ev|QueryKeys], Hash, [Id=Ev|QueryIds]) :-
b_hash_lookup(Key, Id, Hash),
evidence_to_ids(QueryKeys, Hash, QueryIds).
%evidence_to_ids([], _, [], []).
%evidence_to_ids([Key=V|QueryKeys], Hash, [Id=V|QueryIds], [Id=Name|QueryNames]) :-
% b_hash_lookup(Key, Id, Hash),
% pfl:skolem(Key,Dom),
% nth0(V, Dom, Name),
% evidence_to_ids(QueryKeys, Hash, QueryIds, QueryNames).
factors_to_ids([], _, []).
factors_to_ids([f(_, Keys, CPT)|Fs], Hash, [f(Ids, Ranges, CPT, DistId)|NFs]) :-
list_of_keys_to_ids(Keys, Hash, Ids),
@ -135,6 +120,22 @@ get_ranges(K.Ks, Range.Rs) :- !,
get_ranges(Ks, Rs).
evidence_to_ids([], _, []).
evidence_to_ids([Key=Ev|QueryKeys], Hash, [Id=Ev|QueryIds]) :-
b_hash_lookup(Key, Id, Hash),
evidence_to_ids(QueryKeys, Hash, QueryIds).
get_vars_information([], []).
get_vars_information(Key.QueryKeys, Domain.StatesNames) :-
pfl:skolem(Key, Domain),
get_vars_information(QueryKeys, StatesNames).
finalize_bp_solver(bp(Network, _)) :-
free_ground_network(Network).
bp([[]],_,_) :- !.
bp([QueryVars], AllVars, Output) :-
init_bp_solver(_, AllVars, _, Network),
@ -144,53 +145,22 @@ bp([QueryVars], AllVars, Output) :-
init_bp_solver(_, AllVars0, _, bp(BayesNet, DistIds)) :-
%writeln('init_bp_solver'),
%check_for_agg_vars(AllVars0, AllVars),
%writeln('clpbn_vars:'), print_clpbn_vars(AllVars),
get_vars_info(AllVars, VarsInfo, DistIds0),
sort(DistIds0, DistIds),
create_ground_network(VarsInfo, BayesNet),
%get_extra_vars_info(AllVars, ExtraVarsInfo),
%set_extra_vars_info(BayesNet, ExtraVarsInfo),
%writeln(extravarsinfo:ExtraVarsInfo),
true.
run_bp_solver(QueryVars, Solutions, bp(Network, DistIds)) :-
%writeln('-> run_bp_solver'),
get_dists_parameters(DistIds, DistsParams),
set_bayes_net_params(Network, DistsParams),
vars_to_ids(QueryVars, QueryVarsIds),
run_ground_solver(Network, QueryVarsIds, Solutions).
finalize_bp_solver(bp(Network, _)) :-
free_ground_network(Network).
get_extra_vars_info([], []).
get_extra_vars_info([V|Vs], [v(VarId, Label, Domain)|VarsInfo]) :-
get_atts(V, [id(VarId)]), !,
clpbn:get_atts(V, [key(Key), dist(DistId, _)]),
term_to_atom(Key, Label),
get_dist_domain(DistId, Domain0),
numbers_to_atoms(Domain0, Domain),
get_extra_vars_info(Vs, VarsInfo).
get_extra_vars_info([_|Vs], VarsInfo) :-
get_extra_vars_info(Vs, VarsInfo).
get_dists_parameters([],[]).
get_dists_parameters([Id|Ids], [dist(Id, Params)|DistsInfo]) :-
get_dist_params(Id, Params),
get_dists_parameters(Ids, DistsInfo).
numbers_to_atoms([], []).
numbers_to_atoms([Atom|L0], [Atom|L]) :-
atom(Atom), !,
numbers_to_atoms(L0, L).
numbers_to_atoms([Number|L0], [Atom|L]) :-
number_atom(Number, Atom),
numbers_to_atoms(L0, L).

View File

@ -63,12 +63,12 @@ BayesBall::constructGraph (FactorGraph* fg) const
const FactorNodes& facNodes = fg_.factorNodes();
for (unsigned i = 0; i < facNodes.size(); i++) {
const DAGraphNode* n = dag_.getNode (
facNodes[i]->factor()->argument (0));
facNodes[i]->factor().argument (0));
if (n->isMarkedOnTop()) {
fg->addFactor (Factor (*(facNodes[i]->factor())));
fg->addFactor (Factor (facNodes[i]->factor()));
} else if (n->hasEvidence() && n->isVisited()) {
VarIds varIds = { facNodes[i]->factor()->argument (0) };
Ranges ranges = { facNodes[i]->factor()->range (0) };
VarIds varIds = { facNodes[i]->factor().argument (0) };
Ranges ranges = { facNodes[i]->factor().range (0) };
Params params (ranges[0], LogAware::noEvidence());
params[n->getEvidence()] = LogAware::withEvidence();
fg->addFactor (Factor (varIds, ranges, params));

View File

@ -84,5 +84,5 @@ class DAGraph
unordered_map<VarId, DAGraphNode*> varMap_;
};
#endif // HORUS_BAYESNET_H

View File

@ -101,7 +101,7 @@ BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
VarNode* vn = factorGraph_->getVarNode (jointVarIds[0]);
const FactorNodes& factorNodes = vn->neighbors();
for (unsigned i = 0; i < factorNodes.size(); i++) {
if (factorNodes[i]->factor()->contains (jointVarIds)) {
if (factorNodes[i]->factor().contains (jointVarIds)) {
idx = i;
break;
}
@ -109,7 +109,7 @@ BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
if (idx == -1) {
return getJointByConditioning (jointVarIds);
} else {
Factor res (*factorNodes[idx]->factor());
Factor res (factorNodes[idx]->factor());
const SpLinkSet& links = ninf(factorNodes[idx])->getLinks();
for (unsigned i = 0; i < links.size(); i++) {
Factor msg (links[i]->getVariable()->varId(),
@ -316,9 +316,9 @@ BpSolver::maxResidualSchedule (void)
void
BpSolver::calculateFactor2VariableMsg (SpLink* link) const
BpSolver::calculateFactor2VariableMsg (SpLink* link)
{
const FactorNode* src = link->getFactor();
FactorNode* src = link->getFactor();
const VarNode* dst = link->getVariable();
const SpLinkSet& links = ninf(src)->getLinks();
// calculate the product of messages that were sent
@ -357,10 +357,9 @@ BpSolver::calculateFactor2VariableMsg (SpLink* link) const
}
}
Factor result (src->factor()->arguments(),
src->factor()->ranges(),
msgProduct);
result.multiply (*(src->factor()));
Factor result (src->factor().arguments(),
src->factor().ranges(), msgProduct);
result.multiply (src->factor());
if (Constants::DEBUG >= 5) {
cout << " message product: " << msgProduct << endl;
cout << " original factor: " << src->params() << endl;

View File

@ -108,7 +108,7 @@ class BpSolver : public Solver
virtual void maxResidualSchedule (void);
virtual void calculateFactor2VariableMsg (SpLink*) const;
virtual void calculateFactor2VariableMsg (SpLink*);
virtual Params getVar2FactorMsg (const SpLink*) const;

View File

@ -74,20 +74,20 @@ CFactorGraph::setInitialColors (void)
if (checkForIdenticalFactors) {
unsigned groupCount = 1;
for (unsigned i = 0; i < facNodes.size(); i++) {
Factor* f1 = facNodes[i]->factor();
if (f1->distId() != Util::maxUnsigned()) {
Factor& f1 = facNodes[i]->factor();
if (f1.distId() != Util::maxUnsigned()) {
continue;
}
f1->setDistId (groupCount);
f1.setDistId (groupCount);
for (unsigned j = i + 1; j < facNodes.size(); j++) {
Factor* f2 = facNodes[j]->factor();
if (f2->distId() != Util::maxUnsigned()) {
Factor& f2 = facNodes[j]->factor();
if (f2.distId() != Util::maxUnsigned()) {
continue;
}
if (f1->size() == f2->size() &&
f1->ranges() == f2->ranges() &&
f1->params() == f2->params()) {
f2->setDistId (groupCount);
if (f1.size() == f2.size() &&
f1.ranges() == f2.ranges() &&
f1.params() == f2.params()) {
f2.setDistId (groupCount);
}
}
groupCount ++;
@ -96,7 +96,7 @@ CFactorGraph::setInitialColors (void)
// create the initial factor colors
DistColorMap distColors;
for (unsigned i = 0; i < facNodes.size(); i++) {
unsigned distId = facNodes[i]->factor()->distId();
unsigned distId = facNodes[i]->factor().distId();
DistColorMap::iterator it = distColors.find (distId);
if (it == distColors.end()) {
it = distColors.insert (make_pair (distId, getFreeColor())).first;
@ -211,7 +211,7 @@ CFactorGraph::getSignature (const VarNode* varNode)
for (unsigned i = 0; i < neighs.size(); i++) {
*it = getColor (neighs[i]);
it ++;
*it = neighs[i]->factor()->indexOf (varNode->varId());
*it = neighs[i]->factor().indexOf (varNode->varId());
it ++;
}
*it = getColor (varNode);
@ -244,7 +244,7 @@ CFactorGraph::getCompressedFactorGraph (void)
VarNode* var = varClusters_[i]->getGroundVarNodes()[0];
VarNode* newVar = new VarNode (var);
varClusters_[i]->setRepresentativeVariable (newVar);
fg->addVariable (newVar);
fg->addVarNode (newVar);
}
for (unsigned i = 0; i < facClusters_.size(); i++) {
@ -255,11 +255,10 @@ CFactorGraph::getCompressedFactorGraph (void)
VarNode* v = myVarClusters[j]->getRepresentativeVariable();
myGroundVars.push_back (v);
}
Factor* newFactor = new Factor (myGroundVars,
facClusters_[i]->getGroundFactors()[0]->params());
FactorNode* fn = new FactorNode (newFactor);
FactorNode* fn = new FactorNode (Factor (myGroundVars,
facClusters_[i]->getGroundFactors()[0]->params()));
facClusters_[i]->setRepresentativeFactor (fn);
fg->addFactor (fn);
fg->addFactorNode (fn);
for (unsigned j = 0; j < myGroundVars.size(); j++) {
fg->addEdge (fn, static_cast<VarNode*> (myGroundVars[j]));
}
@ -279,7 +278,7 @@ CFactorGraph::getGroundEdgeCount (
VarNode* varNode = vc->getGroundVarNodes()[0];
unsigned count = 0;
for (unsigned i = 0; i < clusterGroundFactors.size(); i++) {
if (clusterGroundFactors[i]->factor()->indexOf (varNode->varId()) != -1) {
if (clusterGroundFactors[i]->factor().indexOf (varNode->varId()) != -1) {
count ++;
}
}

View File

@ -18,21 +18,20 @@ bool FactorGraph::orderFactorVariables = false;
FactorGraph::FactorGraph (const FactorGraph& fg)
{
const VarNodes& vars = fg.varNodes();
for (unsigned i = 0; i < vars.size(); i++) {
VarNode* varNode = new VarNode (vars[i]);
addVariable (varNode);
const VarNodes& varNodes = fg.varNodes();
for (unsigned i = 0; i < varNodes.size(); i++) {
addVarNode (new VarNode (varNodes[i]));
}
const FactorNodes& facs = fg.factorNodes();
for (unsigned i = 0; i < facs.size(); i++) {
FactorNode* facNode = new FactorNode (facs[i]);
addFactor (facNode);
const VarNodes& neighs = facs[i]->neighbors();
const FactorNodes& facNodes = fg.factorNodes();
for (unsigned i = 0; i < facNodes.size(); i++) {
FactorNode* facNode = new FactorNode (facNodes[i]->factor());
addFactorNode (facNode);
const VarNodes& neighs = facNodes[i]->neighbors();
for (unsigned j = 0; j < neighs.size(); j++) {
addEdge (facNode, varNodes_[neighs[j]->getIndex()]);
}
}
setIndexes();
}
@ -45,77 +44,65 @@ FactorGraph::readFromUaiFormat (const char* fileName)
cerr << "error: cannot read from file " + std::string (fileName) << endl;
abort();
}
ignoreLines (is);
string line;
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
getline (is, line);
if (line != "MARKOV") {
cerr << "error: the network must be a MARKOV network " << endl;
abort();
}
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
unsigned nVars;
is >> nVars;
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
vector<int> domainSizes (nVars);
for (unsigned i = 0; i < nVars; i++) {
unsigned ds;
is >> ds;
domainSizes[i] = ds;
// read the number of vars
ignoreLines (is);
unsigned nrVars;
is >> nrVars;
// read the range of each var
ignoreLines (is);
Ranges ranges (nrVars);
for (unsigned i = 0; i < nrVars; i++) {
is >> ranges[i];
}
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
for (unsigned i = 0; i < nVars; i++) {
addVariable (new VarNode (i, domainSizes[i]));
}
unsigned nFactors;
is >> nFactors;
for (unsigned i = 0; i < nFactors; i++) {
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
unsigned nFactorVars;
is >> nFactorVars;
Vars neighs;
for (unsigned j = 0; j < nFactorVars; j++) {
unsigned nrFactors;
unsigned nrArgs;
unsigned vid;
is >> nrFactors;
vector<VarIds> factorVarIds;
vector<Ranges> factorRanges;
for (unsigned i = 0; i < nrFactors; i++) {
ignoreLines (is);
is >> nrArgs;
factorVarIds.push_back ({ });
factorRanges.push_back ({ });
for (unsigned j = 0; j < nrArgs; j++) {
is >> vid;
VarNode* neigh = getVarNode (vid);
if (!neigh) {
cerr << "error: invalid variable identifier (" << vid << ")" << endl;
if (vid >= ranges.size()) {
cerr << "error: invalid variable identifier `" << vid << "'" << endl;
cerr << "identifiers must be between 0 and " << ranges.size() - 1 ;
cerr << endl;
abort();
}
neighs.push_back (neigh);
}
FactorNode* fn = new FactorNode (new Factor (neighs));
addFactor (fn);
for (unsigned j = 0; j < neighs.size(); j++) {
addEdge (fn, static_cast<VarNode*> (neighs[j]));
factorVarIds.back().push_back (vid);
factorRanges.back().push_back (ranges[vid]);
}
}
for (unsigned i = 0; i < nFactors; i++) {
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
unsigned nParams;
is >> nParams;
if (facNodes_[i]->params().size() != nParams) {
cerr << "error: invalid number of parameters for factor " ;
cerr << facNodes_[i]->getLabel() ;
cerr << ", expected: " << facNodes_[i]->params().size();
cerr << ", given: " << nParams << endl;
// read the parameters
unsigned nrParams;
for (unsigned i = 0; i < nrFactors; i++) {
ignoreLines (is);
is >> nrParams;
if (nrParams != Util::expectedSize (factorRanges[i])) {
cerr << "error: invalid number of parameters for factor nº " << i ;
cerr << ", expected: " << Util::expectedSize (factorRanges[i]);
cerr << ", given: " << nrParams << endl;
abort();
}
Params params (nParams);
for (unsigned j = 0; j < nParams; j++) {
double param;
is >> param;
params[j] = param;
Params params (nrParams);
for (unsigned j = 0; j < nrParams; j++) {
is >> params[j];
}
if (Globals::logDomain) {
Util::toLog (params);
}
facNodes_[i]->factor()->setParams (params);
addFactor (Factor (factorVarIds[i], factorRanges[i], params));
}
is.close();
setIndexes();
@ -131,79 +118,51 @@ FactorGraph::readFromLibDaiFormat (const char* fileName)
cerr << "error: cannot read from file " + std::string (fileName) << endl;
abort();
}
string line;
unsigned nFactors;
while ((is.peek()) == '#') getline (is, line);
is >> nFactors;
if (is.fail()) {
cerr << "error: cannot read the number of factors" << endl;
abort();
}
getline (is, line);
if (is.fail() || line.size() > 0) {
cerr << "error: cannot read the number of factors" << endl;
abort();
}
for (unsigned i = 0; i < nFactors; i++) {
unsigned nVars;
while ((is.peek()) == '#') getline (is, line);
is >> nVars;
VarIds vids;
for (unsigned j = 0; j < nVars; j++) {
ignoreLines (is);
unsigned nrFactors;
unsigned nrArgs;
VarId vid;
while ((is.peek()) == '#') getline (is, line);
is >> nrFactors;
for (unsigned i = 0; i < nrFactors; i++) {
ignoreLines (is);
// read the factor arguments
is >> nrArgs;
VarIds vids;
for (unsigned j = 0; j < nrArgs; j++) {
ignoreLines (is);
is >> vid;
vids.push_back (vid);
}
Vars neighs;
unsigned nParams = 1;
for (unsigned j = 0; j < nVars; j++) {
unsigned dsize;
while ((is.peek()) == '#') getline (is, line);
is >> dsize;
// read ranges
Ranges ranges (nrArgs);
for (unsigned j = 0; j < nrArgs; j++) {
ignoreLines (is);
is >> ranges[j];
VarNode* var = getVarNode (vids[j]);
if (var == 0) {
var = new VarNode (vids[j], dsize);
addVariable (var);
} else {
if (var->range() != dsize) {
if (var != 0 && ranges[j] != var->range()) {
cerr << "error: variable `" << vids[j] << "' appears in two or " ;
cerr << "more factors with different domain sizes" << endl;
cerr << "more factors with a different range" << endl;
}
}
neighs.push_back (var);
nParams *= var->range();
}
Params params (nParams, 0);
// read parameters
ignoreLines (is);
unsigned nNonzeros;
while ((is.peek()) == '#') getline (is, line);
is >> nNonzeros;
Params params (Util::expectedSize (ranges), 0);
for (unsigned j = 0; j < nNonzeros; j++) {
ignoreLines (is);
unsigned index;
double val;
while ((is.peek()) == '#') getline (is, line);
is >> index;
while ((is.peek()) == '#') getline (is, line);
ignoreLines (is);
double val;
is >> val;
params[index] = val;
}
reverse (neighs.begin(), neighs.end());
reverse (vids.begin(), vids.end());
if (Globals::logDomain) {
Util::toLog (params);
}
FactorNode* fn = new FactorNode (new Factor (neighs, params));
addFactor (fn);
for (unsigned j = 0; j < neighs.size(); j++) {
addEdge (fn, static_cast<VarNode*> (neighs[j]));
}
addFactor (Factor (vids, ranges, params));
}
is.close();
setIndexes();
@ -223,30 +182,11 @@ FactorGraph::~FactorGraph (void)
void
FactorGraph::addVariable (VarNode* vn)
{
varNodes_.push_back (vn);
vn->setIndex (varNodes_.size() - 1);
varMap_.insert (make_pair (vn->varId(), varNodes_.size() - 1));
}
void
FactorGraph::addFactor (FactorNode* fn)
{
facNodes_.push_back (fn);
fn->setIndex (facNodes_.size() - 1);
}
void
FactorGraph::addFactor (const Factor& factor)
{
FactorNode* fn = new FactorNode (factor);
addFactor (fn);
addFactorNode (fn);
const VarIds& vids = factor.arguments();
for (unsigned i = 0; i < vids.size(); i++) {
bool found = false;
@ -258,7 +198,7 @@ FactorGraph::addFactor (const Factor& factor)
}
if (found == false) {
VarNode* vn = new VarNode (vids[i], factor.range (i));
addVariable (vn);
addVarNode (vn);
addEdge (vn, fn);
}
}
@ -266,6 +206,25 @@ FactorGraph::addFactor (const Factor& factor)
void
FactorGraph::addVarNode (VarNode* vn)
{
varNodes_.push_back (vn);
vn->setIndex (varNodes_.size() - 1);
varMap_.insert (make_pair (vn->varId(), varNodes_.size() - 1));
}
void
FactorGraph::addFactorNode (FactorNode* fn)
{
facNodes_.push_back (fn);
fn->setIndex (facNodes_.size() - 1);
}
void
FactorGraph::addEdge (VarNode* vn, FactorNode* fn)
{
@ -301,7 +260,7 @@ FactorGraph::getStructure (void)
structure_.addNode (new DAGraphNode (varNodes_[i]));
}
for (unsigned i = 0; i < facNodes_.size(); i++) {
const VarIds& vids = facNodes_[i]->factor()->arguments();
const VarIds& vids = facNodes_[i]->factor().arguments();
for (unsigned j = 1; j < vids.size(); j++) {
structure_.addEdge (vids[j], vids[0]);
}
@ -340,7 +299,7 @@ FactorGraph::print (void) const
cout << endl << endl;
}
for (unsigned i = 0; i < facNodes_.size(); i++) {
facNodes_[i]->factor()->print();
facNodes_[i]->factor().print();
}
}
@ -446,7 +405,7 @@ FactorGraph::exportToLibDaiFormat (const char* fileName) const
out << factorVars[j]->range() << " " ;
}
out << endl;
Params params = facNodes_[i]->factor()->params();
Params params = facNodes_[i]->factor().params();
if (Globals::logDomain) {
Util::fromLog (params);
}
@ -461,6 +420,17 @@ FactorGraph::exportToLibDaiFormat (const char* fileName) const
void
FactorGraph::ignoreLines (std::ifstream& is) const
{
string ignoreStr;
while (is.peek() == '#' || is.peek() == '\n') {
getline (is, ignoreStr);
}
}
bool
FactorGraph::containsCycle (void) const
{

View File

@ -34,47 +34,28 @@ class VarNode : public Var
class FactorNode
{
public:
FactorNode (const FactorNode* fn)
{
factor_ = new Factor (*fn->factor());
index_ = -1;
}
FactorNode (const Factor& f) : factor_(f), index_(-1) { }
FactorNode (Factor* f) : factor_(new Factor(*f)), index_(-1) { }
const Factor& factor (void) const { return factor_; }
FactorNode (const Factor& f) : factor_(new Factor (f)), index_(-1) { }
Factor* factor() const { return factor_; }
Factor& factor (void) { return factor_; }
void addNeighbor (VarNode* vn) { neighs_.push_back (vn); }
const VarNodes& neighbors (void) const { return neighs_; }
int getIndex (void) const
{
assert (index_ != -1);
return index_;
}
int getIndex (void) const { return index_; }
void setIndex (int index)
{
index_ = index;
}
void setIndex (int index) { index_ = index; }
const Params& params (void) const
{
return factor_->params();
}
const Params& params (void) const { return factor_.params(); }
string getLabel (void)
{
return factor_->getLabel();
}
string getLabel (void) { return factor_.getLabel(); }
private:
DISALLOW_COPY_AND_ASSIGN (FactorNode);
Factor* factor_;
Factor factor_;
VarNodes neighs_;
int index_;
};
@ -116,12 +97,12 @@ class FactorGraph
void readFromLibDaiFormat (const char*);
void addVariable (VarNode*);
void addFactor (FactorNode*);
void addFactor (const Factor& factor);
void addVarNode (VarNode*);
void addFactorNode (FactorNode*);
void addEdge (VarNode*, FactorNode*);
void addEdge (FactorNode*, VarNode*);
@ -145,6 +126,8 @@ class FactorGraph
private:
// DISALLOW_COPY_AND_ASSIGN (FactorGraph);
void ignoreLines (std::ifstream&) const;
bool containsCycle (void) const;
bool containsCycle (const VarNode*, const FactorNode*,

View File

@ -31,15 +31,16 @@ main (int argc, const char* argv[])
FactorGraph fg;
if (extension == "uai") {
fg.readFromUaiFormat (argv[1]);
processArguments (fg, argc, argv);
} else if (extension == "fg") {
fg.readFromLibDaiFormat (argv[1]);
processArguments (fg, argc, argv);
} else {
cerr << "error: the graphical model must be defined either " ;
cerr << "in a UAI or libDAI file" << endl;
exit (0);
}
fg.print();
assert (false);
processArguments (fg, argc, argv);
return 0;
}

View File

@ -214,8 +214,6 @@ createGroundNetwork (void)
{
FactorGraph* fg = new FactorGraph();;
string factorsType ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG1)));
cout << "factors type: '" << factorsType << "'" << endl;
YAP_Term factorList = YAP_ARG2;
while (factorList != YAP_TermNil()) {
YAP_Term factor = YAP_HeadOfTerm (factorList);
@ -240,7 +238,6 @@ createGroundNetwork (void)
YAP_Term evTerm = YAP_HeadOfTerm (evidenceList);
unsigned vid = (unsigned) YAP_IntOfTerm ((YAP_ArgOfTerm (1, evTerm)));
unsigned ev = (unsigned) YAP_IntOfTerm ((YAP_ArgOfTerm (2, evTerm)));
cout << vid << " == " << ev << endl;
assert (fg->getVarNode (vid));
fg->getVarNode (vid)->setEvidence (ev);
evidenceList = YAP_TailOfTerm (evidenceList);
@ -354,7 +351,7 @@ runGroundSolver (void)
} else {
runBpSolver (fg, tasks, results);
}
cout << "results: " << results << endl;
YAP_Term list = YAP_TermNil();
for (int i = results.size() - 1; i >= 0; i--) {
const Params& beliefs = results[i];
@ -491,24 +488,29 @@ setBayesNetParams (void)
int
setExtraVarsInfo (void)
setVarsInformation (void)
{
Var::clearVariablesInformation();
YAP_Term varsInfoL = YAP_ARG2;
while (varsInfoL != YAP_TermNil()) {
YAP_Term head = YAP_HeadOfTerm (varsInfoL);
VarId vid = YAP_IntOfTerm (YAP_ArgOfTerm (1, head));
YAP_Atom label = YAP_AtomOfTerm (YAP_ArgOfTerm (2, head));
YAP_Term statesL = YAP_ArgOfTerm (3, head);
States states;
while (statesL != YAP_TermNil()) {
YAP_Atom atom = YAP_AtomOfTerm (YAP_HeadOfTerm (statesL));
states.push_back ((char*) YAP_AtomName (atom));
statesL = YAP_TailOfTerm (statesL);
Var::clearVarsInfo();
YAP_Term labelsL = YAP_ARG1;
vector<string> labels;
while (labelsL != YAP_TermNil()) {
YAP_Atom atom = YAP_AtomOfTerm (YAP_HeadOfTerm (labelsL));
labels.push_back ((char*) YAP_AtomName (atom));
labelsL = YAP_TailOfTerm (labelsL);
}
Var::addVariableInformation (vid,
(char*) YAP_AtomName (label), states);
varsInfoL = YAP_TailOfTerm (varsInfoL);
unsigned count = 0;
YAP_Term stateNamesL = YAP_ARG2;
while (stateNamesL != YAP_TermNil()) {
States states;
YAP_Term namesL = YAP_HeadOfTerm (stateNamesL);
while (namesL != YAP_TermNil()) {
YAP_Atom atom = YAP_AtomOfTerm (YAP_HeadOfTerm (namesL));
states.push_back ((char*) YAP_AtomName (atom));
namesL = YAP_TailOfTerm (namesL);
}
Var::addVarInfo (count, labels[count], states);
count ++;
stateNamesL = YAP_TailOfTerm (stateNamesL);
}
return TRUE;
}
@ -627,7 +629,7 @@ init_predicates (void)
YAP_UserCPredicate ("run_ground_solver", runGroundSolver, 3);
YAP_UserCPredicate ("set_parfactors_params", setParfactorsParams, 2);
YAP_UserCPredicate ("set_bayes_net_params", setBayesNetParams, 2);
YAP_UserCPredicate ("set_extra_vars_info", setExtraVarsInfo, 2);
YAP_UserCPredicate ("set_vars_information", setVarsInformation, 2);
YAP_UserCPredicate ("set_horus_flag", setHorusFlag, 2);
YAP_UserCPredicate ("free_parfactors", freeParfactors, 1);
YAP_UserCPredicate ("free_ground_network", freeGroundNetwork, 1);

View File

@ -366,12 +366,12 @@ Statistics::printStatistics (void)
void
Statistics::writeStatisticsToFile (const char* fileName)
Statistics::writeStatistics (const char* fileName)
{
ofstream out (fileName);
if (!out.is_open()) {
cerr << "error: cannot open file to write at " ;
cerr << "Statistics::writeStatisticsToFile()" << endl;
cerr << "Statistics::writeStats()" << endl;
abort();
}
out << getStatisticString();

View File

@ -359,7 +359,7 @@ class Statistics
static void printStatistics (void);
static void writeStatisticsToFile (const char*);
static void writeStatistics (const char*);
static void updateCompressingStatistics (
unsigned, unsigned, unsigned, unsigned, unsigned);

View File

@ -42,7 +42,7 @@ Var::isValidState (int stateIndex)
bool
Var::isValidState (const string& stateName)
{
States states = Var::getVarInformation (varId_).states;
States states = Var::getVarInfo (varId_).states;
return Util::contains (states, stateName);
}
@ -60,7 +60,7 @@ Var::setEvidence (int ev)
void
Var::setEvidence (const string& ev)
{
States states = Var::getVarInformation (varId_).states;
States states = Var::getVarInfo (varId_).states;
for (unsigned i = 0; i < states.size(); i++) {
if (states[i] == ev) {
evidence_ = i;
@ -75,8 +75,8 @@ Var::setEvidence (const string& ev)
string
Var::label (void) const
{
if (Var::variablesHaveInformation()) {
return Var::getVarInformation (varId_).label;
if (Var::varsHaveInfo()) {
return Var::getVarInfo (varId_).label;
}
stringstream ss;
ss << "x" << varId_;
@ -88,8 +88,8 @@ Var::label (void) const
States
Var::states (void) const
{
if (Var::variablesHaveInformation()) {
return Var::getVarInformation (varId_).states;
if (Var::varsHaveInfo()) {
return Var::getVarInfo (varId_).states;
}
States states;
for (unsigned i = 0; i < range_; i++) {

View File

@ -71,25 +71,25 @@ class Var
States states (void) const;
static void addVariableInformation (
static void addVarInfo (
VarId vid, string label, const States& states)
{
assert (Util::contains (varsInfo_, vid) == false);
varsInfo_.insert (make_pair (vid, VarInfo (label, states)));
}
static VarInfo getVarInformation (VarId vid)
static VarInfo getVarInfo (VarId vid)
{
assert (Util::contains (varsInfo_, vid));
return varsInfo_.find (vid)->second;
}
static bool variablesHaveInformation (void)
static bool varsHaveInfo (void)
{
return varsInfo_.size() != 0;
}
static void clearVariablesInformation (void)
static void clearVarsInfo (void)
{
varsInfo_.clear();
}

View File

@ -60,7 +60,7 @@ VarElimSolver::createFactorList (void)
const FactorNodes& factorNodes = factorGraph_->factorNodes();
factorList_.reserve (factorNodes.size() * 2);
for (unsigned i = 0; i < factorNodes.size(); i++) {
factorList_.push_back (new Factor (*factorNodes[i]->factor()));
factorList_.push_back (new Factor (factorNodes[i]->factor())); // FIXME
const VarNodes& neighs = factorNodes[i]->neighbors();
for (unsigned j = 0; j < neighs.size(); j++) {
unordered_map<VarId,vector<unsigned> >::iterator it
@ -95,7 +95,6 @@ VarElimSolver::absorveEvidence (void)
}
}
}
printActiveFactors();
}

View File

@ -12,7 +12,7 @@
set_bayes_net_params/2,
run_lifted_solver/3,
run_ground_solver/3,
set_extra_vars_info/2,
set_vars_information/2,
set_horus_flag/2,
free_parfactors/1,
free_ground_network/1