Improvements
Factor nodes now contain a factor object instead of a pointer. Refactor the way .fg and .uai formats are readed.
This commit is contained in:
parent
f1d52c0389
commit
6986e8c0d7
@ -40,7 +40,6 @@
|
||||
get_pfl_parameters/2
|
||||
]).
|
||||
|
||||
% :- use_module(library(clpbn/horus)).
|
||||
|
||||
:- use_module(library(lists)).
|
||||
|
||||
@ -53,45 +52,42 @@
|
||||
[create_ground_network/4,
|
||||
set_bayes_net_params/2,
|
||||
run_ground_solver/3,
|
||||
set_extra_vars_info/2,
|
||||
set_vars_information/2,
|
||||
free_ground_network/1
|
||||
]).
|
||||
|
||||
|
||||
call_bp_ground(QueryKeys, AllKeys, Factors, Evidence, Solutions) :-
|
||||
call_bp_ground(QueryKeys, AllKeys, Factors, Evidence, Output) :-
|
||||
b_hash_new(Hash0),
|
||||
keys_to_ids(AllKeys, 0, Hash0, Hash),
|
||||
%InvMap =.. [view|AllKeys],
|
||||
list_of_keys_to_ids(QueryKeys, Hash, QueryIds),
|
||||
evidence_to_ids(Evidence, Hash, EvIds),
|
||||
factors_to_ids(Factors, Hash, FactorIds),
|
||||
get_factors_type(Factors, Type),
|
||||
evidence_to_ids(Evidence, Hash, EvidenceIds),
|
||||
factors_to_ids(Factors, Hash, FactorIds),
|
||||
writeln(type:Type), writeln(''),
|
||||
writeln(allKeys:AllKeys), writeln(''),
|
||||
%writeln(allKeysIds:Hash), writeln(''),
|
||||
writeln(queryKeys:QueryKeys), writeln(''),
|
||||
writeln(queryIds:QueryIds), writeln(''),
|
||||
writeln(factors:Factors), writeln(''),
|
||||
writeln(factorIds:FactorIds), writeln(''),
|
||||
writeln(evidence:Evidence), writeln(''),
|
||||
writeln(evIds:EvIds),
|
||||
create_ground_network(Type, FactorIds, EvIds, Network),
|
||||
run_ground_fixme_solver(ground(Network,Hash), QueryKeys, Solutions).
|
||||
%free_graphical_model(Network).
|
||||
writeln(evidenceIds:EvidenceIds), writeln(''),
|
||||
create_ground_network(Type, FactorIds, EvidenceIds, Network),
|
||||
%get_vars_information(AllKeys, StatesNames),
|
||||
%set_vars_information(AllKeys, StatesNames),
|
||||
run_solver(ground(Network,Hash), QueryKeys, Solutions),
|
||||
writeln(answer:Solutions),
|
||||
%clpbn_bind_vals([QueryKeys], Solutions, Output).
|
||||
free_ground_network(Network).
|
||||
|
||||
|
||||
run_ground_fixme_solver(ground(Network,Hash), QueryKeys, Solutions) :-
|
||||
run_solver(ground(Network,Hash), QueryKeys, Solutions) :-
|
||||
%get_dists_parameters(DistIds, DistsParams),
|
||||
%set_bayes_net_params(Network, DistsParams),
|
||||
%vars_to_ids(QueryVars, QueryVarsIds),
|
||||
list_of_keys_to_ids(QueryKeys, Hash, QueryIds),
|
||||
writeln(queryKeys:QueryKeys), writeln(''),
|
||||
writeln(queryIds:QueryIds), writeln(''),
|
||||
list_of_keys_to_ids(QueryKeys, Hash, QueryIds),
|
||||
run_ground_solver(Network, [QueryIds], Solutions).
|
||||
|
||||
|
||||
get_factors_type([f(bayes, _, _)|_], bayes) :- ! .
|
||||
get_factors_type([f(markov, _, _)|_], markov) :- ! .
|
||||
|
||||
|
||||
keys_to_ids([], _, Hash, Hash).
|
||||
keys_to_ids([Key|AllKeys], I0, Hash0, Hash) :-
|
||||
b_hash_insert(Hash0, Key, I0, HashI),
|
||||
@ -99,27 +95,16 @@ keys_to_ids([Key|AllKeys], I0, Hash0, Hash) :-
|
||||
keys_to_ids(AllKeys, I, HashI, Hash).
|
||||
|
||||
|
||||
get_factors_type([f(bayes, _, _)|_], bayes) :- ! .
|
||||
get_factors_type([f(markov, _, _)|_], markov) :- ! .
|
||||
|
||||
|
||||
list_of_keys_to_ids([], _, []).
|
||||
list_of_keys_to_ids([Key|QueryKeys], Hash, [Id|QueryIds]) :-
|
||||
b_hash_lookup(Key, Id, Hash),
|
||||
list_of_keys_to_ids(QueryKeys, Hash, QueryIds).
|
||||
|
||||
|
||||
evidence_to_ids([], _, []).
|
||||
evidence_to_ids([Key=Ev|QueryKeys], Hash, [Id=Ev|QueryIds]) :-
|
||||
b_hash_lookup(Key, Id, Hash),
|
||||
evidence_to_ids(QueryKeys, Hash, QueryIds).
|
||||
|
||||
|
||||
|
||||
%evidence_to_ids([], _, [], []).
|
||||
%evidence_to_ids([Key=V|QueryKeys], Hash, [Id=V|QueryIds], [Id=Name|QueryNames]) :-
|
||||
% b_hash_lookup(Key, Id, Hash),
|
||||
% pfl:skolem(Key,Dom),
|
||||
% nth0(V, Dom, Name),
|
||||
% evidence_to_ids(QueryKeys, Hash, QueryIds, QueryNames).
|
||||
|
||||
|
||||
factors_to_ids([], _, []).
|
||||
factors_to_ids([f(_, Keys, CPT)|Fs], Hash, [f(Ids, Ranges, CPT, DistId)|NFs]) :-
|
||||
list_of_keys_to_ids(Keys, Hash, Ids),
|
||||
@ -135,6 +120,22 @@ get_ranges(K.Ks, Range.Rs) :- !,
|
||||
get_ranges(Ks, Rs).
|
||||
|
||||
|
||||
evidence_to_ids([], _, []).
|
||||
evidence_to_ids([Key=Ev|QueryKeys], Hash, [Id=Ev|QueryIds]) :-
|
||||
b_hash_lookup(Key, Id, Hash),
|
||||
evidence_to_ids(QueryKeys, Hash, QueryIds).
|
||||
|
||||
|
||||
get_vars_information([], []).
|
||||
get_vars_information(Key.QueryKeys, Domain.StatesNames) :-
|
||||
pfl:skolem(Key, Domain),
|
||||
get_vars_information(QueryKeys, StatesNames).
|
||||
|
||||
|
||||
finalize_bp_solver(bp(Network, _)) :-
|
||||
free_ground_network(Network).
|
||||
|
||||
|
||||
bp([[]],_,_) :- !.
|
||||
bp([QueryVars], AllVars, Output) :-
|
||||
init_bp_solver(_, AllVars, _, Network),
|
||||
@ -144,53 +145,22 @@ bp([QueryVars], AllVars, Output) :-
|
||||
|
||||
|
||||
init_bp_solver(_, AllVars0, _, bp(BayesNet, DistIds)) :-
|
||||
%writeln('init_bp_solver'),
|
||||
%check_for_agg_vars(AllVars0, AllVars),
|
||||
%writeln('clpbn_vars:'), print_clpbn_vars(AllVars),
|
||||
get_vars_info(AllVars, VarsInfo, DistIds0),
|
||||
sort(DistIds0, DistIds),
|
||||
create_ground_network(VarsInfo, BayesNet),
|
||||
%get_extra_vars_info(AllVars, ExtraVarsInfo),
|
||||
%set_extra_vars_info(BayesNet, ExtraVarsInfo),
|
||||
%writeln(extravarsinfo:ExtraVarsInfo),
|
||||
true.
|
||||
|
||||
|
||||
run_bp_solver(QueryVars, Solutions, bp(Network, DistIds)) :-
|
||||
%writeln('-> run_bp_solver'),
|
||||
get_dists_parameters(DistIds, DistsParams),
|
||||
set_bayes_net_params(Network, DistsParams),
|
||||
vars_to_ids(QueryVars, QueryVarsIds),
|
||||
run_ground_solver(Network, QueryVarsIds, Solutions).
|
||||
|
||||
|
||||
finalize_bp_solver(bp(Network, _)) :-
|
||||
free_ground_network(Network).
|
||||
|
||||
|
||||
get_extra_vars_info([], []).
|
||||
get_extra_vars_info([V|Vs], [v(VarId, Label, Domain)|VarsInfo]) :-
|
||||
get_atts(V, [id(VarId)]), !,
|
||||
clpbn:get_atts(V, [key(Key), dist(DistId, _)]),
|
||||
term_to_atom(Key, Label),
|
||||
get_dist_domain(DistId, Domain0),
|
||||
numbers_to_atoms(Domain0, Domain),
|
||||
get_extra_vars_info(Vs, VarsInfo).
|
||||
get_extra_vars_info([_|Vs], VarsInfo) :-
|
||||
get_extra_vars_info(Vs, VarsInfo).
|
||||
|
||||
|
||||
get_dists_parameters([],[]).
|
||||
get_dists_parameters([Id|Ids], [dist(Id, Params)|DistsInfo]) :-
|
||||
get_dist_params(Id, Params),
|
||||
get_dists_parameters(Ids, DistsInfo).
|
||||
|
||||
|
||||
numbers_to_atoms([], []).
|
||||
numbers_to_atoms([Atom|L0], [Atom|L]) :-
|
||||
atom(Atom), !,
|
||||
numbers_to_atoms(L0, L).
|
||||
numbers_to_atoms([Number|L0], [Atom|L]) :-
|
||||
number_atom(Number, Atom),
|
||||
numbers_to_atoms(L0, L).
|
||||
|
||||
|
@ -63,12 +63,12 @@ BayesBall::constructGraph (FactorGraph* fg) const
|
||||
const FactorNodes& facNodes = fg_.factorNodes();
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
const DAGraphNode* n = dag_.getNode (
|
||||
facNodes[i]->factor()->argument (0));
|
||||
facNodes[i]->factor().argument (0));
|
||||
if (n->isMarkedOnTop()) {
|
||||
fg->addFactor (Factor (*(facNodes[i]->factor())));
|
||||
fg->addFactor (Factor (facNodes[i]->factor()));
|
||||
} else if (n->hasEvidence() && n->isVisited()) {
|
||||
VarIds varIds = { facNodes[i]->factor()->argument (0) };
|
||||
Ranges ranges = { facNodes[i]->factor()->range (0) };
|
||||
VarIds varIds = { facNodes[i]->factor().argument (0) };
|
||||
Ranges ranges = { facNodes[i]->factor().range (0) };
|
||||
Params params (ranges[0], LogAware::noEvidence());
|
||||
params[n->getEvidence()] = LogAware::withEvidence();
|
||||
fg->addFactor (Factor (varIds, ranges, params));
|
||||
|
@ -84,5 +84,5 @@ class DAGraph
|
||||
unordered_map<VarId, DAGraphNode*> varMap_;
|
||||
};
|
||||
|
||||
|
||||
#endif // HORUS_BAYESNET_H
|
||||
|
||||
|
@ -101,7 +101,7 @@ BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
||||
VarNode* vn = factorGraph_->getVarNode (jointVarIds[0]);
|
||||
const FactorNodes& factorNodes = vn->neighbors();
|
||||
for (unsigned i = 0; i < factorNodes.size(); i++) {
|
||||
if (factorNodes[i]->factor()->contains (jointVarIds)) {
|
||||
if (factorNodes[i]->factor().contains (jointVarIds)) {
|
||||
idx = i;
|
||||
break;
|
||||
}
|
||||
@ -109,7 +109,7 @@ BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
||||
if (idx == -1) {
|
||||
return getJointByConditioning (jointVarIds);
|
||||
} else {
|
||||
Factor res (*factorNodes[idx]->factor());
|
||||
Factor res (factorNodes[idx]->factor());
|
||||
const SpLinkSet& links = ninf(factorNodes[idx])->getLinks();
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
Factor msg (links[i]->getVariable()->varId(),
|
||||
@ -316,9 +316,9 @@ BpSolver::maxResidualSchedule (void)
|
||||
|
||||
|
||||
void
|
||||
BpSolver::calculateFactor2VariableMsg (SpLink* link) const
|
||||
BpSolver::calculateFactor2VariableMsg (SpLink* link)
|
||||
{
|
||||
const FactorNode* src = link->getFactor();
|
||||
FactorNode* src = link->getFactor();
|
||||
const VarNode* dst = link->getVariable();
|
||||
const SpLinkSet& links = ninf(src)->getLinks();
|
||||
// calculate the product of messages that were sent
|
||||
@ -357,10 +357,9 @@ BpSolver::calculateFactor2VariableMsg (SpLink* link) const
|
||||
}
|
||||
}
|
||||
|
||||
Factor result (src->factor()->arguments(),
|
||||
src->factor()->ranges(),
|
||||
msgProduct);
|
||||
result.multiply (*(src->factor()));
|
||||
Factor result (src->factor().arguments(),
|
||||
src->factor().ranges(), msgProduct);
|
||||
result.multiply (src->factor());
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " message product: " << msgProduct << endl;
|
||||
cout << " original factor: " << src->params() << endl;
|
||||
|
@ -108,7 +108,7 @@ class BpSolver : public Solver
|
||||
|
||||
virtual void maxResidualSchedule (void);
|
||||
|
||||
virtual void calculateFactor2VariableMsg (SpLink*) const;
|
||||
virtual void calculateFactor2VariableMsg (SpLink*);
|
||||
|
||||
virtual Params getVar2FactorMsg (const SpLink*) const;
|
||||
|
||||
|
@ -74,20 +74,20 @@ CFactorGraph::setInitialColors (void)
|
||||
if (checkForIdenticalFactors) {
|
||||
unsigned groupCount = 1;
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
Factor* f1 = facNodes[i]->factor();
|
||||
if (f1->distId() != Util::maxUnsigned()) {
|
||||
Factor& f1 = facNodes[i]->factor();
|
||||
if (f1.distId() != Util::maxUnsigned()) {
|
||||
continue;
|
||||
}
|
||||
f1->setDistId (groupCount);
|
||||
f1.setDistId (groupCount);
|
||||
for (unsigned j = i + 1; j < facNodes.size(); j++) {
|
||||
Factor* f2 = facNodes[j]->factor();
|
||||
if (f2->distId() != Util::maxUnsigned()) {
|
||||
Factor& f2 = facNodes[j]->factor();
|
||||
if (f2.distId() != Util::maxUnsigned()) {
|
||||
continue;
|
||||
}
|
||||
if (f1->size() == f2->size() &&
|
||||
f1->ranges() == f2->ranges() &&
|
||||
f1->params() == f2->params()) {
|
||||
f2->setDistId (groupCount);
|
||||
if (f1.size() == f2.size() &&
|
||||
f1.ranges() == f2.ranges() &&
|
||||
f1.params() == f2.params()) {
|
||||
f2.setDistId (groupCount);
|
||||
}
|
||||
}
|
||||
groupCount ++;
|
||||
@ -96,7 +96,7 @@ CFactorGraph::setInitialColors (void)
|
||||
// create the initial factor colors
|
||||
DistColorMap distColors;
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
unsigned distId = facNodes[i]->factor()->distId();
|
||||
unsigned distId = facNodes[i]->factor().distId();
|
||||
DistColorMap::iterator it = distColors.find (distId);
|
||||
if (it == distColors.end()) {
|
||||
it = distColors.insert (make_pair (distId, getFreeColor())).first;
|
||||
@ -211,7 +211,7 @@ CFactorGraph::getSignature (const VarNode* varNode)
|
||||
for (unsigned i = 0; i < neighs.size(); i++) {
|
||||
*it = getColor (neighs[i]);
|
||||
it ++;
|
||||
*it = neighs[i]->factor()->indexOf (varNode->varId());
|
||||
*it = neighs[i]->factor().indexOf (varNode->varId());
|
||||
it ++;
|
||||
}
|
||||
*it = getColor (varNode);
|
||||
@ -244,7 +244,7 @@ CFactorGraph::getCompressedFactorGraph (void)
|
||||
VarNode* var = varClusters_[i]->getGroundVarNodes()[0];
|
||||
VarNode* newVar = new VarNode (var);
|
||||
varClusters_[i]->setRepresentativeVariable (newVar);
|
||||
fg->addVariable (newVar);
|
||||
fg->addVarNode (newVar);
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < facClusters_.size(); i++) {
|
||||
@ -255,11 +255,10 @@ CFactorGraph::getCompressedFactorGraph (void)
|
||||
VarNode* v = myVarClusters[j]->getRepresentativeVariable();
|
||||
myGroundVars.push_back (v);
|
||||
}
|
||||
Factor* newFactor = new Factor (myGroundVars,
|
||||
facClusters_[i]->getGroundFactors()[0]->params());
|
||||
FactorNode* fn = new FactorNode (newFactor);
|
||||
FactorNode* fn = new FactorNode (Factor (myGroundVars,
|
||||
facClusters_[i]->getGroundFactors()[0]->params()));
|
||||
facClusters_[i]->setRepresentativeFactor (fn);
|
||||
fg->addFactor (fn);
|
||||
fg->addFactorNode (fn);
|
||||
for (unsigned j = 0; j < myGroundVars.size(); j++) {
|
||||
fg->addEdge (fn, static_cast<VarNode*> (myGroundVars[j]));
|
||||
}
|
||||
@ -279,7 +278,7 @@ CFactorGraph::getGroundEdgeCount (
|
||||
VarNode* varNode = vc->getGroundVarNodes()[0];
|
||||
unsigned count = 0;
|
||||
for (unsigned i = 0; i < clusterGroundFactors.size(); i++) {
|
||||
if (clusterGroundFactors[i]->factor()->indexOf (varNode->varId()) != -1) {
|
||||
if (clusterGroundFactors[i]->factor().indexOf (varNode->varId()) != -1) {
|
||||
count ++;
|
||||
}
|
||||
}
|
||||
|
@ -18,21 +18,20 @@ bool FactorGraph::orderFactorVariables = false;
|
||||
|
||||
FactorGraph::FactorGraph (const FactorGraph& fg)
|
||||
{
|
||||
const VarNodes& vars = fg.varNodes();
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
VarNode* varNode = new VarNode (vars[i]);
|
||||
addVariable (varNode);
|
||||
const VarNodes& varNodes = fg.varNodes();
|
||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||
addVarNode (new VarNode (varNodes[i]));
|
||||
}
|
||||
|
||||
const FactorNodes& facs = fg.factorNodes();
|
||||
for (unsigned i = 0; i < facs.size(); i++) {
|
||||
FactorNode* facNode = new FactorNode (facs[i]);
|
||||
addFactor (facNode);
|
||||
const VarNodes& neighs = facs[i]->neighbors();
|
||||
const FactorNodes& facNodes = fg.factorNodes();
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
FactorNode* facNode = new FactorNode (facNodes[i]->factor());
|
||||
addFactorNode (facNode);
|
||||
const VarNodes& neighs = facNodes[i]->neighbors();
|
||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||
addEdge (facNode, varNodes_[neighs[j]->getIndex()]);
|
||||
}
|
||||
}
|
||||
setIndexes();
|
||||
}
|
||||
|
||||
|
||||
@ -40,82 +39,70 @@ FactorGraph::FactorGraph (const FactorGraph& fg)
|
||||
void
|
||||
FactorGraph::readFromUaiFormat (const char* fileName)
|
||||
{
|
||||
ifstream is (fileName);
|
||||
ifstream is (fileName);
|
||||
if (!is.is_open()) {
|
||||
cerr << "error: cannot read from file " + std::string (fileName) << endl;
|
||||
abort();
|
||||
}
|
||||
|
||||
ignoreLines (is);
|
||||
string line;
|
||||
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
||||
getline (is, line);
|
||||
if (line != "MARKOV") {
|
||||
cerr << "error: the network must be a MARKOV network " << endl;
|
||||
abort();
|
||||
}
|
||||
|
||||
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
||||
unsigned nVars;
|
||||
is >> nVars;
|
||||
|
||||
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
||||
vector<int> domainSizes (nVars);
|
||||
for (unsigned i = 0; i < nVars; i++) {
|
||||
unsigned ds;
|
||||
is >> ds;
|
||||
domainSizes[i] = ds;
|
||||
// read the number of vars
|
||||
ignoreLines (is);
|
||||
unsigned nrVars;
|
||||
is >> nrVars;
|
||||
// read the range of each var
|
||||
ignoreLines (is);
|
||||
Ranges ranges (nrVars);
|
||||
for (unsigned i = 0; i < nrVars; i++) {
|
||||
is >> ranges[i];
|
||||
}
|
||||
|
||||
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
||||
for (unsigned i = 0; i < nVars; i++) {
|
||||
addVariable (new VarNode (i, domainSizes[i]));
|
||||
}
|
||||
|
||||
unsigned nFactors;
|
||||
is >> nFactors;
|
||||
for (unsigned i = 0; i < nFactors; i++) {
|
||||
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
||||
unsigned nFactorVars;
|
||||
is >> nFactorVars;
|
||||
Vars neighs;
|
||||
for (unsigned j = 0; j < nFactorVars; j++) {
|
||||
unsigned vid;
|
||||
unsigned nrFactors;
|
||||
unsigned nrArgs;
|
||||
unsigned vid;
|
||||
is >> nrFactors;
|
||||
vector<VarIds> factorVarIds;
|
||||
vector<Ranges> factorRanges;
|
||||
for (unsigned i = 0; i < nrFactors; i++) {
|
||||
ignoreLines (is);
|
||||
is >> nrArgs;
|
||||
factorVarIds.push_back ({ });
|
||||
factorRanges.push_back ({ });
|
||||
for (unsigned j = 0; j < nrArgs; j++) {
|
||||
is >> vid;
|
||||
VarNode* neigh = getVarNode (vid);
|
||||
if (!neigh) {
|
||||
cerr << "error: invalid variable identifier (" << vid << ")" << endl;
|
||||
if (vid >= ranges.size()) {
|
||||
cerr << "error: invalid variable identifier `" << vid << "'" << endl;
|
||||
cerr << "identifiers must be between 0 and " << ranges.size() - 1 ;
|
||||
cerr << endl;
|
||||
abort();
|
||||
}
|
||||
neighs.push_back (neigh);
|
||||
}
|
||||
FactorNode* fn = new FactorNode (new Factor (neighs));
|
||||
addFactor (fn);
|
||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||
addEdge (fn, static_cast<VarNode*> (neighs[j]));
|
||||
factorVarIds.back().push_back (vid);
|
||||
factorRanges.back().push_back (ranges[vid]);
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < nFactors; i++) {
|
||||
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
||||
unsigned nParams;
|
||||
is >> nParams;
|
||||
if (facNodes_[i]->params().size() != nParams) {
|
||||
cerr << "error: invalid number of parameters for factor " ;
|
||||
cerr << facNodes_[i]->getLabel() ;
|
||||
cerr << ", expected: " << facNodes_[i]->params().size();
|
||||
cerr << ", given: " << nParams << endl;
|
||||
// read the parameters
|
||||
unsigned nrParams;
|
||||
for (unsigned i = 0; i < nrFactors; i++) {
|
||||
ignoreLines (is);
|
||||
is >> nrParams;
|
||||
if (nrParams != Util::expectedSize (factorRanges[i])) {
|
||||
cerr << "error: invalid number of parameters for factor nº " << i ;
|
||||
cerr << ", expected: " << Util::expectedSize (factorRanges[i]);
|
||||
cerr << ", given: " << nrParams << endl;
|
||||
abort();
|
||||
}
|
||||
Params params (nParams);
|
||||
for (unsigned j = 0; j < nParams; j++) {
|
||||
double param;
|
||||
is >> param;
|
||||
params[j] = param;
|
||||
Params params (nrParams);
|
||||
for (unsigned j = 0; j < nrParams; j++) {
|
||||
is >> params[j];
|
||||
}
|
||||
if (Globals::logDomain) {
|
||||
Util::toLog (params);
|
||||
}
|
||||
facNodes_[i]->factor()->setParams (params);
|
||||
addFactor (Factor (factorVarIds[i], factorRanges[i], params));
|
||||
}
|
||||
is.close();
|
||||
setIndexes();
|
||||
@ -131,79 +118,51 @@ FactorGraph::readFromLibDaiFormat (const char* fileName)
|
||||
cerr << "error: cannot read from file " + std::string (fileName) << endl;
|
||||
abort();
|
||||
}
|
||||
|
||||
string line;
|
||||
unsigned nFactors;
|
||||
|
||||
while ((is.peek()) == '#') getline (is, line);
|
||||
is >> nFactors;
|
||||
|
||||
if (is.fail()) {
|
||||
cerr << "error: cannot read the number of factors" << endl;
|
||||
abort();
|
||||
}
|
||||
|
||||
getline (is, line);
|
||||
if (is.fail() || line.size() > 0) {
|
||||
cerr << "error: cannot read the number of factors" << endl;
|
||||
abort();
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < nFactors; i++) {
|
||||
unsigned nVars;
|
||||
while ((is.peek()) == '#') getline (is, line);
|
||||
|
||||
is >> nVars;
|
||||
ignoreLines (is);
|
||||
unsigned nrFactors;
|
||||
unsigned nrArgs;
|
||||
VarId vid;
|
||||
is >> nrFactors;
|
||||
for (unsigned i = 0; i < nrFactors; i++) {
|
||||
ignoreLines (is);
|
||||
// read the factor arguments
|
||||
is >> nrArgs;
|
||||
VarIds vids;
|
||||
for (unsigned j = 0; j < nVars; j++) {
|
||||
VarId vid;
|
||||
while ((is.peek()) == '#') getline (is, line);
|
||||
for (unsigned j = 0; j < nrArgs; j++) {
|
||||
ignoreLines (is);
|
||||
is >> vid;
|
||||
vids.push_back (vid);
|
||||
}
|
||||
|
||||
Vars neighs;
|
||||
unsigned nParams = 1;
|
||||
for (unsigned j = 0; j < nVars; j++) {
|
||||
unsigned dsize;
|
||||
while ((is.peek()) == '#') getline (is, line);
|
||||
is >> dsize;
|
||||
// read ranges
|
||||
Ranges ranges (nrArgs);
|
||||
for (unsigned j = 0; j < nrArgs; j++) {
|
||||
ignoreLines (is);
|
||||
is >> ranges[j];
|
||||
VarNode* var = getVarNode (vids[j]);
|
||||
if (var == 0) {
|
||||
var = new VarNode (vids[j], dsize);
|
||||
addVariable (var);
|
||||
} else {
|
||||
if (var->range() != dsize) {
|
||||
cerr << "error: variable `" << vids[j] << "' appears in two or " ;
|
||||
cerr << "more factors with different domain sizes" << endl;
|
||||
}
|
||||
if (var != 0 && ranges[j] != var->range()) {
|
||||
cerr << "error: variable `" << vids[j] << "' appears in two or " ;
|
||||
cerr << "more factors with a different range" << endl;
|
||||
}
|
||||
neighs.push_back (var);
|
||||
nParams *= var->range();
|
||||
}
|
||||
Params params (nParams, 0);
|
||||
// read parameters
|
||||
ignoreLines (is);
|
||||
unsigned nNonzeros;
|
||||
while ((is.peek()) == '#') getline (is, line);
|
||||
is >> nNonzeros;
|
||||
|
||||
Params params (Util::expectedSize (ranges), 0);
|
||||
for (unsigned j = 0; j < nNonzeros; j++) {
|
||||
ignoreLines (is);
|
||||
unsigned index;
|
||||
double val;
|
||||
while ((is.peek()) == '#') getline (is, line);
|
||||
is >> index;
|
||||
while ((is.peek()) == '#') getline (is, line);
|
||||
ignoreLines (is);
|
||||
double val;
|
||||
is >> val;
|
||||
params[index] = val;
|
||||
}
|
||||
reverse (neighs.begin(), neighs.end());
|
||||
reverse (vids.begin(), vids.end());
|
||||
if (Globals::logDomain) {
|
||||
Util::toLog (params);
|
||||
}
|
||||
FactorNode* fn = new FactorNode (new Factor (neighs, params));
|
||||
addFactor (fn);
|
||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||
addEdge (fn, static_cast<VarNode*> (neighs[j]));
|
||||
}
|
||||
addFactor (Factor (vids, ranges, params));
|
||||
}
|
||||
is.close();
|
||||
setIndexes();
|
||||
@ -223,30 +182,11 @@ FactorGraph::~FactorGraph (void)
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::addVariable (VarNode* vn)
|
||||
{
|
||||
varNodes_.push_back (vn);
|
||||
vn->setIndex (varNodes_.size() - 1);
|
||||
varMap_.insert (make_pair (vn->varId(), varNodes_.size() - 1));
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::addFactor (FactorNode* fn)
|
||||
{
|
||||
facNodes_.push_back (fn);
|
||||
fn->setIndex (facNodes_.size() - 1);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::addFactor (const Factor& factor)
|
||||
{
|
||||
FactorNode* fn = new FactorNode (factor);
|
||||
addFactor (fn);
|
||||
addFactorNode (fn);
|
||||
const VarIds& vids = factor.arguments();
|
||||
for (unsigned i = 0; i < vids.size(); i++) {
|
||||
bool found = false;
|
||||
@ -258,7 +198,7 @@ FactorGraph::addFactor (const Factor& factor)
|
||||
}
|
||||
if (found == false) {
|
||||
VarNode* vn = new VarNode (vids[i], factor.range (i));
|
||||
addVariable (vn);
|
||||
addVarNode (vn);
|
||||
addEdge (vn, fn);
|
||||
}
|
||||
}
|
||||
@ -266,6 +206,25 @@ FactorGraph::addFactor (const Factor& factor)
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::addVarNode (VarNode* vn)
|
||||
{
|
||||
varNodes_.push_back (vn);
|
||||
vn->setIndex (varNodes_.size() - 1);
|
||||
varMap_.insert (make_pair (vn->varId(), varNodes_.size() - 1));
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::addFactorNode (FactorNode* fn)
|
||||
{
|
||||
facNodes_.push_back (fn);
|
||||
fn->setIndex (facNodes_.size() - 1);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::addEdge (VarNode* vn, FactorNode* fn)
|
||||
{
|
||||
@ -301,7 +260,7 @@ FactorGraph::getStructure (void)
|
||||
structure_.addNode (new DAGraphNode (varNodes_[i]));
|
||||
}
|
||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||
const VarIds& vids = facNodes_[i]->factor()->arguments();
|
||||
const VarIds& vids = facNodes_[i]->factor().arguments();
|
||||
for (unsigned j = 1; j < vids.size(); j++) {
|
||||
structure_.addEdge (vids[j], vids[0]);
|
||||
}
|
||||
@ -340,7 +299,7 @@ FactorGraph::print (void) const
|
||||
cout << endl << endl;
|
||||
}
|
||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||
facNodes_[i]->factor()->print();
|
||||
facNodes_[i]->factor().print();
|
||||
}
|
||||
}
|
||||
|
||||
@ -446,7 +405,7 @@ FactorGraph::exportToLibDaiFormat (const char* fileName) const
|
||||
out << factorVars[j]->range() << " " ;
|
||||
}
|
||||
out << endl;
|
||||
Params params = facNodes_[i]->factor()->params();
|
||||
Params params = facNodes_[i]->factor().params();
|
||||
if (Globals::logDomain) {
|
||||
Util::fromLog (params);
|
||||
}
|
||||
@ -461,6 +420,17 @@ FactorGraph::exportToLibDaiFormat (const char* fileName) const
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::ignoreLines (std::ifstream& is) const
|
||||
{
|
||||
string ignoreStr;
|
||||
while (is.peek() == '#' || is.peek() == '\n') {
|
||||
getline (is, ignoreStr);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
FactorGraph::containsCycle (void) const
|
||||
{
|
||||
|
@ -34,47 +34,28 @@ class VarNode : public Var
|
||||
class FactorNode
|
||||
{
|
||||
public:
|
||||
FactorNode (const FactorNode* fn)
|
||||
{
|
||||
factor_ = new Factor (*fn->factor());
|
||||
index_ = -1;
|
||||
}
|
||||
FactorNode (const Factor& f) : factor_(f), index_(-1) { }
|
||||
|
||||
FactorNode (Factor* f) : factor_(new Factor(*f)), index_(-1) { }
|
||||
const Factor& factor (void) const { return factor_; }
|
||||
|
||||
FactorNode (const Factor& f) : factor_(new Factor (f)), index_(-1) { }
|
||||
|
||||
Factor* factor() const { return factor_; }
|
||||
Factor& factor (void) { return factor_; }
|
||||
|
||||
void addNeighbor (VarNode* vn) { neighs_.push_back (vn); }
|
||||
|
||||
const VarNodes& neighbors (void) const { return neighs_; }
|
||||
const VarNodes& neighbors (void) const { return neighs_; }
|
||||
|
||||
int getIndex (void) const
|
||||
{
|
||||
assert (index_ != -1);
|
||||
return index_;
|
||||
}
|
||||
int getIndex (void) const { return index_; }
|
||||
|
||||
void setIndex (int index)
|
||||
{
|
||||
index_ = index;
|
||||
}
|
||||
void setIndex (int index) { index_ = index; }
|
||||
|
||||
const Params& params (void) const
|
||||
{
|
||||
return factor_->params();
|
||||
}
|
||||
const Params& params (void) const { return factor_.params(); }
|
||||
|
||||
string getLabel (void)
|
||||
{
|
||||
return factor_->getLabel();
|
||||
}
|
||||
string getLabel (void) { return factor_.getLabel(); }
|
||||
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN (FactorNode);
|
||||
|
||||
Factor* factor_;
|
||||
Factor factor_;
|
||||
VarNodes neighs_;
|
||||
int index_;
|
||||
};
|
||||
@ -116,12 +97,12 @@ class FactorGraph
|
||||
|
||||
void readFromLibDaiFormat (const char*);
|
||||
|
||||
void addVariable (VarNode*);
|
||||
|
||||
void addFactor (FactorNode*);
|
||||
|
||||
void addFactor (const Factor& factor);
|
||||
|
||||
void addVarNode (VarNode*);
|
||||
|
||||
void addFactorNode (FactorNode*);
|
||||
|
||||
void addEdge (VarNode*, FactorNode*);
|
||||
|
||||
void addEdge (FactorNode*, VarNode*);
|
||||
@ -145,6 +126,8 @@ class FactorGraph
|
||||
private:
|
||||
// DISALLOW_COPY_AND_ASSIGN (FactorGraph);
|
||||
|
||||
void ignoreLines (std::ifstream&) const;
|
||||
|
||||
bool containsCycle (void) const;
|
||||
|
||||
bool containsCycle (const VarNode*, const FactorNode*,
|
||||
@ -153,7 +136,7 @@ class FactorGraph
|
||||
bool containsCycle (const FactorNode*, const VarNode*,
|
||||
vector<bool>&, vector<bool>&) const;
|
||||
|
||||
VarNodes varNodes_;
|
||||
VarNodes varNodes_;
|
||||
FactorNodes facNodes_;
|
||||
|
||||
bool fromBayesNet_;
|
||||
|
@ -31,15 +31,16 @@ main (int argc, const char* argv[])
|
||||
FactorGraph fg;
|
||||
if (extension == "uai") {
|
||||
fg.readFromUaiFormat (argv[1]);
|
||||
processArguments (fg, argc, argv);
|
||||
} else if (extension == "fg") {
|
||||
fg.readFromLibDaiFormat (argv[1]);
|
||||
processArguments (fg, argc, argv);
|
||||
} else {
|
||||
cerr << "error: the graphical model must be defined either " ;
|
||||
cerr << "in a UAI or libDAI file" << endl;
|
||||
exit (0);
|
||||
}
|
||||
fg.print();
|
||||
assert (false);
|
||||
processArguments (fg, argc, argv);
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
@ -214,8 +214,6 @@ createGroundNetwork (void)
|
||||
{
|
||||
FactorGraph* fg = new FactorGraph();;
|
||||
string factorsType ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG1)));
|
||||
cout << "factors type: '" << factorsType << "'" << endl;
|
||||
|
||||
YAP_Term factorList = YAP_ARG2;
|
||||
while (factorList != YAP_TermNil()) {
|
||||
YAP_Term factor = YAP_HeadOfTerm (factorList);
|
||||
@ -240,7 +238,6 @@ createGroundNetwork (void)
|
||||
YAP_Term evTerm = YAP_HeadOfTerm (evidenceList);
|
||||
unsigned vid = (unsigned) YAP_IntOfTerm ((YAP_ArgOfTerm (1, evTerm)));
|
||||
unsigned ev = (unsigned) YAP_IntOfTerm ((YAP_ArgOfTerm (2, evTerm)));
|
||||
cout << vid << " == " << ev << endl;
|
||||
assert (fg->getVarNode (vid));
|
||||
fg->getVarNode (vid)->setEvidence (ev);
|
||||
evidenceList = YAP_TailOfTerm (evidenceList);
|
||||
@ -354,7 +351,7 @@ runGroundSolver (void)
|
||||
} else {
|
||||
runBpSolver (fg, tasks, results);
|
||||
}
|
||||
cout << "results: " << results << endl;
|
||||
|
||||
YAP_Term list = YAP_TermNil();
|
||||
for (int i = results.size() - 1; i >= 0; i--) {
|
||||
const Params& beliefs = results[i];
|
||||
@ -491,24 +488,29 @@ setBayesNetParams (void)
|
||||
|
||||
|
||||
int
|
||||
setExtraVarsInfo (void)
|
||||
setVarsInformation (void)
|
||||
{
|
||||
Var::clearVariablesInformation();
|
||||
YAP_Term varsInfoL = YAP_ARG2;
|
||||
while (varsInfoL != YAP_TermNil()) {
|
||||
YAP_Term head = YAP_HeadOfTerm (varsInfoL);
|
||||
VarId vid = YAP_IntOfTerm (YAP_ArgOfTerm (1, head));
|
||||
YAP_Atom label = YAP_AtomOfTerm (YAP_ArgOfTerm (2, head));
|
||||
YAP_Term statesL = YAP_ArgOfTerm (3, head);
|
||||
Var::clearVarsInfo();
|
||||
YAP_Term labelsL = YAP_ARG1;
|
||||
vector<string> labels;
|
||||
while (labelsL != YAP_TermNil()) {
|
||||
YAP_Atom atom = YAP_AtomOfTerm (YAP_HeadOfTerm (labelsL));
|
||||
labels.push_back ((char*) YAP_AtomName (atom));
|
||||
labelsL = YAP_TailOfTerm (labelsL);
|
||||
}
|
||||
unsigned count = 0;
|
||||
YAP_Term stateNamesL = YAP_ARG2;
|
||||
while (stateNamesL != YAP_TermNil()) {
|
||||
States states;
|
||||
while (statesL != YAP_TermNil()) {
|
||||
YAP_Atom atom = YAP_AtomOfTerm (YAP_HeadOfTerm (statesL));
|
||||
YAP_Term namesL = YAP_HeadOfTerm (stateNamesL);
|
||||
while (namesL != YAP_TermNil()) {
|
||||
YAP_Atom atom = YAP_AtomOfTerm (YAP_HeadOfTerm (namesL));
|
||||
states.push_back ((char*) YAP_AtomName (atom));
|
||||
statesL = YAP_TailOfTerm (statesL);
|
||||
namesL = YAP_TailOfTerm (namesL);
|
||||
}
|
||||
Var::addVariableInformation (vid,
|
||||
(char*) YAP_AtomName (label), states);
|
||||
varsInfoL = YAP_TailOfTerm (varsInfoL);
|
||||
Var::addVarInfo (count, labels[count], states);
|
||||
count ++;
|
||||
stateNamesL = YAP_TailOfTerm (stateNamesL);
|
||||
}
|
||||
return TRUE;
|
||||
}
|
||||
@ -627,7 +629,7 @@ init_predicates (void)
|
||||
YAP_UserCPredicate ("run_ground_solver", runGroundSolver, 3);
|
||||
YAP_UserCPredicate ("set_parfactors_params", setParfactorsParams, 2);
|
||||
YAP_UserCPredicate ("set_bayes_net_params", setBayesNetParams, 2);
|
||||
YAP_UserCPredicate ("set_extra_vars_info", setExtraVarsInfo, 2);
|
||||
YAP_UserCPredicate ("set_vars_information", setVarsInformation, 2);
|
||||
YAP_UserCPredicate ("set_horus_flag", setHorusFlag, 2);
|
||||
YAP_UserCPredicate ("free_parfactors", freeParfactors, 1);
|
||||
YAP_UserCPredicate ("free_ground_network", freeGroundNetwork, 1);
|
||||
|
@ -366,12 +366,12 @@ Statistics::printStatistics (void)
|
||||
|
||||
|
||||
void
|
||||
Statistics::writeStatisticsToFile (const char* fileName)
|
||||
Statistics::writeStatistics (const char* fileName)
|
||||
{
|
||||
ofstream out (fileName);
|
||||
if (!out.is_open()) {
|
||||
cerr << "error: cannot open file to write at " ;
|
||||
cerr << "Statistics::writeStatisticsToFile()" << endl;
|
||||
cerr << "Statistics::writeStats()" << endl;
|
||||
abort();
|
||||
}
|
||||
out << getStatisticString();
|
||||
|
@ -359,7 +359,7 @@ class Statistics
|
||||
|
||||
static void printStatistics (void);
|
||||
|
||||
static void writeStatisticsToFile (const char*);
|
||||
static void writeStatistics (const char*);
|
||||
|
||||
static void updateCompressingStatistics (
|
||||
unsigned, unsigned, unsigned, unsigned, unsigned);
|
||||
|
@ -42,7 +42,7 @@ Var::isValidState (int stateIndex)
|
||||
bool
|
||||
Var::isValidState (const string& stateName)
|
||||
{
|
||||
States states = Var::getVarInformation (varId_).states;
|
||||
States states = Var::getVarInfo (varId_).states;
|
||||
return Util::contains (states, stateName);
|
||||
}
|
||||
|
||||
@ -60,7 +60,7 @@ Var::setEvidence (int ev)
|
||||
void
|
||||
Var::setEvidence (const string& ev)
|
||||
{
|
||||
States states = Var::getVarInformation (varId_).states;
|
||||
States states = Var::getVarInfo (varId_).states;
|
||||
for (unsigned i = 0; i < states.size(); i++) {
|
||||
if (states[i] == ev) {
|
||||
evidence_ = i;
|
||||
@ -75,8 +75,8 @@ Var::setEvidence (const string& ev)
|
||||
string
|
||||
Var::label (void) const
|
||||
{
|
||||
if (Var::variablesHaveInformation()) {
|
||||
return Var::getVarInformation (varId_).label;
|
||||
if (Var::varsHaveInfo()) {
|
||||
return Var::getVarInfo (varId_).label;
|
||||
}
|
||||
stringstream ss;
|
||||
ss << "x" << varId_;
|
||||
@ -88,8 +88,8 @@ Var::label (void) const
|
||||
States
|
||||
Var::states (void) const
|
||||
{
|
||||
if (Var::variablesHaveInformation()) {
|
||||
return Var::getVarInformation (varId_).states;
|
||||
if (Var::varsHaveInfo()) {
|
||||
return Var::getVarInfo (varId_).states;
|
||||
}
|
||||
States states;
|
||||
for (unsigned i = 0; i < range_; i++) {
|
||||
|
@ -71,25 +71,25 @@ class Var
|
||||
|
||||
States states (void) const;
|
||||
|
||||
static void addVariableInformation (
|
||||
static void addVarInfo (
|
||||
VarId vid, string label, const States& states)
|
||||
{
|
||||
assert (Util::contains (varsInfo_, vid) == false);
|
||||
varsInfo_.insert (make_pair (vid, VarInfo (label, states)));
|
||||
}
|
||||
|
||||
static VarInfo getVarInformation (VarId vid)
|
||||
static VarInfo getVarInfo (VarId vid)
|
||||
{
|
||||
assert (Util::contains (varsInfo_, vid));
|
||||
return varsInfo_.find (vid)->second;
|
||||
}
|
||||
|
||||
static bool variablesHaveInformation (void)
|
||||
static bool varsHaveInfo (void)
|
||||
{
|
||||
return varsInfo_.size() != 0;
|
||||
}
|
||||
|
||||
static void clearVariablesInformation (void)
|
||||
static void clearVarsInfo (void)
|
||||
{
|
||||
varsInfo_.clear();
|
||||
}
|
||||
|
@ -60,7 +60,7 @@ VarElimSolver::createFactorList (void)
|
||||
const FactorNodes& factorNodes = factorGraph_->factorNodes();
|
||||
factorList_.reserve (factorNodes.size() * 2);
|
||||
for (unsigned i = 0; i < factorNodes.size(); i++) {
|
||||
factorList_.push_back (new Factor (*factorNodes[i]->factor()));
|
||||
factorList_.push_back (new Factor (factorNodes[i]->factor())); // FIXME
|
||||
const VarNodes& neighs = factorNodes[i]->neighbors();
|
||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||
unordered_map<VarId,vector<unsigned> >::iterator it
|
||||
@ -95,7 +95,6 @@ VarElimSolver::absorveEvidence (void)
|
||||
}
|
||||
}
|
||||
}
|
||||
printActiveFactors();
|
||||
}
|
||||
|
||||
|
||||
|
@ -12,7 +12,7 @@
|
||||
set_bayes_net_params/2,
|
||||
run_lifted_solver/3,
|
||||
run_ground_solver/3,
|
||||
set_extra_vars_info/2,
|
||||
set_vars_information/2,
|
||||
set_horus_flag/2,
|
||||
free_parfactors/1,
|
||||
free_ground_network/1
|
||||
|
Reference in New Issue
Block a user