use only solveQuery

This commit is contained in:
Tiago Gomes 2012-05-31 23:06:53 +01:00
parent f91e543d9d
commit b5369db214
9 changed files with 50 additions and 99 deletions

View File

@ -59,29 +59,17 @@ Params
CbpSolver::solveQuery (VarIds queryVids) CbpSolver::solveQuery (VarIds queryVids)
{ {
assert (queryVids.empty() == false); assert (queryVids.empty() == false);
return queryVids.size() == 1 Params res;
? getPosterioriOf (queryVids[0]) if (queryVids.size() == 1) {
: getJointDistributionOf (queryVids); res = solver_->getPosterioriOf (getRepresentative (queryVids[0]));
} } else {
VarIds representatives;
for (size_t i = 0; i < queryVids.size(); i++) {
representatives.push_back (getRepresentative (queryVids[i]));
Params }
CbpSolver::getPosterioriOf (VarId vid) res = solver_->getJointDistributionOf (representatives);
{
return solver_->getPosterioriOf (getRepresentative (vid));
}
Params
CbpSolver::getJointDistributionOf (const VarIds& jointVids)
{
VarIds representatives;
for (size_t i = 0; i < jointVids.size(); i++) {
representatives.push_back (getRepresentative (jointVids[i]));
} }
return solver_->getJointDistributionOf (representatives); return res;
} }

View File

