fix joint calculation with lifted bp

This commit is contained in:
Tiago Gomes 2012-06-14 11:57:00 +01:00
parent 0e83a75b60
commit 384c108e62
9 changed files with 154 additions and 18 deletions

View File

@ -118,19 +118,20 @@ BeliefProp::getJointDistributionOf (const VarIds& jointVarIds)
if (idx == facNodes.size()) { if (idx == facNodes.size()) {
return getJointByConditioning (jointVarIds); return getJointByConditioning (jointVarIds);
} }
return getFactorJoint (facNodes[idx], jointVarIds); return getFactorJoint (idx, jointVarIds);
} }
Params Params
BeliefProp::getFactorJoint ( BeliefProp::getFactorJoint (
FacNode* fn, size_t fnIdx,
const VarIds& jointVarIds) const VarIds& jointVarIds)
{ {
if (runned_ == false) { if (runned_ == false) {
runSolver(); runSolver();
} }
FacNode* fn = fg.facNodes()[fnIdx];
Factor res (fn->factor()); Factor res (fn->factor());
const BpLinks& links = ninf(fn)->getLinks(); const BpLinks& links = ninf(fn)->getLinks();
for (size_t i = 0; i < links.size(); i++) { for (size_t i = 0; i < links.size(); i++) {

View File

@ -112,7 +112,7 @@ class BeliefProp : public Solver
virtual Params getJointByConditioning (const VarIds&) const; virtual Params getJointByConditioning (const VarIds&) const;
public: public:
Params getFactorJoint (FacNode*, const VarIds&); Params getFactorJoint (size_t fnIdx, const VarIds&);
protected: protected:
SPNodeInfo* ninf (const VarNode* var) const SPNodeInfo* ninf (const VarNode* var) const

View File

@ -77,13 +77,11 @@ CountingBp::solveQuery (VarIds queryVids)
res = Solver::getJointByConditioning ( res = Solver::getJointByConditioning (
GroundSolver::CBP, fg, queryVids); GroundSolver::CBP, fg, queryVids);
} else { } else {
FacNode* reprFn = getRepresentative (facNodes[idx]);
assert (reprFn != 0);
VarIds reprArgs; VarIds reprArgs;
for (size_t i = 0; i < queryVids.size(); i++) { for (size_t i = 0; i < queryVids.size(); i++) {
reprArgs.push_back (getRepresentative (queryVids[i])); reprArgs.push_back (getRepresentative (queryVids[i]));
} }
res = solver_->getFactorJoint (reprFn, reprArgs); res = solver_->getFactorJoint (idx, reprArgs);
} }
} }
return res; return res;

View File

@ -29,11 +29,26 @@ LiftedBp::solveQuery (const Grounds& query)
if (query.size() == 1) { if (query.size() == 1) {
res = solver_->getPosterioriOf (groups[0]); res = solver_->getPosterioriOf (groups[0]);
} else { } else {
VarIds queryVids; ParfactorList::iterator it = pfList_.begin();
for (unsigned i = 0; i < groups.size(); i++) { size_t idx = pfList_.size();
queryVids.push_back (groups[i]); size_t count = 0;
while (it != pfList_.end()) {
if ((*it)->containsGrounds (query)) {
idx = count;
break;
}
++ it;
++ count;
}
if (idx == pfList_.size()) {
res = getJointByConditioning (pfList_, query);
} else {
VarIds queryVids;
for (unsigned i = 0; i < groups.size(); i++) {
queryVids.push_back (groups[i]);
}
res = solver_->getFactorJoint (idx, queryVids);
} }
res = solver_->getJointDistributionOf (queryVids);
} }
return res; return res;
} }
@ -153,3 +168,65 @@ LiftedBp::getWeights (void) const
} }
unsigned
LiftedBp::rangeOfGround (const Ground& gr)
{
ParfactorList::iterator it = pfList_.begin();
while (it != pfList_.end()) {
if ((*it)->containsGround (gr)) {
PrvGroup prvGroup = (*it)->findGroup (gr);
return (*it)->range ((*it)->indexOfGroup (prvGroup));
}
++ it;
}
return std::numeric_limits<unsigned>::max();
}
Params
LiftedBp::getJointByConditioning (
const ParfactorList& pfList,
const Grounds& grounds)
{
LiftedBp solver (pfList);
Params prevBeliefs = solver.solveQuery ({grounds[0]});
Grounds obsGrounds = {grounds[0]};
for (size_t i = 1; i < grounds.size(); i++) {
Params newBeliefs;
vector<ObservedFormula> obsFs;
Ranges obsRanges;
for (size_t j = 0; j < obsGrounds.size(); j++) {
obsFs.push_back (ObservedFormula (
obsGrounds[j].functor(), 0, obsGrounds[j].args()));
obsRanges.push_back (rangeOfGround (obsGrounds[j]));
}
Indexer indexer (obsRanges, false);
while (indexer.valid()) {
for (size_t j = 0; j < obsFs.size(); j++) {
obsFs[j].setEvidence (indexer[j]);
}
ParfactorList tempPfList (pfList);
LiftedVe::absorveEvidence (tempPfList, obsFs);
LiftedBp solver (tempPfList);
Params beliefs = solver.solveQuery ({grounds[i]});
for (size_t k = 0; k < beliefs.size(); k++) {
newBeliefs.push_back (beliefs[k]);
}
++ indexer;
}
int count = -1;
unsigned range = rangeOfGround (grounds[i]);
for (size_t j = 0; j < newBeliefs.size(); j++) {
if (j % range == 0) {
count ++;
}
newBeliefs[j] *= prevBeliefs[count];
}
prevBeliefs = newBeliefs;
obsGrounds.push_back (grounds[i]);
}
return prevBeliefs;
}

