diff --git a/packages/CLPBN/horus/VarElim.cpp b/packages/CLPBN/horus/VarElim.cpp index 54ee18d20..d31f6ce51 100644 --- a/packages/CLPBN/horus/VarElim.cpp +++ b/packages/CLPBN/horus/VarElim.cpp @@ -6,13 +6,6 @@ #include "Util.h" -VarElim::~VarElim (void) -{ - delete factorList_.back(); -} - - - Params VarElim::solveQuery (VarIds queryVids) { @@ -24,14 +17,13 @@ VarElim::solveQuery (VarIds queryVids) } cout << endl; } + totalFactorSize_ = 0; + largestFactorSize_ = 0; factorList_.clear(); - varFactors_.clear(); - elimOrder_.clear(); + varMap_.clear(); createFactorList(); absorveEvidence(); - findEliminationOrder (queryVids); - processFactorList (queryVids); - Params params = factorList_.back()->params(); + Params params = processFactorList (queryVids); if (Globals::logDomain) { Util::exp (params); } @@ -68,15 +60,15 @@ VarElim::createFactorList (void) factorList_.reserve (facNodes.size() * 2); for (size_t i = 0; i < facNodes.size(); i++) { factorList_.push_back (new Factor (facNodes[i]->factor())); - const VarNodes& neighs = facNodes[i]->neighbors(); - for (size_t j = 0; j < neighs.size(); j++) { - unordered_map>::iterator it - = varFactors_.find (neighs[j]->varId()); - if (it == varFactors_.end()) { - it = varFactors_.insert (make_pair ( - neighs[j]->varId(), vector())).first; + const VarIds& args = facNodes[i]->factor().arguments(); + for (size_t j = 0; j < args.size(); j++) { + unordered_map>::iterator it; + it = varMap_.find (args[j]); + if (it != varMap_.end()) { + it->second.push_back (i); + } else { + varMap_[args[j]] = { i }; } - it->second.push_back (i); } } } @@ -99,15 +91,15 @@ VarElim::absorveEvidence (void) cout << varNodes[i]->label() << " = " ; cout << varNodes[i]->getEvidence() << endl; } - const vector& idxs = - varFactors_.find (varNodes[i]->varId())->second; - for (size_t j = 0; j < idxs.size(); j++) { - Factor* factor = factorList_[idxs[j]]; - if (factor->nrArguments() == 1) { - factorList_[idxs[j]] = 0; - } else { - factorList_[idxs[j]]->absorveEvidence ( + const vector& indices = varMap_[varNodes[i]->varId()]; + for (size_t j = 0; j < indices.size(); j++) { + size_t idx = indices[j]; + if (factorList_[idx]->nrArguments() > 1) { + factorList_[idx]->absorveEvidence ( varNodes[i]->varId(), varNodes[i]->getEvidence()); + } else { + delete factorList_[idx]; + factorList_[idx] = 0; } } } @@ -116,72 +108,60 @@ VarElim::absorveEvidence (void) -void -VarElim::findEliminationOrder (const VarIds& vids) +Params +VarElim::processFactorList (const VarIds& queryVids) { - elimOrder_ = ElimGraph::getEliminationOrder (factorList_, vids); -} - - - -void -VarElim::processFactorList (const VarIds& vids) -{ - totalFactorSize_ = 0; - largestFactorSize_ = 0; - for (size_t i = 0; i < elimOrder_.size(); i++) { + VarIds elimOrder = ElimGraph::getEliminationOrder ( + factorList_, queryVids); + for (size_t i = 0; i < elimOrder.size(); i++) { if (Globals::verbosity >= 2) { if (Globals::verbosity >= 3) { Util::printDashedLine(); printActiveFactors(); } cout << "-> summing out " ; - cout << fg.getVarNode (elimOrder_[i])->label() << endl; + cout << fg.getVarNode (elimOrder[i])->label() << endl; } - eliminate (elimOrder_[i]); + eliminate (elimOrder[i]); } - Factor* finalFactor = new Factor(); + Factor result; for (size_t i = 0; i < factorList_.size(); i++) { if (factorList_[i]) { - finalFactor->multiply (*factorList_[i]); + result.multiply (*factorList_[i]); delete factorList_[i]; factorList_[i] = 0; } } VarIds unobservedVids; - for (size_t i = 0; i < vids.size(); i++) { - if (fg.getVarNode (vids[i])->hasEvidence() == false) { - unobservedVids.push_back (vids[i]); + for (size_t i = 0; i < queryVids.size(); i++) { + if (fg.getVarNode (queryVids[i])->hasEvidence() == false) { + unobservedVids.push_back (queryVids[i]); } } - finalFactor->reorderArguments (unobservedVids); - finalFactor->normalize(); - factorList_.push_back (finalFactor); + result.reorderArguments (unobservedVids); + result.normalize(); if (Globals::verbosity > 0) { cout << "total factor size: " << totalFactorSize_ << endl; cout << "largest factor size: " << largestFactorSize_ << endl; cout << endl; } + return result.params(); } void -VarElim::eliminate (VarId elimVar) +VarElim::eliminate (VarId vid) { - Factor* result = 0; - vector& idxs = varFactors_.find (elimVar)->second; - for (size_t i = 0; i < idxs.size(); i++) { - size_t idx = idxs[i]; + Factor* result = new Factor(); + const vector& indices = varMap_[vid]; + for (size_t i = 0; i < indices.size(); i++) { + size_t idx = indices[i]; if (factorList_[idx]) { - if (result == 0) { - result = new Factor (*factorList_[idx]); - } else { - result->multiply (*factorList_[idx]); - } + result->multiply (*factorList_[idx]); delete factorList_[idx]; factorList_[idx] = 0; } @@ -190,15 +170,16 @@ VarElim::eliminate (VarId elimVar) if (result->size() > largestFactorSize_) { largestFactorSize_ = result->size(); } - if (result != 0 && result->nrArguments() != 1) { - result->sumOut (elimVar); - factorList_.push_back (result); - const VarIds& resultVarIds = result->arguments(); - for (size_t i = 0; i < resultVarIds.size(); i++) { - vector& idxs = - varFactors_.find (resultVarIds[i])->second; - idxs.push_back (factorList_.size() - 1); + if (result->nrArguments() > 1) { + result->sumOut (vid); + const VarIds& args = result->arguments(); + for (size_t i = 0; i < args.size(); i++) { + vector& indices2 = varMap_[args[i]]; + indices2.push_back (factorList_.size()); } + factorList_.push_back (result); + } else { + delete result; } } @@ -208,9 +189,10 @@ void VarElim::printActiveFactors (void) { for (size_t i = 0; i < factorList_.size(); i++) { - if (factorList_[i] != 0) { + if (factorList_[i]) { cout << factorList_[i]->getLabel() << " " ; - cout << factorList_[i]->params() << endl; + cout << factorList_[i]->params(); + cout << endl; } } } diff --git a/packages/CLPBN/horus/VarElim.h b/packages/CLPBN/horus/VarElim.h index fe1327fc0..96906bb00 100644 --- a/packages/CLPBN/horus/VarElim.h +++ b/packages/CLPBN/horus/VarElim.h @@ -16,7 +16,7 @@ class VarElim : public GroundSolver public: VarElim (const FactorGraph& fg) : GroundSolver (fg) { } - ~VarElim (void); + ~VarElim (void) { } Params solveQuery (VarIds); @@ -27,19 +27,16 @@ class VarElim : public GroundSolver void absorveEvidence (void); - void findEliminationOrder (const VarIds&); - - void processFactorList (const VarIds&); + Params processFactorList (const VarIds&); void eliminate (VarId); void printActiveFactors (void); - Factors factorList_; - VarIds elimOrder_; - unsigned largestFactorSize_; - unsigned totalFactorSize_; - unordered_map> varFactors_; + Factors factorList_; + unsigned largestFactorSize_; + unsigned totalFactorSize_; + unordered_map> varMap_; }; #endif // HORUS_VARELIM_H