refactorings

This commit is contained in:
Tiago Gomes 2012-04-10 12:53:52 +01:00
parent aa1b2e40ea
commit b52dc99914
10 changed files with 50 additions and 128 deletions

View File

@ -50,7 +50,7 @@
:- use_module(horus, :- use_module(horus,
[create_ground_network/4, [create_ground_network/4,
set_bayes_net_params/2, set_factors_params/2,
run_ground_solver/3, run_ground_solver/3,
set_vars_information/2, set_vars_information/2,
free_ground_network/1 free_ground_network/1
@ -80,7 +80,7 @@ call_bp_ground(QueryKeys, AllKeys, Factors, Evidence, Output) :-
run_solver(ground(Network,Hash), QueryKeys, Solutions) :- run_solver(ground(Network,Hash), QueryKeys, Solutions) :-
%get_dists_parameters(DistIds, DistsParams), %get_dists_parameters(DistIds, DistsParams),
%set_bayes_net_params(Network, DistsParams), %set_factors_params(Network, DistsParams),
list_of_keys_to_ids(QueryKeys, Hash, QueryIds), list_of_keys_to_ids(QueryKeys, Hash, QueryIds),
writeln(queryKeys:QueryKeys), writeln(''), writeln(queryKeys:QueryKeys), writeln(''),
writeln(queryIds:QueryIds), writeln(''), writeln(queryIds:QueryIds), writeln(''),
@ -154,7 +154,7 @@ init_bp_solver(_, AllVars0, _, bp(BayesNet, DistIds)) :-
run_bp_solver(QueryVars, Solutions, bp(Network, DistIds)) :- run_bp_solver(QueryVars, Solutions, bp(Network, DistIds)) :-
get_dists_parameters(DistIds, DistsParams), get_dists_parameters(DistIds, DistsParams),
set_bayes_net_params(Network, DistsParams), set_factors_params(Network, DistsParams),
vars_to_ids(QueryVars, QueryVarsIds), vars_to_ids(QueryVars, QueryVarsIds),
run_ground_solver(Network, QueryVarsIds, Solutions). run_ground_solver(Network, QueryVarsIds, Solutions).

View File

@ -52,7 +52,6 @@ BayesBall::getMinimalFactorGraph (const VarIds& queryIds)
FactorGraph* fg = new FactorGraph(); FactorGraph* fg = new FactorGraph();
constructGraph (fg); constructGraph (fg);
return fg; return fg;
} }

View File

@ -112,8 +112,8 @@ BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
Factor res (facNodes[idx]->factor()); Factor res (facNodes[idx]->factor());
const SpLinkSet& links = ninf(facNodes[idx])->getLinks(); const SpLinkSet& links = ninf(facNodes[idx])->getLinks();
for (unsigned i = 0; i < links.size(); i++) { for (unsigned i = 0; i < links.size(); i++) {
Factor msg (links[i]->getVariable()->varId(), Factor msg ({links[i]->getVariable()->varId()},
links[i]->getVariable()->range(), {links[i]->getVariable()->range()},
getVar2FactorMsg (links[i])); getVar2FactorMsg (links[i]));
res.multiply (msg); res.multiply (msg);
} }
@ -362,7 +362,7 @@ BpSolver::calculateFactor2VariableMsg (SpLink* link)
result.multiply (src->factor()); result.multiply (src->factor());
if (Constants::DEBUG >= 5) { if (Constants::DEBUG >= 5) {
cout << " message product: " << msgProduct << endl; cout << " message product: " << msgProduct << endl;
cout << " original factor: " << src->params() << endl; cout << " original factor: " << src->factor().params() << endl;
cout << " factor product: " << result.params() << endl; cout << " factor product: " << result.params() << endl;
} }
result.sumOutAllExcept (dst->varId()); result.sumOutAllExcept (dst->varId());

View File

@ -256,14 +256,13 @@ CFactorGraph::getCompressedFactorGraph (void)
myGroundVars.push_back (v); myGroundVars.push_back (v);
} }
FacNode* fn = new FacNode (Factor (myGroundVars, FacNode* fn = new FacNode (Factor (myGroundVars,
facClusters_[i]->getGroundFactors()[0]->params())); facClusters_[i]->getGroundFactors()[0]->factor().params()));
facClusters_[i]->setRepresentativeFactor (fn); facClusters_[i]->setRepresentativeFactor (fn);
fg->addFacNode (fn); fg->addFacNode (fn);
for (unsigned j = 0; j < myGroundVars.size(); j++) { for (unsigned j = 0; j < myGroundVars.size(); j++) {
fg->addEdge (static_cast<VarNode*> (myGroundVars[j]), fn); fg->addEdge (static_cast<VarNode*> (myGroundVars[j]), fn);
} }
} }
fg->setIndexes();
return fg; return fg;
} }