View File

@ -27,9 +27,13 @@ class LiftedBp
FactorGraph* getFactorGraph (void); FactorGraph* getFactorGraph (void);
vector<vector<unsigned>> getWeights (void) const; vector<vector<unsigned>> getWeights (void) const;
unsigned rangeOfGround (const Ground&);
ParfactorList pfList_; Params getJointByConditioning (const ParfactorList&, const Grounds&);
WeightedBp* solver_;
ParfactorList pfList_;
WeightedBp* solver_;
}; };

View File

@ -11,7 +11,7 @@ LiftedOperator::getValidOps (
ParfactorList& pfList, ParfactorList& pfList,
const Grounds& query) const Grounds& query)
{ {
vector<LiftedOperator*> validOps; vector<LiftedOperator*> validOps;
vector<ProductOperator*> multOps; vector<ProductOperator*> multOps;
multOps = ProductOperator::getValidOps (pfList); multOps = ProductOperator::getValidOps (pfList);

View File

@ -402,21 +402,32 @@ Parfactor::applySubstitution (const Substitution& theta)
PrvGroup size_t
Parfactor::findGroup (const Ground& ground) const Parfactor::indexOfGround (const Ground& ground) const
{ {
PrvGroup group = numeric_limits<PrvGroup>::max(); size_t idx = args_.size();
for (size_t i = 0; i < args_.size(); i++) { for (size_t i = 0; i < args_.size(); i++) {
if (args_[i].functor() == ground.functor() && if (args_[i].functor() == ground.functor() &&
args_[i].arity() == ground.arity()) { args_[i].arity() == ground.arity()) {
constr_->moveToTop (args_[i].logVars()); constr_->moveToTop (args_[i].logVars());
if (constr_->containsTuple (ground.args())) { if (constr_->containsTuple (ground.args())) {
group = args_[i].group(); idx = i;
break; break;
} }
} }
} }
return group; return idx;
}
PrvGroup
Parfactor::findGroup (const Ground& ground) const
{
size_t idx = indexOfGround (ground);
return idx == args_.size()
? numeric_limits<PrvGroup>::max()
: args_[idx].group();
} }
@ -429,6 +440,30 @@ Parfactor::containsGround (const Ground& ground) const
bool
Parfactor::containsGrounds (const Grounds& grounds) const
{
Tuple tuple;
LogVars tupleLvs;
for (size_t i = 0; i < grounds.size(); i++) {
size_t idx = indexOfGround (grounds[i]);
if (idx == args_.size()) {
return false;
}
LogVars lvs = args_[idx].logVars();
for (size_t j = 0; j < lvs.size(); j++) {
if (Util::contains (tupleLvs, lvs[j]) == false) {
tuple.push_back (grounds[i].args()[j]);
tupleLvs.push_back (lvs[j]);
}
}
}
constr_->moveToTop (tupleLvs);
return constr_->containsTuple (tuple);
}
bool bool
Parfactor::containsGroup (PrvGroup group) const Parfactor::containsGroup (PrvGroup group) const
{ {
@ -442,6 +477,19 @@ Parfactor::containsGroup (PrvGroup group) const
bool
Parfactor::containsGroups (vector<PrvGroup> groups) const
{
for (size_t i = 0; i < groups.size(); i++) {
if (containsGroup (groups[i]) == false) {
return false;
}
}
return true;
}
unsigned unsigned
Parfactor::nrFormulas (LogVar X) const Parfactor::nrFormulas (LogVar X) const
{ {

View File

@ -64,11 +64,17 @@ class Parfactor : public TFactor<ProbFormula>
void applySubstitution (const Substitution&); void applySubstitution (const Substitution&);
size_t indexOfGround (const Ground&) const;
PrvGroup findGroup (const Ground&) const; PrvGroup findGroup (const Ground&) const;
bool containsGround (const Ground&) const; bool containsGround (const Ground&) const;
bool containsGrounds (const Grounds&) const;
bool containsGroup (PrvGroup) const; bool containsGroup (PrvGroup) const;
bool containsGroups (vector<PrvGroup>) const;
unsigned nrFormulas (LogVar) const; unsigned nrFormulas (LogVar) const;

View File

@ -91,6 +91,8 @@ class ObservedFormula
unsigned evidence (void) const { return evidence_; } unsigned evidence (void) const { return evidence_; }
void setEvidence (unsigned ev) { evidence_ = ev; }
ConstraintTree& constr (void) { return constr_; } ConstraintTree& constr (void) { return constr_; }
bool isAtom (void) const { return arity_ == 0; } bool isAtom (void) const { return arity_ == 0; }