drop Solver suffix and rename some files and classes

This commit is contained in:
Tiago Gomes 2012-06-12 16:29:57 +01:00
parent 919116763f
commit d4f63b3942
15 changed files with 169 additions and 170 deletions

View File

@ -5,21 +5,21 @@
#include <iostream> #include <iostream>
#include "BpSolver.h" #include "BeliefProp.h"
#include "FactorGraph.h" #include "FactorGraph.h"
#include "Factor.h" #include "Factor.h"
#include "Indexer.h" #include "Indexer.h"
#include "Horus.h" #include "Horus.h"
BpSolver::BpSolver (const FactorGraph& fg) : Solver (fg) BeliefProp::BeliefProp (const FactorGraph& fg) : Solver (fg)
{ {
runned_ = false; runned_ = false;
} }
BpSolver::~BpSolver (void) BeliefProp::~BeliefProp (void)
{ {
for (size_t i = 0; i < varsI_.size(); i++) { for (size_t i = 0; i < varsI_.size(); i++) {
delete varsI_[i]; delete varsI_[i];
@ -35,7 +35,7 @@ BpSolver::~BpSolver (void)
Params Params
BpSolver::solveQuery (VarIds queryVids) BeliefProp::solveQuery (VarIds queryVids)
{ {
assert (queryVids.empty() == false); assert (queryVids.empty() == false);
return queryVids.size() == 1 return queryVids.size() == 1
@ -46,7 +46,7 @@ BpSolver::solveQuery (VarIds queryVids)
void void
BpSolver::printSolverFlags (void) const BeliefProp::printSolverFlags (void) const
{ {
stringstream ss; stringstream ss;
ss << "belief propagation [" ; ss << "belief propagation [" ;
@ -68,7 +68,7 @@ BpSolver::printSolverFlags (void) const
Params Params
BpSolver::getPosterioriOf (VarId vid) BeliefProp::getPosterioriOf (VarId vid)
{ {
if (runned_ == false) { if (runned_ == false) {
runSolver(); runSolver();
@ -101,7 +101,7 @@ BpSolver::getPosterioriOf (VarId vid)
Params Params
BpSolver::getJointDistributionOf (const VarIds& jointVarIds) BeliefProp::getJointDistributionOf (const VarIds& jointVarIds)
{ {
if (runned_ == false) { if (runned_ == false) {
runSolver(); runSolver();
@ -140,7 +140,7 @@ BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
void void
BpSolver::runSolver (void) BeliefProp::runSolver (void)
{ {
initializeSolver(); initializeSolver();
nIters_ = 0; nIters_ = 0;
@ -187,7 +187,7 @@ BpSolver::runSolver (void)
void void
BpSolver::createLinks (void) BeliefProp::createLinks (void)
{ {
const FacNodes& facNodes = fg.facNodes(); const FacNodes& facNodes = fg.facNodes();
for (size_t i = 0; i < facNodes.size(); i++) { for (size_t i = 0; i < facNodes.size(); i++) {
@ -201,7 +201,7 @@ BpSolver::createLinks (void)
void void
BpSolver::maxResidualSchedule (void) BeliefProp::maxResidualSchedule (void)
{ {
if (nIters_ == 1) { if (nIters_ == 1) {
for (size_t i = 0; i < links_.size(); i++) { for (size_t i = 0; i < links_.size(); i++) {
@ -256,7 +256,7 @@ BpSolver::maxResidualSchedule (void)
void void
BpSolver::calcFactorToVarMsg (BpLink* link) BeliefProp::calcFactorToVarMsg (BpLink* link)
{ {
FacNode* src = link->facNode(); FacNode* src = link->facNode();
const VarNode* dst = link->varNode(); const VarNode* dst = link->varNode();
@ -320,7 +320,7 @@ BpSolver::calcFactorToVarMsg (BpLink* link)
Params Params
BpSolver::getVarToFactorMsg (const BpLink* link) const BeliefProp::getVarToFactorMsg (const BpLink* link) const
{ {
const VarNode* src = link->varNode(); const VarNode* src = link->varNode();
Params msg; Params msg;
@ -361,7 +361,7 @@ BpSolver::getVarToFactorMsg (const BpLink* link) const
Params Params
BpSolver::getJointByConditioning (const VarIds& jointVarIds) const BeliefProp::getJointByConditioning (const VarIds& jointVarIds) const
{ {
VarNodes jointVars; VarNodes jointVars;
for (size_t i = 0; i < jointVarIds.size(); i++) { for (size_t i = 0; i < jointVarIds.size(); i++) {
@ -370,7 +370,7 @@ BpSolver::getJointByConditioning (const VarIds& jointVarIds) const
} }
FactorGraph* tempFg = new FactorGraph (fg); FactorGraph* tempFg = new FactorGraph (fg);
BpSolver solver (*tempFg); BeliefProp solver (*tempFg);
solver.runSolver(); solver.runSolver();
Params prevBeliefs = solver.getPosterioriOf (jointVarIds[0]); Params prevBeliefs = solver.getPosterioriOf (jointVarIds[0]);
@ -390,7 +390,7 @@ BpSolver::getJointByConditioning (const VarIds& jointVarIds) const
for (size_t j = 0; j < observedVars.size(); j++) { for (size_t j = 0; j < observedVars.size(); j++) {
observedVars[j]->setEvidence (indexer[j]); observedVars[j]->setEvidence (indexer[j]);
} }
BpSolver solver (*tempFg); BeliefProp solver (*tempFg);
solver.runSolver(); solver.runSolver();
Params beliefs = solver.getPosterioriOf (jointVarIds[i]); Params beliefs = solver.getPosterioriOf (jointVarIds[i]);
for (size_t k = 0; k < beliefs.size(); k++) { for (size_t k = 0; k < beliefs.size(); k++) {
@ -415,7 +415,7 @@ BpSolver::getJointByConditioning (const VarIds& jointVarIds) const
void void
BpSolver::initializeSolver (void) BeliefProp::initializeSolver (void)
{ {
const VarNodes& varNodes = fg.varNodes(); const VarNodes& varNodes = fg.varNodes();
varsI_.reserve (varNodes.size()); varsI_.reserve (varNodes.size());
@ -439,7 +439,7 @@ BpSolver::initializeSolver (void)
bool bool
BpSolver::converged (void) BeliefProp::converged (void)
{ {
if (links_.size() == 0) { if (links_.size() == 0) {
return true; return true;
@ -487,7 +487,7 @@ BpSolver::converged (void)
void void
BpSolver::printLinkInformation (void) const BeliefProp::printLinkInformation (void) const
{ {
for (size_t i = 0; i < links_.size(); i++) { for (size_t i = 0; i < links_.size(); i++) {
BpLink* l = links_[i]; BpLink* l = links_[i];

View File

@ -1,5 +1,5 @@
#ifndef HORUS_BPSOLVER_H #ifndef HORUS_BELIEFPROP_H
#define HORUS_BPSOLVER_H #define HORUS_BELIEFPROP_H
#include <set> #include <set>
#include <vector> #include <vector>
@ -83,12 +83,12 @@ class SPNodeInfo
}; };
class BpSolver : public Solver class BeliefProp : public Solver
{ {
public: public:
BpSolver (const FactorGraph&); BeliefProp (const FactorGraph&);
virtual ~BpSolver (void); virtual ~BeliefProp (void);
Params solveQuery (VarIds); Params solveQuery (VarIds);
@ -180,5 +180,5 @@ class BpSolver : public Solver
virtual void printLinkInformation (void) const; virtual void printLinkInformation (void) const;
}; };
#endif // HORUS_BPSOLVER_H #endif // HORUS_BELIEFPROP_H

View File

@ -1,23 +1,23 @@
#include "CbpSolver.h" #include "CountingBp.h"
#include "WeightedBpSolver.h" #include "WeightedBp.h"
bool CbpSolver::checkForIdenticalFactors = true; bool CountingBp::checkForIdenticalFactors = true;
CbpSolver::CbpSolver (const FactorGraph& fg) CountingBp::CountingBp (const FactorGraph& fg)
: Solver (fg), freeColor_(0) : Solver (fg), freeColor_(0)
{ {
findIdenticalFactors(); findIdenticalFactors();
setInitialColors(); setInitialColors();
createGroups(); createGroups();
compressedFg_ = getCompressedFactorGraph(); compressedFg_ = getCompressedFactorGraph();
solver_ = new WeightedBpSolver (*compressedFg_, getWeights()); solver_ = new WeightedBp (*compressedFg_, getWeights());
} }
CbpSolver::~CbpSolver (void) CountingBp::~CountingBp (void)
{ {
delete solver_; delete solver_;
delete compressedFg_; delete compressedFg_;
@ -32,7 +32,7 @@ CbpSolver::~CbpSolver (void)
void void
CbpSolver::printSolverFlags (void) const CountingBp::printSolverFlags (void) const
{ {
stringstream ss; stringstream ss;
ss << "counting bp [" ; ss << "counting bp [" ;
@ -48,7 +48,7 @@ CbpSolver::printSolverFlags (void) const
ss << ",accuracy=" << BpOptions::accuracy; ss << ",accuracy=" << BpOptions::accuracy;
ss << ",log_domain=" << Util::toString (Globals::logDomain); ss << ",log_domain=" << Util::toString (Globals::logDomain);
ss << ",chkif=" << ss << ",chkif=" <<
Util::toString (CbpSolver::checkForIdenticalFactors); Util::toString (CountingBp::checkForIdenticalFactors);
ss << "]" ; ss << "]" ;
cout << ss.str() << endl; cout << ss.str() << endl;
} }
@ -56,7 +56,7 @@ CbpSolver::printSolverFlags (void) const
Params Params
CbpSolver::solveQuery (VarIds queryVids) CountingBp::solveQuery (VarIds queryVids)
{ {
assert (queryVids.empty() == false); assert (queryVids.empty() == false);
Params res; Params res;
@ -91,7 +91,7 @@ CbpSolver::solveQuery (VarIds queryVids)
void void
CbpSolver::findIdenticalFactors() CountingBp::findIdenticalFactors()
{ {
const FacNodes& facNodes = fg.facNodes(); const FacNodes& facNodes = fg.facNodes();
if (checkForIdenticalFactors == false || if (checkForIdenticalFactors == false ||
@ -126,7 +126,7 @@ CbpSolver::findIdenticalFactors()
void void
CbpSolver::setInitialColors (void) CountingBp::setInitialColors (void)
{ {
varColors_.resize (fg.nrVarNodes()); varColors_.resize (fg.nrVarNodes());
facColors_.resize (fg.nrFacNodes()); facColors_.resize (fg.nrFacNodes());
@ -165,7 +165,7 @@ CbpSolver::setInitialColors (void)
void void
CbpSolver::createGroups (void) CountingBp::createGroups (void)
{ {
VarSignMap varGroups; VarSignMap varGroups;
FacSignMap facGroups; FacSignMap facGroups;
@ -227,7 +227,7 @@ CbpSolver::createGroups (void)
void void
CbpSolver::createClusters ( CountingBp::createClusters (
const VarSignMap& varGroups, const VarSignMap& varGroups,
const FacSignMap& facGroups) const FacSignMap& facGroups)
{ {
@ -260,7 +260,7 @@ CbpSolver::createClusters (
VarSignature VarSignature
CbpSolver::getSignature (const VarNode* varNode) CountingBp::getSignature (const VarNode* varNode)
{ {
const FacNodes& neighs = varNode->neighbors(); const FacNodes& neighs = varNode->neighbors();
VarSignature sign; VarSignature sign;
@ -278,7 +278,7 @@ CbpSolver::getSignature (const VarNode* varNode)
FacSignature FacSignature
CbpSolver::getSignature (const FacNode* facNode) CountingBp::getSignature (const FacNode* facNode)
{ {
const VarNodes& neighs = facNode->neighbors(); const VarNodes& neighs = facNode->neighbors();
FacSignature sign; FacSignature sign;
@ -293,7 +293,7 @@ CbpSolver::getSignature (const FacNode* facNode)
FactorGraph* FactorGraph*
CbpSolver::getCompressedFactorGraph (void) CountingBp::getCompressedFactorGraph (void)
{ {
FactorGraph* fg = new FactorGraph(); FactorGraph* fg = new FactorGraph();
for (size_t i = 0; i < varClusters_.size(); i++) { for (size_t i = 0; i < varClusters_.size(); i++) {
@ -322,7 +322,7 @@ CbpSolver::getCompressedFactorGraph (void)
vector<vector<unsigned>> vector<vector<unsigned>>
CbpSolver::getWeights (void) const CountingBp::getWeights (void) const
{ {
vector<vector<unsigned>> weights; vector<vector<unsigned>> weights;
weights.reserve (facClusters_.size()); weights.reserve (facClusters_.size());
@ -341,7 +341,7 @@ CbpSolver::getWeights (void) const
unsigned unsigned
CbpSolver::getWeight ( CountingBp::getWeight (
const FacCluster* fc, const FacCluster* fc,
const VarCluster* vc, const VarCluster* vc,
size_t index) const size_t index) const
@ -364,7 +364,7 @@ CbpSolver::getWeight (
void void
CbpSolver::printGroups ( CountingBp::printGroups (
const VarSignMap& varGroups, const VarSignMap& varGroups,
const FacSignMap& facGroups) const const FacSignMap& facGroups) const
{ {

View File

@ -1,5 +1,5 @@
#ifndef HORUS_CBPSOLVER_H #ifndef HORUS_COUNTINGBP_H
#define HORUS_CBPSOLVER_H #define HORUS_COUNTINGBP_H
#include <unordered_map> #include <unordered_map>
@ -12,7 +12,7 @@ class VarCluster;
class FacCluster; class FacCluster;
class VarSignHash; class VarSignHash;
class FacSignHash; class FacSignHash;
class WeightedBpSolver; class WeightedBp;
typedef long Color; typedef long Color;
typedef vector<Color> Colors; typedef vector<Color> Colors;
@ -100,12 +100,12 @@ class FacCluster
}; };
class CbpSolver : public Solver class CountingBp : public Solver
{ {
public: public:
CbpSolver (const FactorGraph& fg); CountingBp (const FactorGraph& fg);
~CbpSolver (void); ~CountingBp (void);
void printSolverFlags (void) const; void printSolverFlags (void) const;
@ -176,8 +176,8 @@ class CbpSolver : public Solver
FacClusters facClusters_; FacClusters facClusters_;
VarId2VarCluster vid2VarCluster_; VarId2VarCluster vid2VarCluster_;
const FactorGraph* compressedFg_; const FactorGraph* compressedFg_;
WeightedBpSolver* solver_; WeightedBp* solver_;
}; };
#endif // HORUS_CBPSOLVER_H #endif // HORUS_COUNTINGBP_H

View File

@ -4,9 +4,9 @@
#include <sstream> #include <sstream>
#include "FactorGraph.h" #include "FactorGraph.h"
#include "VarElimSolver.h" #include "VarElim.h"
#include "BpSolver.h" #include "BeliefProp.h"
#include "CbpSolver.h" #include "CountingBp.h"
using namespace std; using namespace std;
@ -163,13 +163,13 @@ runSolver (const FactorGraph& fg, const VarIds& queryIds)
Solver* solver = 0; Solver* solver = 0;
switch (Globals::groundSolver) { switch (Globals::groundSolver) {
case GroundSolvers::VE: case GroundSolvers::VE:
solver = new VarElimSolver (fg); solver = new VarElim (fg);
break; break;
case GroundSolvers::BP: case GroundSolvers::BP:
solver = new BpSolver (fg); solver = new BeliefProp (fg);
break; break;
case GroundSolvers::CBP: case GroundSolvers::CBP:
solver = new CbpSolver (fg); solver = new CountingBp (fg);
break; break;
default: default:
assert (false); assert (false);

View File

@ -9,11 +9,11 @@
#include "ParfactorList.h" #include "ParfactorList.h"
#include "FactorGraph.h" #include "FactorGraph.h"
#include "FoveSolver.h" #include "LiftedVe.h"
#include "VarElimSolver.h" #include "VarElim.h"
#include "LiftedBpSolver.h" #include "LiftedBp.h"
#include "BpSolver.h" #include "BeliefProp.h"
#include "CbpSolver.h" #include "CountingBp.h"
#include "ElimGraph.h" #include "ElimGraph.h"
#include "BayesBall.h" #include "BayesBall.h"
@ -35,7 +35,7 @@ Parfactor* readParfactor (YAP_Term);
void runVeSolver (FactorGraph* fg, const vector<VarIds>& tasks, void runVeSolver (FactorGraph* fg, const vector<VarIds>& tasks,
vector<Params>& results); vector<Params>& results);
void runBpSolver (FactorGraph* fg, const vector<VarIds>& tasks, void runBeliefProp (FactorGraph* fg, const vector<VarIds>& tasks,
vector<Params>& results); vector<Params>& results);
@ -285,7 +285,7 @@ runLiftedSolver (void)
YAP_Term taskList = YAP_ARG2; YAP_Term taskList = YAP_ARG2;
vector<Params> results; vector<Params> results;
ParfactorList pfListCopy (*network->first); ParfactorList pfListCopy (*network->first);
FoveSolver::absorveEvidence (pfListCopy, *network->second); LiftedVe::absorveEvidence (pfListCopy, *network->second);
while (taskList != YAP_TermNil()) { while (taskList != YAP_TermNil()) {
Grounds queryVars; Grounds queryVars;
YAP_Term jointList = YAP_HeadOfTerm (taskList); YAP_Term jointList = YAP_HeadOfTerm (taskList);
@ -312,14 +312,14 @@ runLiftedSolver (void)
jointList = YAP_TailOfTerm (jointList); jointList = YAP_TailOfTerm (jointList);
} }
if (Globals::liftedSolver == LiftedSolvers::FOVE) { if (Globals::liftedSolver == LiftedSolvers::FOVE) {
FoveSolver solver (pfListCopy); LiftedVe solver (pfListCopy);
if (Globals::verbosity > 0 && taskList == YAP_ARG2) { if (Globals::verbosity > 0 && taskList == YAP_ARG2) {
solver.printSolverFlags(); solver.printSolverFlags();
cout << endl; cout << endl;
} }
results.push_back (solver.solveQuery (queryVars)); results.push_back (solver.solveQuery (queryVars));
} else if (Globals::liftedSolver == LiftedSolvers::LBP) { } else if (Globals::liftedSolver == LiftedSolvers::LBP) {
LiftedBpSolver solver (pfListCopy); LiftedBp solver (pfListCopy);
if (Globals::verbosity > 0 && taskList == YAP_ARG2) { if (Globals::verbosity > 0 && taskList == YAP_ARG2) {
solver.printSolverFlags(); solver.printSolverFlags();
cout << endl; cout << endl;
@ -365,7 +365,7 @@ runGroundSolver (void)
if (Globals::groundSolver == GroundSolvers::VE) { if (Globals::groundSolver == GroundSolvers::VE) {
runVeSolver (fg, tasks, results); runVeSolver (fg, tasks, results);
} else { } else {
runBpSolver (fg, tasks, results); runBeliefProp (fg, tasks, results);
} }
YAP_Term list = YAP_TermNil(); YAP_Term list = YAP_TermNil();
@ -397,8 +397,8 @@ void runVeSolver (
if (fg->bayesianFactors()) { if (fg->bayesianFactors()) {
// mfg = BayesBall::getMinimalFactorGraph (*fg, tasks[i]); // mfg = BayesBall::getMinimalFactorGraph (*fg, tasks[i]);
} }
// VarElimSolver solver (*mfg); // VarElim solver (*mfg);
VarElimSolver solver (*fg); //FIXME VarElim solver (*fg); //FIXME
if (Globals::verbosity > 0 && i == 0) { if (Globals::verbosity > 0 && i == 0) {
solver.printSolverFlags(); solver.printSolverFlags();
cout << endl; cout << endl;
@ -412,7 +412,7 @@ void runVeSolver (
void runBpSolver ( void runBeliefProp (
FactorGraph* fg, FactorGraph* fg,
const vector<VarIds>& tasks, const vector<VarIds>& tasks,
vector<Params>& results) vector<Params>& results)
@ -428,10 +428,10 @@ void runBpSolver (
// *fg, VarIds (vids.begin(),vids.end())); // *fg, VarIds (vids.begin(),vids.end()));
} }
if (Globals::groundSolver == GroundSolvers::BP) { if (Globals::groundSolver == GroundSolvers::BP) {
solver = new BpSolver (*fg); // FIXME solver = new BeliefProp (*fg); // FIXME
} else if (Globals::groundSolver == GroundSolvers::CBP) { } else if (Globals::groundSolver == GroundSolvers::CBP) {
CbpSolver::checkForIdenticalFactors = false; CountingBp::checkForIdenticalFactors = false;
solver = new CbpSolver (*fg); // FIXME solver = new CountingBp (*fg); // FIXME
} else { } else {
cerr << "error: unknow solver" << endl; cerr << "error: unknow solver" << endl;
abort(); abort();

View File

@ -1,20 +1,20 @@
#include "LiftedBpSolver.h" #include "LiftedBp.h"
#include "WeightedBpSolver.h" #include "WeightedBp.h"
#include "FactorGraph.h" #include "FactorGraph.h"
#include "FoveSolver.h" #include "LiftedVe.h"
LiftedBpSolver::LiftedBpSolver (const ParfactorList& pfList) LiftedBp::LiftedBp (const ParfactorList& pfList)
: pfList_(pfList) : pfList_(pfList)
{ {
refineParfactors(); refineParfactors();
solver_ = new WeightedBpSolver (*getFactorGraph(), getWeights()); solver_ = new WeightedBp (*getFactorGraph(), getWeights());
} }
Params Params
LiftedBpSolver::solveQuery (const Grounds& query) LiftedBp::solveQuery (const Grounds& query)
{ {
assert (query.empty() == false); assert (query.empty() == false);
Params res; Params res;
@ -34,7 +34,7 @@ LiftedBpSolver::solveQuery (const Grounds& query)
void void
LiftedBpSolver::printSolverFlags (void) const LiftedBp::printSolverFlags (void) const
{ {
stringstream ss; stringstream ss;
ss << "lifted bp [" ; ss << "lifted bp [" ;
@ -56,7 +56,7 @@ LiftedBpSolver::printSolverFlags (void) const
void void
LiftedBpSolver::refineParfactors (void) LiftedBp::refineParfactors (void)
{ {
while (iterate() == false); while (iterate() == false);
@ -69,7 +69,7 @@ LiftedBpSolver::refineParfactors (void)
bool bool
LiftedBpSolver::iterate (void) LiftedBp::iterate (void)
{ {
ParfactorList::iterator it = pfList_.begin(); ParfactorList::iterator it = pfList_.begin();
while (it != pfList_.end()) { while (it != pfList_.end()) {
@ -77,7 +77,7 @@ LiftedBpSolver::iterate (void)
for (size_t i = 0; i < args.size(); i++) { for (size_t i = 0; i < args.size(); i++) {
LogVarSet lvs = (*it)->logVarSet() - args[i].logVars(); LogVarSet lvs = (*it)->logVarSet() - args[i].logVars();
if ((*it)->constr()->isCountNormalized (lvs) == false) { if ((*it)->constr()->isCountNormalized (lvs) == false) {
Parfactors pfs = FoveSolver::countNormalize (*it, lvs); Parfactors pfs = LiftedVe::countNormalize (*it, lvs);
it = pfList_.removeAndDelete (it); it = pfList_.removeAndDelete (it);
pfList_.add (pfs); pfList_.add (pfs);
return false; return false;
@ -91,7 +91,7 @@ LiftedBpSolver::iterate (void)
vector<PrvGroup> vector<PrvGroup>
LiftedBpSolver::getQueryGroups (const Grounds& query) LiftedBp::getQueryGroups (const Grounds& query)
{ {
vector<PrvGroup> queryGroups; vector<PrvGroup> queryGroups;
for (unsigned i = 0; i < query.size(); i++) { for (unsigned i = 0; i < query.size(); i++) {
@ -110,7 +110,7 @@ LiftedBpSolver::getQueryGroups (const Grounds& query)
FactorGraph* FactorGraph*
LiftedBpSolver::getFactorGraph (void) LiftedBp::getFactorGraph (void)
{ {
FactorGraph* fg = new FactorGraph(); FactorGraph* fg = new FactorGraph();
ParfactorList::const_iterator it = pfList_.begin(); ParfactorList::const_iterator it = pfList_.begin();
@ -128,7 +128,7 @@ LiftedBpSolver::getFactorGraph (void)
vector<vector<unsigned>> vector<vector<unsigned>>
LiftedBpSolver::getWeights (void) const LiftedBp::getWeights (void) const
{ {
vector<vector<unsigned>> weights; vector<vector<unsigned>> weights;
weights.reserve (pfList_.size()); weights.reserve (pfList_.size());

View File

@ -1,15 +1,15 @@
#ifndef HORUS_LIFTEDBPSOLVER_H #ifndef HORUS_LIFTEDBP_H
#define HORUS_LIFTEDBPSOLVER_H #define HORUS_LIFTEDBP_H
#include "ParfactorList.h" #include "ParfactorList.h"
class FactorGraph; class FactorGraph;
class WeightedBpSolver; class WeightedBp;
class LiftedBpSolver class LiftedBp
{ {
public: public:
LiftedBpSolver (const ParfactorList& pfList); LiftedBp (const ParfactorList& pfList);
Params solveQuery (const Grounds&); Params solveQuery (const Grounds&);
@ -27,8 +27,8 @@ class LiftedBpSolver
vector<vector<unsigned>> getWeights (void) const; vector<vector<unsigned>> getWeights (void) const;
ParfactorList pfList_; ParfactorList pfList_;
WeightedBpSolver* solver_; WeightedBp* solver_;
}; };
#endif // HORUS_LIFTEDBPSOLVER_H #endif // HORUS_LIFTEDBP_H

View File

@ -1,8 +1,7 @@
#include <algorithm> #include <algorithm>
#include <set> #include <set>
#include "FoveSolver.h" #include "LiftedVe.h"
#include "Histogram.h" #include "Histogram.h"
#include "Util.h" #include "Util.h"
@ -222,7 +221,7 @@ SumOutOperator::apply (void)
product->sumOutIndex (fIdx); product->sumOutIndex (fIdx);
pfList_.addShattered (product); pfList_.addShattered (product);
} else { } else {
Parfactors pfs = FoveSolver::countNormalize (product, excl); Parfactors pfs = LiftedVe::countNormalize (product, excl);
for (size_t i = 0; i < pfs.size(); i++) { for (size_t i = 0; i < pfs.size(); i++) {
pfs[i]->sumOutIndex (fIdx); pfs[i]->sumOutIndex (fIdx);
pfList_.add (pfs[i]); pfList_.add (pfs[i]);
@ -376,7 +375,7 @@ CountingOperator::apply (void)
} else { } else {
Parfactor* pf = *pfIter_; Parfactor* pf = *pfIter_;
pfList_.remove (pfIter_); pfList_.remove (pfIter_);
Parfactors pfs = FoveSolver::countNormalize (pf, X_); Parfactors pfs = LiftedVe::countNormalize (pf, X_);
for (size_t i = 0; i < pfs.size(); i++) { for (size_t i = 0; i < pfs.size(); i++) {
unsigned condCount = pfs[i]->constr()->getConditionalCount (X_); unsigned condCount = pfs[i]->constr()->getConditionalCount (X_);
bool cartProduct = pfs[i]->constr()->isCartesianProduct ( bool cartProduct = pfs[i]->constr()->isCartesianProduct (
@ -420,7 +419,7 @@ CountingOperator::toString (void)
ss << "count convert " << X_ << " in " ; ss << "count convert " << X_ << " in " ;
ss << (*pfIter_)->getLabel(); ss << (*pfIter_)->getLabel();
ss << " [cost=" << std::exp (getLogCost()) << "]" << endl; ss << " [cost=" << std::exp (getLogCost()) << "]" << endl;
Parfactors pfs = FoveSolver::countNormalize (*pfIter_, X_); Parfactors pfs = LiftedVe::countNormalize (*pfIter_, X_);
if ((*pfIter_)->constr()->isCountNormalized (X_) == false) { if ((*pfIter_)->constr()->isCountNormalized (X_) == false) {
for (size_t i = 0; i < pfs.size(); i++) { for (size_t i = 0; i < pfs.size(); i++) {
ss << " º " << pfs[i]->getLabel() << endl; ss << " º " << pfs[i]->getLabel() << endl;
@ -630,7 +629,7 @@ GroundOperator::getAffectedFormulas (void)
Params Params
FoveSolver::solveQuery (const Grounds& query) LiftedVe::solveQuery (const Grounds& query)
{ {
assert (query.empty() == false); assert (query.empty() == false);
runSolver (query); runSolver (query);
@ -645,7 +644,7 @@ FoveSolver::solveQuery (const Grounds& query)
void void
FoveSolver::printSolverFlags (void) const LiftedVe::printSolverFlags (void) const
{ {
stringstream ss; stringstream ss;
ss << "fove [" ; ss << "fove [" ;
@ -657,7 +656,7 @@ FoveSolver::printSolverFlags (void) const
void void
FoveSolver::absorveEvidence ( LiftedVe::absorveEvidence (
ParfactorList& pfList, ParfactorList& pfList,
ObservedFormulas& obsFormulas) ObservedFormulas& obsFormulas)
{ {
@ -696,7 +695,7 @@ FoveSolver::absorveEvidence (
Parfactors Parfactors
FoveSolver::countNormalize ( LiftedVe::countNormalize (
Parfactor* g, Parfactor* g,
const LogVarSet& set) const LogVarSet& set)
{ {
@ -715,7 +714,7 @@ FoveSolver::countNormalize (
Parfactor Parfactor
FoveSolver::calcGroundMultiplication (Parfactor pf) LiftedVe::calcGroundMultiplication (Parfactor pf)
{ {
LogVarSet lvs = pf.constr()->logVarSet(); LogVarSet lvs = pf.constr()->logVarSet();
lvs -= pf.constr()->singletons(); lvs -= pf.constr()->singletons();
@ -748,7 +747,7 @@ FoveSolver::calcGroundMultiplication (Parfactor pf)
void void
FoveSolver::runSolver (const Grounds& query) LiftedVe::runSolver (const Grounds& query)
{ {
largestCost_ = std::log (0); largestCost_ = std::log (0);
shatterAgainstQuery (query); shatterAgainstQuery (query);
@ -794,7 +793,7 @@ FoveSolver::runSolver (const Grounds& query)
LiftedOperator* LiftedOperator*
FoveSolver::getBestOperation (const Grounds& query) LiftedVe::getBestOperation (const Grounds& query)
{ {
double bestCost = 0.0; double bestCost = 0.0;
LiftedOperator* bestOp = 0; LiftedOperator* bestOp = 0;
@ -821,7 +820,7 @@ FoveSolver::getBestOperation (const Grounds& query)
void void
FoveSolver::runWeakBayesBall (const Grounds& query) LiftedVe::runWeakBayesBall (const Grounds& query)
{ {
queue<PrvGroup> todo; // groups to process queue<PrvGroup> todo; // groups to process
set<PrvGroup> done; // processed or in queue set<PrvGroup> done; // processed or in queue
@ -880,7 +879,7 @@ FoveSolver::runWeakBayesBall (const Grounds& query)
void void
FoveSolver::shatterAgainstQuery (const Grounds& query) LiftedVe::shatterAgainstQuery (const Grounds& query)
{ {
for (size_t i = 0; i < query.size(); i++) { for (size_t i = 0; i < query.size(); i++) {
if (query[i].isAtom()) { if (query[i].isAtom()) {
@ -931,7 +930,7 @@ FoveSolver::shatterAgainstQuery (const Grounds& query)
Parfactors Parfactors
FoveSolver::absorve ( LiftedVe::absorve (
ObservedFormula& obsFormula, ObservedFormula& obsFormula,
Parfactor* g) Parfactor* g)
{ {

View File

@ -1,5 +1,5 @@
#ifndef HORUS_FOVESOLVER_H #ifndef HORUS_LIFTEDVE_H
#define HORUS_FOVESOLVER_H #define HORUS_LIFTEDVE_H
#include "ParfactorList.h" #include "ParfactorList.h"
@ -130,10 +130,10 @@ class GroundOperator : public LiftedOperator
class FoveSolver class LiftedVe
{ {
public: public:
FoveSolver (const ParfactorList& pfList) : pfList_(pfList) { } LiftedVe (const ParfactorList& pfList) : pfList_(pfList) { }
Params solveQuery (const Grounds&); Params solveQuery (const Grounds&);
@ -162,5 +162,5 @@ class FoveSolver
double largestCost_; double largestCost_;
}; };
#endif // HORUS_FOVESOLVER_H #endif // HORUS_LIFTEDVE_H

View File

@ -23,10 +23,10 @@ CC=@CC@
CXX=@CXX@ CXX=@CXX@
# normal # normal
#CXXFLAGS= -std=c++0x @SHLIB_CXXFLAGS@ $(YAP_EXTRAS) $(DEFS) -D_YAP_NOT_INSTALLED_=1 -I$(srcdir) -I../../.. -I$(srcdir)/../../../include @CPPFLAGS@ -DNDEBUG CXXFLAGS= -std=c++0x @SHLIB_CXXFLAGS@ $(YAP_EXTRAS) $(DEFS) -D_YAP_NOT_INSTALLED_=1 -I$(srcdir) -I../../.. -I$(srcdir)/../../../include @CPPFLAGS@ -DNDEBUG
# debug # debug
CXXFLAGS= -std=c++0x @SHLIB_CXXFLAGS@ $(YAP_EXTRAS) $(DEFS) -D_YAP_NOT_INSTALLED_=1 -I$(srcdir) -I../../.. -I$(srcdir)/../../../include @CPPFLAGS@ -g -O0 -Wextra #CXXFLAGS= -std=c++0x @SHLIB_CXXFLAGS@ $(YAP_EXTRAS) $(DEFS) -D_YAP_NOT_INSTALLED_=1 -I$(srcdir) -I../../.. -I$(srcdir)/../../../include @CPPFLAGS@ -g -O0 -Wextra
# #
@ -52,10 +52,10 @@ HEADERS = \
$(srcdir)/Factor.h \ $(srcdir)/Factor.h \
$(srcdir)/ConstraintTree.h \ $(srcdir)/ConstraintTree.h \
$(srcdir)/Solver.h \ $(srcdir)/Solver.h \
$(srcdir)/VarElimSolver.h \ $(srcdir)/VarElim.h \
$(srcdir)/BpSolver.h \ $(srcdir)/BeliefProp.h \
$(srcdir)/CbpSolver.h \ $(srcdir)/CountingBp.h \
$(srcdir)/FoveSolver.h \ $(srcdir)/LiftedVe.h \
$(srcdir)/Var.h \ $(srcdir)/Var.h \
$(srcdir)/Indexer.h \ $(srcdir)/Indexer.h \
$(srcdir)/Parfactor.h \ $(srcdir)/Parfactor.h \
@ -64,8 +64,8 @@ HEADERS = \
$(srcdir)/ParfactorList.h \ $(srcdir)/ParfactorList.h \
$(srcdir)/LiftedUtils.h \ $(srcdir)/LiftedUtils.h \
$(srcdir)/TinySet.h \ $(srcdir)/TinySet.h \
$(srcdir)/LiftedBpSolver.h \ $(srcdir)/LiftedBp.h \
$(srcdir)/WeightedBpSolver.h \ $(srcdir)/WeightedBp.h \
$(srcdir)/Util.h \ $(srcdir)/Util.h \
$(srcdir)/Horus.h $(srcdir)/Horus.h
@ -78,18 +78,18 @@ CPP_SOURCES = \
$(srcdir)/ConstraintTree.cpp \ $(srcdir)/ConstraintTree.cpp \
$(srcdir)/Var.cpp \ $(srcdir)/Var.cpp \
$(srcdir)/Solver.cpp \ $(srcdir)/Solver.cpp \
$(srcdir)/VarElimSolver.cpp \ $(srcdir)/VarElim.cpp \
$(srcdir)/BpSolver.cpp \ $(srcdir)/BeliefProp.cpp \
$(srcdir)/CbpSolver.cpp \ $(srcdir)/CountingBp.cpp \
$(srcdir)/FoveSolver.cpp \ $(srcdir)/LiftedVe.cpp \
$(srcdir)/Parfactor.cpp \ $(srcdir)/Parfactor.cpp \
$(srcdir)/ProbFormula.cpp \ $(srcdir)/ProbFormula.cpp \
$(srcdir)/Histogram.cpp \ $(srcdir)/Histogram.cpp \
$(srcdir)/ParfactorList.cpp \ $(srcdir)/ParfactorList.cpp \
$(srcdir)/LiftedUtils.cpp \ $(srcdir)/LiftedUtils.cpp \
$(srcdir)/Util.cpp \ $(srcdir)/Util.cpp \
$(srcdir)/LiftedBpSolver.cpp \ $(srcdir)/LiftedBp.cpp \
$(srcdir)/WeightedBpSolver.cpp \ $(srcdir)/WeightedBp.cpp \
$(srcdir)/HorusYap.cpp \ $(srcdir)/HorusYap.cpp \
$(srcdir)/HorusCli.cpp $(srcdir)/HorusCli.cpp
@ -102,18 +102,18 @@ OBJS = \
ConstraintTree.o \ ConstraintTree.o \
Var.o \ Var.o \
Solver.o \ Solver.o \
VarElimSolver.o \ VarElim.o \
BpSolver.o \ BeliefProp.o \
CbpSolver.o \ CountingBp.o \
FoveSolver.o \ LiftedVe.o \
Parfactor.o \ Parfactor.o \
ProbFormula.o \ ProbFormula.o \
Histogram.o \ Histogram.o \
ParfactorList.o \ ParfactorList.o \
LiftedUtils.o \ LiftedUtils.o \
Util.o \ Util.o \
LiftedBpSolver.o \ LiftedBp.o \
WeightedBpSolver.o \ WeightedBp.o \
HorusYap.o HorusYap.o
HCLI_OBJS = \ HCLI_OBJS = \
@ -125,15 +125,15 @@ HCLI_OBJS = \
ConstraintTree.o \ ConstraintTree.o \
Var.o \ Var.o \
Solver.o \ Solver.o \
VarElimSolver.o \ VarElim.o \
BpSolver.o \ BeliefProp.o \
CbpSolver.o \ CountingBp.o \
FoveSolver.o \ LiftedVe.o \
Parfactor.o \ Parfactor.o \
ProbFormula.o \ ProbFormula.o \
Histogram.o \ Histogram.o \
ParfactorList.o \ ParfactorList.o \
WeightedBpSolver.o \ WeightedBp.o \
LiftedUtils.o \ LiftedUtils.o \
Util.o \ Util.o \
HorusCli.o HorusCli.o

View File

@ -1,12 +1,12 @@
#include <algorithm> #include <algorithm>
#include "VarElimSolver.h" #include "VarElim.h"
#include "ElimGraph.h" #include "ElimGraph.h"
#include "Factor.h" #include "Factor.h"
#include "Util.h" #include "Util.h"
VarElimSolver::~VarElimSolver (void) VarElim::~VarElim (void)
{ {
delete factorList_.back(); delete factorList_.back();
} }
@ -14,7 +14,7 @@ VarElimSolver::~VarElimSolver (void)
Params Params
VarElimSolver::solveQuery (VarIds queryVids) VarElim::solveQuery (VarIds queryVids)
{ {
if (Globals::verbosity > 1) { if (Globals::verbosity > 1) {
cout << "Solving query on " ; cout << "Solving query on " ;
@ -41,7 +41,7 @@ VarElimSolver::solveQuery (VarIds queryVids)
void void
VarElimSolver::printSolverFlags (void) const VarElim::printSolverFlags (void) const
{ {
stringstream ss; stringstream ss;
ss << "variable elimination [" ; ss << "variable elimination [" ;
@ -62,7 +62,7 @@ VarElimSolver::printSolverFlags (void) const
void void
VarElimSolver::createFactorList (void) VarElim::createFactorList (void)
{ {
const FacNodes& facNodes = fg.facNodes(); const FacNodes& facNodes = fg.facNodes();
factorList_.reserve (facNodes.size() * 2); factorList_.reserve (facNodes.size() * 2);
@ -84,7 +84,7 @@ VarElimSolver::createFactorList (void)
void void
VarElimSolver::absorveEvidence (void) VarElim::absorveEvidence (void)
{ {
if (Globals::verbosity > 2) { if (Globals::verbosity > 2) {
Util::printDashedLine(); Util::printDashedLine();
@ -117,7 +117,7 @@ VarElimSolver::absorveEvidence (void)
void void
VarElimSolver::findEliminationOrder (const VarIds& vids) VarElim::findEliminationOrder (const VarIds& vids)
{ {
elimOrder_ = ElimGraph::getEliminationOrder (factorList_, vids); elimOrder_ = ElimGraph::getEliminationOrder (factorList_, vids);
} }
@ -125,7 +125,7 @@ VarElimSolver::findEliminationOrder (const VarIds& vids)
void void
VarElimSolver::processFactorList (const VarIds& vids) VarElim::processFactorList (const VarIds& vids)
{ {
totalFactorSize_ = 0; totalFactorSize_ = 0;
largestFactorSize_ = 0; largestFactorSize_ = 0;
@ -170,7 +170,7 @@ VarElimSolver::processFactorList (const VarIds& vids)
void void
VarElimSolver::eliminate (VarId elimVar) VarElim::eliminate (VarId elimVar)
{ {
Factor* result = 0; Factor* result = 0;
vector<size_t>& idxs = varFactors_.find (elimVar)->second; vector<size_t>& idxs = varFactors_.find (elimVar)->second;
@ -205,7 +205,7 @@ VarElimSolver::eliminate (VarId elimVar)
void void
VarElimSolver::printActiveFactors (void) VarElim::printActiveFactors (void)
{ {
for (size_t i = 0; i < factorList_.size(); i++) { for (size_t i = 0; i < factorList_.size(); i++) {
if (factorList_[i] != 0) { if (factorList_[i] != 0) {

View File

@ -1,5 +1,5 @@
#ifndef HORUS_VARELIMSOLVER_H #ifndef HORUS_VARELIM_H
#define HORUS_VARELIMSOLVER_H #define HORUS_VARELIM_H
#include "unordered_map" #include "unordered_map"
@ -11,12 +11,12 @@
using namespace std; using namespace std;
class VarElimSolver : public Solver class VarElim : public Solver
{ {
public: public:
VarElimSolver (const FactorGraph& fg) : Solver (fg) { } VarElim (const FactorGraph& fg) : Solver (fg) { }
~VarElimSolver (void); ~VarElim (void);
Params solveQuery (VarIds); Params solveQuery (VarIds);
@ -42,5 +42,5 @@ class VarElimSolver : public Solver
unordered_map<VarId, vector<size_t>> varFactors_; unordered_map<VarId, vector<size_t>> varFactors_;
}; };
#endif // HORUS_VARELIMSOLVER_H #endif // HORUS_VARELIM_H

View File

@ -1,7 +1,7 @@
#include "WeightedBpSolver.h" #include "WeightedBp.h"
WeightedBpSolver::~WeightedBpSolver (void) WeightedBp::~WeightedBp (void)
{ {
for (size_t i = 0; i < links_.size(); i++) { for (size_t i = 0; i < links_.size(); i++) {
delete links_[i]; delete links_[i];
@ -12,7 +12,7 @@ WeightedBpSolver::~WeightedBpSolver (void)
Params Params
WeightedBpSolver::getPosterioriOf (VarId vid) WeightedBp::getPosterioriOf (VarId vid)
{ {
if (runned_ == false) { if (runned_ == false) {
runSolver(); runSolver();
@ -47,7 +47,7 @@ WeightedBpSolver::getPosterioriOf (VarId vid)
void void
WeightedBpSolver::createLinks (void) WeightedBp::createLinks (void)
{ {
if (Globals::verbosity > 0) { if (Globals::verbosity > 0) {
cout << "compressed factor graph contains " ; cout << "compressed factor graph contains " ;
@ -78,7 +78,7 @@ WeightedBpSolver::createLinks (void)
void void
WeightedBpSolver::maxResidualSchedule (void) WeightedBp::maxResidualSchedule (void)
{ {
if (nIters_ == 1) { if (nIters_ == 1) {
for (size_t i = 0; i < links_.size(); i++) { for (size_t i = 0; i < links_.size(); i++) {
@ -151,7 +151,7 @@ WeightedBpSolver::maxResidualSchedule (void)
void void
WeightedBpSolver::calcFactorToVarMsg (BpLink* _link) WeightedBp::calcFactorToVarMsg (BpLink* _link)
{ {
WeightedLink* link = static_cast<WeightedLink*> (_link); WeightedLink* link = static_cast<WeightedLink*> (_link);
FacNode* src = link->facNode(); FacNode* src = link->facNode();
@ -223,7 +223,7 @@ WeightedBpSolver::calcFactorToVarMsg (BpLink* _link)
Params Params
WeightedBpSolver::getVarToFactorMsg (const BpLink* _link) const WeightedBp::getVarToFactorMsg (const BpLink* _link) const
{ {
const WeightedLink* link = static_cast<const WeightedLink*> (_link); const WeightedLink* link = static_cast<const WeightedLink*> (_link);
const VarNode* src = link->varNode(); const VarNode* src = link->varNode();
@ -272,7 +272,7 @@ WeightedBpSolver::getVarToFactorMsg (const BpLink* _link) const
void void
WeightedBpSolver::printLinkInformation (void) const WeightedBp::printLinkInformation (void) const
{ {
for (size_t i = 0; i < links_.size(); i++) { for (size_t i = 0; i < links_.size(); i++) {
WeightedLink* l = static_cast<WeightedLink*> (links_[i]); WeightedLink* l = static_cast<WeightedLink*> (links_[i]);

View File

@ -1,7 +1,7 @@
#ifndef HORUS_WEIGHTEDBPSOLVER_H #ifndef HORUS_WEIGHTEDBP_H
#define HORUS_WEIGHTEDBPSOLVER_H #define HORUS_WEIGHTEDBP_H
#include "BpSolver.h" #include "BeliefProp.h"
class WeightedLink : public BpLink class WeightedLink : public BpLink
{ {
@ -31,14 +31,14 @@ class WeightedLink : public BpLink
class WeightedBpSolver : public BpSolver class WeightedBp : public BeliefProp
{ {
public: public:
WeightedBpSolver (const FactorGraph& fg, WeightedBp (const FactorGraph& fg,
const vector<vector<unsigned>>& weights) const vector<vector<unsigned>>& weights)
: BpSolver (fg), weights_(weights) { } : BeliefProp (fg), weights_(weights) { }
~WeightedBpSolver (void); ~WeightedBp (void);
Params getPosterioriOf (VarId); Params getPosterioriOf (VarId);
@ -57,5 +57,5 @@ class WeightedBpSolver : public BpSolver
vector<vector<unsigned>> weights_; vector<vector<unsigned>> weights_;
}; };
#endif // HORUS_WEIGHTEDBPSOLVER_H #endif // HORUS_WEIGHTEDBP_H