View File

@ -18,42 +18,16 @@ Factor::Factor (const Factor& g)
Factor::Factor (VarId vid, unsigned range)
{
args_.push_back (vid);
ranges_.push_back (range);
params_.resize (range, 1.0);
distId_ = Util::maxUnsigned();
assert (params_.size() == Util::expectedSize (ranges_));
}
Factor::Factor (const Vars& vars)
{
int nrParams = 1;
for (unsigned i = 0; i < vars.size(); i++) {
args_.push_back (vars[i]->varId());
ranges_.push_back (vars[i]->range());
nrParams *= vars[i]->range();
}
double val = 1.0 / nrParams;
params_.resize (nrParams, val);
distId_ = Util::maxUnsigned();
assert (params_.size() == Util::expectedSize (ranges_));
}
Factor::Factor ( Factor::Factor (
VarId vid, const VarIds& vids,
unsigned range, const Ranges& ranges,
const Params& params) const Params& params,
unsigned distId)
{ {
args_.push_back (vid); args_ = vids;
ranges_.push_back (range); ranges_ = ranges;
params_ = params; params_ = params;
distId_ = Util::maxUnsigned(); distId_ = distId;
assert (params_.size() == Util::expectedSize (ranges_)); assert (params_.size() == Util::expectedSize (ranges_));
} }
@ -75,21 +49,6 @@ Factor::Factor (
Factor::Factor (
const VarIds& vids,
const Ranges& ranges,
const Params& params,
unsigned distId)
{
args_ = vids;
ranges_ = ranges;
params_ = params;
distId_ = distId;
assert (params_.size() == Util::expectedSize (ranges_));
}
void void
Factor::sumOutAllExcept (VarId vid) Factor::sumOutAllExcept (VarId vid)
{ {

View File

@ -33,17 +33,14 @@ class TFactor
void setDistId (unsigned id) { distId_ = id; } void setDistId (unsigned id) { distId_ = id; }
void normalize (void) { LogAware::normalize (params_); }
void setParams (const Params& newParams) void setParams (const Params& newParams)
{ {
params_ = newParams; params_ = newParams;
assert (params_.size() == Util::expectedSize (ranges_)); assert (params_.size() == Util::expectedSize (ranges_));
} }
void normalize (void)
{
LogAware::normalize (params_);
}
int indexOf (const T& t) const int indexOf (const T& t) const
{ {
int idx = -1; int idx = -1;
@ -258,16 +255,10 @@ class Factor : public TFactor<VarId>
Factor (const Factor&); Factor (const Factor&);
Factor (VarId, unsigned); Factor (const VarIds&, const Ranges&, const Params&,
Factor (const Vars&);
Factor (VarId, unsigned, const Params&);
Factor (const Vars&, const Params&,
unsigned = Util::maxUnsigned()); unsigned = Util::maxUnsigned());
Factor (const VarIds&, const Ranges&, const Params&, Factor (const Vars&, const Params&,
unsigned = Util::maxUnsigned()); unsigned = Util::maxUnsigned());
void sumOutAllExcept (VarId); void sumOutAllExcept (VarId);

View File

@ -31,7 +31,6 @@ FactorGraph::FactorGraph (const FactorGraph& fg)
addEdge (varNodes_[neighs[j]->getIndex()], facNode); addEdge (varNodes_[neighs[j]->getIndex()], facNode);
} }
} }
setIndexes();
} }
@ -105,7 +104,6 @@ FactorGraph::readFromUaiFormat (const char* fileName)
addFactor (Factor (factorVarIds[i], factorRanges[i], params)); addFactor (Factor (factorVarIds[i], factorRanges[i], params));
} }
is.close(); is.close();
setIndexes();
} }
@ -165,7 +163,6 @@ FactorGraph::readFromLibDaiFormat (const char* fileName)
addFactor (Factor (vids, ranges, params)); addFactor (Factor (vids, ranges, params));
} }
is.close(); is.close();
setIndexes();
} }
@ -211,7 +208,7 @@ FactorGraph::addVarNode (VarNode* vn)
{ {
varNodes_.push_back (vn); varNodes_.push_back (vn);
vn->setIndex (varNodes_.size() - 1); vn->setIndex (varNodes_.size() - 1);
varMap_.insert (make_pair (vn->varId(), varNodes_.size() - 1)); varMap_.insert (make_pair (vn->varId(), vn));
} }
@ -262,19 +259,6 @@ FactorGraph::getStructure (void)
void
FactorGraph::setIndexes (void)
{
for (unsigned i = 0; i < varNodes_.size(); i++) {
varNodes_[i]->setIndex (i);
}
for (unsigned i = 0; i < facNodes_.size(); i++) {
facNodes_[i]->setIndex (i);
}
}
void void
FactorGraph::print (void) const FactorGraph::print (void) const
{ {
@ -355,7 +339,7 @@ FactorGraph::exportToUaiFormat (const char* fileName) const
out << endl; out << endl;
} }
for (unsigned i = 0; i < facNodes_.size(); i++) { for (unsigned i = 0; i < facNodes_.size(); i++) {
Params params = facNodes_[i]->params(); Params params = facNodes_[i]->factor().params();
if (Globals::logDomain) { if (Globals::logDomain) {
Util::fromLog (params); Util::fromLog (params);
} }

View File

@ -12,11 +12,11 @@ using namespace std;
class FacNode; class FacNode;
class VarNode : public Var class VarNode : public Var
{ {
public: public:
VarNode (VarId varId, unsigned nrStates) : Var (varId, nrStates) { } VarNode (VarId varId, unsigned nrStates)
: Var (varId, nrStates) { }
VarNode (const Var* v) : Var (v) { } VarNode (const Var* v) : Var (v) { }
@ -48,15 +48,13 @@ class FacNode
void setIndex (int index) { index_ = index; } void setIndex (int index) { index_ = index; }
const Params& params (void) const { return factor_.params(); }
string getLabel (void) { return factor_.getLabel(); } string getLabel (void) { return factor_.getLabel(); }
private: private:
DISALLOW_COPY_AND_ASSIGN (FacNode); DISALLOW_COPY_AND_ASSIGN (FacNode);
Factor factor_;
VarNodes neighs_; VarNodes neighs_;
Factor factor_;
int index_; int index_;
}; };
@ -73,7 +71,7 @@ struct CompVarId
class FactorGraph class FactorGraph
{ {
public: public:
FactorGraph (void) { } FactorGraph (bool fbn = false) : fromBayesNet_(fbn) { }
FactorGraph (const FactorGraph&); FactorGraph (const FactorGraph&);
@ -82,15 +80,13 @@ class FactorGraph
const VarNodes& varNodes (void) const { return varNodes_; } const VarNodes& varNodes (void) const { return varNodes_; }
const FacNodes& facNodes (void) const { return facNodes_; } const FacNodes& facNodes (void) const { return facNodes_; }
void setFromBayesNetwork (void) { fromBayesNet_ = true; }
bool isFromBayesNetwork (void) const { return fromBayesNet_ ; } bool isFromBayesNetwork (void) const { return fromBayesNet_ ; }
VarNode* getVarNode (VarId vid) const VarNode* getVarNode (VarId vid) const
{ {
IndexMap::const_iterator it = varMap_.find (vid); VarMap::const_iterator it = varMap_.find (vid);
return (it != varMap_.end()) ? varNodes_[it->second] : 0; return it != varMap_.end() ? it->second : 0;
} }
void readFromUaiFormat (const char*); void readFromUaiFormat (const char*);
@ -109,8 +105,6 @@ class FactorGraph
DAGraph& getStructure (void); DAGraph& getStructure (void);
void setIndexes (void);
void print (void) const; void print (void) const;
void exportToGraphViz (const char*) const; void exportToGraphViz (const char*) const;
@ -137,11 +131,11 @@ class FactorGraph
VarNodes varNodes_; VarNodes varNodes_;
FacNodes facNodes_; FacNodes facNodes_;
bool fromBayesNet_;
DAGraph structure_; DAGraph structure_;
bool fromBayesNet_;
typedef unordered_map<unsigned, unsigned> IndexMap; typedef unordered_map<unsigned, VarNode*> VarMap;
IndexMap varMap_; VarMap varMap_;
}; };
#endif // HORUS_FACTORGRAPH_H #endif // HORUS_FACTORGRAPH_H

View File

@ -31,6 +31,13 @@ void readLiftedEvidence (YAP_Term, ObservedFormulas&);
Parfactor* readParfactor (YAP_Term); Parfactor* readParfactor (YAP_Term);
void runVeSolver (FactorGraph* fg, const vector<VarIds>& tasks,
vector<Params>& results);
void runBpSolver (FactorGraph* fg, const vector<VarIds>& tasks,
vector<Params>& results);
vector<unsigned> vector<unsigned>
@ -212,8 +219,9 @@ void readLiftedEvidence (
int int
createGroundNetwork (void) createGroundNetwork (void)
{ {
FactorGraph* fg = new FactorGraph();;
string factorsType ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG1))); string factorsType ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG1)));
bool fromBayesNet = factorsType == "bayes";
FactorGraph* fg = new FactorGraph (fromBayesNet);
YAP_Term factorList = YAP_ARG2; YAP_Term factorList = YAP_ARG2;
while (factorList != YAP_TermNil()) { while (factorList != YAP_TermNil()) {
YAP_Term factor = YAP_HeadOfTerm (factorList); YAP_Term factor = YAP_HeadOfTerm (factorList);
@ -228,10 +236,6 @@ createGroundNetwork (void)
fg->addFactor (Factor (varIds, ranges, params, distId)); fg->addFactor (Factor (varIds, ranges, params, distId));
factorList = YAP_TailOfTerm (factorList); factorList = YAP_TailOfTerm (factorList);
} }
if (factorsType == "bayes") {
fg->setFromBayesNetwork();
}
fg->setIndexes();
YAP_Term evidenceList = YAP_ARG3; YAP_Term evidenceList = YAP_ARG3;
while (evidenceList != YAP_TermNil()) { while (evidenceList != YAP_TermNil()) {
@ -326,13 +330,6 @@ runLiftedSolver (void)
void runVeSolver (FactorGraph* fg, const vector<VarIds>& tasks,
vector<Params>& results);
void runBpSolver (FactorGraph* fg, const vector<VarIds>& tasks,
vector<Params>& results);
int int
runGroundSolver (void) runGroundSolver (void)
{ {
@ -462,26 +459,25 @@ setParfactorsParams (void)
int int
setBayesNetParams (void) setFactorsParams (void)
{ {
// TODO FIXME return TRUE; // TODO
/* FactorGraph* fg = (FactorGraph*) YAP_IntOfTerm (YAP_ARG1);
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
YAP_Term distList = YAP_ARG2; YAP_Term distList = YAP_ARG2;
unordered_map<unsigned, Params> paramsMap; unordered_map<unsigned, Params> paramsMap;
while (distList != YAP_TermNil()) { while (distList != YAP_TermNil()) {
YAP_Term dist = YAP_HeadOfTerm (distList); YAP_Term dist = YAP_HeadOfTerm (distList);
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, dist)); unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, dist));
assert (Util::contains (paramsMap, distId) == false); assert (Util::contains (paramsMap, distId) == false);
paramsMap[distId] = readParameters (YAP_ArgOfTerm (2, dist)); paramsMap[distId] = readParameters (YAP_ArgOfTerm (2, dist));
distList = YAP_TailOfTerm (distList); distList = YAP_TailOfTerm (distList);
} }
const BnNodeSet& nodes = bn->getBayesNodes(); const FacNodes& facNodes = fg->facNodes();
for (unsigned i = 0; i < nodes.size(); i++) { for (unsigned i = 0; i < facNodes.size(); i++) {
assert (Util::contains (paramsMap, nodes[i]->distId())); unsigned distId = facNodes[i]->factor().distId();
nodes[i]->setParams (paramsMap[nodes[i]->distId()]); assert (Util::contains (paramsMap, distId));
facNodes[i]->factor().setParams (paramsMap[distId]);
} }
*/
return TRUE; return TRUE;
} }
@ -628,7 +624,7 @@ init_predicates (void)
YAP_UserCPredicate ("run_lifted_solver", runLiftedSolver, 3); YAP_UserCPredicate ("run_lifted_solver", runLiftedSolver, 3);
YAP_UserCPredicate ("run_ground_solver", runGroundSolver, 3); YAP_UserCPredicate ("run_ground_solver", runGroundSolver, 3);
YAP_UserCPredicate ("set_parfactors_params", setParfactorsParams, 2); YAP_UserCPredicate ("set_parfactors_params", setParfactorsParams, 2);
YAP_UserCPredicate ("set_bayes_net_params", setBayesNetParams, 2); YAP_UserCPredicate ("set_factors_params", setFactorsParams, 2);
YAP_UserCPredicate ("set_vars_information", setVarsInformation, 2); YAP_UserCPredicate ("set_vars_information", setVarsInformation, 2);
YAP_UserCPredicate ("set_horus_flag", setHorusFlag, 2); YAP_UserCPredicate ("set_horus_flag", setHorusFlag, 2);
YAP_UserCPredicate ("free_parfactors", freeParfactors, 1); YAP_UserCPredicate ("free_parfactors", freeParfactors, 1);

View File

@ -9,7 +9,7 @@
[create_lifted_network/3, [create_lifted_network/3,
create_ground_network/4, create_ground_network/4,
set_parfactors_params/2, set_parfactors_params/2,
set_bayes_net_params/2, set_factors_params/2,
run_lifted_solver/3, run_lifted_solver/3,
run_ground_solver/3, run_ground_solver/3,
set_vars_information/2, set_vars_information/2,