Improve variable elimination

This commit is contained in:
Tiago Gomes 2012-12-26 22:55:48 +00:00
parent 188f359496
commit 8bdcb65907
2 changed files with 59 additions and 80 deletions

View File

@ -6,13 +6,6 @@
#include "Util.h" #include "Util.h"
VarElim::~VarElim (void)
{
delete factorList_.back();
}
Params Params
VarElim::solveQuery (VarIds queryVids) VarElim::solveQuery (VarIds queryVids)
{ {
@ -24,14 +17,13 @@ VarElim::solveQuery (VarIds queryVids)
} }
cout << endl; cout << endl;
} }
totalFactorSize_ = 0;
largestFactorSize_ = 0;
factorList_.clear(); factorList_.clear();
varFactors_.clear(); varMap_.clear();
elimOrder_.clear();
createFactorList(); createFactorList();
absorveEvidence(); absorveEvidence();
findEliminationOrder (queryVids); Params params = processFactorList (queryVids);
processFactorList (queryVids);
Params params = factorList_.back()->params();
if (Globals::logDomain) { if (Globals::logDomain) {
Util::exp (params); Util::exp (params);
} }
@ -68,15 +60,15 @@ VarElim::createFactorList (void)
factorList_.reserve (facNodes.size() * 2); factorList_.reserve (facNodes.size() * 2);
for (size_t i = 0; i < facNodes.size(); i++) { for (size_t i = 0; i < facNodes.size(); i++) {
factorList_.push_back (new Factor (facNodes[i]->factor())); factorList_.push_back (new Factor (facNodes[i]->factor()));
const VarNodes& neighs = facNodes[i]->neighbors(); const VarIds& args = facNodes[i]->factor().arguments();
for (size_t j = 0; j < neighs.size(); j++) { for (size_t j = 0; j < args.size(); j++) {
unordered_map<VarId, vector<size_t>>::iterator it unordered_map<VarId, vector<size_t>>::iterator it;
= varFactors_.find (neighs[j]->varId()); it = varMap_.find (args[j]);
if (it == varFactors_.end()) { if (it != varMap_.end()) {
it = varFactors_.insert (make_pair ( it->second.push_back (i);
neighs[j]->varId(), vector<size_t>())).first; } else {
varMap_[args[j]] = { i };
} }
it->second.push_back (i);
} }
} }
} }
@ -99,15 +91,15 @@ VarElim::absorveEvidence (void)
cout << varNodes[i]->label() << " = " ; cout << varNodes[i]->label() << " = " ;
cout << varNodes[i]->getEvidence() << endl; cout << varNodes[i]->getEvidence() << endl;
} }
const vector<size_t>& idxs = const vector<size_t>& indices = varMap_[varNodes[i]->varId()];
varFactors_.find (varNodes[i]->varId())->second; for (size_t j = 0; j < indices.size(); j++) {
for (size_t j = 0; j < idxs.size(); j++) { size_t idx = indices[j];
Factor* factor = factorList_[idxs[j]]; if (factorList_[idx]->nrArguments() > 1) {
if (factor->nrArguments() == 1) { factorList_[idx]->absorveEvidence (
factorList_[idxs[j]] = 0;
} else {
factorList_[idxs[j]]->absorveEvidence (
varNodes[i]->varId(), varNodes[i]->getEvidence()); varNodes[i]->varId(), varNodes[i]->getEvidence());
} else {
delete factorList_[idx];
factorList_[idx] = 0;
} }
} }
} }
@ -116,72 +108,60 @@ VarElim::absorveEvidence (void)
void Params
VarElim::findEliminationOrder (const VarIds& vids) VarElim::processFactorList (const VarIds& queryVids)
{ {
elimOrder_ = ElimGraph::getEliminationOrder (factorList_, vids); VarIds elimOrder = ElimGraph::getEliminationOrder (
} factorList_, queryVids);
for (size_t i = 0; i < elimOrder.size(); i++) {
void
VarElim::processFactorList (const VarIds& vids)
{
totalFactorSize_ = 0;
largestFactorSize_ = 0;
for (size_t i = 0; i < elimOrder_.size(); i++) {
if (Globals::verbosity >= 2) { if (Globals::verbosity >= 2) {
if (Globals::verbosity >= 3) { if (Globals::verbosity >= 3) {
Util::printDashedLine(); Util::printDashedLine();
printActiveFactors(); printActiveFactors();
} }
cout << "-> summing out " ; cout << "-> summing out " ;
cout << fg.getVarNode (elimOrder_[i])->label() << endl; cout << fg.getVarNode (elimOrder[i])->label() << endl;
} }
eliminate (elimOrder_[i]); eliminate (elimOrder[i]);
} }
Factor* finalFactor = new Factor(); Factor result;
for (size_t i = 0; i < factorList_.size(); i++) { for (size_t i = 0; i < factorList_.size(); i++) {
if (factorList_[i]) { if (factorList_[i]) {
finalFactor->multiply (*factorList_[i]); result.multiply (*factorList_[i]);
delete factorList_[i]; delete factorList_[i];
factorList_[i] = 0; factorList_[i] = 0;
} }
} }
VarIds unobservedVids; VarIds unobservedVids;
for (size_t i = 0; i < vids.size(); i++) { for (size_t i = 0; i < queryVids.size(); i++) {
if (fg.getVarNode (vids[i])->hasEvidence() == false) { if (fg.getVarNode (queryVids[i])->hasEvidence() == false) {
unobservedVids.push_back (vids[i]); unobservedVids.push_back (queryVids[i]);
} }
} }
finalFactor->reorderArguments (unobservedVids); result.reorderArguments (unobservedVids);
finalFactor->normalize(); result.normalize();
factorList_.push_back (finalFactor);
if (Globals::verbosity > 0) { if (Globals::verbosity > 0) {
cout << "total factor size: " << totalFactorSize_ << endl; cout << "total factor size: " << totalFactorSize_ << endl;
cout << "largest factor size: " << largestFactorSize_ << endl; cout << "largest factor size: " << largestFactorSize_ << endl;
cout << endl; cout << endl;
} }
return result.params();
} }
void void
VarElim::eliminate (VarId elimVar) VarElim::eliminate (VarId vid)
{ {
Factor* result = 0; Factor* result = new Factor();
vector<size_t>& idxs = varFactors_.find (elimVar)->second; const vector<size_t>& indices = varMap_[vid];
for (size_t i = 0; i < idxs.size(); i++) { for (size_t i = 0; i < indices.size(); i++) {
size_t idx = idxs[i]; size_t idx = indices[i];
if (factorList_[idx]) { if (factorList_[idx]) {
if (result == 0) { result->multiply (*factorList_[idx]);
result = new Factor (*factorList_[idx]);
} else {
result->multiply (*factorList_[idx]);
}
delete factorList_[idx]; delete factorList_[idx];
factorList_[idx] = 0; factorList_[idx] = 0;
} }
@ -190,15 +170,16 @@ VarElim::eliminate (VarId elimVar)
if (result->size() > largestFactorSize_) { if (result->size() > largestFactorSize_) {
largestFactorSize_ = result->size(); largestFactorSize_ = result->size();
} }
if (result != 0 && result->nrArguments() != 1) { if (result->nrArguments() > 1) {
result->sumOut (elimVar); result->sumOut (vid);
factorList_.push_back (result); const VarIds& args = result->arguments();
const VarIds& resultVarIds = result->arguments(); for (size_t i = 0; i < args.size(); i++) {
for (size_t i = 0; i < resultVarIds.size(); i++) { vector<size_t>& indices2 = varMap_[args[i]];
vector<size_t>& idxs = indices2.push_back (factorList_.size());
varFactors_.find (resultVarIds[i])->second;
idxs.push_back (factorList_.size() - 1);
} }
factorList_.push_back (result);
} else {
delete result;
} }
} }
@ -208,9 +189,10 @@ void
VarElim::printActiveFactors (void) VarElim::printActiveFactors (void)
{ {
for (size_t i = 0; i < factorList_.size(); i++) { for (size_t i = 0; i < factorList_.size(); i++) {
if (factorList_[i] != 0) { if (factorList_[i]) {
cout << factorList_[i]->getLabel() << " " ; cout << factorList_[i]->getLabel() << " " ;
cout << factorList_[i]->params() << endl; cout << factorList_[i]->params();
cout << endl;
} }
} }
} }

View File

@ -16,7 +16,7 @@ class VarElim : public GroundSolver
public: public:
VarElim (const FactorGraph& fg) : GroundSolver (fg) { } VarElim (const FactorGraph& fg) : GroundSolver (fg) { }
~VarElim (void); ~VarElim (void) { }
Params solveQuery (VarIds); Params solveQuery (VarIds);
@ -27,19 +27,16 @@ class VarElim : public GroundSolver
void absorveEvidence (void); void absorveEvidence (void);
void findEliminationOrder (const VarIds&); Params processFactorList (const VarIds&);
void processFactorList (const VarIds&);
void eliminate (VarId); void eliminate (VarId);
void printActiveFactors (void); void printActiveFactors (void);
Factors factorList_; Factors factorList_;
VarIds elimOrder_; unsigned largestFactorSize_;
unsigned largestFactorSize_; unsigned totalFactorSize_;
unsigned totalFactorSize_; unordered_map<VarId, vector<size_t>> varMap_;
unordered_map<VarId, vector<size_t>> varFactors_;
}; };
#endif // HORUS_VARELIM_H #endif // HORUS_VARELIM_H