@ -58,7 +58,6 @@ struct FacSignHash
}; };
class VarCluster class VarCluster
{ {
public: public:
@ -78,7 +77,6 @@ class VarCluster
}; };
class FacCluster class FacCluster
{ {
public: public:
@ -102,7 +100,6 @@ class FacCluster
}; };
class CbpSolver : public Solver class CbpSolver : public Solver
{ {
public: public:
@ -113,15 +110,10 @@ class CbpSolver : public Solver
void printSolverFlags (void) const; void printSolverFlags (void) const;
Params solveQuery (VarIds); Params solveQuery (VarIds);
Params getPosterioriOf (VarId);
Params getJointDistributionOf (const VarIds&);
static bool checkForIdenticalFactors; static bool checkForIdenticalFactors;
private: private:
Color getNewColor (void) Color getNewColor (void)
{ {
++ freeColor_; ++ freeColor_;

View File

@ -630,16 +630,9 @@ GroundOperator::getAffectedFormulas (void)
Params Params
FoveSolver::getPosterioriOf (const Ground& query) FoveSolver::solveQuery (const Grounds& query)
{
return getJointDistributionOf ({query});
}
Params
FoveSolver::getJointDistributionOf (const Grounds& query)
{ {
assert (query.empty() == false);
runSolver (query); runSolver (query);
(*pfList_.begin())->normalize(); (*pfList_.begin())->normalize();
Params params = (*pfList_.begin())->params(); Params params = (*pfList_.begin())->params();

View File

@ -135,9 +135,7 @@ class FoveSolver
public: public:
FoveSolver (const ParfactorList& pfList) : pfList_(pfList) { } FoveSolver (const ParfactorList& pfList) : pfList_(pfList) { }
Params getPosterioriOf (const Ground&); Params solveQuery (const Grounds&);
Params getJointDistributionOf (const Grounds&);
void printSolverFlags (void) const; void printSolverFlags (void) const;

View File

@ -317,22 +317,14 @@ runLiftedSolver (void)
solver.printSolverFlags(); solver.printSolverFlags();
cout << endl; cout << endl;
} }
if (queryVars.size() == 1) { results.push_back (solver.solveQuery (queryVars));
results.push_back (solver.getPosterioriOf (queryVars[0]));
} else {
results.push_back (solver.getJointDistributionOf (queryVars));
}
} else if (Globals::liftedSolver == LiftedSolvers::LBP) { } else if (Globals::liftedSolver == LiftedSolvers::LBP) {
LiftedBpSolver solver (pfListCopy); LiftedBpSolver solver (pfListCopy);
if (Globals::verbosity > 0 && taskList == YAP_ARG2) { if (Globals::verbosity > 0 && taskList == YAP_ARG2) {
solver.printSolverFlags(); solver.printSolverFlags();
cout << endl; cout << endl;
} }
if (queryVars.size() == 1) { results.push_back (solver.solveQuery (queryVars));
results.push_back (solver.getPosterioriOf (queryVars[0]));
} else {
results.push_back (solver.getJointDistributionOf (queryVars));
}
} else { } else {
assert (false); assert (false);
} }

View File

@ -15,23 +15,21 @@ LiftedBpSolver::LiftedBpSolver (const ParfactorList& pfList)
Params Params
LiftedBpSolver::getPosterioriOf (const Ground& query) LiftedBpSolver::solveQuery (const Grounds& query)
{
vector<PrvGroup> groups = getQueryGroups ({query});
return solver_->getPosterioriOf (groups[0]);
}
Params
LiftedBpSolver::getJointDistributionOf (const Grounds& query)
{ {
assert (query.empty() == false);
Params res;
vector<PrvGroup> groups = getQueryGroups (query); vector<PrvGroup> groups = getQueryGroups (query);
VarIds queryVids; if (query.size() == 1) {
for (unsigned i = 0; i < groups.size(); i++) { res = solver_->getPosterioriOf (groups[0]);
queryVids.push_back (groups[i]); } else {
VarIds queryVids;
for (unsigned i = 0; i < groups.size(); i++) {
queryVids.push_back (groups[i]);
}
res = solver_->getJointDistributionOf (queryVids);
} }
return solver_->getJointDistributionOf (queryVids); return res;
} }

View File

@ -12,9 +12,7 @@ class LiftedBpSolver
public: public:
LiftedBpSolver (const ParfactorList& pfList); LiftedBpSolver (const ParfactorList& pfList);
Params getPosterioriOf (const Ground&); Params solveQuery (const Grounds&);
Params getJointDistributionOf (const Grounds&);
void printSolverFlags (void) const; void printSolverFlags (void) const;

View File

@ -1,14 +1,6 @@
#include "WeightedBpSolver.h" #include "WeightedBpSolver.h"
WeightedBpSolver::WeightedBpSolver (
const FactorGraph& fg, const vector<vector<unsigned>>& weights)
: BpSolver (fg), weights_(weights)
{
}
WeightedBpSolver::~WeightedBpSolver (void) WeightedBpSolver::~WeightedBpSolver (void)
{ {
for (size_t i = 0; i < links_.size(); i++) { for (size_t i = 0; i < links_.size(); i++) {
@ -172,8 +164,8 @@ WeightedBpSolver::calcFactorToVarMsg (BpLink* _link)
Params msgProduct (msgSize, LogAware::multIdenty()); Params msgProduct (msgSize, LogAware::multIdenty());
if (Globals::logDomain) { if (Globals::logDomain) {
for (size_t i = links.size(); i-- > 0; ) { for (size_t i = links.size(); i-- > 0; ) {
const WeightedLink* cl = static_cast<const WeightedLink*> (links[i]); const WeightedLink* l = static_cast<const WeightedLink*> (links[i]);
if ( ! (cl->varNode() == dst && cl->index() == link->index())) { if ( ! (l->varNode() == dst && l->index() == link->index())) {
if (Constants::SHOW_BP_CALCS) { if (Constants::SHOW_BP_CALCS) {
cout << " message from " << links[i]->varNode()->label(); cout << " message from " << links[i]->varNode()->label();
cout << ": " ; cout << ": " ;
@ -188,8 +180,8 @@ WeightedBpSolver::calcFactorToVarMsg (BpLink* _link)
} }
} else { } else {
for (size_t i = links.size(); i-- > 0; ) { for (size_t i = links.size(); i-- > 0; ) {
const WeightedLink* cl = static_cast<const WeightedLink*> (links[i]); const WeightedLink* l = static_cast<const WeightedLink*> (links[i]);
if ( ! (cl->varNode() == dst && cl->index() == link->index())) { if ( ! (l->varNode() == dst && l->index() == link->index())) {
if (Constants::SHOW_BP_CALCS) { if (Constants::SHOW_BP_CALCS) {
cout << " message from " << links[i]->varNode()->label(); cout << " message from " << links[i]->varNode()->label();
cout << ": " ; cout << ": " ;
@ -255,19 +247,18 @@ WeightedBpSolver::getVarToFactorMsg (const BpLink* _link) const
const BpLinks& links = ninf(src)->getLinks(); const BpLinks& links = ninf(src)->getLinks();
if (Globals::logDomain) { if (Globals::logDomain) {
for (size_t i = 0; i < links.size(); i++) { for (size_t i = 0; i < links.size(); i++) {
WeightedLink* cl = static_cast<WeightedLink*> (links[i]); WeightedLink* l = static_cast<WeightedLink*> (links[i]);
if ( ! (cl->facNode() == dst && cl->index() == link->index())) { if ( ! (l->facNode() == dst && l->index() == link->index())) {
WeightedLink* cl = static_cast<WeightedLink*> (links[i]); msg += l->powMessage();
msg += cl->powMessage();
} }
} }
} else { } else {
for (size_t i = 0; i < links.size(); i++) { for (size_t i = 0; i < links.size(); i++) {
WeightedLink* cl = static_cast<WeightedLink*> (links[i]); WeightedLink* l = static_cast<WeightedLink*> (links[i]);
if ( ! (cl->facNode() == dst && cl->index() == link->index())) { if ( ! (l->facNode() == dst && l->index() == link->index())) {
msg *= cl->powMessage(); msg *= l->powMessage();
if (Constants::SHOW_BP_CALCS) { if (Constants::SHOW_BP_CALCS) {
cout << " x " << cl->nextMessage() << "^" << link->weight(); cout << " x " << l->nextMessage() << "^" << link->weight();
} }
} }
} }
@ -284,14 +275,14 @@ void
WeightedBpSolver::printLinkInformation (void) const WeightedBpSolver::printLinkInformation (void) const
{ {
for (size_t i = 0; i < links_.size(); i++) { for (size_t i = 0; i < links_.size(); i++) {
WeightedLink* cl = static_cast<WeightedLink*> (links_[i]); WeightedLink* l = static_cast<WeightedLink*> (links_[i]);
cout << cl->toString() << ":" << endl; cout << l->toString() << ":" << endl;
cout << " curr msg = " << cl->message() << endl; cout << " curr msg = " << l->message() << endl;
cout << " next msg = " << cl->nextMessage() << endl; cout << " next msg = " << l->nextMessage() << endl;
cout << " index = " << cl->index() << endl; cout << " pow msg = " << l->powMessage() << endl;
cout << " weight = " << cl->weight() << endl; cout << " index = " << l->index() << endl;
cout << " powered = " << cl->powMessage() << endl; cout << " weight = " << l->weight() << endl;
cout << " residual = " << cl->residual() << endl; cout << " residual = " << l->residual() << endl;
} }
} }

View File

@ -35,7 +35,8 @@ class WeightedBpSolver : public BpSolver
{ {
public: public:
WeightedBpSolver (const FactorGraph& fg, WeightedBpSolver (const FactorGraph& fg,
const vector<vector<unsigned>>&); const vector<vector<unsigned>>& weights)
: BpSolver (fg), weights_(weights) { }
~WeightedBpSolver (void); ~WeightedBpSolver (void);