use only 1 lifted solver instance

This commit is contained in:
Tiago Gomes 2012-11-16 16:50:19 +00:00
parent c2791748d2
commit 59fd21bf33

View File

@ -276,13 +276,24 @@ readParameters (YAP_Term paramL)
int int
runLiftedSolver (void) runLiftedSolver (void)
{ {
// TODO one solver instatiation should be used
// to solve several inference tasks
LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1); LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1);
YAP_Term taskList = YAP_ARG2;
vector<Params> results;
ParfactorList pfListCopy (*network->first); ParfactorList pfListCopy (*network->first);
LiftedOperations::absorveEvidence (pfListCopy, *network->second); LiftedOperations::absorveEvidence (pfListCopy, *network->second);
LiftedSolver* solver = 0;
switch (Globals::liftedSolver) {
case LiftedSolverType::LVE: solver = new LiftedVe (pfListCopy); break;
case LiftedSolverType::LBP: solver = new LiftedBp (pfListCopy); break;
case LiftedSolverType::LKC: solver = new LiftedKc (pfListCopy); break;
}
if (Globals::verbosity > 0) {
solver->printSolverFlags();
cout << endl;
}
YAP_Term taskList = YAP_ARG2;
vector<Params> results;
while (taskList != YAP_TermNil()) { while (taskList != YAP_TermNil()) {
Grounds queryVars; Grounds queryVars;
YAP_Term jointList = YAP_HeadOfTerm (taskList); YAP_Term jointList = YAP_HeadOfTerm (taskList);
@ -308,33 +319,12 @@ runLiftedSolver (void)
} }
jointList = YAP_TailOfTerm (jointList); jointList = YAP_TailOfTerm (jointList);
} }
if (Globals::liftedSolver == LiftedSolverType::LVE) { results.push_back (solver->solveQuery (queryVars));
LiftedVe solver (pfListCopy);
if (Globals::verbosity > 0 && taskList == YAP_ARG2) {
solver.printSolverFlags();
cout << endl;
}
results.push_back (solver.solveQuery (queryVars));
} else if (Globals::liftedSolver == LiftedSolverType::LBP) {
LiftedBp solver (pfListCopy);
if (Globals::verbosity > 0 && taskList == YAP_ARG2) {
solver.printSolverFlags();
cout << endl;
}
results.push_back (solver.solveQuery (queryVars));
} else if (Globals::liftedSolver == LiftedSolverType::LKC) {
LiftedKc solver (pfListCopy);
if (Globals::verbosity > 0 && taskList == YAP_ARG2) {
solver.printSolverFlags();
cout << endl;
}
results.push_back (solver.solveQuery (queryVars));
} else {
assert (false);
}
taskList = YAP_TailOfTerm (taskList); taskList = YAP_TailOfTerm (taskList);
} }
delete solver;
YAP_Term list = YAP_TermNil(); YAP_Term list = YAP_TermNil();
for (size_t i = results.size(); i-- > 0; ) { for (size_t i = results.size(); i-- > 0; ) {
const Params& beliefs = results[i]; const Params& beliefs = results[i];
@ -358,6 +348,7 @@ int
runGroundSolver (void) runGroundSolver (void)
{ {
FactorGraph* fg = (FactorGraph*) YAP_IntOfTerm (YAP_ARG1); FactorGraph* fg = (FactorGraph*) YAP_IntOfTerm (YAP_ARG1);
vector<VarIds> tasks; vector<VarIds> tasks;
YAP_Term taskList = YAP_ARG2; YAP_Term taskList = YAP_ARG2;
while (taskList != YAP_TermNil()) { while (taskList != YAP_TermNil()) {
@ -369,22 +360,19 @@ runGroundSolver (void)
for (size_t i = 0; i < tasks.size(); i++) { for (size_t i = 0; i < tasks.size(); i++) {
Util::addToSet (vids, tasks[i]); Util::addToSet (vids, tasks[i]);
} }
GroundSolver* solver = 0;
FactorGraph* mfg = fg; FactorGraph* mfg = fg;
if (fg->bayesianFactors()) { if (fg->bayesianFactors()) {
mfg = BayesBall::getMinimalFactorGraph ( mfg = BayesBall::getMinimalFactorGraph (
*fg, VarIds (vids.begin(), vids.end())); *fg, VarIds (vids.begin(), vids.end()));
} }
if (Globals::groundSolver == GroundSolverType::VE) { GroundSolver* solver = 0;
solver = new VarElim (*mfg);
} else if (Globals::groundSolver == GroundSolverType::BP) {
solver = new BeliefProp (*mfg);
} else if (Globals::groundSolver == GroundSolverType::CBP) {
CountingBp::checkForIdenticalFactors = false; CountingBp::checkForIdenticalFactors = false;
solver = new CountingBp (*mfg); switch (Globals::groundSolver) {
} else { case GroundSolverType::VE: solver = new VarElim (*mfg); break;
assert (false); case GroundSolverType::BP: solver = new BeliefProp (*mfg); break;
case GroundSolverType::CBP: solver = new CountingBp (*mfg); break;
} }
if (Globals::verbosity > 0) { if (Globals::verbosity > 0) {