This repository has been archived on 2023-08-20. You can view files and clone it, but cannot push or open issues or pull requests.
yap-6.3/packages/CLPBN/clpbn/bp/Util.cpp

503 lines
9.8 KiB
C++
Raw Normal View History

2012-03-31 23:27:37 +01:00
#include <limits>
#include <sstream>
2012-03-31 23:27:37 +01:00
#include <fstream>
2012-03-22 11:33:24 +00:00
#include "Util.h"
#include "Indexer.h"
namespace Globals {
2012-03-31 23:27:37 +01:00
bool logDomain = false;
2012-03-22 11:33:24 +00:00
//InfAlgs infAlgorithm = InfAlgorithms::VE;
2011-12-12 15:29:51 +00:00
//InfAlgs infAlgorithm = InfAlgorithms::BN_BP;
2012-03-31 23:27:37 +01:00
//InfAlgs infAlgorithm = InfAlgorithms::FG_BP;
InfAlgorithms infAlgorithm = InfAlgorithms::CBP;
};
2011-12-12 15:29:51 +00:00
namespace BpOptions {
Schedule schedule = BpOptions::Schedule::SEQ_FIXED;
//Schedule schedule = BpOptions::Schedule::SEQ_RANDOM;
//Schedule schedule = BpOptions::Schedule::PARALLEL;
//Schedule schedule = BpOptions::Schedule::MAX_RESIDUAL;
2012-04-05 23:00:48 +01:00
double accuracy = 0.0001;
unsigned maxIter = 1000;
}
2011-12-12 15:29:51 +00:00
vector<NetInfo> Statistics::netInfo_;
vector<CompressInfo> Statistics::compressInfo_;
2011-12-12 15:29:51 +00:00
unsigned Statistics::primaryNetCount_;
namespace Util {
void
2012-03-22 11:33:24 +00:00
toLog (Params& v)
{
for (unsigned i = 0; i < v.size(); i++) {
2011-12-12 15:29:51 +00:00
v[i] = log (v[i]);
}
2011-12-12 15:29:51 +00:00
}
void
2012-03-22 11:33:24 +00:00
fromLog (Params& v)
2011-12-12 15:29:51 +00:00
{
for (unsigned i = 0; i < v.size(); i++) {
2011-12-12 15:29:51 +00:00
v[i] = exp (v[i]);
}
}
2011-12-12 15:29:51 +00:00
2012-03-31 23:27:37 +01:00
double
factorial (double num)
2011-12-12 15:29:51 +00:00
{
2012-03-31 23:27:37 +01:00
double result = 1.0;
for (int i = 1; i <= num; i++) {
result *= i;
2012-03-22 11:33:24 +00:00
}
2012-03-31 23:27:37 +01:00
return result;
2012-03-22 11:33:24 +00:00
}
2012-03-31 23:27:37 +01:00
unsigned
nrCombinations (unsigned n, unsigned r)
2012-03-22 11:33:24 +00:00
{
2012-03-31 23:27:37 +01:00
assert (n >= r);
unsigned prod = 1;
for (int i = (int)n; i > (int)(n - r); i--) {
prod *= i;
2011-12-12 15:29:51 +00:00
}
2012-03-31 23:27:37 +01:00
return (prod / factorial (r));
2011-12-12 15:29:51 +00:00
}
2012-03-31 23:27:37 +01:00
unsigned
expectedSize (const Ranges& ranges)
{
2012-03-31 23:27:37 +01:00
unsigned prod = 1;
for (unsigned i = 0; i < ranges.size(); i++) {
prod *= ranges[i];
2011-12-12 15:29:51 +00:00
}
2012-03-31 23:27:37 +01:00
return prod;
}
unsigned
getNumberOfDigits (int number)
{
unsigned count = 1;
while (number >= 10) {
number /= 10;
count ++;
}
2012-03-31 23:27:37 +01:00
return count;
}
2011-12-12 15:29:51 +00:00
2012-03-31 23:27:37 +01:00
bool
isInteger (const string& s)
2011-12-12 15:29:51 +00:00
{
2012-03-31 23:27:37 +01:00
stringstream ss1 (s);
stringstream ss2;
int integer;
ss1 >> integer;
ss2 << integer;
return (ss1.str() == ss2.str());
2012-03-22 11:33:24 +00:00
}
2012-03-31 23:27:37 +01:00
string
parametersToString (const Params& v, unsigned precision)
2012-03-22 11:33:24 +00:00
{
2012-03-31 23:27:37 +01:00
stringstream ss;
ss.precision (precision);
ss << "[" ;
for (unsigned i = 0; i < v.size(); i++) {
if (i != 0) ss << ", " ;
ss << v[i];
2011-12-12 15:29:51 +00:00
}
2012-03-31 23:27:37 +01:00
ss << "]" ;
return ss.str();
2012-03-22 11:33:24 +00:00
}
2012-03-31 23:27:37 +01:00
vector<string>
2012-04-10 15:00:18 +01:00
getStateLines (const Vars& vars)
2012-03-22 11:33:24 +00:00
{
2012-03-31 23:27:37 +01:00
StatesIndexer idx (vars);
vector<string> jointStrings;
while (idx.valid()) {
stringstream ss;
for (unsigned i = 0; i < vars.size(); i++) {
if (i != 0) ss << ", " ;
ss << vars[i]->label() << "=" << vars[i]->states()[(idx[i])];
}
jointStrings.push_back (ss.str());
++ idx;
}
return jointStrings;
}
void
printHeader (string header, std::ostream& os)
2012-03-31 23:27:37 +01:00
{
printAsteriskLine (os);
os << header << endl;
printAsteriskLine (os);
}
void
printSubHeader (string header, std::ostream& os)
2012-03-31 23:27:37 +01:00
{
printDashedLine (os);
os << header << endl;
printDashedLine (os);
}
void
printAsteriskLine (std::ostream& os)
2012-03-31 23:27:37 +01:00
{
os << "********************************" ;
os << "********************************" ;
os << endl;
}
void
printDashedLine (std::ostream& os)
2012-03-31 23:27:37 +01:00
{
os << "--------------------------------" ;
os << "--------------------------------" ;
os << endl;
}
}
namespace LogAware {
void
normalize (Params& v)
{
double sum;
if (Globals::logDomain) {
sum = LogAware::addIdenty();
for (unsigned i = 0; i < v.size(); i++) {
sum = Util::logSum (sum, v[i]);
}
assert (sum != -numeric_limits<double>::infinity());
for (unsigned i = 0; i < v.size(); i++) {
v[i] -= sum;
}
} else {
sum = 0.0;
for (unsigned i = 0; i < v.size(); i++) {
sum += v[i];
}
assert (sum != 0.0);
for (unsigned i = 0; i < v.size(); i++) {
v[i] /= sum;
}
2012-03-22 11:33:24 +00:00
}
2011-12-12 15:29:51 +00:00
}
double
2012-03-22 11:33:24 +00:00
getL1Distance (const Params& v1, const Params& v2)
{
assert (v1.size() == v2.size());
double dist = 0.0;
2012-03-22 11:33:24 +00:00
if (Globals::logDomain) {
for (unsigned i = 0; i < v1.size(); i++) {
dist += abs (exp(v1[i]) - exp(v2[i]));
}
} else {
for (unsigned i = 0; i < v1.size(); i++) {
dist += abs (v1[i] - v2[i]);
}
}
return dist;
}
2011-12-12 15:29:51 +00:00
double
2012-03-22 11:33:24 +00:00
getMaxNorm (const Params& v1, const Params& v2)
{
assert (v1.size() == v2.size());
double max = 0.0;
2012-03-22 11:33:24 +00:00
if (Globals::logDomain) {
for (unsigned i = 0; i < v1.size(); i++) {
double diff = abs (exp(v1[i]) - exp(v2[i]));
if (diff > max) {
max = diff;
2011-12-12 15:29:51 +00:00
}
2012-03-22 11:33:24 +00:00
}
} else {
for (unsigned i = 0; i < v1.size(); i++) {
double diff = abs (v1[i] - v2[i]);
if (diff > max) {
max = diff;
2011-12-12 15:29:51 +00:00
}
2012-03-22 11:33:24 +00:00
}
}
return max;
}
2012-03-31 23:27:37 +01:00
double
pow (double p, unsigned expoent)
{
return Globals::logDomain ? p * expoent : std::pow (p, expoent);
2011-12-12 15:29:51 +00:00
}
2012-03-31 23:27:37 +01:00
double
pow (double p, double expoent)
{
2012-03-31 23:27:37 +01:00
// assumes that `expoent' is never in log domain
return Globals::logDomain ? p * expoent : std::pow (p, expoent);
}
2012-03-31 23:27:37 +01:00
void
pow (Params& v, unsigned expoent)
{
2012-03-31 23:27:37 +01:00
if (expoent == 1) {
return;
}
if (Globals::logDomain) {
for (unsigned i = 0; i < v.size(); i++) {
v[i] *= expoent;
}
} else {
for (unsigned i = 0; i < v.size(); i++) {
v[i] = std::pow (v[i], expoent);
}
}
}
2012-03-31 23:27:37 +01:00
void
pow (Params& v, double expoent)
2011-12-12 15:29:51 +00:00
{
2012-03-31 23:27:37 +01:00
// assumes that `expoent' is never in log domain
if (Globals::logDomain) {
for (unsigned i = 0; i < v.size(); i++) {
v[i] *= expoent;
}
} else {
for (unsigned i = 0; i < v.size(); i++) {
v[i] = std::pow (v[i], expoent);
2011-12-12 15:29:51 +00:00
}
}
2011-12-12 15:29:51 +00:00
}
2011-12-12 15:29:51 +00:00
}
unsigned
Statistics::getSolvedNetworksCounting (void)
{
return netInfo_.size();
}
void
Statistics::incrementPrimaryNetworksCounting (void)
{
primaryNetCount_ ++;
}
unsigned
Statistics::getPrimaryNetworksCounting (void)
{
return primaryNetCount_;
}
void
2012-03-31 23:27:37 +01:00
Statistics::updateStatistics (
unsigned size,
bool loopy,
unsigned nIters,
double time)
2011-12-12 15:29:51 +00:00
{
netInfo_.push_back (NetInfo (size, loopy, nIters, time));
}
void
Statistics::printStatistics (void)
{
cout << getStatisticString();
}
void
Statistics::writeStatistics (const char* fileName)
2011-12-12 15:29:51 +00:00
{
ofstream out (fileName);
if (!out.is_open()) {
cerr << "error: cannot open file to write at " ;
cerr << "Statistics::writeStats()" << endl;
2011-12-12 15:29:51 +00:00
abort();
}
out << getStatisticString();
out.close();
}
void
2012-03-31 23:27:37 +01:00
Statistics::updateCompressingStatistics (
unsigned nrGroundVars,
unsigned nrGroundFactors,
unsigned nrClusterVars,
unsigned nrClusterFactors,
unsigned nrNeighborless) {
compressInfo_.push_back (CompressInfo (nrGroundVars, nrGroundFactors,
nrClusterVars, nrClusterFactors, nrNeighborless));
2011-12-12 15:29:51 +00:00
}
string
Statistics::getStatisticString (void)
{
stringstream ss2, ss3, ss4, ss1;
ss1 << "running mode: " ;
2012-03-31 23:27:37 +01:00
switch (Globals::infAlgorithm) {
2012-04-05 23:00:48 +01:00
case InfAlgorithms::VE: ss1 << "ve" << endl; break;
case InfAlgorithms::BP: ss1 << "bp" << endl; break;
case InfAlgorithms::CBP: ss1 << "cbp" << endl; break;
2011-12-12 15:29:51 +00:00
}
ss1 << "message schedule: " ;
switch (BpOptions::schedule) {
2012-03-31 23:27:37 +01:00
case BpOptions::Schedule::SEQ_FIXED:
ss1 << "sequential fixed" << endl;
break;
case BpOptions::Schedule::SEQ_RANDOM:
ss1 << "sequential random" << endl;
break;
case BpOptions::Schedule::PARALLEL:
ss1 << "parallel" << endl;
break;
case BpOptions::Schedule::MAX_RESIDUAL:
ss1 << "max residual" << endl;
break;
2011-12-12 15:29:51 +00:00
}
ss1 << "max iterations: " << BpOptions::maxIter << endl;
ss1 << "accuracy " << BpOptions::accuracy << endl;
ss1 << endl << endl;
2012-03-31 23:27:37 +01:00
Util::printSubHeader ("Network information", ss2);
2011-12-12 15:29:51 +00:00
ss2 << left;
ss2 << setw (15) << "Network Size" ;
ss2 << setw (9) << "Loopy" ;
ss2 << setw (15) << "Iterations" ;
ss2 << setw (15) << "Solving Time" ;
ss2 << endl;
unsigned nLoopyNets = 0;
unsigned nUnconvergedRuns = 0;
double totalSolvingTime = 0.0;
for (unsigned i = 0; i < netInfo_.size(); i++) {
ss2 << setw (15) << netInfo_[i].size;
if (netInfo_[i].loopy) {
ss2 << setw (9) << "yes";
nLoopyNets ++;
} else {
ss2 << setw (9) << "no";
}
if (netInfo_[i].nIters == 0) {
ss2 << setw (15) << "n/a" ;
} else {
ss2 << setw (15) << netInfo_[i].nIters;
if (netInfo_[i].nIters > BpOptions::maxIter) {
nUnconvergedRuns ++;
}
}
2011-12-12 15:29:51 +00:00
ss2 << setw (15) << netInfo_[i].time;
totalSolvingTime += netInfo_[i].time;
ss2 << endl;
}
2011-12-12 15:29:51 +00:00
ss2 << endl << endl;
unsigned c1 = 0, c2 = 0, c3 = 0, c4 = 0;
if (compressInfo_.size() > 0) {
2012-03-31 23:27:37 +01:00
Util::printSubHeader ("Compress information", ss3);
2011-12-12 15:29:51 +00:00
ss3 << left;
ss3 << "Ground Cluster Ground Cluster Neighborless" << endl;
ss3 << "Vars Vars Factors Factors Vars" << endl;
for (unsigned i = 0; i < compressInfo_.size(); i++) {
ss3 << setw (9) << compressInfo_[i].nrGroundVars;
ss3 << setw (10) << compressInfo_[i].nrClusterVars;
ss3 << setw (10) << compressInfo_[i].nrGroundFactors;
ss3 << setw (10) << compressInfo_[i].nrClusterFactors;
ss3 << setw (10) << compressInfo_[i].nrNeighborless;
2011-12-12 15:29:51 +00:00
ss3 << endl;
c1 += compressInfo_[i].nrGroundVars - compressInfo_[i].nrNeighborless;
c2 += compressInfo_[i].nrClusterVars;
c3 += compressInfo_[i].nrGroundFactors - compressInfo_[i].nrNeighborless;
c4 += compressInfo_[i].nrClusterFactors;
if (compressInfo_[i].nrNeighborless != 0) {
2011-12-12 15:29:51 +00:00
c2 --;
c4 --;
}
}
ss3 << endl << endl;
}
ss4 << "primary networks: " << primaryNetCount_ << endl;
ss4 << "solved networks: " << netInfo_.size() << endl;
ss4 << "loopy networks: " << nLoopyNets << endl;
ss4 << "unconverged runs: " << nUnconvergedRuns << endl;
ss4 << "total solving time: " << totalSolvingTime << endl;
if (compressInfo_.size() > 0) {
double pc1 = (1.0 - (c2 / (double)c1)) * 100.0;
double pc2 = (1.0 - (c4 / (double)c3)) * 100.0;
ss4 << setprecision (5);
ss4 << "variable compression: " << pc1 << "%" << endl;
ss4 << "factor compression: " << pc2 << "%" << endl;
}
ss4 << endl << endl;
2011-12-12 15:29:51 +00:00
ss1 << ss4.str() << ss2.str() << ss3.str();
return ss1.str();
}