diff --git a/packages/CLPBN/horus/BeliefProp.cpp b/packages/CLPBN/horus/BeliefProp.cpp index 3d2237c49..61c69a7c2 100644 --- a/packages/CLPBN/horus/BeliefProp.cpp +++ b/packages/CLPBN/horus/BeliefProp.cpp @@ -118,19 +118,20 @@ BeliefProp::getJointDistributionOf (const VarIds& jointVarIds) if (idx == facNodes.size()) { return getJointByConditioning (jointVarIds); } - return getFactorJoint (facNodes[idx], jointVarIds); + return getFactorJoint (idx, jointVarIds); } Params BeliefProp::getFactorJoint ( - FacNode* fn, + size_t fnIdx, const VarIds& jointVarIds) { if (runned_ == false) { runSolver(); } + FacNode* fn = fg.facNodes()[fnIdx]; Factor res (fn->factor()); const BpLinks& links = ninf(fn)->getLinks(); for (size_t i = 0; i < links.size(); i++) { diff --git a/packages/CLPBN/horus/BeliefProp.h b/packages/CLPBN/horus/BeliefProp.h index af8da9a23..44c867dbc 100644 --- a/packages/CLPBN/horus/BeliefProp.h +++ b/packages/CLPBN/horus/BeliefProp.h @@ -112,7 +112,7 @@ class BeliefProp : public Solver virtual Params getJointByConditioning (const VarIds&) const; public: - Params getFactorJoint (FacNode*, const VarIds&); + Params getFactorJoint (size_t fnIdx, const VarIds&); protected: SPNodeInfo* ninf (const VarNode* var) const diff --git a/packages/CLPBN/horus/CountingBp.cpp b/packages/CLPBN/horus/CountingBp.cpp index ffd1abc64..1ee3b48f1 100644 --- a/packages/CLPBN/horus/CountingBp.cpp +++ b/packages/CLPBN/horus/CountingBp.cpp @@ -77,13 +77,11 @@ CountingBp::solveQuery (VarIds queryVids) res = Solver::getJointByConditioning ( GroundSolver::CBP, fg, queryVids); } else { - FacNode* reprFn = getRepresentative (facNodes[idx]); - assert (reprFn != 0); VarIds reprArgs; for (size_t i = 0; i < queryVids.size(); i++) { reprArgs.push_back (getRepresentative (queryVids[i])); } - res = solver_->getFactorJoint (reprFn, reprArgs); + res = solver_->getFactorJoint (idx, reprArgs); } } return res; diff --git a/packages/CLPBN/horus/LiftedBp.cpp b/packages/CLPBN/horus/LiftedBp.cpp index b4abfcdee..8fc8573b6 100644 --- a/packages/CLPBN/horus/LiftedBp.cpp +++ b/packages/CLPBN/horus/LiftedBp.cpp @@ -29,11 +29,26 @@ LiftedBp::solveQuery (const Grounds& query) if (query.size() == 1) { res = solver_->getPosterioriOf (groups[0]); } else { - VarIds queryVids; - for (unsigned i = 0; i < groups.size(); i++) { - queryVids.push_back (groups[i]); + ParfactorList::iterator it = pfList_.begin(); + size_t idx = pfList_.size(); + size_t count = 0; + while (it != pfList_.end()) { + if ((*it)->containsGrounds (query)) { + idx = count; + break; + } + ++ it; + ++ count; + } + if (idx == pfList_.size()) { + res = getJointByConditioning (pfList_, query); + } else { + VarIds queryVids; + for (unsigned i = 0; i < groups.size(); i++) { + queryVids.push_back (groups[i]); + } + res = solver_->getFactorJoint (idx, queryVids); } - res = solver_->getJointDistributionOf (queryVids); } return res; } @@ -153,3 +168,65 @@ LiftedBp::getWeights (void) const } + +unsigned +LiftedBp::rangeOfGround (const Ground& gr) +{ + ParfactorList::iterator it = pfList_.begin(); + while (it != pfList_.end()) { + if ((*it)->containsGround (gr)) { + PrvGroup prvGroup = (*it)->findGroup (gr); + return (*it)->range ((*it)->indexOfGroup (prvGroup)); + } + ++ it; + } + return std::numeric_limits::max(); +} + + + +Params +LiftedBp::getJointByConditioning ( + const ParfactorList& pfList, + const Grounds& grounds) +{ + LiftedBp solver (pfList); + Params prevBeliefs = solver.solveQuery ({grounds[0]}); + Grounds obsGrounds = {grounds[0]}; + for (size_t i = 1; i < grounds.size(); i++) { + Params newBeliefs; + vector obsFs; + Ranges obsRanges; + for (size_t j = 0; j < obsGrounds.size(); j++) { + obsFs.push_back (ObservedFormula ( + obsGrounds[j].functor(), 0, obsGrounds[j].args())); + obsRanges.push_back (rangeOfGround (obsGrounds[j])); + } + Indexer indexer (obsRanges, false); + while (indexer.valid()) { + for (size_t j = 0; j < obsFs.size(); j++) { + obsFs[j].setEvidence (indexer[j]); + } + ParfactorList tempPfList (pfList); + LiftedVe::absorveEvidence (tempPfList, obsFs); + LiftedBp solver (tempPfList); + Params beliefs = solver.solveQuery ({grounds[i]}); + for (size_t k = 0; k < beliefs.size(); k++) { + newBeliefs.push_back (beliefs[k]); + } + ++ indexer; + } + int count = -1; + unsigned range = rangeOfGround (grounds[i]); + for (size_t j = 0; j < newBeliefs.size(); j++) { + if (j % range == 0) { + count ++; + } + newBeliefs[j] *= prevBeliefs[count]; + } + prevBeliefs = newBeliefs; + obsGrounds.push_back (grounds[i]); + } + return prevBeliefs; +} + diff --git a/packages/CLPBN/horus/LiftedBp.h b/packages/CLPBN/horus/LiftedBp.h index f4c1ffa44..01e18310e 100644 --- a/packages/CLPBN/horus/LiftedBp.h +++ b/packages/CLPBN/horus/LiftedBp.h @@ -27,9 +27,13 @@ class LiftedBp FactorGraph* getFactorGraph (void); vector> getWeights (void) const; + + unsigned rangeOfGround (const Ground&); - ParfactorList pfList_; - WeightedBp* solver_; + Params getJointByConditioning (const ParfactorList&, const Grounds&); + + ParfactorList pfList_; + WeightedBp* solver_; }; diff --git a/packages/CLPBN/horus/LiftedVe.cpp b/packages/CLPBN/horus/LiftedVe.cpp index a6724c3c2..6ed2c22f7 100644 --- a/packages/CLPBN/horus/LiftedVe.cpp +++ b/packages/CLPBN/horus/LiftedVe.cpp @@ -11,7 +11,7 @@ LiftedOperator::getValidOps ( ParfactorList& pfList, const Grounds& query) { - vector validOps; + vector validOps; vector multOps; multOps = ProductOperator::getValidOps (pfList); diff --git a/packages/CLPBN/horus/Parfactor.cpp b/packages/CLPBN/horus/Parfactor.cpp index cad0fe32f..6eaa32e72 100644 --- a/packages/CLPBN/horus/Parfactor.cpp +++ b/packages/CLPBN/horus/Parfactor.cpp @@ -402,21 +402,32 @@ Parfactor::applySubstitution (const Substitution& theta) -PrvGroup -Parfactor::findGroup (const Ground& ground) const +size_t +Parfactor::indexOfGround (const Ground& ground) const { - PrvGroup group = numeric_limits::max(); + size_t idx = args_.size(); for (size_t i = 0; i < args_.size(); i++) { if (args_[i].functor() == ground.functor() && args_[i].arity() == ground.arity()) { constr_->moveToTop (args_[i].logVars()); if (constr_->containsTuple (ground.args())) { - group = args_[i].group(); + idx = i; break; } } } - return group; + return idx; +} + + + +PrvGroup +Parfactor::findGroup (const Ground& ground) const +{ + size_t idx = indexOfGround (ground); + return idx == args_.size() + ? numeric_limits::max() + : args_[idx].group(); } @@ -429,6 +440,30 @@ Parfactor::containsGround (const Ground& ground) const +bool +Parfactor::containsGrounds (const Grounds& grounds) const +{ + Tuple tuple; + LogVars tupleLvs; + for (size_t i = 0; i < grounds.size(); i++) { + size_t idx = indexOfGround (grounds[i]); + if (idx == args_.size()) { + return false; + } + LogVars lvs = args_[idx].logVars(); + for (size_t j = 0; j < lvs.size(); j++) { + if (Util::contains (tupleLvs, lvs[j]) == false) { + tuple.push_back (grounds[i].args()[j]); + tupleLvs.push_back (lvs[j]); + } + } + } + constr_->moveToTop (tupleLvs); + return constr_->containsTuple (tuple); +} + + + bool Parfactor::containsGroup (PrvGroup group) const { @@ -442,6 +477,19 @@ Parfactor::containsGroup (PrvGroup group) const +bool +Parfactor::containsGroups (vector groups) const +{ + for (size_t i = 0; i < groups.size(); i++) { + if (containsGroup (groups[i]) == false) { + return false; + } + } + return true; +} + + + unsigned Parfactor::nrFormulas (LogVar X) const { diff --git a/packages/CLPBN/horus/Parfactor.h b/packages/CLPBN/horus/Parfactor.h index aded326d3..5f6aec550 100644 --- a/packages/CLPBN/horus/Parfactor.h +++ b/packages/CLPBN/horus/Parfactor.h @@ -64,11 +64,17 @@ class Parfactor : public TFactor void applySubstitution (const Substitution&); + size_t indexOfGround (const Ground&) const; + PrvGroup findGroup (const Ground&) const; bool containsGround (const Ground&) const; + bool containsGrounds (const Grounds&) const; + bool containsGroup (PrvGroup) const; + + bool containsGroups (vector) const; unsigned nrFormulas (LogVar) const; diff --git a/packages/CLPBN/horus/ProbFormula.h b/packages/CLPBN/horus/ProbFormula.h index 7e29e933e..61b016288 100644 --- a/packages/CLPBN/horus/ProbFormula.h +++ b/packages/CLPBN/horus/ProbFormula.h @@ -91,6 +91,8 @@ class ObservedFormula unsigned evidence (void) const { return evidence_; } + void setEvidence (unsigned ev) { evidence_ = ev; } + ConstraintTree& constr (void) { return constr_; } bool isAtom (void) const { return arity_ == 0; }