This commit is contained in:
Vitor Santos Costa 2012-11-22 12:29:37 +00:00
commit 027632456a
48 changed files with 3049 additions and 615 deletions

View File

@ -49,11 +49,12 @@ Inference Options
PFL supports both ground and lifted inference. The inference algorithm PFL supports both ground and lifted inference. The inference algorithm
can be chosen using the set_solver/1 predicate. The following algorithms can be chosen using the set_solver/1 predicate. The following algorithms
are supported: are supported:
- fove: lifted variable elimination with arbitrary constraints (GC-FOVE) - lve: generalized counting first-order variable elimination (GC-FOVE)
- hve: (ground) variable elimination - hve: (ground) variable elimination
- lbp: lifted first-order belief propagation - lbp: lifted first-order belief propagation
- cbp: counting belief propagation - cbp: counting belief propagation
- bp: (ground) belief propagation - bp: (ground) belief propagation
- lkc: lifted first-order knowledge compilation
For example, if we want to use ground variable elimination to solve some For example, if we want to use ground variable elimination to solve some
query, we need to call first the following goal: query, we need to call first the following goal:

View File

@ -3,7 +3,7 @@
source city.sh source city.sh
source ../benchs.sh source ../benchs.sh
SOLVER="fove" SOLVER="lve"
function run_all_graphs function run_all_graphs
{ {
@ -32,5 +32,5 @@ function run_all_graphs
} }
prepare_new_run prepare_new_run
run_all_graphs "fove " run_all_graphs "lve "

View File

@ -3,7 +3,7 @@
source cw.sh source cw.sh
source ../benchs.sh source ../benchs.sh
SOLVER="fove" SOLVER="lve"
function run_all_graphs function run_all_graphs
{ {
@ -26,6 +26,6 @@ function run_all_graphs
} }
prepare_new_run prepare_new_run
run_all_graphs "fove " run_all_graphs "lve "

View File

@ -3,7 +3,7 @@
cd workshop_attrs cd workshop_attrs
source hve_tests.sh source hve_tests.sh
source bp_tests.sh source bp_tests.sh
source fove_tests.sh source lve_tests.sh
source lbp_tests.sh source lbp_tests.sh
source cbp_tests.sh source cbp_tests.sh
cd .. cd ..
@ -11,7 +11,7 @@ cd ..
cd comp_workshops cd comp_workshops
source hve_tests.sh source hve_tests.sh
source bp_tests.sh source bp_tests.sh
source fove_tests.sh source lve_tests.sh
source lbp_tests.sh source lbp_tests.sh
source cbp_tests.sh source cbp_tests.sh
cd .. cd ..
@ -19,7 +19,7 @@ cd ..
cd city cd city
source hve_tests.sh source hve_tests.sh
source bp_tests.sh source bp_tests.sh
source fove_tests.sh source lve_tests.sh
source lbp_tests.sh source lbp_tests.sh
source cbp_tests.sh source cbp_tests.sh
cd .. cd ..
@ -27,7 +27,7 @@ cd ..
cd smokers cd smokers
source hve_tests.sh source hve_tests.sh
source bp_tests.sh source bp_tests.sh
source fove_tests.sh source lve_tests.sh
source lbp_tests.sh source lbp_tests.sh
source cbp_tests.sh source cbp_tests.sh
cd .. cd ..

View File

@ -3,7 +3,7 @@
source sm.sh source sm.sh
source ../benchs.sh source ../benchs.sh
SOLVER="fove" SOLVER="lve"
function run_all_graphs function run_all_graphs
{ {
@ -26,6 +26,6 @@ function run_all_graphs
} }
prepare_new_run prepare_new_run
run_all_graphs "fove " run_all_graphs "lve "

View File

@ -3,7 +3,7 @@
source sm.sh source sm.sh
source ../benchs.sh source ../benchs.sh
SOLVER="fove" SOLVER="lve"
function run_all_graphs function run_all_graphs
{ {
@ -30,6 +30,6 @@ function run_all_graphs
} }
prepare_new_run prepare_new_run
run_all_graphs "fove " run_all_graphs "lve "

View File

@ -3,7 +3,7 @@
source wa.sh source wa.sh
source ../benchs.sh source ../benchs.sh
SOLVER="fove" SOLVER="lve"
function run_all_graphs function run_all_graphs
{ {
@ -32,6 +32,6 @@ function run_all_graphs
} }
prepare_new_run prepare_new_run
run_all_graphs "fove " run_all_graphs "lve "

View File

@ -39,8 +39,9 @@ set_solver(ve) :- !, set_clpbn_flag(solver,ve).
set_solver(bdd) :- !, set_clpbn_flag(solver,bdd). set_solver(bdd) :- !, set_clpbn_flag(solver,bdd).
set_solver(jt) :- !, set_clpbn_flag(solver,jt). set_solver(jt) :- !, set_clpbn_flag(solver,jt).
set_solver(gibbs) :- !, set_clpbn_flag(solver,gibbs). set_solver(gibbs) :- !, set_clpbn_flag(solver,gibbs).
set_solver(fove) :- !, set_clpbn_flag(solver,fove), set_horus_flag(lifted_solver, fove). set_solver(lve) :- !, set_clpbn_flag(solver,fove), set_horus_flag(lifted_solver, lve).
set_solver(lbp) :- !, set_clpbn_flag(solver,fove), set_horus_flag(lifted_solver, lbp). set_solver(lbp) :- !, set_clpbn_flag(solver,fove), set_horus_flag(lifted_solver, lbp).
set_solver(lkc) :- !, set_clpbn_flag(solver,fove), set_horus_flag(lifted_solver, lkc).
set_solver(hve) :- !, set_clpbn_flag(solver,bp), set_horus_flag(ground_solver, ve). set_solver(hve) :- !, set_clpbn_flag(solver,bp), set_horus_flag(ground_solver, ve).
set_solver(bp) :- !, set_clpbn_flag(solver,bp), set_horus_flag(ground_solver, bp). set_solver(bp) :- !, set_clpbn_flag(solver,bp), set_horus_flag(ground_solver, bp).
set_solver(cbp) :- !, set_clpbn_flag(solver,bp), set_horus_flag(ground_solver, cbp). set_solver(cbp) :- !, set_clpbn_flag(solver,bp), set_horus_flag(ground_solver, cbp).

View File

@ -18,7 +18,7 @@ total_students(256).
:- ensure_loaded(parschema). :- ensure_loaded(parschema).
:- yap_flag(unknown,error). :- yap_flag(unknown,error).
%:- clpbn_horus:set_solver(fove). %:- clpbn_horus:set_solver(lve).
%:- clpbn_horus:set_solver(hve). %:- clpbn_horus:set_solver(hve).
:- clpbn_horus:set_solver(bp). :- clpbn_horus:set_solver(bp).
%:- clpbn_horus:set_solver(bdd). %:- clpbn_horus:set_solver(bdd).

View File

@ -1,6 +1,6 @@
:- use_module(library(pfl)). :- use_module(library(pfl)).
%:- set_solver(fove). %:- set_solver(lve).
%:- set_solver(hve). %:- set_solver(hve).
%:- set_solver(bp). %:- set_solver(bp).
%:- set_solver(cbp). %:- set_solver(cbp).

View File

@ -1,6 +1,6 @@
:- use_module(library(pfl)). :- use_module(library(pfl)).
%:- set_solver(fove). %:- set_solver(lve).
%:- set_solver(hve). %:- set_solver(hve).
%:- set_solver(bp). %:- set_solver(bp).
%:- set_solver(cbp). %:- set_solver(cbp).
@ -29,7 +29,7 @@ bayes car_color(P)::[t,f], hair_color(P) ; car_color_table(P); [people(P,_)].
bayes height(P)::[t,f], gender(P) ; height_table(P) ; [people(P,_)]. bayes height(P)::[t,f], gender(P) ; height_table(P) ; [people(P,_)].
bayes shoe_size(P):[t,f], height(P) ; shoe_size_table(P); [people(P,_)]. bayes shoe_size(P)::[t,f], height(P) ; shoe_size_table(P); [people(P,_)].
bayes guilty(P)::[y,n] ; guilty_table(P) ; [people(P,_)]. bayes guilty(P)::[y,n] ; guilty_table(P) ; [people(P,_)].

View File

@ -1,6 +1,6 @@
:- use_module(library(pfl)). :- use_module(library(pfl)).
%:- set_solver(fove). %:- set_solver(lve).
%:- set_solver(hve). %:- set_solver(hve).
%:- set_solver(bp). %:- set_solver(bp).
%:- set_solver(cbp). %:- set_solver(cbp).

View File

@ -1,7 +1,7 @@
:- use_module(library(pfl)). :- use_module(library(pfl)).
:- set_pfl_flag(solver,fove). :- set_pfl_flag(solver,lve).
%:- set_pfl_flag(solver,bp), clpbn_horus:set_horus_flag(inf_alg,ve). %:- set_pfl_flag(solver,bp), clpbn_horus:set_horus_flag(inf_alg,ve).
%:- set_pfl_flag(solver,bp), clpbn_horus:set_horus_flag(inf_alg,bp). %:- set_pfl_flag(solver,bp), clpbn_horus:set_horus_flag(inf_alg,bp).
%:- set_pfl_flag(solver,bp), clpbn_horus:set_horus_flag(inf_alg,cbp). %:- set_pfl_flag(solver,bp), clpbn_horus:set_horus_flag(inf_alg,cbp).

View File

@ -1,6 +1,6 @@
:- use_module(library(pfl)). :- use_module(library(pfl)).
:- set_solver(fove). :- set_solver(lve).
%:- set_solver(hve). %:- set_solver(hve).
%:- set_solver(bp). %:- set_solver(bp).
%:- set_solver(cbp). %:- set_solver(cbp).

View File

@ -1,6 +1,6 @@
:- use_module(library(pfl)). :- use_module(library(pfl)).
%:- set_solver(fove). %:- set_solver(lve).
%:- set_solver(hve). %:- set_solver(hve).
%:- set_solver(bp). %:- set_solver(bp).
%:- set_solver(cbp). %:- set_solver(cbp).

View File

@ -1,6 +1,6 @@
:- use_module(library(pfl)). :- use_module(library(pfl)).
%:- set_solver(fove). %:- set_solver(lve).
%:- set_solver(hve). %:- set_solver(hve).
%:- set_solver(bp). %:- set_solver(bp).
%:- set_solver(cbp). %:- set_solver(cbp).

View File

@ -1,6 +1,6 @@
:- use_module(library(pfl)). :- use_module(library(pfl)).
%:- set_solver(fove). %:- set_solver(lve).
%:- set_solver(hve). %:- set_solver(hve).
%:- set_solver(bp). %:- set_solver(bp).
%:- set_solver(cbp). %:- set_solver(cbp).

View File

@ -12,7 +12,7 @@
#include "Horus.h" #include "Horus.h"
BeliefProp::BeliefProp (const FactorGraph& fg) : Solver (fg) BeliefProp::BeliefProp (const FactorGraph& fg) : GroundSolver (fg)
{ {
runned_ = false; runned_ = false;
} }
@ -377,7 +377,8 @@ BeliefProp::getVarToFactorMsg (const BpLink* link) const
Params Params
BeliefProp::getJointByConditioning (const VarIds& jointVarIds) const BeliefProp::getJointByConditioning (const VarIds& jointVarIds) const
{ {
return Solver::getJointByConditioning (GroundSolver::BP, fg, jointVarIds); return GroundSolver::getJointByConditioning (
GroundSolverType::BP, fg, jointVarIds);
} }

View File

