Merge branch 'master' of ssh://yap.git.sourceforge.net/gitroot/yap/yap-6.3
This commit is contained in:
commit
2d373a28a6
@ -118,20 +118,19 @@ BeliefProp::getJointDistributionOf (const VarIds& jointVarIds)
|
|||||||
if (idx == facNodes.size()) {
|
if (idx == facNodes.size()) {
|
||||||
return getJointByConditioning (jointVarIds);
|
return getJointByConditioning (jointVarIds);
|
||||||
}
|
}
|
||||||
return getFactorJoint (idx, jointVarIds);
|
return getFactorJoint (facNodes[idx], jointVarIds);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
Params
|
||||||
BeliefProp::getFactorJoint (
|
BeliefProp::getFactorJoint (
|
||||||
size_t fnIdx,
|
FacNode* fn,
|
||||||
const VarIds& jointVarIds)
|
const VarIds& jointVarIds)
|
||||||
{
|
{
|
||||||
if (runned_ == false) {
|
if (runned_ == false) {
|
||||||
runSolver();
|
runSolver();
|
||||||
}
|
}
|
||||||
FacNode* fn = fg.facNodes()[fnIdx];
|
|
||||||
Factor res (fn->factor());
|
Factor res (fn->factor());
|
||||||
const BpLinks& links = ninf(fn)->getLinks();
|
const BpLinks& links = ninf(fn)->getLinks();
|
||||||
for (size_t i = 0; i < links.size(); i++) {
|
for (size_t i = 0; i < links.size(); i++) {
|
||||||
@ -350,20 +349,22 @@ BeliefProp::getVarToFactorMsg (const BpLink* link) const
|
|||||||
const BpLinks& links = ninf (src)->getLinks();
|
const BpLinks& links = ninf (src)->getLinks();
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
for (it = links.begin(); it != links.end(); ++it) {
|
for (it = links.begin(); it != links.end(); ++it) {
|
||||||
msg += (*it)->message();
|
if (*it != link) {
|
||||||
|
msg += (*it)->message();
|
||||||
|
}
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
cout << " x " << (*it)->message();
|
cout << " x " << (*it)->message();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
msg -= link->message();
|
|
||||||
} else {
|
} else {
|
||||||
for (it = links.begin(); it != links.end(); ++it) {
|
for (it = links.begin(); it != links.end(); ++it) {
|
||||||
msg *= (*it)->message();
|
if (*it != link) {
|
||||||
|
msg *= (*it)->message();
|
||||||
|
}
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
cout << " x " << (*it)->message();
|
cout << " x " << (*it)->message();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
msg /= link->message();
|
|
||||||
}
|
}
|
||||||
if (Constants::SHOW_BP_CALCS) {
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
cout << " = " << msg;
|
cout << " = " << msg;
|
||||||
|
@ -112,7 +112,7 @@ class BeliefProp : public Solver
|
|||||||
virtual Params getJointByConditioning (const VarIds&) const;
|
virtual Params getJointByConditioning (const VarIds&) const;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
Params getFactorJoint (size_t fnIdx, const VarIds&);
|
Params getFactorJoint (FacNode* fn, const VarIds&);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
SPNodeInfo* ninf (const VarNode* var) const
|
SPNodeInfo* ninf (const VarNode* var) const
|
||||||
|
@ -81,7 +81,9 @@ CountingBp::solveQuery (VarIds queryVids)
|
|||||||
for (size_t i = 0; i < queryVids.size(); i++) {
|
for (size_t i = 0; i < queryVids.size(); i++) {
|
||||||
reprArgs.push_back (getRepresentative (queryVids[i]));
|
reprArgs.push_back (getRepresentative (queryVids[i]));
|
||||||
}
|
}
|
||||||
res = solver_->getFactorJoint (idx, reprArgs);
|
FacNode* reprFac = getRepresentative (facNodes[idx]);
|
||||||
|
assert (reprFac != 0);
|
||||||
|
res = solver_->getFactorJoint (reprFac, reprArgs);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return res;
|
return res;
|
||||||
|
@ -8,7 +8,8 @@ LiftedBp::LiftedBp (const ParfactorList& pfList)
|
|||||||
: pfList_(pfList)
|
: pfList_(pfList)
|
||||||
{
|
{
|
||||||
refineParfactors();
|
refineParfactors();
|
||||||
solver_ = new WeightedBp (*getFactorGraph(), getWeights());
|
createFactorGraph();
|
||||||
|
solver_ = new WeightedBp (*fg_, getWeights());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -16,6 +17,7 @@ LiftedBp::LiftedBp (const ParfactorList& pfList)
|
|||||||
LiftedBp::~LiftedBp (void)
|
LiftedBp::~LiftedBp (void)
|
||||||
{
|
{
|
||||||
delete solver_;
|
delete solver_;
|
||||||
|
delete fg_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -47,7 +49,7 @@ LiftedBp::solveQuery (const Grounds& query)
|
|||||||
for (unsigned i = 0; i < groups.size(); i++) {
|
for (unsigned i = 0; i < groups.size(); i++) {
|
||||||
queryVids.push_back (groups[i]);
|
queryVids.push_back (groups[i]);
|
||||||
}
|
}
|
||||||
res = solver_->getFactorJoint (idx, queryVids);
|
res = solver_->getFactorJoint (fg_->facNodes()[idx], queryVids);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return res;
|
return res;
|
||||||
@ -131,10 +133,10 @@ LiftedBp::getQueryGroups (const Grounds& query)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
FactorGraph*
|
void
|
||||||
LiftedBp::getFactorGraph (void)
|
LiftedBp::createFactorGraph (void)
|
||||||
{
|
{
|
||||||
FactorGraph* fg = new FactorGraph();
|
fg_ = new FactorGraph();
|
||||||
ParfactorList::const_iterator it = pfList_.begin();
|
ParfactorList::const_iterator it = pfList_.begin();
|
||||||
for (; it != pfList_.end(); ++it) {
|
for (; it != pfList_.end(); ++it) {
|
||||||
vector<PrvGroup> groups = (*it)->getAllGroups();
|
vector<PrvGroup> groups = (*it)->getAllGroups();
|
||||||
@ -142,9 +144,8 @@ LiftedBp::getFactorGraph (void)
|
|||||||
for (size_t i = 0; i < groups.size(); i++) {
|
for (size_t i = 0; i < groups.size(); i++) {
|
||||||
varIds.push_back (groups[i]);
|
varIds.push_back (groups[i]);
|
||||||
}
|
}
|
||||||
fg->addFactor (Factor (varIds, (*it)->ranges(), (*it)->params()));
|
fg_->addFactor (Factor (varIds, (*it)->ranges(), (*it)->params()));
|
||||||
}
|
}
|
||||||
return fg;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -24,7 +24,7 @@ class LiftedBp
|
|||||||
|
|
||||||
vector<PrvGroup> getQueryGroups (const Grounds&);
|
vector<PrvGroup> getQueryGroups (const Grounds&);
|
||||||
|
|
||||||
FactorGraph* getFactorGraph (void);
|
void createFactorGraph (void);
|
||||||
|
|
||||||
vector<vector<unsigned>> getWeights (void) const;
|
vector<vector<unsigned>> getWeights (void) const;
|
||||||
|
|
||||||
@ -34,6 +34,7 @@ class LiftedBp
|
|||||||
|
|
||||||
ParfactorList pfList_;
|
ParfactorList pfList_;
|
||||||
WeightedBp* solver_;
|
WeightedBp* solver_;
|
||||||
|
FactorGraph* fg_;
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -275,7 +275,7 @@ void
|
|||||||
WeightedBp::printLinkInformation (void) const
|
WeightedBp::printLinkInformation (void) const
|
||||||
{
|
{
|
||||||
for (size_t i = 0; i < links_.size(); i++) {
|
for (size_t i = 0; i < links_.size(); i++) {
|
||||||
WeightedLink* l = static_cast<WeightedLink*> (links_[i]);
|
WeightedLink* l = static_cast<WeightedLink*> (links_[i]);
|
||||||
cout << l->toString() << ":" << endl;
|
cout << l->toString() << ":" << endl;
|
||||||
cout << " curr msg = " << l->message() << endl;
|
cout << " curr msg = " << l->message() << endl;
|
||||||
cout << " next msg = " << l->nextMessage() << endl;
|
cout << " next msg = " << l->nextMessage() << endl;
|
||||||
|
Reference in New Issue
Block a user