This commit is contained in:
Vitor Santos Costa 2012-11-22 12:29:37 +00:00
commit 027632456a
48 changed files with 3049 additions and 615 deletions

View File

@ -49,11 +49,12 @@ Inference Options
PFL supports both ground and lifted inference. The inference algorithm
can be chosen using the set_solver/1 predicate. The following algorithms
are supported:
- fove: lifted variable elimination with arbitrary constraints (GC-FOVE)
- lve: generalized counting first-order variable elimination (GC-FOVE)
- hve: (ground) variable elimination
- lbp: lifted first-order belief propagation
- cbp: counting belief propagation
- bp: (ground) belief propagation
- lkc: lifted first-order knowledge compilation
For example, if we want to use ground variable elimination to solve some
query, we need to call first the following goal:

View File

@ -3,7 +3,7 @@
source city.sh
source ../benchs.sh
SOLVER="fove"
SOLVER="lve"
function run_all_graphs
{
@ -32,5 +32,5 @@ function run_all_graphs
}
prepare_new_run
run_all_graphs "fove "
run_all_graphs "lve "

View File

@ -3,7 +3,7 @@
source cw.sh
source ../benchs.sh
SOLVER="fove"
SOLVER="lve"
function run_all_graphs
{
@ -26,6 +26,6 @@ function run_all_graphs
}
prepare_new_run
run_all_graphs "fove "
run_all_graphs "lve "

View File

@ -3,7 +3,7 @@
cd workshop_attrs
source hve_tests.sh
source bp_tests.sh
source fove_tests.sh
source lve_tests.sh
source lbp_tests.sh
source cbp_tests.sh
cd ..
@ -11,7 +11,7 @@ cd ..
cd comp_workshops
source hve_tests.sh
source bp_tests.sh
source fove_tests.sh
source lve_tests.sh
source lbp_tests.sh
source cbp_tests.sh
cd ..
@ -19,7 +19,7 @@ cd ..
cd city
source hve_tests.sh
source bp_tests.sh
source fove_tests.sh
source lve_tests.sh
source lbp_tests.sh
source cbp_tests.sh
cd ..
@ -27,7 +27,7 @@ cd ..
cd smokers
source hve_tests.sh
source bp_tests.sh
source fove_tests.sh
source lve_tests.sh
source lbp_tests.sh
source cbp_tests.sh
cd ..

View File

@ -3,7 +3,7 @@
source sm.sh
source ../benchs.sh
SOLVER="fove"
SOLVER="lve"
function run_all_graphs
{
@ -26,6 +26,6 @@ function run_all_graphs
}
prepare_new_run
run_all_graphs "fove "
run_all_graphs "lve "

View File

@ -3,7 +3,7 @@
source sm.sh
source ../benchs.sh
SOLVER="fove"
SOLVER="lve"
function run_all_graphs
{
@ -30,6 +30,6 @@ function run_all_graphs
}
prepare_new_run
run_all_graphs "fove "
run_all_graphs "lve "

View File

@ -3,7 +3,7 @@
source wa.sh
source ../benchs.sh
SOLVER="fove"
SOLVER="lve"
function run_all_graphs
{
@ -32,6 +32,6 @@ function run_all_graphs
}
prepare_new_run
run_all_graphs "fove "
run_all_graphs "lve "

View File

@ -39,8 +39,9 @@ set_solver(ve) :- !, set_clpbn_flag(solver,ve).
set_solver(bdd) :- !, set_clpbn_flag(solver,bdd).
set_solver(jt) :- !, set_clpbn_flag(solver,jt).
set_solver(gibbs) :- !, set_clpbn_flag(solver,gibbs).
set_solver(fove) :- !, set_clpbn_flag(solver,fove), set_horus_flag(lifted_solver, fove).
set_solver(lve) :- !, set_clpbn_flag(solver,fove), set_horus_flag(lifted_solver, lve).
set_solver(lbp) :- !, set_clpbn_flag(solver,fove), set_horus_flag(lifted_solver, lbp).
set_solver(lkc) :- !, set_clpbn_flag(solver,fove), set_horus_flag(lifted_solver, lkc).
set_solver(hve) :- !, set_clpbn_flag(solver,bp), set_horus_flag(ground_solver, ve).
set_solver(bp) :- !, set_clpbn_flag(solver,bp), set_horus_flag(ground_solver, bp).
set_solver(cbp) :- !, set_clpbn_flag(solver,bp), set_horus_flag(ground_solver, cbp).

View File

@ -18,7 +18,7 @@ total_students(256).
:- ensure_loaded(parschema).
:- yap_flag(unknown,error).
%:- clpbn_horus:set_solver(fove).
%:- clpbn_horus:set_solver(lve).
%:- clpbn_horus:set_solver(hve).
:- clpbn_horus:set_solver(bp).
%:- clpbn_horus:set_solver(bdd).

View File

@ -1,6 +1,6 @@
:- use_module(library(pfl)).
%:- set_solver(fove).
%:- set_solver(lve).
%:- set_solver(hve).
%:- set_solver(bp).
%:- set_solver(cbp).

View File

@ -1,6 +1,6 @@
:- use_module(library(pfl)).
%:- set_solver(fove).
%:- set_solver(lve).
%:- set_solver(hve).
%:- set_solver(bp).
%:- set_solver(cbp).
@ -29,7 +29,7 @@ bayes car_color(P)::[t,f], hair_color(P) ; car_color_table(P); [people(P,_)].
bayes height(P)::[t,f], gender(P) ; height_table(P) ; [people(P,_)].
bayes shoe_size(P):[t,f], height(P) ; shoe_size_table(P); [people(P,_)].
bayes shoe_size(P)::[t,f], height(P) ; shoe_size_table(P); [people(P,_)].
bayes guilty(P)::[y,n] ; guilty_table(P) ; [people(P,_)].

View File

@ -1,6 +1,6 @@
:- use_module(library(pfl)).
%:- set_solver(fove).
%:- set_solver(lve).
%:- set_solver(hve).
%:- set_solver(bp).
%:- set_solver(cbp).

View File

@ -1,7 +1,7 @@
:- use_module(library(pfl)).
:- set_pfl_flag(solver,fove).
:- set_pfl_flag(solver,lve).
%:- set_pfl_flag(solver,bp), clpbn_horus:set_horus_flag(inf_alg,ve).
%:- set_pfl_flag(solver,bp), clpbn_horus:set_horus_flag(inf_alg,bp).
%:- set_pfl_flag(solver,bp), clpbn_horus:set_horus_flag(inf_alg,cbp).

View File

@ -1,6 +1,6 @@
:- use_module(library(pfl)).
:- set_solver(fove).
:- set_solver(lve).
%:- set_solver(hve).
%:- set_solver(bp).
%:- set_solver(cbp).

View File

@ -1,6 +1,6 @@
:- use_module(library(pfl)).
%:- set_solver(fove).
%:- set_solver(lve).
%:- set_solver(hve).
%:- set_solver(bp).
%:- set_solver(cbp).

View File

@ -1,6 +1,6 @@
:- use_module(library(pfl)).
%:- set_solver(fove).
%:- set_solver(lve).
%:- set_solver(hve).
%:- set_solver(bp).
%:- set_solver(cbp).

View File

@ -1,6 +1,6 @@
:- use_module(library(pfl)).
%:- set_solver(fove).
%:- set_solver(lve).
%:- set_solver(hve).
%:- set_solver(bp).
%:- set_solver(cbp).

View File

@ -12,7 +12,7 @@
#include "Horus.h"
BeliefProp::BeliefProp (const FactorGraph& fg) : Solver (fg)
BeliefProp::BeliefProp (const FactorGraph& fg) : GroundSolver (fg)
{
runned_ = false;
}
@ -377,7 +377,8 @@ BeliefProp::getVarToFactorMsg (const BpLink* link) const
Params
BeliefProp::getJointByConditioning (const VarIds& jointVarIds) const
{
return Solver::getJointByConditioning (GroundSolver::BP, fg, jointVarIds);
return GroundSolver::getJointByConditioning (
GroundSolverType::BP, fg, jointVarIds);
}

View File

