refactorings
This commit is contained in:
parent
aa1b2e40ea
commit
b52dc99914
@ -50,7 +50,7 @@
|
||||
|
||||
:- use_module(horus,
|
||||
[create_ground_network/4,
|
||||
set_bayes_net_params/2,
|
||||
set_factors_params/2,
|
||||
run_ground_solver/3,
|
||||
set_vars_information/2,
|
||||
free_ground_network/1
|
||||
@ -80,7 +80,7 @@ call_bp_ground(QueryKeys, AllKeys, Factors, Evidence, Output) :-
|
||||
|
||||
run_solver(ground(Network,Hash), QueryKeys, Solutions) :-
|
||||
%get_dists_parameters(DistIds, DistsParams),
|
||||
%set_bayes_net_params(Network, DistsParams),
|
||||
%set_factors_params(Network, DistsParams),
|
||||
list_of_keys_to_ids(QueryKeys, Hash, QueryIds),
|
||||
writeln(queryKeys:QueryKeys), writeln(''),
|
||||
writeln(queryIds:QueryIds), writeln(''),
|
||||
@ -154,7 +154,7 @@ init_bp_solver(_, AllVars0, _, bp(BayesNet, DistIds)) :-
|
||||
|
||||
run_bp_solver(QueryVars, Solutions, bp(Network, DistIds)) :-
|
||||
get_dists_parameters(DistIds, DistsParams),
|
||||
set_bayes_net_params(Network, DistsParams),
|
||||
set_factors_params(Network, DistsParams),
|
||||
vars_to_ids(QueryVars, QueryVarsIds),
|
||||
run_ground_solver(Network, QueryVarsIds, Solutions).
|
||||
|
||||
|
@ -52,7 +52,6 @@ BayesBall::getMinimalFactorGraph (const VarIds& queryIds)
|
||||
FactorGraph* fg = new FactorGraph();
|
||||
constructGraph (fg);
|
||||
return fg;
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -112,8 +112,8 @@ BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
||||
Factor res (facNodes[idx]->factor());
|
||||
const SpLinkSet& links = ninf(facNodes[idx])->getLinks();
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
Factor msg (links[i]->getVariable()->varId(),
|
||||
links[i]->getVariable()->range(),
|
||||
Factor msg ({links[i]->getVariable()->varId()},
|
||||
{links[i]->getVariable()->range()},
|
||||
getVar2FactorMsg (links[i]));
|
||||
res.multiply (msg);
|
||||
}
|
||||
@ -362,7 +362,7 @@ BpSolver::calculateFactor2VariableMsg (SpLink* link)
|
||||
result.multiply (src->factor());
|
||||
if (Constants::DEBUG >= 5) {
|
||||
cout << " message product: " << msgProduct << endl;
|
||||
cout << " original factor: " << src->params() << endl;
|
||||
cout << " original factor: " << src->factor().params() << endl;
|
||||
cout << " factor product: " << result.params() << endl;
|
||||
}
|
||||
result.sumOutAllExcept (dst->varId());
|
||||
|
@ -256,14 +256,13 @@ CFactorGraph::getCompressedFactorGraph (void)
|
||||
myGroundVars.push_back (v);
|
||||
}
|
||||
FacNode* fn = new FacNode (Factor (myGroundVars,
|
||||
facClusters_[i]->getGroundFactors()[0]->params()));
|
||||
facClusters_[i]->getGroundFactors()[0]->factor().params()));
|
||||
facClusters_[i]->setRepresentativeFactor (fn);
|
||||
fg->addFacNode (fn);
|
||||
for (unsigned j = 0; j < myGroundVars.size(); j++) {
|
||||
fg->addEdge (static_cast<VarNode*> (myGroundVars[j]), fn);
|
||||
}
|
||||
}
|
||||
fg->setIndexes();
|
||||
return fg;
|
||||
}
|
||||
|
||||
|
@ -18,42 +18,16 @@ Factor::Factor (const Factor& g)
|
||||
|
||||
|
||||
|
||||
Factor::Factor (VarId vid, unsigned range)
|
||||
{
|
||||
args_.push_back (vid);
|
||||
ranges_.push_back (range);
|
||||
params_.resize (range, 1.0);
|
||||
distId_ = Util::maxUnsigned();
|
||||
assert (params_.size() == Util::expectedSize (ranges_));
|
||||
}
|
||||
|
||||
|
||||
|
||||
Factor::Factor (const Vars& vars)
|
||||
{
|
||||
int nrParams = 1;
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
args_.push_back (vars[i]->varId());
|
||||
ranges_.push_back (vars[i]->range());
|
||||
nrParams *= vars[i]->range();
|
||||
}
|
||||
double val = 1.0 / nrParams;
|
||||
params_.resize (nrParams, val);
|
||||
distId_ = Util::maxUnsigned();
|
||||
assert (params_.size() == Util::expectedSize (ranges_));
|
||||
}
|
||||
|
||||
|
||||
|
||||
Factor::Factor (
|
||||
VarId vid,
|
||||
unsigned range,
|
||||
const Params& params)
|
||||
const VarIds& vids,
|
||||
const Ranges& ranges,
|
||||
const Params& params,
|
||||
unsigned distId)
|
||||
{
|
||||
args_.push_back (vid);
|
||||
ranges_.push_back (range);
|
||||
args_ = vids;
|
||||
ranges_ = ranges;
|
||||
params_ = params;
|
||||
distId_ = Util::maxUnsigned();
|
||||
distId_ = distId;
|
||||
assert (params_.size() == Util::expectedSize (ranges_));
|
||||
}
|
||||
|
||||
@ -75,21 +49,6 @@ Factor::Factor (
|
||||
|
||||
|
||||
|
||||
Factor::Factor (
|
||||
const VarIds& vids,
|
||||
const Ranges& ranges,
|
||||
const Params& params,
|
||||
unsigned distId)
|
||||
{
|
||||
args_ = vids;
|
||||
ranges_ = ranges;
|
||||
params_ = params;
|
||||
distId_ = distId;
|
||||
assert (params_.size() == Util::expectedSize (ranges_));
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::sumOutAllExcept (VarId vid)
|
||||
{
|
||||
|
@ -33,17 +33,14 @@ class TFactor
|
||||
|
||||
void setDistId (unsigned id) { distId_ = id; }
|
||||
|
||||
void normalize (void) { LogAware::normalize (params_); }
|
||||
|
||||
void setParams (const Params& newParams)
|
||||
{
|
||||
params_ = newParams;
|
||||
assert (params_.size() == Util::expectedSize (ranges_));
|
||||
}
|
||||
|
||||
void normalize (void)
|
||||
{
|
||||
LogAware::normalize (params_);
|
||||
}
|
||||
|
||||
int indexOf (const T& t) const
|
||||
{
|
||||
int idx = -1;
|
||||
@ -258,16 +255,10 @@ class Factor : public TFactor<VarId>
|
||||
|
||||
Factor (const Factor&);
|
||||
|
||||
Factor (VarId, unsigned);
|
||||
|
||||
Factor (const Vars&);
|
||||
|
||||
Factor (VarId, unsigned, const Params&);
|
||||
|
||||
Factor (const Vars&, const Params&,
|
||||
Factor (const VarIds&, const Ranges&, const Params&,
|
||||
unsigned = Util::maxUnsigned());
|
||||
|
||||
Factor (const VarIds&, const Ranges&, const Params&,
|
||||
Factor (const Vars&, const Params&,
|
||||
unsigned = Util::maxUnsigned());
|
||||
|
||||
void sumOutAllExcept (VarId);
|
||||
|
@ -31,7 +31,6 @@ FactorGraph::FactorGraph (const FactorGraph& fg)
|
||||
addEdge (varNodes_[neighs[j]->getIndex()], facNode);
|
||||
}
|
||||
}
|
||||
setIndexes();
|
||||
}
|
||||
|
||||
|
||||
@ -105,7 +104,6 @@ FactorGraph::readFromUaiFormat (const char* fileName)
|
||||
addFactor (Factor (factorVarIds[i], factorRanges[i], params));
|
||||
}
|
||||
is.close();
|
||||
setIndexes();
|
||||
}
|
||||
|
||||
|
||||
@ -165,7 +163,6 @@ FactorGraph::readFromLibDaiFormat (const char* fileName)
|
||||
addFactor (Factor (vids, ranges, params));
|
||||
}
|
||||
is.close();
|
||||
setIndexes();
|
||||
}
|
||||
|
||||
|
||||
@ -211,7 +208,7 @@ FactorGraph::addVarNode (VarNode* vn)
|
||||
{
|
||||
varNodes_.push_back (vn);
|
||||
vn->setIndex (varNodes_.size() - 1);
|
||||
varMap_.insert (make_pair (vn->varId(), varNodes_.size() - 1));
|
||||
varMap_.insert (make_pair (vn->varId(), vn));
|
||||
}
|
||||
|
||||
|
||||
@ -262,19 +259,6 @@ FactorGraph::getStructure (void)
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::setIndexes (void)
|
||||
{
|
||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||
varNodes_[i]->setIndex (i);
|
||||
}
|
||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||
facNodes_[i]->setIndex (i);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::print (void) const
|
||||
{
|
||||
@ -355,7 +339,7 @@ FactorGraph::exportToUaiFormat (const char* fileName) const
|
||||
out << endl;
|
||||
}
|
||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||
Params params = facNodes_[i]->params();
|
||||
Params params = facNodes_[i]->factor().params();
|
||||
if (Globals::logDomain) {
|
||||
Util::fromLog (params);
|
||||
}
|
||||
|
@ -12,11 +12,11 @@ using namespace std;
|
||||
|
||||
class FacNode;
|
||||
|
||||
|
||||
class VarNode : public Var
|
||||
{
|
||||
public:
|
||||
VarNode (VarId varId, unsigned nrStates) : Var (varId, nrStates) { }
|
||||
VarNode (VarId varId, unsigned nrStates)
|
||||
: Var (varId, nrStates) { }
|
||||
|
||||
VarNode (const Var* v) : Var (v) { }
|
||||
|
||||
@ -48,15 +48,13 @@ class FacNode
|
||||
|
||||
void setIndex (int index) { index_ = index; }
|
||||
|
||||
const Params& params (void) const { return factor_.params(); }
|
||||
|
||||
string getLabel (void) { return factor_.getLabel(); }
|
||||
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN (FacNode);
|
||||
|
||||
Factor factor_;
|
||||
VarNodes neighs_;
|
||||
Factor factor_;
|
||||
int index_;
|
||||
};
|
||||
|
||||
@ -73,7 +71,7 @@ struct CompVarId
|
||||
class FactorGraph
|
||||
{
|
||||
public:
|
||||
FactorGraph (void) { }
|
||||
FactorGraph (bool fbn = false) : fromBayesNet_(fbn) { }
|
||||
|
||||
FactorGraph (const FactorGraph&);
|
||||
|
||||
@ -82,15 +80,13 @@ class FactorGraph
|
||||
const VarNodes& varNodes (void) const { return varNodes_; }
|
||||
|
||||
const FacNodes& facNodes (void) const { return facNodes_; }
|
||||
|
||||
void setFromBayesNetwork (void) { fromBayesNet_ = true; }
|
||||
|
||||
bool isFromBayesNetwork (void) const { return fromBayesNet_ ; }
|
||||
|
||||
VarNode* getVarNode (VarId vid) const
|
||||
{
|
||||
IndexMap::const_iterator it = varMap_.find (vid);
|
||||
return (it != varMap_.end()) ? varNodes_[it->second] : 0;
|
||||
VarMap::const_iterator it = varMap_.find (vid);
|
||||
return it != varMap_.end() ? it->second : 0;
|
||||
}
|
||||
|
||||
void readFromUaiFormat (const char*);
|
||||
@ -109,8 +105,6 @@ class FactorGraph
|
||||
|
||||
DAGraph& getStructure (void);
|
||||
|
||||
void setIndexes (void);
|
||||
|
||||
void print (void) const;
|
||||
|
||||
void exportToGraphViz (const char*) const;
|
||||
@ -137,11 +131,11 @@ class FactorGraph
|
||||
VarNodes varNodes_;
|
||||
FacNodes facNodes_;
|
||||
|
||||
bool fromBayesNet_;
|
||||
DAGraph structure_;
|
||||
bool fromBayesNet_;
|
||||
|
||||
typedef unordered_map<unsigned, unsigned> IndexMap;
|
||||
IndexMap varMap_;
|
||||
typedef unordered_map<unsigned, VarNode*> VarMap;
|
||||
VarMap varMap_;
|
||||
};
|
||||
|
||||
#endif // HORUS_FACTORGRAPH_H
|
||||
|
@ -31,6 +31,13 @@ void readLiftedEvidence (YAP_Term, ObservedFormulas&);
|
||||
|
||||
Parfactor* readParfactor (YAP_Term);
|
||||
|
||||
void runVeSolver (FactorGraph* fg, const vector<VarIds>& tasks,
|
||||
vector<Params>& results);
|
||||
|
||||
void runBpSolver (FactorGraph* fg, const vector<VarIds>& tasks,
|
||||
vector<Params>& results);
|
||||
|
||||
|
||||
|
||||
|
||||
vector<unsigned>
|
||||
@ -212,8 +219,9 @@ void readLiftedEvidence (
|
||||
int
|
||||
createGroundNetwork (void)
|
||||
{
|
||||
FactorGraph* fg = new FactorGraph();;
|
||||
string factorsType ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG1)));
|
||||
bool fromBayesNet = factorsType == "bayes";
|
||||
FactorGraph* fg = new FactorGraph (fromBayesNet);
|
||||
YAP_Term factorList = YAP_ARG2;
|
||||
while (factorList != YAP_TermNil()) {
|
||||
YAP_Term factor = YAP_HeadOfTerm (factorList);
|
||||
@ -228,10 +236,6 @@ createGroundNetwork (void)
|
||||
fg->addFactor (Factor (varIds, ranges, params, distId));
|
||||
factorList = YAP_TailOfTerm (factorList);
|
||||
}
|
||||
if (factorsType == "bayes") {
|
||||
fg->setFromBayesNetwork();
|
||||
}
|
||||
fg->setIndexes();
|
||||
|
||||
YAP_Term evidenceList = YAP_ARG3;
|
||||
while (evidenceList != YAP_TermNil()) {
|
||||
@ -326,13 +330,6 @@ runLiftedSolver (void)
|
||||
|
||||
|
||||
|
||||
void runVeSolver (FactorGraph* fg, const vector<VarIds>& tasks,
|
||||
vector<Params>& results);
|
||||
void runBpSolver (FactorGraph* fg, const vector<VarIds>& tasks,
|
||||
vector<Params>& results);
|
||||
|
||||
|
||||
|
||||
int
|
||||
runGroundSolver (void)
|
||||
{
|
||||
@ -462,26 +459,25 @@ setParfactorsParams (void)
|
||||
|
||||
|
||||
int
|
||||
setBayesNetParams (void)
|
||||
setFactorsParams (void)
|
||||
{
|
||||
// TODO FIXME
|
||||
/*
|
||||
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
|
||||
return TRUE; // TODO
|
||||
FactorGraph* fg = (FactorGraph*) YAP_IntOfTerm (YAP_ARG1);
|
||||
YAP_Term distList = YAP_ARG2;
|
||||
unordered_map<unsigned, Params> paramsMap;
|
||||
while (distList != YAP_TermNil()) {
|
||||
YAP_Term dist = YAP_HeadOfTerm (distList);
|
||||
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, dist));
|
||||
YAP_Term dist = YAP_HeadOfTerm (distList);
|
||||
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, dist));
|
||||
assert (Util::contains (paramsMap, distId) == false);
|
||||
paramsMap[distId] = readParameters (YAP_ArgOfTerm (2, dist));
|
||||
distList = YAP_TailOfTerm (distList);
|
||||
}
|
||||
const BnNodeSet& nodes = bn->getBayesNodes();
|
||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||
assert (Util::contains (paramsMap, nodes[i]->distId()));
|
||||
nodes[i]->setParams (paramsMap[nodes[i]->distId()]);
|
||||
const FacNodes& facNodes = fg->facNodes();
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
unsigned distId = facNodes[i]->factor().distId();
|
||||
assert (Util::contains (paramsMap, distId));
|
||||
facNodes[i]->factor().setParams (paramsMap[distId]);
|
||||
}
|
||||
*/
|
||||
return TRUE;
|
||||
}
|
||||
|
||||
@ -628,7 +624,7 @@ init_predicates (void)
|
||||
YAP_UserCPredicate ("run_lifted_solver", runLiftedSolver, 3);
|
||||
YAP_UserCPredicate ("run_ground_solver", runGroundSolver, 3);
|
||||
YAP_UserCPredicate ("set_parfactors_params", setParfactorsParams, 2);
|
||||
YAP_UserCPredicate ("set_bayes_net_params", setBayesNetParams, 2);
|
||||
YAP_UserCPredicate ("set_factors_params", setFactorsParams, 2);
|
||||
YAP_UserCPredicate ("set_vars_information", setVarsInformation, 2);
|
||||
YAP_UserCPredicate ("set_horus_flag", setHorusFlag, 2);
|
||||
YAP_UserCPredicate ("free_parfactors", freeParfactors, 1);
|
||||
|
@ -9,7 +9,7 @@
|
||||
[create_lifted_network/3,
|
||||
create_ground_network/4,
|
||||
set_parfactors_params/2,
|
||||
set_bayes_net_params/2,
|
||||
set_factors_params/2,
|
||||
run_lifted_solver/3,
|
||||
run_ground_solver/3,
|
||||
set_vars_information/2,
|
||||
|
Reference in New Issue
Block a user