fix joint calculation with lifted bp
This commit is contained in:
parent
0e83a75b60
commit
384c108e62
@ -118,19 +118,20 @@ BeliefProp::getJointDistributionOf (const VarIds& jointVarIds)
|
|||||||
if (idx == facNodes.size()) {
|
if (idx == facNodes.size()) {
|
||||||
return getJointByConditioning (jointVarIds);
|
return getJointByConditioning (jointVarIds);
|
||||||
}
|
}
|
||||||
return getFactorJoint (facNodes[idx], jointVarIds);
|
return getFactorJoint (idx, jointVarIds);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
Params
|
||||||
BeliefProp::getFactorJoint (
|
BeliefProp::getFactorJoint (
|
||||||
FacNode* fn,
|
size_t fnIdx,
|
||||||
const VarIds& jointVarIds)
|
const VarIds& jointVarIds)
|
||||||
{
|
{
|
||||||
if (runned_ == false) {
|
if (runned_ == false) {
|
||||||
runSolver();
|
runSolver();
|
||||||
}
|
}
|
||||||
|
FacNode* fn = fg.facNodes()[fnIdx];
|
||||||
Factor res (fn->factor());
|
Factor res (fn->factor());
|
||||||
const BpLinks& links = ninf(fn)->getLinks();
|
const BpLinks& links = ninf(fn)->getLinks();
|
||||||
for (size_t i = 0; i < links.size(); i++) {
|
for (size_t i = 0; i < links.size(); i++) {
|
||||||
|
@ -112,7 +112,7 @@ class BeliefProp : public Solver
|
|||||||
virtual Params getJointByConditioning (const VarIds&) const;
|
virtual Params getJointByConditioning (const VarIds&) const;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
Params getFactorJoint (FacNode*, const VarIds&);
|
Params getFactorJoint (size_t fnIdx, const VarIds&);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
SPNodeInfo* ninf (const VarNode* var) const
|
SPNodeInfo* ninf (const VarNode* var) const
|
||||||
|
@ -77,13 +77,11 @@ CountingBp::solveQuery (VarIds queryVids)
|
|||||||
res = Solver::getJointByConditioning (
|
res = Solver::getJointByConditioning (
|
||||||
GroundSolver::CBP, fg, queryVids);
|
GroundSolver::CBP, fg, queryVids);
|
||||||
} else {
|
} else {
|
||||||
FacNode* reprFn = getRepresentative (facNodes[idx]);
|
|
||||||
assert (reprFn != 0);
|
|
||||||
VarIds reprArgs;
|
VarIds reprArgs;
|
||||||
for (size_t i = 0; i < queryVids.size(); i++) {
|
for (size_t i = 0; i < queryVids.size(); i++) {
|
||||||
reprArgs.push_back (getRepresentative (queryVids[i]));
|
reprArgs.push_back (getRepresentative (queryVids[i]));
|
||||||
}
|
}
|
||||||
res = solver_->getFactorJoint (reprFn, reprArgs);
|
res = solver_->getFactorJoint (idx, reprArgs);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return res;
|
return res;
|
||||||
|
@ -29,11 +29,26 @@ LiftedBp::solveQuery (const Grounds& query)
|
|||||||
if (query.size() == 1) {
|
if (query.size() == 1) {
|
||||||
res = solver_->getPosterioriOf (groups[0]);
|
res = solver_->getPosterioriOf (groups[0]);
|
||||||
} else {
|
} else {
|
||||||
VarIds queryVids;
|
ParfactorList::iterator it = pfList_.begin();
|
||||||
for (unsigned i = 0; i < groups.size(); i++) {
|
size_t idx = pfList_.size();
|
||||||
queryVids.push_back (groups[i]);
|
size_t count = 0;
|
||||||
|
while (it != pfList_.end()) {
|
||||||
|
if ((*it)->containsGrounds (query)) {
|
||||||
|
idx = count;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
++ it;
|
||||||
|
++ count;
|
||||||
|
}
|
||||||
|
if (idx == pfList_.size()) {
|
||||||
|
res = getJointByConditioning (pfList_, query);
|
||||||
|
} else {
|
||||||
|
VarIds queryVids;
|
||||||
|
for (unsigned i = 0; i < groups.size(); i++) {
|
||||||
|
queryVids.push_back (groups[i]);
|
||||||
|
}
|
||||||
|
res = solver_->getFactorJoint (idx, queryVids);
|
||||||
}
|
}
|
||||||
res = solver_->getJointDistributionOf (queryVids);
|
|
||||||
}
|
}
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
@ -153,3 +168,65 @@ LiftedBp::getWeights (void) const
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
unsigned
|
||||||
|
LiftedBp::rangeOfGround (const Ground& gr)
|
||||||
|
{
|
||||||
|
ParfactorList::iterator it = pfList_.begin();
|
||||||
|
while (it != pfList_.end()) {
|
||||||
|
if ((*it)->containsGround (gr)) {
|
||||||
|
PrvGroup prvGroup = (*it)->findGroup (gr);
|
||||||
|
return (*it)->range ((*it)->indexOfGroup (prvGroup));
|
||||||
|
}
|
||||||
|
++ it;
|
||||||
|
}
|
||||||
|
return std::numeric_limits<unsigned>::max();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Params
|
||||||
|
LiftedBp::getJointByConditioning (
|
||||||
|
const ParfactorList& pfList,
|
||||||
|
const Grounds& grounds)
|
||||||
|
{
|
||||||
|
LiftedBp solver (pfList);
|
||||||
|
Params prevBeliefs = solver.solveQuery ({grounds[0]});
|
||||||
|
Grounds obsGrounds = {grounds[0]};
|
||||||
|
for (size_t i = 1; i < grounds.size(); i++) {
|
||||||
|
Params newBeliefs;
|
||||||
|
vector<ObservedFormula> obsFs;
|
||||||
|
Ranges obsRanges;
|
||||||
|
for (size_t j = 0; j < obsGrounds.size(); j++) {
|
||||||
|
obsFs.push_back (ObservedFormula (
|
||||||
|
obsGrounds[j].functor(), 0, obsGrounds[j].args()));
|
||||||
|
obsRanges.push_back (rangeOfGround (obsGrounds[j]));
|
||||||
|
}
|
||||||
|
Indexer indexer (obsRanges, false);
|
||||||
|
while (indexer.valid()) {
|
||||||
|
for (size_t j = 0; j < obsFs.size(); j++) {
|
||||||
|
obsFs[j].setEvidence (indexer[j]);
|
||||||
|
}
|
||||||
|
ParfactorList tempPfList (pfList);
|
||||||
|
LiftedVe::absorveEvidence (tempPfList, obsFs);
|
||||||
|
LiftedBp solver (tempPfList);
|
||||||
|
Params beliefs = solver.solveQuery ({grounds[i]});
|
||||||
|
for (size_t k = 0; k < beliefs.size(); k++) {
|
||||||
|
newBeliefs.push_back (beliefs[k]);
|
||||||
|
}
|
||||||
|
++ indexer;
|
||||||
|
}
|
||||||
|
int count = -1;
|
||||||
|
unsigned range = rangeOfGround (grounds[i]);
|
||||||
|
for (size_t j = 0; j < newBeliefs.size(); j++) {
|
||||||
|
if (j % range == 0) {
|
||||||
|
count ++;
|
||||||
|
}
|
||||||
|
newBeliefs[j] *= prevBeliefs[count];
|
||||||
|
}
|
||||||
|
prevBeliefs = newBeliefs;
|
||||||
|
obsGrounds.push_back (grounds[i]);
|
||||||
|
}
|
||||||
|
return prevBeliefs;
|
||||||
|
}
|
||||||
|
|
||||||
|
@ -27,9 +27,13 @@ class LiftedBp
|
|||||||
FactorGraph* getFactorGraph (void);
|
FactorGraph* getFactorGraph (void);
|
||||||
|
|
||||||
vector<vector<unsigned>> getWeights (void) const;
|
vector<vector<unsigned>> getWeights (void) const;
|
||||||
|
|
||||||
|
unsigned rangeOfGround (const Ground&);
|
||||||
|
|
||||||
ParfactorList pfList_;
|
Params getJointByConditioning (const ParfactorList&, const Grounds&);
|
||||||
WeightedBp* solver_;
|
|
||||||
|
ParfactorList pfList_;
|
||||||
|
WeightedBp* solver_;
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -11,7 +11,7 @@ LiftedOperator::getValidOps (
|
|||||||
ParfactorList& pfList,
|
ParfactorList& pfList,
|
||||||
const Grounds& query)
|
const Grounds& query)
|
||||||
{
|
{
|
||||||
vector<LiftedOperator*> validOps;
|
vector<LiftedOperator*> validOps;
|
||||||
vector<ProductOperator*> multOps;
|
vector<ProductOperator*> multOps;
|
||||||
|
|
||||||
multOps = ProductOperator::getValidOps (pfList);
|
multOps = ProductOperator::getValidOps (pfList);
|
||||||
|
@ -402,21 +402,32 @@ Parfactor::applySubstitution (const Substitution& theta)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
PrvGroup
|
size_t
|
||||||
Parfactor::findGroup (const Ground& ground) const
|
Parfactor::indexOfGround (const Ground& ground) const
|
||||||
{
|
{
|
||||||
PrvGroup group = numeric_limits<PrvGroup>::max();
|
size_t idx = args_.size();
|
||||||
for (size_t i = 0; i < args_.size(); i++) {
|
for (size_t i = 0; i < args_.size(); i++) {
|
||||||
if (args_[i].functor() == ground.functor() &&
|
if (args_[i].functor() == ground.functor() &&
|
||||||
args_[i].arity() == ground.arity()) {
|
args_[i].arity() == ground.arity()) {
|
||||||
constr_->moveToTop (args_[i].logVars());
|
constr_->moveToTop (args_[i].logVars());
|
||||||
if (constr_->containsTuple (ground.args())) {
|
if (constr_->containsTuple (ground.args())) {
|
||||||
group = args_[i].group();
|
idx = i;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return group;
|
return idx;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
PrvGroup
|
||||||
|
Parfactor::findGroup (const Ground& ground) const
|
||||||
|
{
|
||||||
|
size_t idx = indexOfGround (ground);
|
||||||
|
return idx == args_.size()
|
||||||
|
? numeric_limits<PrvGroup>::max()
|
||||||
|
: args_[idx].group();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -429,6 +440,30 @@ Parfactor::containsGround (const Ground& ground) const
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
Parfactor::containsGrounds (const Grounds& grounds) const
|
||||||
|
{
|
||||||
|
Tuple tuple;
|
||||||
|
LogVars tupleLvs;
|
||||||
|
for (size_t i = 0; i < grounds.size(); i++) {
|
||||||
|
size_t idx = indexOfGround (grounds[i]);
|
||||||
|
if (idx == args_.size()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
LogVars lvs = args_[idx].logVars();
|
||||||
|
for (size_t j = 0; j < lvs.size(); j++) {
|
||||||
|
if (Util::contains (tupleLvs, lvs[j]) == false) {
|
||||||
|
tuple.push_back (grounds[i].args()[j]);
|
||||||
|
tupleLvs.push_back (lvs[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
constr_->moveToTop (tupleLvs);
|
||||||
|
return constr_->containsTuple (tuple);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
bool
|
||||||
Parfactor::containsGroup (PrvGroup group) const
|
Parfactor::containsGroup (PrvGroup group) const
|
||||||
{
|
{
|
||||||
@ -442,6 +477,19 @@ Parfactor::containsGroup (PrvGroup group) const
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
Parfactor::containsGroups (vector<PrvGroup> groups) const
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < groups.size(); i++) {
|
||||||
|
if (containsGroup (groups[i]) == false) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
unsigned
|
unsigned
|
||||||
Parfactor::nrFormulas (LogVar X) const
|
Parfactor::nrFormulas (LogVar X) const
|
||||||
{
|
{
|
||||||
|
@ -64,11 +64,17 @@ class Parfactor : public TFactor<ProbFormula>
|
|||||||
|
|
||||||
void applySubstitution (const Substitution&);
|
void applySubstitution (const Substitution&);
|
||||||
|
|
||||||
|
size_t indexOfGround (const Ground&) const;
|
||||||
|
|
||||||
PrvGroup findGroup (const Ground&) const;
|
PrvGroup findGroup (const Ground&) const;
|
||||||
|
|
||||||
bool containsGround (const Ground&) const;
|
bool containsGround (const Ground&) const;
|
||||||
|
|
||||||
|
bool containsGrounds (const Grounds&) const;
|
||||||
|
|
||||||
bool containsGroup (PrvGroup) const;
|
bool containsGroup (PrvGroup) const;
|
||||||
|
|
||||||
|
bool containsGroups (vector<PrvGroup>) const;
|
||||||
|
|
||||||
unsigned nrFormulas (LogVar) const;
|
unsigned nrFormulas (LogVar) const;
|
||||||
|
|
||||||
|
@ -91,6 +91,8 @@ class ObservedFormula
|
|||||||
|
|
||||||
unsigned evidence (void) const { return evidence_; }
|
unsigned evidence (void) const { return evidence_; }
|
||||||
|
|
||||||
|
void setEvidence (unsigned ev) { evidence_ = ev; }
|
||||||
|
|
||||||
ConstraintTree& constr (void) { return constr_; }
|
ConstraintTree& constr (void) { return constr_; }
|
||||||
|
|
||||||
bool isAtom (void) const { return arity_ == 0; }
|
bool isAtom (void) const { return arity_ == 0; }
|
||||||
|
Reference in New Issue
Block a user