This repository has been archived on 2023-08-20. You can view files and clone it, but cannot push or open issues or pull requests.
yap-6.3/packages/CLPBN/horus/VarElimSolver.cpp

217 lines
5.3 KiB
C++
Raw Normal View History

2012-05-23 14:56:01 +01:00
#include <algorithm>
#include "VarElimSolver.h"
#include "ElimGraph.h"
#include "Factor.h"
#include "Util.h"
VarElimSolver::~VarElimSolver (void)
{
delete factorList_.back();
}
Params
VarElimSolver::solveQuery (VarIds queryVids)
{
if (Globals::verbosity > 1) {
cout << "Solving query on " ;
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < queryVids.size(); i++) {
2012-05-23 14:56:01 +01:00
if (i != 0) cout << ", " ;
cout << fg.getVarNode (queryVids[i])->label();
}
cout << endl;
}
factorList_.clear();
varFactors_.clear();
elimOrder_.clear();
createFactorList();
absorveEvidence();
findEliminationOrder (queryVids);
processFactorList (queryVids);
Params params = factorList_.back()->params();
if (Globals::logDomain) {
2012-05-24 16:14:13 +01:00
Util::exp (params);
2012-05-23 14:56:01 +01:00
}
return params;
}
void
VarElimSolver::printSolverFlags (void) const
{
stringstream ss;
ss << "variable elimination [" ;
ss << "elim_heuristic=" ;
ElimHeuristic eh = ElimGraph::elimHeuristic;
switch (eh) {
case MIN_NEIGHBORS: ss << "min_neighbors"; break;
case MIN_WEIGHT: ss << "min_weight"; break;
case MIN_FILL: ss << "min_fill"; break;
case WEIGHTED_MIN_FILL: ss << "weighted_min_fill"; break;
}
ss << ",log_domain=" << Util::toString (Globals::logDomain);
ss << "]" ;
cout << ss.str() << endl;
}
void
VarElimSolver::createFactorList (void)
{
const FacNodes& facNodes = fg.facNodes();
factorList_.reserve (facNodes.size() * 2);
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < facNodes.size(); i++) {
2012-05-23 14:56:01 +01:00
factorList_.push_back (new Factor (facNodes[i]->factor()));
const VarNodes& neighs = facNodes[i]->neighbors();
2012-05-24 22:55:20 +01:00
for (size_t j = 0; j < neighs.size(); j++) {
unordered_map<VarId, vector<size_t>>::iterator it
2012-05-23 14:56:01 +01:00
= varFactors_.find (neighs[j]->varId());
if (it == varFactors_.end()) {
it = varFactors_.insert (make_pair (
2012-05-24 22:55:20 +01:00
neighs[j]->varId(), vector<size_t>())).first;
2012-05-23 14:56:01 +01:00
}
it->second.push_back (i);
}
}
}
void
VarElimSolver::absorveEvidence (void)
{
if (Globals::verbosity > 2) {
Util::printDashedLine();
cout << "(initial factor list)" << endl;
printActiveFactors();
}
const VarNodes& varNodes = fg.varNodes();
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < varNodes.size(); i++) {
2012-05-23 14:56:01 +01:00
if (varNodes[i]->hasEvidence()) {
if (Globals::verbosity > 1) {
cout << "-> aborving evidence on ";
cout << varNodes[i]->label() << " = " ;
cout << varNodes[i]->getEvidence() << endl;
}
2012-05-24 22:55:20 +01:00
const vector<size_t>& idxs =
2012-05-23 14:56:01 +01:00
varFactors_.find (varNodes[i]->varId())->second;
2012-05-24 22:55:20 +01:00
for (size_t j = 0; j < idxs.size(); j++) {
2012-05-23 14:56:01 +01:00
Factor* factor = factorList_[idxs[j]];
if (factor->nrArguments() == 1) {
factorList_[idxs[j]] = 0;
} else {
factorList_[idxs[j]]->absorveEvidence (
varNodes[i]->varId(), varNodes[i]->getEvidence());
}
}
}
}
}
void
VarElimSolver::findEliminationOrder (const VarIds& vids)
{
elimOrder_ = ElimGraph::getEliminationOrder (factorList_, vids);
}
void
VarElimSolver::processFactorList (const VarIds& vids)
{
totalFactorSize_ = 0;
largestFactorSize_ = 0;
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < elimOrder_.size(); i++) {
2012-05-23 14:56:01 +01:00
if (Globals::verbosity >= 2) {
if (Globals::verbosity >= 3) {
Util::printDashedLine();
printActiveFactors();
}
cout << "-> summing out " ;
cout << fg.getVarNode (elimOrder_[i])->label() << endl;
}
eliminate (elimOrder_[i]);
}
Factor* finalFactor = new Factor();
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < factorList_.size(); i++) {
2012-05-23 14:56:01 +01:00
if (factorList_[i]) {
finalFactor->multiply (*factorList_[i]);
delete factorList_[i];
factorList_[i] = 0;
}
}
VarIds unobservedVids;
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < vids.size(); i++) {
2012-05-23 14:56:01 +01:00
if (fg.getVarNode (vids[i])->hasEvidence() == false) {
unobservedVids.push_back (vids[i]);
}
}
finalFactor->reorderArguments (unobservedVids);
finalFactor->normalize();
factorList_.push_back (finalFactor);
if (Globals::verbosity > 0) {
cout << "total factor size: " << totalFactorSize_ << endl;
cout << "largest factor size: " << largestFactorSize_ << endl;
cout << endl;
}
}
void
VarElimSolver::eliminate (VarId elimVar)
{
Factor* result = 0;
2012-05-24 22:55:20 +01:00
vector<size_t>& idxs = varFactors_.find (elimVar)->second;
for (size_t i = 0; i < idxs.size(); i++) {
size_t idx = idxs[i];
2012-05-23 14:56:01 +01:00
if (factorList_[idx]) {
if (result == 0) {
result = new Factor (*factorList_[idx]);
} else {
result->multiply (*factorList_[idx]);
}
delete factorList_[idx];
factorList_[idx] = 0;
}
}
totalFactorSize_ += result->size();
if (result->size() > largestFactorSize_) {
largestFactorSize_ = result->size();
}
if (result != 0 && result->nrArguments() != 1) {
result->sumOut (elimVar);
factorList_.push_back (result);
const VarIds& resultVarIds = result->arguments();
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < resultVarIds.size(); i++) {
vector<size_t>& idxs =
2012-05-23 14:56:01 +01:00
varFactors_.find (resultVarIds[i])->second;
idxs.push_back (factorList_.size() - 1);
}
}
}
void
VarElimSolver::printActiveFactors (void)
{
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < factorList_.size(); i++) {
2012-05-23 14:56:01 +01:00
if (factorList_[i] != 0) {
cout << factorList_[i]->getLabel() << " " ;
cout << factorList_[i]->params() << endl;
}
}
}