fix joint calculation with lifted bp
This commit is contained in:
parent
0e83a75b60
commit
384c108e62
@ -118,19 +118,20 @@ BeliefProp::getJointDistributionOf (const VarIds& jointVarIds)
|
||||
if (idx == facNodes.size()) {
|
||||
return getJointByConditioning (jointVarIds);
|
||||
}
|
||||
return getFactorJoint (facNodes[idx], jointVarIds);
|
||||
return getFactorJoint (idx, jointVarIds);
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
BeliefProp::getFactorJoint (
|
||||
FacNode* fn,
|
||||
size_t fnIdx,
|
||||
const VarIds& jointVarIds)
|
||||
{
|
||||
if (runned_ == false) {
|
||||
runSolver();
|
||||
}
|
||||
FacNode* fn = fg.facNodes()[fnIdx];
|
||||
Factor res (fn->factor());
|
||||
const BpLinks& links = ninf(fn)->getLinks();
|
||||
for (size_t i = 0; i < links.size(); i++) {
|
||||
|
@ -112,7 +112,7 @@ class BeliefProp : public Solver
|
||||
virtual Params getJointByConditioning (const VarIds&) const;
|
||||
|
||||
public:
|
||||
Params getFactorJoint (FacNode*, const VarIds&);
|
||||
Params getFactorJoint (size_t fnIdx, const VarIds&);
|
||||
|
||||
protected:
|
||||
SPNodeInfo* ninf (const VarNode* var) const
|
||||
|
@ -77,13 +77,11 @@ CountingBp::solveQuery (VarIds queryVids)
|
||||
res = Solver::getJointByConditioning (
|
||||
GroundSolver::CBP, fg, queryVids);
|
||||
} else {
|
||||
FacNode* reprFn = getRepresentative (facNodes[idx]);
|
||||
assert (reprFn != 0);
|
||||
VarIds reprArgs;
|
||||
for (size_t i = 0; i < queryVids.size(); i++) {
|
||||
reprArgs.push_back (getRepresentative (queryVids[i]));
|
||||
}
|
||||
res = solver_->getFactorJoint (reprFn, reprArgs);
|
||||
res = solver_->getFactorJoint (idx, reprArgs);
|
||||
}
|
||||
}
|
||||
return res;
|
||||
|
@ -29,11 +29,26 @@ LiftedBp::solveQuery (const Grounds& query)
|
||||
if (query.size() == 1) {
|
||||
res = solver_->getPosterioriOf (groups[0]);
|
||||
} else {
|
||||
VarIds queryVids;
|
||||
for (unsigned i = 0; i < groups.size(); i++) {
|
||||
queryVids.push_back (groups[i]);
|
||||
ParfactorList::iterator it = pfList_.begin();
|
||||
size_t idx = pfList_.size();
|
||||
size_t count = 0;
|
||||
while (it != pfList_.end()) {
|
||||
if ((*it)->containsGrounds (query)) {
|
||||
idx = count;
|
||||
break;
|
||||
}
|
||||
++ it;
|
||||
++ count;
|
||||
}
|
||||
if (idx == pfList_.size()) {
|
||||
res = getJointByConditioning (pfList_, query);
|
||||
} else {
|
||||
VarIds queryVids;
|
||||
for (unsigned i = 0; i < groups.size(); i++) {
|
||||
queryVids.push_back (groups[i]);
|
||||
}
|
||||
res = solver_->getFactorJoint (idx, queryVids);
|
||||
}
|
||||
res = solver_->getJointDistributionOf (queryVids);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
@ -153,3 +168,65 @@ LiftedBp::getWeights (void) const
|
||||
}
|
||||
|
||||
|
||||
|
||||
unsigned
|
||||
LiftedBp::rangeOfGround (const Ground& gr)
|
||||
{
|
||||
ParfactorList::iterator it = pfList_.begin();
|
||||
while (it != pfList_.end()) {
|
||||
if ((*it)->containsGround (gr)) {
|
||||
PrvGroup prvGroup = (*it)->findGroup (gr);
|
||||
return (*it)->range ((*it)->indexOfGroup (prvGroup));
|
||||
}
|
||||
++ it;
|
||||
}
|
||||
return std::numeric_limits<unsigned>::max();
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
LiftedBp::getJointByConditioning (
|
||||
const ParfactorList& pfList,
|
||||
const Grounds& grounds)
|
||||
{
|
||||
LiftedBp solver (pfList);
|
||||
Params prevBeliefs = solver.solveQuery ({grounds[0]});
|
||||
Grounds obsGrounds = {grounds[0]};
|
||||
for (size_t i = 1; i < grounds.size(); i++) {
|
||||
Params newBeliefs;
|
||||
vector<ObservedFormula> obsFs;
|
||||
Ranges obsRanges;
|
||||
for (size_t j = 0; j < obsGrounds.size(); j++) {
|
||||
obsFs.push_back (ObservedFormula (
|
||||
obsGrounds[j].functor(), 0, obsGrounds[j].args()));
|
||||
obsRanges.push_back (rangeOfGround (obsGrounds[j]));
|
||||
}
|
||||
Indexer indexer (obsRanges, false);
|
||||
while (indexer.valid()) {
|
||||
for (size_t j = 0; j < obsFs.size(); j++) {
|
||||
obsFs[j].setEvidence (indexer[j]);
|
||||
}
|
||||
ParfactorList tempPfList (pfList);
|
||||
LiftedVe::absorveEvidence (tempPfList, obsFs);
|
||||
LiftedBp solver (tempPfList);
|
||||
Params beliefs = solver.solveQuery ({grounds[i]});
|
||||
for (size_t k = 0; k < beliefs.size(); k++) {
|
||||
newBeliefs.push_back (beliefs[k]);
|
||||
}
|
||||
++ indexer;
|
||||
}
|
||||
int count = -1;
|
||||
unsigned range = rangeOfGround (grounds[i]);
|
||||
for (size_t j = 0; j < newBeliefs.size(); j++) {
|
||||
if (j % range == 0) {
|
||||
count ++;
|
||||
}
|
||||
newBeliefs[j] *= prevBeliefs[count];
|
||||
}
|
||||
prevBeliefs = newBeliefs;
|
||||
obsGrounds.push_back (grounds[i]);
|
||||
}
|
||||
return prevBeliefs;
|
||||
}
|
||||
|
||||
|
@ -28,8 +28,12 @@ class LiftedBp
|
||||
|
||||
vector<vector<unsigned>> getWeights (void) const;
|
||||
|
||||
ParfactorList pfList_;
|
||||
WeightedBp* solver_;
|
||||
unsigned rangeOfGround (const Ground&);
|
||||
|
||||
Params getJointByConditioning (const ParfactorList&, const Grounds&);
|
||||
|
||||
ParfactorList pfList_;
|
||||
WeightedBp* solver_;
|
||||
|
||||
};
|
||||
|
||||
|
@ -11,7 +11,7 @@ LiftedOperator::getValidOps (
|
||||
ParfactorList& pfList,
|
||||
const Grounds& query)
|
||||
{
|
||||
vector<LiftedOperator*> validOps;
|
||||
vector<LiftedOperator*> validOps;
|
||||
vector<ProductOperator*> multOps;
|
||||
|
||||
multOps = ProductOperator::getValidOps (pfList);
|
||||
|
@ -402,21 +402,32 @@ Parfactor::applySubstitution (const Substitution& theta)
|
||||
|
||||
|
||||
|
||||
PrvGroup
|
||||
Parfactor::findGroup (const Ground& ground) const
|
||||
size_t
|
||||
Parfactor::indexOfGround (const Ground& ground) const
|
||||
{
|
||||
PrvGroup group = numeric_limits<PrvGroup>::max();
|
||||
size_t idx = args_.size();
|
||||
for (size_t i = 0; i < args_.size(); i++) {
|
||||
if (args_[i].functor() == ground.functor() &&
|
||||
args_[i].arity() == ground.arity()) {
|
||||
constr_->moveToTop (args_[i].logVars());
|
||||
if (constr_->containsTuple (ground.args())) {
|
||||
group = args_[i].group();
|
||||
idx = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return group;
|
||||
return idx;
|
||||
}
|
||||
|
||||
|
||||
|
||||
PrvGroup
|
||||
Parfactor::findGroup (const Ground& ground) const
|
||||
{
|
||||
size_t idx = indexOfGround (ground);
|
||||
return idx == args_.size()
|
||||
? numeric_limits<PrvGroup>::max()
|
||||
: args_[idx].group();
|
||||
}
|
||||
|
||||
|
||||
@ -429,6 +440,30 @@ Parfactor::containsGround (const Ground& ground) const
|
||||
|
||||
|
||||
|
||||
bool
|
||||
Parfactor::containsGrounds (const Grounds& grounds) const
|
||||
{
|
||||
Tuple tuple;
|
||||
LogVars tupleLvs;
|
||||
for (size_t i = 0; i < grounds.size(); i++) {
|
||||
size_t idx = indexOfGround (grounds[i]);
|
||||
if (idx == args_.size()) {
|
||||
return false;
|
||||
}
|
||||
LogVars lvs = args_[idx].logVars();
|
||||
for (size_t j = 0; j < lvs.size(); j++) {
|
||||
if (Util::contains (tupleLvs, lvs[j]) == false) {
|
||||
tuple.push_back (grounds[i].args()[j]);
|
||||
tupleLvs.push_back (lvs[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
constr_->moveToTop (tupleLvs);
|
||||
return constr_->containsTuple (tuple);
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
Parfactor::containsGroup (PrvGroup group) const
|
||||
{
|
||||
@ -442,6 +477,19 @@ Parfactor::containsGroup (PrvGroup group) const
|
||||
|
||||
|
||||
|
||||
bool
|
||||
Parfactor::containsGroups (vector<PrvGroup> groups) const
|
||||
{
|
||||
for (size_t i = 0; i < groups.size(); i++) {
|
||||
if (containsGroup (groups[i]) == false) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
|
||||
unsigned
|
||||
Parfactor::nrFormulas (LogVar X) const
|
||||
{
|
||||
|
@ -64,12 +64,18 @@ class Parfactor : public TFactor<ProbFormula>
|
||||
|
||||
void applySubstitution (const Substitution&);
|
||||
|
||||
size_t indexOfGround (const Ground&) const;
|
||||
|
||||
PrvGroup findGroup (const Ground&) const;
|
||||
|
||||
bool containsGround (const Ground&) const;
|
||||
|
||||
bool containsGrounds (const Grounds&) const;
|
||||
|
||||
bool containsGroup (PrvGroup) const;
|
||||
|
||||
bool containsGroups (vector<PrvGroup>) const;
|
||||
|
||||
unsigned nrFormulas (LogVar) const;
|
||||
|
||||
int indexOfLogVar (LogVar) const;
|
||||
|
@ -91,6 +91,8 @@ class ObservedFormula
|
||||
|
||||
unsigned evidence (void) const { return evidence_; }
|
||||
|
||||
void setEvidence (unsigned ev) { evidence_ = ev; }
|
||||
|
||||
ConstraintTree& constr (void) { return constr_; }
|
||||
|
||||
bool isAtom (void) const { return arity_ == 0; }
|
||||
|
Reference in New Issue
Block a user