refactorings

This commit is contained in:
Tiago Gomes 2012-04-10 12:53:52 +01:00
parent aa1b2e40ea
commit b52dc99914
10 changed files with 50 additions and 128 deletions

View File

@ -50,7 +50,7 @@
:- use_module(horus,
[create_ground_network/4,
set_bayes_net_params/2,
set_factors_params/2,
run_ground_solver/3,
set_vars_information/2,
free_ground_network/1
@ -80,7 +80,7 @@ call_bp_ground(QueryKeys, AllKeys, Factors, Evidence, Output) :-
run_solver(ground(Network,Hash), QueryKeys, Solutions) :-
%get_dists_parameters(DistIds, DistsParams),
%set_bayes_net_params(Network, DistsParams),
%set_factors_params(Network, DistsParams),
list_of_keys_to_ids(QueryKeys, Hash, QueryIds),
writeln(queryKeys:QueryKeys), writeln(''),
writeln(queryIds:QueryIds), writeln(''),
@ -154,7 +154,7 @@ init_bp_solver(_, AllVars0, _, bp(BayesNet, DistIds)) :-
run_bp_solver(QueryVars, Solutions, bp(Network, DistIds)) :-
get_dists_parameters(DistIds, DistsParams),
set_bayes_net_params(Network, DistsParams),
set_factors_params(Network, DistsParams),
vars_to_ids(QueryVars, QueryVarsIds),
run_ground_solver(Network, QueryVarsIds, Solutions).

View File

@ -52,7 +52,6 @@ BayesBall::getMinimalFactorGraph (const VarIds& queryIds)
FactorGraph* fg = new FactorGraph();
constructGraph (fg);
return fg;
}

View File

@ -112,8 +112,8 @@ BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
Factor res (facNodes[idx]->factor());
const SpLinkSet& links = ninf(facNodes[idx])->getLinks();
for (unsigned i = 0; i < links.size(); i++) {
Factor msg (links[i]->getVariable()->varId(),
links[i]->getVariable()->range(),
Factor msg ({links[i]->getVariable()->varId()},
{links[i]->getVariable()->range()},
getVar2FactorMsg (links[i]));
res.multiply (msg);
}
@ -362,7 +362,7 @@ BpSolver::calculateFactor2VariableMsg (SpLink* link)
result.multiply (src->factor());
if (Constants::DEBUG >= 5) {
cout << " message product: " << msgProduct << endl;
cout << " original factor: " << src->params() << endl;
cout << " original factor: " << src->factor().params() << endl;
cout << " factor product: " << result.params() << endl;
}
result.sumOutAllExcept (dst->varId());

View File

@ -256,14 +256,13 @@ CFactorGraph::getCompressedFactorGraph (void)
myGroundVars.push_back (v);
}
FacNode* fn = new FacNode (Factor (myGroundVars,
facClusters_[i]->getGroundFactors()[0]->params()));
facClusters_[i]->getGroundFactors()[0]->factor().params()));
facClusters_[i]->setRepresentativeFactor (fn);
fg->addFacNode (fn);
for (unsigned j = 0; j < myGroundVars.size(); j++) {
fg->addEdge (static_cast<VarNode*> (myGroundVars[j]), fn);
}
}
fg->setIndexes();
return fg;
}

View File

