fix joint calculation with lifted bp
This commit is contained in:
@@ -29,11 +29,26 @@ LiftedBp::solveQuery (const Grounds& query)
|
||||
if (query.size() == 1) {
|
||||
res = solver_->getPosterioriOf (groups[0]);
|
||||
} else {
|
||||
VarIds queryVids;
|
||||
for (unsigned i = 0; i < groups.size(); i++) {
|
||||
queryVids.push_back (groups[i]);
|
||||
ParfactorList::iterator it = pfList_.begin();
|
||||
size_t idx = pfList_.size();
|
||||
size_t count = 0;
|
||||
while (it != pfList_.end()) {
|
||||
if ((*it)->containsGrounds (query)) {
|
||||
idx = count;
|
||||
break;
|
||||
}
|
||||
++ it;
|
||||
++ count;
|
||||
}
|
||||
if (idx == pfList_.size()) {
|
||||
res = getJointByConditioning (pfList_, query);
|
||||
} else {
|
||||
VarIds queryVids;
|
||||
for (unsigned i = 0; i < groups.size(); i++) {
|
||||
queryVids.push_back (groups[i]);
|
||||
}
|
||||
res = solver_->getFactorJoint (idx, queryVids);
|
||||
}
|
||||
res = solver_->getJointDistributionOf (queryVids);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
@@ -153,3 +168,65 @@ LiftedBp::getWeights (void) const
|
||||
}
|
||||
|
||||
|
||||
|
||||
unsigned
|
||||
LiftedBp::rangeOfGround (const Ground& gr)
|
||||
{
|
||||
ParfactorList::iterator it = pfList_.begin();
|
||||
while (it != pfList_.end()) {
|
||||
if ((*it)->containsGround (gr)) {
|
||||
PrvGroup prvGroup = (*it)->findGroup (gr);
|
||||
return (*it)->range ((*it)->indexOfGroup (prvGroup));
|
||||
}
|
||||
++ it;
|
||||
}
|
||||
return std::numeric_limits<unsigned>::max();
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
LiftedBp::getJointByConditioning (
|
||||
const ParfactorList& pfList,
|
||||
const Grounds& grounds)
|
||||
{
|
||||
LiftedBp solver (pfList);
|
||||
Params prevBeliefs = solver.solveQuery ({grounds[0]});
|
||||
Grounds obsGrounds = {grounds[0]};
|
||||
for (size_t i = 1; i < grounds.size(); i++) {
|
||||
Params newBeliefs;
|
||||
vector<ObservedFormula> obsFs;
|
||||
Ranges obsRanges;
|
||||
for (size_t j = 0; j < obsGrounds.size(); j++) {
|
||||
obsFs.push_back (ObservedFormula (
|
||||
obsGrounds[j].functor(), 0, obsGrounds[j].args()));
|
||||
obsRanges.push_back (rangeOfGround (obsGrounds[j]));
|
||||
}
|
||||
Indexer indexer (obsRanges, false);
|
||||
while (indexer.valid()) {
|
||||
for (size_t j = 0; j < obsFs.size(); j++) {
|
||||
obsFs[j].setEvidence (indexer[j]);
|
||||
}
|
||||
ParfactorList tempPfList (pfList);
|
||||
LiftedVe::absorveEvidence (tempPfList, obsFs);
|
||||
LiftedBp solver (tempPfList);
|
||||
Params beliefs = solver.solveQuery ({grounds[i]});
|
||||
for (size_t k = 0; k < beliefs.size(); k++) {
|
||||
newBeliefs.push_back (beliefs[k]);
|
||||
}
|
||||
++ indexer;
|
||||
}
|
||||
int count = -1;
|
||||
unsigned range = rangeOfGround (grounds[i]);
|
||||
for (size_t j = 0; j < newBeliefs.size(); j++) {
|
||||
if (j % range == 0) {
|
||||
count ++;
|
||||
}
|
||||
newBeliefs[j] *= prevBeliefs[count];
|
||||
}
|
||||
prevBeliefs = newBeliefs;
|
||||
obsGrounds.push_back (grounds[i]);
|
||||
}
|
||||
return prevBeliefs;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user