Merge branch 'master' of https://github.com/tacgomes/yap6.3
This commit is contained in:
commit
027632456a
@ -49,11 +49,12 @@ Inference Options
|
||||
PFL supports both ground and lifted inference. The inference algorithm
|
||||
can be chosen using the set_solver/1 predicate. The following algorithms
|
||||
are supported:
|
||||
- fove: lifted variable elimination with arbitrary constraints (GC-FOVE)
|
||||
- lve: generalized counting first-order variable elimination (GC-FOVE)
|
||||
- hve: (ground) variable elimination
|
||||
- lbp: lifted first-order belief propagation
|
||||
- cbp: counting belief propagation
|
||||
- bp: (ground) belief propagation
|
||||
- lkc: lifted first-order knowledge compilation
|
||||
|
||||
For example, if we want to use ground variable elimination to solve some
|
||||
query, we need to call first the following goal:
|
||||
|
@ -3,7 +3,7 @@
|
||||
source city.sh
|
||||
source ../benchs.sh
|
||||
|
||||
SOLVER="fove"
|
||||
SOLVER="lve"
|
||||
|
||||
function run_all_graphs
|
||||
{
|
||||
@ -32,5 +32,5 @@ function run_all_graphs
|
||||
}
|
||||
|
||||
prepare_new_run
|
||||
run_all_graphs "fove "
|
||||
run_all_graphs "lve "
|
||||
|
@ -3,7 +3,7 @@
|
||||
source cw.sh
|
||||
source ../benchs.sh
|
||||
|
||||
SOLVER="fove"
|
||||
SOLVER="lve"
|
||||
|
||||
function run_all_graphs
|
||||
{
|
||||
@ -26,6 +26,6 @@ function run_all_graphs
|
||||
}
|
||||
|
||||
prepare_new_run
|
||||
run_all_graphs "fove "
|
||||
run_all_graphs "lve "
|
||||
|
||||
|
@ -3,7 +3,7 @@
|
||||
cd workshop_attrs
|
||||
source hve_tests.sh
|
||||
source bp_tests.sh
|
||||
source fove_tests.sh
|
||||
source lve_tests.sh
|
||||
source lbp_tests.sh
|
||||
source cbp_tests.sh
|
||||
cd ..
|
||||
@ -11,7 +11,7 @@ cd ..
|
||||
cd comp_workshops
|
||||
source hve_tests.sh
|
||||
source bp_tests.sh
|
||||
source fove_tests.sh
|
||||
source lve_tests.sh
|
||||
source lbp_tests.sh
|
||||
source cbp_tests.sh
|
||||
cd ..
|
||||
@ -19,7 +19,7 @@ cd ..
|
||||
cd city
|
||||
source hve_tests.sh
|
||||
source bp_tests.sh
|
||||
source fove_tests.sh
|
||||
source lve_tests.sh
|
||||
source lbp_tests.sh
|
||||
source cbp_tests.sh
|
||||
cd ..
|
||||
@ -27,7 +27,7 @@ cd ..
|
||||
cd smokers
|
||||
source hve_tests.sh
|
||||
source bp_tests.sh
|
||||
source fove_tests.sh
|
||||
source lve_tests.sh
|
||||
source lbp_tests.sh
|
||||
source cbp_tests.sh
|
||||
cd ..
|
||||
|
@ -3,7 +3,7 @@
|
||||
source sm.sh
|
||||
source ../benchs.sh
|
||||
|
||||
SOLVER="fove"
|
||||
SOLVER="lve"
|
||||
|
||||
function run_all_graphs
|
||||
{
|
||||
@ -26,6 +26,6 @@ function run_all_graphs
|
||||
}
|
||||
|
||||
prepare_new_run
|
||||
run_all_graphs "fove "
|
||||
run_all_graphs "lve "
|
||||
|
||||
|
@ -3,7 +3,7 @@
|
||||
source sm.sh
|
||||
source ../benchs.sh
|
||||
|
||||
SOLVER="fove"
|
||||
SOLVER="lve"
|
||||
|
||||
function run_all_graphs
|
||||
{
|
||||
@ -30,6 +30,6 @@ function run_all_graphs
|
||||
}
|
||||
|
||||
prepare_new_run
|
||||
run_all_graphs "fove "
|
||||
run_all_graphs "lve "
|
||||
|
||||
|
@ -3,7 +3,7 @@
|
||||
source wa.sh
|
||||
source ../benchs.sh
|
||||
|
||||
SOLVER="fove"
|
||||
SOLVER="lve"
|
||||
|
||||
function run_all_graphs
|
||||
{
|
||||
@ -32,6 +32,6 @@ function run_all_graphs
|
||||
}
|
||||
|
||||
prepare_new_run
|
||||
run_all_graphs "fove "
|
||||
run_all_graphs "lve "
|
||||
|
||||
|
@ -39,8 +39,9 @@ set_solver(ve) :- !, set_clpbn_flag(solver,ve).
|
||||
set_solver(bdd) :- !, set_clpbn_flag(solver,bdd).
|
||||
set_solver(jt) :- !, set_clpbn_flag(solver,jt).
|
||||
set_solver(gibbs) :- !, set_clpbn_flag(solver,gibbs).
|
||||
set_solver(fove) :- !, set_clpbn_flag(solver,fove), set_horus_flag(lifted_solver, fove).
|
||||
set_solver(lve) :- !, set_clpbn_flag(solver,fove), set_horus_flag(lifted_solver, lve).
|
||||
set_solver(lbp) :- !, set_clpbn_flag(solver,fove), set_horus_flag(lifted_solver, lbp).
|
||||
set_solver(lkc) :- !, set_clpbn_flag(solver,fove), set_horus_flag(lifted_solver, lkc).
|
||||
set_solver(hve) :- !, set_clpbn_flag(solver,bp), set_horus_flag(ground_solver, ve).
|
||||
set_solver(bp) :- !, set_clpbn_flag(solver,bp), set_horus_flag(ground_solver, bp).
|
||||
set_solver(cbp) :- !, set_clpbn_flag(solver,bp), set_horus_flag(ground_solver, cbp).
|
||||
|
@ -18,7 +18,7 @@ total_students(256).
|
||||
:- ensure_loaded(parschema).
|
||||
|
||||
:- yap_flag(unknown,error).
|
||||
%:- clpbn_horus:set_solver(fove).
|
||||
%:- clpbn_horus:set_solver(lve).
|
||||
%:- clpbn_horus:set_solver(hve).
|
||||
:- clpbn_horus:set_solver(bp).
|
||||
%:- clpbn_horus:set_solver(bdd).
|
||||
|
@ -1,6 +1,6 @@
|
||||
:- use_module(library(pfl)).
|
||||
|
||||
%:- set_solver(fove).
|
||||
%:- set_solver(lve).
|
||||
%:- set_solver(hve).
|
||||
%:- set_solver(bp).
|
||||
%:- set_solver(cbp).
|
||||
|
@ -1,6 +1,6 @@
|
||||
:- use_module(library(pfl)).
|
||||
|
||||
%:- set_solver(fove).
|
||||
%:- set_solver(lve).
|
||||
%:- set_solver(hve).
|
||||
%:- set_solver(bp).
|
||||
%:- set_solver(cbp).
|
||||
@ -29,7 +29,7 @@ bayes car_color(P)::[t,f], hair_color(P) ; car_color_table(P); [people(P,_)].
|
||||
|
||||
bayes height(P)::[t,f], gender(P) ; height_table(P) ; [people(P,_)].
|
||||
|
||||
bayes shoe_size(P):[t,f], height(P) ; shoe_size_table(P); [people(P,_)].
|
||||
bayes shoe_size(P)::[t,f], height(P) ; shoe_size_table(P); [people(P,_)].
|
||||
|
||||
bayes guilty(P)::[y,n] ; guilty_table(P) ; [people(P,_)].
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
:- use_module(library(pfl)).
|
||||
|
||||
%:- set_solver(fove).
|
||||
%:- set_solver(lve).
|
||||
%:- set_solver(hve).
|
||||
%:- set_solver(bp).
|
||||
%:- set_solver(cbp).
|
||||
|
@ -1,7 +1,7 @@
|
||||
|
||||
:- use_module(library(pfl)).
|
||||
|
||||
:- set_pfl_flag(solver,fove).
|
||||
:- set_pfl_flag(solver,lve).
|
||||
%:- set_pfl_flag(solver,bp), clpbn_horus:set_horus_flag(inf_alg,ve).
|
||||
%:- set_pfl_flag(solver,bp), clpbn_horus:set_horus_flag(inf_alg,bp).
|
||||
%:- set_pfl_flag(solver,bp), clpbn_horus:set_horus_flag(inf_alg,cbp).
|
||||
|
@ -1,6 +1,6 @@
|
||||
:- use_module(library(pfl)).
|
||||
|
||||
:- set_solver(fove).
|
||||
:- set_solver(lve).
|
||||
%:- set_solver(hve).
|
||||
%:- set_solver(bp).
|
||||
%:- set_solver(cbp).
|
||||
|
@ -1,6 +1,6 @@
|
||||
:- use_module(library(pfl)).
|
||||
|
||||
%:- set_solver(fove).
|
||||
%:- set_solver(lve).
|
||||
%:- set_solver(hve).
|
||||
%:- set_solver(bp).
|
||||
%:- set_solver(cbp).
|
||||
|
@ -1,6 +1,6 @@
|
||||
:- use_module(library(pfl)).
|
||||
|
||||
%:- set_solver(fove).
|
||||
%:- set_solver(lve).
|
||||
%:- set_solver(hve).
|
||||
%:- set_solver(bp).
|
||||
%:- set_solver(cbp).
|
||||
|
@ -1,6 +1,6 @@
|
||||
:- use_module(library(pfl)).
|
||||
|
||||
%:- set_solver(fove).
|
||||
%:- set_solver(lve).
|
||||
%:- set_solver(hve).
|
||||
%:- set_solver(bp).
|
||||
%:- set_solver(cbp).
|
||||
|
@ -12,7 +12,7 @@
|
||||
#include "Horus.h"
|
||||
|
||||
|
||||
BeliefProp::BeliefProp (const FactorGraph& fg) : Solver (fg)
|
||||
BeliefProp::BeliefProp (const FactorGraph& fg) : GroundSolver (fg)
|
||||
{
|
||||
runned_ = false;
|
||||
}
|
||||
@ -377,7 +377,8 @@ BeliefProp::getVarToFactorMsg (const BpLink* link) const
|
||||
Params
|
||||
BeliefProp::getJointByConditioning (const VarIds& jointVarIds) const
|
||||
{
|
||||
return Solver::getJointByConditioning (GroundSolver::BP, fg, jointVarIds);
|
||||
return GroundSolver::getJointByConditioning (
|
||||
GroundSolverType::BP, fg, jointVarIds);
|
||||
}
|
||||
|
||||
|
||||
|
@ -5,7 +5,7 @@
|
||||
#include <vector>
|
||||
#include <sstream>
|
||||
|
||||
#include "Solver.h"
|
||||
#include "GroundSolver.h"
|
||||
#include "Factor.h"
|
||||
#include "FactorGraph.h"
|
||||
#include "Util.h"
|
||||
@ -83,7 +83,7 @@ class SPNodeInfo
|
||||
};
|
||||
|
||||
|
||||
class BeliefProp : public Solver
|
||||
class BeliefProp : public GroundSolver
|
||||
{
|
||||
public:
|
||||
BeliefProp (const FactorGraph&);
|
||||
|
@ -112,6 +112,8 @@ CTNode::copySubtree (const CTNode* root1)
|
||||
const CTNode* n1 = stack.back().first;
|
||||
CTNode* n2 = stack.back().second;
|
||||
stack.pop_back();
|
||||
// cout << "n2 childs: " << n2->childs();
|
||||
// cout << "n1 childs: " << n1->childs();
|
||||
n2->childs().reserve (n1->nrChilds());
|
||||
stack.reserve (n1->nrChilds());
|
||||
for (CTChilds::const_iterator chIt = n1->childs().begin();
|
||||
@ -185,11 +187,31 @@ ConstraintTree::ConstraintTree (
|
||||
|
||||
|
||||
|
||||
ConstraintTree::ConstraintTree (vector<vector<string>> names)
|
||||
{
|
||||
assert (names.empty() == false);
|
||||
assert (names.front().empty() == false);
|
||||
unsigned nrLvs = names[0].size();
|
||||
for (size_t i = 0; i < nrLvs; i++) {
|
||||
logVars_.push_back (LogVar (i));
|
||||
}
|
||||
root_ = new CTNode (0, 0);
|
||||
logVarSet_ = LogVarSet (logVars_);
|
||||
for (size_t i = 0; i < names.size(); i++) {
|
||||
Tuple t;
|
||||
for (size_t j = 0; j < names[i].size(); j++) {
|
||||
assert (names[i].size() == nrLvs);
|
||||
t.push_back (LiftedUtils::getSymbol (names[i][j]));
|
||||
}
|
||||
addTuple (t);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
ConstraintTree::ConstraintTree (const ConstraintTree& ct)
|
||||
{
|
||||
root_ = CTNode::copySubtree (ct.root_);
|
||||
logVars_ = ct.logVars_;
|
||||
logVarSet_ = ct.logVarSet_;
|
||||
*this = ct;
|
||||
}
|
||||
|
||||
|
||||
@ -367,6 +389,16 @@ ConstraintTree::project (const LogVarSet& X)
|
||||
|
||||
|
||||
|
||||
ConstraintTree
|
||||
ConstraintTree::projectedCopy (const LogVarSet& X)
|
||||
{
|
||||
ConstraintTree copy = *this;
|
||||
copy.project (X);
|
||||
return copy;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
ConstraintTree::remove (const LogVarSet& X)
|
||||
{
|
||||
@ -865,6 +897,19 @@ ConstraintTree::copyLogVar (LogVar X_1, LogVar X_2)
|
||||
|
||||
|
||||
|
||||
ConstraintTree&
|
||||
ConstraintTree::operator= (const ConstraintTree& ct)
|
||||
{
|
||||
if (this != &ct) {
|
||||
root_ = CTNode::copySubtree (ct.root_);
|
||||
logVars_ = ct.logVars_;
|
||||
logVarSet_ = ct.logVarSet_;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
||||
|
||||
unsigned
|
||||
ConstraintTree::countTuples (const CTNode* n) const
|
||||
{
|
||||
|
@ -108,6 +108,8 @@ class ConstraintTree
|
||||
ConstraintTree (const LogVars&);
|
||||
|
||||
ConstraintTree (const LogVars&, const Tuples&);
|
||||
|
||||
ConstraintTree (vector<vector<string>> names);
|
||||
|
||||
ConstraintTree (const ConstraintTree&);
|
||||
|
||||
@ -157,6 +159,8 @@ class ConstraintTree
|
||||
void applySubstitution (const Substitution&);
|
||||
|
||||
void project (const LogVarSet&);
|
||||
|
||||
ConstraintTree projectedCopy (const LogVarSet&);
|
||||
|
||||
void remove (const LogVarSet&);
|
||||
|
||||
@ -197,6 +201,8 @@ class ConstraintTree
|
||||
ConstraintTrees ground (LogVar);
|
||||
|
||||
void copyLogVar (LogVar,LogVar);
|
||||
|
||||
ConstraintTree& operator= (const ConstraintTree& ct);
|
||||
|
||||
private:
|
||||
unsigned countTuples (const CTNode*) const;
|
||||
|
@ -6,7 +6,7 @@ bool CountingBp::checkForIdenticalFactors = true;
|
||||
|
||||
|
||||
CountingBp::CountingBp (const FactorGraph& fg)
|
||||
: Solver (fg), freeColor_(0)
|
||||
: GroundSolver (fg), freeColor_(0)
|
||||
{
|
||||
findIdenticalFactors();
|
||||
setInitialColors();
|
||||
@ -74,8 +74,8 @@ CountingBp::solveQuery (VarIds queryVids)
|
||||
cout << endl;
|
||||
}
|
||||
if (idx == facNodes.size()) {
|
||||
res = Solver::getJointByConditioning (
|
||||
GroundSolver::CBP, fg, queryVids);
|
||||
res = GroundSolver::getJointByConditioning (
|
||||
GroundSolverType::CBP, fg, queryVids);
|
||||
} else {
|
||||
VarIds reprArgs;
|
||||
for (size_t i = 0; i < queryVids.size(); i++) {
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "Solver.h"
|
||||
#include "GroundSolver.h"
|
||||
#include "FactorGraph.h"
|
||||
#include "Util.h"
|
||||
#include "Horus.h"
|
||||
@ -102,7 +102,7 @@ class FacCluster
|
||||
};
|
||||
|
||||
|
||||
class CountingBp : public Solver
|
||||
class CountingBp : public GroundSolver
|
||||
{
|
||||
public:
|
||||
CountingBp (const FactorGraph& fg);
|
||||
|
@ -1,4 +1,4 @@
|
||||
#include "Solver.h"
|
||||
#include "GroundSolver.h"
|
||||
#include "Util.h"
|
||||
#include "BeliefProp.h"
|
||||
#include "CountingBp.h"
|
||||
@ -6,7 +6,7 @@
|
||||
|
||||
|
||||
void
|
||||
Solver::printAnswer (const VarIds& vids)
|
||||
GroundSolver::printAnswer (const VarIds& vids)
|
||||
{
|
||||
Vars unobservedVars;
|
||||
VarIds unobservedVids;
|
||||
@ -32,7 +32,7 @@ Solver::printAnswer (const VarIds& vids)
|
||||
|
||||
|
||||
void
|
||||
Solver::printAllPosterioris (void)
|
||||
GroundSolver::printAllPosterioris (void)
|
||||
{
|
||||
VarNodes vars = fg.varNodes();
|
||||
std::sort (vars.begin(), vars.end(), sortByVarId());
|
||||
@ -44,8 +44,8 @@ Solver::printAllPosterioris (void)
|
||||
|
||||
|
||||
Params
|
||||
Solver::getJointByConditioning (
|
||||
GroundSolver solverType,
|
||||
GroundSolver::getJointByConditioning (
|
||||
GroundSolverType solverType,
|
||||
FactorGraph fg,
|
||||
const VarIds& jointVarIds) const
|
||||
{
|
||||
@ -55,11 +55,11 @@ Solver::getJointByConditioning (
|
||||
jointVars.push_back (fg.getVarNode (jointVarIds[i]));
|
||||
}
|
||||
|
||||
Solver* solver = 0;
|
||||
GroundSolver* solver = 0;
|
||||
switch (solverType) {
|
||||
case GroundSolver::BP: solver = new BeliefProp (fg); break;
|
||||
case GroundSolver::CBP: solver = new CountingBp (fg); break;
|
||||
case GroundSolver::VE: solver = new VarElim (fg); break;
|
||||
case GroundSolverType::BP: solver = new BeliefProp (fg); break;
|
||||
case GroundSolverType::CBP: solver = new CountingBp (fg); break;
|
||||
case GroundSolverType::VE: solver = new VarElim (fg); break;
|
||||
}
|
||||
Params prevBeliefs = solver->solveQuery ({jointVarIds[0]});
|
||||
VarIds observedVids = {jointVars[0]->varId()};
|
||||
@ -80,9 +80,9 @@ Solver::getJointByConditioning (
|
||||
}
|
||||
delete solver;
|
||||
switch (solverType) {
|
||||
case GroundSolver::BP: solver = new BeliefProp (fg); break;
|
||||
case GroundSolver::CBP: solver = new CountingBp (fg); break;
|
||||
case GroundSolver::VE: solver = new VarElim (fg); break;
|
||||
case GroundSolverType::BP: solver = new BeliefProp (fg); break;
|
||||
case GroundSolverType::CBP: solver = new CountingBp (fg); break;
|
||||
case GroundSolverType::VE: solver = new VarElim (fg); break;
|
||||
}
|
||||
Params beliefs = solver->solveQuery ({jointVarIds[i]});
|
||||
for (size_t k = 0; k < beliefs.size(); k++) {
|
@ -1,5 +1,5 @@
|
||||
#ifndef HORUS_SOLVER_H
|
||||
#define HORUS_SOLVER_H
|
||||
#ifndef HORUS_GROUNDSOLVER_H
|
||||
#define HORUS_GROUNDSOLVER_H
|
||||
|
||||
#include <iomanip>
|
||||
|
||||
@ -10,12 +10,12 @@
|
||||
|
||||
using namespace std;
|
||||
|
||||
class Solver
|
||||
class GroundSolver
|
||||
{
|
||||
public:
|
||||
Solver (const FactorGraph& factorGraph) : fg(factorGraph) { }
|
||||
GroundSolver (const FactorGraph& factorGraph) : fg(factorGraph) { }
|
||||
|
||||
virtual ~Solver() { } // ensure that subclass destructor is called
|
||||
virtual ~GroundSolver() { } // ensure that subclass destructor is called
|
||||
|
||||
virtual Params solveQuery (VarIds queryVids) = 0;
|
||||
|
||||
@ -25,12 +25,12 @@ class Solver
|
||||
|
||||
void printAllPosterioris (void);
|
||||
|
||||
Params getJointByConditioning (GroundSolver,
|
||||
Params getJointByConditioning (GroundSolverType,
|
||||
FactorGraph, const VarIds& jointVarIds) const;
|
||||
|
||||
protected:
|
||||
const FactorGraph& fg;
|
||||
};
|
||||
|
||||
#endif // HORUS_SOLVER_H
|
||||
#endif // HORUS_GROUNDSOLVER_H
|
||||
|
@ -28,14 +28,15 @@ typedef vector<unsigned> Ranges;
|
||||
typedef unsigned long long ullong;
|
||||
|
||||
|
||||
enum LiftedSolver
|
||||
enum LiftedSolverType
|
||||
{
|
||||
FOVE, // first order variable elimination
|
||||
LBP, // lifted belief propagation
|
||||
LVE, // generalized counting first-order variable elimination (GC-FOVE)
|
||||
LBP, // lifted first-order belief propagation
|
||||
LKC // lifted first-order knowledge compilation
|
||||
};
|
||||
|
||||
|
||||
enum GroundSolver
|
||||
enum GroundSolverType
|
||||
{
|
||||
VE, // variable elimination
|
||||
BP, // belief propagation
|
||||
@ -50,8 +51,8 @@ extern bool logDomain;
|
||||
// level of debug information
|
||||
extern unsigned verbosity;
|
||||
|
||||
extern LiftedSolver liftedSolver;
|
||||
extern GroundSolver groundSolver;
|
||||
extern LiftedSolverType liftedSolver;
|
||||
extern GroundSolverType groundSolver;
|
||||
|
||||
};
|
||||
|
||||
|
@ -160,15 +160,15 @@ readQueryAndEvidence (
|
||||
void
|
||||
runSolver (const FactorGraph& fg, const VarIds& queryIds)
|
||||
{
|
||||
Solver* solver = 0;
|
||||
GroundSolver* solver = 0;
|
||||
switch (Globals::groundSolver) {
|
||||
case GroundSolver::VE:
|
||||
case GroundSolverType::VE:
|
||||
solver = new VarElim (fg);
|
||||
break;
|
||||
case GroundSolver::BP:
|
||||
case GroundSolverType::BP:
|
||||
solver = new BeliefProp (fg);
|
||||
break;
|
||||
case GroundSolver::CBP:
|
||||
case GroundSolverType::CBP:
|
||||
solver = new CountingBp (fg);
|
||||
break;
|
||||
default:
|
||||
|
@ -9,11 +9,13 @@
|
||||
|
||||
#include "ParfactorList.h"
|
||||
#include "FactorGraph.h"
|
||||
#include "LiftedOperations.h"
|
||||
#include "LiftedVe.h"
|
||||
#include "VarElim.h"
|
||||
#include "LiftedBp.h"
|
||||
#include "CountingBp.h"
|
||||
#include "BeliefProp.h"
|
||||
#include "LiftedKc.h"
|
||||
#include "ElimGraph.h"
|
||||
#include "BayesBall.h"
|
||||
|
||||
@ -22,25 +24,15 @@ using namespace std;
|
||||
|
||||
typedef std::pair<ParfactorList*, ObservedFormulas*> LiftedNetwork;
|
||||
|
||||
Params readParameters (YAP_Term);
|
||||
|
||||
vector<unsigned> readUnsignedList (YAP_Term);
|
||||
Parfactor* readParfactor (YAP_Term);
|
||||
|
||||
void readLiftedEvidence (YAP_Term, ObservedFormulas&);
|
||||
|
||||
Parfactor* readParfactor (YAP_Term);
|
||||
vector<unsigned> readUnsignedList (YAP_Term list);
|
||||
|
||||
Params readParameters (YAP_Term);
|
||||
|
||||
vector<unsigned>
|
||||
readUnsignedList (YAP_Term list)
|
||||
{
|
||||
vector<unsigned> vec;
|
||||
while (list != YAP_TermNil()) {
|
||||
vec.push_back ((unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (list)));
|
||||
list = YAP_TailOfTerm (list);
|
||||
}
|
||||
return vec;
|
||||
}
|
||||
YAP_Term fillAnswersPrologList (vector<Params>& results);
|
||||
|
||||
|
||||
|
||||
@ -76,138 +68,13 @@ createLiftedNetwork (void)
|
||||
readLiftedEvidence (YAP_ARG2, *(obsFormulas));
|
||||
|
||||
LiftedNetwork* net = new LiftedNetwork (pfList, obsFormulas);
|
||||
|
||||
YAP_Int p = (YAP_Int) (net);
|
||||
return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG3);
|
||||
}
|
||||
|
||||
|
||||
|
||||
Parfactor*
|
||||
readParfactor (YAP_Term pfTerm)
|
||||
{
|
||||
// read dist id
|
||||
unsigned distId = YAP_IntOfTerm (YAP_ArgOfTerm (1, pfTerm));
|
||||
|
||||
// read the ranges
|
||||
Ranges ranges;
|
||||
YAP_Term rangeList = YAP_ArgOfTerm (3, pfTerm);
|
||||
while (rangeList != YAP_TermNil()) {
|
||||
unsigned range = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (rangeList));
|
||||
ranges.push_back (range);
|
||||
rangeList = YAP_TailOfTerm (rangeList);
|
||||
}
|
||||
|
||||
// read parametric random vars
|
||||
ProbFormulas formulas;
|
||||
unsigned count = 0;
|
||||
unordered_map<YAP_Term, LogVar> lvMap;
|
||||
YAP_Term pvList = YAP_ArgOfTerm (2, pfTerm);
|
||||
while (pvList != YAP_TermNil()) {
|
||||
YAP_Term formulaTerm = YAP_HeadOfTerm (pvList);
|
||||
if (YAP_IsAtomTerm (formulaTerm)) {
|
||||
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (formulaTerm)));
|
||||
Symbol functor = LiftedUtils::getSymbol (name);
|
||||
formulas.push_back (ProbFormula (functor, ranges[count]));
|
||||
} else {
|
||||
LogVars logVars;
|
||||
YAP_Functor yapFunctor = YAP_FunctorOfTerm (formulaTerm);
|
||||
string name ((char*) YAP_AtomName (YAP_NameOfFunctor (yapFunctor)));
|
||||
Symbol functor = LiftedUtils::getSymbol (name);
|
||||
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
|
||||
for (unsigned i = 1; i <= arity; i++) {
|
||||
YAP_Term ti = YAP_ArgOfTerm (i, formulaTerm);
|
||||
unordered_map<YAP_Term, LogVar>::iterator it = lvMap.find (ti);
|
||||
if (it != lvMap.end()) {
|
||||
logVars.push_back (it->second);
|
||||
} else {
|
||||
unsigned newLv = lvMap.size();
|
||||
lvMap[ti] = newLv;
|
||||
logVars.push_back (newLv);
|
||||
}
|
||||
}
|
||||
formulas.push_back (ProbFormula (functor, logVars, ranges[count]));
|
||||
}
|
||||
count ++;
|
||||
pvList = YAP_TailOfTerm (pvList);
|
||||
}
|
||||
|
||||
// read the parameters
|
||||
const Params& params = readParameters (YAP_ArgOfTerm (4, pfTerm));
|
||||
|
||||
// read the constraint
|
||||
Tuples tuples;
|
||||
if (lvMap.size() >= 1) {
|
||||
YAP_Term tupleList = YAP_ArgOfTerm (5, pfTerm);
|
||||
while (tupleList != YAP_TermNil()) {
|
||||
YAP_Term term = YAP_HeadOfTerm (tupleList);
|
||||
assert (YAP_IsApplTerm (term));
|
||||
YAP_Functor yapFunctor = YAP_FunctorOfTerm (term);
|
||||
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
|
||||
assert (lvMap.size() == arity);
|
||||
Tuple tuple (arity);
|
||||
for (unsigned i = 1; i <= arity; i++) {
|
||||
YAP_Term ti = YAP_ArgOfTerm (i, term);
|
||||
if (YAP_IsAtomTerm (ti) == false) {
|
||||
cerr << "error: constraint has free variables" << endl;
|
||||
abort();
|
||||
}
|
||||
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ti)));
|
||||
tuple[i - 1] = LiftedUtils::getSymbol (name);
|
||||
}
|
||||
tuples.push_back (tuple);
|
||||
tupleList = YAP_TailOfTerm (tupleList);
|
||||
}
|
||||
}
|
||||
return new Parfactor (formulas, params, tuples, distId);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
readLiftedEvidence (
|
||||
YAP_Term observedList,
|
||||
ObservedFormulas& obsFormulas)
|
||||
{
|
||||
while (observedList != YAP_TermNil()) {
|
||||
YAP_Term pair = YAP_HeadOfTerm (observedList);
|
||||
YAP_Term ground = YAP_ArgOfTerm (1, pair);
|
||||
Symbol functor;
|
||||
Symbols args;
|
||||
if (YAP_IsAtomTerm (ground)) {
|
||||
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ground)));
|
||||
functor = LiftedUtils::getSymbol (name);
|
||||
} else {
|
||||
assert (YAP_IsApplTerm (ground));
|
||||
YAP_Functor yapFunctor = YAP_FunctorOfTerm (ground);
|
||||
string name ((char*) (YAP_AtomName (YAP_NameOfFunctor (yapFunctor))));
|
||||
functor = LiftedUtils::getSymbol (name);
|
||||
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
|
||||
for (unsigned i = 1; i <= arity; i++) {
|
||||
YAP_Term ti = YAP_ArgOfTerm (i, ground);
|
||||
assert (YAP_IsAtomTerm (ti));
|
||||
string arg ((char *) YAP_AtomName (YAP_AtomOfTerm (ti)));
|
||||
args.push_back (LiftedUtils::getSymbol (arg));
|
||||
}
|
||||
}
|
||||
unsigned evidence = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (2, pair));
|
||||
bool found = false;
|
||||
for (size_t i = 0; i < obsFormulas.size(); i++) {
|
||||
if (obsFormulas[i].functor() == functor &&
|
||||
obsFormulas[i].arity() == args.size() &&
|
||||
obsFormulas[i].evidence() == evidence) {
|
||||
obsFormulas[i].addTuple (args);
|
||||
found = true;
|
||||
}
|
||||
}
|
||||
if (found == false) {
|
||||
obsFormulas.push_back (ObservedFormula (functor, evidence, args));
|
||||
}
|
||||
observedList = YAP_TailOfTerm (observedList);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
int
|
||||
createGroundNetwork (void)
|
||||
{
|
||||
@ -253,31 +120,27 @@ createGroundNetwork (void)
|
||||
|
||||
|
||||
|
||||
Params
|
||||
readParameters (YAP_Term paramL)
|
||||
{
|
||||
Params params;
|
||||
assert (YAP_IsPairTerm (paramL));
|
||||
while (paramL != YAP_TermNil()) {
|
||||
params.push_back ((double) YAP_FloatOfTerm (YAP_HeadOfTerm (paramL)));
|
||||
paramL = YAP_TailOfTerm (paramL);
|
||||
}
|
||||
if (Globals::logDomain) {
|
||||
Util::log (params);
|
||||
}
|
||||
return params;
|
||||
}
|
||||
|
||||
|
||||
|
||||
int
|
||||
runLiftedSolver (void)
|
||||
{
|
||||
LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1);
|
||||
ParfactorList pfListCopy (*network->first);
|
||||
LiftedOperations::absorveEvidence (pfListCopy, *network->second);
|
||||
|
||||
LiftedSolver* solver = 0;
|
||||
switch (Globals::liftedSolver) {
|
||||
case LiftedSolverType::LVE: solver = new LiftedVe (pfListCopy); break;
|
||||
case LiftedSolverType::LBP: solver = new LiftedBp (pfListCopy); break;
|
||||
case LiftedSolverType::LKC: solver = new LiftedKc (pfListCopy); break;
|
||||
}
|
||||
|
||||
if (Globals::verbosity > 0) {
|
||||
solver->printSolverFlags();
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
YAP_Term taskList = YAP_ARG2;
|
||||
vector<Params> results;
|
||||
ParfactorList pfListCopy (*network->first);
|
||||
LiftedVe::absorveEvidence (pfListCopy, *network->second);
|
||||
while (taskList != YAP_TermNil()) {
|
||||
Grounds queryVars;
|
||||
YAP_Term jointList = YAP_HeadOfTerm (taskList);
|
||||
@ -303,41 +166,13 @@ runLiftedSolver (void)
|
||||
}
|
||||
jointList = YAP_TailOfTerm (jointList);
|
||||
}
|
||||
if (Globals::liftedSolver == LiftedSolver::FOVE) {
|
||||
LiftedVe solver (pfListCopy);
|
||||
if (Globals::verbosity > 0 && taskList == YAP_ARG2) {
|
||||
solver.printSolverFlags();
|
||||
cout << endl;
|
||||
}
|
||||
results.push_back (solver.solveQuery (queryVars));
|
||||
} else if (Globals::liftedSolver == LiftedSolver::LBP) {
|
||||
LiftedBp solver (pfListCopy);
|
||||
if (Globals::verbosity > 0 && taskList == YAP_ARG2) {
|
||||
solver.printSolverFlags();
|
||||
cout << endl;
|
||||
}
|
||||
results.push_back (solver.solveQuery (queryVars));
|
||||
} else {
|
||||
assert (false);
|
||||
}
|
||||
results.push_back (solver->solveQuery (queryVars));
|
||||
taskList = YAP_TailOfTerm (taskList);
|
||||
}
|
||||
|
||||
YAP_Term list = YAP_TermNil();
|
||||
for (size_t i = results.size(); i-- > 0; ) {
|
||||
const Params& beliefs = results[i];
|
||||
YAP_Term queryBeliefsL = YAP_TermNil();
|
||||
for (size_t j = beliefs.size(); j-- > 0; ) {
|
||||
YAP_Int sl1 = YAP_InitSlot (list);
|
||||
YAP_Term belief = YAP_MkFloatTerm (beliefs[j]);
|
||||
queryBeliefsL = YAP_MkPairTerm (belief, queryBeliefsL);
|
||||
list = YAP_GetFromSlot (sl1);
|
||||
YAP_RecoverSlots (1);
|
||||
}
|
||||
list = YAP_MkPairTerm (queryBeliefsL, list);
|
||||
}
|
||||
|
||||
return YAP_Unify (list, YAP_ARG3);
|
||||
delete solver;
|
||||
|
||||
return YAP_Unify (fillAnswersPrologList (results), YAP_ARG3);
|
||||
}
|
||||
|
||||
|
||||
@ -346,6 +181,7 @@ int
|
||||
runGroundSolver (void)
|
||||
{
|
||||
FactorGraph* fg = (FactorGraph*) YAP_IntOfTerm (YAP_ARG1);
|
||||
|
||||
vector<VarIds> tasks;
|
||||
YAP_Term taskList = YAP_ARG2;
|
||||
while (taskList != YAP_TermNil()) {
|
||||
@ -357,22 +193,19 @@ runGroundSolver (void)
|
||||
for (size_t i = 0; i < tasks.size(); i++) {
|
||||
Util::addToSet (vids, tasks[i]);
|
||||
}
|
||||
Solver* solver = 0;
|
||||
|
||||
FactorGraph* mfg = fg;
|
||||
if (fg->bayesianFactors()) {
|
||||
mfg = BayesBall::getMinimalFactorGraph (
|
||||
*fg, VarIds (vids.begin(), vids.end()));
|
||||
}
|
||||
|
||||
if (Globals::groundSolver == GroundSolver::VE) {
|
||||
solver = new VarElim (*mfg);
|
||||
} else if (Globals::groundSolver == GroundSolver::BP) {
|
||||
solver = new BeliefProp (*mfg);
|
||||
} else if (Globals::groundSolver == GroundSolver::CBP) {
|
||||
CountingBp::checkForIdenticalFactors = false;
|
||||
solver = new CountingBp (*mfg);
|
||||
} else {
|
||||
assert (false);
|
||||
GroundSolver* solver = 0;
|
||||
CountingBp::checkForIdenticalFactors = false;
|
||||
switch (Globals::groundSolver) {
|
||||
case GroundSolverType::VE: solver = new VarElim (*mfg); break;
|
||||
case GroundSolverType::BP: solver = new BeliefProp (*mfg); break;
|
||||
case GroundSolverType::CBP: solver = new CountingBp (*mfg); break;
|
||||
}
|
||||
|
||||
if (Globals::verbosity > 0) {
|
||||
@ -391,20 +224,7 @@ runGroundSolver (void)
|
||||
delete mfg;
|
||||
}
|
||||
|
||||
YAP_Term list = YAP_TermNil();
|
||||
for (size_t i = results.size(); i-- > 0; ) {
|
||||
const Params& beliefs = results[i];
|
||||
YAP_Term queryBeliefsL = YAP_TermNil();
|
||||
for (size_t j = beliefs.size(); j-- > 0; ) {
|
||||
YAP_Int sl1 = YAP_InitSlot (list);
|
||||
YAP_Term belief = YAP_MkFloatTerm (beliefs[j]);
|
||||
queryBeliefsL = YAP_MkPairTerm (belief, queryBeliefsL);
|
||||
list = YAP_GetFromSlot (sl1);
|
||||
YAP_RecoverSlots (1);
|
||||
}
|
||||
list = YAP_MkPairTerm (queryBeliefsL, list);
|
||||
}
|
||||
return YAP_Unify (list, YAP_ARG3);
|
||||
return YAP_Unify (fillAnswersPrologList (results), YAP_ARG3);
|
||||
}
|
||||
|
||||
|
||||
@ -535,6 +355,183 @@ freeLiftedNetwork (void)
|
||||
|
||||
|
||||
|
||||
Parfactor*
|
||||
readParfactor (YAP_Term pfTerm)
|
||||
{
|
||||
// read dist id
|
||||
unsigned distId = YAP_IntOfTerm (YAP_ArgOfTerm (1, pfTerm));
|
||||
|
||||
// read the ranges
|
||||
Ranges ranges;
|
||||
YAP_Term rangeList = YAP_ArgOfTerm (3, pfTerm);
|
||||
while (rangeList != YAP_TermNil()) {
|
||||
unsigned range = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (rangeList));
|
||||
ranges.push_back (range);
|
||||
rangeList = YAP_TailOfTerm (rangeList);
|
||||
}
|
||||
|
||||
// read parametric random vars
|
||||
ProbFormulas formulas;
|
||||
unsigned count = 0;
|
||||
unordered_map<YAP_Term, LogVar> lvMap;
|
||||
YAP_Term pvList = YAP_ArgOfTerm (2, pfTerm);
|
||||
while (pvList != YAP_TermNil()) {
|
||||
YAP_Term formulaTerm = YAP_HeadOfTerm (pvList);
|
||||
if (YAP_IsAtomTerm (formulaTerm)) {
|
||||
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (formulaTerm)));
|
||||
Symbol functor = LiftedUtils::getSymbol (name);
|
||||
formulas.push_back (ProbFormula (functor, ranges[count]));
|
||||
} else {
|
||||
LogVars logVars;
|
||||
YAP_Functor yapFunctor = YAP_FunctorOfTerm (formulaTerm);
|
||||
string name ((char*) YAP_AtomName (YAP_NameOfFunctor (yapFunctor)));
|
||||
Symbol functor = LiftedUtils::getSymbol (name);
|
||||
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
|
||||
for (unsigned i = 1; i <= arity; i++) {
|
||||
YAP_Term ti = YAP_ArgOfTerm (i, formulaTerm);
|
||||
unordered_map<YAP_Term, LogVar>::iterator it = lvMap.find (ti);
|
||||
if (it != lvMap.end()) {
|
||||
logVars.push_back (it->second);
|
||||
} else {
|
||||
unsigned newLv = lvMap.size();
|
||||
lvMap[ti] = newLv;
|
||||
logVars.push_back (newLv);
|
||||
}
|
||||
}
|
||||
formulas.push_back (ProbFormula (functor, logVars, ranges[count]));
|
||||
}
|
||||
count ++;
|
||||
pvList = YAP_TailOfTerm (pvList);
|
||||
}
|
||||
|
||||
// read the parameters
|
||||
const Params& params = readParameters (YAP_ArgOfTerm (4, pfTerm));
|
||||
|
||||
// read the constraint
|
||||
Tuples tuples;
|
||||
if (lvMap.size() >= 1) {
|
||||
YAP_Term tupleList = YAP_ArgOfTerm (5, pfTerm);
|
||||
while (tupleList != YAP_TermNil()) {
|
||||
YAP_Term term = YAP_HeadOfTerm (tupleList);
|
||||
assert (YAP_IsApplTerm (term));
|
||||
YAP_Functor yapFunctor = YAP_FunctorOfTerm (term);
|
||||
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
|
||||
assert (lvMap.size() == arity);
|
||||
Tuple tuple (arity);
|
||||
for (unsigned i = 1; i <= arity; i++) {
|
||||
YAP_Term ti = YAP_ArgOfTerm (i, term);
|
||||
if (YAP_IsAtomTerm (ti) == false) {
|
||||
cerr << "error: constraint has free variables" << endl;
|
||||
abort();
|
||||
}
|
||||
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ti)));
|
||||
tuple[i - 1] = LiftedUtils::getSymbol (name);
|
||||
}
|
||||
tuples.push_back (tuple);
|
||||
tupleList = YAP_TailOfTerm (tupleList);
|
||||
}
|
||||
}
|
||||
return new Parfactor (formulas, params, tuples, distId);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
readLiftedEvidence (
|
||||
YAP_Term observedList,
|
||||
ObservedFormulas& obsFormulas)
|
||||
{
|
||||
while (observedList != YAP_TermNil()) {
|
||||
YAP_Term pair = YAP_HeadOfTerm (observedList);
|
||||
YAP_Term ground = YAP_ArgOfTerm (1, pair);
|
||||
Symbol functor;
|
||||
Symbols args;
|
||||
if (YAP_IsAtomTerm (ground)) {
|
||||
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ground)));
|
||||
functor = LiftedUtils::getSymbol (name);
|
||||
} else {
|
||||
assert (YAP_IsApplTerm (ground));
|
||||
YAP_Functor yapFunctor = YAP_FunctorOfTerm (ground);
|
||||
string name ((char*) (YAP_AtomName (YAP_NameOfFunctor (yapFunctor))));
|
||||
functor = LiftedUtils::getSymbol (name);
|
||||
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
|
||||
for (unsigned i = 1; i <= arity; i++) {
|
||||
YAP_Term ti = YAP_ArgOfTerm (i, ground);
|
||||
assert (YAP_IsAtomTerm (ti));
|
||||
string arg ((char *) YAP_AtomName (YAP_AtomOfTerm (ti)));
|
||||
args.push_back (LiftedUtils::getSymbol (arg));
|
||||
}
|
||||
}
|
||||
unsigned evidence = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (2, pair));
|
||||
bool found = false;
|
||||
for (size_t i = 0; i < obsFormulas.size(); i++) {
|
||||
if (obsFormulas[i].functor() == functor &&
|
||||
obsFormulas[i].arity() == args.size() &&
|
||||
obsFormulas[i].evidence() == evidence) {
|
||||
obsFormulas[i].addTuple (args);
|
||||
found = true;
|
||||
}
|
||||
}
|
||||
if (found == false) {
|
||||
obsFormulas.push_back (ObservedFormula (functor, evidence, args));
|
||||
}
|
||||
observedList = YAP_TailOfTerm (observedList);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
vector<unsigned>
|
||||
readUnsignedList (YAP_Term list)
|
||||
{
|
||||
vector<unsigned> vec;
|
||||
while (list != YAP_TermNil()) {
|
||||
vec.push_back ((unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (list)));
|
||||
list = YAP_TailOfTerm (list);
|
||||
}
|
||||
return vec;
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
readParameters (YAP_Term paramL)
|
||||
{
|
||||
Params params;
|
||||
assert (YAP_IsPairTerm (paramL));
|
||||
while (paramL != YAP_TermNil()) {
|
||||
params.push_back ((double) YAP_FloatOfTerm (YAP_HeadOfTerm (paramL)));
|
||||
paramL = YAP_TailOfTerm (paramL);
|
||||
}
|
||||
if (Globals::logDomain) {
|
||||
Util::log (params);
|
||||
}
|
||||
return params;
|
||||
}
|
||||
|
||||
|
||||
|
||||
YAP_Term
|
||||
fillAnswersPrologList (vector<Params>& results)
|
||||
{
|
||||
YAP_Term list = YAP_TermNil();
|
||||
for (size_t i = results.size(); i-- > 0; ) {
|
||||
const Params& beliefs = results[i];
|
||||
YAP_Term queryBeliefsL = YAP_TermNil();
|
||||
for (size_t j = beliefs.size(); j-- > 0; ) {
|
||||
YAP_Int sl1 = YAP_InitSlot (list);
|
||||
YAP_Term belief = YAP_MkFloatTerm (beliefs[j]);
|
||||
queryBeliefsL = YAP_MkPairTerm (belief, queryBeliefsL);
|
||||
list = YAP_GetFromSlot (sl1);
|
||||
YAP_RecoverSlots (1);
|
||||
}
|
||||
list = YAP_MkPairTerm (queryBeliefsL, list);
|
||||
}
|
||||
return list;
|
||||
}
|
||||
|
||||
|
||||
|
||||
extern "C" void
|
||||
init_predicates (void)
|
||||
{
|
||||
|
@ -1,11 +1,11 @@
|
||||
#include "LiftedBp.h"
|
||||
#include "WeightedBp.h"
|
||||
#include "FactorGraph.h"
|
||||
#include "LiftedVe.h"
|
||||
#include "LiftedOperations.h"
|
||||
|
||||
|
||||
LiftedBp::LiftedBp (const ParfactorList& pfList)
|
||||
: pfList_(pfList)
|
||||
LiftedBp::LiftedBp (const ParfactorList& parfactorList)
|
||||
: LiftedSolver (parfactorList)
|
||||
{
|
||||
refineParfactors();
|
||||
createFactorGraph();
|
||||
@ -82,6 +82,7 @@ LiftedBp::printSolverFlags (void) const
|
||||
void
|
||||
LiftedBp::refineParfactors (void)
|
||||
{
|
||||
pfList_ = parfactorList;
|
||||
while (iterate() == false);
|
||||
|
||||
if (Globals::verbosity > 2) {
|
||||
@ -101,7 +102,7 @@ LiftedBp::iterate (void)
|
||||
for (size_t i = 0; i < args.size(); i++) {
|
||||
LogVarSet lvs = (*it)->logVarSet() - args[i].logVars();
|
||||
if ((*it)->constr()->isCountNormalized (lvs) == false) {
|
||||
Parfactors pfs = LiftedVe::countNormalize (*it, lvs);
|
||||
Parfactors pfs = LiftedOperations::countNormalize (*it, lvs);
|
||||
it = pfList_.removeAndDelete (it);
|
||||
pfList_.add (pfs);
|
||||
return false;
|
||||
@ -189,12 +190,12 @@ LiftedBp::rangeOfGround (const Ground& gr)
|
||||
Params
|
||||
LiftedBp::getJointByConditioning (
|
||||
const ParfactorList& pfList,
|
||||
const Grounds& grounds)
|
||||
const Grounds& query)
|
||||
{
|
||||
LiftedBp solver (pfList);
|
||||
Params prevBeliefs = solver.solveQuery ({grounds[0]});
|
||||
Grounds obsGrounds = {grounds[0]};
|
||||
for (size_t i = 1; i < grounds.size(); i++) {
|
||||
Params prevBeliefs = solver.solveQuery ({query[0]});
|
||||
Grounds obsGrounds = {query[0]};
|
||||
for (size_t i = 1; i < query.size(); i++) {
|
||||
Params newBeliefs;
|
||||
vector<ObservedFormula> obsFs;
|
||||
Ranges obsRanges;
|
||||
@ -209,16 +210,16 @@ LiftedBp::getJointByConditioning (
|
||||
obsFs[j].setEvidence (indexer[j]);
|
||||
}
|
||||
ParfactorList tempPfList (pfList);
|
||||
LiftedVe::absorveEvidence (tempPfList, obsFs);
|
||||
LiftedOperations::absorveEvidence (tempPfList, obsFs);
|
||||
LiftedBp solver (tempPfList);
|
||||
Params beliefs = solver.solveQuery ({grounds[i]});
|
||||
Params beliefs = solver.solveQuery ({query[i]});
|
||||
for (size_t k = 0; k < beliefs.size(); k++) {
|
||||
newBeliefs.push_back (beliefs[k]);
|
||||
}
|
||||
++ indexer;
|
||||
}
|
||||
int count = -1;
|
||||
unsigned range = rangeOfGround (grounds[i]);
|
||||
unsigned range = rangeOfGround (query[i]);
|
||||
for (size_t j = 0; j < newBeliefs.size(); j++) {
|
||||
if (j % range == 0) {
|
||||
count ++;
|
||||
@ -226,7 +227,7 @@ LiftedBp::getJointByConditioning (
|
||||
newBeliefs[j] *= prevBeliefs[count];
|
||||
}
|
||||
prevBeliefs = newBeliefs;
|
||||
obsGrounds.push_back (grounds[i]);
|
||||
obsGrounds.push_back (query[i]);
|
||||
}
|
||||
return prevBeliefs;
|
||||
}
|
||||
|
@ -1,12 +1,13 @@
|
||||
#ifndef HORUS_LIFTEDBP_H
|
||||
#define HORUS_LIFTEDBP_H
|
||||
|
||||
#include "LiftedSolver.h"
|
||||
#include "ParfactorList.h"
|
||||
|
||||
class FactorGraph;
|
||||
class WeightedBp;
|
||||
|
||||
class LiftedBp
|
||||
class LiftedBp : public LiftedSolver
|
||||
{
|
||||
public:
|
||||
LiftedBp (const ParfactorList& pfList);
|
||||
@ -39,3 +40,4 @@ class LiftedBp
|
||||
};
|
||||
|
||||
#endif // HORUS_LIFTEDBP_H
|
||||
|
||||
|
1075
packages/CLPBN/horus/LiftedCircuit.cpp
Normal file
1075
packages/CLPBN/horus/LiftedCircuit.cpp
Normal file
File diff suppressed because it is too large
Load Diff
274
packages/CLPBN/horus/LiftedCircuit.h
Normal file
274
packages/CLPBN/horus/LiftedCircuit.h
Normal file
@ -0,0 +1,274 @@
|
||||
#ifndef HORUS_LIFTEDCIRCUIT_H
|
||||
#define HORUS_LIFTEDCIRCUIT_H
|
||||
|
||||
#include <stack>
|
||||
|
||||
#include "LiftedWCNF.h"
|
||||
|
||||
|
||||
enum CircuitNodeType {
|
||||
OR_NODE,
|
||||
AND_NODE,
|
||||
SET_OR_NODE,
|
||||
SET_AND_NODE,
|
||||
INC_EXC_NODE,
|
||||
LEAF_NODE,
|
||||
SMOOTH_NODE,
|
||||
TRUE_NODE,
|
||||
COMPILATION_FAILED_NODE
|
||||
};
|
||||
|
||||
|
||||
|
||||
class CircuitNode
|
||||
{
|
||||
public:
|
||||
CircuitNode (const Clauses& clauses, string explanation = "")
|
||||
: clauses_(clauses), explanation_(explanation) { }
|
||||
|
||||
const Clauses& clauses (void) const { return clauses_; }
|
||||
|
||||
Clauses clauses (void) { return clauses_; }
|
||||
|
||||
virtual double weight (void) const = 0;
|
||||
|
||||
string explanation (void) const { return explanation_; }
|
||||
|
||||
private:
|
||||
Clauses clauses_;
|
||||
string explanation_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class OrNode : public CircuitNode
|
||||
{
|
||||
public:
|
||||
OrNode (const Clauses& clauses, string explanation = "")
|
||||
: CircuitNode (clauses, explanation),
|
||||
leftBranch_(0), rightBranch_(0) { }
|
||||
|
||||
CircuitNode** leftBranch (void) { return &leftBranch_; }
|
||||
CircuitNode** rightBranch (void) { return &rightBranch_; }
|
||||
|
||||
double weight (void) const;
|
||||
|
||||
private:
|
||||
CircuitNode* leftBranch_;
|
||||
CircuitNode* rightBranch_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class AndNode : public CircuitNode
|
||||
{
|
||||
public:
|
||||
AndNode (const Clauses& clauses, string explanation = "")
|
||||
: CircuitNode (clauses, explanation),
|
||||
leftBranch_(0), rightBranch_(0) { }
|
||||
|
||||
AndNode (
|
||||
const Clauses& clauses,
|
||||
CircuitNode* leftBranch,
|
||||
CircuitNode* rightBranch,
|
||||
string explanation = "")
|
||||
: CircuitNode (clauses, explanation),
|
||||
leftBranch_(leftBranch), rightBranch_(rightBranch) { }
|
||||
|
||||
AndNode (
|
||||
CircuitNode* leftBranch,
|
||||
CircuitNode* rightBranch,
|
||||
string explanation = "")
|
||||
: CircuitNode ({}, explanation),
|
||||
leftBranch_(leftBranch), rightBranch_(rightBranch) { }
|
||||
|
||||
CircuitNode** leftBranch (void) { return &leftBranch_; }
|
||||
CircuitNode** rightBranch (void) { return &rightBranch_; }
|
||||
|
||||
double weight (void) const;
|
||||
|
||||
private:
|
||||
CircuitNode* leftBranch_;
|
||||
CircuitNode* rightBranch_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class SetOrNode : public CircuitNode
|
||||
{
|
||||
public:
|
||||
SetOrNode (unsigned nrGroundings, const Clauses& clauses)
|
||||
: CircuitNode (clauses, " AC"), follow_(0),
|
||||
nrGroundings_(nrGroundings) { }
|
||||
|
||||
CircuitNode** follow (void) { return &follow_; }
|
||||
|
||||
static unsigned nrPositives (void) { return nrGrsStack.top().first; }
|
||||
|
||||
static unsigned nrNegatives (void) { return nrGrsStack.top().second; }
|
||||
|
||||
double weight (void) const;
|
||||
|
||||
private:
|
||||
CircuitNode* follow_;
|
||||
unsigned nrGroundings_;
|
||||
|
||||
static stack<pair<unsigned, unsigned>> nrGrsStack;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class SetAndNode : public CircuitNode
|
||||
{
|
||||
public:
|
||||
SetAndNode (unsigned nrGroundings, const Clauses& clauses)
|
||||
: CircuitNode (clauses, " IPG"), follow_(0),
|
||||
nrGroundings_(nrGroundings) { }
|
||||
|
||||
CircuitNode** follow (void) { return &follow_; }
|
||||
|
||||
double weight (void) const;
|
||||
|
||||
private:
|
||||
CircuitNode* follow_;
|
||||
unsigned nrGroundings_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class IncExcNode : public CircuitNode
|
||||
{
|
||||
public:
|
||||
IncExcNode (const Clauses& clauses, string explanation)
|
||||
: CircuitNode (clauses, explanation), plus1Branch_(0),
|
||||
plus2Branch_(0), minusBranch_(0) { }
|
||||
|
||||
CircuitNode** plus1Branch (void) { return &plus1Branch_; }
|
||||
CircuitNode** plus2Branch (void) { return &plus2Branch_; }
|
||||
CircuitNode** minusBranch (void) { return &minusBranch_; }
|
||||
|
||||
double weight (void) const;
|
||||
|
||||
private:
|
||||
CircuitNode* plus1Branch_;
|
||||
CircuitNode* plus2Branch_;
|
||||
CircuitNode* minusBranch_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class LeafNode : public CircuitNode
|
||||
{
|
||||
public:
|
||||
LeafNode (const Clause& clause, const LiftedWCNF& lwcnf)
|
||||
: CircuitNode (Clauses() = {clause}), lwcnf_(lwcnf) { }
|
||||
|
||||
double weight (void) const;
|
||||
|
||||
private:
|
||||
const LiftedWCNF& lwcnf_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class SmoothNode : public CircuitNode
|
||||
{
|
||||
public:
|
||||
SmoothNode (const Clauses& clauses, const LiftedWCNF& lwcnf)
|
||||
: CircuitNode (clauses), lwcnf_(lwcnf) { }
|
||||
|
||||
double weight (void) const;
|
||||
|
||||
private:
|
||||
const LiftedWCNF& lwcnf_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class TrueNode : public CircuitNode
|
||||
{
|
||||
public:
|
||||
TrueNode (void) : CircuitNode ({}) { }
|
||||
|
||||
double weight (void) const;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class CompilationFailedNode : public CircuitNode
|
||||
{
|
||||
public:
|
||||
CompilationFailedNode (const Clauses& clauses)
|
||||
: CircuitNode (clauses) { }
|
||||
|
||||
double weight (void) const;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class LiftedCircuit
|
||||
{
|
||||
public:
|
||||
LiftedCircuit (const LiftedWCNF* lwcnf);
|
||||
|
||||
double getWeightedModelCount (void) const;
|
||||
|
||||
void exportToGraphViz (const char*);
|
||||
|
||||
private:
|
||||
|
||||
void compile (CircuitNode** follow, Clauses& clauses);
|
||||
|
||||
bool tryUnitPropagation (CircuitNode** follow, Clauses& clauses);
|
||||
|
||||
bool tryIndependence (CircuitNode** follow, Clauses& clauses);
|
||||
|
||||
bool tryShannonDecomp (CircuitNode** follow, Clauses& clauses);
|
||||
|
||||
bool tryInclusionExclusion (CircuitNode** follow, Clauses& clauses);
|
||||
|
||||
bool tryIndepPartialGrounding (CircuitNode** follow, Clauses& clauses);
|
||||
|
||||
bool tryIndepPartialGroundingAux (Clauses& clauses, ConstraintTree& ct,
|
||||
LogVars& rootLogVars);
|
||||
|
||||
bool tryAtomCounting (CircuitNode** follow, Clauses& clauses);
|
||||
|
||||
bool tryGrounding (CircuitNode** follow, Clauses& clauses);
|
||||
|
||||
void shatterCountedLogVars (Clauses& clauses);
|
||||
|
||||
bool shatterCountedLogVarsAux (Clauses& clauses);
|
||||
|
||||
bool shatterCountedLogVarsAux (Clauses& clauses, size_t idx1, size_t idx2);
|
||||
|
||||
bool independentClause (Clause& clause, Clauses& otherClauses) const;
|
||||
|
||||
bool independentLiteral (const Literal& lit,
|
||||
const Literals& otherLits) const;
|
||||
|
||||
LitLvTypesSet smoothCircuit (CircuitNode* node);
|
||||
|
||||
void createSmoothNode (const LitLvTypesSet& lids,
|
||||
CircuitNode** prev);
|
||||
|
||||
vector<LogVarTypes> getAllPossibleTypes (unsigned nrLogVars) const;
|
||||
|
||||
bool containsTypes (const LogVarTypes& typesA,
|
||||
const LogVarTypes& typesB) const;
|
||||
|
||||
CircuitNodeType getCircuitNodeType (const CircuitNode* node) const;
|
||||
|
||||
void exportToGraphViz (CircuitNode* node, ofstream&);
|
||||
|
||||
void printClauses (const CircuitNode* node, ofstream&,
|
||||
string extraOptions = "");
|
||||
|
||||
string escapeNode (const CircuitNode* node) const;
|
||||
|
||||
CircuitNode* root_;
|
||||
const LiftedWCNF* lwcnf_;
|
||||
};
|
||||
|
||||
#endif // HORUS_LIFTEDCIRCUIT_H
|
||||
|
75
packages/CLPBN/horus/LiftedKc.cpp
Normal file
75
packages/CLPBN/horus/LiftedKc.cpp
Normal file
@ -0,0 +1,75 @@
|
||||
#include "LiftedKc.h"
|
||||
#include "LiftedWCNF.h"
|
||||
#include "LiftedCircuit.h"
|
||||
#include "LiftedOperations.h"
|
||||
#include "Indexer.h"
|
||||
|
||||
|
||||
LiftedKc::~LiftedKc (void)
|
||||
{
|
||||
delete lwcnf_;
|
||||
delete circuit_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
LiftedKc::solveQuery (const Grounds& query)
|
||||
{
|
||||
pfList_ = parfactorList;
|
||||
LiftedOperations::shatterAgainstQuery (pfList_, query);
|
||||
LiftedOperations::runWeakBayesBall (pfList_, query);
|
||||
lwcnf_ = new LiftedWCNF (pfList_);
|
||||
circuit_ = new LiftedCircuit (lwcnf_);
|
||||
vector<PrvGroup> groups;
|
||||
Ranges ranges;
|
||||
for (size_t i = 0; i < query.size(); i++) {
|
||||
ParfactorList::const_iterator it = pfList_.begin();
|
||||
while (it != pfList_.end()) {
|
||||
size_t idx = (*it)->indexOfGround (query[i]);
|
||||
if (idx != (*it)->nrArguments()) {
|
||||
groups.push_back ((*it)->argument (idx).group());
|
||||
ranges.push_back ((*it)->range (idx));
|
||||
break;
|
||||
}
|
||||
++ it;
|
||||
}
|
||||
}
|
||||
assert (groups.size() == query.size());
|
||||
Params params;
|
||||
Indexer indexer (ranges);
|
||||
while (indexer.valid()) {
|
||||
for (size_t i = 0; i < groups.size(); i++) {
|
||||
vector<LiteralId> litIds = lwcnf_->prvGroupLiterals (groups[i]);
|
||||
for (size_t j = 0; j < litIds.size(); j++) {
|
||||
if (indexer[i] == j) {
|
||||
lwcnf_->addWeight (litIds[j], LogAware::one(),
|
||||
LogAware::one());
|
||||
} else {
|
||||
lwcnf_->addWeight (litIds[j], LogAware::zero(),
|
||||
LogAware::one());
|
||||
}
|
||||
}
|
||||
}
|
||||
params.push_back (circuit_->getWeightedModelCount());
|
||||
++ indexer;
|
||||
}
|
||||
LogAware::normalize (params);
|
||||
if (Globals::logDomain) {
|
||||
Util::exp (params);
|
||||
}
|
||||
return params;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
LiftedKc::printSolverFlags (void) const
|
||||
{
|
||||
stringstream ss;
|
||||
ss << "lifted kc [" ;
|
||||
ss << "log_domain=" << Util::toString (Globals::logDomain);
|
||||
ss << "]" ;
|
||||
cout << ss.str() << endl;
|
||||
}
|
||||
|
30
packages/CLPBN/horus/LiftedKc.h
Normal file
30
packages/CLPBN/horus/LiftedKc.h
Normal file
@ -0,0 +1,30 @@
|
||||
#ifndef HORUS_LIFTEDKC_H
|
||||
#define HORUS_LIFTEDKC_H
|
||||
|
||||
#include "LiftedSolver.h"
|
||||
#include "ParfactorList.h"
|
||||
|
||||
class LiftedWCNF;
|
||||
class LiftedCircuit;
|
||||
|
||||
|
||||
class LiftedKc : public LiftedSolver
|
||||
{
|
||||
public:
|
||||
LiftedKc (const ParfactorList& pfList)
|
||||
: LiftedSolver(pfList) { }
|
||||
|
||||
~LiftedKc (void);
|
||||
|
||||
Params solveQuery (const Grounds&);
|
||||
|
||||
void printSolverFlags (void) const;
|
||||
|
||||
private:
|
||||
LiftedWCNF* lwcnf_;
|
||||
LiftedCircuit* circuit_;
|
||||
ParfactorList pfList_;
|
||||
};
|
||||
|
||||
#endif // HORUS_LIFTEDKC_H
|
||||
|
271
packages/CLPBN/horus/LiftedOperations.cpp
Normal file
271
packages/CLPBN/horus/LiftedOperations.cpp
Normal file
@ -0,0 +1,271 @@
|
||||
#include "LiftedOperations.h"
|
||||
|
||||
|
||||
void
|
||||
LiftedOperations::shatterAgainstQuery (
|
||||
ParfactorList& pfList,
|
||||
const Grounds& query)
|
||||
{
|
||||
for (size_t i = 0; i < query.size(); i++) {
|
||||
if (query[i].isAtom()) {
|
||||
continue;
|
||||
}
|
||||
bool found = false;
|
||||
Parfactors newPfs;
|
||||
ParfactorList::iterator it = pfList.begin();
|
||||
while (it != pfList.end()) {
|
||||
if ((*it)->containsGround (query[i])) {
|
||||
found = true;
|
||||
std::pair<ConstraintTree*, ConstraintTree*> split;
|
||||
LogVars queryLvs (
|
||||
(*it)->constr()->logVars().begin(),
|
||||
(*it)->constr()->logVars().begin() + query[i].arity());
|
||||
split = (*it)->constr()->split (query[i].args());
|
||||
ConstraintTree* commCt = split.first;
|
||||
ConstraintTree* exclCt = split.second;
|
||||
newPfs.push_back (new Parfactor (*it, commCt));
|
||||
if (exclCt->empty() == false) {
|
||||
newPfs.push_back (new Parfactor (*it, exclCt));
|
||||
} else {
|
||||
delete exclCt;
|
||||
}
|
||||
it = pfList.removeAndDelete (it);
|
||||
} else {
|
||||
++ it;
|
||||
}
|
||||
}
|
||||
if (found == false) {
|
||||
cerr << "error: could not find a parfactor with ground " ;
|
||||
cerr << "`" << query[i] << "'" << endl;
|
||||
exit (0);
|
||||
}
|
||||
pfList.add (newPfs);
|
||||
}
|
||||
if (Globals::verbosity > 2) {
|
||||
Util::printAsteriskLine();
|
||||
cout << "SHATTERED AGAINST THE QUERY" << endl;
|
||||
for (size_t i = 0; i < query.size(); i++) {
|
||||
cout << " -> " << query[i] << endl;
|
||||
}
|
||||
Util::printAsteriskLine();
|
||||
pfList.print();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
LiftedOperations::runWeakBayesBall (
|
||||
ParfactorList& pfList,
|
||||
const Grounds& query)
|
||||
{
|
||||
queue<PrvGroup> todo; // groups to process
|
||||
set<PrvGroup> done; // processed or in queue
|
||||
for (size_t i = 0; i < query.size(); i++) {
|
||||
ParfactorList::iterator it = pfList.begin();
|
||||
while (it != pfList.end()) {
|
||||
PrvGroup group = (*it)->findGroup (query[i]);
|
||||
if (group != numeric_limits<PrvGroup>::max()) {
|
||||
todo.push (group);
|
||||
done.insert (group);
|
||||
break;
|
||||
}
|
||||
++ it;
|
||||
}
|
||||
}
|
||||
|
||||
set<Parfactor*> requiredPfs;
|
||||
while (todo.empty() == false) {
|
||||
PrvGroup group = todo.front();
|
||||
ParfactorList::iterator it = pfList.begin();
|
||||
while (it != pfList.end()) {
|
||||
if (Util::contains (requiredPfs, *it) == false &&
|
||||
(*it)->containsGroup (group)) {
|
||||
vector<PrvGroup> groups = (*it)->getAllGroups();
|
||||
for (size_t i = 0; i < groups.size(); i++) {
|
||||
if (Util::contains (done, groups[i]) == false) {
|
||||
todo.push (groups[i]);
|
||||
done.insert (groups[i]);
|
||||
}
|
||||
}
|
||||
requiredPfs.insert (*it);
|
||||
}
|
||||
++ it;
|
||||
}
|
||||
todo.pop();
|
||||
}
|
||||
|
||||
ParfactorList::iterator it = pfList.begin();
|
||||
bool foundNotRequired = false;
|
||||
while (it != pfList.end()) {
|
||||
if (Util::contains (requiredPfs, *it) == false) {
|
||||
if (Globals::verbosity > 2) {
|
||||
if (foundNotRequired == false) {
|
||||
Util::printHeader ("PARFACTORS TO DISCARD");
|
||||
foundNotRequired = true;
|
||||
}
|
||||
(*it)->print();
|
||||
}
|
||||
it = pfList.removeAndDelete (it);
|
||||
} else {
|
||||
++ it;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
LiftedOperations::absorveEvidence (
|
||||
ParfactorList& pfList,
|
||||
ObservedFormulas& obsFormulas)
|
||||
{
|
||||
for (size_t i = 0; i < obsFormulas.size(); i++) {
|
||||
Parfactors newPfs;
|
||||
ParfactorList::iterator it = pfList.begin();
|
||||
while (it != pfList.end()) {
|
||||
Parfactor* pf = *it;
|
||||
it = pfList.remove (it);
|
||||
Parfactors absorvedPfs = absorve (obsFormulas[i], pf);
|
||||
if (absorvedPfs.empty() == false) {
|
||||
if (absorvedPfs.size() == 1 && absorvedPfs[0] == 0) {
|
||||
// just remove pf;
|
||||
} else {
|
||||
Util::addToVector (newPfs, absorvedPfs);
|
||||
}
|
||||
delete pf;
|
||||
} else {
|
||||
it = pfList.insertShattered (it, pf);
|
||||
++ it;
|
||||
}
|
||||
}
|
||||
pfList.add (newPfs);
|
||||
}
|
||||
if (Globals::verbosity > 2 && obsFormulas.empty() == false) {
|
||||
Util::printAsteriskLine();
|
||||
cout << "AFTER EVIDENCE ABSORVED" << endl;
|
||||
for (size_t i = 0; i < obsFormulas.size(); i++) {
|
||||
cout << " -> " << obsFormulas[i] << endl;
|
||||
}
|
||||
Util::printAsteriskLine();
|
||||
pfList.print();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
Parfactors
|
||||
LiftedOperations::countNormalize (
|
||||
Parfactor* g,
|
||||
const LogVarSet& set)
|
||||
{
|
||||
Parfactors normPfs;
|
||||
if (set.empty()) {
|
||||
normPfs.push_back (new Parfactor (*g));
|
||||
} else {
|
||||
ConstraintTrees normCts = g->constr()->countNormalize (set);
|
||||
for (size_t i = 0; i < normCts.size(); i++) {
|
||||
normPfs.push_back (new Parfactor (g, normCts[i]));
|
||||
}
|
||||
}
|
||||
return normPfs;
|
||||
}
|
||||
|
||||
|
||||
|
||||
Parfactor
|
||||
LiftedOperations::calcGroundMultiplication (Parfactor pf)
|
||||
{
|
||||
LogVarSet lvs = pf.constr()->logVarSet();
|
||||
lvs -= pf.constr()->singletons();
|
||||
Parfactors newPfs = {new Parfactor (pf)};
|
||||
for (size_t i = 0; i < lvs.size(); i++) {
|
||||
Parfactors pfs = newPfs;
|
||||
newPfs.clear();
|
||||
for (size_t j = 0; j < pfs.size(); j++) {
|
||||
bool countedLv = pfs[j]->countedLogVars().contains (lvs[i]);
|
||||
if (countedLv) {
|
||||
pfs[j]->fullExpand (lvs[i]);
|
||||
newPfs.push_back (pfs[j]);
|
||||
} else {
|
||||
ConstraintTrees cts = pfs[j]->constr()->ground (lvs[i]);
|
||||
for (size_t k = 0; k < cts.size(); k++) {
|
||||
newPfs.push_back (new Parfactor (pfs[j], cts[k]));
|
||||
}
|
||||
delete pfs[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
ParfactorList pfList (newPfs);
|
||||
Parfactors groundShatteredPfs (pfList.begin(),pfList.end());
|
||||
for (size_t i = 1; i < groundShatteredPfs.size(); i++) {
|
||||
groundShatteredPfs[0]->multiply (*groundShatteredPfs[i]);
|
||||
}
|
||||
return Parfactor (*groundShatteredPfs[0]);
|
||||
}
|
||||
|
||||
|
||||
|
||||
Parfactors
|
||||
LiftedOperations::absorve (
|
||||
ObservedFormula& obsFormula,
|
||||
Parfactor* g)
|
||||
{
|
||||
Parfactors absorvedPfs;
|
||||
const ProbFormulas& formulas = g->arguments();
|
||||
for (size_t i = 0; i < formulas.size(); i++) {
|
||||
if (obsFormula.functor() == formulas[i].functor() &&
|
||||
obsFormula.arity() == formulas[i].arity()) {
|
||||
|
||||
if (obsFormula.isAtom()) {
|
||||
if (formulas.size() > 1) {
|
||||
g->absorveEvidence (formulas[i], obsFormula.evidence());
|
||||
} else {
|
||||
// hack to erase parfactor g
|
||||
absorvedPfs.push_back (0);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
g->constr()->moveToTop (formulas[i].logVars());
|
||||
std::pair<ConstraintTree*, ConstraintTree*> res;
|
||||
res = g->constr()->split (
|
||||
formulas[i].logVars(),
|
||||
&(obsFormula.constr()),
|
||||
obsFormula.constr().logVars());
|
||||
ConstraintTree* commCt = res.first;
|
||||
ConstraintTree* exclCt = res.second;
|
||||
|
||||
if (commCt->empty() == false) {
|
||||
if (formulas.size() > 1) {
|
||||
LogVarSet excl = g->exclusiveLogVars (i);
|
||||
Parfactor tempPf (g, commCt);
|
||||
Parfactors countNormPfs = LiftedOperations::countNormalize (
|
||||
&tempPf, excl);
|
||||
for (size_t j = 0; j < countNormPfs.size(); j++) {
|
||||
countNormPfs[j]->absorveEvidence (
|
||||
formulas[i], obsFormula.evidence());
|
||||
absorvedPfs.push_back (countNormPfs[j]);
|
||||
}
|
||||
} else {
|
||||
delete commCt;
|
||||
}
|
||||
if (exclCt->empty() == false) {
|
||||
absorvedPfs.push_back (new Parfactor (g, exclCt));
|
||||
} else {
|
||||
delete exclCt;
|
||||
}
|
||||
if (absorvedPfs.empty()) {
|
||||
// hack to erase parfactor g
|
||||
absorvedPfs.push_back (0);
|
||||
}
|
||||
break;
|
||||
} else {
|
||||
delete commCt;
|
||||
delete exclCt;
|
||||
}
|
||||
}
|
||||
}
|
||||
return absorvedPfs;
|
||||
}
|
||||
|
26
packages/CLPBN/horus/LiftedOperations.h
Normal file
26
packages/CLPBN/horus/LiftedOperations.h
Normal file
@ -0,0 +1,26 @@
|
||||
#ifndef HORUS_LIFTEDOPERATIONS_H
|
||||
#define HORUS_LIFTEDOPERATIONS_H
|
||||
|
||||
#include "ParfactorList.h"
|
||||
|
||||
class LiftedOperations
|
||||
{
|
||||
public:
|
||||
static void shatterAgainstQuery (
|
||||
ParfactorList& pfList, const Grounds& query);
|
||||
|
||||
static void runWeakBayesBall (
|
||||
ParfactorList& pfList, const Grounds&);
|
||||
|
||||
static void absorveEvidence (
|
||||
ParfactorList& pfList, ObservedFormulas& obsFormulas);
|
||||
|
||||
static Parfactors countNormalize (Parfactor*, const LogVarSet&);
|
||||
|
||||
static Parfactor calcGroundMultiplication (Parfactor pf);
|
||||
|
||||
private:
|
||||
static Parfactors absorve (ObservedFormula&, Parfactor*);
|
||||
};
|
||||
|
||||
#endif // HORUS_LIFTEDOPERATIONS_H
|
27
packages/CLPBN/horus/LiftedSolver.h
Normal file
27
packages/CLPBN/horus/LiftedSolver.h
Normal file
@ -0,0 +1,27 @@
|
||||
#ifndef HORUS_LIFTEDSOLVER_H
|
||||
#define HORUS_LIFTEDSOLVER_H
|
||||
|
||||
#include "ParfactorList.h"
|
||||
#include "Horus.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
|
||||
class LiftedSolver
|
||||
{
|
||||
public:
|
||||
LiftedSolver (const ParfactorList& pfList)
|
||||
: parfactorList(pfList) { }
|
||||
|
||||
virtual ~LiftedSolver() { } // ensure that subclass destructor is called
|
||||
|
||||
virtual Params solveQuery (const Grounds& query) = 0;
|
||||
|
||||
virtual void printSolverFlags (void) const = 0;
|
||||
|
||||
protected:
|
||||
const ParfactorList& parfactorList;
|
||||
};
|
||||
|
||||
#endif // HORUS_LIFTEDSOLVER_H
|
||||
|
@ -2,6 +2,7 @@
|
||||
#include <set>
|
||||
|
||||
#include "LiftedVe.h"
|
||||
#include "LiftedOperations.h"
|
||||
#include "Histogram.h"
|
||||
#include "Util.h"
|
||||
|
||||
@ -221,7 +222,7 @@ SumOutOperator::apply (void)
|
||||
product->sumOutIndex (fIdx);
|
||||
pfList_.addShattered (product);
|
||||
} else {
|
||||
Parfactors pfs = LiftedVe::countNormalize (product, excl);
|
||||
Parfactors pfs = LiftedOperations::countNormalize (product, excl);
|
||||
for (size_t i = 0; i < pfs.size(); i++) {
|
||||
pfs[i]->sumOutIndex (fIdx);
|
||||
pfList_.add (pfs[i]);
|
||||
@ -375,7 +376,7 @@ CountingOperator::apply (void)
|
||||
} else {
|
||||
Parfactor* pf = *pfIter_;
|
||||
pfList_.remove (pfIter_);
|
||||
Parfactors pfs = LiftedVe::countNormalize (pf, X_);
|
||||
Parfactors pfs = LiftedOperations::countNormalize (pf, X_);
|
||||
for (size_t i = 0; i < pfs.size(); i++) {
|
||||
unsigned condCount = pfs[i]->constr()->getConditionalCount (X_);
|
||||
bool cartProduct = pfs[i]->constr()->isCartesianProduct (
|
||||
@ -419,7 +420,7 @@ CountingOperator::toString (void)
|
||||
ss << "count convert " << X_ << " in " ;
|
||||
ss << (*pfIter_)->getLabel();
|
||||
ss << " [cost=" << std::exp (getLogCost()) << "]" << endl;
|
||||
Parfactors pfs = LiftedVe::countNormalize (*pfIter_, X_);
|
||||
Parfactors pfs = LiftedOperations::countNormalize (*pfIter_, X_);
|
||||
if ((*pfIter_)->constr()->isCountNormalized (X_) == false) {
|
||||
for (size_t i = 0; i < pfs.size(); i++) {
|
||||
ss << " º " << pfs[i]->getLabel() << endl;
|
||||
@ -508,8 +509,6 @@ GroundOperator::getLogCost (void)
|
||||
void
|
||||
GroundOperator::apply (void)
|
||||
{
|
||||
// TODO if we update the correct groups
|
||||
// we can skip shattering
|
||||
ParfactorList::iterator pfIter;
|
||||
pfIter = getParfactorsWithGroup (pfList_, group_).front();
|
||||
Parfactor* pf = *pfIter;
|
||||
@ -632,6 +631,7 @@ Params
|
||||
LiftedVe::solveQuery (const Grounds& query)
|
||||
{
|
||||
assert (query.empty() == false);
|
||||
pfList_ = parfactorList;
|
||||
runSolver (query);
|
||||
(*pfList_.begin())->normalize();
|
||||
Params params = (*pfList_.begin())->params();
|
||||
@ -647,7 +647,7 @@ void
|
||||
LiftedVe::printSolverFlags (void) const
|
||||
{
|
||||
stringstream ss;
|
||||
ss << "fove [" ;
|
||||
ss << "lve [" ;
|
||||
ss << "log_domain=" << Util::toString (Globals::logDomain);
|
||||
ss << "]" ;
|
||||
cout << ss.str() << endl;
|
||||
@ -655,103 +655,12 @@ LiftedVe::printSolverFlags (void) const
|
||||
|
||||
|
||||
|
||||
void
|
||||
LiftedVe::absorveEvidence (
|
||||
ParfactorList& pfList,
|
||||
ObservedFormulas& obsFormulas)
|
||||
{
|
||||
for (size_t i = 0; i < obsFormulas.size(); i++) {
|
||||
Parfactors newPfs;
|
||||
ParfactorList::iterator it = pfList.begin();
|
||||
while (it != pfList.end()) {
|
||||
Parfactor* pf = *it;
|
||||
it = pfList.remove (it);
|
||||
Parfactors absorvedPfs = absorve (obsFormulas[i], pf);
|
||||
if (absorvedPfs.empty() == false) {
|
||||
if (absorvedPfs.size() == 1 && absorvedPfs[0] == 0) {
|
||||
// just remove pf;
|
||||
} else {
|
||||
Util::addToVector (newPfs, absorvedPfs);
|
||||
}
|
||||
delete pf;
|
||||
} else {
|
||||
it = pfList.insertShattered (it, pf);
|
||||
++ it;
|
||||
}
|
||||
}
|
||||
pfList.add (newPfs);
|
||||
}
|
||||
if (Globals::verbosity > 2 && obsFormulas.empty() == false) {
|
||||
Util::printAsteriskLine();
|
||||
cout << "AFTER EVIDENCE ABSORVED" << endl;
|
||||
for (size_t i = 0; i < obsFormulas.size(); i++) {
|
||||
cout << " -> " << obsFormulas[i] << endl;
|
||||
}
|
||||
Util::printAsteriskLine();
|
||||
pfList.print();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
Parfactors
|
||||
LiftedVe::countNormalize (
|
||||
Parfactor* g,
|
||||
const LogVarSet& set)
|
||||
{
|
||||
Parfactors normPfs;
|
||||
if (set.empty()) {
|
||||
normPfs.push_back (new Parfactor (*g));
|
||||
} else {
|
||||
ConstraintTrees normCts = g->constr()->countNormalize (set);
|
||||
for (size_t i = 0; i < normCts.size(); i++) {
|
||||
normPfs.push_back (new Parfactor (g, normCts[i]));
|
||||
}
|
||||
}
|
||||
return normPfs;
|
||||
}
|
||||
|
||||
|
||||
|
||||
Parfactor
|
||||
LiftedVe::calcGroundMultiplication (Parfactor pf)
|
||||
{
|
||||
LogVarSet lvs = pf.constr()->logVarSet();
|
||||
lvs -= pf.constr()->singletons();
|
||||
Parfactors newPfs = {new Parfactor (pf)};
|
||||
for (size_t i = 0; i < lvs.size(); i++) {
|
||||
Parfactors pfs = newPfs;
|
||||
newPfs.clear();
|
||||
for (size_t j = 0; j < pfs.size(); j++) {
|
||||
bool countedLv = pfs[j]->countedLogVars().contains (lvs[i]);
|
||||
if (countedLv) {
|
||||
pfs[j]->fullExpand (lvs[i]);
|
||||
newPfs.push_back (pfs[j]);
|
||||
} else {
|
||||
ConstraintTrees cts = pfs[j]->constr()->ground (lvs[i]);
|
||||
for (size_t k = 0; k < cts.size(); k++) {
|
||||
newPfs.push_back (new Parfactor (pfs[j], cts[k]));
|
||||
}
|
||||
delete pfs[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
ParfactorList pfList (newPfs);
|
||||
Parfactors groundShatteredPfs (pfList.begin(),pfList.end());
|
||||
for (size_t i = 1; i < groundShatteredPfs.size(); i++) {
|
||||
groundShatteredPfs[0]->multiply (*groundShatteredPfs[i]);
|
||||
}
|
||||
return Parfactor (*groundShatteredPfs[0]);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
LiftedVe::runSolver (const Grounds& query)
|
||||
{
|
||||
largestCost_ = std::log (0);
|
||||
shatterAgainstQuery (query);
|
||||
runWeakBayesBall (query);
|
||||
LiftedOperations::shatterAgainstQuery (pfList_, query);
|
||||
LiftedOperations::runWeakBayesBall (pfList_, query);
|
||||
while (true) {
|
||||
if (Globals::verbosity > 2) {
|
||||
Util::printDashedLine();
|
||||
@ -817,177 +726,3 @@ LiftedVe::getBestOperation (const Grounds& query)
|
||||
return bestOp;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
LiftedVe::runWeakBayesBall (const Grounds& query)
|
||||
{
|
||||
queue<PrvGroup> todo; // groups to process
|
||||
set<PrvGroup> done; // processed or in queue
|
||||
for (size_t i = 0; i < query.size(); i++) {
|
||||
ParfactorList::iterator it = pfList_.begin();
|
||||
while (it != pfList_.end()) {
|
||||
PrvGroup group = (*it)->findGroup (query[i]);
|
||||
if (group != numeric_limits<PrvGroup>::max()) {
|
||||
todo.push (group);
|
||||
done.insert (group);
|
||||
break;
|
||||
}
|
||||
++ it;
|
||||
}
|
||||
}
|
||||
|
||||
set<Parfactor*> requiredPfs;
|
||||
while (todo.empty() == false) {
|
||||
PrvGroup group = todo.front();
|
||||
ParfactorList::iterator it = pfList_.begin();
|
||||
while (it != pfList_.end()) {
|
||||
if (Util::contains (requiredPfs, *it) == false &&
|
||||
(*it)->containsGroup (group)) {
|
||||
vector<PrvGroup> groups = (*it)->getAllGroups();
|
||||
for (size_t i = 0; i < groups.size(); i++) {
|
||||
if (Util::contains (done, groups[i]) == false) {
|
||||
todo.push (groups[i]);
|
||||
done.insert (groups[i]);
|
||||
}
|
||||
}
|
||||
requiredPfs.insert (*it);
|
||||
}
|
||||
++ it;
|
||||
}
|
||||
todo.pop();
|
||||
}
|
||||
|
||||
ParfactorList::iterator it = pfList_.begin();
|
||||
bool foundNotRequired = false;
|
||||
while (it != pfList_.end()) {
|
||||
if (Util::contains (requiredPfs, *it) == false) {
|
||||
if (Globals::verbosity > 2) {
|
||||
if (foundNotRequired == false) {
|
||||
Util::printHeader ("PARFACTORS TO DISCARD");
|
||||
foundNotRequired = true;
|
||||
}
|
||||
(*it)->print();
|
||||
}
|
||||
it = pfList_.removeAndDelete (it);
|
||||
} else {
|
||||
++ it;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
LiftedVe::shatterAgainstQuery (const Grounds& query)
|
||||
{
|
||||
for (size_t i = 0; i < query.size(); i++) {
|
||||
if (query[i].isAtom()) {
|
||||
continue;
|
||||
}
|
||||
bool found = false;
|
||||
Parfactors newPfs;
|
||||
ParfactorList::iterator it = pfList_.begin();
|
||||
while (it != pfList_.end()) {
|
||||
if ((*it)->containsGround (query[i])) {
|
||||
found = true;
|
||||
std::pair<ConstraintTree*, ConstraintTree*> split;
|
||||
LogVars queryLvs (
|
||||
(*it)->constr()->logVars().begin(),
|
||||
(*it)->constr()->logVars().begin() + query[i].arity());
|
||||
split = (*it)->constr()->split (query[i].args());
|
||||
ConstraintTree* commCt = split.first;
|
||||
ConstraintTree* exclCt = split.second;
|
||||
newPfs.push_back (new Parfactor (*it, commCt));
|
||||
if (exclCt->empty() == false) {
|
||||
newPfs.push_back (new Parfactor (*it, exclCt));
|
||||
} else {
|
||||
delete exclCt;
|
||||
}
|
||||
it = pfList_.removeAndDelete (it);
|
||||
} else {
|
||||
++ it;
|
||||
}
|
||||
}
|
||||
if (found == false) {
|
||||
cerr << "error: could not find a parfactor with ground " ;
|
||||
cerr << "`" << query[i] << "'" << endl;
|
||||
exit (0);
|
||||
}
|
||||
pfList_.add (newPfs);
|
||||
}
|
||||
if (Globals::verbosity > 2) {
|
||||
Util::printAsteriskLine();
|
||||
cout << "SHATTERED AGAINST THE QUERY" << endl;
|
||||
for (size_t i = 0; i < query.size(); i++) {
|
||||
cout << " -> " << query[i] << endl;
|
||||
}
|
||||
Util::printAsteriskLine();
|
||||
pfList_.print();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
Parfactors
|
||||
LiftedVe::absorve (
|
||||
ObservedFormula& obsFormula,
|
||||
Parfactor* g)
|
||||
{
|
||||
Parfactors absorvedPfs;
|
||||
const ProbFormulas& formulas = g->arguments();
|
||||
for (size_t i = 0; i < formulas.size(); i++) {
|
||||
if (obsFormula.functor() == formulas[i].functor() &&
|
||||
obsFormula.arity() == formulas[i].arity()) {
|
||||
|
||||
if (obsFormula.isAtom()) {
|
||||
if (formulas.size() > 1) {
|
||||
g->absorveEvidence (formulas[i], obsFormula.evidence());
|
||||
} else {
|
||||
// hack to erase parfactor g
|
||||
absorvedPfs.push_back (0);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
g->constr()->moveToTop (formulas[i].logVars());
|
||||
std::pair<ConstraintTree*, ConstraintTree*> res;
|
||||
res = g->constr()->split (
|
||||
formulas[i].logVars(),
|
||||
&(obsFormula.constr()),
|
||||
obsFormula.constr().logVars());
|
||||
ConstraintTree* commCt = res.first;
|
||||
ConstraintTree* exclCt = res.second;
|
||||
|
||||
if (commCt->empty() == false) {
|
||||
if (formulas.size() > 1) {
|
||||
LogVarSet excl = g->exclusiveLogVars (i);
|
||||
Parfactor tempPf (g, commCt);
|
||||
Parfactors countNormPfs = countNormalize (&tempPf, excl);
|
||||
for (size_t j = 0; j < countNormPfs.size(); j++) {
|
||||
countNormPfs[j]->absorveEvidence (
|
||||
formulas[i], obsFormula.evidence());
|
||||
absorvedPfs.push_back (countNormPfs[j]);
|
||||
}
|
||||
} else {
|
||||
delete commCt;
|
||||
}
|
||||
if (exclCt->empty() == false) {
|
||||
absorvedPfs.push_back (new Parfactor (g, exclCt));
|
||||
} else {
|
||||
delete exclCt;
|
||||
}
|
||||
if (absorvedPfs.empty()) {
|
||||
// hack to erase parfactor g
|
||||
absorvedPfs.push_back (0);
|
||||
}
|
||||
break;
|
||||
} else {
|
||||
delete commCt;
|
||||
delete exclCt;
|
||||
}
|
||||
}
|
||||
}
|
||||
return absorvedPfs;
|
||||
}
|
||||
|
||||
|
@ -1,13 +1,15 @@
|
||||
#ifndef HORUS_LIFTEDVE_H
|
||||
#define HORUS_LIFTEDVE_H
|
||||
|
||||
|
||||
#include "LiftedSolver.h"
|
||||
#include "ParfactorList.h"
|
||||
|
||||
|
||||
class LiftedOperator
|
||||
{
|
||||
public:
|
||||
virtual ~LiftedOperator (void) { }
|
||||
|
||||
virtual double getLogCost (void) = 0;
|
||||
|
||||
virtual void apply (void) = 0;
|
||||
@ -43,9 +45,9 @@ class ProductOperator : public LiftedOperator
|
||||
private:
|
||||
static bool validOp (Parfactor*, Parfactor*);
|
||||
|
||||
ParfactorList::iterator g1_;
|
||||
ParfactorList::iterator g2_;
|
||||
ParfactorList& pfList_;
|
||||
ParfactorList::iterator g1_;
|
||||
ParfactorList::iterator g2_;
|
||||
ParfactorList& pfList_;
|
||||
};
|
||||
|
||||
|
||||
@ -123,43 +125,30 @@ class GroundOperator : public LiftedOperator
|
||||
private:
|
||||
vector<pair<PrvGroup, unsigned>> getAffectedFormulas (void);
|
||||
|
||||
PrvGroup group_;
|
||||
unsigned lvIndex_;
|
||||
ParfactorList& pfList_;
|
||||
PrvGroup group_;
|
||||
unsigned lvIndex_;
|
||||
ParfactorList& pfList_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class LiftedVe
|
||||
class LiftedVe : public LiftedSolver
|
||||
{
|
||||
public:
|
||||
LiftedVe (const ParfactorList& pfList) : pfList_(pfList) { }
|
||||
LiftedVe (const ParfactorList& pfList)
|
||||
: LiftedSolver(pfList) { }
|
||||
|
||||
Params solveQuery (const Grounds&);
|
||||
|
||||
void printSolverFlags (void) const;
|
||||
|
||||
static void absorveEvidence (
|
||||
ParfactorList& pfList, ObservedFormulas& obsFormulas);
|
||||
|
||||
static Parfactors countNormalize (Parfactor*, const LogVarSet&);
|
||||
|
||||
static Parfactor calcGroundMultiplication (Parfactor pf);
|
||||
|
||||
private:
|
||||
void runSolver (const Grounds&);
|
||||
|
||||
LiftedOperator* getBestOperation (const Grounds&);
|
||||
|
||||
void runWeakBayesBall (const Grounds&);
|
||||
|
||||
void shatterAgainstQuery (const Grounds&);
|
||||
|
||||
static Parfactors absorve (ObservedFormula&, Parfactor*);
|
||||
|
||||
ParfactorList pfList_;
|
||||
|
||||
double largestCost_;
|
||||
ParfactorList pfList_;
|
||||
double largestCost_;
|
||||
};
|
||||
|
||||
#endif // HORUS_LIFTEDVE_H
|
||||
|
610
packages/CLPBN/horus/LiftedWCNF.cpp
Normal file
610
packages/CLPBN/horus/LiftedWCNF.cpp
Normal file
@ -0,0 +1,610 @@
|
||||
#include "LiftedWCNF.h"
|
||||
#include "ConstraintTree.h"
|
||||
#include "Indexer.h"
|
||||
|
||||
|
||||
|
||||
bool
|
||||
Literal::isGround (ConstraintTree constr, LogVarSet ipgLogVars) const
|
||||
{
|
||||
if (logVars_.size() == 0) {
|
||||
return true;
|
||||
}
|
||||
LogVarSet lvs (logVars_);
|
||||
lvs -= ipgLogVars;
|
||||
return constr.singletons().contains (lvs);
|
||||
}
|
||||
|
||||
|
||||
|
||||
size_t
|
||||
Literal::indexOfLogVar (LogVar X) const
|
||||
{
|
||||
return Util::indexOf (logVars_, X);
|
||||
}
|
||||
|
||||
|
||||
|
||||
string
|
||||
Literal::toString (
|
||||
LogVarSet ipgLogVars,
|
||||
LogVarSet posCountedLvs,
|
||||
LogVarSet negCountedLvs) const
|
||||
{
|
||||
stringstream ss;
|
||||
negated_ ? ss << "¬" : ss << "" ;
|
||||
// if (negated_ == false) {
|
||||
// posWeight_ < 0.0 ? ss << "λ" : ss << "Θ" ;
|
||||
// } else {
|
||||
// negWeight_ < 0.0 ? ss << "λ" : ss << "Θ" ;
|
||||
// }
|
||||
ss << "λ" ;
|
||||
ss << lid_ ;
|
||||
if (logVars_.empty() == false) {
|
||||
ss << "(" ;
|
||||
for (size_t i = 0; i < logVars_.size(); i++) {
|
||||
if (i != 0) ss << ",";
|
||||
if (posCountedLvs.contains (logVars_[i])) {
|
||||
ss << "+" << logVars_[i];
|
||||
} else if (negCountedLvs.contains (logVars_[i])) {
|
||||
ss << "-" << logVars_[i];
|
||||
} else if (ipgLogVars.contains (logVars_[i])) {
|
||||
LogVar X = logVars_[i];
|
||||
const string labels[] = {
|
||||
"a", "b", "c", "d", "e", "f",
|
||||
"g", "h", "i", "j", "k", "m" };
|
||||
(X >= 12) ? ss << "x_" << X : ss << labels[X];
|
||||
} else {
|
||||
ss << logVars_[i];
|
||||
}
|
||||
}
|
||||
ss << ")" ;
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
|
||||
|
||||
std::ostream& operator<< (ostream &os, const Literal& lit)
|
||||
{
|
||||
os << lit.toString();
|
||||
return os;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Clause::addLiteralComplemented (const Literal& lit)
|
||||
{
|
||||
assert (constr_.logVarSet().contains (lit.logVars()));
|
||||
literals_.push_back (lit);
|
||||
literals_.back().complement();
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
Clause::containsLiteral (LiteralId lid) const
|
||||
{
|
||||
for (size_t i = 0; i < literals_.size(); i++) {
|
||||
if (literals_[i].lid() == lid) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
Clause::containsPositiveLiteral (
|
||||
LiteralId lid,
|
||||
const LogVarTypes& types) const
|
||||
{
|
||||
for (size_t i = 0; i < literals_.size(); i++) {
|
||||
if (literals_[i].lid() == lid
|
||||
&& literals_[i].isPositive()
|
||||
&& logVarTypes (i) == types) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
Clause::containsNegativeLiteral (
|
||||
LiteralId lid,
|
||||
const LogVarTypes& types) const
|
||||
{
|
||||
for (size_t i = 0; i < literals_.size(); i++) {
|
||||
if (literals_[i].lid() == lid
|
||||
&& literals_[i].isNegative()
|
||||
&& logVarTypes (i) == types) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Clause::removeLiterals (LiteralId lid)
|
||||
{
|
||||
size_t i = 0;
|
||||
while (i != literals_.size()) {
|
||||
if (literals_[i].lid() == lid) {
|
||||
removeLiteral (i);
|
||||
} else {
|
||||
i ++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Clause::removePositiveLiterals (
|
||||
LiteralId lid,
|
||||
const LogVarTypes& types)
|
||||
{
|
||||
size_t i = 0;
|
||||
while (i != literals_.size()) {
|
||||
if (literals_[i].lid() == lid
|
||||
&& literals_[i].isPositive()
|
||||
&& logVarTypes (i) == types) {
|
||||
removeLiteral (i);
|
||||
} else {
|
||||
i ++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Clause::removeNegativeLiterals (
|
||||
LiteralId lid,
|
||||
const LogVarTypes& types)
|
||||
{
|
||||
size_t i = 0;
|
||||
while (i != literals_.size()) {
|
||||
if (literals_[i].lid() == lid
|
||||
&& literals_[i].isNegative()
|
||||
&& logVarTypes (i) == types) {
|
||||
removeLiteral (i);
|
||||
} else {
|
||||
i ++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
Clause::isCountedLogVar (LogVar X) const
|
||||
{
|
||||
assert (constr_.logVarSet().contains (X));
|
||||
return posCountedLvs_.contains (X)
|
||||
|| negCountedLvs_.contains (X);
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
Clause::isPositiveCountedLogVar (LogVar X) const
|
||||
{
|
||||
assert (constr_.logVarSet().contains (X));
|
||||
return posCountedLvs_.contains (X);
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
Clause::isNegativeCountedLogVar (LogVar X) const
|
||||
{
|
||||
assert (constr_.logVarSet().contains (X));
|
||||
return negCountedLvs_.contains (X);
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
Clause::isIpgLogVar (LogVar X) const
|
||||
{
|
||||
assert (constr_.logVarSet().contains (X));
|
||||
return ipgLvs_.contains (X);
|
||||
}
|
||||
|
||||
|
||||
|
||||
TinySet<LiteralId>
|
||||
Clause::lidSet (void) const
|
||||
{
|
||||
TinySet<LiteralId> lidSet;
|
||||
for (size_t i = 0; i < literals_.size(); i++) {
|
||||
lidSet.insert (literals_[i].lid());
|
||||
}
|
||||
return lidSet;
|
||||
}
|
||||
|
||||
|
||||
|
||||
LogVarSet
|
||||
Clause::ipgCandidates (void) const
|
||||
{
|
||||
LogVarSet candidates;
|
||||
LogVarSet allLvs = constr_.logVarSet();
|
||||
allLvs -= ipgLvs_;
|
||||
allLvs -= posCountedLvs_;
|
||||
allLvs -= negCountedLvs_;
|
||||
for (size_t i = 0; i < allLvs.size(); i++) {
|
||||
bool valid = true;
|
||||
for (size_t j = 0; j < literals_.size(); j++) {
|
||||
if (Util::contains (literals_[j].logVars(), allLvs[i]) == false) {
|
||||
valid = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (valid) {
|
||||
candidates.insert (allLvs[i]);
|
||||
}
|
||||
}
|
||||
return candidates;
|
||||
}
|
||||
|
||||
|
||||
|
||||
LogVarTypes
|
||||
Clause::logVarTypes (size_t litIdx) const
|
||||
{
|
||||
LogVarTypes types;
|
||||
const LogVars lvs = literals_[litIdx].logVars();
|
||||
for (size_t i = 0; i < lvs.size(); i++) {
|
||||
if (posCountedLvs_.contains (lvs[i])) {
|
||||
types.push_back (LogVarType::POS_LV);
|
||||
} else if (negCountedLvs_.contains (lvs[i])) {
|
||||
types.push_back (LogVarType::NEG_LV);
|
||||
} else {
|
||||
types.push_back (LogVarType::FULL_LV);
|
||||
}
|
||||
}
|
||||
return types;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Clause::removeLiteral (size_t litIdx)
|
||||
{
|
||||
LogVarSet lvsToRemove = literals_[litIdx].logVarSet()
|
||||
- getLogVarSetExcluding (litIdx);
|
||||
ipgLvs_ -= lvsToRemove;
|
||||
posCountedLvs_ -= lvsToRemove;
|
||||
negCountedLvs_ -= lvsToRemove;
|
||||
constr_.remove (lvsToRemove);
|
||||
literals_.erase (literals_.begin() + litIdx);
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
Clause::independentClauses (Clause& c1, Clause& c2)
|
||||
{
|
||||
const Literals& lits1 = c1.literals();
|
||||
const Literals& lits2 = c2.literals();
|
||||
for (size_t i = 0; i < lits1.size(); i++) {
|
||||
for (size_t j = 0; j < lits2.size(); j++) {
|
||||
if (lits1[i].lid() == lits2[j].lid()
|
||||
&& c1.logVarTypes (i) == c2.logVarTypes (j)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Clause::printClauses (const Clauses& clauses)
|
||||
{
|
||||
for (size_t i = 0; i < clauses.size(); i++) {
|
||||
cout << clauses[i] << endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
std::ostream& operator<< (ostream &os, const Clause& clause)
|
||||
{
|
||||
for (unsigned i = 0; i < clause.literals_.size(); i++) {
|
||||
if (i != 0) os << " v " ;
|
||||
os << clause.literals_[i].toString (clause.ipgLvs_,
|
||||
clause.posCountedLvs_, clause.negCountedLvs_);
|
||||
}
|
||||
if (clause.constr_.empty() == false) {
|
||||
ConstraintTree copy (clause.constr_);
|
||||
copy.moveToTop (copy.logVarSet().elements());
|
||||
os << " | " << copy.tupleSet();
|
||||
}
|
||||
return os;
|
||||
}
|
||||
|
||||
|
||||
|
||||
LogVarSet
|
||||
Clause::getLogVarSetExcluding (size_t idx) const
|
||||
{
|
||||
LogVarSet lvs;
|
||||
for (size_t i = 0; i < literals_.size(); i++) {
|
||||
if (i != idx) {
|
||||
lvs |= literals_[i].logVars();
|
||||
}
|
||||
}
|
||||
return lvs;
|
||||
}
|
||||
|
||||
|
||||
|
||||
LiftedWCNF::LiftedWCNF (const ParfactorList& pfList)
|
||||
: freeLiteralId_(0), pfList_(pfList)
|
||||
{
|
||||
addIndicatorClauses (pfList);
|
||||
addParameterClauses (pfList);
|
||||
|
||||
/*
|
||||
// INCLUSION-EXCLUSION TEST
|
||||
vector<vector<string>> names = {
|
||||
// {"a1","b1"},{"a2","b2"},{"a1","b3"}
|
||||
{"b1","a1"},{"b2","a2"},{"b3","a1"}
|
||||
};
|
||||
Clause c1 (names);
|
||||
c1.addLiteral (Literal (0, LogVars() = {0}));
|
||||
c1.addLiteral (Literal (1, LogVars() = {1}));
|
||||
clauses_.push_back(c1);
|
||||
freeLiteralId_ ++ ;
|
||||
freeLiteralId_ ++ ;
|
||||
*/
|
||||
|
||||
/*
|
||||
// ATOM-COUNTING TEST
|
||||
vector<vector<string>> names = {
|
||||
{"p1","p1"},{"p1","p2"},{"p1","p3"},
|
||||
{"p2","p1"},{"p2","p2"},{"p2","p3"},
|
||||
{"p3","p1"},{"p3","p2"},{"p3","p3"}
|
||||
};
|
||||
Clause c1 (names);
|
||||
c1.addLiteral (Literal (0, LogVars() = {0}));
|
||||
c1.addLiteralComplemented (Literal (1, {0,1}));
|
||||
clauses_.push_back(c1);
|
||||
Clause c2 (names);
|
||||
c2.addLiteral (Literal (0, LogVars()={0}));
|
||||
c2.addLiteralComplemented (Literal (1, {1,0}));
|
||||
clauses_.push_back(c2);
|
||||
addWeight (0, LogAware::log(3.0), LogAware::log(4.0));
|
||||
addWeight (1, LogAware::log(2.0), LogAware::log(5.0));
|
||||
freeLiteralId_ = 2;
|
||||
*/
|
||||
|
||||
cout << "FORMULA INDICATORS:" << endl;
|
||||
printFormulaIndicators();
|
||||
cout << endl;
|
||||
|
||||
cout << "WEIGHTS:" << endl;
|
||||
printWeights();
|
||||
cout << endl;
|
||||
|
||||
cout << "CLAUSES:" << endl;
|
||||
printClauses();
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
|
||||
|
||||
LiftedWCNF::~LiftedWCNF (void)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
LiftedWCNF::addWeight (LiteralId lid, double posW, double negW)
|
||||
{
|
||||
weights_[lid] = make_pair (posW, negW);
|
||||
}
|
||||
|
||||
|
||||
|
||||
double
|
||||
LiftedWCNF::posWeight (LiteralId lid) const
|
||||
{
|
||||
unordered_map<LiteralId, std::pair<double,double>>::const_iterator it;
|
||||
it = weights_.find (lid);
|
||||
return it != weights_.end() ? it->second.first : LogAware::one();
|
||||
}
|
||||
|
||||
|
||||
|
||||
double
|
||||
LiftedWCNF::negWeight (LiteralId lid) const
|
||||
{
|
||||
unordered_map<LiteralId, std::pair<double,double>>::const_iterator it;
|
||||
it = weights_.find (lid);
|
||||
return it != weights_.end() ? it->second.second : LogAware::one();
|
||||
}
|
||||
|
||||
|
||||
|
||||
vector<LiteralId>
|
||||
LiftedWCNF::prvGroupLiterals (PrvGroup prvGroup)
|
||||
{
|
||||
assert (Util::contains (map_, prvGroup));
|
||||
return map_[prvGroup];
|
||||
}
|
||||
|
||||
|
||||
|
||||
Clause
|
||||
LiftedWCNF::createClause (LiteralId lid) const
|
||||
{
|
||||
for (size_t i = 0; i < clauses_.size(); i++) {
|
||||
const Literals& literals = clauses_[i].literals();
|
||||
for (size_t j = 0; j < literals.size(); j++) {
|
||||
if (literals[j].lid() == lid) {
|
||||
ConstraintTree ct = clauses_[i].constr();
|
||||
ct.project (literals[j].logVars());
|
||||
Clause clause (ct);
|
||||
clause.addLiteral (literals[j]);
|
||||
return clause;
|
||||
}
|
||||
}
|
||||
}
|
||||
abort(); // we should not reach this point
|
||||
return Clause (ConstraintTree({}));
|
||||
}
|
||||
|
||||
|
||||
|
||||
LiteralId
|
||||
LiftedWCNF::getLiteralId (PrvGroup prvGroup, unsigned range)
|
||||
{
|
||||
assert (Util::contains (map_, prvGroup));
|
||||
return map_[prvGroup][range];
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
LiftedWCNF::addIndicatorClauses (const ParfactorList& pfList)
|
||||
{
|
||||
ParfactorList::const_iterator it = pfList.begin();
|
||||
set<PrvGroup> allGroups;
|
||||
while (it != pfList.end()) {
|
||||
const ProbFormulas& formulas = (*it)->arguments();
|
||||
for (size_t i = 0; i < formulas.size(); i++) {
|
||||
if (Util::contains (allGroups, formulas[i].group()) == false) {
|
||||
allGroups.insert (formulas[i].group());
|
||||
ConstraintTree tempConstr = *(*it)->constr();
|
||||
tempConstr.project (formulas[i].logVars());
|
||||
Clause clause (tempConstr);
|
||||
vector<LiteralId> lids;
|
||||
for (size_t j = 0; j < formulas[i].range(); j++) {
|
||||
clause.addLiteral (Literal (freeLiteralId_, formulas[i].logVars()));
|
||||
lids.push_back (freeLiteralId_);
|
||||
freeLiteralId_ ++;
|
||||
}
|
||||
clauses_.push_back (clause);
|
||||
for (size_t j = 0; j < formulas[i].range() - 1; j++) {
|
||||
for (size_t k = j + 1; k < formulas[i].range(); k++) {
|
||||
ConstraintTree tempConstr2 = *(*it)->constr();
|
||||
tempConstr2.project (formulas[i].logVars());
|
||||
Clause clause2 (tempConstr2);
|
||||
clause2.addLiteralComplemented (Literal (clause.literals()[j]));
|
||||
clause2.addLiteralComplemented (Literal (clause.literals()[k]));
|
||||
clauses_.push_back (clause2);
|
||||
}
|
||||
}
|
||||
map_[formulas[i].group()] = lids;
|
||||
}
|
||||
}
|
||||
++ it;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
LiftedWCNF::addParameterClauses (const ParfactorList& pfList)
|
||||
{
|
||||
ParfactorList::const_iterator it = pfList.begin();
|
||||
while (it != pfList.end()) {
|
||||
Indexer indexer ((*it)->ranges());
|
||||
vector<PrvGroup> groups = (*it)->getAllGroups();
|
||||
while (indexer.valid()) {
|
||||
LiteralId paramVarLid = freeLiteralId_;
|
||||
// λu1 ∧ ... ∧ λun ∧ λxi <=> θxi|u1,...,un
|
||||
//
|
||||
// ¬λu1 ... ¬λun v θxi|u1,...,un -> clause1
|
||||
// ¬θxi|u1,...,un v λu1 -> tempClause
|
||||
// ¬θxi|u1,...,un v λu2 -> tempClause
|
||||
double posWeight = (**it)[indexer];
|
||||
addWeight (paramVarLid, posWeight, 1.0);
|
||||
|
||||
Clause clause1 (*(*it)->constr());
|
||||
|
||||
for (unsigned i = 0; i < groups.size(); i++) {
|
||||
LiteralId lid = getLiteralId (groups[i], indexer[i]);
|
||||
|
||||
clause1.addLiteralComplemented (
|
||||
Literal (lid, (*it)->argument(i).logVars()));
|
||||
|
||||
ConstraintTree ct = *(*it)->constr();
|
||||
Clause tempClause (ct);
|
||||
tempClause.addLiteralComplemented (Literal (
|
||||
paramVarLid, (*it)->constr()->logVars()));
|
||||
tempClause.addLiteral (Literal (lid, (*it)->argument(i).logVars()));
|
||||
clauses_.push_back (tempClause);
|
||||
}
|
||||
clause1.addLiteral (Literal (paramVarLid, (*it)->constr()->logVars()));
|
||||
clauses_.push_back (clause1);
|
||||
freeLiteralId_ ++;
|
||||
++ indexer;
|
||||
}
|
||||
++ it;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void
|
||||
LiftedWCNF::printFormulaIndicators (void) const
|
||||
{
|
||||
set<PrvGroup> allGroups;
|
||||
ParfactorList::const_iterator it = pfList_.begin();
|
||||
while (it != pfList_.end()) {
|
||||
const ProbFormulas& formulas = (*it)->arguments();
|
||||
for (size_t i = 0; i < formulas.size(); i++) {
|
||||
if (Util::contains (allGroups, formulas[i].group()) == false) {
|
||||
allGroups.insert (formulas[i].group());
|
||||
cout << formulas[i] << " | " ;
|
||||
ConstraintTree tempCt = *(*it)->constr();
|
||||
tempCt.project (formulas[i].logVars());
|
||||
cout << tempCt.tupleSet();
|
||||
cout << " indicators => " ;
|
||||
vector<LiteralId> indicators =
|
||||
(map_.find (formulas[i].group()))->second;
|
||||
cout << indicators << endl;
|
||||
}
|
||||
}
|
||||
++ it;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
LiftedWCNF::printWeights (void) const
|
||||
{
|
||||
unordered_map<LiteralId, std::pair<double,double>>::const_iterator it;
|
||||
it = weights_.begin();
|
||||
while (it != weights_.end()) {
|
||||
cout << "λ" << it->first << " weights: " ;
|
||||
cout << it->second.first << " " << it->second.second;
|
||||
cout << endl;
|
||||
++ it;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
LiftedWCNF::printClauses (void) const
|
||||
{
|
||||
for (unsigned i = 0; i < clauses_.size(); i++) {
|
||||
cout << clauses_[i] << endl;
|
||||
}
|
||||
}
|
||||
|
226
packages/CLPBN/horus/LiftedWCNF.h
Normal file
226
packages/CLPBN/horus/LiftedWCNF.h
Normal file
@ -0,0 +1,226 @@
|
||||
#ifndef HORUS_LIFTEDWCNF_H
|
||||
#define HORUS_LIFTEDWCNF_H
|
||||
|
||||
#include "ParfactorList.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
typedef long LiteralId;
|
||||
|
||||
class ConstraintTree;
|
||||
|
||||
|
||||
enum LogVarType
|
||||
{
|
||||
FULL_LV,
|
||||
POS_LV,
|
||||
NEG_LV
|
||||
};
|
||||
|
||||
typedef vector<LogVarType> LogVarTypes;
|
||||
|
||||
|
||||
|
||||
class Literal
|
||||
{
|
||||
public:
|
||||
Literal (LiteralId lid, const LogVars& lvs) :
|
||||
lid_(lid), logVars_(lvs), negated_(false) { }
|
||||
|
||||
Literal (const Literal& lit, bool negated) :
|
||||
lid_(lit.lid_), logVars_(lit.logVars_), negated_(negated) { }
|
||||
|
||||
LiteralId lid (void) const { return lid_; }
|
||||
|
||||
LogVars logVars (void) const { return logVars_; }
|
||||
|
||||
size_t nrLogVars (void) const { return logVars_.size(); }
|
||||
|
||||
LogVarSet logVarSet (void) const { return LogVarSet (logVars_); }
|
||||
|
||||
void complement (void) { negated_ = !negated_; }
|
||||
|
||||
bool isPositive (void) const { return negated_ == false; }
|
||||
|
||||
bool isNegative (void) const { return negated_; }
|
||||
|
||||
bool isGround (ConstraintTree constr, LogVarSet ipgLogVars) const;
|
||||
|
||||
size_t indexOfLogVar (LogVar X) const;
|
||||
|
||||
string toString (LogVarSet ipgLogVars = LogVarSet(),
|
||||
LogVarSet posCountedLvs = LogVarSet(),
|
||||
LogVarSet negCountedLvs = LogVarSet()) const;
|
||||
|
||||
friend std::ostream& operator<< (std::ostream &os, const Literal& lit);
|
||||
|
||||
private:
|
||||
LiteralId lid_;
|
||||
LogVars logVars_;
|
||||
bool negated_;
|
||||
};
|
||||
|
||||
typedef vector<Literal> Literals;
|
||||
|
||||
|
||||
|
||||
class Clause
|
||||
{
|
||||
public:
|
||||
Clause (const ConstraintTree& ct = ConstraintTree({})) : constr_(ct) { }
|
||||
|
||||
Clause (vector<vector<string>> names) : constr_(ConstraintTree (names)) { }
|
||||
|
||||
void addLiteral (const Literal& l) { literals_.push_back (l); }
|
||||
|
||||
const Literals& literals (void) const { return literals_; }
|
||||
|
||||
const ConstraintTree& constr (void) const { return constr_; }
|
||||
|
||||
ConstraintTree constr (void) { return constr_; }
|
||||
|
||||
bool isUnit (void) const { return literals_.size() == 1; }
|
||||
|
||||
LogVarSet ipgLogVars (void) const { return ipgLvs_; }
|
||||
|
||||
void addIpgLogVar (LogVar X) { ipgLvs_.insert (X); }
|
||||
|
||||
void addPosCountedLogVar (LogVar X) { posCountedLvs_.insert (X); }
|
||||
|
||||
void addNegCountedLogVar (LogVar X) { negCountedLvs_.insert (X); }
|
||||
|
||||
LogVarSet posCountedLogVars (void) const { return posCountedLvs_; }
|
||||
|
||||
LogVarSet negCountedLogVars (void) const { return negCountedLvs_; }
|
||||
|
||||
unsigned nrPosCountedLogVars (void) const { return posCountedLvs_.size(); }
|
||||
|
||||
unsigned nrNegCountedLogVars (void) const { return negCountedLvs_.size(); }
|
||||
|
||||
void addLiteralComplemented (const Literal& lit);
|
||||
|
||||
bool containsLiteral (LiteralId lid) const;
|
||||
|
||||
bool containsPositiveLiteral (LiteralId lid, const LogVarTypes&) const;
|
||||
|
||||
bool containsNegativeLiteral (LiteralId lid, const LogVarTypes&) const;
|
||||
|
||||
void removeLiterals (LiteralId lid);
|
||||
|
||||
void removePositiveLiterals (LiteralId lid, const LogVarTypes&);
|
||||
|
||||
void removeNegativeLiterals (LiteralId lid, const LogVarTypes&);
|
||||
|
||||
bool isCountedLogVar (LogVar X) const;
|
||||
|
||||
bool isPositiveCountedLogVar (LogVar X) const;
|
||||
|
||||
bool isNegativeCountedLogVar (LogVar X) const;
|
||||
|
||||
bool isIpgLogVar (LogVar X) const;
|
||||
|
||||
TinySet<LiteralId> lidSet (void) const;
|
||||
|
||||
LogVarSet ipgCandidates (void) const;
|
||||
|
||||
LogVarTypes logVarTypes (size_t litIdx) const;
|
||||
|
||||
void removeLiteral (size_t litIdx);
|
||||
|
||||
static bool independentClauses (Clause& c1, Clause& c2);
|
||||
|
||||
static void printClauses (const vector<Clause>& clauses);
|
||||
|
||||
friend std::ostream& operator<< (ostream &os, const Clause& clause);
|
||||
|
||||
private:
|
||||
LogVarSet getLogVarSetExcluding (size_t idx) const;
|
||||
|
||||
Literals literals_;
|
||||
LogVarSet ipgLvs_;
|
||||
LogVarSet posCountedLvs_;
|
||||
LogVarSet negCountedLvs_;
|
||||
ConstraintTree constr_;
|
||||
};
|
||||
|
||||
typedef vector<Clause> Clauses;
|
||||
|
||||
|
||||
|
||||
class LitLvTypes
|
||||
{
|
||||
public:
|
||||
struct CompareLitLvTypes
|
||||
{
|
||||
bool operator() (
|
||||
const LitLvTypes& types1,
|
||||
const LitLvTypes& types2) const
|
||||
{
|
||||
if (types1.lid_ < types2.lid_) {
|
||||
return true;
|
||||
}
|
||||
return types1.lvTypes_ < types2.lvTypes_;
|
||||
}
|
||||
};
|
||||
|
||||
LitLvTypes (LiteralId lid, const LogVarTypes& lvTypes) :
|
||||
lid_(lid), lvTypes_(lvTypes) { }
|
||||
|
||||
LiteralId lid (void) const { return lid_; }
|
||||
|
||||
const LogVarTypes& logVarTypes (void) const { return lvTypes_; }
|
||||
|
||||
void setAllFullLogVars (void) {
|
||||
std::fill (lvTypes_.begin(), lvTypes_.end(), LogVarType::FULL_LV); }
|
||||
|
||||
private:
|
||||
LiteralId lid_;
|
||||
LogVarTypes lvTypes_;
|
||||
};
|
||||
|
||||
typedef TinySet<LitLvTypes,LitLvTypes::CompareLitLvTypes> LitLvTypesSet;
|
||||
|
||||
|
||||
|
||||
class LiftedWCNF
|
||||
{
|
||||
public:
|
||||
LiftedWCNF (const ParfactorList& pfList);
|
||||
|
||||
~LiftedWCNF (void);
|
||||
|
||||
const Clauses& clauses (void) const { return clauses_; }
|
||||
|
||||
void addWeight (LiteralId lid, double posW, double negW);
|
||||
|
||||
double posWeight (LiteralId lid) const;
|
||||
|
||||
double negWeight (LiteralId lid) const;
|
||||
|
||||
vector<LiteralId> prvGroupLiterals (PrvGroup prvGroup);
|
||||
|
||||
Clause createClause (LiteralId lid) const;
|
||||
|
||||
void printFormulaIndicators (void) const;
|
||||
|
||||
void printWeights (void) const;
|
||||
|
||||
void printClauses (void) const;
|
||||
|
||||
private:
|
||||
|
||||
LiteralId getLiteralId (PrvGroup prvGroup, unsigned range);
|
||||
|
||||
void addIndicatorClauses (const ParfactorList& pfList);
|
||||
|
||||
void addParameterClauses (const ParfactorList& pfList);
|
||||
|
||||
Clauses clauses_;
|
||||
LiteralId freeLiteralId_;
|
||||
const ParfactorList& pfList_;
|
||||
unordered_map<PrvGroup, vector<LiteralId>> map_;
|
||||
unordered_map<LiteralId, std::pair<double,double>> weights_;
|
||||
};
|
||||
|
||||
#endif // HORUS_LIFTEDWCNF_H
|
||||
|
@ -23,10 +23,10 @@ CC=@CC@
|
||||
CXX=@CXX@
|
||||
|
||||
# normal
|
||||
CXXFLAGS= -std=c++0x @SHLIB_CXXFLAGS@ $(YAP_EXTRAS) $(DEFS) -D_YAP_NOT_INSTALLED_=1 -I$(srcdir) -I../../.. -I$(srcdir)/../../../include @CPPFLAGS@ -DNDEBUG
|
||||
#CXXFLAGS= -std=c++0x @SHLIB_CXXFLAGS@ $(YAP_EXTRAS) $(DEFS) -D_YAP_NOT_INSTALLED_=1 -I$(srcdir) -I../../.. -I$(srcdir)/../../../include @CPPFLAGS@ -DNDEBUG
|
||||
|
||||
# debug
|
||||
#CXXFLAGS= -std=c++0x @SHLIB_CXXFLAGS@ $(YAP_EXTRAS) $(DEFS) -D_YAP_NOT_INSTALLED_=1 -I$(srcdir) -I../../.. -I$(srcdir)/../../../include @CPPFLAGS@ -g -O0 -Wextra
|
||||
CXXFLAGS= -std=c++0x @SHLIB_CXXFLAGS@ $(YAP_EXTRAS) $(DEFS) -D_YAP_NOT_INSTALLED_=1 -I$(srcdir) -I../../.. -I$(srcdir)/../../../include @CPPFLAGS@ -g -O0 -Wextra
|
||||
|
||||
|
||||
#
|
||||
@ -57,12 +57,17 @@ HEADERS = \
|
||||
$(srcdir)/Horus.h \
|
||||
$(srcdir)/Indexer.h \
|
||||
$(srcdir)/LiftedBp.h \
|
||||
$(srcdir)/LiftedCircuit.h \
|
||||
$(srcdir)/LiftedKc.h \
|
||||
$(srcdir)/LiftedOperations.h \
|
||||
$(srcdir)/LiftedSolver.h \
|
||||
$(srcdir)/LiftedUtils.h \
|
||||
$(srcdir)/LiftedVe.h \
|
||||
$(srcdir)/LiftedWCNF.h \
|
||||
$(srcdir)/Parfactor.h \
|
||||
$(srcdir)/ParfactorList.h \
|
||||
$(srcdir)/ProbFormula.h \
|
||||
$(srcdir)/Solver.h \
|
||||
$(srcdir)/GroundSolver.h \
|
||||
$(srcdir)/TinySet.h \
|
||||
$(srcdir)/Util.h \
|
||||
$(srcdir)/Var.h \
|
||||
@ -82,16 +87,20 @@ CPP_SOURCES = \
|
||||
$(srcdir)/HorusCli.cpp \
|
||||
$(srcdir)/HorusYap.cpp \
|
||||
$(srcdir)/LiftedBp.cpp \
|
||||
$(srcdir)/LiftedCircuit.cpp \
|
||||
$(srcdir)/LiftedKc.cpp \
|
||||
$(srcdir)/LiftedOperations.cpp \
|
||||
$(srcdir)/LiftedUtils.cpp \
|
||||
$(srcdir)/LiftedVe.cpp \
|
||||
$(srcdir)/LiftedWCNF.cpp \
|
||||
$(srcdir)/Parfactor.cpp \
|
||||
$(srcdir)/ParfactorList.cpp \
|
||||
$(srcdir)/ProbFormula.cpp \
|
||||
$(srcdir)/Solver.cpp \
|
||||
$(srcdir)/GroundSolver.cpp \
|
||||
$(srcdir)/Util.cpp \
|
||||
$(srcdir)/Var.cpp \
|
||||
$(srcdir)/VarElim.cpp \
|
||||
$(srcdir)/WeightedBp.cpp \
|
||||
$(srcdir)/WeightedBp.cpp
|
||||
|
||||
OBJS = \
|
||||
BayesBall.o \
|
||||
@ -105,12 +114,16 @@ OBJS = \
|
||||
Histogram.o \
|
||||
HorusYap.o \
|
||||
LiftedBp.o \
|
||||
LiftedCircuit.o \
|
||||
LiftedKc.o \
|
||||
LiftedOperations.o \
|
||||
LiftedUtils.o \
|
||||
LiftedVe.o \
|
||||
LiftedWCNF.o \
|
||||
ProbFormula.o \
|
||||
Parfactor.o \
|
||||
ParfactorList.o \
|
||||
Solver.o \
|
||||
GroundSolver.o \
|
||||
Util.o \
|
||||
Var.o \
|
||||
VarElim.o \
|
||||
@ -125,7 +138,7 @@ HCLI_OBJS = \
|
||||
Factor.o \
|
||||
FactorGraph.o \
|
||||
HorusCli.o \
|
||||
Solver.o \
|
||||
GroundSolver.o \
|
||||
Util.o \
|
||||
Var.o \
|
||||
VarElim.o \
|
||||
|
@ -9,8 +9,7 @@ ParfactorList::ParfactorList (const ParfactorList& pfList)
|
||||
while (it != pfList.end()) {
|
||||
addShattered (new Parfactor (**it));
|
||||
++ it;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -126,6 +125,27 @@ ParfactorList::print (void) const
|
||||
|
||||
|
||||
|
||||
ParfactorList&
|
||||
ParfactorList::operator= (const ParfactorList& pfList)
|
||||
{
|
||||
if (this != &pfList) {
|
||||
ParfactorList::const_iterator it0 = pfList_.begin();
|
||||
while (it0 != pfList_.end()) {
|
||||
delete *it0;
|
||||
++ it0;
|
||||
}
|
||||
pfList_.clear();
|
||||
ParfactorList::const_iterator it = pfList.begin();
|
||||
while (it != pfList.end()) {
|
||||
addShattered (new Parfactor (**it));
|
||||
++ it;
|
||||
}
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
ParfactorList::isShattered (const Parfactor* g) const
|
||||
{
|
||||
|
@ -56,6 +56,8 @@ class ParfactorList
|
||||
bool isAllShattered (void) const;
|
||||
|
||||
void print (void) const;
|
||||
|
||||
ParfactorList& operator= (const ParfactorList& pfList);
|
||||
|
||||
private:
|
||||
|
||||
|
@ -1,3 +1,2 @@
|
||||
- Find a way to decrease the time required to find an
|
||||
elimination order for variable elimination
|
||||
- Handle formulas like f(X,X)
|
||||
|
||||
|
@ -153,6 +153,11 @@ class TinySet
|
||||
{
|
||||
return vec_[i];
|
||||
}
|
||||
|
||||
T& operator[] (typename vector<T>::size_type i)
|
||||
{
|
||||
return vec_[i];
|
||||
}
|
||||
|
||||
T front (void) const
|
||||
{
|
||||
|
@ -13,9 +13,9 @@ bool logDomain = false;
|
||||
|
||||
unsigned verbosity = 0;
|
||||
|
||||
LiftedSolver liftedSolver = LiftedSolver::FOVE;
|
||||
LiftedSolverType liftedSolver = LiftedSolverType::LVE;
|
||||
|
||||
GroundSolver groundSolver = GroundSolver::VE;
|
||||
GroundSolverType groundSolver = GroundSolverType::VE;
|
||||
|
||||
};
|
||||
|
||||
@ -210,10 +210,12 @@ setHorusFlag (string key, string value)
|
||||
ss << value;
|
||||
ss >> Globals::verbosity;
|
||||
} else if (key == "lifted_solver") {
|
||||
if ( value == "fove") {
|
||||
Globals::liftedSolver = LiftedSolver::FOVE;
|
||||
if ( value == "lve") {
|
||||
Globals::liftedSolver = LiftedSolverType::LVE;
|
||||
} else if (value == "lbp") {
|
||||
Globals::liftedSolver = LiftedSolver::LBP;
|
||||
Globals::liftedSolver = LiftedSolverType::LBP;
|
||||
} else if (value == "lkc") {
|
||||
Globals::liftedSolver = LiftedSolverType::LKC;
|
||||
} else {
|
||||
cerr << "warning: invalid value `" << value << "' " ;
|
||||
cerr << "for `" << key << "'" << endl;
|
||||
@ -221,11 +223,11 @@ setHorusFlag (string key, string value)
|
||||
}
|
||||
} else if (key == "ground_solver") {
|
||||
if ( value == "ve") {
|
||||
Globals::groundSolver = GroundSolver::VE;
|
||||
Globals::groundSolver = GroundSolverType::VE;
|
||||
} else if (value == "bp") {
|
||||
Globals::groundSolver = GroundSolver::BP;
|
||||
Globals::groundSolver = GroundSolverType::BP;
|
||||
} else if (value == "cbp") {
|
||||
Globals::groundSolver = GroundSolver::CBP;
|
||||
Globals::groundSolver = GroundSolverType::CBP;
|
||||
} else {
|
||||
cerr << "warning: invalid value `" << value << "' " ;
|
||||
cerr << "for `" << key << "'" << endl;
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
#include "unordered_map"
|
||||
|
||||
#include "Solver.h"
|
||||
#include "GroundSolver.h"
|
||||
#include "FactorGraph.h"
|
||||
#include "Horus.h"
|
||||
|
||||
@ -11,10 +11,10 @@
|
||||
using namespace std;
|
||||
|
||||
|
||||
class VarElim : public Solver
|
||||
class VarElim : public GroundSolver
|
||||
{
|
||||
public:
|
||||
VarElim (const FactorGraph& fg) : Solver (fg) { }
|
||||
VarElim (const FactorGraph& fg) : GroundSolver (fg) { }
|
||||
|
||||
~VarElim (void);
|
||||
|
||||
|
Reference in New Issue
Block a user