@ -5,7 +5,7 @@
#include <vector> #include <vector>
#include <sstream> #include <sstream>
#include "Solver.h" #include "GroundSolver.h"
#include "Factor.h" #include "Factor.h"
#include "FactorGraph.h" #include "FactorGraph.h"
#include "Util.h" #include "Util.h"
@ -83,7 +83,7 @@ class SPNodeInfo
}; };
class BeliefProp : public Solver class BeliefProp : public GroundSolver
{ {
public: public:
BeliefProp (const FactorGraph&); BeliefProp (const FactorGraph&);

View File

@ -112,6 +112,8 @@ CTNode::copySubtree (const CTNode* root1)
const CTNode* n1 = stack.back().first; const CTNode* n1 = stack.back().first;
CTNode* n2 = stack.back().second; CTNode* n2 = stack.back().second;
stack.pop_back(); stack.pop_back();
// cout << "n2 childs: " << n2->childs();
// cout << "n1 childs: " << n1->childs();
n2->childs().reserve (n1->nrChilds()); n2->childs().reserve (n1->nrChilds());
stack.reserve (n1->nrChilds()); stack.reserve (n1->nrChilds());
for (CTChilds::const_iterator chIt = n1->childs().begin(); for (CTChilds::const_iterator chIt = n1->childs().begin();
@ -185,11 +187,31 @@ ConstraintTree::ConstraintTree (
ConstraintTree::ConstraintTree (vector<vector<string>> names)
{
assert (names.empty() == false);
assert (names.front().empty() == false);
unsigned nrLvs = names[0].size();
for (size_t i = 0; i < nrLvs; i++) {
logVars_.push_back (LogVar (i));
}
root_ = new CTNode (0, 0);
logVarSet_ = LogVarSet (logVars_);
for (size_t i = 0; i < names.size(); i++) {
Tuple t;
for (size_t j = 0; j < names[i].size(); j++) {
assert (names[i].size() == nrLvs);
t.push_back (LiftedUtils::getSymbol (names[i][j]));
}
addTuple (t);
}
}
ConstraintTree::ConstraintTree (const ConstraintTree& ct) ConstraintTree::ConstraintTree (const ConstraintTree& ct)
{ {
root_ = CTNode::copySubtree (ct.root_); *this = ct;
logVars_ = ct.logVars_;
logVarSet_ = ct.logVarSet_;
} }
@ -367,6 +389,16 @@ ConstraintTree::project (const LogVarSet& X)
ConstraintTree
ConstraintTree::projectedCopy (const LogVarSet& X)
{
ConstraintTree copy = *this;
copy.project (X);
return copy;
}
void void
ConstraintTree::remove (const LogVarSet& X) ConstraintTree::remove (const LogVarSet& X)
{ {
@ -865,6 +897,19 @@ ConstraintTree::copyLogVar (LogVar X_1, LogVar X_2)
ConstraintTree&
ConstraintTree::operator= (const ConstraintTree& ct)
{
if (this != &ct) {
root_ = CTNode::copySubtree (ct.root_);
logVars_ = ct.logVars_;
logVarSet_ = ct.logVarSet_;
}
return *this;
}
unsigned unsigned
ConstraintTree::countTuples (const CTNode* n) const ConstraintTree::countTuples (const CTNode* n) const
{ {

View File

@ -108,6 +108,8 @@ class ConstraintTree
ConstraintTree (const LogVars&); ConstraintTree (const LogVars&);
ConstraintTree (const LogVars&, const Tuples&); ConstraintTree (const LogVars&, const Tuples&);
ConstraintTree (vector<vector<string>> names);
ConstraintTree (const ConstraintTree&); ConstraintTree (const ConstraintTree&);
@ -157,6 +159,8 @@ class ConstraintTree
void applySubstitution (const Substitution&); void applySubstitution (const Substitution&);
void project (const LogVarSet&); void project (const LogVarSet&);
ConstraintTree projectedCopy (const LogVarSet&);
void remove (const LogVarSet&); void remove (const LogVarSet&);
@ -197,6 +201,8 @@ class ConstraintTree
ConstraintTrees ground (LogVar); ConstraintTrees ground (LogVar);
void copyLogVar (LogVar,LogVar); void copyLogVar (LogVar,LogVar);
ConstraintTree& operator= (const ConstraintTree& ct);
private: private:
unsigned countTuples (const CTNode*) const; unsigned countTuples (const CTNode*) const;

View File

@ -6,7 +6,7 @@ bool CountingBp::checkForIdenticalFactors = true;
CountingBp::CountingBp (const FactorGraph& fg) CountingBp::CountingBp (const FactorGraph& fg)
: Solver (fg), freeColor_(0) : GroundSolver (fg), freeColor_(0)
{ {
findIdenticalFactors(); findIdenticalFactors();
setInitialColors(); setInitialColors();
@ -74,8 +74,8 @@ CountingBp::solveQuery (VarIds queryVids)
cout << endl; cout << endl;
} }
if (idx == facNodes.size()) { if (idx == facNodes.size()) {
res = Solver::getJointByConditioning ( res = GroundSolver::getJointByConditioning (
GroundSolver::CBP, fg, queryVids); GroundSolverType::CBP, fg, queryVids);
} else { } else {
VarIds reprArgs; VarIds reprArgs;
for (size_t i = 0; i < queryVids.size(); i++) { for (size_t i = 0; i < queryVids.size(); i++) {

View File

@ -3,7 +3,7 @@
#include <unordered_map> #include <unordered_map>
#include "Solver.h" #include "GroundSolver.h"
#include "FactorGraph.h" #include "FactorGraph.h"
#include "Util.h" #include "Util.h"
#include "Horus.h" #include "Horus.h"
@ -102,7 +102,7 @@ class FacCluster
}; };
class CountingBp : public Solver class CountingBp : public GroundSolver
{ {
public: public:
CountingBp (const FactorGraph& fg); CountingBp (const FactorGraph& fg);

View File

@ -1,4 +1,4 @@
#include "Solver.h" #include "GroundSolver.h"
#include "Util.h" #include "Util.h"
#include "BeliefProp.h" #include "BeliefProp.h"
#include "CountingBp.h" #include "CountingBp.h"
@ -6,7 +6,7 @@
void void
Solver::printAnswer (const VarIds& vids) GroundSolver::printAnswer (const VarIds& vids)
{ {
Vars unobservedVars; Vars unobservedVars;
VarIds unobservedVids; VarIds unobservedVids;
@ -32,7 +32,7 @@ Solver::printAnswer (const VarIds& vids)
void void
Solver::printAllPosterioris (void) GroundSolver::printAllPosterioris (void)
{ {
VarNodes vars = fg.varNodes(); VarNodes vars = fg.varNodes();
std::sort (vars.begin(), vars.end(), sortByVarId()); std::sort (vars.begin(), vars.end(), sortByVarId());
@ -44,8 +44,8 @@ Solver::printAllPosterioris (void)
Params Params
Solver::getJointByConditioning ( GroundSolver::getJointByConditioning (
GroundSolver solverType, GroundSolverType solverType,
FactorGraph fg, FactorGraph fg,
const VarIds& jointVarIds) const const VarIds& jointVarIds) const
{ {
@ -55,11 +55,11 @@ Solver::getJointByConditioning (
jointVars.push_back (fg.getVarNode (jointVarIds[i])); jointVars.push_back (fg.getVarNode (jointVarIds[i]));
} }
Solver* solver = 0; GroundSolver* solver = 0;
switch (solverType) { switch (solverType) {
case GroundSolver::BP: solver = new BeliefProp (fg); break; case GroundSolverType::BP: solver = new BeliefProp (fg); break;
case GroundSolver::CBP: solver = new CountingBp (fg); break; case GroundSolverType::CBP: solver = new CountingBp (fg); break;
case GroundSolver::VE: solver = new VarElim (fg); break; case GroundSolverType::VE: solver = new VarElim (fg); break;
} }
Params prevBeliefs = solver->solveQuery ({jointVarIds[0]}); Params prevBeliefs = solver->solveQuery ({jointVarIds[0]});
VarIds observedVids = {jointVars[0]->varId()}; VarIds observedVids = {jointVars[0]->varId()};
@ -80,9 +80,9 @@ Solver::getJointByConditioning (
} }
delete solver; delete solver;
switch (solverType) { switch (solverType) {
case GroundSolver::BP: solver = new BeliefProp (fg); break; case GroundSolverType::BP: solver = new BeliefProp (fg); break;
case GroundSolver::CBP: solver = new CountingBp (fg); break; case GroundSolverType::CBP: solver = new CountingBp (fg); break;
case GroundSolver::VE: solver = new VarElim (fg); break; case GroundSolverType::VE: solver = new VarElim (fg); break;
} }
Params beliefs = solver->solveQuery ({jointVarIds[i]}); Params beliefs = solver->solveQuery ({jointVarIds[i]});
for (size_t k = 0; k < beliefs.size(); k++) { for (size_t k = 0; k < beliefs.size(); k++) {

View File

@ -1,5 +1,5 @@
#ifndef HORUS_SOLVER_H #ifndef HORUS_GROUNDSOLVER_H
#define HORUS_SOLVER_H #define HORUS_GROUNDSOLVER_H
#include <iomanip> #include <iomanip>
@ -10,12 +10,12 @@
using namespace std; using namespace std;
class Solver class GroundSolver
{ {
public: public:
Solver (const FactorGraph& factorGraph) : fg(factorGraph) { } GroundSolver (const FactorGraph& factorGraph) : fg(factorGraph) { }
virtual ~Solver() { } // ensure that subclass destructor is called virtual ~GroundSolver() { } // ensure that subclass destructor is called
virtual Params solveQuery (VarIds queryVids) = 0; virtual Params solveQuery (VarIds queryVids) = 0;
@ -25,12 +25,12 @@ class Solver
void printAllPosterioris (void); void printAllPosterioris (void);
Params getJointByConditioning (GroundSolver, Params getJointByConditioning (GroundSolverType,
FactorGraph, const VarIds& jointVarIds) const; FactorGraph, const VarIds& jointVarIds) const;
protected: protected:
const FactorGraph& fg; const FactorGraph& fg;
}; };
#endif // HORUS_SOLVER_H #endif // HORUS_GROUNDSOLVER_H

View File

@ -28,14 +28,15 @@ typedef vector<unsigned> Ranges;
typedef unsigned long long ullong; typedef unsigned long long ullong;
enum LiftedSolver enum LiftedSolverType
{ {
FOVE, // first order variable elimination LVE, // generalized counting first-order variable elimination (GC-FOVE)
LBP, // lifted belief propagation LBP, // lifted first-order belief propagation
LKC // lifted first-order knowledge compilation
}; };
enum GroundSolver enum GroundSolverType
{ {
VE, // variable elimination VE, // variable elimination
BP, // belief propagation BP, // belief propagation
@ -50,8 +51,8 @@ extern bool logDomain;
// level of debug information // level of debug information
extern unsigned verbosity; extern unsigned verbosity;
extern LiftedSolver liftedSolver; extern LiftedSolverType liftedSolver;
extern GroundSolver groundSolver; extern GroundSolverType groundSolver;
}; };

View File

@ -160,15 +160,15 @@ readQueryAndEvidence (
void void
runSolver (const FactorGraph& fg, const VarIds& queryIds) runSolver (const FactorGraph& fg, const VarIds& queryIds)
{ {
Solver* solver = 0; GroundSolver* solver = 0;
switch (Globals::groundSolver) { switch (Globals::groundSolver) {
case GroundSolver::VE: case GroundSolverType::VE:
solver = new VarElim (fg); solver = new VarElim (fg);
break; break;
case GroundSolver::BP: case GroundSolverType::BP:
solver = new BeliefProp (fg); solver = new BeliefProp (fg);
break; break;
case GroundSolver::CBP: case GroundSolverType::CBP:
solver = new CountingBp (fg); solver = new CountingBp (fg);
break; break;
default: default:

View File

@ -9,11 +9,13 @@
#include "ParfactorList.h" #include "ParfactorList.h"
#include "FactorGraph.h" #include "FactorGraph.h"
#include "LiftedOperations.h"
#include "LiftedVe.h" #include "LiftedVe.h"
#include "VarElim.h" #include "VarElim.h"
#include "LiftedBp.h" #include "LiftedBp.h"
#include "CountingBp.h" #include "CountingBp.h"
#include "BeliefProp.h" #include "BeliefProp.h"
#include "LiftedKc.h"
#include "ElimGraph.h" #include "ElimGraph.h"
#include "BayesBall.h" #include "BayesBall.h"
@ -22,25 +24,15 @@ using namespace std;
typedef std::pair<ParfactorList*, ObservedFormulas*> LiftedNetwork; typedef std::pair<ParfactorList*, ObservedFormulas*> LiftedNetwork;
Params readParameters (YAP_Term); Parfactor* readParfactor (YAP_Term);
vector<unsigned> readUnsignedList (YAP_Term);
void readLiftedEvidence (YAP_Term, ObservedFormulas&); void readLiftedEvidence (YAP_Term, ObservedFormulas&);
Parfactor* readParfactor (YAP_Term); vector<unsigned> readUnsignedList (YAP_Term list);
Params readParameters (YAP_Term);
vector<unsigned> YAP_Term fillAnswersPrologList (vector<Params>& results);
readUnsignedList (YAP_Term list)
{
vector<unsigned> vec;
while (list != YAP_TermNil()) {
vec.push_back ((unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (list)));
list = YAP_TailOfTerm (list);
}
return vec;
}
@ -76,138 +68,13 @@ createLiftedNetwork (void)
readLiftedEvidence (YAP_ARG2, *(obsFormulas)); readLiftedEvidence (YAP_ARG2, *(obsFormulas));
LiftedNetwork* net = new LiftedNetwork (pfList, obsFormulas); LiftedNetwork* net = new LiftedNetwork (pfList, obsFormulas);
YAP_Int p = (YAP_Int) (net); YAP_Int p = (YAP_Int) (net);
return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG3); return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG3);
} }
Parfactor*
readParfactor (YAP_Term pfTerm)
{
// read dist id
unsigned distId = YAP_IntOfTerm (YAP_ArgOfTerm (1, pfTerm));
// read the ranges
Ranges ranges;
YAP_Term rangeList = YAP_ArgOfTerm (3, pfTerm);
while (rangeList != YAP_TermNil()) {
unsigned range = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (rangeList));
ranges.push_back (range);
rangeList = YAP_TailOfTerm (rangeList);
}
// read parametric random vars
ProbFormulas formulas;
unsigned count = 0;
unordered_map<YAP_Term, LogVar> lvMap;
YAP_Term pvList = YAP_ArgOfTerm (2, pfTerm);
while (pvList != YAP_TermNil()) {
YAP_Term formulaTerm = YAP_HeadOfTerm (pvList);
if (YAP_IsAtomTerm (formulaTerm)) {
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (formulaTerm)));
Symbol functor = LiftedUtils::getSymbol (name);
formulas.push_back (ProbFormula (functor, ranges[count]));
} else {
LogVars logVars;
YAP_Functor yapFunctor = YAP_FunctorOfTerm (formulaTerm);
string name ((char*) YAP_AtomName (YAP_NameOfFunctor (yapFunctor)));
Symbol functor = LiftedUtils::getSymbol (name);
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
for (unsigned i = 1; i <= arity; i++) {
YAP_Term ti = YAP_ArgOfTerm (i, formulaTerm);
unordered_map<YAP_Term, LogVar>::iterator it = lvMap.find (ti);
if (it != lvMap.end()) {
logVars.push_back (it->second);
} else {
unsigned newLv = lvMap.size();
lvMap[ti] = newLv;
logVars.push_back (newLv);
}
}
formulas.push_back (ProbFormula (functor, logVars, ranges[count]));
}
count ++;
pvList = YAP_TailOfTerm (pvList);
}
// read the parameters
const Params& params = readParameters (YAP_ArgOfTerm (4, pfTerm));
// read the constraint
Tuples tuples;
if (lvMap.size() >= 1) {
YAP_Term tupleList = YAP_ArgOfTerm (5, pfTerm);
while (tupleList != YAP_TermNil()) {
YAP_Term term = YAP_HeadOfTerm (tupleList);
assert (YAP_IsApplTerm (term));
YAP_Functor yapFunctor = YAP_FunctorOfTerm (term);
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
assert (lvMap.size() == arity);
Tuple tuple (arity);
for (unsigned i = 1; i <= arity; i++) {
YAP_Term ti = YAP_ArgOfTerm (i, term);
if (YAP_IsAtomTerm (ti) == false) {
cerr << "error: constraint has free variables" << endl;
abort();
}
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ti)));
tuple[i - 1] = LiftedUtils::getSymbol (name);
}
tuples.push_back (tuple);
tupleList = YAP_TailOfTerm (tupleList);
}
}
return new Parfactor (formulas, params, tuples, distId);
}
void
readLiftedEvidence (
YAP_Term observedList,
ObservedFormulas& obsFormulas)
{
while (observedList != YAP_TermNil()) {
YAP_Term pair = YAP_HeadOfTerm (observedList);
YAP_Term ground = YAP_ArgOfTerm (1, pair);
Symbol functor;
Symbols args;
if (YAP_IsAtomTerm (ground)) {
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ground)));
functor = LiftedUtils::getSymbol (name);
} else {
assert (YAP_IsApplTerm (ground));
YAP_Functor yapFunctor = YAP_FunctorOfTerm (ground);
string name ((char*) (YAP_AtomName (YAP_NameOfFunctor (yapFunctor))));
functor = LiftedUtils::getSymbol (name);
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
for (unsigned i = 1; i <= arity; i++) {
YAP_Term ti = YAP_ArgOfTerm (i, ground);
assert (YAP_IsAtomTerm (ti));
string arg ((char *) YAP_AtomName (YAP_AtomOfTerm (ti)));
args.push_back (LiftedUtils::getSymbol (arg));
}
}
unsigned evidence = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (2, pair));
bool found = false;
for (size_t i = 0; i < obsFormulas.size(); i++) {
if (obsFormulas[i].functor() == functor &&
obsFormulas[i].arity() == args.size() &&
obsFormulas[i].evidence() == evidence) {
obsFormulas[i].addTuple (args);
found = true;
}
}
if (found == false) {
obsFormulas.push_back (ObservedFormula (functor, evidence, args));
}
observedList = YAP_TailOfTerm (observedList);
}
}
int int
createGroundNetwork (void) createGroundNetwork (void)
{ {
@ -253,31 +120,27 @@ createGroundNetwork (void)
Params
readParameters (YAP_Term paramL)
{
Params params;
assert (YAP_IsPairTerm (paramL));
while (paramL != YAP_TermNil()) {
params.push_back ((double) YAP_FloatOfTerm (YAP_HeadOfTerm (paramL)));
paramL = YAP_TailOfTerm (paramL);
}
if (Globals::logDomain) {
Util::log (params);
}
return params;
}
int int
runLiftedSolver (void) runLiftedSolver (void)
{ {
LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1); LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1);
ParfactorList pfListCopy (*network->first);
LiftedOperations::absorveEvidence (pfListCopy, *network->second);
LiftedSolver* solver = 0;
switch (Globals::liftedSolver) {
case LiftedSolverType::LVE: solver = new LiftedVe (pfListCopy); break;
case LiftedSolverType::LBP: solver = new LiftedBp (pfListCopy); break;
case LiftedSolverType::LKC: solver = new LiftedKc (pfListCopy); break;
}
if (Globals::verbosity > 0) {
solver->printSolverFlags();
cout << endl;
}
YAP_Term taskList = YAP_ARG2; YAP_Term taskList = YAP_ARG2;
vector<Params> results; vector<Params> results;
ParfactorList pfListCopy (*network->first);
LiftedVe::absorveEvidence (pfListCopy, *network->second);
while (taskList != YAP_TermNil()) { while (taskList != YAP_TermNil()) {
Grounds queryVars; Grounds queryVars;
YAP_Term jointList = YAP_HeadOfTerm (taskList); YAP_Term jointList = YAP_HeadOfTerm (taskList);
@ -303,41 +166,13 @@ runLiftedSolver (void)
} }
jointList = YAP_TailOfTerm (jointList); jointList = YAP_TailOfTerm (jointList);
} }
if (Globals::liftedSolver == LiftedSolver::FOVE) { results.push_back (solver->solveQuery (queryVars));
LiftedVe solver (pfListCopy);
if (Globals::verbosity > 0 && taskList == YAP_ARG2) {
solver.printSolverFlags();
cout << endl;
}
results.push_back (solver.solveQuery (queryVars));
} else if (Globals::liftedSolver == LiftedSolver::LBP) {
LiftedBp solver (pfListCopy);
if (Globals::verbosity > 0 && taskList == YAP_ARG2) {
solver.printSolverFlags();
cout << endl;
}
results.push_back (solver.solveQuery (queryVars));
} else {
assert (false);
}
taskList = YAP_TailOfTerm (taskList); taskList = YAP_TailOfTerm (taskList);
} }
YAP_Term list = YAP_TermNil(); delete solver;
for (size_t i = results.size(); i-- > 0; ) {
const Params& beliefs = results[i]; return YAP_Unify (fillAnswersPrologList (results), YAP_ARG3);
YAP_Term queryBeliefsL = YAP_TermNil();
for (size_t j = beliefs.size(); j-- > 0; ) {
YAP_Int sl1 = YAP_InitSlot (list);
YAP_Term belief = YAP_MkFloatTerm (beliefs[j]);
queryBeliefsL = YAP_MkPairTerm (belief, queryBeliefsL);
list = YAP_GetFromSlot (sl1);
YAP_RecoverSlots (1);
}
list = YAP_MkPairTerm (queryBeliefsL, list);
}
return YAP_Unify (list, YAP_ARG3);
} }
@ -346,6 +181,7 @@ int
runGroundSolver (void) runGroundSolver (void)
{ {
FactorGraph* fg = (FactorGraph*) YAP_IntOfTerm (YAP_ARG1); FactorGraph* fg = (FactorGraph*) YAP_IntOfTerm (YAP_ARG1);
vector<VarIds> tasks; vector<VarIds> tasks;
YAP_Term taskList = YAP_ARG2; YAP_Term taskList = YAP_ARG2;
while (taskList != YAP_TermNil()) { while (taskList != YAP_TermNil()) {
@ -357,22 +193,19 @@ runGroundSolver (void)
for (size_t i = 0; i < tasks.size(); i++) { for (size_t i = 0; i < tasks.size(); i++) {
Util::addToSet (vids, tasks[i]); Util::addToSet (vids, tasks[i]);
} }
Solver* solver = 0;
FactorGraph* mfg = fg; FactorGraph* mfg = fg;
if (fg->bayesianFactors()) { if (fg->bayesianFactors()) {
mfg = BayesBall::getMinimalFactorGraph ( mfg = BayesBall::getMinimalFactorGraph (
*fg, VarIds (vids.begin(), vids.end())); *fg, VarIds (vids.begin(), vids.end()));
} }
if (Globals::groundSolver == GroundSolver::VE) { GroundSolver* solver = 0;
solver = new VarElim (*mfg); CountingBp::checkForIdenticalFactors = false;
} else if (Globals::groundSolver == GroundSolver::BP) { switch (Globals::groundSolver) {
solver = new BeliefProp (*mfg); case GroundSolverType::VE: solver = new VarElim (*mfg); break;
} else if (Globals::groundSolver == GroundSolver::CBP) { case GroundSolverType::BP: solver = new BeliefProp (*mfg); break;
CountingBp::checkForIdenticalFactors = false; case GroundSolverType::CBP: solver = new CountingBp (*mfg); break;
solver = new CountingBp (*mfg);
} else {
assert (false);
} }
if (Globals::verbosity > 0) { if (Globals::verbosity > 0) {
@ -391,20 +224,7 @@ runGroundSolver (void)
delete mfg; delete mfg;
} }
YAP_Term list = YAP_TermNil(); return YAP_Unify (fillAnswersPrologList (results), YAP_ARG3);
for (size_t i = results.size(); i-- > 0; ) {
const Params& beliefs = results[i];
YAP_Term queryBeliefsL = YAP_TermNil();
for (size_t j = beliefs.size(); j-- > 0; ) {
YAP_Int sl1 = YAP_InitSlot (list);
YAP_Term belief = YAP_MkFloatTerm (beliefs[j]);
queryBeliefsL = YAP_MkPairTerm (belief, queryBeliefsL);
list = YAP_GetFromSlot (sl1);
YAP_RecoverSlots (1);
}
list = YAP_MkPairTerm (queryBeliefsL, list);
}
return YAP_Unify (list, YAP_ARG3);
} }
@ -535,6 +355,183 @@ freeLiftedNetwork (void)
Parfactor*
readParfactor (YAP_Term pfTerm)
{
// read dist id
unsigned distId = YAP_IntOfTerm (YAP_ArgOfTerm (1, pfTerm));
// read the ranges
Ranges ranges;
YAP_Term rangeList = YAP_ArgOfTerm (3, pfTerm);
while (rangeList != YAP_TermNil()) {
unsigned range = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (rangeList));
ranges.push_back (range);
rangeList = YAP_TailOfTerm (rangeList);
}
// read parametric random vars
ProbFormulas formulas;
unsigned count = 0;
unordered_map<YAP_Term, LogVar> lvMap;
YAP_Term pvList = YAP_ArgOfTerm (2, pfTerm);
while (pvList != YAP_TermNil()) {
YAP_Term formulaTerm = YAP_HeadOfTerm (pvList);
if (YAP_IsAtomTerm (formulaTerm)) {
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (formulaTerm)));
Symbol functor = LiftedUtils::getSymbol (name);
formulas.push_back (ProbFormula (functor, ranges[count]));
} else {
LogVars logVars;
YAP_Functor yapFunctor = YAP_FunctorOfTerm (formulaTerm);
string name ((char*) YAP_AtomName (YAP_NameOfFunctor (yapFunctor)));
Symbol functor = LiftedUtils::getSymbol (name);
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
for (unsigned i = 1; i <= arity; i++) {
YAP_Term ti = YAP_ArgOfTerm (i, formulaTerm);
unordered_map<YAP_Term, LogVar>::iterator it = lvMap.find (ti);
if (it != lvMap.end()) {
logVars.push_back (it->second);
} else {
unsigned newLv = lvMap.size();
lvMap[ti] = newLv;
logVars.push_back (newLv);
}
}
formulas.push_back (ProbFormula (functor, logVars, ranges[count]));
}
count ++;
pvList = YAP_TailOfTerm (pvList);
}
// read the parameters
const Params& params = readParameters (YAP_ArgOfTerm (4, pfTerm));
// read the constraint
Tuples tuples;
if (lvMap.size() >= 1) {
YAP_Term tupleList = YAP_ArgOfTerm (5, pfTerm);
while (tupleList != YAP_TermNil()) {
YAP_Term term = YAP_HeadOfTerm (tupleList);
assert (YAP_IsApplTerm (term));
YAP_Functor yapFunctor = YAP_FunctorOfTerm (term);
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
assert (lvMap.size() == arity);
Tuple tuple (arity);
for (unsigned i = 1; i <= arity; i++) {
YAP_Term ti = YAP_ArgOfTerm (i, term);
if (YAP_IsAtomTerm (ti) == false) {
cerr << "error: constraint has free variables" << endl;
abort();
}
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ti)));
tuple[i - 1] = LiftedUtils::getSymbol (name);
}
tuples.push_back (tuple);
tupleList = YAP_TailOfTerm (tupleList);
}
}
return new Parfactor (formulas, params, tuples, distId);
}
void
readLiftedEvidence (
YAP_Term observedList,
ObservedFormulas& obsFormulas)
{
while (observedList != YAP_TermNil()) {
YAP_Term pair = YAP_HeadOfTerm (observedList);
YAP_Term ground = YAP_ArgOfTerm (1, pair);
Symbol functor;
Symbols args;
if (YAP_IsAtomTerm (ground)) {
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ground)));
functor = LiftedUtils::getSymbol (name);
} else {
assert (YAP_IsApplTerm (ground));
YAP_Functor yapFunctor = YAP_FunctorOfTerm (ground);
string name ((char*) (YAP_AtomName (YAP_NameOfFunctor (yapFunctor))));
functor = LiftedUtils::getSymbol (name);
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
for (unsigned i = 1; i <= arity; i++) {
YAP_Term ti = YAP_ArgOfTerm (i, ground);
assert (YAP_IsAtomTerm (ti));
string arg ((char *) YAP_AtomName (YAP_AtomOfTerm (ti)));
args.push_back (LiftedUtils::getSymbol (arg));
}
}
unsigned evidence = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (2, pair));
bool found = false;
for (size_t i = 0; i < obsFormulas.size(); i++) {
if (obsFormulas[i].functor() == functor &&
obsFormulas[i].arity() == args.size() &&
obsFormulas[i].evidence() == evidence) {
obsFormulas[i].addTuple (args);
found = true;
}
}
if (found == false) {
obsFormulas.push_back (ObservedFormula (functor, evidence, args));
}
observedList = YAP_TailOfTerm (observedList);
}
}
vector<unsigned>
readUnsignedList (YAP_Term list)
{
vector<unsigned> vec;
while (list != YAP_TermNil()) {
vec.push_back ((unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (list)));
list = YAP_TailOfTerm (list);
}
return vec;
}
Params
readParameters (YAP_Term paramL)
{
Params params;
assert (YAP_IsPairTerm (paramL));
while (paramL != YAP_TermNil()) {
params.push_back ((double) YAP_FloatOfTerm (YAP_HeadOfTerm (paramL)));
paramL = YAP_TailOfTerm (paramL);
}
if (Globals::logDomain) {
Util::log (params);
}
return params;
}
YAP_Term
fillAnswersPrologList (vector<Params>& results)
{
YAP_Term list = YAP_TermNil();
for (size_t i = results.size(); i-- > 0; ) {
const Params& beliefs = results[i];
YAP_Term queryBeliefsL = YAP_TermNil();
for (size_t j = beliefs.size(); j-- > 0; ) {
YAP_Int sl1 = YAP_InitSlot (list);
YAP_Term belief = YAP_MkFloatTerm (beliefs[j]);
queryBeliefsL = YAP_MkPairTerm (belief, queryBeliefsL);
list = YAP_GetFromSlot (sl1);
YAP_RecoverSlots (1);
}
list = YAP_MkPairTerm (queryBeliefsL, list);
}
return list;
}
extern "C" void extern "C" void
init_predicates (void) init_predicates (void)
{ {

View File

@ -1,11 +1,11 @@
#include "LiftedBp.h" #include "LiftedBp.h"
#include "WeightedBp.h" #include "WeightedBp.h"
#include "FactorGraph.h" #include "FactorGraph.h"
#include "LiftedVe.h" #include "LiftedOperations.h"
LiftedBp::LiftedBp (const ParfactorList& pfList) LiftedBp::LiftedBp (const ParfactorList& parfactorList)
: pfList_(pfList) : LiftedSolver (parfactorList)
{ {
refineParfactors(); refineParfactors();
createFactorGraph(); createFactorGraph();
@ -82,6 +82,7 @@ LiftedBp::printSolverFlags (void) const
void void
LiftedBp::refineParfactors (void) LiftedBp::refineParfactors (void)
{ {
pfList_ = parfactorList;
while (iterate() == false); while (iterate() == false);
if (Globals::verbosity > 2) { if (Globals::verbosity > 2) {
@ -101,7 +102,7 @@ LiftedBp::iterate (void)
for (size_t i = 0; i < args.size(); i++) { for (size_t i = 0; i < args.size(); i++) {
LogVarSet lvs = (*it)->logVarSet() - args[i].logVars(); LogVarSet lvs = (*it)->logVarSet() - args[i].logVars();
if ((*it)->constr()->isCountNormalized (lvs) == false) { if ((*it)->constr()->isCountNormalized (lvs) == false) {
Parfactors pfs = LiftedVe::countNormalize (*it, lvs); Parfactors pfs = LiftedOperations::countNormalize (*it, lvs);
it = pfList_.removeAndDelete (it); it = pfList_.removeAndDelete (it);
pfList_.add (pfs); pfList_.add (pfs);
return false; return false;
@ -189,12 +190,12 @@ LiftedBp::rangeOfGround (const Ground& gr)
Params Params
LiftedBp::getJointByConditioning ( LiftedBp::getJointByConditioning (
const ParfactorList& pfList, const ParfactorList& pfList,
const Grounds& grounds) const Grounds& query)
{ {
LiftedBp solver (pfList); LiftedBp solver (pfList);
Params prevBeliefs = solver.solveQuery ({grounds[0]}); Params prevBeliefs = solver.solveQuery ({query[0]});
Grounds obsGrounds = {grounds[0]}; Grounds obsGrounds = {query[0]};
for (size_t i = 1; i < grounds.size(); i++) { for (size_t i = 1; i < query.size(); i++) {
Params newBeliefs; Params newBeliefs;
vector<ObservedFormula> obsFs; vector<ObservedFormula> obsFs;
Ranges obsRanges; Ranges obsRanges;
@ -209,16 +210,16 @@ LiftedBp::getJointByConditioning (
obsFs[j].setEvidence (indexer[j]); obsFs[j].setEvidence (indexer[j]);
} }
ParfactorList tempPfList (pfList); ParfactorList tempPfList (pfList);
LiftedVe::absorveEvidence (tempPfList, obsFs); LiftedOperations::absorveEvidence (tempPfList, obsFs);
LiftedBp solver (tempPfList); LiftedBp solver (tempPfList);
Params beliefs = solver.solveQuery ({grounds[i]}); Params beliefs = solver.solveQuery ({query[i]});
for (size_t k = 0; k < beliefs.size(); k++) { for (size_t k = 0; k < beliefs.size(); k++) {
newBeliefs.push_back (beliefs[k]); newBeliefs.push_back (beliefs[k]);
} }
++ indexer; ++ indexer;
} }
int count = -1; int count = -1;
unsigned range = rangeOfGround (grounds[i]); unsigned range = rangeOfGround (query[i]);
for (size_t j = 0; j < newBeliefs.size(); j++) { for (size_t j = 0; j < newBeliefs.size(); j++) {
if (j % range == 0) { if (j % range == 0) {
count ++; count ++;
@ -226,7 +227,7 @@ LiftedBp::getJointByConditioning (
newBeliefs[j] *= prevBeliefs[count]; newBeliefs[j] *= prevBeliefs[count];
} }
prevBeliefs = newBeliefs; prevBeliefs = newBeliefs;
obsGrounds.push_back (grounds[i]); obsGrounds.push_back (query[i]);
} }
return prevBeliefs; return prevBeliefs;
} }

View File

@ -1,12 +1,13 @@
#ifndef HORUS_LIFTEDBP_H #ifndef HORUS_LIFTEDBP_H
#define HORUS_LIFTEDBP_H #define HORUS_LIFTEDBP_H
#include "LiftedSolver.h"
#include "ParfactorList.h" #include "ParfactorList.h"
class FactorGraph; class FactorGraph;
class WeightedBp; class WeightedBp;
class LiftedBp class LiftedBp : public LiftedSolver
{ {
public: public:
LiftedBp (const ParfactorList& pfList); LiftedBp (const ParfactorList& pfList);
@ -39,3 +40,4 @@ class LiftedBp
}; };
#endif // HORUS_LIFTEDBP_H #endif // HORUS_LIFTEDBP_H

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,274 @@
#ifndef HORUS_LIFTEDCIRCUIT_H
#define HORUS_LIFTEDCIRCUIT_H
#include <stack>
#include "LiftedWCNF.h"
enum CircuitNodeType {
OR_NODE,
AND_NODE,
SET_OR_NODE,
SET_AND_NODE,
INC_EXC_NODE,
LEAF_NODE,
SMOOTH_NODE,
TRUE_NODE,
COMPILATION_FAILED_NODE
};
class CircuitNode
{
public:
CircuitNode (const Clauses& clauses, string explanation = "")
: clauses_(clauses), explanation_(explanation) { }
const Clauses& clauses (void) const { return clauses_; }
Clauses clauses (void) { return clauses_; }
virtual double weight (void) const = 0;
string explanation (void) const { return explanation_; }
private:
Clauses clauses_;
string explanation_;
};
class OrNode : public CircuitNode
{
public:
OrNode (const Clauses& clauses, string explanation = "")
: CircuitNode (clauses, explanation),
leftBranch_(0), rightBranch_(0) { }
CircuitNode** leftBranch (void) { return &leftBranch_; }
CircuitNode** rightBranch (void) { return &rightBranch_; }
double weight (void) const;
private:
CircuitNode* leftBranch_;
CircuitNode* rightBranch_;
};
class AndNode : public CircuitNode
{
public:
AndNode (const Clauses& clauses, string explanation = "")
: CircuitNode (clauses, explanation),
leftBranch_(0), rightBranch_(0) { }
AndNode (
const Clauses& clauses,
CircuitNode* leftBranch,
CircuitNode* rightBranch,
string explanation = "")
: CircuitNode (clauses, explanation),
leftBranch_(leftBranch), rightBranch_(rightBranch) { }
AndNode (
CircuitNode* leftBranch,
CircuitNode* rightBranch,
string explanation = "")
: CircuitNode ({}, explanation),
leftBranch_(leftBranch), rightBranch_(rightBranch) { }
CircuitNode** leftBranch (void) { return &leftBranch_; }
CircuitNode** rightBranch (void) { return &rightBranch_; }
double weight (void) const;
private:
CircuitNode* leftBranch_;
CircuitNode* rightBranch_;
};
class SetOrNode : public CircuitNode
{
public:
SetOrNode (unsigned nrGroundings, const Clauses& clauses)
: CircuitNode (clauses, " AC"), follow_(0),
nrGroundings_(nrGroundings) { }
CircuitNode** follow (void) { return &follow_; }
static unsigned nrPositives (void) { return nrGrsStack.top().first; }
static unsigned nrNegatives (void) { return nrGrsStack.top().second; }
double weight (void) const;
private:
CircuitNode* follow_;
unsigned nrGroundings_;
static stack<pair<unsigned, unsigned>> nrGrsStack;
};
class SetAndNode : public CircuitNode
{
public:
SetAndNode (unsigned nrGroundings, const Clauses& clauses)
: CircuitNode (clauses, " IPG"), follow_(0),
nrGroundings_(nrGroundings) { }
CircuitNode** follow (void) { return &follow_; }
double weight (void) const;
private:
CircuitNode* follow_;
unsigned nrGroundings_;
};
class IncExcNode : public CircuitNode
{
public:
IncExcNode (const Clauses& clauses, string explanation)
: CircuitNode (clauses, explanation), plus1Branch_(0),
plus2Branch_(0), minusBranch_(0) { }
CircuitNode** plus1Branch (void) { return &plus1Branch_; }
CircuitNode** plus2Branch (void) { return &plus2Branch_; }
CircuitNode** minusBranch (void) { return &minusBranch_; }
double weight (void) const;
private:
CircuitNode* plus1Branch_;
CircuitNode* plus2Branch_;
CircuitNode* minusBranch_;
};
class LeafNode : public CircuitNode
{
public:
LeafNode (const Clause& clause, const LiftedWCNF& lwcnf)
: CircuitNode (Clauses() = {clause}), lwcnf_(lwcnf) { }
double weight (void) const;
private:
const LiftedWCNF& lwcnf_;
};
class SmoothNode : public CircuitNode
{
public:
SmoothNode (const Clauses& clauses, const LiftedWCNF& lwcnf)
: CircuitNode (clauses), lwcnf_(lwcnf) { }
double weight (void) const;
private:
const LiftedWCNF& lwcnf_;
};
class TrueNode : public CircuitNode
{
public:
TrueNode (void) : CircuitNode ({}) { }
double weight (void) const;
};
class CompilationFailedNode : public CircuitNode
{
public:
CompilationFailedNode (const Clauses& clauses)
: CircuitNode (clauses) { }
double weight (void) const;
};
class LiftedCircuit
{
public:
LiftedCircuit (const LiftedWCNF* lwcnf);
double getWeightedModelCount (void) const;
void exportToGraphViz (const char*);
private:
void compile (CircuitNode** follow, Clauses& clauses);
bool tryUnitPropagation (CircuitNode** follow, Clauses& clauses);
bool tryIndependence (CircuitNode** follow, Clauses& clauses);
bool tryShannonDecomp (CircuitNode** follow, Clauses& clauses);
bool tryInclusionExclusion (CircuitNode** follow, Clauses& clauses);
bool tryIndepPartialGrounding (CircuitNode** follow, Clauses& clauses);
bool tryIndepPartialGroundingAux (Clauses& clauses, ConstraintTree& ct,
LogVars& rootLogVars);
bool tryAtomCounting (CircuitNode** follow, Clauses& clauses);
bool tryGrounding (CircuitNode** follow, Clauses& clauses);
void shatterCountedLogVars (Clauses& clauses);
bool shatterCountedLogVarsAux (Clauses& clauses);
bool shatterCountedLogVarsAux (Clauses& clauses, size_t idx1, size_t idx2);
bool independentClause (Clause& clause, Clauses& otherClauses) const;
bool independentLiteral (const Literal& lit,
const Literals& otherLits) const;
LitLvTypesSet smoothCircuit (CircuitNode* node);
void createSmoothNode (const LitLvTypesSet& lids,
CircuitNode** prev);
vector<LogVarTypes> getAllPossibleTypes (unsigned nrLogVars) const;
bool containsTypes (const LogVarTypes& typesA,
const LogVarTypes& typesB) const;
CircuitNodeType getCircuitNodeType (const CircuitNode* node) const;
void exportToGraphViz (CircuitNode* node, ofstream&);
void printClauses (const CircuitNode* node, ofstream&,
string extraOptions = "");
string escapeNode (const CircuitNode* node) const;
CircuitNode* root_;
const LiftedWCNF* lwcnf_;
};
#endif // HORUS_LIFTEDCIRCUIT_H

View File

@ -0,0 +1,75 @@
#include "LiftedKc.h"
#include "LiftedWCNF.h"
#include "LiftedCircuit.h"
#include "LiftedOperations.h"
#include "Indexer.h"
LiftedKc::~LiftedKc (void)
{
delete lwcnf_;
delete circuit_;
}
Params
LiftedKc::solveQuery (const Grounds& query)
{
pfList_ = parfactorList;
LiftedOperations::shatterAgainstQuery (pfList_, query);
LiftedOperations::runWeakBayesBall (pfList_, query);
lwcnf_ = new LiftedWCNF (pfList_);
circuit_ = new LiftedCircuit (lwcnf_);
vector<PrvGroup> groups;
Ranges ranges;
for (size_t i = 0; i < query.size(); i++) {
ParfactorList::const_iterator it = pfList_.begin();
while (it != pfList_.end()) {
size_t idx = (*it)->indexOfGround (query[i]);
if (idx != (*it)->nrArguments()) {
groups.push_back ((*it)->argument (idx).group());
ranges.push_back ((*it)->range (idx));
break;
}
++ it;
}
}
assert (groups.size() == query.size());
Params params;
Indexer indexer (ranges);
while (indexer.valid()) {
for (size_t i = 0; i < groups.size(); i++) {
vector<LiteralId> litIds = lwcnf_->prvGroupLiterals (groups[i]);
for (size_t j = 0; j < litIds.size(); j++) {
if (indexer[i] == j) {
lwcnf_->addWeight (litIds[j], LogAware::one(),
LogAware::one());
} else {
lwcnf_->addWeight (litIds[j], LogAware::zero(),
LogAware::one());
}
}
}
params.push_back (circuit_->getWeightedModelCount());
++ indexer;
}
LogAware::normalize (params);
if (Globals::logDomain) {
Util::exp (params);
}
return params;
}
void
LiftedKc::printSolverFlags (void) const
{
stringstream ss;
ss << "lifted kc [" ;
ss << "log_domain=" << Util::toString (Globals::logDomain);
ss << "]" ;
cout << ss.str() << endl;
}

View File

@ -0,0 +1,30 @@
#ifndef HORUS_LIFTEDKC_H
#define HORUS_LIFTEDKC_H
#include "LiftedSolver.h"
#include "ParfactorList.h"
class LiftedWCNF;
class LiftedCircuit;
class LiftedKc : public LiftedSolver
{
public:
LiftedKc (const ParfactorList& pfList)
: LiftedSolver(pfList) { }
~LiftedKc (void);
Params solveQuery (const Grounds&);
void printSolverFlags (void) const;
private:
LiftedWCNF* lwcnf_;
LiftedCircuit* circuit_;
ParfactorList pfList_;
};
#endif // HORUS_LIFTEDKC_H

View File

@ -0,0 +1,271 @@
#include "LiftedOperations.h"
void
LiftedOperations::shatterAgainstQuery (
ParfactorList& pfList,
const Grounds& query)
{
for (size_t i = 0; i < query.size(); i++) {
if (query[i].isAtom()) {
continue;
}
bool found = false;
Parfactors newPfs;
ParfactorList::iterator it = pfList.begin();
while (it != pfList.end()) {
if ((*it)->containsGround (query[i])) {
found = true;
std::pair<ConstraintTree*, ConstraintTree*> split;
LogVars queryLvs (
(*it)->constr()->logVars().begin(),
(*it)->constr()->logVars().begin() + query[i].arity());
split = (*it)->constr()->split (query[i].args());
ConstraintTree* commCt = split.first;
ConstraintTree* exclCt = split.second;
newPfs.push_back (new Parfactor (*it, commCt));
if (exclCt->empty() == false) {
newPfs.push_back (new Parfactor (*it, exclCt));
} else {
delete exclCt;
}
it = pfList.removeAndDelete (it);
} else {
++ it;
}
}
if (found == false) {
cerr << "error: could not find a parfactor with ground " ;
cerr << "`" << query[i] << "'" << endl;
exit (0);
}
pfList.add (newPfs);
}
if (Globals::verbosity > 2) {
Util::printAsteriskLine();
cout << "SHATTERED AGAINST THE QUERY" << endl;
for (size_t i = 0; i < query.size(); i++) {
cout << " -> " << query[i] << endl;
}
Util::printAsteriskLine();
pfList.print();
}
}
void
LiftedOperations::runWeakBayesBall (
ParfactorList& pfList,
const Grounds& query)
{
queue<PrvGroup> todo; // groups to process
set<PrvGroup> done; // processed or in queue
for (size_t i = 0; i < query.size(); i++) {
ParfactorList::iterator it = pfList.begin();
while (it != pfList.end()) {
PrvGroup group = (*it)->findGroup (query[i]);
if (group != numeric_limits<PrvGroup>::max()) {
todo.push (group);
done.insert (group);
break;
}
++ it;
}
}
set<Parfactor*> requiredPfs;
while (todo.empty() == false) {
PrvGroup group = todo.front();
ParfactorList::iterator it = pfList.begin();
while (it != pfList.end()) {
if (Util::contains (requiredPfs, *it) == false &&
(*it)->containsGroup (group)) {
vector<PrvGroup> groups = (*it)->getAllGroups();
for (size_t i = 0; i < groups.size(); i++) {
if (Util::contains (done, groups[i]) == false) {
todo.push (groups[i]);
done.insert (groups[i]);
}
}
requiredPfs.insert (*it);
}
++ it;
}
todo.pop();
}
ParfactorList::iterator it = pfList.begin();
bool foundNotRequired = false;
while (it != pfList.end()) {
if (Util::contains (requiredPfs, *it) == false) {
if (Globals::verbosity > 2) {
if (foundNotRequired == false) {
Util::printHeader ("PARFACTORS TO DISCARD");
foundNotRequired = true;
}
(*it)->print();
}
it = pfList.removeAndDelete (it);
} else {
++ it;
}
}
}
void
LiftedOperations::absorveEvidence (
ParfactorList& pfList,
ObservedFormulas& obsFormulas)
{
for (size_t i = 0; i < obsFormulas.size(); i++) {
Parfactors newPfs;
ParfactorList::iterator it = pfList.begin();
while (it != pfList.end()) {
Parfactor* pf = *it;
it = pfList.remove (it);
Parfactors absorvedPfs = absorve (obsFormulas[i], pf);
if (absorvedPfs.empty() == false) {
if (absorvedPfs.size() == 1 && absorvedPfs[0] == 0) {
// just remove pf;
} else {
Util::addToVector (newPfs, absorvedPfs);
}
delete pf;
} else {
it = pfList.insertShattered (it, pf);
++ it;
}
}
pfList.add (newPfs);
}
if (Globals::verbosity > 2 && obsFormulas.empty() == false) {
Util::printAsteriskLine();
cout << "AFTER EVIDENCE ABSORVED" << endl;
for (size_t i = 0; i < obsFormulas.size(); i++) {
cout << " -> " << obsFormulas[i] << endl;
}
Util::printAsteriskLine();
pfList.print();
}
}
Parfactors
LiftedOperations::countNormalize (
Parfactor* g,
const LogVarSet& set)
{
Parfactors normPfs;
if (set.empty()) {
normPfs.push_back (new Parfactor (*g));
} else {
ConstraintTrees normCts = g->constr()->countNormalize (set);
for (size_t i = 0; i < normCts.size(); i++) {
normPfs.push_back (new Parfactor (g, normCts[i]));
}
}
return normPfs;
}
Parfactor
LiftedOperations::calcGroundMultiplication (Parfactor pf)
{
LogVarSet lvs = pf.constr()->logVarSet();
lvs -= pf.constr()->singletons();
Parfactors newPfs = {new Parfactor (pf)};
for (size_t i = 0; i < lvs.size(); i++) {
Parfactors pfs = newPfs;
newPfs.clear();
for (size_t j = 0; j < pfs.size(); j++) {
bool countedLv = pfs[j]->countedLogVars().contains (lvs[i]);
if (countedLv) {
pfs[j]->fullExpand (lvs[i]);
newPfs.push_back (pfs[j]);
} else {
ConstraintTrees cts = pfs[j]->constr()->ground (lvs[i]);
for (size_t k = 0; k < cts.size(); k++) {
newPfs.push_back (new Parfactor (pfs[j], cts[k]));
}
delete pfs[j];
}
}
}
ParfactorList pfList (newPfs);
Parfactors groundShatteredPfs (pfList.begin(),pfList.end());
for (size_t i = 1; i < groundShatteredPfs.size(); i++) {
groundShatteredPfs[0]->multiply (*groundShatteredPfs[i]);
}
return Parfactor (*groundShatteredPfs[0]);
}
Parfactors
LiftedOperations::absorve (
ObservedFormula& obsFormula,
Parfactor* g)
{
Parfactors absorvedPfs;
const ProbFormulas& formulas = g->arguments();
for (size_t i = 0; i < formulas.size(); i++) {
if (obsFormula.functor() == formulas[i].functor() &&
obsFormula.arity() == formulas[i].arity()) {
if (obsFormula.isAtom()) {
if (formulas.size() > 1) {
g->absorveEvidence (formulas[i], obsFormula.evidence());
} else {
// hack to erase parfactor g
absorvedPfs.push_back (0);
}
break;
}
g->constr()->moveToTop (formulas[i].logVars());
std::pair<ConstraintTree*, ConstraintTree*> res;
res = g->constr()->split (
formulas[i].logVars(),
&(obsFormula.constr()),
obsFormula.constr().logVars());
ConstraintTree* commCt = res.first;
ConstraintTree* exclCt = res.second;
if (commCt->empty() == false) {
if (formulas.size() > 1) {
LogVarSet excl = g->exclusiveLogVars (i);
Parfactor tempPf (g, commCt);
Parfactors countNormPfs = LiftedOperations::countNormalize (
&tempPf, excl);
for (size_t j = 0; j < countNormPfs.size(); j++) {
countNormPfs[j]->absorveEvidence (
formulas[i], obsFormula.evidence());
absorvedPfs.push_back (countNormPfs[j]);
}
} else {
delete commCt;
}
if (exclCt->empty() == false) {
absorvedPfs.push_back (new Parfactor (g, exclCt));
} else {
delete exclCt;
}
if (absorvedPfs.empty()) {
// hack to erase parfactor g
absorvedPfs.push_back (0);
}
break;
} else {
delete commCt;
delete exclCt;
}
}
}
return absorvedPfs;
}

View File

@ -0,0 +1,26 @@
#ifndef HORUS_LIFTEDOPERATIONS_H
#define HORUS_LIFTEDOPERATIONS_H
#include "ParfactorList.h"
class LiftedOperations
{
public:
static void shatterAgainstQuery (
ParfactorList& pfList, const Grounds& query);
static void runWeakBayesBall (
ParfactorList& pfList, const Grounds&);
static void absorveEvidence (
ParfactorList& pfList, ObservedFormulas& obsFormulas);
static Parfactors countNormalize (Parfactor*, const LogVarSet&);
static Parfactor calcGroundMultiplication (Parfactor pf);
private:
static Parfactors absorve (ObservedFormula&, Parfactor*);
};
#endif // HORUS_LIFTEDOPERATIONS_H

View File

@ -0,0 +1,27 @@
#ifndef HORUS_LIFTEDSOLVER_H
#define HORUS_LIFTEDSOLVER_H
#include "ParfactorList.h"
#include "Horus.h"
using namespace std;
class LiftedSolver
{
public:
LiftedSolver (const ParfactorList& pfList)
: parfactorList(pfList) { }
virtual ~LiftedSolver() { } // ensure that subclass destructor is called
virtual Params solveQuery (const Grounds& query) = 0;
virtual void printSolverFlags (void) const = 0;
protected:
const ParfactorList& parfactorList;
};
#endif // HORUS_LIFTEDSOLVER_H

View File

@ -2,6 +2,7 @@
#include <set> #include <set>
#include "LiftedVe.h" #include "LiftedVe.h"
#include "LiftedOperations.h"
#include "Histogram.h" #include "Histogram.h"
#include "Util.h" #include "Util.h"
@ -221,7 +222,7 @@ SumOutOperator::apply (void)
product->sumOutIndex (fIdx); product->sumOutIndex (fIdx);
pfList_.addShattered (product); pfList_.addShattered (product);
} else { } else {
Parfactors pfs = LiftedVe::countNormalize (product, excl); Parfactors pfs = LiftedOperations::countNormalize (product, excl);
for (size_t i = 0; i < pfs.size(); i++) { for (size_t i = 0; i < pfs.size(); i++) {
pfs[i]->sumOutIndex (fIdx); pfs[i]->sumOutIndex (fIdx);
pfList_.add (pfs[i]); pfList_.add (pfs[i]);
@ -375,7 +376,7 @@ CountingOperator::apply (void)
} else { } else {
Parfactor* pf = *pfIter_; Parfactor* pf = *pfIter_;
pfList_.remove (pfIter_); pfList_.remove (pfIter_);
Parfactors pfs = LiftedVe::countNormalize (pf, X_); Parfactors pfs = LiftedOperations::countNormalize (pf, X_);
for (size_t i = 0; i < pfs.size(); i++) { for (size_t i = 0; i < pfs.size(); i++) {
unsigned condCount = pfs[i]->constr()->getConditionalCount (X_); unsigned condCount = pfs[i]->constr()->getConditionalCount (X_);
bool cartProduct = pfs[i]->constr()->isCartesianProduct ( bool cartProduct = pfs[i]->constr()->isCartesianProduct (
@ -419,7 +420,7 @@ CountingOperator::toString (void)
ss << "count convert " << X_ << " in " ; ss << "count convert " << X_ << " in " ;
ss << (*pfIter_)->getLabel(); ss << (*pfIter_)->getLabel();
ss << " [cost=" << std::exp (getLogCost()) << "]" << endl; ss << " [cost=" << std::exp (getLogCost()) << "]" << endl;
Parfactors pfs = LiftedVe::countNormalize (*pfIter_, X_); Parfactors pfs = LiftedOperations::countNormalize (*pfIter_, X_);
if ((*pfIter_)->constr()->isCountNormalized (X_) == false) { if ((*pfIter_)->constr()->isCountNormalized (X_) == false) {
for (size_t i = 0; i < pfs.size(); i++) { for (size_t i = 0; i < pfs.size(); i++) {
ss << " º " << pfs[i]->getLabel() << endl; ss << " º " << pfs[i]->getLabel() << endl;
@ -508,8 +509,6 @@ GroundOperator::getLogCost (void)
void void
GroundOperator::apply (void) GroundOperator::apply (void)
{ {
// TODO if we update the correct groups
// we can skip shattering
ParfactorList::iterator pfIter; ParfactorList::iterator pfIter;
pfIter = getParfactorsWithGroup (pfList_, group_).front(); pfIter = getParfactorsWithGroup (pfList_, group_).front();
Parfactor* pf = *pfIter; Parfactor* pf = *pfIter;
@ -632,6 +631,7 @@ Params
LiftedVe::solveQuery (const Grounds& query) LiftedVe::solveQuery (const Grounds& query)
{ {
assert (query.empty() == false); assert (query.empty() == false);
pfList_ = parfactorList;
runSolver (query); runSolver (query);
(*pfList_.begin())->normalize(); (*pfList_.begin())->normalize();
Params params = (*pfList_.begin())->params(); Params params = (*pfList_.begin())->params();
@ -647,7 +647,7 @@ void
LiftedVe::printSolverFlags (void) const LiftedVe::printSolverFlags (void) const
{ {
stringstream ss; stringstream ss;
ss << "fove [" ; ss << "lve [" ;
ss << "log_domain=" << Util::toString (Globals::logDomain); ss << "log_domain=" << Util::toString (Globals::logDomain);
ss << "]" ; ss << "]" ;
cout << ss.str() << endl; cout << ss.str() << endl;
@ -655,103 +655,12 @@ LiftedVe::printSolverFlags (void) const
void
LiftedVe::absorveEvidence (
ParfactorList& pfList,
ObservedFormulas& obsFormulas)
{
for (size_t i = 0; i < obsFormulas.size(); i++) {
Parfactors newPfs;
ParfactorList::iterator it = pfList.begin();
while (it != pfList.end()) {
Parfactor* pf = *it;
it = pfList.remove (it);
Parfactors absorvedPfs = absorve (obsFormulas[i], pf);
if (absorvedPfs.empty() == false) {
if (absorvedPfs.size() == 1 && absorvedPfs[0] == 0) {
// just remove pf;
} else {
Util::addToVector (newPfs, absorvedPfs);
}
delete pf;
} else {
it = pfList.insertShattered (it, pf);
++ it;
}
}
pfList.add (newPfs);
}
if (Globals::verbosity > 2 && obsFormulas.empty() == false) {
Util::printAsteriskLine();
cout << "AFTER EVIDENCE ABSORVED" << endl;
for (size_t i = 0; i < obsFormulas.size(); i++) {
cout << " -> " << obsFormulas[i] << endl;
}
Util::printAsteriskLine();
pfList.print();
}
}
Parfactors
LiftedVe::countNormalize (
Parfactor* g,
const LogVarSet& set)
{
Parfactors normPfs;
if (set.empty()) {
normPfs.push_back (new Parfactor (*g));
} else {
ConstraintTrees normCts = g->constr()->countNormalize (set);
for (size_t i = 0; i < normCts.size(); i++) {
normPfs.push_back (new Parfactor (g, normCts[i]));
}
}
return normPfs;
}
Parfactor
LiftedVe::calcGroundMultiplication (Parfactor pf)
{
LogVarSet lvs = pf.constr()->logVarSet();
lvs -= pf.constr()->singletons();
Parfactors newPfs = {new Parfactor (pf)};
for (size_t i = 0; i < lvs.size(); i++) {
Parfactors pfs = newPfs;
newPfs.clear();
for (size_t j = 0; j < pfs.size(); j++) {
bool countedLv = pfs[j]->countedLogVars().contains (lvs[i]);
if (countedLv) {
pfs[j]->fullExpand (lvs[i]);
newPfs.push_back (pfs[j]);
} else {
ConstraintTrees cts = pfs[j]->constr()->ground (lvs[i]);
for (size_t k = 0; k < cts.size(); k++) {
newPfs.push_back (new Parfactor (pfs[j], cts[k]));
}
delete pfs[j];
}
}
}
ParfactorList pfList (newPfs);
Parfactors groundShatteredPfs (pfList.begin(),pfList.end());
for (size_t i = 1; i < groundShatteredPfs.size(); i++) {
groundShatteredPfs[0]->multiply (*groundShatteredPfs[i]);
}
return Parfactor (*groundShatteredPfs[0]);
}
void void
LiftedVe::runSolver (const Grounds& query) LiftedVe::runSolver (const Grounds& query)
{ {
largestCost_ = std::log (0); largestCost_ = std::log (0);
shatterAgainstQuery (query); LiftedOperations::shatterAgainstQuery (pfList_, query);
runWeakBayesBall (query); LiftedOperations::runWeakBayesBall (pfList_, query);
while (true) { while (true) {
if (Globals::verbosity > 2) { if (Globals::verbosity > 2) {
Util::printDashedLine(); Util::printDashedLine();
@ -817,177 +726,3 @@ LiftedVe::getBestOperation (const Grounds& query)
return bestOp; return bestOp;
} }
void
LiftedVe::runWeakBayesBall (const Grounds& query)
{
queue<PrvGroup> todo; // groups to process
set<PrvGroup> done; // processed or in queue
for (size_t i = 0; i < query.size(); i++) {
ParfactorList::iterator it = pfList_.begin();
while (it != pfList_.end()) {
PrvGroup group = (*it)->findGroup (query[i]);
if (group != numeric_limits<PrvGroup>::max()) {
todo.push (group);
done.insert (group);
break;
}
++ it;
}
}
set<Parfactor*> requiredPfs;
while (todo.empty() == false) {
PrvGroup group = todo.front();
ParfactorList::iterator it = pfList_.begin();
while (it != pfList_.end()) {
if (Util::contains (requiredPfs, *it) == false &&
(*it)->containsGroup (group)) {
vector<PrvGroup> groups = (*it)->getAllGroups();
for (size_t i = 0; i < groups.size(); i++) {
if (Util::contains (done, groups[i]) == false) {
todo.push (groups[i]);
done.insert (groups[i]);
}
}
requiredPfs.insert (*it);
}
++ it;
}
todo.pop();
}
ParfactorList::iterator it = pfList_.begin();
bool foundNotRequired = false;
while (it != pfList_.end()) {
if (Util::contains (requiredPfs, *it) == false) {
if (Globals::verbosity > 2) {
if (foundNotRequired == false) {
Util::printHeader ("PARFACTORS TO DISCARD");
foundNotRequired = true;
}
(*it)->print();
}
it = pfList_.removeAndDelete (it);
} else {
++ it;
}
}
}
void
LiftedVe::shatterAgainstQuery (const Grounds& query)
{
for (size_t i = 0; i < query.size(); i++) {
if (query[i].isAtom()) {
continue;
}
bool found = false;
Parfactors newPfs;
ParfactorList::iterator it = pfList_.begin();
while (it != pfList_.end()) {
if ((*it)->containsGround (query[i])) {
found = true;
std::pair<ConstraintTree*, ConstraintTree*> split;
LogVars queryLvs (
(*it)->constr()->logVars().begin(),
(*it)->constr()->logVars().begin() + query[i].arity());
split = (*it)->constr()->split (query[i].args());
ConstraintTree* commCt = split.first;
ConstraintTree* exclCt = split.second;
newPfs.push_back (new Parfactor (*it, commCt));
if (exclCt->empty() == false) {
newPfs.push_back (new Parfactor (*it, exclCt));
} else {
delete exclCt;
}
it = pfList_.removeAndDelete (it);
} else {
++ it;
}
}
if (found == false) {
cerr << "error: could not find a parfactor with ground " ;
cerr << "`" << query[i] << "'" << endl;
exit (0);
}
pfList_.add (newPfs);
}
if (Globals::verbosity > 2) {
Util::printAsteriskLine();
cout << "SHATTERED AGAINST THE QUERY" << endl;
for (size_t i = 0; i < query.size(); i++) {
cout << " -> " << query[i] << endl;
}
Util::printAsteriskLine();
pfList_.print();
}
}
Parfactors
LiftedVe::absorve (
ObservedFormula& obsFormula,
Parfactor* g)
{
Parfactors absorvedPfs;
const ProbFormulas& formulas = g->arguments();
for (size_t i = 0; i < formulas.size(); i++) {
if (obsFormula.functor() == formulas[i].functor() &&
obsFormula.arity() == formulas[i].arity()) {
if (obsFormula.isAtom()) {
if (formulas.size() > 1) {
g->absorveEvidence (formulas[i], obsFormula.evidence());
} else {
// hack to erase parfactor g
absorvedPfs.push_back (0);
}
break;
}
g->constr()->moveToTop (formulas[i].logVars());
std::pair<ConstraintTree*, ConstraintTree*> res;
res = g->constr()->split (
formulas[i].logVars(),
&(obsFormula.constr()),
obsFormula.constr().logVars());
ConstraintTree* commCt = res.first;
ConstraintTree* exclCt = res.second;
if (commCt->empty() == false) {
if (formulas.size() > 1) {
LogVarSet excl = g->exclusiveLogVars (i);
Parfactor tempPf (g, commCt);
Parfactors countNormPfs = countNormalize (&tempPf, excl);
for (size_t j = 0; j < countNormPfs.size(); j++) {
countNormPfs[j]->absorveEvidence (
formulas[i], obsFormula.evidence());
absorvedPfs.push_back (countNormPfs[j]);
}
} else {
delete commCt;
}
if (exclCt->empty() == false) {
absorvedPfs.push_back (new Parfactor (g, exclCt));
} else {
delete exclCt;
}
if (absorvedPfs.empty()) {
// hack to erase parfactor g
absorvedPfs.push_back (0);
}
break;
} else {
delete commCt;
delete exclCt;
}
}
}
return absorvedPfs;
}

View File

@ -1,13 +1,15 @@
#ifndef HORUS_LIFTEDVE_H #ifndef HORUS_LIFTEDVE_H
#define HORUS_LIFTEDVE_H #define HORUS_LIFTEDVE_H
#include "LiftedSolver.h"
#include "ParfactorList.h" #include "ParfactorList.h"
class LiftedOperator class LiftedOperator
{ {
public: public:
virtual ~LiftedOperator (void) { }
virtual double getLogCost (void) = 0; virtual double getLogCost (void) = 0;
virtual void apply (void) = 0; virtual void apply (void) = 0;
@ -43,9 +45,9 @@ class ProductOperator : public LiftedOperator
private: private:
static bool validOp (Parfactor*, Parfactor*); static bool validOp (Parfactor*, Parfactor*);
ParfactorList::iterator g1_; ParfactorList::iterator g1_;
ParfactorList::iterator g2_; ParfactorList::iterator g2_;
ParfactorList& pfList_; ParfactorList& pfList_;
}; };
@ -123,43 +125,30 @@ class GroundOperator : public LiftedOperator
private: private:
vector<pair<PrvGroup, unsigned>> getAffectedFormulas (void); vector<pair<PrvGroup, unsigned>> getAffectedFormulas (void);
PrvGroup group_; PrvGroup group_;
unsigned lvIndex_; unsigned lvIndex_;
ParfactorList& pfList_; ParfactorList& pfList_;
}; };
class LiftedVe class LiftedVe : public LiftedSolver
{ {
public: public:
LiftedVe (const ParfactorList& pfList) : pfList_(pfList) { } LiftedVe (const ParfactorList& pfList)
: LiftedSolver(pfList) { }
Params solveQuery (const Grounds&); Params solveQuery (const Grounds&);
void printSolverFlags (void) const; void printSolverFlags (void) const;
static void absorveEvidence (
ParfactorList& pfList, ObservedFormulas& obsFormulas);
static Parfactors countNormalize (Parfactor*, const LogVarSet&);
static Parfactor calcGroundMultiplication (Parfactor pf);
private: private:
void runSolver (const Grounds&); void runSolver (const Grounds&);
LiftedOperator* getBestOperation (const Grounds&); LiftedOperator* getBestOperation (const Grounds&);
void runWeakBayesBall (const Grounds&); ParfactorList pfList_;
double largestCost_;
void shatterAgainstQuery (const Grounds&);
static Parfactors absorve (ObservedFormula&, Parfactor*);
ParfactorList pfList_;
double largestCost_;
}; };
#endif // HORUS_LIFTEDVE_H #endif // HORUS_LIFTEDVE_H

View File

@ -0,0 +1,610 @@
#include "LiftedWCNF.h"
#include "ConstraintTree.h"
#include "Indexer.h"
bool
Literal::isGround (ConstraintTree constr, LogVarSet ipgLogVars) const
{
if (logVars_.size() == 0) {
return true;
}
LogVarSet lvs (logVars_);
lvs -= ipgLogVars;
return constr.singletons().contains (lvs);
}
size_t
Literal::indexOfLogVar (LogVar X) const
{
return Util::indexOf (logVars_, X);
}
string
Literal::toString (
LogVarSet ipgLogVars,
LogVarSet posCountedLvs,
LogVarSet negCountedLvs) const
{
stringstream ss;
negated_ ? ss << "¬" : ss << "" ;
// if (negated_ == false) {
// posWeight_ < 0.0 ? ss << "λ" : ss << "Θ" ;
// } else {
// negWeight_ < 0.0 ? ss << "λ" : ss << "Θ" ;
// }
ss << "λ" ;
ss << lid_ ;
if (logVars_.empty() == false) {
ss << "(" ;
for (size_t i = 0; i < logVars_.size(); i++) {
if (i != 0) ss << ",";
if (posCountedLvs.contains (logVars_[i])) {
ss << "+" << logVars_[i];
} else if (negCountedLvs.contains (logVars_[i])) {
ss << "-" << logVars_[i];
} else if (ipgLogVars.contains (logVars_[i])) {
LogVar X = logVars_[i];
const string labels[] = {
"a", "b", "c", "d", "e", "f",
"g", "h", "i", "j", "k", "m" };
(X >= 12) ? ss << "x_" << X : ss << labels[X];
} else {
ss << logVars_[i];
}
}
ss << ")" ;
}
return ss.str();
}
std::ostream& operator<< (ostream &os, const Literal& lit)
{
os << lit.toString();
return os;
}
void
Clause::addLiteralComplemented (const Literal& lit)
{
assert (constr_.logVarSet().contains (lit.logVars()));
literals_.push_back (lit);
literals_.back().complement();
}
bool
Clause::containsLiteral (LiteralId lid) const
{
for (size_t i = 0; i < literals_.size(); i++) {
if (literals_[i].lid() == lid) {
return true;
}
}
return false;
}
bool
Clause::containsPositiveLiteral (
LiteralId lid,
const LogVarTypes& types) const
{
for (size_t i = 0; i < literals_.size(); i++) {
if (literals_[i].lid() == lid
&& literals_[i].isPositive()
&& logVarTypes (i) == types) {
return true;
}
}
return false;
}
bool
Clause::containsNegativeLiteral (
LiteralId lid,
const LogVarTypes& types) const
{
for (size_t i = 0; i < literals_.size(); i++) {
if (literals_[i].lid() == lid
&& literals_[i].isNegative()
&& logVarTypes (i) == types) {
return true;
}
}
return false;
}
void
Clause::removeLiterals (LiteralId lid)
{
size_t i = 0;
while (i != literals_.size()) {
if (literals_[i].lid() == lid) {
removeLiteral (i);
} else {
i ++;
}
}
}
void
Clause::removePositiveLiterals (
LiteralId lid,
const LogVarTypes& types)
{
size_t i = 0;
while (i != literals_.size()) {
if (literals_[i].lid() == lid
&& literals_[i].isPositive()
&& logVarTypes (i) == types) {
removeLiteral (i);
} else {
i ++;
}
}
}
void
Clause::removeNegativeLiterals (
LiteralId lid,
const LogVarTypes& types)
{
size_t i = 0;
while (i != literals_.size()) {
if (literals_[i].lid() == lid
&& literals_[i].isNegative()
&& logVarTypes (i) == types) {
removeLiteral (i);
} else {
i ++;
}
}
}
bool
Clause::isCountedLogVar (LogVar X) const
{
assert (constr_.logVarSet().contains (X));
return posCountedLvs_.contains (X)
|| negCountedLvs_.contains (X);
}
bool
Clause::isPositiveCountedLogVar (LogVar X) const
{
assert (constr_.logVarSet().contains (X));
return posCountedLvs_.contains (X);
}
bool
Clause::isNegativeCountedLogVar (LogVar X) const
{
assert (constr_.logVarSet().contains (X));
return negCountedLvs_.contains (X);
}
bool
Clause::isIpgLogVar (LogVar X) const
{
assert (constr_.logVarSet().contains (X));
return ipgLvs_.contains (X);
}
TinySet<LiteralId>
Clause::lidSet (void) const
{
TinySet<LiteralId> lidSet;
for (size_t i = 0; i < literals_.size(); i++) {
lidSet.insert (literals_[i].lid());
}
return lidSet;
}
LogVarSet
Clause::ipgCandidates (void) const
{
LogVarSet candidates;
LogVarSet allLvs = constr_.logVarSet();
allLvs -= ipgLvs_;
allLvs -= posCountedLvs_;
allLvs -= negCountedLvs_;
for (size_t i = 0; i < allLvs.size(); i++) {
bool valid = true;
for (size_t j = 0; j < literals_.size(); j++) {
if (Util::contains (literals_[j].logVars(), allLvs[i]) == false) {
valid = false;
break;
}
}
if (valid) {
candidates.insert (allLvs[i]);
}
}
return candidates;
}
LogVarTypes
Clause::logVarTypes (size_t litIdx) const
{
LogVarTypes types;
const LogVars lvs = literals_[litIdx].logVars();
for (size_t i = 0; i < lvs.size(); i++) {
if (posCountedLvs_.contains (lvs[i])) {
types.push_back (LogVarType::POS_LV);
} else if (negCountedLvs_.contains (lvs[i])) {
types.push_back (LogVarType::NEG_LV);
} else {
types.push_back (LogVarType::FULL_LV);
}
}
return types;
}
void
Clause::removeLiteral (size_t litIdx)
{
LogVarSet lvsToRemove = literals_[litIdx].logVarSet()
- getLogVarSetExcluding (litIdx);
ipgLvs_ -= lvsToRemove;
posCountedLvs_ -= lvsToRemove;
negCountedLvs_ -= lvsToRemove;
constr_.remove (lvsToRemove);
literals_.erase (literals_.begin() + litIdx);
}
bool
Clause::independentClauses (Clause& c1, Clause& c2)
{
const Literals& lits1 = c1.literals();
const Literals& lits2 = c2.literals();
for (size_t i = 0; i < lits1.size(); i++) {
for (size_t j = 0; j < lits2.size(); j++) {
if (lits1[i].lid() == lits2[j].lid()
&& c1.logVarTypes (i) == c2.logVarTypes (j)) {
return false;
}
}
}
return true;
}
void
Clause::printClauses (const Clauses& clauses)
{
for (size_t i = 0; i < clauses.size(); i++) {
cout << clauses[i] << endl;
}
}
std::ostream& operator<< (ostream &os, const Clause& clause)
{
for (unsigned i = 0; i < clause.literals_.size(); i++) {
if (i != 0) os << " v " ;
os << clause.literals_[i].toString (clause.ipgLvs_,
clause.posCountedLvs_, clause.negCountedLvs_);
}
if (clause.constr_.empty() == false) {
ConstraintTree copy (clause.constr_);
copy.moveToTop (copy.logVarSet().elements());
os << " | " << copy.tupleSet();
}
return os;
}
LogVarSet
Clause::getLogVarSetExcluding (size_t idx) const
{
LogVarSet lvs;
for (size_t i = 0; i < literals_.size(); i++) {
if (i != idx) {
lvs |= literals_[i].logVars();
}
}
return lvs;
}
LiftedWCNF::LiftedWCNF (const ParfactorList& pfList)
: freeLiteralId_(0), pfList_(pfList)
{
addIndicatorClauses (pfList);
addParameterClauses (pfList);
/*
// INCLUSION-EXCLUSION TEST
vector<vector<string>> names = {
// {"a1","b1"},{"a2","b2"},{"a1","b3"}
{"b1","a1"},{"b2","a2"},{"b3","a1"}
};
Clause c1 (names);
c1.addLiteral (Literal (0, LogVars() = {0}));
c1.addLiteral (Literal (1, LogVars() = {1}));
clauses_.push_back(c1);
freeLiteralId_ ++ ;
freeLiteralId_ ++ ;
*/
/*
// ATOM-COUNTING TEST
vector<vector<string>> names = {
{"p1","p1"},{"p1","p2"},{"p1","p3"},
{"p2","p1"},{"p2","p2"},{"p2","p3"},
{"p3","p1"},{"p3","p2"},{"p3","p3"}
};
Clause c1 (names);
c1.addLiteral (Literal (0, LogVars() = {0}));
c1.addLiteralComplemented (Literal (1, {0,1}));
clauses_.push_back(c1);
Clause c2 (names);
c2.addLiteral (Literal (0, LogVars()={0}));
c2.addLiteralComplemented (Literal (1, {1,0}));
clauses_.push_back(c2);
addWeight (0, LogAware::log(3.0), LogAware::log(4.0));
addWeight (1, LogAware::log(2.0), LogAware::log(5.0));
freeLiteralId_ = 2;
*/
cout << "FORMULA INDICATORS:" << endl;
printFormulaIndicators();
cout << endl;
cout << "WEIGHTS:" << endl;
printWeights();
cout << endl;
cout << "CLAUSES:" << endl;
printClauses();
cout << endl;
}
LiftedWCNF::~LiftedWCNF (void)
{
}
void
LiftedWCNF::addWeight (LiteralId lid, double posW, double negW)
{
weights_[lid] = make_pair (posW, negW);
}
double
LiftedWCNF::posWeight (LiteralId lid) const
{
unordered_map<LiteralId, std::pair<double,double>>::const_iterator it;
it = weights_.find (lid);
return it != weights_.end() ? it->second.first : LogAware::one();
}
double
LiftedWCNF::negWeight (LiteralId lid) const
{
unordered_map<LiteralId, std::pair<double,double>>::const_iterator it;
it = weights_.find (lid);
return it != weights_.end() ? it->second.second : LogAware::one();
}
vector<LiteralId>
LiftedWCNF::prvGroupLiterals (PrvGroup prvGroup)
{
assert (Util::contains (map_, prvGroup));
return map_[prvGroup];
}
Clause
LiftedWCNF::createClause (LiteralId lid) const
{
for (size_t i = 0; i < clauses_.size(); i++) {
const Literals& literals = clauses_[i].literals();
for (size_t j = 0; j < literals.size(); j++) {
if (literals[j].lid() == lid) {
ConstraintTree ct = clauses_[i].constr();
ct.project (literals[j].logVars());
Clause clause (ct);
clause.addLiteral (literals[j]);
return clause;
}
}
}
abort(); // we should not reach this point
return Clause (ConstraintTree({}));
}
LiteralId
LiftedWCNF::getLiteralId (PrvGroup prvGroup, unsigned range)
{
assert (Util::contains (map_, prvGroup));
return map_[prvGroup][range];
}
void
LiftedWCNF::addIndicatorClauses (const ParfactorList& pfList)
{
ParfactorList::const_iterator it = pfList.begin();
set<PrvGroup> allGroups;
while (it != pfList.end()) {
const ProbFormulas& formulas = (*it)->arguments();
for (size_t i = 0; i < formulas.size(); i++) {
if (Util::contains (allGroups, formulas[i].group()) == false) {
allGroups.insert (formulas[i].group());
ConstraintTree tempConstr = *(*it)->constr();
tempConstr.project (formulas[i].logVars());
Clause clause (tempConstr);
vector<LiteralId> lids;
for (size_t j = 0; j < formulas[i].range(); j++) {
clause.addLiteral (Literal (freeLiteralId_, formulas[i].logVars()));
lids.push_back (freeLiteralId_);
freeLiteralId_ ++;
}
clauses_.push_back (clause);
for (size_t j = 0; j < formulas[i].range() - 1; j++) {
for (size_t k = j + 1; k < formulas[i].range(); k++) {
ConstraintTree tempConstr2 = *(*it)->constr();
tempConstr2.project (formulas[i].logVars());
Clause clause2 (tempConstr2);
clause2.addLiteralComplemented (Literal (clause.literals()[j]));
clause2.addLiteralComplemented (Literal (clause.literals()[k]));
clauses_.push_back (clause2);
}
}
map_[formulas[i].group()] = lids;
}
}
++ it;
}
}
void
LiftedWCNF::addParameterClauses (const ParfactorList& pfList)
{
ParfactorList::const_iterator it = pfList.begin();
while (it != pfList.end()) {
Indexer indexer ((*it)->ranges());
vector<PrvGroup> groups = (*it)->getAllGroups();
while (indexer.valid()) {
LiteralId paramVarLid = freeLiteralId_;
// λu1 ∧ ... ∧ λun ∧ λxi <=> θxi|u1,...,un
//
// ¬λu1 ... ¬λun v θxi|u1,...,un -> clause1
// ¬θxi|u1,...,un v λu1 -> tempClause
// ¬θxi|u1,...,un v λu2 -> tempClause
double posWeight = (**it)[indexer];
addWeight (paramVarLid, posWeight, 1.0);
Clause clause1 (*(*it)->constr());
for (unsigned i = 0; i < groups.size(); i++) {
LiteralId lid = getLiteralId (groups[i], indexer[i]);
clause1.addLiteralComplemented (
Literal (lid, (*it)->argument(i).logVars()));
ConstraintTree ct = *(*it)->constr();
Clause tempClause (ct);
tempClause.addLiteralComplemented (Literal (
paramVarLid, (*it)->constr()->logVars()));
tempClause.addLiteral (Literal (lid, (*it)->argument(i).logVars()));
clauses_.push_back (tempClause);
}
clause1.addLiteral (Literal (paramVarLid, (*it)->constr()->logVars()));
clauses_.push_back (clause1);
freeLiteralId_ ++;
++ indexer;
}
++ it;
}
}
void
LiftedWCNF::printFormulaIndicators (void) const
{
set<PrvGroup> allGroups;
ParfactorList::const_iterator it = pfList_.begin();
while (it != pfList_.end()) {
const ProbFormulas& formulas = (*it)->arguments();
for (size_t i = 0; i < formulas.size(); i++) {
if (Util::contains (allGroups, formulas[i].group()) == false) {
allGroups.insert (formulas[i].group());
cout << formulas[i] << " | " ;
ConstraintTree tempCt = *(*it)->constr();
tempCt.project (formulas[i].logVars());
cout << tempCt.tupleSet();
cout << " indicators => " ;
vector<LiteralId> indicators =
(map_.find (formulas[i].group()))->second;
cout << indicators << endl;
}
}
++ it;
}
}
void
LiftedWCNF::printWeights (void) const
{
unordered_map<LiteralId, std::pair<double,double>>::const_iterator it;
it = weights_.begin();
while (it != weights_.end()) {
cout << "λ" << it->first << " weights: " ;
cout << it->second.first << " " << it->second.second;
cout << endl;
++ it;
}
}
void
LiftedWCNF::printClauses (void) const
{
for (unsigned i = 0; i < clauses_.size(); i++) {
cout << clauses_[i] << endl;
}
}

View File

@ -0,0 +1,226 @@
#ifndef HORUS_LIFTEDWCNF_H
#define HORUS_LIFTEDWCNF_H
#include "ParfactorList.h"
using namespace std;
typedef long LiteralId;
class ConstraintTree;
enum LogVarType
{
FULL_LV,
POS_LV,
NEG_LV
};
typedef vector<LogVarType> LogVarTypes;
class Literal
{
public:
Literal (LiteralId lid, const LogVars& lvs) :
lid_(lid), logVars_(lvs), negated_(false) { }
Literal (const Literal& lit, bool negated) :
lid_(lit.lid_), logVars_(lit.logVars_), negated_(negated) { }
LiteralId lid (void) const { return lid_; }
LogVars logVars (void) const { return logVars_; }
size_t nrLogVars (void) const { return logVars_.size(); }
LogVarSet logVarSet (void) const { return LogVarSet (logVars_); }
void complement (void) { negated_ = !negated_; }
bool isPositive (void) const { return negated_ == false; }
bool isNegative (void) const { return negated_; }
bool isGround (ConstraintTree constr, LogVarSet ipgLogVars) const;
size_t indexOfLogVar (LogVar X) const;
string toString (LogVarSet ipgLogVars = LogVarSet(),
LogVarSet posCountedLvs = LogVarSet(),
LogVarSet negCountedLvs = LogVarSet()) const;
friend std::ostream& operator<< (std::ostream &os, const Literal& lit);
private:
LiteralId lid_;
LogVars logVars_;
bool negated_;
};
typedef vector<Literal> Literals;
class Clause
{
public:
Clause (const ConstraintTree& ct = ConstraintTree({})) : constr_(ct) { }
Clause (vector<vector<string>> names) : constr_(ConstraintTree (names)) { }
void addLiteral (const Literal& l) { literals_.push_back (l); }
const Literals& literals (void) const { return literals_; }
const ConstraintTree& constr (void) const { return constr_; }
ConstraintTree constr (void) { return constr_; }
bool isUnit (void) const { return literals_.size() == 1; }
LogVarSet ipgLogVars (void) const { return ipgLvs_; }
void addIpgLogVar (LogVar X) { ipgLvs_.insert (X); }
void addPosCountedLogVar (LogVar X) { posCountedLvs_.insert (X); }
void addNegCountedLogVar (LogVar X) { negCountedLvs_.insert (X); }
LogVarSet posCountedLogVars (void) const { return posCountedLvs_; }
LogVarSet negCountedLogVars (void) const { return negCountedLvs_; }
unsigned nrPosCountedLogVars (void) const { return posCountedLvs_.size(); }
unsigned nrNegCountedLogVars (void) const { return negCountedLvs_.size(); }
void addLiteralComplemented (const Literal& lit);
bool containsLiteral (LiteralId lid) const;
bool containsPositiveLiteral (LiteralId lid, const LogVarTypes&) const;
bool containsNegativeLiteral (LiteralId lid, const LogVarTypes&) const;
void removeLiterals (LiteralId lid);
void removePositiveLiterals (LiteralId lid, const LogVarTypes&);
void removeNegativeLiterals (LiteralId lid, const LogVarTypes&);
bool isCountedLogVar (LogVar X) const;
bool isPositiveCountedLogVar (LogVar X) const;
bool isNegativeCountedLogVar (LogVar X) const;
bool isIpgLogVar (LogVar X) const;
TinySet<LiteralId> lidSet (void) const;
LogVarSet ipgCandidates (void) const;
LogVarTypes logVarTypes (size_t litIdx) const;
void removeLiteral (size_t litIdx);
static bool independentClauses (Clause& c1, Clause& c2);
static void printClauses (const vector<Clause>& clauses);
friend std::ostream& operator<< (ostream &os, const Clause& clause);
private:
LogVarSet getLogVarSetExcluding (size_t idx) const;
Literals literals_;
LogVarSet ipgLvs_;
LogVarSet posCountedLvs_;
LogVarSet negCountedLvs_;
ConstraintTree constr_;
};
typedef vector<Clause> Clauses;
class LitLvTypes
{
public:
struct CompareLitLvTypes
{
bool operator() (
const LitLvTypes& types1,
const LitLvTypes& types2) const
{
if (types1.lid_ < types2.lid_) {
return true;
}
return types1.lvTypes_ < types2.lvTypes_;
}
};
LitLvTypes (LiteralId lid, const LogVarTypes& lvTypes) :
lid_(lid), lvTypes_(lvTypes) { }
LiteralId lid (void) const { return lid_; }
const LogVarTypes& logVarTypes (void) const { return lvTypes_; }
void setAllFullLogVars (void) {
std::fill (lvTypes_.begin(), lvTypes_.end(), LogVarType::FULL_LV); }
private:
LiteralId lid_;
LogVarTypes lvTypes_;
};
typedef TinySet<LitLvTypes,LitLvTypes::CompareLitLvTypes> LitLvTypesSet;
class LiftedWCNF
{
public:
LiftedWCNF (const ParfactorList& pfList);
~LiftedWCNF (void);
const Clauses& clauses (void) const { return clauses_; }
void addWeight (LiteralId lid, double posW, double negW);
double posWeight (LiteralId lid) const;
double negWeight (LiteralId lid) const;
vector<LiteralId> prvGroupLiterals (PrvGroup prvGroup);
Clause createClause (LiteralId lid) const;
void printFormulaIndicators (void) const;
void printWeights (void) const;
void printClauses (void) const;
private:
LiteralId getLiteralId (PrvGroup prvGroup, unsigned range);
void addIndicatorClauses (const ParfactorList& pfList);
void addParameterClauses (const ParfactorList& pfList);
Clauses clauses_;
LiteralId freeLiteralId_;
const ParfactorList& pfList_;
unordered_map<PrvGroup, vector<LiteralId>> map_;
unordered_map<LiteralId, std::pair<double,double>> weights_;
};
#endif // HORUS_LIFTEDWCNF_H

View File

@ -23,10 +23,10 @@ CC=@CC@
CXX=@CXX@ CXX=@CXX@
# normal # normal
CXXFLAGS= -std=c++0x @SHLIB_CXXFLAGS@ $(YAP_EXTRAS) $(DEFS) -D_YAP_NOT_INSTALLED_=1 -I$(srcdir) -I../../.. -I$(srcdir)/../../../include @CPPFLAGS@ -DNDEBUG #CXXFLAGS= -std=c++0x @SHLIB_CXXFLAGS@ $(YAP_EXTRAS) $(DEFS) -D_YAP_NOT_INSTALLED_=1 -I$(srcdir) -I../../.. -I$(srcdir)/../../../include @CPPFLAGS@ -DNDEBUG
# debug # debug
#CXXFLAGS= -std=c++0x @SHLIB_CXXFLAGS@ $(YAP_EXTRAS) $(DEFS) -D_YAP_NOT_INSTALLED_=1 -I$(srcdir) -I../../.. -I$(srcdir)/../../../include @CPPFLAGS@ -g -O0 -Wextra CXXFLAGS= -std=c++0x @SHLIB_CXXFLAGS@ $(YAP_EXTRAS) $(DEFS) -D_YAP_NOT_INSTALLED_=1 -I$(srcdir) -I../../.. -I$(srcdir)/../../../include @CPPFLAGS@ -g -O0 -Wextra
# #
@ -57,12 +57,17 @@ HEADERS = \
$(srcdir)/Horus.h \ $(srcdir)/Horus.h \
$(srcdir)/Indexer.h \ $(srcdir)/Indexer.h \
$(srcdir)/LiftedBp.h \ $(srcdir)/LiftedBp.h \
$(srcdir)/LiftedCircuit.h \
$(srcdir)/LiftedKc.h \
$(srcdir)/LiftedOperations.h \
$(srcdir)/LiftedSolver.h \
$(srcdir)/LiftedUtils.h \ $(srcdir)/LiftedUtils.h \
$(srcdir)/LiftedVe.h \ $(srcdir)/LiftedVe.h \
$(srcdir)/LiftedWCNF.h \
$(srcdir)/Parfactor.h \ $(srcdir)/Parfactor.h \
$(srcdir)/ParfactorList.h \ $(srcdir)/ParfactorList.h \
$(srcdir)/ProbFormula.h \ $(srcdir)/ProbFormula.h \
$(srcdir)/Solver.h \ $(srcdir)/GroundSolver.h \
$(srcdir)/TinySet.h \ $(srcdir)/TinySet.h \
$(srcdir)/Util.h \ $(srcdir)/Util.h \
$(srcdir)/Var.h \ $(srcdir)/Var.h \
@ -82,16 +87,20 @@ CPP_SOURCES = \
$(srcdir)/HorusCli.cpp \ $(srcdir)/HorusCli.cpp \
$(srcdir)/HorusYap.cpp \ $(srcdir)/HorusYap.cpp \
$(srcdir)/LiftedBp.cpp \ $(srcdir)/LiftedBp.cpp \
$(srcdir)/LiftedCircuit.cpp \
$(srcdir)/LiftedKc.cpp \
$(srcdir)/LiftedOperations.cpp \
$(srcdir)/LiftedUtils.cpp \ $(srcdir)/LiftedUtils.cpp \
$(srcdir)/LiftedVe.cpp \ $(srcdir)/LiftedVe.cpp \
$(srcdir)/LiftedWCNF.cpp \
$(srcdir)/Parfactor.cpp \ $(srcdir)/Parfactor.cpp \
$(srcdir)/ParfactorList.cpp \ $(srcdir)/ParfactorList.cpp \
$(srcdir)/ProbFormula.cpp \ $(srcdir)/ProbFormula.cpp \
$(srcdir)/Solver.cpp \ $(srcdir)/GroundSolver.cpp \
$(srcdir)/Util.cpp \ $(srcdir)/Util.cpp \
$(srcdir)/Var.cpp \ $(srcdir)/Var.cpp \
$(srcdir)/VarElim.cpp \ $(srcdir)/VarElim.cpp \
$(srcdir)/WeightedBp.cpp \ $(srcdir)/WeightedBp.cpp
OBJS = \ OBJS = \
BayesBall.o \ BayesBall.o \
@ -105,12 +114,16 @@ OBJS = \
Histogram.o \ Histogram.o \
HorusYap.o \ HorusYap.o \
LiftedBp.o \ LiftedBp.o \
LiftedCircuit.o \
LiftedKc.o \
LiftedOperations.o \
LiftedUtils.o \ LiftedUtils.o \
LiftedVe.o \ LiftedVe.o \
LiftedWCNF.o \
ProbFormula.o \ ProbFormula.o \
Parfactor.o \ Parfactor.o \
ParfactorList.o \ ParfactorList.o \
Solver.o \ GroundSolver.o \
Util.o \ Util.o \
Var.o \ Var.o \
VarElim.o \ VarElim.o \
@ -125,7 +138,7 @@ HCLI_OBJS = \
Factor.o \ Factor.o \
FactorGraph.o \ FactorGraph.o \
HorusCli.o \ HorusCli.o \
Solver.o \ GroundSolver.o \
Util.o \ Util.o \
Var.o \ Var.o \
VarElim.o \ VarElim.o \

View File

@ -9,8 +9,7 @@ ParfactorList::ParfactorList (const ParfactorList& pfList)
while (it != pfList.end()) { while (it != pfList.end()) {
addShattered (new Parfactor (**it)); addShattered (new Parfactor (**it));
++ it; ++ it;
} }
} }
@ -126,6 +125,27 @@ ParfactorList::print (void) const
ParfactorList&
ParfactorList::operator= (const ParfactorList& pfList)
{
if (this != &pfList) {
ParfactorList::const_iterator it0 = pfList_.begin();
while (it0 != pfList_.end()) {
delete *it0;
++ it0;
}
pfList_.clear();
ParfactorList::const_iterator it = pfList.begin();
while (it != pfList.end()) {
addShattered (new Parfactor (**it));
++ it;
}
}
return *this;
}
bool bool
ParfactorList::isShattered (const Parfactor* g) const ParfactorList::isShattered (const Parfactor* g) const
{ {

View File

@ -56,6 +56,8 @@ class ParfactorList
bool isAllShattered (void) const; bool isAllShattered (void) const;
void print (void) const; void print (void) const;
ParfactorList& operator= (const ParfactorList& pfList);
private: private:

View File

@ -1,3 +1,2 @@
- Find a way to decrease the time required to find an - Handle formulas like f(X,X)
elimination order for variable elimination

View File

@ -153,6 +153,11 @@ class TinySet
{ {
return vec_[i]; return vec_[i];
} }
T& operator[] (typename vector<T>::size_type i)
{
return vec_[i];
}
T front (void) const T front (void) const
{ {

View File

@ -13,9 +13,9 @@ bool logDomain = false;
unsigned verbosity = 0; unsigned verbosity = 0;
LiftedSolver liftedSolver = LiftedSolver::FOVE; LiftedSolverType liftedSolver = LiftedSolverType::LVE;
GroundSolver groundSolver = GroundSolver::VE; GroundSolverType groundSolver = GroundSolverType::VE;
}; };
@ -210,10 +210,12 @@ setHorusFlag (string key, string value)
ss << value; ss << value;
ss >> Globals::verbosity; ss >> Globals::verbosity;
} else if (key == "lifted_solver") { } else if (key == "lifted_solver") {
if ( value == "fove") { if ( value == "lve") {
Globals::liftedSolver = LiftedSolver::FOVE; Globals::liftedSolver = LiftedSolverType::LVE;
} else if (value == "lbp") { } else if (value == "lbp") {
Globals::liftedSolver = LiftedSolver::LBP; Globals::liftedSolver = LiftedSolverType::LBP;
} else if (value == "lkc") {
Globals::liftedSolver = LiftedSolverType::LKC;
} else { } else {
cerr << "warning: invalid value `" << value << "' " ; cerr << "warning: invalid value `" << value << "' " ;
cerr << "for `" << key << "'" << endl; cerr << "for `" << key << "'" << endl;
@ -221,11 +223,11 @@ setHorusFlag (string key, string value)
} }
} else if (key == "ground_solver") { } else if (key == "ground_solver") {
if ( value == "ve") { if ( value == "ve") {
Globals::groundSolver = GroundSolver::VE; Globals::groundSolver = GroundSolverType::VE;
} else if (value == "bp") { } else if (value == "bp") {
Globals::groundSolver = GroundSolver::BP; Globals::groundSolver = GroundSolverType::BP;
} else if (value == "cbp") { } else if (value == "cbp") {
Globals::groundSolver = GroundSolver::CBP; Globals::groundSolver = GroundSolverType::CBP;
} else { } else {
cerr << "warning: invalid value `" << value << "' " ; cerr << "warning: invalid value `" << value << "' " ;
cerr << "for `" << key << "'" << endl; cerr << "for `" << key << "'" << endl;

View File

@ -3,7 +3,7 @@
#include "unordered_map" #include "unordered_map"
#include "Solver.h" #include "GroundSolver.h"
#include "FactorGraph.h" #include "FactorGraph.h"
#include "Horus.h" #include "Horus.h"
@ -11,10 +11,10 @@
using namespace std; using namespace std;
class VarElim : public Solver class VarElim : public GroundSolver
{ {
public: public:
VarElim (const FactorGraph& fg) : Solver (fg) { } VarElim (const FactorGraph& fg) : GroundSolver (fg) { }
~VarElim (void); ~VarElim (void);