fix joint calculation with lifted bp

This commit is contained in:
Tiago Gomes
2012-06-14 11:57:00 +01:00
parent 0e83a75b60
commit 384c108e62
9 changed files with 154 additions and 18 deletions

View File

@@ -29,11 +29,26 @@ LiftedBp::solveQuery (const Grounds& query)
if (query.size() == 1) {
res = solver_->getPosterioriOf (groups[0]);
} else {
VarIds queryVids;
for (unsigned i = 0; i < groups.size(); i++) {
queryVids.push_back (groups[i]);
ParfactorList::iterator it = pfList_.begin();
size_t idx = pfList_.size();
size_t count = 0;
while (it != pfList_.end()) {
if ((*it)->containsGrounds (query)) {
idx = count;
break;
}
++ it;
++ count;
}
if (idx == pfList_.size()) {
res = getJointByConditioning (pfList_, query);
} else {
VarIds queryVids;
for (unsigned i = 0; i < groups.size(); i++) {
queryVids.push_back (groups[i]);
}
res = solver_->getFactorJoint (idx, queryVids);
}
res = solver_->getJointDistributionOf (queryVids);
}
return res;
}
@@ -153,3 +168,65 @@ LiftedBp::getWeights (void) const
}
unsigned
LiftedBp::rangeOfGround (const Ground& gr)
{
ParfactorList::iterator it = pfList_.begin();
while (it != pfList_.end()) {
if ((*it)->containsGround (gr)) {
PrvGroup prvGroup = (*it)->findGroup (gr);
return (*it)->range ((*it)->indexOfGroup (prvGroup));
}
++ it;
}
return std::numeric_limits<unsigned>::max();
}
Params
LiftedBp::getJointByConditioning (
const ParfactorList& pfList,
const Grounds& grounds)
{
LiftedBp solver (pfList);
Params prevBeliefs = solver.solveQuery ({grounds[0]});
Grounds obsGrounds = {grounds[0]};
for (size_t i = 1; i < grounds.size(); i++) {
Params newBeliefs;
vector<ObservedFormula> obsFs;
Ranges obsRanges;
for (size_t j = 0; j < obsGrounds.size(); j++) {
obsFs.push_back (ObservedFormula (
obsGrounds[j].functor(), 0, obsGrounds[j].args()));
obsRanges.push_back (rangeOfGround (obsGrounds[j]));
}
Indexer indexer (obsRanges, false);
while (indexer.valid()) {
for (size_t j = 0; j < obsFs.size(); j++) {
obsFs[j].setEvidence (indexer[j]);
}
ParfactorList tempPfList (pfList);
LiftedVe::absorveEvidence (tempPfList, obsFs);
LiftedBp solver (tempPfList);
Params beliefs = solver.solveQuery ({grounds[i]});
for (size_t k = 0; k < beliefs.size(); k++) {
newBeliefs.push_back (beliefs[k]);
}
++ indexer;
}
int count = -1;
unsigned range = rangeOfGround (grounds[i]);
for (size_t j = 0; j < newBeliefs.size(); j++) {
if (j % range == 0) {
count ++;
}
newBeliefs[j] *= prevBeliefs[count];
}
prevBeliefs = newBeliefs;
obsGrounds.push_back (grounds[i]);
}
return prevBeliefs;
}