support setting flags in horus cli and fix a bug when ordering the variables
This commit is contained in:
parent
085ebe1e96
commit
8c689665a0
@ -24,6 +24,12 @@ CbpSolver::CbpSolver (const FactorGraph& fg) : BpSolver (fg)
|
|||||||
Statistics::updateCompressingStatistics (nrGroundVars,
|
Statistics::updateCompressingStatistics (nrGroundVars,
|
||||||
nrGroundFacs, nrClusterVars, nrClusterFacs, nrNeighborless);
|
nrGroundFacs, nrClusterVars, nrClusterFacs, nrNeighborless);
|
||||||
}
|
}
|
||||||
|
// cout << "uncompressed factor graph:" << endl;
|
||||||
|
// cout << " " << fg.nrVarNodes() << " variables " << endl;
|
||||||
|
// cout << " " << fg.nrFacNodes() << " factors " << endl;
|
||||||
|
// cout << "compressed factor graph:" << endl;
|
||||||
|
// cout << " " << fg_->nrVarNodes() << " variables " << endl;
|
||||||
|
// cout << " " << fg_->nrFacNodes() << " factors " << endl;
|
||||||
// Util::printHeader ("Compressed Factor Graph");
|
// Util::printHeader ("Compressed Factor Graph");
|
||||||
// fg_->print();
|
// fg_->print();
|
||||||
// Util::printHeader ("Uncompressed Factor Graph");
|
// Util::printHeader ("Uncompressed Factor Graph");
|
||||||
@ -60,7 +66,7 @@ CbpSolver::printSolverFlags (void) const
|
|||||||
ss << ",max_iter=" << BpOptions::maxIter;
|
ss << ",max_iter=" << BpOptions::maxIter;
|
||||||
ss << ",accuracy=" << BpOptions::accuracy;
|
ss << ",accuracy=" << BpOptions::accuracy;
|
||||||
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
||||||
ss << ",order_vars=" << Util::toString (FactorGraph::orderVariables);
|
ss << ",order_vars=" << Util::toString (FactorGraph::orderVars);
|
||||||
ss << ",chkif=" <<
|
ss << ",chkif=" <<
|
||||||
Util::toString (CFactorGraph::checkForIdenticalFactors);
|
Util::toString (CFactorGraph::checkForIdenticalFactors);
|
||||||
ss << "]" ;
|
ss << "]" ;
|
||||||
@ -69,7 +75,6 @@ CbpSolver::printSolverFlags (void) const
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
Params
|
||||||
CbpSolver::getPosterioriOf (VarId vid)
|
CbpSolver::getPosterioriOf (VarId vid)
|
||||||
{
|
{
|
||||||
|
@ -4,7 +4,7 @@
|
|||||||
|
|
||||||
#include "ElimGraph.h"
|
#include "ElimGraph.h"
|
||||||
|
|
||||||
ElimHeuristic ElimGraph::elimHeuristic_ = MIN_NEIGHBORS;
|
ElimHeuristic ElimGraph::elimHeuristic = MIN_NEIGHBORS;
|
||||||
|
|
||||||
|
|
||||||
ElimGraph::ElimGraph (const vector<Factor*>& factors)
|
ElimGraph::ElimGraph (const vector<Factor*>& factors)
|
||||||
@ -174,7 +174,7 @@ ElimGraph::getLowestCostNode (void) const
|
|||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||||
if (marked_[i]) continue;
|
if (marked_[i]) continue;
|
||||||
unsigned cost = 0;
|
unsigned cost = 0;
|
||||||
switch (elimHeuristic_) {
|
switch (elimHeuristic) {
|
||||||
case MIN_NEIGHBORS:
|
case MIN_NEIGHBORS:
|
||||||
cost = getNeighborsCost (nodes_[i]);
|
cost = getNeighborsCost (nodes_[i]);
|
||||||
break;
|
break;
|
||||||
|
@ -47,15 +47,7 @@ class ElimGraph
|
|||||||
|
|
||||||
static VarIds getEliminationOrder (const vector<Factor*>, VarIds);
|
static VarIds getEliminationOrder (const vector<Factor*>, VarIds);
|
||||||
|
|
||||||
static void setEliminationHeuristic (ElimHeuristic h)
|
static ElimHeuristic elimHeuristic;
|
||||||
{
|
|
||||||
elimHeuristic_ = h;
|
|
||||||
}
|
|
||||||
|
|
||||||
static ElimHeuristic getEliminationHeuristic (void)
|
|
||||||
{
|
|
||||||
return elimHeuristic_;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
||||||
@ -86,7 +78,6 @@ class ElimGraph
|
|||||||
vector<EgNode*> nodes_;
|
vector<EgNode*> nodes_;
|
||||||
vector<bool> marked_;
|
vector<bool> marked_;
|
||||||
unordered_map<VarId, EgNode*> varMap_;
|
unordered_map<VarId, EgNode*> varMap_;
|
||||||
static ElimHeuristic elimHeuristic_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // HORUS_ELIMGRAPH_H
|
#endif // HORUS_ELIMGRAPH_H
|
||||||
|
@ -13,7 +13,7 @@
|
|||||||
#include "Util.h"
|
#include "Util.h"
|
||||||
|
|
||||||
|
|
||||||
bool FactorGraph::orderVariables = false;
|
bool FactorGraph::orderVars = false;
|
||||||
|
|
||||||
|
|
||||||
FactorGraph::FactorGraph (const FactorGraph& fg)
|
FactorGraph::FactorGraph (const FactorGraph& fg)
|
||||||
@ -186,14 +186,17 @@ void
|
|||||||
FactorGraph::addFactor (const Factor& factor)
|
FactorGraph::addFactor (const Factor& factor)
|
||||||
{
|
{
|
||||||
FacNode* fn = new FacNode (factor);
|
FacNode* fn = new FacNode (factor);
|
||||||
|
if (orderVars) {
|
||||||
|
fn->factor().reorderAccordingVarIds();
|
||||||
|
}
|
||||||
addFacNode (fn);
|
addFacNode (fn);
|
||||||
const VarIds& vids = factor.arguments();
|
const VarIds& vids = fn->factor().arguments();
|
||||||
for (unsigned i = 0; i < vids.size(); i++) {
|
for (unsigned i = 0; i < vids.size(); i++) {
|
||||||
VarMap::const_iterator it = varMap_.find (vids[i]);
|
VarMap::const_iterator it = varMap_.find (vids[i]);
|
||||||
if (it != varMap_.end()) {
|
if (it != varMap_.end()) {
|
||||||
addEdge (it->second, fn);
|
addEdge (it->second, fn);
|
||||||
} else {
|
} else {
|
||||||
VarNode* vn = new VarNode (vids[i], factor.range (i));
|
VarNode* vn = new VarNode (vids[i], fn->factor().range (i));
|
||||||
addVarNode (vn);
|
addVarNode (vn);
|
||||||
addEdge (vn, fn);
|
addEdge (vn, fn);
|
||||||
}
|
}
|
||||||
@ -261,6 +264,7 @@ FactorGraph::getStructure (void)
|
|||||||
void
|
void
|
||||||
FactorGraph::print (void) const
|
FactorGraph::print (void) const
|
||||||
{
|
{
|
||||||
|
/*
|
||||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||||
cout << "var id = " << varNodes_[i]->varId() << endl;
|
cout << "var id = " << varNodes_[i]->varId() << endl;
|
||||||
cout << "label = " << varNodes_[i]->label() << endl;
|
cout << "label = " << varNodes_[i]->label() << endl;
|
||||||
@ -272,6 +276,7 @@ FactorGraph::print (void) const
|
|||||||
}
|
}
|
||||||
cout << endl << endl;
|
cout << endl << endl;
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||||
facNodes_[i]->factor().print();
|
facNodes_[i]->factor().print();
|
||||||
}
|
}
|
||||||
|
@ -117,7 +117,7 @@ class FactorGraph
|
|||||||
|
|
||||||
void exportToLibDaiFormat (const char*) const;
|
void exportToLibDaiFormat (const char*) const;
|
||||||
|
|
||||||
static bool orderVariables;
|
static bool orderVars;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// DISALLOW_COPY_AND_ASSIGN (FactorGraph);
|
// DISALLOW_COPY_AND_ASSIGN (FactorGraph);
|
||||||
|
@ -10,43 +10,68 @@
|
|||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
void processArguments (FactorGraph&, int, const char* []);
|
int readHorusFlags (int, const char* []);
|
||||||
|
void readFactorGraph (FactorGraph&, const char*);
|
||||||
|
VarIds readQueryAndEvidence (FactorGraph&, int, const char* [], int);
|
||||||
|
|
||||||
void runSolver (const FactorGraph&, const VarIds&);
|
void runSolver (const FactorGraph&, const VarIds&);
|
||||||
|
|
||||||
const string USAGE = "usage: \
|
const string USAGE = "usage: ./hcli [HORUS_FLAG=VALUE] \
|
||||||
./hcli ve|bp|cbp NETWORK_FILE [VARIABLE | OBSERVED_VARIABLE=EVIDENCE]..." ;
|
NETWORK_FILE [VARIABLE | OBSERVED_VARIABLE=EVIDENCE] ..." ;
|
||||||
|
|
||||||
|
|
||||||
int
|
int
|
||||||
main (int argc, const char* argv[])
|
main (int argc, const char* argv[])
|
||||||
{
|
{
|
||||||
if (argc <= 1) {
|
if (argc <= 1) {
|
||||||
cerr << "error: no solver specified" << endl;
|
|
||||||
cerr << "error: no graphical model specified" << endl;
|
cerr << "error: no graphical model specified" << endl;
|
||||||
cerr << USAGE << endl;
|
cerr << USAGE << endl;
|
||||||
exit (0);
|
exit (0);
|
||||||
}
|
}
|
||||||
if (argc <= 2) {
|
int idx = readHorusFlags (argc, argv);
|
||||||
cerr << "error: no graphical model specified" << endl;
|
|
||||||
cerr << USAGE << endl;
|
|
||||||
exit (0);
|
|
||||||
}
|
|
||||||
string solver (argv[1]);
|
|
||||||
if (solver == "ve") {
|
|
||||||
Globals::infAlgorithm = InfAlgorithms::VE;
|
|
||||||
} else if (solver == "bp") {
|
|
||||||
Globals::infAlgorithm = InfAlgorithms::BP;
|
|
||||||
} else if (solver == "cbp") {
|
|
||||||
Globals::infAlgorithm = InfAlgorithms::CBP;
|
|
||||||
} else {
|
|
||||||
cerr << "error: unknow solver `" << solver << "'" << endl ;
|
|
||||||
cerr << USAGE << endl;
|
|
||||||
exit(0);
|
|
||||||
}
|
|
||||||
string fileName (argv[2]);
|
|
||||||
string extension = fileName.substr (
|
|
||||||
fileName.find_last_of ('.') + 1);
|
|
||||||
FactorGraph fg;
|
FactorGraph fg;
|
||||||
|
readFactorGraph (fg, argv[idx]);
|
||||||
|
VarIds queryIds = readQueryAndEvidence (fg, argc, argv, idx + 1);
|
||||||
|
runSolver (fg, queryIds);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
int
|
||||||
|
readHorusFlags (int argc, const char* argv[])
|
||||||
|
{
|
||||||
|
int i = 1;
|
||||||
|
for (; i < argc; i++) {
|
||||||
|
const string& arg = argv[i];
|
||||||
|
size_t pos = arg.find ('=');
|
||||||
|
if (pos == std::string::npos) {
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
string leftArg = arg.substr (0, pos);
|
||||||
|
string rightArg = arg.substr (pos + 1);
|
||||||
|
if (leftArg.empty()) {
|
||||||
|
cerr << "error: missing left argument" << endl;
|
||||||
|
cerr << USAGE << endl;
|
||||||
|
exit (0);
|
||||||
|
}
|
||||||
|
if (rightArg.empty()) {
|
||||||
|
cerr << "error: missing right argument" << endl;
|
||||||
|
cerr << USAGE << endl;
|
||||||
|
exit (0);
|
||||||
|
}
|
||||||
|
Util::setHorusFlag (leftArg, rightArg);
|
||||||
|
}
|
||||||
|
return i + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
readFactorGraph (FactorGraph& fg, const char* s)
|
||||||
|
{
|
||||||
|
string fileName (s);
|
||||||
|
string extension = fileName.substr (fileName.find_last_of ('.') + 1);
|
||||||
if (extension == "uai") {
|
if (extension == "uai") {
|
||||||
fg.readFromUaiFormat (fileName.c_str());
|
fg.readFromUaiFormat (fileName.c_str());
|
||||||
} else if (extension == "fg") {
|
} else if (extension == "fg") {
|
||||||
@ -56,90 +81,78 @@ main (int argc, const char* argv[])
|
|||||||
cerr << "in a UAI or libDAI file" << endl;
|
cerr << "in a UAI or libDAI file" << endl;
|
||||||
exit (0);
|
exit (0);
|
||||||
}
|
}
|
||||||
processArguments (fg, argc, argv);
|
|
||||||
return 0;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
VarIds
|
||||||
processArguments (FactorGraph& fg, int argc, const char* argv[])
|
readQueryAndEvidence (
|
||||||
|
FactorGraph& fg,
|
||||||
|
int argc,
|
||||||
|
const char* argv[],
|
||||||
|
int start)
|
||||||
{
|
{
|
||||||
VarIds queryIds;
|
VarIds queryIds;
|
||||||
for (int i = 3; i < argc; i++) {
|
for (int i = start; i < argc; i++) {
|
||||||
const string& arg = argv[i];
|
const string& arg = argv[i];
|
||||||
if (arg.find ('=') == std::string::npos) {
|
if (arg.find ('=') == std::string::npos) {
|
||||||
if (!Util::isInteger (arg)) {
|
if (Util::isInteger (arg) == false) {
|
||||||
cerr << "error: `" << arg << "' " ;
|
cerr << "error: `" << arg << "' " ;
|
||||||
cerr << "is not a valid variable id" ;
|
|
||||||
cerr << endl;
|
|
||||||
exit (0);
|
|
||||||
}
|
|
||||||
VarId vid;
|
|
||||||
stringstream ss;
|
|
||||||
ss << arg;
|
|
||||||
ss >> vid;
|
|
||||||
VarNode* queryVar = fg.getVarNode (vid);
|
|
||||||
if (queryVar) {
|
|
||||||
queryIds.push_back (vid);
|
|
||||||
} else {
|
|
||||||
cerr << "error: there isn't a variable with " ;
|
|
||||||
cerr << "`" << vid << "' as id" ;
|
|
||||||
cerr << endl;
|
|
||||||
exit (0);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
size_t pos = arg.find ('=');
|
|
||||||
if (arg.substr (0, pos).empty()) {
|
|
||||||
cerr << "error: missing left argument" << endl;
|
|
||||||
cerr << USAGE << endl;
|
|
||||||
exit (0);
|
|
||||||
}
|
|
||||||
if (arg.substr (pos + 1).empty()) {
|
|
||||||
cerr << "error: missing right argument" << endl;
|
|
||||||
cerr << USAGE << endl;
|
|
||||||
exit (0);
|
|
||||||
}
|
|
||||||
if (!Util::isInteger (arg.substr (0, pos))) {
|
|
||||||
cerr << "error: `" << arg.substr (0, pos) << "' " ;
|
|
||||||
cerr << "is not a variable id" ;
|
cerr << "is not a variable id" ;
|
||||||
cerr << endl;
|
cerr << endl;
|
||||||
exit (0);
|
exit (0);
|
||||||
}
|
}
|
||||||
VarId vid;
|
VarId vid = Util::stringToUnsigned (arg);
|
||||||
stringstream ss;
|
VarNode* queryVar = fg.getVarNode (vid);
|
||||||
ss << arg.substr (0, pos);
|
if (queryVar == false) {
|
||||||
ss >> vid;
|
cerr << "error: unknow variable with id " ;
|
||||||
VarNode* var = fg.getVarNode (vid);
|
cerr << "`" << vid << "'" << endl;
|
||||||
if (var) {
|
|
||||||
if (!Util::isInteger (arg.substr (pos + 1))) {
|
|
||||||
cerr << "error: `" << arg.substr (pos + 1) << "' " ;
|
|
||||||
cerr << "is not a state index" ;
|
|
||||||
cerr << endl;
|
|
||||||
exit (0);
|
|
||||||
}
|
|
||||||
int stateIndex;
|
|
||||||
stringstream ss;
|
|
||||||
ss << arg.substr (pos + 1);
|
|
||||||
ss >> stateIndex;
|
|
||||||
if (var->isValidState (stateIndex)) {
|
|
||||||
var->setEvidence (stateIndex);
|
|
||||||
} else {
|
|
||||||
cerr << "error: `" << stateIndex << "' " ;
|
|
||||||
cerr << "is not a valid state index for variable " ;
|
|
||||||
cerr << "`" << var->varId() << "'" ;
|
|
||||||
cerr << endl;
|
|
||||||
exit (0);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
cerr << "error: there isn't a variable with " ;
|
|
||||||
cerr << "`" << vid << "' as id" ;
|
|
||||||
cerr << endl;
|
|
||||||
exit (0);
|
exit (0);
|
||||||
}
|
}
|
||||||
|
queryIds.push_back (vid);
|
||||||
|
} else {
|
||||||
|
size_t pos = arg.find ('=');
|
||||||
|
string leftArg = arg.substr (0, pos);
|
||||||
|
string rightArg = arg.substr (pos + 1);
|
||||||
|
if (leftArg.empty()) {
|
||||||
|
cerr << "error: missing left argument" << endl;
|
||||||
|
cerr << USAGE << endl;
|
||||||
|
exit (0);
|
||||||
|
}
|
||||||
|
if (Util::isInteger (leftArg) == false) {
|
||||||
|
cerr << "error: `" << leftArg << "' " ;
|
||||||
|
cerr << "is not a variable id" << endl ;
|
||||||
|
exit (0);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
VarId vid = Util::stringToUnsigned (leftArg);
|
||||||
|
VarNode* observedVar = fg.getVarNode (vid);
|
||||||
|
if (observedVar == false) {
|
||||||
|
cerr << "error: unknow variable with id " ;
|
||||||
|
cerr << "`" << vid << "'" << endl;
|
||||||
|
exit (0);
|
||||||
|
}
|
||||||
|
if (rightArg.empty()) {
|
||||||
|
cerr << "error: missing right argument" << endl;
|
||||||
|
cerr << USAGE << endl;
|
||||||
|
exit (0);
|
||||||
|
}
|
||||||
|
if (Util::isInteger (rightArg) == false) {
|
||||||
|
cerr << "error: `" << rightArg << "' " ;
|
||||||
|
cerr << "is not a state index" << endl ;
|
||||||
|
exit (0);
|
||||||
|
}
|
||||||
|
unsigned stateIdx = Util::stringToUnsigned (rightArg);
|
||||||
|
if (observedVar->isValidState (stateIdx) == false) {
|
||||||
|
cerr << "error: `" << stateIdx << "' " ;
|
||||||
|
cerr << "is not a valid state index for variable with id " ;
|
||||||
|
cerr << "`" << vid << "'" << endl;
|
||||||
|
exit (0);
|
||||||
|
}
|
||||||
|
observedVar->setEvidence (stateIdx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
runSolver (fg, queryIds);
|
return queryIds;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -161,6 +174,8 @@ runSolver (const FactorGraph& fg, const VarIds& queryIds)
|
|||||||
default:
|
default:
|
||||||
assert (false);
|
assert (false);
|
||||||
}
|
}
|
||||||
|
solver->printSolverFlags();
|
||||||
|
cout << endl;
|
||||||
if (queryIds.size() == 0) {
|
if (queryIds.size() == 0) {
|
||||||
solver->printAllPosterioris();
|
solver->printAllPosterioris();
|
||||||
} else {
|
} else {
|
||||||
|
@ -67,12 +67,12 @@ int createLiftedNetwork (void)
|
|||||||
if (Constants::DEBUG > 2) {
|
if (Constants::DEBUG > 2) {
|
||||||
Util::printHeader ("INITIAL PARFACTORS");
|
Util::printHeader ("INITIAL PARFACTORS");
|
||||||
for (unsigned i = 0; i < parfactors.size(); i++) {
|
for (unsigned i = 0; i < parfactors.size(); i++) {
|
||||||
parfactors[i]->print();
|
parfactors[i]->print();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ParfactorList* pfList = new ParfactorList (parfactors);
|
ParfactorList* pfList = new ParfactorList (parfactors);
|
||||||
|
|
||||||
if (Constants::DEBUG >= 2) {
|
if (Constants::DEBUG >= 2) {
|
||||||
Util::printHeader ("SHATTERED PARFACTORS");
|
Util::printHeader ("SHATTERED PARFACTORS");
|
||||||
pfList->print();
|
pfList->print();
|
||||||
@ -402,10 +402,10 @@ void runBpSolver (
|
|||||||
*fg, VarIds (vids.begin(),vids.end()));
|
*fg, VarIds (vids.begin(),vids.end()));
|
||||||
}
|
}
|
||||||
if (Globals::infAlgorithm == InfAlgorithms::BP) {
|
if (Globals::infAlgorithm == InfAlgorithms::BP) {
|
||||||
solver = new BpSolver (*mfg);
|
solver = new BpSolver (*fg); // FIXME
|
||||||
} else if (Globals::infAlgorithm == InfAlgorithms::CBP) {
|
} else if (Globals::infAlgorithm == InfAlgorithms::CBP) {
|
||||||
CFactorGraph::checkForIdenticalFactors = false;
|
CFactorGraph::checkForIdenticalFactors = false;
|
||||||
solver = new CbpSolver (*mfg);
|
solver = new CbpSolver (*fg); // FIXME
|
||||||
} else {
|
} else {
|
||||||
cerr << "error: unknow solver" << endl;
|
cerr << "error: unknow solver" << endl;
|
||||||
abort();
|
abort();
|
||||||
@ -507,80 +507,19 @@ int
|
|||||||
setHorusFlag (void)
|
setHorusFlag (void)
|
||||||
{
|
{
|
||||||
string key ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG1)));
|
string key ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG1)));
|
||||||
if (key == "inf_alg") {
|
string value;
|
||||||
string value ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG2)));
|
if (key == "accuracy") {
|
||||||
if ( value == "ve") {
|
stringstream ss;
|
||||||
Globals::infAlgorithm = InfAlgorithms::VE;
|
ss << (float) YAP_FloatOfTerm (YAP_ARG2);
|
||||||
} else if (value == "bp") {
|
ss >> value;
|
||||||
Globals::infAlgorithm = InfAlgorithms::BP;
|
|
||||||
} else if (value == "cbp") {
|
|
||||||
Globals::infAlgorithm = InfAlgorithms::CBP;
|
|
||||||
} else {
|
|
||||||
cerr << "warning: invalid value `" << value << "' " ;
|
|
||||||
cerr << "for `" << key << "'" << endl;
|
|
||||||
return FALSE;
|
|
||||||
}
|
|
||||||
} else if (key == "elim_heuristic") {
|
|
||||||
string value ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG2)));
|
|
||||||
if ( value == "min_neighbors") {
|
|
||||||
ElimGraph::setEliminationHeuristic (ElimHeuristic::MIN_NEIGHBORS);
|
|
||||||
} else if (value == "min_weight") {
|
|
||||||
ElimGraph::setEliminationHeuristic (ElimHeuristic::MIN_WEIGHT);
|
|
||||||
} else if (value == "min_fill") {
|
|
||||||
ElimGraph::setEliminationHeuristic (ElimHeuristic::MIN_FILL);
|
|
||||||
} else if (value == "weighted_min_fill") {
|
|
||||||
ElimGraph::setEliminationHeuristic (ElimHeuristic::WEIGHTED_MIN_FILL);
|
|
||||||
} else {
|
|
||||||
cerr << "warning: invalid value `" << value << "' " ;
|
|
||||||
cerr << "for `" << key << "'" << endl;
|
|
||||||
return FALSE;
|
|
||||||
}
|
|
||||||
} else if (key == "schedule") {
|
|
||||||
string value ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG2)));
|
|
||||||
if ( value == "seq_fixed") {
|
|
||||||
BpOptions::schedule = BpOptions::Schedule::SEQ_FIXED;
|
|
||||||
} else if (value == "seq_random") {
|
|
||||||
BpOptions::schedule = BpOptions::Schedule::SEQ_RANDOM;
|
|
||||||
} else if (value == "parallel") {
|
|
||||||
BpOptions::schedule = BpOptions::Schedule::PARALLEL;
|
|
||||||
} else if (value == "max_residual") {
|
|
||||||
BpOptions::schedule = BpOptions::Schedule::MAX_RESIDUAL;
|
|
||||||
} else {
|
|
||||||
cerr << "warning: invalid value `" << value << "' " ;
|
|
||||||
cerr << "for `" << key << "'" << endl;
|
|
||||||
return FALSE;
|
|
||||||
}
|
|
||||||
} else if (key == "accuracy") {
|
|
||||||
BpOptions::accuracy = (double) YAP_FloatOfTerm (YAP_ARG2);
|
|
||||||
} else if (key == "max_iter") {
|
} else if (key == "max_iter") {
|
||||||
BpOptions::maxIter = (int) YAP_IntOfTerm (YAP_ARG2);
|
stringstream ss;
|
||||||
} else if (key == "use_logarithms") {
|
ss << (int) YAP_IntOfTerm (YAP_ARG2);
|
||||||
string value ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG2)));
|
ss >> value;
|
||||||
if ( value == "true") {
|
|
||||||
Globals::logDomain = true;
|
|
||||||
} else if (value == "false") {
|
|
||||||
Globals::logDomain = false;
|
|
||||||
} else {
|
|
||||||
cerr << "warning: invalid value `" << value << "' " ;
|
|
||||||
cerr << "for `" << key << "'" << endl;
|
|
||||||
return FALSE;
|
|
||||||
}
|
|
||||||
} else if (key == "order_variables") {
|
|
||||||
string value ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG2)));
|
|
||||||
if ( value == "true") {
|
|
||||||
FactorGraph::orderVariables = true;
|
|
||||||
} else if (value == "false") {
|
|
||||||
FactorGraph::orderVariables = false;
|
|
||||||
} else {
|
|
||||||
cerr << "warning: invalid value `" << value << "' " ;
|
|
||||||
cerr << "for `" << key << "'" << endl;
|
|
||||||
return FALSE;
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
cerr << "warning: invalid key `" << key << "'" << endl;
|
value = ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG2)));
|
||||||
return FALSE;
|
|
||||||
}
|
}
|
||||||
return TRUE;
|
return Util::setHorusFlag (key, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
|
|
||||||
#include "Util.h"
|
#include "Util.h"
|
||||||
#include "Indexer.h"
|
#include "Indexer.h"
|
||||||
|
#include "ElimGraph.h"
|
||||||
|
|
||||||
|
|
||||||
namespace Globals {
|
namespace Globals {
|
||||||
@ -48,6 +49,34 @@ toString (const bool& b)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
unsigned
|
||||||
|
stringToUnsigned (string str)
|
||||||
|
{
|
||||||
|
int val;
|
||||||
|
stringstream ss;
|
||||||
|
ss << str;
|
||||||
|
ss >> val;
|
||||||
|
if (val < 0) {
|
||||||
|
cerr << "error: the value tried to read is negative" << endl;
|
||||||
|
abort();
|
||||||
|
}
|
||||||
|
return static_cast<unsigned> (val);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
double
|
||||||
|
stringToDouble (string str)
|
||||||
|
{
|
||||||
|
double val;
|
||||||
|
stringstream ss;
|
||||||
|
ss << str;
|
||||||
|
ss >> val;
|
||||||
|
return val;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
toLog (Params& v)
|
toLog (Params& v)
|
||||||
{
|
{
|
||||||
@ -194,6 +223,87 @@ getStateLines (const Vars& vars)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
setHorusFlag (string key, string value)
|
||||||
|
{
|
||||||
|
bool returnVal = true;
|
||||||
|
if (key == "inf_alg") {
|
||||||
|
if ( value == "ve") {
|
||||||
|
Globals::infAlgorithm = InfAlgorithms::VE;
|
||||||
|
} else if (value == "bp") {
|
||||||
|
Globals::infAlgorithm = InfAlgorithms::BP;
|
||||||
|
} else if (value == "cbp") {
|
||||||
|
Globals::infAlgorithm = InfAlgorithms::CBP;
|
||||||
|
} else {
|
||||||
|
cerr << "warning: invalid value `" << value << "' " ;
|
||||||
|
cerr << "for `" << key << "'" << endl;
|
||||||
|
returnVal = false;
|
||||||
|
}
|
||||||
|
} else if (key == "elim_heuristic") {
|
||||||
|
if ( value == "min_neighbors") {
|
||||||
|
ElimGraph::elimHeuristic = ElimHeuristic::MIN_NEIGHBORS;
|
||||||
|
} else if (value == "min_weight") {
|
||||||
|
ElimGraph::elimHeuristic = ElimHeuristic::MIN_WEIGHT;
|
||||||
|
} else if (value == "min_fill") {
|
||||||
|
ElimGraph::elimHeuristic = ElimHeuristic::MIN_FILL;
|
||||||
|
} else if (value == "weighted_min_fill") {
|
||||||
|
ElimGraph::elimHeuristic = ElimHeuristic::WEIGHTED_MIN_FILL;
|
||||||
|
} else {
|
||||||
|
cerr << "warning: invalid value `" << value << "' " ;
|
||||||
|
cerr << "for `" << key << "'" << endl;
|
||||||
|
returnVal = false;
|
||||||
|
}
|
||||||
|
} else if (key == "schedule") {
|
||||||
|
if ( value == "seq_fixed") {
|
||||||
|
BpOptions::schedule = BpOptions::Schedule::SEQ_FIXED;
|
||||||
|
} else if (value == "seq_random") {
|
||||||
|
BpOptions::schedule = BpOptions::Schedule::SEQ_RANDOM;
|
||||||
|
} else if (value == "parallel") {
|
||||||
|
BpOptions::schedule = BpOptions::Schedule::PARALLEL;
|
||||||
|
} else if (value == "max_residual") {
|
||||||
|
BpOptions::schedule = BpOptions::Schedule::MAX_RESIDUAL;
|
||||||
|
} else {
|
||||||
|
cerr << "warning: invalid value `" << value << "' " ;
|
||||||
|
cerr << "for `" << key << "'" << endl;
|
||||||
|
returnVal = false;
|
||||||
|
}
|
||||||
|
} else if (key == "accuracy") {
|
||||||
|
stringstream ss;
|
||||||
|
ss << value;
|
||||||
|
ss >> BpOptions::accuracy;
|
||||||
|
} else if (key == "max_iter") {
|
||||||
|
stringstream ss;
|
||||||
|
ss << value;
|
||||||
|
ss >> BpOptions::maxIter;
|
||||||
|
} else if (key == "use_logarithms") {
|
||||||
|
if ( value == "true") {
|
||||||
|
Globals::logDomain = true;
|
||||||
|
} else if (value == "false") {
|
||||||
|
Globals::logDomain = false;
|
||||||
|
} else {
|
||||||
|
cerr << "warning: invalid value `" << value << "' " ;
|
||||||
|
cerr << "for `" << key << "'" << endl;
|
||||||
|
returnVal = false;
|
||||||
|
}
|
||||||
|
} else if (key == "order_vars") {
|
||||||
|
if ( value == "true") {
|
||||||
|
FactorGraph::orderVars = true;
|
||||||
|
} else if (value == "false") {
|
||||||
|
FactorGraph::orderVars = false;
|
||||||
|
} else {
|
||||||
|
cerr << "warning: invalid value `" << value << "' " ;
|
||||||
|
cerr << "for `" << key << "'" << endl;
|
||||||
|
returnVal = false;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
cerr << "warning: invalid key `" << key << "'" << endl;
|
||||||
|
returnVal = false;
|
||||||
|
}
|
||||||
|
return returnVal;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
printHeader (string header, std::ostream& os)
|
printHeader (string header, std::ostream& os)
|
||||||
{
|
{
|
||||||
|
@ -38,6 +38,10 @@ template <typename T> std::string toString (const T&);
|
|||||||
|
|
||||||
template <> std::string toString (const bool&);
|
template <> std::string toString (const bool&);
|
||||||
|
|
||||||
|
unsigned stringToUnsigned (string);
|
||||||
|
|
||||||
|
double stringToDouble (string);
|
||||||
|
|
||||||
void toLog (Params&);
|
void toLog (Params&);
|
||||||
|
|
||||||
void fromLog (Params&);
|
void fromLog (Params&);
|
||||||
@ -52,6 +56,8 @@ void add (Params&, const Params&);
|
|||||||
|
|
||||||
void add (Params&, const Params&, unsigned);
|
void add (Params&, const Params&, unsigned);
|
||||||
|
|
||||||
|
unsigned maxUnsigned (void);
|
||||||
|
|
||||||
double factorial (unsigned);
|
double factorial (unsigned);
|
||||||
|
|
||||||
double logFactorial (unsigned);
|
double logFactorial (unsigned);
|
||||||
@ -68,6 +74,8 @@ string parametersToString (const Params&, unsigned = Constants::PRECISION);
|
|||||||
|
|
||||||
vector<string> getStateLines (const Vars&);
|
vector<string> getStateLines (const Vars&);
|
||||||
|
|
||||||
|
bool setHorusFlag (string key, string value);
|
||||||
|
|
||||||
void printHeader (string, std::ostream& os = std::cout);
|
void printHeader (string, std::ostream& os = std::cout);
|
||||||
|
|
||||||
void printSubHeader (string, std::ostream& os = std::cout);
|
void printSubHeader (string, std::ostream& os = std::cout);
|
||||||
@ -76,8 +84,6 @@ void printAsteriskLine (std::ostream& os = std::cout);
|
|||||||
|
|
||||||
void printDashedLine (std::ostream& os = std::cout);
|
void printDashedLine (std::ostream& os = std::cout);
|
||||||
|
|
||||||
unsigned maxUnsigned (void);
|
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
@ -38,7 +38,7 @@ VarElimSolver::printSolverFlags (void) const
|
|||||||
stringstream ss;
|
stringstream ss;
|
||||||
ss << "variable elimination [" ;
|
ss << "variable elimination [" ;
|
||||||
ss << "elim_heuristic=" ;
|
ss << "elim_heuristic=" ;
|
||||||
ElimHeuristic eh = ElimGraph::getEliminationHeuristic();
|
ElimHeuristic eh = ElimGraph::elimHeuristic;
|
||||||
switch (eh) {
|
switch (eh) {
|
||||||
case MIN_NEIGHBORS: ss << "min_neighbors"; break;
|
case MIN_NEIGHBORS: ss << "min_neighbors"; break;
|
||||||
case MIN_WEIGHT: ss << "min_weight"; break;
|
case MIN_WEIGHT: ss << "min_weight"; break;
|
||||||
|
@ -59,8 +59,8 @@ set_solver(S) :- throw(error('unknow solver ', S)).
|
|||||||
|
|
||||||
:- set_horus_flag(max_iter, 1000).
|
:- set_horus_flag(max_iter, 1000).
|
||||||
|
|
||||||
:- set_horus_flag(order_variables, false).
|
:- set_horus_flag(order_vars, false).
|
||||||
%:- set_horus_flag(order_variables, true).
|
%:- set_horus_flag(order_vars, true).
|
||||||
|
|
||||||
:- set_horus_flag(use_logarithms, false).
|
:- set_horus_flag(use_logarithms, false).
|
||||||
% :- set_horus_flag(use_logarithms, true).
|
% :- set_horus_flag(use_logarithms, true).
|
||||||
|
Reference in New Issue
Block a user