This repository has been archived on 2023-08-20. You can view files and clone it, but cannot push or open issues or pull requests.
yap-6.3/packages/CLPBN/horus/VarElim.cpp

206 lines
5.2 KiB
C++
Raw Permalink Normal View History

2012-05-23 14:56:01 +01:00
#include <algorithm>
2013-02-07 20:09:10 +00:00
#include <iostream>
#include <sstream>
2012-05-23 14:56:01 +01:00
#include "VarElim.h"
2012-05-23 14:56:01 +01:00
#include "ElimGraph.h"
#include "Factor.h"
#include "Util.h"
namespace Horus {
2013-02-07 23:53:13 +00:00
2012-05-23 14:56:01 +01:00
Params
VarElim::solveQuery (VarIds queryVids)
2012-05-23 14:56:01 +01:00
{
if (Globals::verbosity > 1) {
2013-02-07 13:37:15 +00:00
std::cout << "Solving query on " ;
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < queryVids.size(); i++) {
2013-02-07 13:37:15 +00:00
if (i != 0) std::cout << ", " ;
std::cout << fg.getVarNode (queryVids[i])->label();
2012-05-23 14:56:01 +01:00
}
2013-02-07 13:37:15 +00:00
std::cout << std::endl;
2012-05-23 14:56:01 +01:00
}
2012-12-26 22:55:48 +00:00
totalFactorSize_ = 0;
largestFactorSize_ = 0;
2012-05-23 14:56:01 +01:00
factorList_.clear();
2012-12-26 22:55:48 +00:00
varMap_.clear();
2012-05-23 14:56:01 +01:00
createFactorList();
absorveEvidence();
2012-12-26 22:55:48 +00:00
Params params = processFactorList (queryVids);
if (Globals::logDomain) {
Util::exp (params);
2012-05-23 14:56:01 +01:00
}
return params;
}
void
VarElim::printSolverFlags() const
2012-05-23 14:56:01 +01:00
{
2013-02-07 13:37:15 +00:00
std::stringstream ss;
2012-05-23 14:56:01 +01:00
ss << "variable elimination [" ;
ss << "elim_heuristic=" ;
typedef ElimGraph::ElimHeuristic ElimHeuristic;
switch (ElimGraph::elimHeuristic()) {
case ElimHeuristic::sequentialEh: ss << "sequential"; break;
case ElimHeuristic::minNeighborsEh: ss << "min_neighbors"; break;
case ElimHeuristic::minWeightEh: ss << "min_weight"; break;
case ElimHeuristic::minFillEh: ss << "min_fill"; break;
case ElimHeuristic::weightedMinFillEh: ss << "weighted_min_fill"; break;
2012-05-23 14:56:01 +01:00
}
ss << ",log_domain=" << Util::toString (Globals::logDomain);
2012-05-23 14:56:01 +01:00
ss << "]" ;
2013-02-07 13:37:15 +00:00
std::cout << ss.str() << std::endl;
2012-05-23 14:56:01 +01:00
}
void
VarElim::createFactorList()
2012-05-23 14:56:01 +01:00
{
const FacNodes& facNodes = fg.facNodes();
factorList_.reserve (facNodes.size() * 2);
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < facNodes.size(); i++) {
2012-05-23 14:56:01 +01:00
factorList_.push_back (new Factor (facNodes[i]->factor()));
2012-12-26 22:55:48 +00:00
const VarIds& args = facNodes[i]->factor().arguments();
for (size_t j = 0; j < args.size(); j++) {
2013-02-07 13:37:15 +00:00
std::unordered_map<VarId, std::vector<size_t>>::iterator it;
2012-12-26 22:55:48 +00:00
it = varMap_.find (args[j]);
if (it != varMap_.end()) {
it->second.push_back (i);
} else {
varMap_[args[j]] = { i };
2012-05-23 14:56:01 +01:00
}
}
}
2012-12-17 18:39:42 +00:00
}
2012-05-23 14:56:01 +01:00
void
VarElim::absorveEvidence()
2012-05-23 14:56:01 +01:00
{
if (Globals::verbosity > 2) {
Util::printDashedLine();
2013-02-07 13:37:15 +00:00
std::cout << "(initial factor list)" << std::endl;
2012-05-23 14:56:01 +01:00
printActiveFactors();
}
const VarNodes& varNodes = fg.varNodes();
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < varNodes.size(); i++) {
2012-05-23 14:56:01 +01:00
if (varNodes[i]->hasEvidence()) {
if (Globals::verbosity > 1) {
2013-02-07 13:37:15 +00:00
std::cout << "-> aborving evidence on ";
std::cout << varNodes[i]->label() << " = " ;
std::cout << varNodes[i]->getEvidence() << std::endl;
2012-05-23 14:56:01 +01:00
}
2013-02-07 13:37:15 +00:00
const std::vector<size_t>& indices = varMap_[varNodes[i]->varId()];
2012-12-26 22:55:48 +00:00
for (size_t j = 0; j < indices.size(); j++) {
size_t idx = indices[j];
if (factorList_[idx]->nrArguments() > 1) {
factorList_[idx]->absorveEvidence (
2012-05-23 14:56:01 +01:00
varNodes[i]->varId(), varNodes[i]->getEvidence());
2012-12-26 22:55:48 +00:00
} else {
delete factorList_[idx];
factorList_[idx] = 0;
2012-05-23 14:56:01 +01:00
}
}
}
}
}
2012-12-26 22:55:48 +00:00
Params
VarElim::processFactorList (const VarIds& queryVids)
2012-05-23 14:56:01 +01:00
{
2012-12-26 22:55:48 +00:00
VarIds elimOrder = ElimGraph::getEliminationOrder (
factorList_, queryVids);
for (size_t i = 0; i < elimOrder.size(); i++) {
if (Globals::verbosity >= 2) {
if (Globals::verbosity >= 3) {
Util::printDashedLine();
2012-05-23 14:56:01 +01:00
printActiveFactors();
}
2013-02-07 13:37:15 +00:00
std::cout << "-> summing out " ;
std::cout << fg.getVarNode (elimOrder[i])->label() << std::endl;
2012-05-23 14:56:01 +01:00
}
2012-12-26 22:55:48 +00:00
eliminate (elimOrder[i]);
2012-05-23 14:56:01 +01:00
}
2012-12-26 22:55:48 +00:00
Factor result;
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < factorList_.size(); i++) {
2012-05-23 14:56:01 +01:00
if (factorList_[i]) {
2012-12-26 22:55:48 +00:00
result.multiply (*factorList_[i]);
2012-05-23 14:56:01 +01:00
delete factorList_[i];
factorList_[i] = 0;
}
}
VarIds unobservedVids;
2012-12-26 22:55:48 +00:00
for (size_t i = 0; i < queryVids.size(); i++) {
if (fg.getVarNode (queryVids[i])->hasEvidence() == false) {
unobservedVids.push_back (queryVids[i]);
2012-05-23 14:56:01 +01:00
}
}
2012-12-26 22:55:48 +00:00
result.reorderArguments (unobservedVids);
result.normalize();
if (Globals::verbosity > 0) {
2013-02-07 13:37:15 +00:00
std::cout << "total factor size: " << totalFactorSize_ << std::endl;
std::cout << "largest factor size: " << largestFactorSize_ << std::endl;
std::cout << std::endl;
2012-05-23 14:56:01 +01:00
}
2012-12-26 22:55:48 +00:00
return result.params();
2012-05-23 14:56:01 +01:00
}
void
2012-12-26 22:55:48 +00:00
VarElim::eliminate (VarId vid)
2012-05-23 14:56:01 +01:00
{
2012-12-26 22:55:48 +00:00
Factor* result = new Factor();
2013-02-07 13:37:15 +00:00
const std::vector<size_t>& indices = varMap_[vid];
2012-12-26 22:55:48 +00:00
for (size_t i = 0; i < indices.size(); i++) {
size_t idx = indices[i];
2012-05-23 14:56:01 +01:00
if (factorList_[idx]) {
2012-12-26 22:55:48 +00:00
result->multiply (*factorList_[idx]);
2012-05-23 14:56:01 +01:00
delete factorList_[idx];
factorList_[idx] = 0;
}
}
totalFactorSize_ += result->size();
if (result->size() > largestFactorSize_) {
largestFactorSize_ = result->size();
}
2012-12-26 22:55:48 +00:00
if (result->nrArguments() > 1) {
result->sumOut (vid);
const VarIds& args = result->arguments();
for (size_t i = 0; i < args.size(); i++) {
2013-02-07 13:37:15 +00:00
std::vector<size_t>& indices2 = varMap_[args[i]];
2012-12-26 22:55:48 +00:00
indices2.push_back (factorList_.size());
2012-05-23 14:56:01 +01:00
}
2012-12-26 22:55:48 +00:00
factorList_.push_back (result);
} else {
delete result;
2012-05-23 14:56:01 +01:00
}
}
void
VarElim::printActiveFactors()
2012-05-23 14:56:01 +01:00
{
2012-05-24 22:55:20 +01:00
for (size_t i = 0; i < factorList_.size(); i++) {
2012-12-26 22:55:48 +00:00
if (factorList_[i]) {
2013-02-07 13:37:15 +00:00
std::cout << factorList_[i]->getLabel() << " " ;
std::cout << factorList_[i]->params();
std::cout << std::endl;
2012-05-23 14:56:01 +01:00
}
}
}
} // namespace Horus
2013-02-07 23:53:13 +00:00