@ -18,42 +18,16 @@ Factor::Factor (const Factor& g)
Factor::Factor (VarId vid, unsigned range)
{
args_.push_back (vid);
ranges_.push_back (range);
params_.resize (range, 1.0);
distId_ = Util::maxUnsigned();
assert (params_.size() == Util::expectedSize (ranges_));
}
Factor::Factor (const Vars& vars)
{
int nrParams = 1;
for (unsigned i = 0; i < vars.size(); i++) {
args_.push_back (vars[i]->varId());
ranges_.push_back (vars[i]->range());
nrParams *= vars[i]->range();
}
double val = 1.0 / nrParams;
params_.resize (nrParams, val);
distId_ = Util::maxUnsigned();
assert (params_.size() == Util::expectedSize (ranges_));
}
Factor::Factor (
VarId vid,
unsigned range,
const Params& params)
const VarIds& vids,
const Ranges& ranges,
const Params& params,
unsigned distId)
{
args_.push_back (vid);
ranges_.push_back (range);
args_ = vids;
ranges_ = ranges;
params_ = params;
distId_ = Util::maxUnsigned();
distId_ = distId;
assert (params_.size() == Util::expectedSize (ranges_));
}
@ -75,21 +49,6 @@ Factor::Factor (
Factor::Factor (
const VarIds& vids,
const Ranges& ranges,
const Params& params,
unsigned distId)
{
args_ = vids;
ranges_ = ranges;
params_ = params;
distId_ = distId;
assert (params_.size() == Util::expectedSize (ranges_));
}
void
Factor::sumOutAllExcept (VarId vid)
{

View File

@ -33,17 +33,14 @@ class TFactor
void setDistId (unsigned id) { distId_ = id; }
void normalize (void) { LogAware::normalize (params_); }
void setParams (const Params& newParams)
{
params_ = newParams;
assert (params_.size() == Util::expectedSize (ranges_));
}
void normalize (void)
{
LogAware::normalize (params_);
}
int indexOf (const T& t) const
{
int idx = -1;
@ -258,16 +255,10 @@ class Factor : public TFactor<VarId>
Factor (const Factor&);
Factor (VarId, unsigned);
Factor (const Vars&);
Factor (VarId, unsigned, const Params&);
Factor (const Vars&, const Params&,
Factor (const VarIds&, const Ranges&, const Params&,
unsigned = Util::maxUnsigned());
Factor (const VarIds&, const Ranges&, const Params&,
Factor (const Vars&, const Params&,
unsigned = Util::maxUnsigned());
void sumOutAllExcept (VarId);

View File

@ -31,7 +31,6 @@ FactorGraph::FactorGraph (const FactorGraph& fg)
addEdge (varNodes_[neighs[j]->getIndex()], facNode);
}
}
setIndexes();
}
@ -105,7 +104,6 @@ FactorGraph::readFromUaiFormat (const char* fileName)
addFactor (Factor (factorVarIds[i], factorRanges[i], params));
}
is.close();
setIndexes();
}
@ -165,7 +163,6 @@ FactorGraph::readFromLibDaiFormat (const char* fileName)
addFactor (Factor (vids, ranges, params));
}
is.close();
setIndexes();
}
@ -211,7 +208,7 @@ FactorGraph::addVarNode (VarNode* vn)
{
varNodes_.push_back (vn);
vn->setIndex (varNodes_.size() - 1);
varMap_.insert (make_pair (vn->varId(), varNodes_.size() - 1));
varMap_.insert (make_pair (vn->varId(), vn));
}
@ -262,19 +259,6 @@ FactorGraph::getStructure (void)
void
FactorGraph::setIndexes (void)
{
for (unsigned i = 0; i < varNodes_.size(); i++) {
varNodes_[i]->setIndex (i);
}
for (unsigned i = 0; i < facNodes_.size(); i++) {
facNodes_[i]->setIndex (i);
}
}
void
FactorGraph::print (void) const
{
@ -355,7 +339,7 @@ FactorGraph::exportToUaiFormat (const char* fileName) const
out << endl;
}
for (unsigned i = 0; i < facNodes_.size(); i++) {
Params params = facNodes_[i]->params();
Params params = facNodes_[i]->factor().params();
if (Globals::logDomain) {
Util::fromLog (params);
}

View File

@ -12,11 +12,11 @@ using namespace std;
class FacNode;
class VarNode : public Var
{
public:
VarNode (VarId varId, unsigned nrStates) : Var (varId, nrStates) { }
VarNode (VarId varId, unsigned nrStates)
: Var (varId, nrStates) { }
VarNode (const Var* v) : Var (v) { }
@ -48,15 +48,13 @@ class FacNode
void setIndex (int index) { index_ = index; }
const Params& params (void) const { return factor_.params(); }
string getLabel (void) { return factor_.getLabel(); }
private:
DISALLOW_COPY_AND_ASSIGN (FacNode);
Factor factor_;
VarNodes neighs_;
Factor factor_;
int index_;
};
@ -73,7 +71,7 @@ struct CompVarId
class FactorGraph
{
public:
FactorGraph (void) { }
FactorGraph (bool fbn = false) : fromBayesNet_(fbn) { }
FactorGraph (const FactorGraph&);
@ -82,15 +80,13 @@ class FactorGraph
const VarNodes& varNodes (void) const { return varNodes_; }
const FacNodes& facNodes (void) const { return facNodes_; }
void setFromBayesNetwork (void) { fromBayesNet_ = true; }
bool isFromBayesNetwork (void) const { return fromBayesNet_ ; }
VarNode* getVarNode (VarId vid) const
{
IndexMap::const_iterator it = varMap_.find (vid);
return (it != varMap_.end()) ? varNodes_[it->second] : 0;
VarMap::const_iterator it = varMap_.find (vid);
return it != varMap_.end() ? it->second : 0;
}
void readFromUaiFormat (const char*);
@ -109,8 +105,6 @@ class FactorGraph
DAGraph& getStructure (void);
void setIndexes (void);
void print (void) const;
void exportToGraphViz (const char*) const;
@ -137,11 +131,11 @@ class FactorGraph
VarNodes varNodes_;
FacNodes facNodes_;
bool fromBayesNet_;
DAGraph structure_;
bool fromBayesNet_;
typedef unordered_map<unsigned, unsigned> IndexMap;
IndexMap varMap_;
typedef unordered_map<unsigned, VarNode*> VarMap;
VarMap varMap_;
};
#endif // HORUS_FACTORGRAPH_H

View File

@ -31,6 +31,13 @@ void readLiftedEvidence (YAP_Term, ObservedFormulas&);
Parfactor* readParfactor (YAP_Term);
void runVeSolver (FactorGraph* fg, const vector<VarIds>& tasks,
vector<Params>& results);
void runBpSolver (FactorGraph* fg, const vector<VarIds>& tasks,
vector<Params>& results);
vector<unsigned>
@ -212,8 +219,9 @@ void readLiftedEvidence (
int
createGroundNetwork (void)
{
FactorGraph* fg = new FactorGraph();;
string factorsType ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG1)));
bool fromBayesNet = factorsType == "bayes";
FactorGraph* fg = new FactorGraph (fromBayesNet);
YAP_Term factorList = YAP_ARG2;
while (factorList != YAP_TermNil()) {
YAP_Term factor = YAP_HeadOfTerm (factorList);
@ -228,10 +236,6 @@ createGroundNetwork (void)
fg->addFactor (Factor (varIds, ranges, params, distId));
factorList = YAP_TailOfTerm (factorList);
}
if (factorsType == "bayes") {
fg->setFromBayesNetwork();
}
fg->setIndexes();
YAP_Term evidenceList = YAP_ARG3;
while (evidenceList != YAP_TermNil()) {
@ -326,13 +330,6 @@ runLiftedSolver (void)
void runVeSolver (FactorGraph* fg, const vector<VarIds>& tasks,
vector<Params>& results);
void runBpSolver (FactorGraph* fg, const vector<VarIds>& tasks,
vector<Params>& results);
int
runGroundSolver (void)
{
@ -462,26 +459,25 @@ setParfactorsParams (void)
int
setBayesNetParams (void)
setFactorsParams (void)
{
// TODO FIXME
/*
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
return TRUE; // TODO
FactorGraph* fg = (FactorGraph*) YAP_IntOfTerm (YAP_ARG1);
YAP_Term distList = YAP_ARG2;
unordered_map<unsigned, Params> paramsMap;
while (distList != YAP_TermNil()) {
YAP_Term dist = YAP_HeadOfTerm (distList);
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, dist));
YAP_Term dist = YAP_HeadOfTerm (distList);
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, dist));
assert (Util::contains (paramsMap, distId) == false);
paramsMap[distId] = readParameters (YAP_ArgOfTerm (2, dist));
distList = YAP_TailOfTerm (distList);
}
const BnNodeSet& nodes = bn->getBayesNodes();
for (unsigned i = 0; i < nodes.size(); i++) {
assert (Util::contains (paramsMap, nodes[i]->distId()));
nodes[i]->setParams (paramsMap[nodes[i]->distId()]);
const FacNodes& facNodes = fg->facNodes();
for (unsigned i = 0; i < facNodes.size(); i++) {
unsigned distId = facNodes[i]->factor().distId();
assert (Util::contains (paramsMap, distId));
facNodes[i]->factor().setParams (paramsMap[distId]);
}
*/
return TRUE;
}
@ -628,7 +624,7 @@ init_predicates (void)
YAP_UserCPredicate ("run_lifted_solver", runLiftedSolver, 3);
YAP_UserCPredicate ("run_ground_solver", runGroundSolver, 3);
YAP_UserCPredicate ("set_parfactors_params", setParfactorsParams, 2);
YAP_UserCPredicate ("set_bayes_net_params", setBayesNetParams, 2);
YAP_UserCPredicate ("set_factors_params", setFactorsParams, 2);
YAP_UserCPredicate ("set_vars_information", setVarsInformation, 2);
YAP_UserCPredicate ("set_horus_flag", setHorusFlag, 2);
YAP_UserCPredicate ("free_parfactors", freeParfactors, 1);

View File

@ -9,7 +9,7 @@
[create_lifted_network/3,
create_ground_network/4,
set_parfactors_params/2,
set_bayes_net_params/2,
set_factors_params/2,
run_lifted_solver/3,
run_ground_solver/3,
set_vars_information/2,