This repository has been archived on 2023-08-20. You can view files and clone it, but cannot push or open issues or pull requests.
yap-6.3/packages/CLPBN/horus/VarElim.cpp
Tiago Gomes ef4ebb4d7f Use camel case for constants and enumerators.
All capitals case should be reserved for macros and besides there is no big need to emphasize constness in general.
2013-02-13 18:54:15 +00:00

205 lines
5.1 KiB
C++

#include <algorithm>
#include <iostream>
#include <sstream>
#include "VarElim.h"
#include "ElimGraph.h"
#include "Factor.h"
#include "Util.h"
namespace Horus {
Params
VarElim::solveQuery (VarIds queryVids)
{
if (Globals::verbosity > 1) {
std::cout << "Solving query on " ;
for (size_t i = 0; i < queryVids.size(); i++) {
if (i != 0) std::cout << ", " ;
std::cout << fg.getVarNode (queryVids[i])->label();
}
std::cout << std::endl;
}
totalFactorSize_ = 0;
largestFactorSize_ = 0;
factorList_.clear();
varMap_.clear();
createFactorList();
absorveEvidence();
Params params = processFactorList (queryVids);
if (Globals::logDomain) {
Util::exp (params);
}
return params;
}
void
VarElim::printSolverFlags (void) const
{
std::stringstream ss;
ss << "variable elimination [" ;
ss << "elim_heuristic=" ;
switch (ElimGraph::elimHeuristic()) {
case ElimHeuristic::sequentialEh: ss << "sequential"; break;
case ElimHeuristic::minNeighborsEh: ss << "min_neighbors"; break;
case ElimHeuristic::minWeightEh: ss << "min_weight"; break;
case ElimHeuristic::minFillEh: ss << "min_fill"; break;
case ElimHeuristic::weightedMinFillEh: ss << "weighted_min_fill"; break;
}
ss << ",log_domain=" << Util::toString (Globals::logDomain);
ss << "]" ;
std::cout << ss.str() << std::endl;
}
void
VarElim::createFactorList (void)
{
const FacNodes& facNodes = fg.facNodes();
factorList_.reserve (facNodes.size() * 2);
for (size_t i = 0; i < facNodes.size(); i++) {
factorList_.push_back (new Factor (facNodes[i]->factor()));
const VarIds& args = facNodes[i]->factor().arguments();
for (size_t j = 0; j < args.size(); j++) {
std::unordered_map<VarId, std::vector<size_t>>::iterator it;
it = varMap_.find (args[j]);
if (it != varMap_.end()) {
it->second.push_back (i);
} else {
varMap_[args[j]] = { i };
}
}
}
}
void
VarElim::absorveEvidence (void)
{
if (Globals::verbosity > 2) {
Util::printDashedLine();
std::cout << "(initial factor list)" << std::endl;
printActiveFactors();
}
const VarNodes& varNodes = fg.varNodes();
for (size_t i = 0; i < varNodes.size(); i++) {
if (varNodes[i]->hasEvidence()) {
if (Globals::verbosity > 1) {
std::cout << "-> aborving evidence on ";
std::cout << varNodes[i]->label() << " = " ;
std::cout << varNodes[i]->getEvidence() << std::endl;
}
const std::vector<size_t>& indices = varMap_[varNodes[i]->varId()];
for (size_t j = 0; j < indices.size(); j++) {
size_t idx = indices[j];
if (factorList_[idx]->nrArguments() > 1) {
factorList_[idx]->absorveEvidence (
varNodes[i]->varId(), varNodes[i]->getEvidence());
} else {
delete factorList_[idx];
factorList_[idx] = 0;
}
}
}
}
}
Params
VarElim::processFactorList (const VarIds& queryVids)
{
VarIds elimOrder = ElimGraph::getEliminationOrder (
factorList_, queryVids);
for (size_t i = 0; i < elimOrder.size(); i++) {
if (Globals::verbosity >= 2) {
if (Globals::verbosity >= 3) {
Util::printDashedLine();
printActiveFactors();
}
std::cout << "-> summing out " ;
std::cout << fg.getVarNode (elimOrder[i])->label() << std::endl;
}
eliminate (elimOrder[i]);
}
Factor result;
for (size_t i = 0; i < factorList_.size(); i++) {
if (factorList_[i]) {
result.multiply (*factorList_[i]);
delete factorList_[i];
factorList_[i] = 0;
}
}
VarIds unobservedVids;
for (size_t i = 0; i < queryVids.size(); i++) {
if (fg.getVarNode (queryVids[i])->hasEvidence() == false) {
unobservedVids.push_back (queryVids[i]);
}
}
result.reorderArguments (unobservedVids);
result.normalize();
if (Globals::verbosity > 0) {
std::cout << "total factor size: " << totalFactorSize_ << std::endl;
std::cout << "largest factor size: " << largestFactorSize_ << std::endl;
std::cout << std::endl;
}
return result.params();
}
void
VarElim::eliminate (VarId vid)
{
Factor* result = new Factor();
const std::vector<size_t>& indices = varMap_[vid];
for (size_t i = 0; i < indices.size(); i++) {
size_t idx = indices[i];
if (factorList_[idx]) {
result->multiply (*factorList_[idx]);
delete factorList_[idx];
factorList_[idx] = 0;
}
}
totalFactorSize_ += result->size();
if (result->size() > largestFactorSize_) {
largestFactorSize_ = result->size();
}
if (result->nrArguments() > 1) {
result->sumOut (vid);
const VarIds& args = result->arguments();
for (size_t i = 0; i < args.size(); i++) {
std::vector<size_t>& indices2 = varMap_[args[i]];
indices2.push_back (factorList_.size());
}
factorList_.push_back (result);
} else {
delete result;
}
}
void
VarElim::printActiveFactors (void)
{
for (size_t i = 0; i < factorList_.size(); i++) {
if (factorList_[i]) {
std::cout << factorList_[i]->getLabel() << " " ;
std::cout << factorList_[i]->params();
std::cout << std::endl;
}
}
}
} // namespace Horus