diff --git a/packages/CLPBN/horus/HorusYap.cpp b/packages/CLPBN/horus/HorusYap.cpp index 2fa0008fb..b82cf1d76 100644 --- a/packages/CLPBN/horus/HorusYap.cpp +++ b/packages/CLPBN/horus/HorusYap.cpp @@ -276,13 +276,24 @@ readParameters (YAP_Term paramL) int runLiftedSolver (void) { - // TODO one solver instatiation should be used - // to solve several inference tasks LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1); - YAP_Term taskList = YAP_ARG2; - vector results; ParfactorList pfListCopy (*network->first); LiftedOperations::absorveEvidence (pfListCopy, *network->second); + + LiftedSolver* solver = 0; + switch (Globals::liftedSolver) { + case LiftedSolverType::LVE: solver = new LiftedVe (pfListCopy); break; + case LiftedSolverType::LBP: solver = new LiftedBp (pfListCopy); break; + case LiftedSolverType::LKC: solver = new LiftedKc (pfListCopy); break; + } + + if (Globals::verbosity > 0) { + solver->printSolverFlags(); + cout << endl; + } + + YAP_Term taskList = YAP_ARG2; + vector results; while (taskList != YAP_TermNil()) { Grounds queryVars; YAP_Term jointList = YAP_HeadOfTerm (taskList); @@ -308,33 +319,12 @@ runLiftedSolver (void) } jointList = YAP_TailOfTerm (jointList); } - if (Globals::liftedSolver == LiftedSolverType::LVE) { - LiftedVe solver (pfListCopy); - if (Globals::verbosity > 0 && taskList == YAP_ARG2) { - solver.printSolverFlags(); - cout << endl; - } - results.push_back (solver.solveQuery (queryVars)); - } else if (Globals::liftedSolver == LiftedSolverType::LBP) { - LiftedBp solver (pfListCopy); - if (Globals::verbosity > 0 && taskList == YAP_ARG2) { - solver.printSolverFlags(); - cout << endl; - } - results.push_back (solver.solveQuery (queryVars)); - } else if (Globals::liftedSolver == LiftedSolverType::LKC) { - LiftedKc solver (pfListCopy); - if (Globals::verbosity > 0 && taskList == YAP_ARG2) { - solver.printSolverFlags(); - cout << endl; - } - results.push_back (solver.solveQuery (queryVars)); - } else { - assert (false); - } + results.push_back (solver->solveQuery (queryVars)); taskList = YAP_TailOfTerm (taskList); } + delete solver; + YAP_Term list = YAP_TermNil(); for (size_t i = results.size(); i-- > 0; ) { const Params& beliefs = results[i]; @@ -358,6 +348,7 @@ int runGroundSolver (void) { FactorGraph* fg = (FactorGraph*) YAP_IntOfTerm (YAP_ARG1); + vector tasks; YAP_Term taskList = YAP_ARG2; while (taskList != YAP_TermNil()) { @@ -369,22 +360,19 @@ runGroundSolver (void) for (size_t i = 0; i < tasks.size(); i++) { Util::addToSet (vids, tasks[i]); } - GroundSolver* solver = 0; + FactorGraph* mfg = fg; if (fg->bayesianFactors()) { mfg = BayesBall::getMinimalFactorGraph ( *fg, VarIds (vids.begin(), vids.end())); } - if (Globals::groundSolver == GroundSolverType::VE) { - solver = new VarElim (*mfg); - } else if (Globals::groundSolver == GroundSolverType::BP) { - solver = new BeliefProp (*mfg); - } else if (Globals::groundSolver == GroundSolverType::CBP) { - CountingBp::checkForIdenticalFactors = false; - solver = new CountingBp (*mfg); - } else { - assert (false); + GroundSolver* solver = 0; + CountingBp::checkForIdenticalFactors = false; + switch (Globals::groundSolver) { + case GroundSolverType::VE: solver = new VarElim (*mfg); break; + case GroundSolverType::BP: solver = new BeliefProp (*mfg); break; + case GroundSolverType::CBP: solver = new CountingBp (*mfg); break; } if (Globals::verbosity > 0) {