@ -5,7 +5,7 @@
#include <vector>
#include <sstream>
#include "Solver.h"
#include "GroundSolver.h"
#include "Factor.h"
#include "FactorGraph.h"
#include "Util.h"
@ -83,7 +83,7 @@ class SPNodeInfo
};
class BeliefProp : public Solver
class BeliefProp : public GroundSolver
{
public:
BeliefProp (const FactorGraph&);

View File

@ -112,6 +112,8 @@ CTNode::copySubtree (const CTNode* root1)
const CTNode* n1 = stack.back().first;
CTNode* n2 = stack.back().second;
stack.pop_back();
// cout << "n2 childs: " << n2->childs();
// cout << "n1 childs: " << n1->childs();
n2->childs().reserve (n1->nrChilds());
stack.reserve (n1->nrChilds());
for (CTChilds::const_iterator chIt = n1->childs().begin();
@ -185,11 +187,31 @@ ConstraintTree::ConstraintTree (
ConstraintTree::ConstraintTree (vector<vector<string>> names)
{
assert (names.empty() == false);
assert (names.front().empty() == false);
unsigned nrLvs = names[0].size();
for (size_t i = 0; i < nrLvs; i++) {
logVars_.push_back (LogVar (i));
}
root_ = new CTNode (0, 0);
logVarSet_ = LogVarSet (logVars_);
for (size_t i = 0; i < names.size(); i++) {
Tuple t;
for (size_t j = 0; j < names[i].size(); j++) {
assert (names[i].size() == nrLvs);
t.push_back (LiftedUtils::getSymbol (names[i][j]));
}
addTuple (t);
}
}
ConstraintTree::ConstraintTree (const ConstraintTree& ct)
{
root_ = CTNode::copySubtree (ct.root_);
logVars_ = ct.logVars_;
logVarSet_ = ct.logVarSet_;
*this = ct;
}
@ -367,6 +389,16 @@ ConstraintTree::project (const LogVarSet& X)
ConstraintTree
ConstraintTree::projectedCopy (const LogVarSet& X)
{
ConstraintTree copy = *this;
copy.project (X);
return copy;
}
void
ConstraintTree::remove (const LogVarSet& X)
{
@ -865,6 +897,19 @@ ConstraintTree::copyLogVar (LogVar X_1, LogVar X_2)
ConstraintTree&
ConstraintTree::operator= (const ConstraintTree& ct)
{
if (this != &ct) {
root_ = CTNode::copySubtree (ct.root_);
logVars_ = ct.logVars_;
logVarSet_ = ct.logVarSet_;
}
return *this;
}
unsigned
ConstraintTree::countTuples (const CTNode* n) const
{

View File

@ -108,6 +108,8 @@ class ConstraintTree
ConstraintTree (const LogVars&);
ConstraintTree (const LogVars&, const Tuples&);
ConstraintTree (vector<vector<string>> names);
ConstraintTree (const ConstraintTree&);
@ -157,6 +159,8 @@ class ConstraintTree
void applySubstitution (const Substitution&);
void project (const LogVarSet&);
ConstraintTree projectedCopy (const LogVarSet&);
void remove (const LogVarSet&);
@ -197,6 +201,8 @@ class ConstraintTree
ConstraintTrees ground (LogVar);
void copyLogVar (LogVar,LogVar);
ConstraintTree& operator= (const ConstraintTree& ct);
private:
unsigned countTuples (const CTNode*) const;

View File

@ -6,7 +6,7 @@ bool CountingBp::checkForIdenticalFactors = true;
CountingBp::CountingBp (const FactorGraph& fg)
: Solver (fg), freeColor_(0)
: GroundSolver (fg), freeColor_(0)
{
findIdenticalFactors();
setInitialColors();
@ -74,8 +74,8 @@ CountingBp::solveQuery (VarIds queryVids)
cout << endl;
}
if (idx == facNodes.size()) {
res = Solver::getJointByConditioning (
GroundSolver::CBP, fg, queryVids);
res = GroundSolver::getJointByConditioning (
GroundSolverType::CBP, fg, queryVids);
} else {
VarIds reprArgs;
for (size_t i = 0; i < queryVids.size(); i++) {

View File

@ -3,7 +3,7 @@
#include <unordered_map>
#include "Solver.h"
#include "GroundSolver.h"
#include "FactorGraph.h"
#include "Util.h"
#include "Horus.h"
@ -102,7 +102,7 @@ class FacCluster
};
class CountingBp : public Solver
class CountingBp : public GroundSolver
{
public:
CountingBp (const FactorGraph& fg);

View File

@ -1,4 +1,4 @@
#include "Solver.h"
#include "GroundSolver.h"
#include "Util.h"
#include "BeliefProp.h"
#include "CountingBp.h"
@ -6,7 +6,7 @@
void
Solver::printAnswer (const VarIds& vids)
GroundSolver::printAnswer (const VarIds& vids)
{
Vars unobservedVars;
VarIds unobservedVids;
@ -32,7 +32,7 @@ Solver::printAnswer (const VarIds& vids)
void
Solver::printAllPosterioris (void)
GroundSolver::printAllPosterioris (void)
{
VarNodes vars = fg.varNodes();
std::sort (vars.begin(), vars.end(), sortByVarId());
@ -44,8 +44,8 @@ Solver::printAllPosterioris (void)
Params
Solver::getJointByConditioning (
GroundSolver solverType,
GroundSolver::getJointByConditioning (
GroundSolverType solverType,
FactorGraph fg,
const VarIds& jointVarIds) const
{
@ -55,11 +55,11 @@ Solver::getJointByConditioning (
jointVars.push_back (fg.getVarNode (jointVarIds[i]));
}
Solver* solver = 0;
GroundSolver* solver = 0;
switch (solverType) {
case GroundSolver::BP: solver = new BeliefProp (fg); break;
case GroundSolver::CBP: solver = new CountingBp (fg); break;
case GroundSolver::VE: solver = new VarElim (fg); break;
case GroundSolverType::BP: solver = new BeliefProp (fg); break;
case GroundSolverType::CBP: solver = new CountingBp (fg); break;
case GroundSolverType::VE: solver = new VarElim (fg); break;
}
Params prevBeliefs = solver->solveQuery ({jointVarIds[0]});
VarIds observedVids = {jointVars[0]->varId()};
@ -80,9 +80,9 @@ Solver::getJointByConditioning (
}
delete solver;
switch (solverType) {
case GroundSolver::BP: solver = new BeliefProp (fg); break;
case GroundSolver::CBP: solver = new CountingBp (fg); break;
case GroundSolver::VE: solver = new VarElim (fg); break;
case GroundSolverType::BP: solver = new BeliefProp (fg); break;
case GroundSolverType::CBP: solver = new CountingBp (fg); break;
case GroundSolverType::VE: solver = new VarElim (fg); break;
}
Params beliefs = solver->solveQuery ({jointVarIds[i]});
for (size_t k = 0; k < beliefs.size(); k++) {

View File

@ -1,5 +1,5 @@
#ifndef HORUS_SOLVER_H
#define HORUS_SOLVER_H
#ifndef HORUS_GROUNDSOLVER_H
#define HORUS_GROUNDSOLVER_H
#include <iomanip>
@ -10,12 +10,12 @@
using namespace std;
class Solver
class GroundSolver
{
public:
Solver (const FactorGraph& factorGraph) : fg(factorGraph) { }
GroundSolver (const FactorGraph& factorGraph) : fg(factorGraph) { }
virtual ~Solver() { } // ensure that subclass destructor is called
virtual ~GroundSolver() { } // ensure that subclass destructor is called
virtual Params solveQuery (VarIds queryVids) = 0;
@ -25,12 +25,12 @@ class Solver
void printAllPosterioris (void);
Params getJointByConditioning (GroundSolver,
Params getJointByConditioning (GroundSolverType,
FactorGraph, const VarIds& jointVarIds) const;
protected:
const FactorGraph& fg;
};
#endif // HORUS_SOLVER_H
#endif // HORUS_GROUNDSOLVER_H

View File

@ -28,14 +28,15 @@ typedef vector<unsigned> Ranges;
typedef unsigned long long ullong;
enum LiftedSolver
enum LiftedSolverType
{
FOVE, // first order variable elimination
LBP, // lifted belief propagation
LVE, // generalized counting first-order variable elimination (GC-FOVE)
LBP, // lifted first-order belief propagation
LKC // lifted first-order knowledge compilation
};
enum GroundSolver
enum GroundSolverType
{
VE, // variable elimination
BP, // belief propagation
@ -50,8 +51,8 @@ extern bool logDomain;
// level of debug information
extern unsigned verbosity;
extern LiftedSolver liftedSolver;
extern GroundSolver groundSolver;
extern LiftedSolverType liftedSolver;
extern GroundSolverType groundSolver;
};

View File

@ -160,15 +160,15 @@ readQueryAndEvidence (
void
runSolver (const FactorGraph& fg, const VarIds& queryIds)
{
Solver* solver = 0;
GroundSolver* solver = 0;
switch (Globals::groundSolver) {
case GroundSolver::VE:
case GroundSolverType::VE:
solver = new VarElim (fg);
break;
case GroundSolver::BP:
case GroundSolverType::BP:
solver = new BeliefProp (fg);
break;
case GroundSolver::CBP:
case GroundSolverType::CBP:
solver = new CountingBp (fg);
break;
default:

View File

@ -9,11 +9,13 @@
#include "ParfactorList.h"
#include "FactorGraph.h"
#include "LiftedOperations.h"
#include "LiftedVe.h"
#include "VarElim.h"
#include "LiftedBp.h"
#include "CountingBp.h"
#include "BeliefProp.h"
#include "LiftedKc.h"
#include "ElimGraph.h"
#include "BayesBall.h"
@ -22,25 +24,15 @@ using namespace std;
typedef std::pair<ParfactorList*, ObservedFormulas*> LiftedNetwork;
Params readParameters (YAP_Term);
vector<unsigned> readUnsignedList (YAP_Term);
Parfactor* readParfactor (YAP_Term);
void readLiftedEvidence (YAP_Term, ObservedFormulas&);
Parfactor* readParfactor (YAP_Term);
vector<unsigned> readUnsignedList (YAP_Term list);
Params readParameters (YAP_Term);
vector<unsigned>
readUnsignedList (YAP_Term list)
{
vector<unsigned> vec;
while (list != YAP_TermNil()) {
vec.push_back ((unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (list)));
list = YAP_TailOfTerm (list);
}
return vec;
}
YAP_Term fillAnswersPrologList (vector<Params>& results);
@ -76,138 +68,13 @@ createLiftedNetwork (void)
readLiftedEvidence (YAP_ARG2, *(obsFormulas));
LiftedNetwork* net = new LiftedNetwork (pfList, obsFormulas);
YAP_Int p = (YAP_Int) (net);
return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG3);
}
Parfactor*
readParfactor (YAP_Term pfTerm)
{
// read dist id
unsigned distId = YAP_IntOfTerm (YAP_ArgOfTerm (1, pfTerm));
// read the ranges
Ranges ranges;
YAP_Term rangeList = YAP_ArgOfTerm (3, pfTerm);
while (rangeList != YAP_TermNil()) {
unsigned range = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (rangeList));
ranges.push_back (range);
rangeList = YAP_TailOfTerm (rangeList);
}
// read parametric random vars
ProbFormulas formulas;
unsigned count = 0;
unordered_map<YAP_Term, LogVar> lvMap;
YAP_Term pvList = YAP_ArgOfTerm (2, pfTerm);
while (pvList != YAP_TermNil()) {
YAP_Term formulaTerm = YAP_HeadOfTerm (pvList);
if (YAP_IsAtomTerm (formulaTerm)) {
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (formulaTerm)));
Symbol functor = LiftedUtils::getSymbol (name);
formulas.push_back (ProbFormula (functor, ranges[count]));
} else {
LogVars logVars;
YAP_Functor yapFunctor = YAP_FunctorOfTerm (formulaTerm);
string name ((char*) YAP_AtomName (YAP_NameOfFunctor (yapFunctor)));
Symbol functor = LiftedUtils::getSymbol (name);
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
for (unsigned i = 1; i <= arity; i++) {
YAP_Term ti = YAP_ArgOfTerm (i, formulaTerm);
unordered_map<YAP_Term, LogVar>::iterator it = lvMap.find (ti);
if (it != lvMap.end()) {
logVars.push_back (it->second);
} else {
unsigned newLv = lvMap.size();
lvMap[ti] = newLv;
logVars.push_back (newLv);
}
}
formulas.push_back (ProbFormula (functor, logVars, ranges[count]));
}
count ++;
pvList = YAP_TailOfTerm (pvList);
}
// read the parameters
const Params& params = readParameters (YAP_ArgOfTerm (4, pfTerm));
// read the constraint
Tuples tuples;
if (lvMap.size() >= 1) {
YAP_Term tupleList = YAP_ArgOfTerm (5, pfTerm);
while (tupleList != YAP_TermNil()) {
YAP_Term term = YAP_HeadOfTerm (tupleList);
assert (YAP_IsApplTerm (term));
YAP_Functor yapFunctor = YAP_FunctorOfTerm (term);
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
assert (lvMap.size() == arity);
Tuple tuple (arity);
for (unsigned i = 1; i <= arity; i++) {
YAP_Term ti = YAP_ArgOfTerm (i, term);
if (YAP_IsAtomTerm (ti) == false) {
cerr << "error: constraint has free variables" << endl;
abort();
}
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ti)));
tuple[i - 1] = LiftedUtils::getSymbol (name);
}
tuples.push_back (tuple);
tupleList = YAP_TailOfTerm (tupleList);
}
}
return new Parfactor (formulas, params, tuples, distId);
}
void
readLiftedEvidence (
YAP_Term observedList,
ObservedFormulas& obsFormulas)
{
while (observedList != YAP_TermNil()) {
YAP_Term pair = YAP_HeadOfTerm (observedList);
YAP_Term ground = YAP_ArgOfTerm (1, pair);
Symbol functor;
Symbols args;
if (YAP_IsAtomTerm (ground)) {
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ground)));
functor = LiftedUtils::getSymbol (name);
} else {
assert (YAP_IsApplTerm (ground));
YAP_Functor yapFunctor = YAP_FunctorOfTerm (ground);
string name ((char*) (YAP_AtomName (YAP_NameOfFunctor (yapFunctor))));
functor = LiftedUtils::getSymbol (name);
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
for (unsigned i = 1; i <= arity; i++) {
YAP_Term ti = YAP_ArgOfTerm (i, ground);
assert (YAP_IsAtomTerm (ti));
string arg ((char *) YAP_AtomName (YAP_AtomOfTerm (ti)));
args.push_back (LiftedUtils::getSymbol (arg));
}
}
unsigned evidence = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (2, pair));
bool found = false;
for (size_t i = 0; i < obsFormulas.size(); i++) {
if (obsFormulas[i].functor() == functor &&
obsFormulas[i].arity() == args.size() &&
obsFormulas[i].evidence() == evidence) {
obsFormulas[i].addTuple (args);
found = true;
}
}
if (found == false) {
obsFormulas.push_back (ObservedFormula (functor, evidence, args));
}
observedList = YAP_TailOfTerm (observedList);
}
}
int
createGroundNetwork (void)
{
@ -253,31 +120,27 @@ createGroundNetwork (void)
Params
readParameters (YAP_Term paramL)
{
Params params;
assert (YAP_IsPairTerm (paramL));
while (paramL != YAP_TermNil()) {
params.push_back ((double) YAP_FloatOfTerm (YAP_HeadOfTerm (paramL)));
paramL = YAP_TailOfTerm (paramL);
}
if (Globals::logDomain) {
Util::log (params);
}
return params;
}
int
runLiftedSolver (void)
{
LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1);
ParfactorList pfListCopy (*network->first);
LiftedOperations::absorveEvidence (pfListCopy, *network->second);
LiftedSolver* solver = 0;
switch (Globals::liftedSolver) {
case LiftedSolverType::LVE: solver = new LiftedVe (pfListCopy); break;
case LiftedSolverType::LBP: solver = new LiftedBp (pfListCopy); break;
case LiftedSolverType::LKC: solver = new LiftedKc (pfListCopy); break;
}
if (Globals::verbosity > 0) {
solver->printSolverFlags();
cout << endl;
}
YAP_Term taskList = YAP_ARG2;
vector<Params> results;
ParfactorList pfListCopy (*network->first);
LiftedVe::absorveEvidence (pfListCopy, *network->second);
while (taskList != YAP_TermNil()) {
Grounds queryVars;
YAP_Term jointList = YAP_HeadOfTerm (taskList);
@ -303,41 +166,13 @@ runLiftedSolver (void)
}
jointList = YAP_TailOfTerm (jointList);
}
if (Globals::liftedSolver == LiftedSolver::FOVE) {
LiftedVe solver (pfListCopy);
if (Globals::verbosity > 0 && taskList == YAP_ARG2) {
solver.printSolverFlags();
cout << endl;
}
results.push_back (solver.solveQuery (queryVars));
} else if (Globals::liftedSolver == LiftedSolver::LBP) {
LiftedBp solver (pfListCopy);
if (Globals::verbosity > 0 && taskList == YAP_ARG2) {
solver.printSolverFlags();
cout << endl;
}
results.push_back (solver.solveQuery (queryVars));
} else {
assert (false);
}
results.push_back (solver->solveQuery (queryVars));
taskList = YAP_TailOfTerm (taskList);
}
YAP_Term list = YAP_TermNil();
for (size_t i = results.size(); i-- > 0; ) {
const Params& beliefs = results[i];
YAP_Term queryBeliefsL = YAP_TermNil();
for (size_t j = beliefs.size(); j-- > 0; ) {
YAP_Int sl1 = YAP_InitSlot (list);
YAP_Term belief = YAP_MkFloatTerm (beliefs[j]);
queryBeliefsL = YAP_MkPairTerm (belief, queryBeliefsL);
list = YAP_GetFromSlot (sl1);
YAP_RecoverSlots (1);
}
list = YAP_MkPairTerm (queryBeliefsL, list);
}
return YAP_Unify (list, YAP_ARG3);
delete solver;
return YAP_Unify (fillAnswersPrologList (results), YAP_ARG3);
}
@ -346,6 +181,7 @@ int
runGroundSolver (void)
{
FactorGraph* fg = (FactorGraph*) YAP_IntOfTerm (YAP_ARG1);
vector<VarIds> tasks;
YAP_Term taskList = YAP_ARG2;
while (taskList != YAP_TermNil()) {
@ -357,22 +193,19 @@ runGroundSolver (void)
for (size_t i = 0; i < tasks.size(); i++) {
Util::addToSet (vids, tasks[i]);
}
Solver* solver = 0;
FactorGraph* mfg = fg;
if (fg->bayesianFactors()) {
mfg = BayesBall::getMinimalFactorGraph (
*fg, VarIds (vids.begin(), vids.end()));
}
if (Globals::groundSolver == GroundSolver::VE) {
solver = new VarElim (*mfg);
} else if (Globals::groundSolver == GroundSolver::BP) {
solver = new BeliefProp (*mfg);
} else if (Globals::groundSolver == GroundSolver::CBP) {
CountingBp::checkForIdenticalFactors = false;
solver = new CountingBp (*mfg);
} else {
assert (false);
GroundSolver* solver = 0;
CountingBp::checkForIdenticalFactors = false;
switch (Globals::groundSolver) {
case GroundSolverType::VE: solver = new VarElim (*mfg); break;
case GroundSolverType::BP: solver = new BeliefProp (*mfg); break;
case GroundSolverType::CBP: solver = new CountingBp (*mfg); break;
}
if (Globals::verbosity > 0) {
@ -391,20 +224,7 @@ runGroundSolver (void)
delete mfg;
}
YAP_Term list = YAP_TermNil();
for (size_t i = results.size(); i-- > 0; ) {
const Params& beliefs = results[i];
YAP_Term queryBeliefsL = YAP_TermNil();
for (size_t j = beliefs.size(); j-- > 0; ) {
YAP_Int sl1 = YAP_InitSlot (list);
YAP_Term belief = YAP_MkFloatTerm (beliefs[j]);
queryBeliefsL = YAP_MkPairTerm (belief, queryBeliefsL);
list = YAP_GetFromSlot (sl1);
YAP_RecoverSlots (1);
}
list = YAP_MkPairTerm (queryBeliefsL, list);
}
return YAP_Unify (list, YAP_ARG3);
return YAP_Unify (fillAnswersPrologList (results), YAP_ARG3);
}
@ -535,6 +355,183 @@ freeLiftedNetwork (void)
Parfactor*
readParfactor (YAP_Term pfTerm)
{
// read dist id
unsigned distId = YAP_IntOfTerm (YAP_ArgOfTerm (1, pfTerm));
// read the ranges
Ranges ranges;
YAP_Term rangeList = YAP_ArgOfTerm (3, pfTerm);
while (rangeList != YAP_TermNil()) {
unsigned range = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (rangeList));
ranges.push_back (range);
rangeList = YAP_TailOfTerm (rangeList);
}
// read parametric random vars
ProbFormulas formulas;
unsigned count = 0;
unordered_map<YAP_Term, LogVar> lvMap;
YAP_Term pvList = YAP_ArgOfTerm (2, pfTerm);
while (pvList != YAP_TermNil()) {
YAP_Term formulaTerm = YAP_HeadOfTerm (pvList);
if (YAP_IsAtomTerm (formulaTerm)) {
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (formulaTerm)));
Symbol functor = LiftedUtils::getSymbol (name);
formulas.push_back (ProbFormula (functor, ranges[count]));
} else {
LogVars logVars;
YAP_Functor yapFunctor = YAP_FunctorOfTerm (formulaTerm);
string name ((char*) YAP_AtomName (YAP_NameOfFunctor (yapFunctor)));
Symbol functor = LiftedUtils::getSymbol (name);
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
for (unsigned i = 1; i <= arity; i++) {
YAP_Term ti = YAP_ArgOfTerm (i, formulaTerm);
unordered_map<YAP_Term, LogVar>::iterator it = lvMap.find (ti);
if (it != lvMap.end()) {
logVars.push_back (it->second);
} else {
unsigned newLv = lvMap.size();
lvMap[ti] = newLv;
logVars.push_back (newLv);
}
}
formulas.push_back (ProbFormula (functor, logVars, ranges[count]));
}
count ++;
pvList = YAP_TailOfTerm (pvList);
}
// read the parameters
const Params& params = readParameters (YAP_ArgOfTerm (4, pfTerm));
// read the constraint
Tuples tuples;
if (lvMap.size() >= 1) {
YAP_Term tupleList = YAP_ArgOfTerm (5, pfTerm);
while (tupleList != YAP_TermNil()) {
YAP_Term term = YAP_HeadOfTerm (tupleList);
assert (YAP_IsApplTerm (term));
YAP_Functor yapFunctor = YAP_FunctorOfTerm (term);
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
assert (lvMap.size() == arity);
Tuple tuple (arity);
for (unsigned i = 1; i <= arity; i++) {
YAP_Term ti = YAP_ArgOfTerm (i, term);
if (YAP_IsAtomTerm (ti) == false) {
cerr << "error: constraint has free variables" << endl;
abort();
}
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ti)));
tuple[i - 1] = LiftedUtils::getSymbol (name);
}
tuples.push_back (tuple);
tupleList = YAP_TailOfTerm (tupleList);
}
}
return new Parfactor (formulas, params, tuples, distId);
}
void
readLiftedEvidence (
YAP_Term observedList,
ObservedFormulas& obsFormulas)
{
while (observedList != YAP_TermNil()) {
YAP_Term pair = YAP_HeadOfTerm (observedList);
YAP_Term ground = YAP_ArgOfTerm (1, pair);
Symbol functor;
Symbols args;
if (YAP_IsAtomTerm (ground)) {
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ground)));
functor = LiftedUtils::getSymbol (name);
} else {
assert (YAP_IsApplTerm (ground));
YAP_Functor yapFunctor = YAP_FunctorOfTerm (ground);
string name ((char*) (YAP_AtomName (YAP_NameOfFunctor (yapFunctor))));
functor = LiftedUtils::getSymbol (name);
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
for (unsigned i = 1; i <= arity; i++) {
YAP_Term ti = YAP_ArgOfTerm (i, ground);
assert (YAP_IsAtomTerm (ti));
string arg ((char *) YAP_AtomName (YAP_AtomOfTerm (ti)));
args.push_back (LiftedUtils::getSymbol (arg));
}
}
unsigned evidence = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (2, pair));
bool found = false;
for (size_t i = 0; i < obsFormulas.size(); i++) {
if (obsFormulas[i].functor() == functor &&
obsFormulas[i].arity() == args.size() &&
obsFormulas[i].evidence() == evidence) {
obsFormulas[i].addTuple (args);
found = true;
}
}
if (found == false) {
obsFormulas.push_back (ObservedFormula (functor, evidence, args));
}
observedList = YAP_TailOfTerm (observedList);
}
}
vector<unsigned>
readUnsignedList (YAP_Term list)
{
vector<unsigned> vec;
while (list != YAP_TermNil()) {
vec.push_back ((unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (list)));
list = YAP_TailOfTerm (list);
}
return vec;
}
Params
readParameters (YAP_Term paramL)
{
Params params;
assert (YAP_IsPairTerm (paramL));
while (paramL != YAP_TermNil()) {
params.push_back ((double) YAP_FloatOfTerm (YAP_HeadOfTerm (paramL)));
paramL = YAP_TailOfTerm (paramL);
}
if (Globals::logDomain) {
Util::log (params);
}
return params;
}
YAP_Term
fillAnswersPrologList (vector<Params>& results)
{
YAP_Term list = YAP_TermNil();
for (size_t i = results.size(); i-- > 0; ) {
const Params& beliefs = results[i];
YAP_Term queryBeliefsL = YAP_TermNil();
for (size_t j = beliefs.size(); j-- > 0; ) {
YAP_Int sl1 = YAP_InitSlot (list);
YAP_Term belief = YAP_MkFloatTerm (beliefs[j]);
queryBeliefsL = YAP_MkPairTerm (belief, queryBeliefsL);
list = YAP_GetFromSlot (sl1);
YAP_RecoverSlots (1);
}
list = YAP_MkPairTerm (queryBeliefsL, list);
}
return list;
}
extern "C" void
init_predicates (void)
{

View File

@ -1,11 +1,11 @@
#include "LiftedBp.h"
#include "WeightedBp.h"
#include "FactorGraph.h"
#include "LiftedVe.h"
#include "LiftedOperations.h"
LiftedBp::LiftedBp (const ParfactorList& pfList)
: pfList_(pfList)
LiftedBp::LiftedBp (const ParfactorList& parfactorList)
: LiftedSolver (parfactorList)
{
refineParfactors();
createFactorGraph();
@ -82,6 +82,7 @@ LiftedBp::printSolverFlags (void) const
void
LiftedBp::refineParfactors (void)
{
pfList_ = parfactorList;
while (iterate() == false);
if (Globals::verbosity > 2) {
@ -101,7 +102,7 @@ LiftedBp::iterate (void)
for (size_t i = 0; i < args.size(); i++) {
LogVarSet lvs = (*it)->logVarSet() - args[i].logVars();
if ((*it)->constr()->isCountNormalized (lvs) == false) {
Parfactors pfs = LiftedVe::countNormalize (*it, lvs);
Parfactors pfs = LiftedOperations::countNormalize (*it, lvs);
it = pfList_.removeAndDelete (it);
pfList_.add (pfs);
return false;
@ -189,12 +190,12 @@ LiftedBp::rangeOfGround (const Ground& gr)
Params
LiftedBp::getJointByConditioning (
const ParfactorList& pfList,
const Grounds& grounds)
const Grounds& query)
{
LiftedBp solver (pfList);
Params prevBeliefs = solver.solveQuery ({grounds[0]});
Grounds obsGrounds = {grounds[0]};
for (size_t i = 1; i < grounds.size(); i++) {
Params prevBeliefs = solver.solveQuery ({query[0]});
Grounds obsGrounds = {query[0]};
for (size_t i = 1; i < query.size(); i++) {
Params newBeliefs;
vector<ObservedFormula> obsFs;
Ranges obsRanges;
@ -209,16 +210,16 @@ LiftedBp::getJointByConditioning (
obsFs[j].setEvidence (indexer[j]);
}
ParfactorList tempPfList (pfList);
LiftedVe::absorveEvidence (tempPfList, obsFs);
LiftedOperations::absorveEvidence (tempPfList, obsFs);
LiftedBp solver (tempPfList);
Params beliefs = solver.solveQuery ({grounds[i]});
Params beliefs = solver.solveQuery ({query[i]});
for (size_t k = 0; k < beliefs.size(); k++) {
newBeliefs.push_back (beliefs[k]);
}
++ indexer;
}
int count = -1;
unsigned range = rangeOfGround (grounds[i]);
unsigned range = rangeOfGround (query[i]);
for (size_t j = 0; j < newBeliefs.size(); j++) {
if (j % range == 0) {
count ++;
@ -226,7 +227,7 @@ LiftedBp::getJointByConditioning (
newBeliefs[j] *= prevBeliefs[count];
}
prevBeliefs = newBeliefs;
obsGrounds.push_back (grounds[i]);
obsGrounds.push_back (query[i]);
}
return prevBeliefs;
}

View File

@ -1,12 +1,13 @@
#ifndef HORUS_LIFTEDBP_H
#define HORUS_LIFTEDBP_H
#include "LiftedSolver.h"
#include "ParfactorList.h"
class FactorGraph;
class WeightedBp;
class LiftedBp
class LiftedBp : public LiftedSolver
{
public:
LiftedBp (const ParfactorList& pfList);
@ -39,3 +40,4 @@ class LiftedBp
};
#endif // HORUS_LIFTEDBP_H

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,274 @@
#ifndef HORUS_LIFTEDCIRCUIT_H
#define HORUS_LIFTEDCIRCUIT_H
#include <stack>
#include "LiftedWCNF.h"
enum CircuitNodeType {
OR_NODE,
AND_NODE,
SET_OR_NODE,
SET_AND_NODE,
INC_EXC_NODE,
LEAF_NODE,
SMOOTH_NODE,
TRUE_NODE,
COMPILATION_FAILED_NODE
};
class CircuitNode
{
public:
CircuitNode (const Clauses& clauses, string explanation = "")
: clauses_(clauses), explanation_(explanation) { }
const Clauses& clauses (void) const { return clauses_; }
Clauses clauses (void) { return clauses_; }
virtual double weight (void) const = 0;
string explanation (void) const { return explanation_; }
private:
Clauses clauses_;
string explanation_;
};
class OrNode : public CircuitNode
{
public:
OrNode (const Clauses& clauses, string explanation = "")
: CircuitNode (clauses, explanation),
leftBranch_(0), rightBranch_(0) { }
CircuitNode** leftBranch (void) { return &leftBranch_; }
CircuitNode** rightBranch (void) { return &rightBranch_; }
double weight (void) const;
private:
CircuitNode* leftBranch_;
CircuitNode* rightBranch_;
};
class AndNode : public CircuitNode
{
public:
AndNode (const Clauses& clauses, string explanation = "")
: CircuitNode (clauses, explanation),
leftBranch_(0), rightBranch_(0) { }
AndNode (
const Clauses& clauses,
CircuitNode* leftBranch,
CircuitNode* rightBranch,
string explanation = "")
: CircuitNode (clauses, explanation),
leftBranch_(leftBranch), rightBranch_(rightBranch) { }
AndNode (
CircuitNode* leftBranch,
CircuitNode* rightBranch,
string explanation = "")
: CircuitNode ({}, explanation),
leftBranch_(leftBranch), rightBranch_(rightBranch) { }
CircuitNode** leftBranch (void) { return &leftBranch_; }
CircuitNode** rightBranch (void) { return &rightBranch_; }
double weight (void) const;
private:
CircuitNode* leftBranch_;
CircuitNode* rightBranch_;
};
class SetOrNode : public CircuitNode
{
public:
SetOrNode (unsigned nrGroundings, const Clauses& clauses)
: CircuitNode (clauses, " AC"), follow_(0),
nrGroundings_(nrGroundings) { }
CircuitNode** follow (void) { return &follow_; }
static unsigned nrPositives (void) { return nrGrsStack.top().first; }
static unsigned nrNegatives (void) { return nrGrsStack.top().second; }
double weight (void) const;
private:
CircuitNode* follow_;
unsigned nrGroundings_;
static stack<pair<unsigned, unsigned>> nrGrsStack;
};
class SetAndNode : public CircuitNode
{
public:
SetAndNode (unsigned nrGroundings, const Clauses& clauses)
: CircuitNode (clauses, " IPG"), follow_(0),
nrGroundings_(nrGroundings) { }
CircuitNode** follow (void) { return &follow_; }
double weight (void) const;
private:
CircuitNode* follow_;
unsigned nrGroundings_;
};
class IncExcNode : public CircuitNode
{
public:
IncExcNode (const Clauses& clauses, string explanation)
: CircuitNode (clauses, explanation), plus1Branch_(0),
plus2Branch_(0), minusBranch_(0) { }
CircuitNode** plus1Branch (void) { return &plus1Branch_; }
CircuitNode** plus2Branch (void) { return &plus2Branch_; }
CircuitNode** minusBranch (void) { return &minusBranch_; }
double weight (void) const;
private:
CircuitNode* plus1Branch_;
CircuitNode* plus2Branch_;
CircuitNode* minusBranch_;
};
class LeafNode : public CircuitNode
{
public:
LeafNode (const Clause& clause, const LiftedWCNF& lwcnf)
: CircuitNode (Clauses() = {clause}), lwcnf_(lwcnf) { }
double weight (void) const;
private:
const LiftedWCNF& lwcnf_;
};
class SmoothNode : public CircuitNode
{
public:
SmoothNode (const Clauses& clauses, const LiftedWCNF& lwcnf)
: CircuitNode (clauses), lwcnf_(lwcnf) { }
double weight (void) const;
private:
const LiftedWCNF& lwcnf_;
};
class TrueNode : public CircuitNode
{
public:
TrueNode (void) : CircuitNode ({}) { }
double weight (void) const;
};
class CompilationFailedNode : public CircuitNode
{
public:
CompilationFailedNode (const Clauses& clauses)
: CircuitNode (clauses) { }
double weight (void) const;
};
class LiftedCircuit
{
public:
LiftedCircuit (const LiftedWCNF* lwcnf);
double getWeightedModelCount (void) const;
void exportToGraphViz (const char*);
private:
void compile (CircuitNode** follow, Clauses& clauses);
bool tryUnitPropagation (CircuitNode** follow, Clauses& clauses);
bool tryIndependence (CircuitNode** follow, Clauses& clauses);
bool tryShannonDecomp (CircuitNode** follow, Clauses& clauses);
bool tryInclusionExclusion (CircuitNode** follow, Clauses& clauses);
bool tryIndepPartialGrounding (CircuitNode** follow, Clauses& clauses);
bool tryIndepPartialGroundingAux (Clauses& clauses, ConstraintTree& ct,
LogVars& rootLogVars);
bool tryAtomCounting (CircuitNode** follow, Clauses& clauses);
bool tryGrounding (CircuitNode** follow, Clauses& clauses);
void shatterCountedLogVars (Clauses& clauses);
bool shatterCountedLogVarsAux (Clauses& clauses);
bool shatterCountedLogVarsAux (Clauses& clauses, size_t idx1, size_t idx2);
bool independentClause (Clause& clause, Clauses& otherClauses) const;
bool independentLiteral (const Literal& lit,
const Literals& otherLits) const;
LitLvTypesSet smoothCircuit (CircuitNode* node);
void createSmoothNode (const LitLvTypesSet& lids,
CircuitNode** prev);
vector<LogVarTypes> getAllPossibleTypes (unsigned nrLogVars) const;
bool containsTypes (const LogVarTypes& typesA,
const LogVarTypes& typesB) const;
CircuitNodeType getCircuitNodeType (const CircuitNode* node) const;
void exportToGraphViz (CircuitNode* node, ofstream&);
void printClauses (const CircuitNode* node, ofstream&,
string extraOptions = "");
string escapeNode (const CircuitNode* node) const;
CircuitNode* root_;
const LiftedWCNF* lwcnf_;
};
#endif // HORUS_LIFTEDCIRCUIT_H

View File

@ -0,0 +1,75 @@
#include "LiftedKc.h"
#include "LiftedWCNF.h"
#include "LiftedCircuit.h"
#include "LiftedOperations.h"
#include "Indexer.h"
LiftedKc::~LiftedKc (void)
{
delete lwcnf_;
delete circuit_;
}
Params
LiftedKc::solveQuery (const Grounds& query)
{
pfList_ = parfactorList;
LiftedOperations::shatterAgainstQuery (pfList_, query);
LiftedOperations::runWeakBayesBall (pfList_, query);
lwcnf_ = new LiftedWCNF (pfList_);
circuit_ = new LiftedCircuit (lwcnf_);
vector<PrvGroup> groups;
Ranges ranges;
for (size_t i = 0; i < query.size(); i++) {
ParfactorList::const_iterator it = pfList_.begin();
while (it != pfList_.end()) {
size_t idx = (*it)->indexOfGround (query[i]);
if (idx != (*it)->nrArguments()) {
groups.push_back ((*it)->argument (idx).group());
ranges.push_back ((*it)->range (idx));
break;
}
++ it;
}
}
assert (groups.size() == query.size());
Params params;
Indexer indexer (ranges);
while (indexer.valid()) {
for (size_t i = 0; i < groups.size(); i++) {
vector<LiteralId> litIds = lwcnf_->prvGroupLiterals (groups[i]);
for (size_t j = 0; j < litIds.size(); j++) {
if (indexer[i] == j) {
lwcnf_->addWeight (litIds[j], LogAware::one(),
LogAware::one());
} else {
lwcnf_->addWeight (litIds[j], LogAware::zero(),
LogAware::one());
}
}
}
params.push_back (circuit_->getWeightedModelCount());
++ indexer;
}
LogAware::normalize (params);
if (Globals::logDomain) {
Util::exp (params);
}
return params;
}
void
LiftedKc::printSolverFlags (void) const
{
stringstream ss;
ss << "lifted kc [" ;
ss << "log_domain=" << Util::toString (Globals::logDomain);
ss << "]" ;
cout << ss.str() << endl;
}

View File

@ -0,0 +1,30 @@
#ifndef HORUS_LIFTEDKC_H
#define HORUS_LIFTEDKC_H
#include "LiftedSolver.h"
#include "ParfactorList.h"
class LiftedWCNF;
class LiftedCircuit;
class LiftedKc : public LiftedSolver
{
public:
LiftedKc (const ParfactorList& pfList)
: LiftedSolver(pfList) { }
~LiftedKc (void);
Params solveQuery (const Grounds&);
void printSolverFlags (void) const;
private:
LiftedWCNF* lwcnf_;
LiftedCircuit* circuit_;
ParfactorList pfList_;
};
#endif // HORUS_LIFTEDKC_H

View File

@ -0,0 +1,271 @@
#include "LiftedOperations.h"
void
LiftedOperations::shatterAgainstQuery (
ParfactorList& pfList,
const Grounds& query)
{
for (size_t i = 0; i < query.size(); i++) {
if (query[i].isAtom()) {
continue;
}
bool found = false;
Parfactors newPfs;
ParfactorList::iterator it = pfList.begin();
while (it != pfList.end()) {
if ((*it)->containsGround (query[i])) {
found = true;
std::pair<ConstraintTree*, ConstraintTree*> split;
LogVars queryLvs (
(*it)->constr()->logVars().begin(),
(*it)->constr()->logVars().begin() + query[i].arity());
split = (*it)->constr()->split (query[i].args());
ConstraintTree* commCt = split.first;
ConstraintTree* exclCt = split.second;
newPfs.push_back (new Parfactor (*it, commCt));
if (exclCt->empty() == false) {
newPfs.push_back (new Parfactor (*it, exclCt));
} else {
delete exclCt;
}
it = pfList.removeAndDelete (it);
} else {
++ it;
}
}
if (found == false) {
cerr << "error: could not find a parfactor with ground " ;
cerr << "`" << query[i] << "'" << endl;
exit (0);
}
pfList.add (newPfs);
}
if (Globals::verbosity > 2) {
Util::printAsteriskLine();
cout << "SHATTERED AGAINST THE QUERY" << endl;
for (size_t i = 0; i < query.size(); i++) {
cout << " -> " << query[i] << endl;
}
Util::printAsteriskLine();
pfList.print();
}
}
void
LiftedOperations::runWeakBayesBall (
ParfactorList& pfList,
const Grounds& query)
{
queue<PrvGroup> todo; // groups to process
set<PrvGroup> done; // processed or in queue
for (size_t i = 0; i < query.size(); i++) {
ParfactorList::iterator it = pfList.begin();
while (it != pfList.end()) {
PrvGroup group = (*it)->findGroup (query[i]);
if (group != numeric_limits<PrvGroup>::max()) {
todo.push (group);
done.insert (group);
break;
}
++ it;
}
}
set<Parfactor*> requiredPfs;
while (todo.empty() == false) {
PrvGroup group = todo.front();
ParfactorList::iterator it = pfList.begin();
while (it != pfList.end()) {
if (Util::contains (requiredPfs, *it) == false &&
(*it)->containsGroup (group)) {
vector<PrvGroup> groups = (*it)->getAllGroups();
for (size_t i = 0; i < groups.size(); i++) {
if (Util::contains (done, groups[i]) == false) {
todo.push (groups[i]);
done.insert (groups[i]);
}
}
requiredPfs.insert (*it);
}
++ it;
}
todo.pop();
}
ParfactorList::iterator it = pfList.begin();
bool foundNotRequired = false;
while (it != pfList.end()) {
if (Util::contains (requiredPfs, *it) == false) {
if (Globals::verbosity > 2) {
if (foundNotRequired == false) {
Util::printHeader ("PARFACTORS TO DISCARD");
foundNotRequired = true;
}
(*it)->print();
}
it = pfList.removeAndDelete (it);
} else {
++ it;
}
}
}
void
LiftedOperations::absorveEvidence (
ParfactorList& pfList,
ObservedFormulas& obsFormulas)
{
for (size_t i = 0; i < obsFormulas.size(); i++) {
Parfactors newPfs;
ParfactorList::iterator it = pfList.begin();
while (it != pfList.end()) {
Parfactor* pf = *it;
it = pfList.remove (it);
Parfactors absorvedPfs = absorve (obsFormulas[i], pf);
if (absorvedPfs.empty() == false) {
if (absorvedPfs.size() == 1 && absorvedPfs[0] == 0) {
// just remove pf;
} else {
Util::addToVector (newPfs, absorvedPfs);
}
delete pf;
} else {
it = pfList.insertShattered (it, pf);
++ it;
}
}
pfList.add (newPfs);
}
if (Globals::verbosity > 2 && obsFormulas.empty() == false) {
Util::printAsteriskLine();
cout << "AFTER EVIDENCE ABSORVED" << endl;
for (size_t i = 0; i < obsFormulas.size(); i++) {
cout << " -> " << obsFormulas[i] << endl;
}
Util::printAsteriskLine();
pfList.print();
}
}
Parfactors
LiftedOperations::countNormalize (
Parfactor* g,
const LogVarSet& set)
{
Parfactors normPfs;
if (set.empty()) {
normPfs.push_back (new Parfactor (*g));
} else {
ConstraintTrees normCts = g->constr()->countNormalize (set);
for (size_t i = 0; i < normCts.size(); i++) {
normPfs.push_back (new Parfactor (g, normCts[i]));
}
}
return normPfs;
}
Parfactor
LiftedOperations::calcGroundMultiplication (Parfactor pf)
{
LogVarSet lvs = pf.constr()->logVarSet();
lvs -= pf.constr()->singletons();
Parfactors newPfs = {new Parfactor (pf)};
for (size_t i = 0; i < lvs.size(); i++) {
Parfactors pfs = newPfs;
newPfs.clear();
for (size_t j = 0; j < pfs.size(); j++) {
bool countedLv = pfs[j]->countedLogVars().contains (lvs[i]);
if (countedLv) {
pfs[j]->fullExpand (lvs[i]);
newPfs.push_back (pfs[j]);
} else {
ConstraintTrees cts = pfs[j]->constr()->ground (lvs[i]);
for (size_t k = 0; k < cts.size(); k++) {
newPfs.push_back (new Parfactor (pfs[j], cts[k]));
}
delete pfs[j];
}
}
}
ParfactorList pfList (newPfs);
Parfactors groundShatteredPfs (pfList.begin(),pfList.end());
for (size_t i = 1; i < groundShatteredPfs.size(); i++) {
groundShatteredPfs[0]->multiply (*groundShatteredPfs[i]);
}
return Parfactor (*groundShatteredPfs[0]);
}
Parfactors
LiftedOperations::absorve (
ObservedFormula& obsFormula,
Parfactor* g)
{
Parfactors absorvedPfs;
const ProbFormulas& formulas = g->arguments();
for (size_t i = 0; i < formulas.size(); i++) {
if (obsFormula.functor() == formulas[i].functor() &&
obsFormula.arity() == formulas[i].arity()) {
if (obsFormula.isAtom()) {
if (formulas.size() > 1) {
g->absorveEvidence (formulas[i], obsFormula.evidence());
} else {
// hack to erase parfactor g
absorvedPfs.push_back (0);
}
break;
}
g->constr()->moveToTop (formulas[i].logVars());
std::pair<ConstraintTree*, ConstraintTree*> res;
res = g->constr()->split (
formulas[i].logVars(),
&(obsFormula.constr()),
obsFormula.constr().logVars());
ConstraintTree* commCt = res.first;
ConstraintTree* exclCt = res.second;
if (commCt->empty() == false) {
if (formulas.size() > 1) {
LogVarSet excl = g->exclusiveLogVars (i);
Parfactor tempPf (g, commCt);
Parfactors countNormPfs = LiftedOperations::countNormalize (
&tempPf, excl);
for (size_t j = 0; j < countNormPfs.size(); j++) {
countNormPfs[j]->absorveEvidence (
formulas[i], obsFormula.evidence());
absorvedPfs.push_back (countNormPfs[j]);
}
} else {
delete commCt;
}
if (exclCt->empty() == false) {
absorvedPfs.push_back (new Parfactor (g, exclCt));
} else {
delete exclCt;
}
if (absorvedPfs.empty()) {
// hack to erase parfactor g
absorvedPfs.push_back (0);
}
break;
} else {
delete commCt;
delete exclCt;
}
}
}
return absorvedPfs;
}

View File

@ -0,0 +1,26 @@
#ifndef HORUS_LIFTEDOPERATIONS_H
#define HORUS_LIFTEDOPERATIONS_H
#include "ParfactorList.h"
class LiftedOperations
{
public:
static void shatterAgainstQuery (
ParfactorList& pfList, const Grounds& query);
static void runWeakBayesBall (
ParfactorList& pfList, const Grounds&);
static void absorveEvidence (
ParfactorList& pfList, ObservedFormulas& obsFormulas);
static Parfactors countNormalize (Parfactor*, const LogVarSet&);
static Parfactor calcGroundMultiplication (Parfactor pf);
private:
static Parfactors absorve (ObservedFormula&, Parfactor*);
};
#endif // HORUS_LIFTEDOPERATIONS_H

View File

@ -0,0 +1,27 @@
#ifndef HORUS_LIFTEDSOLVER_H
#define HORUS_LIFTEDSOLVER_H
#include "ParfactorList.h"
#include "Horus.h"
using namespace std;
class LiftedSolver
{
public:
LiftedSolver (const ParfactorList& pfList)
: parfactorList(pfList) { }
virtual ~LiftedSolver() { } // ensure that subclass destructor is called
virtual Params solveQuery (const Grounds& query) = 0;
virtual void printSolverFlags (void) const = 0;
protected:
const ParfactorList& parfactorList;
};
#endif // HORUS_LIFTEDSOLVER_H

View File

@ -2,6 +2,7 @@
#include <set>
#include "LiftedVe.h"
#include "LiftedOperations.h"
#include "Histogram.h"
#include "Util.h"
@ -221,7 +222,7 @@ SumOutOperator::apply (void)
product->sumOutIndex (fIdx);
pfList_.addShattered (product);
} else {
Parfactors pfs = LiftedVe::countNormalize (product, excl);
Parfactors pfs = LiftedOperations::countNormalize (product, excl);
for (size_t i = 0; i < pfs.size(); i++) {
pfs[i]->sumOutIndex (fIdx);
pfList_.add (pfs[i]);
@ -375,7 +376,7 @@ CountingOperator::apply (void)
} else {
Parfactor* pf = *pfIter_;
pfList_.remove (pfIter_);
Parfactors pfs = LiftedVe::countNormalize (pf, X_);
Parfactors pfs = LiftedOperations::countNormalize (pf, X_);
for (size_t i = 0; i < pfs.size(); i++) {
unsigned condCount = pfs[i]->constr()->getConditionalCount (X_);
bool cartProduct = pfs[i]->constr()->isCartesianProduct (
@ -419,7 +420,7 @@ CountingOperator::toString (void)
ss << "count convert " << X_ << " in " ;
ss << (*pfIter_)->getLabel();
ss << " [cost=" << std::exp (getLogCost()) << "]" << endl;
Parfactors pfs = LiftedVe::countNormalize (*pfIter_, X_);
Parfactors pfs = LiftedOperations::countNormalize (*pfIter_, X_);
if ((*pfIter_)->constr()->isCountNormalized (X_) == false) {
for (size_t i = 0; i < pfs.size(); i++) {
ss << " º " << pfs[i]->getLabel() << endl;
@ -508,8 +509,6 @@ GroundOperator::getLogCost (void)
void
GroundOperator::apply (void)
{
// TODO if we update the correct groups
// we can skip shattering
ParfactorList::iterator pfIter;
pfIter = getParfactorsWithGroup (pfList_, group_).front();
Parfactor* pf = *pfIter;
@ -632,6 +631,7 @@ Params
LiftedVe::solveQuery (const Grounds& query)
{
assert (query.empty() == false);
pfList_ = parfactorList;
runSolver (query);
(*pfList_.begin())->normalize();
Params params = (*pfList_.begin())->params();
@ -647,7 +647,7 @@ void
LiftedVe::printSolverFlags (void) const
{
stringstream ss;
ss << "fove [" ;
ss << "lve [" ;
ss << "log_domain=" << Util::toString (Globals::logDomain);
ss << "]" ;
cout << ss.str() << endl;
@ -655,103 +655,12 @@ LiftedVe::printSolverFlags (void) const
void
LiftedVe::absorveEvidence (
ParfactorList& pfList,
ObservedFormulas& obsFormulas)
{
for (size_t i = 0; i < obsFormulas.size(); i++) {
Parfactors newPfs;
ParfactorList::iterator it = pfList.begin();
while (it != pfList.end()) {
Parfactor* pf = *it;
it = pfList.remove (it);
Parfactors absorvedPfs = absorve (obsFormulas[i], pf);
if (absorvedPfs.empty() == false) {
if (absorvedPfs.size() == 1 && absorvedPfs[0] == 0) {
// just remove pf;
} else {
Util::addToVector (newPfs, absorvedPfs);
}
delete pf;
} else {
it = pfList.insertShattered (it, pf);
++ it;
}
}
pfList.add (newPfs);
}
if (Globals::verbosity > 2 && obsFormulas.empty() == false) {
Util::printAsteriskLine();
cout << "AFTER EVIDENCE ABSORVED" << endl;
for (size_t i = 0; i < obsFormulas.size(); i++) {
cout << " -> " << obsFormulas[i] << endl;
}
Util::printAsteriskLine();
pfList.print();
}
}
Parfactors
LiftedVe::countNormalize (
Parfactor* g,
const LogVarSet& set)
{
Parfactors normPfs;
if (set.empty()) {
normPfs.push_back (new Parfactor (*g));
} else {
ConstraintTrees normCts = g->constr()->countNormalize (set);
for (size_t i = 0; i < normCts.size(); i++) {
normPfs.push_back (new Parfactor (g, normCts[i]));
}
}
return normPfs;
}
Parfactor
LiftedVe::calcGroundMultiplication (Parfactor pf)
{
LogVarSet lvs = pf.constr()->logVarSet();
lvs -= pf.constr()->singletons();
Parfactors newPfs = {new Parfactor (pf)};
for (size_t i = 0; i < lvs.size(); i++) {
Parfactors pfs = newPfs;
newPfs.clear();
for (size_t j = 0; j < pfs.size(); j++) {
bool countedLv = pfs[j]->countedLogVars().contains (lvs[i]);
if (countedLv) {
pfs[j]->fullExpand (lvs[i]);
newPfs.push_back (pfs[j]);
} else {
ConstraintTrees cts = pfs[j]->constr()->ground (lvs[i]);
for (size_t k = 0; k < cts.size(); k++) {
newPfs.push_back (new Parfactor (pfs[j], cts[k]));
}
delete pfs[j];
}
}
}
ParfactorList pfList (newPfs);
Parfactors groundShatteredPfs (pfList.begin(),pfList.end());
for (size_t i = 1; i < groundShatteredPfs.size(); i++) {
groundShatteredPfs[0]->multiply (*groundShatteredPfs[i]);
}
return Parfactor (*groundShatteredPfs[0]);
}
void
LiftedVe::runSolver (const Grounds& query)
{
largestCost_ = std::log (0);
shatterAgainstQuery (query);
runWeakBayesBall (query);
LiftedOperations::shatterAgainstQuery (pfList_, query);
LiftedOperations::runWeakBayesBall (pfList_, query);
while (true) {
if (Globals::verbosity > 2) {
Util::printDashedLine();
@ -817,177 +726,3 @@ LiftedVe::getBestOperation (const Grounds& query)
return bestOp;
}
void
LiftedVe::runWeakBayesBall (const Grounds& query)
{
queue<PrvGroup> todo; // groups to process
set<PrvGroup> done; // processed or in queue
for (size_t i = 0; i < query.size(); i++) {
ParfactorList::iterator it = pfList_.begin();
while (it != pfList_.end()) {
PrvGroup group = (*it)->findGroup (query[i]);
if (group != numeric_limits<PrvGroup>::max()) {
todo.push (group);
done.insert (group);
break;
}
++ it;
}
}
set<Parfactor*> requiredPfs;
while (todo.empty() == false) {
PrvGroup group = todo.front();
ParfactorList::iterator it = pfList_.begin();
while (it != pfList_.end()) {
if (Util::contains (requiredPfs, *it) == false &&
(*it)->containsGroup (group)) {
vector<PrvGroup> groups = (*it)->getAllGroups();
for (size_t i = 0; i < groups.size(); i++) {
if (Util::contains (done, groups[i]) == false) {
todo.push (groups[i]);
done.insert (groups[i]);
}
}
requiredPfs.insert (*it);
}
++ it;
}
todo.pop();
}
ParfactorList::iterator it = pfList_.begin();
bool foundNotRequired = false;
while (it != pfList_.end()) {
if (Util::contains (requiredPfs, *it) == false) {
if (Globals::verbosity > 2) {
if (foundNotRequired == false) {
Util::printHeader ("PARFACTORS TO DISCARD");
foundNotRequired = true;
}
(*it)->print();
}
it = pfList_.removeAndDelete (it);
} else {
++ it;
}
}
}
void
LiftedVe::shatterAgainstQuery (const Grounds& query)
{
for (size_t i = 0; i < query.size(); i++) {
if (query[i].isAtom()) {
continue;
}
bool found = false;
Parfactors newPfs;
ParfactorList::iterator it = pfList_.begin();
while (it != pfList_.end()) {
if ((*it)->containsGround (query[i])) {
found = true;
std::pair<ConstraintTree*, ConstraintTree*> split;
LogVars queryLvs (
(*it)->constr()->logVars().begin(),
(*it)->constr()->logVars().begin() + query[i].arity());
split = (*it)->constr()->split (query[i].args());
ConstraintTree* commCt = split.first;
ConstraintTree* exclCt = split.second;
newPfs.push_back (new Parfactor (*it, commCt));
if (exclCt->empty() == false) {
newPfs.push_back (new Parfactor (*it, exclCt));
} else {
delete exclCt;
}
it = pfList_.removeAndDelete (it);
} else {
++ it;
}
}
if (found == false) {
cerr << "error: could not find a parfactor with ground " ;
cerr << "`" << query[i] << "'" << endl;
exit (0);
}
pfList_.add (newPfs);
}
if (Globals::verbosity > 2) {
Util::printAsteriskLine();
cout << "SHATTERED AGAINST THE QUERY" << endl;
for (size_t i = 0; i < query.size(); i++) {
cout << " -> " << query[i] << endl;
}
Util::printAsteriskLine();
pfList_.print();
}
}
Parfactors
LiftedVe::absorve (
ObservedFormula& obsFormula,
Parfactor* g)
{
Parfactors absorvedPfs;
const ProbFormulas& formulas = g->arguments();
for (size_t i = 0; i < formulas.size(); i++) {
if (obsFormula.functor() == formulas[i].functor() &&
obsFormula.arity() == formulas[i].arity()) {
if (obsFormula.isAtom()) {
if (formulas.size() > 1) {
g->absorveEvidence (formulas[i], obsFormula.evidence());
} else {
// hack to erase parfactor g
absorvedPfs.push_back (0);
}
break;
}
g->constr()->moveToTop (formulas[i].logVars());
std::pair<ConstraintTree*, ConstraintTree*> res;
res = g->constr()->split (
formulas[i].logVars(),
&(obsFormula.constr()),
obsFormula.constr().logVars());
ConstraintTree* commCt = res.first;
ConstraintTree* exclCt = res.second;
if (commCt->empty() == false) {
if (formulas.size() > 1) {
LogVarSet excl = g->exclusiveLogVars (i);
Parfactor tempPf (g, commCt);
Parfactors countNormPfs = countNormalize (&tempPf, excl);
for (size_t j = 0; j < countNormPfs.size(); j++) {
countNormPfs[j]->absorveEvidence (
formulas[i], obsFormula.evidence());
absorvedPfs.push_back (countNormPfs[j]);
}
} else {
delete commCt;
}
if (exclCt->empty() == false) {
absorvedPfs.push_back (new Parfactor (g, exclCt));
} else {
delete exclCt;
}
if (absorvedPfs.empty()) {
// hack to erase parfactor g
absorvedPfs.push_back (0);
}
break;
} else {
delete commCt;
delete exclCt;
}
}
}
return absorvedPfs;
}

View File

@ -1,13 +1,15 @@
#ifndef HORUS_LIFTEDVE_H
#define HORUS_LIFTEDVE_H
#include "LiftedSolver.h"
#include "ParfactorList.h"
class LiftedOperator
{
public:
virtual ~LiftedOperator (void) { }
virtual double getLogCost (void) = 0;
virtual void apply (void) = 0;
@ -43,9 +45,9 @@ class ProductOperator : public LiftedOperator
private:
static bool validOp (Parfactor*, Parfactor*);
ParfactorList::iterator g1_;
ParfactorList::iterator g2_;
ParfactorList& pfList_;
ParfactorList::iterator g1_;
ParfactorList::iterator g2_;
ParfactorList& pfList_;
};
@ -123,43 +125,30 @@ class GroundOperator : public LiftedOperator
private:
vector<pair<PrvGroup, unsigned>> getAffectedFormulas (void);
PrvGroup group_;
unsigned lvIndex_;
ParfactorList& pfList_;
PrvGroup group_;
unsigned lvIndex_;
ParfactorList& pfList_;
};
class LiftedVe
class LiftedVe : public LiftedSolver
{
public:
LiftedVe (const ParfactorList& pfList) : pfList_(pfList) { }
LiftedVe (const ParfactorList& pfList)
: LiftedSolver(pfList) { }
Params solveQuery (const Grounds&);
void printSolverFlags (void) const;
static void absorveEvidence (
ParfactorList& pfList, ObservedFormulas& obsFormulas);
static Parfactors countNormalize (Parfactor*, const LogVarSet&);
static Parfactor calcGroundMultiplication (Parfactor pf);
private:
void runSolver (const Grounds&);
LiftedOperator* getBestOperation (const Grounds&);
void runWeakBayesBall (const Grounds&);
void shatterAgainstQuery (const Grounds&);
static Parfactors absorve (ObservedFormula&, Parfactor*);
ParfactorList pfList_;
double largestCost_;
ParfactorList pfList_;
double largestCost_;
};
#endif // HORUS_LIFTEDVE_H

View File

@ -0,0 +1,610 @@
#include "LiftedWCNF.h"
#include "ConstraintTree.h"
#include "Indexer.h"
bool
Literal::isGround (ConstraintTree constr, LogVarSet ipgLogVars) const
{
if (logVars_.size() == 0) {
return true;
}
LogVarSet lvs (logVars_);
lvs -= ipgLogVars;
return constr.singletons().contains (lvs);
}
size_t
Literal::indexOfLogVar (LogVar X) const
{
return Util::indexOf (logVars_, X);
}
string
Literal::toString (
LogVarSet ipgLogVars,
LogVarSet posCountedLvs,
LogVarSet negCountedLvs) const
{
stringstream ss;
negated_ ? ss << "¬" : ss << "" ;
// if (negated_ == false) {
// posWeight_ < 0.0 ? ss << "λ" : ss << "Θ" ;
// } else {
// negWeight_ < 0.0 ? ss << "λ" : ss << "Θ" ;
// }
ss << "λ" ;
ss << lid_ ;
if (logVars_.empty() == false) {
ss << "(" ;
for (size_t i = 0; i < logVars_.size(); i++) {
if (i != 0) ss << ",";
if (posCountedLvs.contains (logVars_[i])) {
ss << "+" << logVars_[i];
} else if (negCountedLvs.contains (logVars_[i])) {
ss << "-" << logVars_[i];
} else if (ipgLogVars.contains (logVars_[i])) {
LogVar X = logVars_[i];
const string labels[] = {
"a", "b", "c", "d", "e", "f",
"g", "h", "i", "j", "k", "m" };
(X >= 12) ? ss << "x_" << X : ss << labels[X];
} else {
ss << logVars_[i];
}
}
ss << ")" ;
}
return ss.str();
}
std::ostream& operator<< (ostream &os, const Literal& lit)
{
os << lit.toString();
return os;
}
void
Clause::addLiteralComplemented (const Literal& lit)
{
assert (constr_.logVarSet().contains (lit.logVars()));
literals_.push_back (lit);
literals_.back().complement();
}
bool
Clause::containsLiteral (LiteralId lid) const
{
for (size_t i = 0; i < literals_.size(); i++) {
if (literals_[i].lid() == lid) {
return true;
}
}
return false;
}
bool
Clause::containsPositiveLiteral (
LiteralId lid,
const LogVarTypes& types) const
{
for (size_t i = 0; i < literals_.size(); i++) {
if (literals_[i].lid() == lid
&& literals_[i].isPositive()
&& logVarTypes (i) == types) {
return true;
}
}
return false;
}
bool
Clause::containsNegativeLiteral (
LiteralId lid,
const LogVarTypes& types) const
{
for (size_t i = 0; i < literals_.size(); i++) {
if (literals_[i].lid() == lid
&& literals_[i].isNegative()
&& logVarTypes (i) == types) {
return true;
}
}
return false;
}
void
Clause::removeLiterals (LiteralId lid)
{
size_t i = 0;
while (i != literals_.size()) {
if (literals_[i].lid() == lid) {
removeLiteral (i);
} else {
i ++;
}
}
}
void
Clause::removePositiveLiterals (
LiteralId lid,
const LogVarTypes& types)
{
size_t i = 0;
while (i != literals_.size()) {
if (literals_[i].lid() == lid
&& literals_[i].isPositive()
&& logVarTypes (i) == types) {
removeLiteral (i);
} else {
i ++;
}
}
}
void
Clause::removeNegativeLiterals (
LiteralId lid,
const LogVarTypes& types)
{
size_t i = 0;
while (i != literals_.size()) {
if (literals_[i].lid() == lid
&& literals_[i].isNegative()
&& logVarTypes (i) == types) {
removeLiteral (i);
} else {
i ++;
}
}
}
bool
Clause::isCountedLogVar (LogVar X) const
{
assert (constr_.logVarSet().contains (X));
return posCountedLvs_.contains (X)
|| negCountedLvs_.contains (X);
}
bool
Clause::isPositiveCountedLogVar (LogVar X) const
{
assert (constr_.logVarSet().contains (X));
return posCountedLvs_.contains (X);
}
bool
Clause::isNegativeCountedLogVar (LogVar X) const
{
assert (constr_.logVarSet().contains (X));
return negCountedLvs_.contains (X);
}
bool
Clause::isIpgLogVar (LogVar X) const
{
assert (constr_.logVarSet().contains (X));
return ipgLvs_.contains (X);
}
TinySet<LiteralId>
Clause::lidSet (void) const
{
TinySet<LiteralId> lidSet;
for (size_t i = 0; i < literals_.size(); i++) {
lidSet.insert (literals_[i].lid());
}
return lidSet;
}
LogVarSet
Clause::ipgCandidates (void) const
{
LogVarSet candidates;
LogVarSet allLvs = constr_.logVarSet();
allLvs -= ipgLvs_;
allLvs -= posCountedLvs_;
allLvs -= negCountedLvs_;
for (size_t i = 0; i < allLvs.size(); i++) {
bool valid = true;
for (size_t j = 0; j < literals_.size(); j++) {
if (Util::contains (literals_[j].logVars(), allLvs[i]) == false) {
valid = false;
break;
}
}
if (valid) {
candidates.insert (allLvs[i]);
}
}
return candidates;
}
LogVarTypes
Clause::logVarTypes (size_t litIdx) const
{
LogVarTypes types;
const LogVars lvs = literals_[litIdx].logVars();
for (size_t i = 0; i < lvs.size(); i++) {
if (posCountedLvs_.contains (lvs[i])) {
types.push_back (LogVarType::POS_LV);
} else if (negCountedLvs_.contains (lvs[i])) {
types.push_back (LogVarType::NEG_LV);
} else {
types.push_back (LogVarType::FULL_LV);
}
}
return types;
}
void
Clause::removeLiteral (size_t litIdx)
{
LogVarSet lvsToRemove = literals_[litIdx].logVarSet()
- getLogVarSetExcluding (litIdx);
ipgLvs_ -= lvsToRemove;
posCountedLvs_ -= lvsToRemove;
negCountedLvs_ -= lvsToRemove;
constr_.remove (lvsToRemove);
literals_.erase (literals_.begin() + litIdx);
}
bool
Clause::independentClauses (Clause& c1, Clause& c2)
{
const Literals& lits1 = c1.literals();
const Literals& lits2 = c2.literals();
for (size_t i = 0; i < lits1.size(); i++) {
for (size_t j = 0; j < lits2.size(); j++) {
if (lits1[i].lid() == lits2[j].lid()
&& c1.logVarTypes (i) == c2.logVarTypes (j)) {
return false;
}
}
}
return true;
}
void
Clause::printClauses (const Clauses& clauses)
{
for (size_t i = 0; i < clauses.size(); i++) {
cout << clauses[i] << endl;
}
}
std::ostream& operator<< (ostream &os, const Clause& clause)
{
for (unsigned i = 0; i < clause.literals_.size(); i++) {
if (i != 0) os << " v " ;
os << clause.literals_[i].toString (clause.ipgLvs_,
clause.posCountedLvs_, clause.negCountedLvs_);
}
if (clause.constr_.empty() == false) {
ConstraintTree copy (clause.constr_);
copy.moveToTop (copy.logVarSet().elements());
os << " | " << copy.tupleSet();
}
return os;
}
LogVarSet
Clause::getLogVarSetExcluding (size_t idx) const
{
LogVarSet lvs;
for (size_t i = 0; i < literals_.size(); i++) {
if (i != idx) {
lvs |= literals_[i].logVars();
}
}
return lvs;
}
LiftedWCNF::LiftedWCNF (const ParfactorList& pfList)
: freeLiteralId_(0), pfList_(pfList)
{
addIndicatorClauses (pfList);
addParameterClauses (pfList);
/*
// INCLUSION-EXCLUSION TEST
vector<vector<string>> names = {
// {"a1","b1"},{"a2","b2"},{"a1","b3"}
{"b1","a1"},{"b2","a2"},{"b3","a1"}
};
Clause c1 (names);
c1.addLiteral (Literal (0, LogVars() = {0}));
c1.addLiteral (Literal (1, LogVars() = {1}));
clauses_.push_back(c1);
freeLiteralId_ ++ ;
freeLiteralId_ ++ ;
*/
/*
// ATOM-COUNTING TEST
vector<vector<string>> names = {
{"p1","p1"},{"p1","p2"},{"p1","p3"},
{"p2","p1"},{"p2","p2"},{"p2","p3"},
{"p3","p1"},{"p3","p2"},{"p3","p3"}
};
Clause c1 (names);
c1.addLiteral (Literal (0, LogVars() = {0}));
c1.addLiteralComplemented (Literal (1, {0,1}));
clauses_.push_back(c1);
Clause c2 (names);
c2.addLiteral (Literal (0, LogVars()={0}));
c2.addLiteralComplemented (Literal (1, {1,0}));
clauses_.push_back(c2);
addWeight (0, LogAware::log(3.0), LogAware::log(4.0));
addWeight (1, LogAware::log(2.0), LogAware::log(5.0));
freeLiteralId_ = 2;
*/
cout << "FORMULA INDICATORS:" << endl;
printFormulaIndicators();
cout << endl;
cout << "WEIGHTS:" << endl;
printWeights();
cout << endl;
cout << "CLAUSES:" << endl;
printClauses();
cout << endl;
}
LiftedWCNF::~LiftedWCNF (void)
{
}
void
LiftedWCNF::addWeight (LiteralId lid, double posW, double negW)
{
weights_[lid] = make_pair (posW, negW);
}
double
LiftedWCNF::posWeight (LiteralId lid) const
{
unordered_map<LiteralId, std::pair<double,double>>::const_iterator it;
it = weights_.find (lid);
return it != weights_.end() ? it->second.first : LogAware::one();
}
double
LiftedWCNF::negWeight (LiteralId lid) const
{
unordered_map<LiteralId, std::pair<double,double>>::const_iterator it;
it = weights_.find (lid);
return it != weights_.end() ? it->second.second : LogAware::one();
}
vector<LiteralId>
LiftedWCNF::prvGroupLiterals (PrvGroup prvGroup)
{
assert (Util::contains (map_, prvGroup));
return map_[prvGroup];
}
Clause
LiftedWCNF::createClause (LiteralId lid) const
{
for (size_t i = 0; i < clauses_.size(); i++) {
const Literals& literals = clauses_[i].literals();
for (size_t j = 0; j < literals.size(); j++) {
if (literals[j].lid() == lid) {
ConstraintTree ct = clauses_[i].constr();
ct.project (literals[j].logVars());
Clause clause (ct);
clause.addLiteral (literals[j]);
return clause;
}
}
}
abort(); // we should not reach this point
return Clause (ConstraintTree({}));
}
LiteralId
LiftedWCNF::getLiteralId (PrvGroup prvGroup, unsigned range)
{
assert (Util::contains (map_, prvGroup));
return map_[prvGroup][range];
}
void
LiftedWCNF::addIndicatorClauses (const ParfactorList& pfList)
{
ParfactorList::const_iterator it = pfList.begin();
set<PrvGroup> allGroups;
while (it != pfList.end()) {
const ProbFormulas& formulas = (*it)->arguments();
for (size_t i = 0; i < formulas.size(); i++) {
if (Util::contains (allGroups, formulas[i].group()) == false) {
allGroups.insert (formulas[i].group());
ConstraintTree tempConstr = *(*it)->constr();
tempConstr.project (formulas[i].logVars());
Clause clause (tempConstr);
vector<LiteralId> lids;
for (size_t j = 0; j < formulas[i].range(); j++) {
clause.addLiteral (Literal (freeLiteralId_, formulas[i].logVars()));
lids.push_back (freeLiteralId_);
freeLiteralId_ ++;
}
clauses_.push_back (clause);
for (size_t j = 0; j < formulas[i].range() - 1; j++) {
for (size_t k = j + 1; k < formulas[i].range(); k++) {
ConstraintTree tempConstr2 = *(*it)->constr();
tempConstr2.project (formulas[i].logVars());
Clause clause2 (tempConstr2);
clause2.addLiteralComplemented (Literal (clause.literals()[j]));
clause2.addLiteralComplemented (Literal (clause.literals()[k]));
clauses_.push_back (clause2);
}
}
map_[formulas[i].group()] = lids;
}
}
++ it;
}
}
void
LiftedWCNF::addParameterClauses (const ParfactorList& pfList)
{
ParfactorList::const_iterator it = pfList.begin();
while (it != pfList.end()) {
Indexer indexer ((*it)->ranges());
vector<PrvGroup> groups = (*it)->getAllGroups();
while (indexer.valid()) {
LiteralId paramVarLid = freeLiteralId_;
// λu1 ∧ ... ∧ λun ∧ λxi <=> θxi|u1,...,un
//
// ¬λu1 ... ¬λun v θxi|u1,...,un -> clause1
// ¬θxi|u1,...,un v λu1 -> tempClause
// ¬θxi|u1,...,un v λu2 -> tempClause
double posWeight = (**it)[indexer];
addWeight (paramVarLid, posWeight, 1.0);
Clause clause1 (*(*it)->constr());
for (unsigned i = 0; i < groups.size(); i++) {
LiteralId lid = getLiteralId (groups[i], indexer[i]);
clause1.addLiteralComplemented (
Literal (lid, (*it)->argument(i).logVars()));
ConstraintTree ct = *(*it)->constr();
Clause tempClause (ct);
tempClause.addLiteralComplemented (Literal (
paramVarLid, (*it)->constr()->logVars()));
tempClause.addLiteral (Literal (lid, (*it)->argument(i).logVars()));
clauses_.push_back (tempClause);
}
clause1.addLiteral (Literal (paramVarLid, (*it)->constr()->logVars()));
clauses_.push_back (clause1);
freeLiteralId_ ++;
++ indexer;
}
++ it;
}
}
void
LiftedWCNF::printFormulaIndicators (void) const
{
set<PrvGroup> allGroups;
ParfactorList::const_iterator it = pfList_.begin();
while (it != pfList_.end()) {
const ProbFormulas& formulas = (*it)->arguments();
for (size_t i = 0; i < formulas.size(); i++) {
if (Util::contains (allGroups, formulas[i].group()) == false) {
allGroups.insert (formulas[i].group());
cout << formulas[i] << " | " ;
ConstraintTree tempCt = *(*it)->constr();
tempCt.project (formulas[i].logVars());
cout << tempCt.tupleSet();
cout << " indicators => " ;
vector<LiteralId> indicators =
(map_.find (formulas[i].group()))->second;
cout << indicators << endl;
}
}
++ it;
}
}
void
LiftedWCNF::printWeights (void) const
{
unordered_map<LiteralId, std::pair<double,double>>::const_iterator it;
it = weights_.begin();
while (it != weights_.end()) {
cout << "λ" << it->first << " weights: " ;
cout << it->second.first << " " << it->second.second;
cout << endl;
++ it;
}
}
void
LiftedWCNF::printClauses (void) const
{
for (unsigned i = 0; i < clauses_.size(); i++) {
cout << clauses_[i] << endl;
}
}

View File

@ -0,0 +1,226 @@
#ifndef HORUS_LIFTEDWCNF_H
#define HORUS_LIFTEDWCNF_H
#include "ParfactorList.h"
using namespace std;
typedef long LiteralId;
class ConstraintTree;
enum LogVarType
{
FULL_LV,
POS_LV,
NEG_LV
};
typedef vector<LogVarType> LogVarTypes;
class Literal
{
public:
Literal (LiteralId lid, const LogVars& lvs) :
lid_(lid), logVars_(lvs), negated_(false) { }
Literal (const Literal& lit, bool negated) :
lid_(lit.lid_), logVars_(lit.logVars_), negated_(negated) { }
LiteralId lid (void) const { return lid_; }
LogVars logVars (void) const { return logVars_; }
size_t nrLogVars (void) const { return logVars_.size(); }
LogVarSet logVarSet (void) const { return LogVarSet (logVars_); }
void complement (void) { negated_ = !negated_; }
bool isPositive (void) const { return negated_ == false; }
bool isNegative (void) const { return negated_; }
bool isGround (ConstraintTree constr, LogVarSet ipgLogVars) const;
size_t indexOfLogVar (LogVar X) const;
string toString (LogVarSet ipgLogVars = LogVarSet(),
LogVarSet posCountedLvs = LogVarSet(),
LogVarSet negCountedLvs = LogVarSet()) const;
friend std::ostream& operator<< (std::ostream &os, const Literal& lit);
private:
LiteralId lid_;
LogVars logVars_;
bool negated_;
};
typedef vector<Literal> Literals;
class Clause
{
public:
Clause (const ConstraintTree& ct = ConstraintTree({})) : constr_(ct) { }
Clause (vector<vector<string>> names) : constr_(ConstraintTree (names)) { }
void addLiteral (const Literal& l) { literals_.push_back (l); }
const Literals& literals (void) const { return literals_; }
const ConstraintTree& constr (void) const { return constr_; }
ConstraintTree constr (void) { return constr_; }
bool isUnit (void) const { return literals_.size() == 1; }
LogVarSet ipgLogVars (void) const { return ipgLvs_; }
void addIpgLogVar (LogVar X) { ipgLvs_.insert (X); }
void addPosCountedLogVar (LogVar X) { posCountedLvs_.insert (X); }
void addNegCountedLogVar (LogVar X) { negCountedLvs_.insert (X); }
LogVarSet posCountedLogVars (void) const { return posCountedLvs_; }
LogVarSet negCountedLogVars (void) const { return negCountedLvs_; }
unsigned nrPosCountedLogVars (void) const { return posCountedLvs_.size(); }
unsigned nrNegCountedLogVars (void) const { return negCountedLvs_.size(); }
void addLiteralComplemented (const Literal& lit);
bool containsLiteral (LiteralId lid) const;
bool containsPositiveLiteral (LiteralId lid, const LogVarTypes&) const;
bool containsNegativeLiteral (LiteralId lid, const LogVarTypes&) const;
void removeLiterals (LiteralId lid);
void removePositiveLiterals (LiteralId lid, const LogVarTypes&);
void removeNegativeLiterals (LiteralId lid, const LogVarTypes&);
bool isCountedLogVar (LogVar X) const;
bool isPositiveCountedLogVar (LogVar X) const;
bool isNegativeCountedLogVar (LogVar X) const;
bool isIpgLogVar (LogVar X) const;
TinySet<LiteralId> lidSet (void) const;
LogVarSet ipgCandidates (void) const;
LogVarTypes logVarTypes (size_t litIdx) const;
void removeLiteral (size_t litIdx);
static bool independentClauses (Clause& c1, Clause& c2);
static void printClauses (const vector<Clause>& clauses);
friend std::ostream& operator<< (ostream &os, const Clause& clause);
private:
LogVarSet getLogVarSetExcluding (size_t idx) const;
Literals literals_;
LogVarSet ipgLvs_;
LogVarSet posCountedLvs_;
LogVarSet negCountedLvs_;
ConstraintTree constr_;
};
typedef vector<Clause> Clauses;
class LitLvTypes
{
public:
struct CompareLitLvTypes
{
bool operator() (
const LitLvTypes& types1,
const LitLvTypes& types2) const
{
if (types1.lid_ < types2.lid_) {
return true;
}
return types1.lvTypes_ < types2.lvTypes_;
}
};
LitLvTypes (LiteralId lid, const LogVarTypes& lvTypes) :
lid_(lid), lvTypes_(lvTypes) { }
LiteralId lid (void) const { return lid_; }
const LogVarTypes& logVarTypes (void) const { return lvTypes_; }
void setAllFullLogVars (void) {
std::fill (lvTypes_.begin(), lvTypes_.end(), LogVarType::FULL_LV); }
private:
LiteralId lid_;
LogVarTypes lvTypes_;
};
typedef TinySet<LitLvTypes,LitLvTypes::CompareLitLvTypes> LitLvTypesSet;
class LiftedWCNF
{
public:
LiftedWCNF (const ParfactorList& pfList);
~LiftedWCNF (void);
const Clauses& clauses (void) const { return clauses_; }
void addWeight (LiteralId lid, double posW, double negW);
double posWeight (LiteralId lid) const;
double negWeight (LiteralId lid) const;
vector<LiteralId> prvGroupLiterals (PrvGroup prvGroup);
Clause createClause (LiteralId lid) const;
void printFormulaIndicators (void) const;
void printWeights (void) const;
void printClauses (void) const;
private:
LiteralId getLiteralId (PrvGroup prvGroup, unsigned range);
void addIndicatorClauses (const ParfactorList& pfList);
void addParameterClauses (const ParfactorList& pfList);
Clauses clauses_;
LiteralId freeLiteralId_;
const ParfactorList& pfList_;
unordered_map<PrvGroup, vector<LiteralId>> map_;
unordered_map<LiteralId, std::pair<double,double>> weights_;
};
#endif // HORUS_LIFTEDWCNF_H

View File

@ -23,10 +23,10 @@ CC=@CC@
CXX=@CXX@
# normal
CXXFLAGS= -std=c++0x @SHLIB_CXXFLAGS@ $(YAP_EXTRAS) $(DEFS) -D_YAP_NOT_INSTALLED_=1 -I$(srcdir) -I../../.. -I$(srcdir)/../../../include @CPPFLAGS@ -DNDEBUG
#CXXFLAGS= -std=c++0x @SHLIB_CXXFLAGS@ $(YAP_EXTRAS) $(DEFS) -D_YAP_NOT_INSTALLED_=1 -I$(srcdir) -I../../.. -I$(srcdir)/../../../include @CPPFLAGS@ -DNDEBUG
# debug
#CXXFLAGS= -std=c++0x @SHLIB_CXXFLAGS@ $(YAP_EXTRAS) $(DEFS) -D_YAP_NOT_INSTALLED_=1 -I$(srcdir) -I../../.. -I$(srcdir)/../../../include @CPPFLAGS@ -g -O0 -Wextra
CXXFLAGS= -std=c++0x @SHLIB_CXXFLAGS@ $(YAP_EXTRAS) $(DEFS) -D_YAP_NOT_INSTALLED_=1 -I$(srcdir) -I../../.. -I$(srcdir)/../../../include @CPPFLAGS@ -g -O0 -Wextra
#
@ -57,12 +57,17 @@ HEADERS = \
$(srcdir)/Horus.h \
$(srcdir)/Indexer.h \
$(srcdir)/LiftedBp.h \
$(srcdir)/LiftedCircuit.h \
$(srcdir)/LiftedKc.h \
$(srcdir)/LiftedOperations.h \
$(srcdir)/LiftedSolver.h \
$(srcdir)/LiftedUtils.h \
$(srcdir)/LiftedVe.h \
$(srcdir)/LiftedWCNF.h \
$(srcdir)/Parfactor.h \
$(srcdir)/ParfactorList.h \
$(srcdir)/ProbFormula.h \
$(srcdir)/Solver.h \
$(srcdir)/GroundSolver.h \
$(srcdir)/TinySet.h \
$(srcdir)/Util.h \
$(srcdir)/Var.h \
@ -82,16 +87,20 @@ CPP_SOURCES = \
$(srcdir)/HorusCli.cpp \
$(srcdir)/HorusYap.cpp \
$(srcdir)/LiftedBp.cpp \
$(srcdir)/LiftedCircuit.cpp \
$(srcdir)/LiftedKc.cpp \
$(srcdir)/LiftedOperations.cpp \
$(srcdir)/LiftedUtils.cpp \
$(srcdir)/LiftedVe.cpp \
$(srcdir)/LiftedWCNF.cpp \
$(srcdir)/Parfactor.cpp \
$(srcdir)/ParfactorList.cpp \
$(srcdir)/ProbFormula.cpp \
$(srcdir)/Solver.cpp \
$(srcdir)/GroundSolver.cpp \
$(srcdir)/Util.cpp \
$(srcdir)/Var.cpp \
$(srcdir)/VarElim.cpp \
$(srcdir)/WeightedBp.cpp \
$(srcdir)/WeightedBp.cpp
OBJS = \
BayesBall.o \
@ -105,12 +114,16 @@ OBJS = \
Histogram.o \
HorusYap.o \
LiftedBp.o \
LiftedCircuit.o \
LiftedKc.o \
LiftedOperations.o \
LiftedUtils.o \
LiftedVe.o \
LiftedWCNF.o \
ProbFormula.o \
Parfactor.o \
ParfactorList.o \
Solver.o \
GroundSolver.o \
Util.o \
Var.o \
VarElim.o \
@ -125,7 +138,7 @@ HCLI_OBJS = \
Factor.o \
FactorGraph.o \
HorusCli.o \
Solver.o \
GroundSolver.o \
Util.o \
Var.o \
VarElim.o \

View File

@ -9,8 +9,7 @@ ParfactorList::ParfactorList (const ParfactorList& pfList)
while (it != pfList.end()) {
addShattered (new Parfactor (**it));
++ it;
}
}
}
@ -126,6 +125,27 @@ ParfactorList::print (void) const
ParfactorList&
ParfactorList::operator= (const ParfactorList& pfList)
{
if (this != &pfList) {
ParfactorList::const_iterator it0 = pfList_.begin();
while (it0 != pfList_.end()) {
delete *it0;
++ it0;
}
pfList_.clear();
ParfactorList::const_iterator it = pfList.begin();
while (it != pfList.end()) {
addShattered (new Parfactor (**it));
++ it;
}
}
return *this;
}
bool
ParfactorList::isShattered (const Parfactor* g) const
{

View File

@ -56,6 +56,8 @@ class ParfactorList
bool isAllShattered (void) const;
void print (void) const;
ParfactorList& operator= (const ParfactorList& pfList);
private:

View File

@ -1,3 +1,2 @@
- Find a way to decrease the time required to find an
elimination order for variable elimination
- Handle formulas like f(X,X)

View File

@ -153,6 +153,11 @@ class TinySet
{
return vec_[i];
}
T& operator[] (typename vector<T>::size_type i)
{
return vec_[i];
}
T front (void) const
{

View File

@ -13,9 +13,9 @@ bool logDomain = false;
unsigned verbosity = 0;
LiftedSolver liftedSolver = LiftedSolver::FOVE;
LiftedSolverType liftedSolver = LiftedSolverType::LVE;
GroundSolver groundSolver = GroundSolver::VE;
GroundSolverType groundSolver = GroundSolverType::VE;
};
@ -210,10 +210,12 @@ setHorusFlag (string key, string value)
ss << value;
ss >> Globals::verbosity;
} else if (key == "lifted_solver") {
if ( value == "fove") {
Globals::liftedSolver = LiftedSolver::FOVE;
if ( value == "lve") {
Globals::liftedSolver = LiftedSolverType::LVE;
} else if (value == "lbp") {
Globals::liftedSolver = LiftedSolver::LBP;
Globals::liftedSolver = LiftedSolverType::LBP;
} else if (value == "lkc") {
Globals::liftedSolver = LiftedSolverType::LKC;
} else {
cerr << "warning: invalid value `" << value << "' " ;
cerr << "for `" << key << "'" << endl;
@ -221,11 +223,11 @@ setHorusFlag (string key, string value)
}
} else if (key == "ground_solver") {
if ( value == "ve") {
Globals::groundSolver = GroundSolver::VE;
Globals::groundSolver = GroundSolverType::VE;
} else if (value == "bp") {
Globals::groundSolver = GroundSolver::BP;
Globals::groundSolver = GroundSolverType::BP;
} else if (value == "cbp") {
Globals::groundSolver = GroundSolver::CBP;
Globals::groundSolver = GroundSolverType::CBP;
} else {
cerr << "warning: invalid value `" << value << "' " ;
cerr << "for `" << key << "'" << endl;

View File

@ -3,7 +3,7 @@
#include "unordered_map"
#include "Solver.h"
#include "GroundSolver.h"
#include "FactorGraph.h"
#include "Horus.h"
@ -11,10 +11,10 @@
using namespace std;
class VarElim : public Solver
class VarElim : public GroundSolver
{
public:
VarElim (const FactorGraph& fg) : Solver (fg) { }
VarElim (const FactorGraph& fg) : GroundSolver (fg) { }
~VarElim (void);