use only solveQuery
This commit is contained in:
parent
f91e543d9d
commit
b5369db214
@ -59,29 +59,17 @@ Params
|
|||||||
CbpSolver::solveQuery (VarIds queryVids)
|
CbpSolver::solveQuery (VarIds queryVids)
|
||||||
{
|
{
|
||||||
assert (queryVids.empty() == false);
|
assert (queryVids.empty() == false);
|
||||||
return queryVids.size() == 1
|
Params res;
|
||||||
? getPosterioriOf (queryVids[0])
|
if (queryVids.size() == 1) {
|
||||||
: getJointDistributionOf (queryVids);
|
res = solver_->getPosterioriOf (getRepresentative (queryVids[0]));
|
||||||
}
|
} else {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
|
||||||
CbpSolver::getPosterioriOf (VarId vid)
|
|
||||||
{
|
|
||||||
return solver_->getPosterioriOf (getRepresentative (vid));
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
|
||||||
CbpSolver::getJointDistributionOf (const VarIds& jointVids)
|
|
||||||
{
|
|
||||||
VarIds representatives;
|
VarIds representatives;
|
||||||
for (size_t i = 0; i < jointVids.size(); i++) {
|
for (size_t i = 0; i < queryVids.size(); i++) {
|
||||||
representatives.push_back (getRepresentative (jointVids[i]));
|
representatives.push_back (getRepresentative (queryVids[i]));
|
||||||
}
|
}
|
||||||
return solver_->getJointDistributionOf (representatives);
|
res = solver_->getJointDistributionOf (representatives);
|
||||||
|
}
|
||||||
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -58,7 +58,6 @@ struct FacSignHash
|
|||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class VarCluster
|
class VarCluster
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
@ -78,7 +77,6 @@ class VarCluster
|
|||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class FacCluster
|
class FacCluster
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
@ -102,7 +100,6 @@ class FacCluster
|
|||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CbpSolver : public Solver
|
class CbpSolver : public Solver
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
@ -114,14 +111,9 @@ class CbpSolver : public Solver
|
|||||||
|
|
||||||
Params solveQuery (VarIds);
|
Params solveQuery (VarIds);
|
||||||
|
|
||||||
Params getPosterioriOf (VarId);
|
|
||||||
|
|
||||||
Params getJointDistributionOf (const VarIds&);
|
|
||||||
|
|
||||||
static bool checkForIdenticalFactors;
|
static bool checkForIdenticalFactors;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
||||||
Color getNewColor (void)
|
Color getNewColor (void)
|
||||||
{
|
{
|
||||||
++ freeColor_;
|
++ freeColor_;
|
||||||
|
@ -630,16 +630,9 @@ GroundOperator::getAffectedFormulas (void)
|
|||||||
|
|
||||||
|
|
||||||
Params
|
Params
|
||||||
FoveSolver::getPosterioriOf (const Ground& query)
|
FoveSolver::solveQuery (const Grounds& query)
|
||||||
{
|
|
||||||
return getJointDistributionOf ({query});
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
|
||||||
FoveSolver::getJointDistributionOf (const Grounds& query)
|
|
||||||
{
|
{
|
||||||
|
assert (query.empty() == false);
|
||||||
runSolver (query);
|
runSolver (query);
|
||||||
(*pfList_.begin())->normalize();
|
(*pfList_.begin())->normalize();
|
||||||
Params params = (*pfList_.begin())->params();
|
Params params = (*pfList_.begin())->params();
|
||||||
|
@ -135,9 +135,7 @@ class FoveSolver
|
|||||||
public:
|
public:
|
||||||
FoveSolver (const ParfactorList& pfList) : pfList_(pfList) { }
|
FoveSolver (const ParfactorList& pfList) : pfList_(pfList) { }
|
||||||
|
|
||||||
Params getPosterioriOf (const Ground&);
|
Params solveQuery (const Grounds&);
|
||||||
|
|
||||||
Params getJointDistributionOf (const Grounds&);
|
|
||||||
|
|
||||||
void printSolverFlags (void) const;
|
void printSolverFlags (void) const;
|
||||||
|
|
||||||
|
@ -317,22 +317,14 @@ runLiftedSolver (void)
|
|||||||
solver.printSolverFlags();
|
solver.printSolverFlags();
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
if (queryVars.size() == 1) {
|
results.push_back (solver.solveQuery (queryVars));
|
||||||
results.push_back (solver.getPosterioriOf (queryVars[0]));
|
|
||||||
} else {
|
|
||||||
results.push_back (solver.getJointDistributionOf (queryVars));
|
|
||||||
}
|
|
||||||
} else if (Globals::liftedSolver == LiftedSolvers::LBP) {
|
} else if (Globals::liftedSolver == LiftedSolvers::LBP) {
|
||||||
LiftedBpSolver solver (pfListCopy);
|
LiftedBpSolver solver (pfListCopy);
|
||||||
if (Globals::verbosity > 0 && taskList == YAP_ARG2) {
|
if (Globals::verbosity > 0 && taskList == YAP_ARG2) {
|
||||||
solver.printSolverFlags();
|
solver.printSolverFlags();
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
if (queryVars.size() == 1) {
|
results.push_back (solver.solveQuery (queryVars));
|
||||||
results.push_back (solver.getPosterioriOf (queryVars[0]));
|
|
||||||
} else {
|
|
||||||
results.push_back (solver.getJointDistributionOf (queryVars));
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
assert (false);
|
assert (false);
|
||||||
}
|
}
|
||||||
|
@ -15,23 +15,21 @@ LiftedBpSolver::LiftedBpSolver (const ParfactorList& pfList)
|
|||||||
|
|
||||||
|
|
||||||
Params
|
Params
|
||||||
LiftedBpSolver::getPosterioriOf (const Ground& query)
|
LiftedBpSolver::solveQuery (const Grounds& query)
|
||||||
{
|
|
||||||
vector<PrvGroup> groups = getQueryGroups ({query});
|
|
||||||
return solver_->getPosterioriOf (groups[0]);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
|
||||||
LiftedBpSolver::getJointDistributionOf (const Grounds& query)
|
|
||||||
{
|
{
|
||||||
|
assert (query.empty() == false);
|
||||||
|
Params res;
|
||||||
vector<PrvGroup> groups = getQueryGroups (query);
|
vector<PrvGroup> groups = getQueryGroups (query);
|
||||||
|
if (query.size() == 1) {
|
||||||
|
res = solver_->getPosterioriOf (groups[0]);
|
||||||
|
} else {
|
||||||
VarIds queryVids;
|
VarIds queryVids;
|
||||||
for (unsigned i = 0; i < groups.size(); i++) {
|
for (unsigned i = 0; i < groups.size(); i++) {
|
||||||
queryVids.push_back (groups[i]);
|
queryVids.push_back (groups[i]);
|
||||||
}
|
}
|
||||||
return solver_->getJointDistributionOf (queryVids);
|
res = solver_->getJointDistributionOf (queryVids);
|
||||||
|
}
|
||||||
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -12,9 +12,7 @@ class LiftedBpSolver
|
|||||||
public:
|
public:
|
||||||
LiftedBpSolver (const ParfactorList& pfList);
|
LiftedBpSolver (const ParfactorList& pfList);
|
||||||
|
|
||||||
Params getPosterioriOf (const Ground&);
|
Params solveQuery (const Grounds&);
|
||||||
|
|
||||||
Params getJointDistributionOf (const Grounds&);
|
|
||||||
|
|
||||||
void printSolverFlags (void) const;
|
void printSolverFlags (void) const;
|
||||||
|
|
||||||
|
@ -1,14 +1,6 @@
|
|||||||
#include "WeightedBpSolver.h"
|
#include "WeightedBpSolver.h"
|
||||||
|
|
||||||
|
|
||||||
WeightedBpSolver::WeightedBpSolver (
|
|
||||||
const FactorGraph& fg, const vector<vector<unsigned>>& weights)
|
|
||||||
: BpSolver (fg), weights_(weights)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
WeightedBpSolver::~WeightedBpSolver (void)
|
WeightedBpSolver::~WeightedBpSolver (void)
|
||||||
{
|
{
|
||||||
for (size_t i = 0; i < links_.size(); i++) {
|
for (size_t i = 0; i < links_.size(); i++) {
|
||||||
@ -172,8 +164,8 @@ WeightedBpSolver::calcFactorToVarMsg (BpLink* _link)
|
|||||||
Params msgProduct (msgSize, LogAware::multIdenty());
|
Params msgProduct (msgSize, LogAware::multIdenty());
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
for (size_t i = links.size(); i-- > 0; ) {
|
for (size_t i = links.size(); i-- > 0; ) {
|
||||||
const WeightedLink* cl = static_cast<const WeightedLink*> (links[i]);
|
const WeightedLink* l = static_cast<const WeightedLink*> (links[i]);
|
||||||
if ( ! (cl->varNode() == dst && cl->index() == link->index())) {
|
if ( ! (l->varNode() == dst && l->index() == link->index())) {
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
cout << " message from " << links[i]->varNode()->label();
|
cout << " message from " << links[i]->varNode()->label();
|
||||||
cout << ": " ;
|
cout << ": " ;
|
||||||
@ -188,8 +180,8 @@ WeightedBpSolver::calcFactorToVarMsg (BpLink* _link)
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (size_t i = links.size(); i-- > 0; ) {
|
for (size_t i = links.size(); i-- > 0; ) {
|
||||||
const WeightedLink* cl = static_cast<const WeightedLink*> (links[i]);
|
const WeightedLink* l = static_cast<const WeightedLink*> (links[i]);
|
||||||
if ( ! (cl->varNode() == dst && cl->index() == link->index())) {
|
if ( ! (l->varNode() == dst && l->index() == link->index())) {
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
cout << " message from " << links[i]->varNode()->label();
|
cout << " message from " << links[i]->varNode()->label();
|
||||||
cout << ": " ;
|
cout << ": " ;
|
||||||
@ -255,19 +247,18 @@ WeightedBpSolver::getVarToFactorMsg (const BpLink* _link) const
|
|||||||
const BpLinks& links = ninf(src)->getLinks();
|
const BpLinks& links = ninf(src)->getLinks();
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
for (size_t i = 0; i < links.size(); i++) {
|
for (size_t i = 0; i < links.size(); i++) {
|
||||||
WeightedLink* cl = static_cast<WeightedLink*> (links[i]);
|
WeightedLink* l = static_cast<WeightedLink*> (links[i]);
|
||||||
if ( ! (cl->facNode() == dst && cl->index() == link->index())) {
|
if ( ! (l->facNode() == dst && l->index() == link->index())) {
|
||||||
WeightedLink* cl = static_cast<WeightedLink*> (links[i]);
|
msg += l->powMessage();
|
||||||
msg += cl->powMessage();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (size_t i = 0; i < links.size(); i++) {
|
for (size_t i = 0; i < links.size(); i++) {
|
||||||
WeightedLink* cl = static_cast<WeightedLink*> (links[i]);
|
WeightedLink* l = static_cast<WeightedLink*> (links[i]);
|
||||||
if ( ! (cl->facNode() == dst && cl->index() == link->index())) {
|
if ( ! (l->facNode() == dst && l->index() == link->index())) {
|
||||||
msg *= cl->powMessage();
|
msg *= l->powMessage();
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
cout << " x " << cl->nextMessage() << "^" << link->weight();
|
cout << " x " << l->nextMessage() << "^" << link->weight();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -284,14 +275,14 @@ void
|
|||||||
WeightedBpSolver::printLinkInformation (void) const
|
WeightedBpSolver::printLinkInformation (void) const
|
||||||
{
|
{
|
||||||
for (size_t i = 0; i < links_.size(); i++) {
|
for (size_t i = 0; i < links_.size(); i++) {
|
||||||
WeightedLink* cl = static_cast<WeightedLink*> (links_[i]);
|
WeightedLink* l = static_cast<WeightedLink*> (links_[i]);
|
||||||
cout << cl->toString() << ":" << endl;
|
cout << l->toString() << ":" << endl;
|
||||||
cout << " curr msg = " << cl->message() << endl;
|
cout << " curr msg = " << l->message() << endl;
|
||||||
cout << " next msg = " << cl->nextMessage() << endl;
|
cout << " next msg = " << l->nextMessage() << endl;
|
||||||
cout << " index = " << cl->index() << endl;
|
cout << " pow msg = " << l->powMessage() << endl;
|
||||||
cout << " weight = " << cl->weight() << endl;
|
cout << " index = " << l->index() << endl;
|
||||||
cout << " powered = " << cl->powMessage() << endl;
|
cout << " weight = " << l->weight() << endl;
|
||||||
cout << " residual = " << cl->residual() << endl;
|
cout << " residual = " << l->residual() << endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -35,7 +35,8 @@ class WeightedBpSolver : public BpSolver
|
|||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
WeightedBpSolver (const FactorGraph& fg,
|
WeightedBpSolver (const FactorGraph& fg,
|
||||||
const vector<vector<unsigned>>&);
|
const vector<vector<unsigned>>& weights)
|
||||||
|
: BpSolver (fg), weights_(weights) { }
|
||||||
|
|
||||||
~WeightedBpSolver (void);
|
~WeightedBpSolver (void);
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user