minor improvements

This commit is contained in:
Tiago Gomes 2012-06-19 15:29:09 +01:00
parent ee5b8e693a
commit 57edd6adb9
4 changed files with 45 additions and 90 deletions

View File

@ -15,7 +15,7 @@
cpp_run_ground_solver/3,
cpp_set_vars_information/2,
cpp_set_horus_flag/2,
cpp_free_parfactors/1,
cpp_free_lifted_network/1,
cpp_free_ground_network/1
]).

View File

@ -17,7 +17,7 @@
[cpp_create_lifted_network/3,
cpp_set_parfactors_params/2,
cpp_run_lifted_solver/3,
cpp_free_parfactors/1
cpp_free_lifted_network/1
]).
:- use_module(library('clpbn/display'),
@ -144,5 +144,5 @@ run_horus_lifted_solver(QueryVarsAtts, Solutions, fove(ParfactorList, DistIds))
finalize_horus_lifted_solver(fove(ParfactorList, _)) :-
cpp_free_parfactors(ParfactorList).
cpp_free_lifted_network(ParfactorList).

View File

@ -9,12 +9,8 @@
#include "Var.h"
#include "Horus.h"
using namespace std;
class Var;
class BBNode : public Var
{
public:

View File

@ -20,10 +20,8 @@
using namespace std;
typedef std::pair<ParfactorList*, ObservedFormulas*> LiftedNetwork;
Params readParameters (YAP_Term);
vector<unsigned> readUnsignedList (YAP_Term);
@ -32,14 +30,6 @@ void readLiftedEvidence (YAP_Term, ObservedFormulas&);
Parfactor* readParfactor (YAP_Term);
void runVeSolver (FactorGraph* fg, const vector<VarIds>& tasks,
vector<Params>& results);
void runBeliefProp (FactorGraph* fg, const vector<VarIds>& tasks,
vector<Params>& results);
vector<unsigned>
readUnsignedList (YAP_Term list)
@ -54,7 +44,8 @@ readUnsignedList (YAP_Term list)
int createLiftedNetwork (void)
int
createLiftedNetwork (void)
{
Parfactors parfactors;
YAP_Term parfactorList = YAP_ARG1;
@ -91,7 +82,8 @@ int createLiftedNetwork (void)
Parfactor* readParfactor (YAP_Term pfTerm)
Parfactor*
readParfactor (YAP_Term pfTerm)
{
// read dist id
unsigned distId = YAP_IntOfTerm (YAP_ArgOfTerm (1, pfTerm));
@ -171,7 +163,8 @@ Parfactor* readParfactor (YAP_Term pfTerm)
void readLiftedEvidence (
void
readLiftedEvidence (
YAP_Term observedList,
ObservedFormulas& obsFormulas)
{
@ -360,11 +353,42 @@ runGroundSolver (void)
taskList = YAP_TailOfTerm (taskList);
}
vector<Params> results;
std::set<VarId> vids;
for (size_t i = 0; i < tasks.size(); i++) {
Util::addToSet (vids, tasks[i]);
}
Solver* solver = 0;
FactorGraph* mfg = fg;
if (fg->bayesianFactors()) {
mfg = BayesBall::getMinimalFactorGraph (
*fg, VarIds (vids.begin(), vids.end()));
}
if (Globals::groundSolver == GroundSolver::VE) {
runVeSolver (fg, tasks, results);
solver = new VarElim (*mfg);
} else if (Globals::groundSolver == GroundSolver::BP) {
solver = new BeliefProp (*mfg);
} else if (Globals::groundSolver == GroundSolver::CBP) {
CountingBp::checkForIdenticalFactors = false;
solver = new CountingBp (*mfg);
} else {
runBeliefProp (fg, tasks, results);
assert (false);
}
if (Globals::verbosity > 0) {
solver->printSolverFlags();
cout << endl;
}
vector<Params> results;
results.reserve (tasks.size());
for (size_t i = 0; i < tasks.size(); i++) {
results.push_back (solver->solveQuery (tasks[i]));
}
delete solver;
if (fg->bayesianFactors()) {
delete mfg;
}
YAP_Term list = YAP_TermNil();
@ -385,71 +409,6 @@ runGroundSolver (void)
void runVeSolver (
FactorGraph* fg,
const vector<VarIds>& tasks,
vector<Params>& results)
{
results.reserve (tasks.size());
for (size_t i = 0; i < tasks.size(); i++) {
FactorGraph* mfg = fg;
if (fg->bayesianFactors()) {
mfg = BayesBall::getMinimalFactorGraph (*fg, tasks[i]);
}
VarElim solver (*mfg);
if (Globals::verbosity > 0 && i == 0) {
solver.printSolverFlags();
cout << endl;
}
results.push_back (solver.solveQuery (tasks[i]));
if (fg->bayesianFactors()) {
delete mfg;
}
}
}
void runBeliefProp (
FactorGraph* fg,
const vector<VarIds>& tasks,
vector<Params>& results)
{
std::set<VarId> vids;
for (size_t i = 0; i < tasks.size(); i++) {
Util::addToSet (vids, tasks[i]);
}
Solver* solver = 0;
FactorGraph* mfg = fg;
if (fg->bayesianFactors()) {
mfg = BayesBall::getMinimalFactorGraph (
*fg, VarIds (vids.begin(),vids.end()));
}
if (Globals::groundSolver == GroundSolver::BP) {
solver = new BeliefProp (*mfg);
} else if (Globals::groundSolver == GroundSolver::CBP) {
CountingBp::checkForIdenticalFactors = false;
solver = new CountingBp (*mfg);
} else {
cerr << "error: unknow solver" << endl;
abort();
}
if (Globals::verbosity > 0) {
solver->printSolverFlags();
cout << endl;
}
results.reserve (tasks.size());
for (size_t i = 0; i < tasks.size(); i++) {
results.push_back (solver->solveQuery (tasks[i]));
}
if (fg->bayesianFactors()) {
delete mfg;
}
delete solver;
}
int
setParfactorsParams (void)
{
@ -565,7 +524,7 @@ freeGroundNetwork (void)
int
freeParfactors (void)
freeLiftedNetwork (void)
{
LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1);
delete network->first;
@ -587,7 +546,7 @@ init_predicates (void)
YAP_UserCPredicate ("cpp_cpp_set_factors_params", setFactorsParams, 2);
YAP_UserCPredicate ("cpp_set_vars_information", setVarsInformation, 2);
YAP_UserCPredicate ("cpp_set_horus_flag", setHorusFlag, 2);
YAP_UserCPredicate ("cpp_free_parfactors", freeParfactors, 1);
YAP_UserCPredicate ("cpp_free_lifted_network", freeLiftedNetwork, 1);
YAP_UserCPredicate ("cpp_free_ground_network", freeGroundNetwork, 1);
}