Improve variable elimination
This commit is contained in:
parent
188f359496
commit
8bdcb65907
@ -6,13 +6,6 @@
|
|||||||
#include "Util.h"
|
#include "Util.h"
|
||||||
|
|
||||||
|
|
||||||
VarElim::~VarElim (void)
|
|
||||||
{
|
|
||||||
delete factorList_.back();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
Params
|
||||||
VarElim::solveQuery (VarIds queryVids)
|
VarElim::solveQuery (VarIds queryVids)
|
||||||
{
|
{
|
||||||
@ -24,14 +17,13 @@ VarElim::solveQuery (VarIds queryVids)
|
|||||||
}
|
}
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
|
totalFactorSize_ = 0;
|
||||||
|
largestFactorSize_ = 0;
|
||||||
factorList_.clear();
|
factorList_.clear();
|
||||||
varFactors_.clear();
|
varMap_.clear();
|
||||||
elimOrder_.clear();
|
|
||||||
createFactorList();
|
createFactorList();
|
||||||
absorveEvidence();
|
absorveEvidence();
|
||||||
findEliminationOrder (queryVids);
|
Params params = processFactorList (queryVids);
|
||||||
processFactorList (queryVids);
|
|
||||||
Params params = factorList_.back()->params();
|
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
Util::exp (params);
|
Util::exp (params);
|
||||||
}
|
}
|
||||||
@ -68,15 +60,15 @@ VarElim::createFactorList (void)
|
|||||||
factorList_.reserve (facNodes.size() * 2);
|
factorList_.reserve (facNodes.size() * 2);
|
||||||
for (size_t i = 0; i < facNodes.size(); i++) {
|
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||||
factorList_.push_back (new Factor (facNodes[i]->factor()));
|
factorList_.push_back (new Factor (facNodes[i]->factor()));
|
||||||
const VarNodes& neighs = facNodes[i]->neighbors();
|
const VarIds& args = facNodes[i]->factor().arguments();
|
||||||
for (size_t j = 0; j < neighs.size(); j++) {
|
for (size_t j = 0; j < args.size(); j++) {
|
||||||
unordered_map<VarId, vector<size_t>>::iterator it
|
unordered_map<VarId, vector<size_t>>::iterator it;
|
||||||
= varFactors_.find (neighs[j]->varId());
|
it = varMap_.find (args[j]);
|
||||||
if (it == varFactors_.end()) {
|
if (it != varMap_.end()) {
|
||||||
it = varFactors_.insert (make_pair (
|
it->second.push_back (i);
|
||||||
neighs[j]->varId(), vector<size_t>())).first;
|
} else {
|
||||||
|
varMap_[args[j]] = { i };
|
||||||
}
|
}
|
||||||
it->second.push_back (i);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -99,15 +91,15 @@ VarElim::absorveEvidence (void)
|
|||||||
cout << varNodes[i]->label() << " = " ;
|
cout << varNodes[i]->label() << " = " ;
|
||||||
cout << varNodes[i]->getEvidence() << endl;
|
cout << varNodes[i]->getEvidence() << endl;
|
||||||
}
|
}
|
||||||
const vector<size_t>& idxs =
|
const vector<size_t>& indices = varMap_[varNodes[i]->varId()];
|
||||||
varFactors_.find (varNodes[i]->varId())->second;
|
for (size_t j = 0; j < indices.size(); j++) {
|
||||||
for (size_t j = 0; j < idxs.size(); j++) {
|
size_t idx = indices[j];
|
||||||
Factor* factor = factorList_[idxs[j]];
|
if (factorList_[idx]->nrArguments() > 1) {
|
||||||
if (factor->nrArguments() == 1) {
|
factorList_[idx]->absorveEvidence (
|
||||||
factorList_[idxs[j]] = 0;
|
|
||||||
} else {
|
|
||||||
factorList_[idxs[j]]->absorveEvidence (
|
|
||||||
varNodes[i]->varId(), varNodes[i]->getEvidence());
|
varNodes[i]->varId(), varNodes[i]->getEvidence());
|
||||||
|
} else {
|
||||||
|
delete factorList_[idx];
|
||||||
|
factorList_[idx] = 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -116,72 +108,60 @@ VarElim::absorveEvidence (void)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
Params
|
||||||
VarElim::findEliminationOrder (const VarIds& vids)
|
VarElim::processFactorList (const VarIds& queryVids)
|
||||||
{
|
{
|
||||||
elimOrder_ = ElimGraph::getEliminationOrder (factorList_, vids);
|
VarIds elimOrder = ElimGraph::getEliminationOrder (
|
||||||
}
|
factorList_, queryVids);
|
||||||
|
for (size_t i = 0; i < elimOrder.size(); i++) {
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
VarElim::processFactorList (const VarIds& vids)
|
|
||||||
{
|
|
||||||
totalFactorSize_ = 0;
|
|
||||||
largestFactorSize_ = 0;
|
|
||||||
for (size_t i = 0; i < elimOrder_.size(); i++) {
|
|
||||||
if (Globals::verbosity >= 2) {
|
if (Globals::verbosity >= 2) {
|
||||||
if (Globals::verbosity >= 3) {
|
if (Globals::verbosity >= 3) {
|
||||||
Util::printDashedLine();
|
Util::printDashedLine();
|
||||||
printActiveFactors();
|
printActiveFactors();
|
||||||
}
|
}
|
||||||
cout << "-> summing out " ;
|
cout << "-> summing out " ;
|
||||||
cout << fg.getVarNode (elimOrder_[i])->label() << endl;
|
cout << fg.getVarNode (elimOrder[i])->label() << endl;
|
||||||
}
|
}
|
||||||
eliminate (elimOrder_[i]);
|
eliminate (elimOrder[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
Factor* finalFactor = new Factor();
|
Factor result;
|
||||||
for (size_t i = 0; i < factorList_.size(); i++) {
|
for (size_t i = 0; i < factorList_.size(); i++) {
|
||||||
if (factorList_[i]) {
|
if (factorList_[i]) {
|
||||||
finalFactor->multiply (*factorList_[i]);
|
result.multiply (*factorList_[i]);
|
||||||
delete factorList_[i];
|
delete factorList_[i];
|
||||||
factorList_[i] = 0;
|
factorList_[i] = 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
VarIds unobservedVids;
|
VarIds unobservedVids;
|
||||||
for (size_t i = 0; i < vids.size(); i++) {
|
for (size_t i = 0; i < queryVids.size(); i++) {
|
||||||
if (fg.getVarNode (vids[i])->hasEvidence() == false) {
|
if (fg.getVarNode (queryVids[i])->hasEvidence() == false) {
|
||||||
unobservedVids.push_back (vids[i]);
|
unobservedVids.push_back (queryVids[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
finalFactor->reorderArguments (unobservedVids);
|
result.reorderArguments (unobservedVids);
|
||||||
finalFactor->normalize();
|
result.normalize();
|
||||||
factorList_.push_back (finalFactor);
|
|
||||||
if (Globals::verbosity > 0) {
|
if (Globals::verbosity > 0) {
|
||||||
cout << "total factor size: " << totalFactorSize_ << endl;
|
cout << "total factor size: " << totalFactorSize_ << endl;
|
||||||
cout << "largest factor size: " << largestFactorSize_ << endl;
|
cout << "largest factor size: " << largestFactorSize_ << endl;
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
|
return result.params();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
VarElim::eliminate (VarId elimVar)
|
VarElim::eliminate (VarId vid)
|
||||||
{
|
{
|
||||||
Factor* result = 0;
|
Factor* result = new Factor();
|
||||||
vector<size_t>& idxs = varFactors_.find (elimVar)->second;
|
const vector<size_t>& indices = varMap_[vid];
|
||||||
for (size_t i = 0; i < idxs.size(); i++) {
|
for (size_t i = 0; i < indices.size(); i++) {
|
||||||
size_t idx = idxs[i];
|
size_t idx = indices[i];
|
||||||
if (factorList_[idx]) {
|
if (factorList_[idx]) {
|
||||||
if (result == 0) {
|
result->multiply (*factorList_[idx]);
|
||||||
result = new Factor (*factorList_[idx]);
|
|
||||||
} else {
|
|
||||||
result->multiply (*factorList_[idx]);
|
|
||||||
}
|
|
||||||
delete factorList_[idx];
|
delete factorList_[idx];
|
||||||
factorList_[idx] = 0;
|
factorList_[idx] = 0;
|
||||||
}
|
}
|
||||||
@ -190,15 +170,16 @@ VarElim::eliminate (VarId elimVar)
|
|||||||
if (result->size() > largestFactorSize_) {
|
if (result->size() > largestFactorSize_) {
|
||||||
largestFactorSize_ = result->size();
|
largestFactorSize_ = result->size();
|
||||||
}
|
}
|
||||||
if (result != 0 && result->nrArguments() != 1) {
|
if (result->nrArguments() > 1) {
|
||||||
result->sumOut (elimVar);
|
result->sumOut (vid);
|
||||||
factorList_.push_back (result);
|
const VarIds& args = result->arguments();
|
||||||
const VarIds& resultVarIds = result->arguments();
|
for (size_t i = 0; i < args.size(); i++) {
|
||||||
for (size_t i = 0; i < resultVarIds.size(); i++) {
|
vector<size_t>& indices2 = varMap_[args[i]];
|
||||||
vector<size_t>& idxs =
|
indices2.push_back (factorList_.size());
|
||||||
varFactors_.find (resultVarIds[i])->second;
|
|
||||||
idxs.push_back (factorList_.size() - 1);
|
|
||||||
}
|
}
|
||||||
|
factorList_.push_back (result);
|
||||||
|
} else {
|
||||||
|
delete result;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -208,9 +189,10 @@ void
|
|||||||
VarElim::printActiveFactors (void)
|
VarElim::printActiveFactors (void)
|
||||||
{
|
{
|
||||||
for (size_t i = 0; i < factorList_.size(); i++) {
|
for (size_t i = 0; i < factorList_.size(); i++) {
|
||||||
if (factorList_[i] != 0) {
|
if (factorList_[i]) {
|
||||||
cout << factorList_[i]->getLabel() << " " ;
|
cout << factorList_[i]->getLabel() << " " ;
|
||||||
cout << factorList_[i]->params() << endl;
|
cout << factorList_[i]->params();
|
||||||
|
cout << endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -16,7 +16,7 @@ class VarElim : public GroundSolver
|
|||||||
public:
|
public:
|
||||||
VarElim (const FactorGraph& fg) : GroundSolver (fg) { }
|
VarElim (const FactorGraph& fg) : GroundSolver (fg) { }
|
||||||
|
|
||||||
~VarElim (void);
|
~VarElim (void) { }
|
||||||
|
|
||||||
Params solveQuery (VarIds);
|
Params solveQuery (VarIds);
|
||||||
|
|
||||||
@ -27,19 +27,16 @@ class VarElim : public GroundSolver
|
|||||||
|
|
||||||
void absorveEvidence (void);
|
void absorveEvidence (void);
|
||||||
|
|
||||||
void findEliminationOrder (const VarIds&);
|
Params processFactorList (const VarIds&);
|
||||||
|
|
||||||
void processFactorList (const VarIds&);
|
|
||||||
|
|
||||||
void eliminate (VarId);
|
void eliminate (VarId);
|
||||||
|
|
||||||
void printActiveFactors (void);
|
void printActiveFactors (void);
|
||||||
|
|
||||||
Factors factorList_;
|
Factors factorList_;
|
||||||
VarIds elimOrder_;
|
unsigned largestFactorSize_;
|
||||||
unsigned largestFactorSize_;
|
unsigned totalFactorSize_;
|
||||||
unsigned totalFactorSize_;
|
unordered_map<VarId, vector<size_t>> varMap_;
|
||||||
unordered_map<VarId, vector<size_t>> varFactors_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // HORUS_VARELIM_H
|
#endif // HORUS_VARELIM_H
|
||||||
|
Reference in New Issue
Block a user