Merge branch 'master' of git.dcc.fc.up.pt:yap-6.3

This commit is contained in:
Vitor Santos Costa 2013-04-16 13:32:24 -05:00
commit 8e33cebd4d
118 changed files with 6455 additions and 4140 deletions

105
C/absmi.c
View File

@ -9376,7 +9376,7 @@ Yap_absmi(int inp)
}
else {
saveregs();
d0 = p_plus(Yap_Eval(d0), Yap_Eval(d1));
d0 = p_plus(Yap_Eval(d0), Yap_Eval(d1) PASS_REGS);
setregs();
if (d0 == 0L) {
saveregs();
@ -9421,7 +9421,7 @@ Yap_absmi(int inp)
}
else {
saveregs();
d0 = p_plus(Yap_Eval(d0), MkIntegerTerm(d1));
d0 = p_plus(Yap_Eval(d0), MkIntegerTerm(d1) PASS_REGS);
setregs();
if (d0 == 0L) {
saveregs();
@ -9462,7 +9462,7 @@ Yap_absmi(int inp)
}
else {
saveregs();
d0 = p_plus(Yap_Eval(d0), Yap_Eval(d1));
d0 = p_plus(Yap_Eval(d0), Yap_Eval(d1) PASS_REGS);
setregs();
if (d0 == 0L) {
saveregs();
@ -9510,7 +9510,7 @@ Yap_absmi(int inp)
}
else {
saveregs();
d0 = p_plus(Yap_Eval(d0), MkIntegerTerm(d1));
d0 = p_plus(Yap_Eval(d0), MkIntegerTerm(d1) PASS_REGS);
setregs();
if (d0 == 0L) {
saveregs();
@ -9554,7 +9554,7 @@ Yap_absmi(int inp)
}
else {
saveregs();
d0 = p_minus(Yap_Eval(d0), Yap_Eval(d1));
d0 = p_minus(Yap_Eval(d0), Yap_Eval(d1) PASS_REGS);
setregs();
if (d0 == 0L) {
saveregs();
@ -9599,7 +9599,7 @@ Yap_absmi(int inp)
}
else {
saveregs();
d0 = p_minus(MkIntegerTerm(d1),Yap_Eval(d0));
d0 = p_minus(MkIntegerTerm(d1),Yap_Eval(d0) PASS_REGS);
setregs();
if (d0 == 0L) {
saveregs();
@ -9640,7 +9640,7 @@ Yap_absmi(int inp)
}
else {
saveregs();
d0 = p_minus(Yap_Eval(d0), Yap_Eval(d1));
d0 = p_minus(Yap_Eval(d0), Yap_Eval(d1) PASS_REGS);
setregs();
if (d0 == 0L) {
saveregs();
@ -9688,7 +9688,7 @@ Yap_absmi(int inp)
}
else {
saveregs();
d0 = p_minus(MkIntegerTerm(d1), Yap_Eval(d0));
d0 = p_minus(MkIntegerTerm(d1), Yap_Eval(d0) PASS_REGS);
setregs();
if (d0 == 0L) {
saveregs();
@ -9728,11 +9728,11 @@ Yap_absmi(int inp)
times_vv_nvar_nvar:
/* d0 and d1 are where I want them */
if (IsIntTerm(d0) && IsIntTerm(d1)) {
d0 = times_int(IntOfTerm(d0), IntOfTerm(d1));
d0 = times_int(IntOfTerm(d0), IntOfTerm(d1) PASS_REGS);
}
else {
saveregs();
d0 = p_times(Yap_Eval(d0), Yap_Eval(d1));
d0 = p_times(Yap_Eval(d0), Yap_Eval(d1) PASS_REGS);
setregs();
if (d0 == 0L) {
saveregs();
@ -9773,11 +9773,11 @@ Yap_absmi(int inp)
{
Int d1 = PREG->u.xxn.c;
if (IsIntTerm(d0)) {
d0 = times_int(IntOfTerm(d0), d1);
d0 = times_int(IntOfTerm(d0), d1 PASS_REGS);
}
else {
saveregs();
d0 = p_times(Yap_Eval(d0), MkIntegerTerm(d1));
d0 = p_times(Yap_Eval(d0), MkIntegerTerm(d1) PASS_REGS);
setregs();
if (d0 == 0L) {
saveregs();
@ -9814,11 +9814,11 @@ Yap_absmi(int inp)
times_y_vv_nvar_nvar:
/* d0 and d1 are where I want them */
if (IsIntTerm(d0) && IsIntTerm(d1)) {
d0 = times_int(IntOfTerm(d0), IntOfTerm(d1));
d0 = times_int(IntOfTerm(d0), IntOfTerm(d1) PASS_REGS);
}
else {
saveregs();
d0 = p_times(Yap_Eval(d0), Yap_Eval(d1));
d0 = p_times(Yap_Eval(d0), Yap_Eval(d1) PASS_REGS);
setregs();
if (d0 == 0L) {
saveregs();
@ -9862,11 +9862,11 @@ Yap_absmi(int inp)
{
Int d1 = PREG->u.yxn.c;
if (IsIntTerm(d0)) {
d0 = times_int(IntOfTerm(d0), d1);
d0 = times_int(IntOfTerm(d0), d1 PASS_REGS);
}
else {
saveregs();
d0 = p_times(Yap_Eval(d0), MkIntegerTerm(d1));
d0 = p_times(Yap_Eval(d0), MkIntegerTerm(d1) PASS_REGS);
setregs();
if (d0 == 0L) {
saveregs();
@ -9917,7 +9917,7 @@ Yap_absmi(int inp)
}
else {
saveregs();
d0 = p_div(Yap_Eval(d0), Yap_Eval(d1));
d0 = p_div(Yap_Eval(d0), Yap_Eval(d1) PASS_REGS);
setregs();
if (d0 == 0L) {
saveregs();
@ -9962,7 +9962,7 @@ Yap_absmi(int inp)
}
else {
saveregs();
d0 = p_div(Yap_Eval(d0),MkIntegerTerm(d1));
d0 = p_div(Yap_Eval(d0),MkIntegerTerm(d1) PASS_REGS);
setregs();
if (d0 == 0L) {
saveregs();
@ -10006,7 +10006,7 @@ Yap_absmi(int inp)
}
else {
saveregs();
d0 = p_div(MkIntegerTerm(d1),Yap_Eval(d0));
d0 = p_div(MkIntegerTerm(d1),Yap_Eval(d0) PASS_REGS);
if (d0 == 0L) {
saveregs();
Yap_Error(LOCAL_Error_TYPE, LOCAL_Error_Term, LOCAL_ErrorMessage);
@ -10053,7 +10053,7 @@ Yap_absmi(int inp)
}
else {
saveregs();
d0 = p_div(Yap_Eval(d0), Yap_Eval(d1));
d0 = p_div(Yap_Eval(d0), Yap_Eval(d1) PASS_REGS);
setregs();
if (d0 == 0L) {
saveregs();
@ -10101,7 +10101,7 @@ Yap_absmi(int inp)
}
else {
saveregs();
d0 = p_div(Yap_Eval(d0),MkIntegerTerm(d1));
d0 = p_div(Yap_Eval(d0),MkIntegerTerm(d1) PASS_REGS);
setregs();
if (d0 == 0L) {
saveregs();
@ -10148,7 +10148,7 @@ Yap_absmi(int inp)
}
else {
saveregs();
d0 = p_div(MkIntegerTerm(d1), Yap_Eval(d0));
d0 = p_div(MkIntegerTerm(d1), Yap_Eval(d0) PASS_REGS);
setregs();
if (d0 == 0L) {
saveregs();
@ -10193,7 +10193,7 @@ Yap_absmi(int inp)
}
else {
saveregs();
d0 = p_and(Yap_Eval(d0), Yap_Eval(d1));
d0 = p_and(Yap_Eval(d0), Yap_Eval(d1) PASS_REGS);
setregs();
if (d0 == 0L) {
saveregs();
@ -10238,7 +10238,7 @@ Yap_absmi(int inp)
}
else {
saveregs();
d0 = p_and(Yap_Eval(d0), MkIntegerTerm(d1));
d0 = p_and(Yap_Eval(d0), MkIntegerTerm(d1) PASS_REGS);
setregs();
if (d0 == 0L) {
saveregs();
@ -10279,7 +10279,7 @@ Yap_absmi(int inp)
}
else {
saveregs();
d0 = p_and(Yap_Eval(d0), Yap_Eval(d1));
d0 = p_and(Yap_Eval(d0), Yap_Eval(d1) PASS_REGS);
setregs();
if (d0 == 0L) {
saveregs();
@ -10327,7 +10327,7 @@ Yap_absmi(int inp)
}
else {
saveregs();
d0 = p_and(Yap_Eval(d0), MkIntegerTerm(d1));
d0 = p_and(Yap_Eval(d0), MkIntegerTerm(d1) PASS_REGS);
setregs();
if (d0 == 0L) {
saveregs();
@ -10372,7 +10372,7 @@ Yap_absmi(int inp)
}
else {
saveregs();
d0 = p_or(Yap_Eval(d0), Yap_Eval(d1));
d0 = p_or(Yap_Eval(d0), Yap_Eval(d1) PASS_REGS);
setregs();
if (d0 == 0L) {
saveregs();
@ -10417,7 +10417,7 @@ Yap_absmi(int inp)
}
else {
saveregs();
d0 = p_or(Yap_Eval(d0), MkIntegerTerm(d1));
d0 = p_or(Yap_Eval(d0), MkIntegerTerm(d1) PASS_REGS);
if (d0 == 0L) {
saveregs();
Yap_Error(LOCAL_Error_TYPE, LOCAL_Error_Term, LOCAL_ErrorMessage);
@ -10457,7 +10457,7 @@ Yap_absmi(int inp)
}
else {
saveregs();
d0 = p_or(Yap_Eval(d0), Yap_Eval(d1));
d0 = p_or(Yap_Eval(d0), Yap_Eval(d1) PASS_REGS);
setregs();
if (d0 == 0L) {
saveregs();
@ -10505,7 +10505,7 @@ Yap_absmi(int inp)
}
else {
saveregs();
d0 = p_or(Yap_Eval(d0), MkIntegerTerm(d1));
d0 = p_or(Yap_Eval(d0), MkIntegerTerm(d1) PASS_REGS);
setregs();
if (d0 == 0L) {
saveregs();
@ -10549,11 +10549,11 @@ Yap_absmi(int inp)
if (i2 < 0)
d0 = MkIntegerTerm(SLR(IntOfTerm(d0), -i2));
else
d0 = do_sll(IntOfTerm(d0),i2);
d0 = do_sll(IntOfTerm(d0),i2 PASS_REGS);
}
else {
saveregs();
d0 = p_sll(Yap_Eval(d0), Yap_Eval(d1));
d0 = p_sll(Yap_Eval(d0), Yap_Eval(d1) PASS_REGS);
setregs();
}
if (d0 == 0L) {
@ -10594,11 +10594,11 @@ Yap_absmi(int inp)
{
Int d1 = PREG->u.xxn.c;
if (IsIntTerm(d0)) {
d0 = do_sll(IntOfTerm(d0), (Int)d1);
d0 = do_sll(IntOfTerm(d0), (Int)d1 PASS_REGS);
}
else {
saveregs();
d0 = p_sll(Yap_Eval(d0), MkIntegerTerm(d1));
d0 = p_sll(Yap_Eval(d0), MkIntegerTerm(d1) PASS_REGS);
setregs();
}
}
@ -10635,11 +10635,11 @@ Yap_absmi(int inp)
if (i2 < 0)
d0 = MkIntegerTerm(SLR(d1, -i2));
else
d0 = do_sll(d1,i2);
d0 = do_sll(d1,i2 PASS_REGS);
}
else {
saveregs();
d0 = p_sll(MkIntegerTerm(d1), Yap_Eval(d0));
d0 = p_sll(MkIntegerTerm(d1), Yap_Eval(d0) PASS_REGS);
setregs();
}
}
@ -10680,11 +10680,11 @@ Yap_absmi(int inp)
if (i2 < 0)
d0 = MkIntegerTerm(SLR(IntOfTerm(d0), -i2));
else
d0 = do_sll(IntOfTerm(d0),i2);
d0 = do_sll(IntOfTerm(d0),i2 PASS_REGS);
}
else {
saveregs();
d0 = p_sll(Yap_Eval(d0), Yap_Eval(d1));
d0 = p_sll(Yap_Eval(d0), Yap_Eval(d1) PASS_REGS);
setregs();
}
if (d0 == 0L) {
@ -10728,11 +10728,11 @@ Yap_absmi(int inp)
{
Int d1 = PREG->u.yxn.c;
if (IsIntTerm(d0)) {
d0 = do_sll(IntOfTerm(d0), Yap_Eval(d1));
d0 = do_sll(IntOfTerm(d0), Yap_Eval(d1) PASS_REGS);
}
else {
saveregs();
d0 = p_sll(Yap_Eval(d0), MkIntegerTerm(d1));
d0 = p_sll(Yap_Eval(d0), MkIntegerTerm(d1) PASS_REGS);
setregs();
}
}
@ -10773,11 +10773,11 @@ Yap_absmi(int inp)
if (i2 < 0)
d0 = MkIntegerTerm(SLR(d1, -i2));
else
d0 = do_sll(d1,i2);
d0 = do_sll(d1,i2 PASS_REGS);
}
else {
saveregs();
d0 = p_sll(MkIntegerTerm(d1), Yap_Eval(0));
d0 = p_sll(MkIntegerTerm(d1), Yap_Eval(0) PASS_REGS);
setregs();
}
}
@ -10819,13 +10819,13 @@ Yap_absmi(int inp)
if (IsIntTerm(d0) && IsIntTerm(d1)) {
Int i2 = IntOfTerm(d1);
if (i2 < 0)
d0 = do_sll(IntOfTerm(d0), -i2);
d0 = do_sll(IntOfTerm(d0), -i2 PASS_REGS);
else
d0 = MkIntTerm(SLR(IntOfTerm(d0), i2));
}
else {
saveregs();
d0 = p_slr(Yap_Eval(d0), Yap_Eval(d1));
d0 = p_slr(Yap_Eval(d0), Yap_Eval(d1) PASS_REGS);
setregs();
}
if (d0 == 0L) {
@ -10870,7 +10870,7 @@ Yap_absmi(int inp)
}
else {
saveregs();
d0 = p_slr(Yap_Eval(d0), MkIntegerTerm(d1));
d0 = p_slr(Yap_Eval(d0), MkIntegerTerm(d1) PASS_REGS);
setregs();
if (d0 == 0L) {
saveregs();
@ -10905,13 +10905,13 @@ Yap_absmi(int inp)
if (IsIntTerm(d0)) {
Int i2 = IntOfTerm(d0);
if (i2 < 0)
d0 = do_sll(d1, -i2);
d0 = do_sll(d1, -i2 PASS_REGS);
else
d0 = MkIntegerTerm(SLR(d1, i2));
}
else {
saveregs();
d0 = p_slr(MkIntegerTerm(d1), Yap_Eval(d0));
d0 = p_slr(MkIntegerTerm(d1), Yap_Eval(d0) PASS_REGS);
setregs();
}
}
@ -10950,13 +10950,13 @@ Yap_absmi(int inp)
if (IsIntTerm(d0) && IsIntTerm(d1)) {
Int i2 = IntOfTerm(d1);
if (i2 < 0)
d0 = do_sll(IntOfTerm(d0), -i2);
d0 = do_sll(IntOfTerm(d0), -i2 PASS_REGS);
else
d0 = MkIntTerm(SLR(IntOfTerm(d0), i2));
}
else {
saveregs();
d0 = p_slr(Yap_Eval(d0), Yap_Eval(d1));
d0 = p_slr(Yap_Eval(d0), Yap_Eval(d1) PASS_REGS);
setregs();
}
BEGP(pt0);
@ -11004,7 +11004,7 @@ Yap_absmi(int inp)
}
else {
saveregs();
d0 = p_slr(Yap_Eval(d0), MkIntegerTerm(d1));
d0 = p_slr(Yap_Eval(d0), MkIntegerTerm(d1) PASS_REGS);
setregs();
if (d0 == 0L) {
saveregs();
@ -11041,13 +11041,13 @@ Yap_absmi(int inp)
if (IsIntTerm(d0)) {
Int i2 = IntOfTerm(d0);
if (i2 < 0)
d0 = do_sll(d1, -i2);
d0 = do_sll(d1, -i2 PASS_REGS);
else
d0 = MkIntegerTerm(SLR(d1, i2));
}
else {
saveregs();
d0 = p_slr(MkIntegerTerm(d1), Yap_Eval(d0));
d0 = p_slr(MkIntegerTerm(d1), Yap_Eval(d0) PASS_REGS);
setregs();
}
}
@ -13432,7 +13432,6 @@ Yap_absmi(int inp)
}
PP = NULL;
SREG = (CELL *) pen;
fprintf(stderr,"Here I was\n");
ASP = ENV_YREG;
if (ASP > (CELL *)PROTECT_FROZEN_B(B))
ASP = (CELL *)PROTECT_FROZEN_B(B);

View File

@ -852,7 +852,6 @@ Yap_NewPredPropByFunctor(FunctorEntry *fe, Term cur_mod)
p->FunctorOfPred = fe;
WRITE_UNLOCK(fe->FRWLock);
{
CACHE_REGS
Yap_inform_profiler_of_clause(&(p->OpcodeOfPred), &(p->OpcodeOfPred)+1, p, GPROF_NEW_PRED_FUNC);
if (!(p->PredFlags & (CPredFlag|AsmPredFlag))) {
Yap_inform_profiler_of_clause(&(p->cs.p_code.ExpandCode), &(p->cs.p_code.ExpandCode)+1, p, GPROF_NEW_PRED_FUNC);
@ -966,7 +965,6 @@ Yap_NewPredPropByAtom(AtomEntry *ae, Term cur_mod)
p->FunctorOfPred = (Functor)AbsAtom(ae);
WRITE_UNLOCK(ae->ARWLock);
{
CACHE_REGS
Yap_inform_profiler_of_clause(&(p->OpcodeOfPred), &(p->OpcodeOfPred)+1, p, GPROF_NEW_PRED_ATOM);
if (!(p->PredFlags & (CPredFlag|AsmPredFlag))) {
Yap_inform_profiler_of_clause(&(p->cs.p_code.ExpandCode), &(p->cs.p_code.ExpandCode)+1, p, GPROF_NEW_PRED_ATOM);
@ -1057,8 +1055,10 @@ Yap_GetValue(Atom a)
if (IsApplTerm(out)) {
Functor f = FunctorOfTerm(out);
if (f == FunctorDouble) {
CACHE_REGS
out = MkFloatTerm(FloatOfTerm(out));
} else if (f == FunctorLongInt) {
CACHE_REGS
out = MkLongIntTerm(LongIntOfTerm(out));
}
#ifdef USE_GMP

View File

@ -2056,10 +2056,7 @@ a_try(op_numbers opcode, CELL lab, CELL opr, int nofalts, int hascut, yamop *cod
save_machine_regs();
siglongjmp(cip->CompilerBotch,2);
}
{
CACHE_REGS
Yap_inform_profiler_of_clause(newcp, (char *)(newcp)+size, ap, GPROF_INDEX);
}
Yap_inform_profiler_of_clause(newcp, (char *)(newcp)+size, ap, GPROF_INDEX);
Yap_LUIndexSpace_CP += size;
#ifdef DEBUG
Yap_NewCps++;

View File

@ -29,7 +29,7 @@ static char SccsId[] = "%W% %G%";
#include "eval.h"
static Term
float_to_int(Float v)
float_to_int(Float v USES_REGS)
{
#if USE_GMP
Int i = (Int)v;
@ -44,7 +44,7 @@ float_to_int(Float v)
#endif
}
#define RBIG_FL(v) return(float_to_int(v))
#define RBIG_FL(v) return(float_to_int(v PASS_REGS))
typedef struct init_un_eval {
char *OpName;
@ -118,7 +118,7 @@ double my_rint(double x)
#endif
static Int
msb(Int inp) /* calculate the most significant bit for an integer */
msb(Int inp USES_REGS) /* calculate the most significant bit for an integer */
{
/* the obvious solution: do it by using binary search */
Int out = 0;
@ -141,7 +141,7 @@ msb(Int inp) /* calculate the most significant bit for an integer */
}
static Int
lsb(Int inp) /* calculate the least significant bit for an integer */
lsb(Int inp USES_REGS) /* calculate the least significant bit for an integer */
{
/* the obvious solution: do it by using binary search */
Int out = 0;
@ -165,7 +165,7 @@ lsb(Int inp) /* calculate the least significant bit for an integer */
}
static Int
popcount(Int inp) /* calculate the least significant bit for an integer */
popcount(Int inp USES_REGS) /* calculate the least significant bit for an integer */
{
/* the obvious solution: do it by using binary search */
Int c = 0, j = 0, m = ((CELL)1);
@ -185,7 +185,7 @@ popcount(Int inp) /* calculate the least significant bit for an integer */
}
static Term
eval1(Int fi, Term t) {
eval1(Int fi, Term t USES_REGS) {
arith1_op f = fi;
switch (f) {
case op_uplus:
@ -586,7 +586,7 @@ eval1(Int fi, Term t) {
case op_msb:
switch (ETypeOfTerm(t)) {
case long_int_e:
RINT(msb(IntegerOfTerm(t)));
RINT(msb(IntegerOfTerm(t) PASS_REGS));
case double_e:
return Yap_ArithError(TYPE_ERROR_INTEGER, t, "msb(%f)", FloatOfTerm(t));
case big_int_e:
@ -599,7 +599,7 @@ eval1(Int fi, Term t) {
case op_lsb:
switch (ETypeOfTerm(t)) {
case long_int_e:
RINT(lsb(IntegerOfTerm(t)));
RINT(lsb(IntegerOfTerm(t) PASS_REGS));
case double_e:
return Yap_ArithError(TYPE_ERROR_INTEGER, t, "lsb(%f)", FloatOfTerm(t));
case big_int_e:
@ -612,7 +612,7 @@ eval1(Int fi, Term t) {
case op_popcount:
switch (ETypeOfTerm(t)) {
case long_int_e:
RINT(popcount(IntegerOfTerm(t)));
RINT(popcount(IntegerOfTerm(t) PASS_REGS));
case double_e:
return Yap_ArithError(TYPE_ERROR_INTEGER, t, "popcount(%f)", FloatOfTerm(t));
case big_int_e:
@ -699,7 +699,8 @@ eval1(Int fi, Term t) {
Term Yap_eval_unary(Int f, Term t)
{
return eval1(f,t);
CACHE_REGS
return eval1(f,t PASS_REGS);
}
static InitUnEntry InitUnTab[] = {
@ -758,7 +759,7 @@ p_unary_is( USES_REGS1 )
return FALSE;
}
if (IsIntTerm(t)) {
Term tout = Yap_FoundArithError(eval1(IntegerOfTerm(t), top), Deref(ARG3));
Term tout = Yap_FoundArithError(eval1(IntegerOfTerm(t), top PASS_REGS), Deref(ARG3));
if (!tout)
return FALSE;
return Yap_unify_constant(ARG1,tout);
@ -781,7 +782,7 @@ p_unary_is( USES_REGS1 )
P = FAILCODE;
return(FALSE);
}
if (!(out=Yap_FoundArithError(eval1(p->FOfEE, top),Deref(ARG3))))
if (!(out=Yap_FoundArithError(eval1(p->FOfEE, top PASS_REGS),Deref(ARG3))))
return FALSE;
return Yap_unify_constant(ARG1,out);
}

View File

@ -37,7 +37,7 @@ typedef struct init_un_eval {
static Term
p_mod(Term t1, Term t2) {
p_mod(Term t1, Term t2 USES_REGS) {
switch (ETypeOfTerm(t1)) {
case (CELL)long_int_e:
switch (ETypeOfTerm(t2)) {
@ -97,7 +97,7 @@ p_mod(Term t1, Term t2) {
}
static Term
p_div2(Term t1, Term t2) {
p_div2(Term t1, Term t2 USES_REGS) {
switch (ETypeOfTerm(t1)) {
case (CELL)long_int_e:
switch (ETypeOfTerm(t2)) {
@ -163,7 +163,7 @@ p_div2(Term t1, Term t2) {
}
static Term
p_rem(Term t1, Term t2) {
p_rem(Term t1, Term t2 USES_REGS) {
switch (ETypeOfTerm(t1)) {
case (CELL)long_int_e:
switch (ETypeOfTerm(t2)) {
@ -215,7 +215,7 @@ p_rem(Term t1, Term t2) {
static Term
p_rdiv(Term t1, Term t2) {
p_rdiv(Term t1, Term t2 USES_REGS) {
#ifdef USE_GMP
switch (ETypeOfTerm(t1)) {
case (CELL)double_e:
@ -266,7 +266,7 @@ p_rdiv(Term t1, Term t2) {
Floating point division: /
*/
static Term
p_fdiv(Term t1, Term t2)
p_fdiv(Term t1, Term t2 USES_REGS)
{
switch (ETypeOfTerm(t1)) {
case long_int_e:
@ -338,7 +338,7 @@ p_fdiv(Term t1, Term t2)
xor #
*/
static Term
p_xor(Term t1, Term t2)
p_xor(Term t1, Term t2 USES_REGS)
{
switch (ETypeOfTerm(t1)) {
case long_int_e:
@ -382,7 +382,7 @@ p_xor(Term t1, Term t2)
atan2: arc tangent x/y
*/
static Term
p_atan2(Term t1, Term t2)
p_atan2(Term t1, Term t2 USES_REGS)
{
switch (ETypeOfTerm(t1)) {
case long_int_e:
@ -461,7 +461,7 @@ p_atan2(Term t1, Term t2)
power: x^y
*/
static Term
p_power(Term t1, Term t2)
p_power(Term t1, Term t2 USES_REGS)
{
switch (ETypeOfTerm(t1)) {
case long_int_e:
@ -577,7 +577,7 @@ ipow(Int x, Int p)
power: x^y
*/
static Term
p_exp(Term t1, Term t2)
p_exp(Term t1, Term t2 USES_REGS)
{
switch (ETypeOfTerm(t1)) {
case long_int_e:
@ -669,7 +669,7 @@ p_exp(Term t1, Term t2)
}
static Int
gcd(Int m11,Int m21)
gcd(Int m11,Int m21 USES_REGS)
{
/* Blankinship algorithm, provided by Miguel Filgueiras */
Int m12=1, m22=0, k;
@ -719,7 +719,7 @@ Int gcdmult(Int m11,Int m21,Int *pm11) /* *pm11 gets multiplier of m11 */
module gcd
*/
static Term
p_gcd(Term t1, Term t2)
p_gcd(Term t1, Term t2 USES_REGS)
{
switch (ETypeOfTerm(t1)) {
case long_int_e:
@ -731,7 +731,7 @@ p_gcd(Term t1, Term t2)
i1 = (i1 >= 0 ? i1 : -i1);
i2 = (i2 >= 0 ? i2 : -i2);
RINT(gcd(i1,i2));
RINT(gcd(i1,i2 PASS_REGS));
}
case double_e:
return Yap_ArithError(TYPE_ERROR_INTEGER, t2, "gcd/2");
@ -957,56 +957,57 @@ p_max(Term t1, Term t2)
}
static Term
eval2(Int fi, Term t1, Term t2) {
eval2(Int fi, Term t1, Term t2 USES_REGS) {
arith2_op f = fi;
switch (f) {
case op_plus:
return p_plus(t1, t2);
return p_plus(t1, t2 PASS_REGS);
case op_minus:
return p_minus(t1, t2);
return p_minus(t1, t2 PASS_REGS);
case op_times:
return p_times(t1, t2);
return p_times(t1, t2 PASS_REGS);
case op_div:
return p_div(t1, t2);
return p_div(t1, t2 PASS_REGS);
case op_idiv:
return p_div2(t1, t2);
return p_div2(t1, t2 PASS_REGS);
case op_and:
return p_and(t1, t2);
return p_and(t1, t2 PASS_REGS);
case op_or:
return p_or(t1, t2);
return p_or(t1, t2 PASS_REGS);
case op_sll:
return p_sll(t1, t2);
return p_sll(t1, t2 PASS_REGS);
case op_slr:
return p_slr(t1, t2);
return p_slr(t1, t2 PASS_REGS);
case op_mod:
return p_mod(t1, t2);
return p_mod(t1, t2 PASS_REGS);
case op_rem:
return p_rem(t1, t2);
return p_rem(t1, t2 PASS_REGS);
case op_fdiv:
return p_fdiv(t1, t2);
return p_fdiv(t1, t2 PASS_REGS);
case op_xor:
return p_xor(t1, t2);
return p_xor(t1, t2 PASS_REGS);
case op_atan2:
return p_atan2(t1, t2);
return p_atan2(t1, t2 PASS_REGS);
case op_power:
return p_exp(t1, t2);
return p_exp(t1, t2 PASS_REGS);
case op_power2:
return p_power(t1, t2);
return p_power(t1, t2 PASS_REGS);
case op_gcd:
return p_gcd(t1, t2);
return p_gcd(t1, t2 PASS_REGS);
case op_min:
return p_min(t1, t2);
case op_max:
return p_max(t1, t2);
case op_rdiv:
return p_rdiv(t1, t2);
return p_rdiv(t1, t2 PASS_REGS);
}
RERROR();
}
Term Yap_eval_binary(Int f, Term t1, Term t2)
{
return eval2(f,t1,t2);
CACHE_REGS
return eval2(f,t1,t2 PASS_REGS);
}
static InitBinEntry InitBinTab[] = {
@ -1058,7 +1059,7 @@ p_binary_is( USES_REGS1 )
return FALSE;
}
if (IsIntTerm(t)) {
Term tout = Yap_FoundArithError(eval2(IntOfTerm(t), t1, t2), 0L);
Term tout = Yap_FoundArithError(eval2(IntOfTerm(t), t1, t2 PASS_REGS), 0L);
if (!tout)
return FALSE;
return Yap_unify_constant(ARG1,tout);
@ -1081,7 +1082,7 @@ p_binary_is( USES_REGS1 )
P = FAILCODE;
return(FALSE);
}
if (!(out=Yap_FoundArithError(eval2(p->FOfEE, t1, t2), 0L)))
if (!(out=Yap_FoundArithError(eval2(p->FOfEE, t1, t2 PASS_REGS), 0L)))
return FALSE;
return Yap_unify_constant(ARG1,out);
}
@ -1105,7 +1106,7 @@ do_arith23(arith2_op op USES_REGS)
t2 = Yap_Eval(Deref(ARG2));
if (t2 == 0L)
return FALSE;
if (!(out=Yap_FoundArithError(eval2(op, t1, t2), 0L)))
if (!(out=Yap_FoundArithError(eval2(op, t1, t2 PASS_REGS), 0L)))
return FALSE;
return Yap_unify_constant(ARG3,out);
}

View File

@ -320,6 +320,7 @@ Yap_MkULLIntTerm(YAP_ULONG_LONG n)
/* try to scan it as a bignum */
mpz_init_set_str (new, tmp, 10);
if (mpz_fits_slong_p(new)) {
CACHE_REGS
return MkIntegerTerm(mpz_get_si(new));
}
t = Yap_MkBigIntTerm(new);
@ -346,6 +347,38 @@ p_is_bignum( USES_REGS1 )
#endif
}
static Int
p_nb_set_bit( USES_REGS1 )
{
#ifdef USE_GMP
Term t = Deref(ARG1);
Term ti = Deref(ARG2);
Int i;
if (!(
IsNonVarTerm(t) &&
IsApplTerm(t) &&
FunctorOfTerm(t) == FunctorBigInt &&
RepAppl(t)[1] == BIG_INT
))
return FALSE;
if (!IsIntegerTerm(ti)) {
return FALSE;
}
if (!IsIntegerTerm(ti)) {
return FALSE;
}
i = IntegerOfTerm(ti);
if (i < 0) {
return FALSE;
}
Yap_gmp_set_bit(i, t);
return TRUE;
#else
return FALSE;
#endif
}
static Int
p_has_bignums( USES_REGS1 )
{
@ -560,4 +593,5 @@ Yap_InitBigNums(void)
Yap_InitCPred("$bignum", 1, p_is_bignum, SafePredFlag);
Yap_InitCPred("rational", 3, p_rational, 0);
Yap_InitCPred("rational", 1, p_is_rational, SafePredFlag);
Yap_InitCPred("nb_set_bit", 2, p_nb_set_bit, SafePredFlag);
}

View File

@ -738,6 +738,7 @@ YAP_IsCompoundTerm(Term t)
X_API Term
YAP_MkIntTerm(Int n)
{
CACHE_REGS
Term I;
BACKUP_H();
@ -854,6 +855,7 @@ YAP_BlobOfTerm(Term t)
X_API Term
YAP_MkFloatTerm(double n)
{
CACHE_REGS
Term t;
BACKUP_H();
@ -3734,6 +3736,7 @@ YAP_CloseList(Term t0, Term tail)
X_API int
YAP_IsAttVar(Term t)
{
CACHE_REGS
t = Deref(t);
if (!IsVarTerm(t))
return FALSE;
@ -3743,6 +3746,7 @@ YAP_IsAttVar(Term t)
X_API Term
YAP_AttsOfVar(Term t)
{
CACHE_REGS
attvar_record *attv;
t = Deref(t);
@ -4023,6 +4027,7 @@ YAP_TagOfTerm(Term t)
if (IsVarTerm(t)) {
CELL *pt = VarOfTerm(t);
if (IsUnboundVar(pt)) {
CACHE_REGS
if (IsAttVar(pt))
return YAP_TAG_ATT;
return YAP_TAG_UNBOUND;

View File

@ -4910,6 +4910,7 @@ replace_integer(Term orig, UInt new)
return MkIntTerm(new);
/* should create an old integer */
if (!IsApplTerm(orig)) {
CACHE_REGS
Yap_Error(SYSTEM_ERROR,orig,"%uld-->%uld where it should increase",(unsigned long int)IntegerOfTerm(orig),(unsigned long int)new);
return MkIntegerTerm(new);
}

View File

@ -471,6 +471,7 @@ ShowOp (char *f, struct PSEUDO *cpc)
case 'b':
/* write a variable bitmap for a call */
{
CACHE_REGS
int max = arg/(8*sizeof(CELL)), i;
CELL *ptr = cptr;
for (i = 0; i <= max; i++) {
@ -490,7 +491,10 @@ ShowOp (char *f, struct PSEUDO *cpc)
}
break;
case 'd':
Yap_DebugPlWrite (MkIntegerTerm (arg));
{
CACHE_REGS
Yap_DebugPlWrite (MkIntegerTerm (arg));
}
break;
case 'z':
Yap_DebugPlWrite (MkIntTerm (cpc->rnd3));

View File

@ -2381,6 +2381,7 @@ GetDBLUKey(PredEntry *ap)
{
PELOCK(63,ap);
if (ap->PredFlags & NumberDBPredFlag) {
CACHE_REGS
Int id = ap->src.IndxId;
UNLOCK(ap->PELock);
return MkIntegerTerm(id);
@ -2430,6 +2431,7 @@ UnifyDBKey(DBRef DBSP, PropFlags flags, Term t)
static int
UnifyDBNumber(DBRef DBSP, Term t)
{
CACHE_REGS
DBProp p = DBSP->Parent;
DBRef ref;
Int i = 1;

211
C/exo.c
View File

@ -40,50 +40,67 @@
#define MAX_ARITY 256
#define FNV32_PRIME 16777619
#define FNV64_PRIME ((UInt)1099511628211)
#define FNV32_OFFSET 2166136261
#define FNV64_OFFSET ((UInt)14695981039346656037)
/* Simple hash function:
first component is the base key.
hash0 spreads extensions coming from different elements.
spread over j quadrants.
*/
static UInt
HASH(UInt hash0, UInt j, CELL *cl, struct index_t *it)
static BITS32
HASH(UInt arity, CELL *cl, UInt bnds[], UInt sz)
{
Term t = cl[j];
UInt sz = it->hsize;
if (IsIntTerm(t))
return (17*(IntOfTerm(t) + (hash0+1)*j ) ) % sz;
return (17*(((UInt)AtomOfTerm(t)>>5) + (hash0+1)*j ) ) % sz;
UInt hash;
UInt j=0;
hash = FNV32_OFFSET;
while (j < arity) {
if (bnds[j]) {
unsigned char *i=(unsigned char*)(cl+j);
unsigned char *m=(unsigned char*)(cl+(j+1));
while (i < m) {
hash = hash ^ i[0];
hash = hash * FNV32_PRIME;
i++;
}
}
j++;
}
return hash;
}
static UInt
NEXT(UInt hash, Term t, UInt j, struct index_t *it)
static BITS32
NEXT(UInt hash)
{
return (hash+(j+1)*997) % (it->hsize);
return (hash*997);
}
/* search for matching elements */
static int
MATCH(CELL *clp, CELL *kvp, UInt j, struct index_t *it, UInt bnds[])
MATCH(CELL *clp, CELL *kvp, UInt arity, UInt bnds[])
{
if ((kvp - it->cls)%it->arity != j)
return FALSE;
do {
if ( bnds[j] && *clp != *kvp)
UInt j = 0;
while (j< arity) {
if ( bnds[j] && clp[j] != kvp[j])
return FALSE;
clp--;
kvp--;
} while (j-- != 0);
j++;
}
return TRUE;
}
static void
ADD_TO_TRY_CHAIN(CELL *kvp, CELL *cl, struct index_t *it)
{
UInt old = (kvp-it->cls)/it->arity;
UInt new = (cl-it->cls)/it->arity;
UInt *links = it->links;
UInt tmp = links[old]; /* points to the end of the chain */
BITS32 old = (kvp-it->cls)/it->arity;
BITS32 new = (cl-it->cls)/it->arity;
BITS32 *links = it->links;
BITS32 tmp = links[old]; /* points to the end of the chain */
if (!tmp) {
links[old] = links[new] = new;
@ -111,50 +128,33 @@ ADD_TO_TRY_CHAIN(CELL *kvp, CELL *cl, struct index_t *it)
* match ci..j ck..j -> find j = minarg(cij \= c2j)
* else
*/
static void
INSERT(CELL *cl, struct index_t *it, UInt arity, UInt base, UInt hash0, UInt bnds[])
static int
INSERT(CELL *cl, struct index_t *it, UInt arity, UInt base, UInt bnds[])
{
UInt j = base;
CELL *kvp;
UInt hash;
BITS32 hash;
int coll_count = 0;
/* skip over argument */
while (!bnds[j]) {
j++;
}
/* j is the firs bound element */
/* check if we match */
hash = hash0 = HASH(hash0, j, cl, it);
//if (exo_write) printf("h=%ld j=%ld %lx\n", hash, j, cl[j]);
hash = HASH(arity, cl, bnds, it->hsize);
next:
/* loop to insert element */
kvp = it->key[hash];
kvp = EXO_OFFSET_TO_ADDRESS(it, it->key [hash % it->hsize]);
if (kvp == NULL) {
/* simple case, new entry */
it->nentries++;
it->key[hash] = cl+j;
return;
} else if (MATCH(cl+j, kvp, j, it, bnds)) {
/* collision */
UInt k;
CELL *target;
for (k =j+1, target = kvp+1; k < arity; k++,target++ ) {
if (bnds[k]) {
if (*target != cl[k]) {
/* found a new forking point */
// printf("j=%ld hash0=%ld cl[j]=%lx\n", j, hash0, cl[j]);
INSERT(cl, it, arity, k, hash0, bnds);
return;
}
}
}
it->key[hash % it->hsize ] = EXO_ADDRESS_TO_OFFSET(it, cl);
return TRUE;
} else if (MATCH(kvp, cl, arity, bnds)) {
it->ntrys++;
ADD_TO_TRY_CHAIN(kvp, cl, it);
return;
return TRUE;
} else {
coll_count++;
if (coll_count == 32)
return FALSE;
it->ncollisions++;
hash = NEXT(hash, cl[j], j, it);
// printf("#");
hash = NEXT(hash);
//if (exo_write) printf("N=%ld\n", hash);
goto next;
}
@ -165,45 +165,31 @@ LOOKUP(struct index_t *it, UInt arity, UInt j, UInt bnds[])
{
CACHE_REGS
CELL *kvp;
UInt hash, hash0 = 0;
BITS32 hash;
/* j is the firs bound element */
/* check if we match */
hash:
hash = hash0 = HASH(hash0, j, XREGS+1, it);
hash = HASH(arity, XREGS+1, bnds, it->hsize);
next:
/* loop to insert element */
kvp = it->key[hash];
kvp = EXO_OFFSET_TO_ADDRESS(it, it->key[hash % it->hsize]);
if (kvp == NULL) {
/* simple case, no element */
return FAILCODE;
} else if (MATCH(XREGS+(j+1), kvp, j, it, bnds)) {
/* found element */
UInt k;
CELL *target;
for (k =j+1, target = kvp+1; k < arity; k++ ) {
if (bnds[k]) {
if (*target != XREGS[k+1]) {
j = k;
goto hash;
}
}
target++;
}
S = target-arity;
} else if (MATCH(kvp, XREGS+1, arity, bnds)) {
S = kvp;
if (!it->is_key && it->links[(S-it->cls)/arity])
return it->code;
else
return NEXTOP(NEXTOP(it->code,lp),lp);
} else {
/* collision */
hash = NEXT(hash, XREGS[j+1], j, it);
hash = NEXT(hash);
goto next;
}
}
static void
static int
fill_hash(UInt bmap, struct index_t *it, UInt bnds[])
{
UInt i;
@ -211,12 +197,13 @@ fill_hash(UInt bmap, struct index_t *it, UInt bnds[])
CELL *cl = it->cls;
for (i=0; i < it->nels; i++) {
INSERT(cl, it, arity, 0, 0, bnds);
if (!INSERT(cl, it, arity, 0, bnds))
return FALSE;
cl += arity;
}
for (i=0; i < it->hsize; i++) {
if (it->key[i]) {
UInt offset = (it->key[i]-it->cls)/arity;
UInt offset = it->key[i]/arity;
UInt last = it->links[offset];
if (last) {
/* the chain used to point straight to the last, and the last back to the origibal first */
@ -225,6 +212,7 @@ fill_hash(UInt bmap, struct index_t *it, UInt bnds[])
}
}
}
return TRUE;
}
static struct index_t *
@ -246,6 +234,7 @@ add_index(struct index_t **ip, UInt bmap, PredEntry *ap, UInt count, UInt bnds[]
Yap_Error(OUT_OF_HEAP_ERROR, TermNil, LOCAL_ErrorMessage);
return NULL;
}
i->is_key = FALSE;
i->next = *ip;
i->prev = NULL;
i->nels = ncls;
@ -255,7 +244,7 @@ add_index(struct index_t **ip, UInt bmap, PredEntry *ap, UInt count, UInt bnds[]
i->is_key = FALSE;
i->hsize = 2*ncls;
if (count) {
if (!(base = (CELL *)Yap_AllocCodeSpace(sizeof(CELL)*(ncls+i->hsize)))) {
if (!(base = (CELL *)Yap_AllocCodeSpace(sizeof(BITS32)*(ncls+i->hsize)))) {
CACHE_REGS
save_machine_regs();
LOCAL_Error_Size = sizeof(CELL)*(ncls+i->hsize);
@ -267,18 +256,51 @@ add_index(struct index_t **ip, UInt bmap, PredEntry *ap, UInt count, UInt bnds[]
bzero(base, sizeof(CELL)*(ncls+i->hsize));
}
i->size = sizeof(CELL)*(ncls+i->hsize)+sz+sizeof(struct index_t);
i->key = (CELL **)base;
i->key = (CELL *)base;
i->links = (CELL *)(base+i->hsize);
i->ncollisions = i->nentries = i->ntrys = 0;
i->cls = (CELL *)((ADDR)ap->cs.p_code.FirstClause+2*sizeof(struct index_t *));
*ip = i;
if (count) {
fill_hash(bmap, i, bnds);
printf("entries=%ld collisions=%ld trys=%ld\n", i->nentries, i->ncollisions, i->ntrys);
if (!i->ntrys) {
i->is_key = TRUE;
if (base != realloc(base, i->hsize*sizeof(CELL)))
while (count) {
if (!fill_hash(bmap, i, bnds)) {
size_t sz;
i->hsize += ncls;
if (i->is_key) {
sz = i->hsize*sizeof(BITS32);
} else {
sz = (ncls+i->hsize)*sizeof(BITS32);
}
if (base != realloc(base, sz))
return FALSE;
bzero(base, sz);
i->key = (CELL *)base;
i->links = (CELL *)(base+i->hsize);
i->ncollisions = i->nentries = i->ntrys = 0;
continue;
}
fprintf(stderr, "entries=%ld collisions=%ld trys=%ld\n", i->nentries, i->ncollisions, i->ntrys);
if (!i->ntrys && !i->is_key) {
i->is_key = TRUE;
if (base != realloc(base, i->hsize*sizeof(BITS32)))
return FALSE;
}
/* our hash table is just too large */
if (( i->nentries+i->ncollisions )*10 < i->hsize) {
size_t sz;
i->hsize = ( i->nentries+i->ncollisions )*10;
if (i->is_key) {
sz = i->hsize*sizeof(BITS32);
} else {
sz = (ncls+i->hsize)*sizeof(BITS32);
}
if (base != realloc(base, sz))
return FALSE;
bzero(base, sz);
i->key = (CELL *)base;
i->links = (CELL *)(base+i->hsize);
i->ncollisions = i->nentries = i->ntrys = 0;
} else {
break;
}
}
ptr = (yamop *)(i+1);
@ -337,14 +359,11 @@ Yap_ExoLookup(PredEntry *ap USES_REGS)
}
while (i) {
if (i->is_key) {
if ((i->bmap & bmap) == i->bmap) {
break;
}
} else {
if (i->bmap == bmap) {
break;
}
// if (i->is_key && (i->bmap & bmap) == i->bmap) {
// break;
// }
if (i->bmap == bmap) {
break;
}
ip = &i->next;
i = i->next;
@ -362,9 +381,9 @@ CELL
Yap_NextExo(choiceptr cptr, struct index_t *it)
{
CACHE_REGS
CELL offset = EXO_ADDRESS_TO_OFFSET(it,(CELL *)((CELL *)(B+1))[it->arity]);
CELL offset = ADDRESS_TO_LINK(it,(CELL *)((CELL *)(B+1))[it->arity]);
CELL next = it->links[offset];
((CELL *)(B+1))[it->arity] = (CELL)EXO_OFFSET_TO_ADDRESS(it, next);
((CELL *)(B+1))[it->arity] = (CELL)LINK_TO_ADDRESS(it, next);
S = it->cls+it->arity*offset;
return next;
}

View File

@ -2121,6 +2121,7 @@ p_nb_beam_close( USES_REGS1 )
static void
PushBeam(CELL *pt, CELL *npt, UInt hsize, Term key, Term to)
{
CACHE_REGS
UInt off = hsize, off2 = hsize;
Term toff, toff2;
@ -2166,6 +2167,7 @@ PushBeam(CELL *pt, CELL *npt, UInt hsize, Term key, Term to)
static void
DelBeamMax(CELL *pt, CELL *pt2, UInt sz)
{
CACHE_REGS
UInt off = IntegerOfTerm(pt2[1]);
UInt indx = 0;
Term tk, ti, tv;
@ -2240,6 +2242,7 @@ DelBeamMax(CELL *pt, CELL *pt2, UInt sz)
static Term
DelBeamMin(CELL *pt, CELL *pt2, UInt sz)
{
CACHE_REGS
UInt off2 = IntegerOfTerm(pt[1]);
Term ov = pt2[3*off2+2]; /* return value */
UInt indx = 0;

View File

@ -132,6 +132,14 @@ Yap_gmp_add_int_big(Int i, Term t)
}
}
/* add i + b using temporary bigint new */
void
Yap_gmp_set_bit(Int i, Term t)
{
MP_INT *b = Yap_BigIntOfTerm(t);
mpz_setbit(b, i);
}
/* sub i - b using temporary bigint new */
Term
Yap_gmp_sub_int_big(Int i, Term t)
@ -384,6 +392,7 @@ Yap_gmp_sll_big_int(Term t, Int i)
} else {
mpz_init(&new);
if (i == Int_MIN) {
CACHE_REGS
return Yap_ArithError(RESOURCE_ERROR_HUGE_INT, MkIntegerTerm(i), "<</2");
}
mpz_fdiv_q_2exp(&new, b, -i);
@ -706,6 +715,7 @@ Yap_gmp_mod_big_int(Term t, Int i2)
Term
Yap_gmp_mod_int_big(Int i1, Term t)
{
CACHE_REGS
CELL *pt = RepAppl(t);
if (pt[1] != BIG_INT) {
return Yap_ArithError(TYPE_ERROR_INTEGER, t, "mod/2");
@ -782,6 +792,7 @@ Yap_gmp_rem_big_int(Term t, Int i2)
Term
Yap_gmp_rem_int_big(Int i1, Term t)
{
CACHE_REGS
CELL *pt = RepAppl(t);
if (pt[1] != BIG_INT) {
return Yap_ArithError(TYPE_ERROR_INTEGER, t, "mod/2");
@ -815,6 +826,7 @@ Yap_gmp_gcd_big_big(Term t1, Term t2)
Term
Yap_gmp_gcd_int_big(Int i, Term t)
{
CACHE_REGS
CELL *pt = RepAppl(t);
if (pt[1] != BIG_INT) {
return Yap_ArithError(TYPE_ERROR_INTEGER, t, "mod/2");
@ -855,6 +867,7 @@ Yap_gmp_to_float(Term t)
Term
Yap_gmp_add_float_big(Float d, Term t)
{
CACHE_REGS
CELL *pt = RepAppl(t);
if (pt[1] == BIG_INT) {
MP_INT *b = Yap_BigIntOfTerm(t);
@ -868,6 +881,7 @@ Yap_gmp_add_float_big(Float d, Term t)
Term
Yap_gmp_sub_float_big(Float d, Term t)
{
CACHE_REGS
CELL *pt = RepAppl(t);
if (pt[1] == BIG_INT) {
MP_INT *b = Yap_BigIntOfTerm(t);
@ -881,6 +895,7 @@ Yap_gmp_sub_float_big(Float d, Term t)
Term
Yap_gmp_sub_big_float(Term t, Float d)
{
CACHE_REGS
CELL *pt = RepAppl(t);
if (pt[1] == BIG_INT) {
MP_INT *b = Yap_BigIntOfTerm(t);
@ -894,6 +909,7 @@ Yap_gmp_sub_big_float(Term t, Float d)
Term
Yap_gmp_mul_float_big(Float d, Term t)
{
CACHE_REGS
CELL *pt = RepAppl(t);
if (pt[1] == BIG_INT) {
MP_INT *b = Yap_BigIntOfTerm(t);
@ -907,6 +923,7 @@ Yap_gmp_mul_float_big(Float d, Term t)
Term
Yap_gmp_fdiv_float_big(Float d, Term t)
{
CACHE_REGS
CELL *pt = RepAppl(t);
if (pt[1] == BIG_INT) {
MP_INT *b = Yap_BigIntOfTerm(t);
@ -920,6 +937,7 @@ Yap_gmp_fdiv_float_big(Float d, Term t)
Term
Yap_gmp_fdiv_big_float(Term t, Float d)
{
CACHE_REGS
CELL *pt = RepAppl(t);
if (pt[1] == BIG_INT) {
MP_INT *b = Yap_BigIntOfTerm(t);
@ -943,6 +961,7 @@ Yap_gmp_exp_int_int(Int i1, Int i2)
Term
Yap_gmp_exp_big_int(Term t, Int i)
{
CACHE_REGS
MP_INT new;
CELL *pt = RepAppl(t);
@ -969,6 +988,7 @@ Yap_gmp_exp_big_int(Term t, Int i)
Term
Yap_gmp_exp_int_big(Int i, Term t)
{
CACHE_REGS
CELL *pt = RepAppl(t);
if (pt[1] == BIG_INT) {
return Yap_ArithError(RESOURCE_ERROR_HUGE_INT, t, "^/2");
@ -982,6 +1002,7 @@ Yap_gmp_exp_int_big(Int i, Term t)
Term
Yap_gmp_exp_big_big(Term t1, Term t2)
{
CACHE_REGS
CELL *pt1 = RepAppl(t1);
CELL *pt2 = RepAppl(t2);
Float dbl1, dbl2;
@ -1116,6 +1137,7 @@ Yap_gmq_rdiv_big_big(Term t1, Term t2)
Term
Yap_gmp_fdiv_int_big(Int i1, Term t2)
{
CACHE_REGS
MP_RAT new;
MP_RAT *b1, *b2;
MP_RAT bb1, bb2;
@ -1142,6 +1164,7 @@ Yap_gmp_fdiv_int_big(Int i1, Term t2)
Term
Yap_gmp_fdiv_big_int(Term t2, Int i1)
{
CACHE_REGS
MP_RAT new;
MP_RAT *b1, *b2;
MP_RAT bb1, bb2;
@ -1168,6 +1191,7 @@ Yap_gmp_fdiv_big_int(Term t2, Int i1)
Term
Yap_gmp_fdiv_big_big(Term t1, Term t2)
{
CACHE_REGS
CELL *pt1 = RepAppl(t1);
CELL *pt2 = RepAppl(t2);
MP_RAT new;
@ -1602,6 +1626,7 @@ Yap_gmp_float_integer_part(Term t)
Term
Yap_gmp_sign(Term t)
{
CACHE_REGS
CELL *pt = RepAppl(t);
if (pt[1] == BIG_INT) {
return MkIntegerTerm(mpz_sgn(Yap_BigIntOfTerm(t)));
@ -1613,6 +1638,7 @@ Yap_gmp_sign(Term t)
Term
Yap_gmp_lsb(Term t)
{
CACHE_REGS
CELL *pt = RepAppl(t);
if (pt[1] == BIG_INT) {
MP_INT *big = Yap_BigIntOfTerm(t);
@ -1629,6 +1655,7 @@ Yap_gmp_lsb(Term t)
Term
Yap_gmp_msb(Term t)
{
CACHE_REGS
CELL *pt = RepAppl(t);
if (pt[1] == BIG_INT) {
MP_INT *big = Yap_BigIntOfTerm(t);
@ -1645,6 +1672,7 @@ Yap_gmp_msb(Term t)
Term
Yap_gmp_popcount(Term t)
{
CACHE_REGS
CELL *pt = RepAppl(t);
if (pt[1] == BIG_INT) {
MP_INT *big = Yap_BigIntOfTerm(t);

View File

@ -1923,10 +1923,7 @@ suspend_indexing(ClauseDef *min, ClauseDef *max, PredEntry *ap, struct intermedi
} else {
Yap_IndexSpace_EXT += sz;
}
{
CACHE_REGS
Yap_inform_profiler_of_clause(ncode, (CODEADDR)ncode+sz, ap, GPROF_NEW_EXPAND_BLOCK);
}
Yap_inform_profiler_of_clause(ncode, (CODEADDR)ncode+sz, ap, GPROF_NEW_EXPAND_BLOCK);
/* create an expand_block */
ncode->opc = Yap_opcode(_expand_clauses);
ncode->u.sssllp.p = ap;

View File

@ -882,6 +882,8 @@ Yap_InitCPredBack(char *Name, unsigned long int Arity,
static void
InitStdPreds(void)
{
void initIO(void);
Yap_InitCPreds();
Yap_InitBackCPreds();
BACKUP_MACHINE_REGS();
@ -1288,17 +1290,21 @@ Yap_InitWorkspace(UInt Heap, UInt Stack, UInt Trail, UInt Atts, UInt max_table_s
if (Heap < MinHeapSpace)
Heap = MinHeapSpace;
Heap = AdjustPageSize(Heap * K);
Heap /= (K);
/* sanity checking for data areas */
if (Trail < MinTrailSpace)
Trail = MinTrailSpace;
Trail = AdjustPageSize(Trail * K);
Trail /= (K);
if (Stack < MinStackSpace)
Stack = MinStackSpace;
Stack = AdjustPageSize(Stack * K);
Stack /= (K);
if (!Atts)
Atts = 2048*sizeof(CELL);
else
Atts = AdjustPageSize(Atts * K);
Atts /= (K);
#if defined(YAPOR) || defined(THREADS)
worker_id = 0;
#endif /* YAPOR || THREADS */

View File

@ -261,7 +261,7 @@ open_file(char *my_file, int flag)
#endif /* O_BINARY */
#endif /* M_WILLIAMS */
{
splfild = 0; /* We do not have an open file */
splfild = -1; /* We do not have an open file */
return -1;
}
#ifdef undf0
@ -1466,7 +1466,7 @@ OpenRestore(char *inpf, char *YapLibDir, CELL *Astate, CELL *ATrail, CELL *AStac
} else {
strncat(LOCAL_FileNameBuf, inpf, YAP_FILENAME_MAX-1);
}
if (inpf != NULL && (splfild = open_file(inpf, O_RDONLY)) > 0) {
if (inpf != NULL && !((splfild = open_file(inpf, O_RDONLY)) < 0)) {
if ((mode = try_open(inpf,Astate,ATrail,AStack,AHeap,save_buffer,streamp)) != FAIL_RESTORE) {
return mode;
}
@ -1499,7 +1499,7 @@ OpenRestore(char *inpf, char *YapLibDir, CELL *Astate, CELL *ATrail, CELL *AStac
#endif
if (YAP_LIBDIR != NULL) {
cat_file_name(LOCAL_FileNameBuf, YAP_LIBDIR, inpf, YAP_FILENAME_MAX);
if ((splfild = open_file(LOCAL_FileNameBuf, O_RDONLY)) > 0) {
if (!((splfild = open_file(LOCAL_FileNameBuf, O_RDONLY)) < 0)) {
if ((mode = try_open(LOCAL_FileNameBuf,Astate,ATrail,AStack,AHeap,save_buffer,streamp)) != FAIL_RESTORE) {
return mode;
}
@ -1508,7 +1508,7 @@ OpenRestore(char *inpf, char *YapLibDir, CELL *Astate, CELL *ATrail, CELL *AStac
}
#if _MSC_VER || defined(__MINGW32__)
if ((inpf = Yap_RegistryGetString("startup"))) {
if ((splfild = open_file(inpf, O_RDONLY)) > 0) {
if (!((splfild = open_file(inpf, O_RDONLY)) < 0)) {
if ((mode = try_open(inpf,Astate,ATrail,AStack,AHeap,save_buffer,streamp)) != FAIL_RESTORE) {
return mode;
}

View File

@ -230,6 +230,7 @@ extern double atof(const char *);
static Term
float_send(char *s, int sign)
{
CACHE_REGS
Float f = (Float)atof(s);
#if HAVE_FINITE
if (yap_flags[LANGUAGE_MODE_FLAG] == 1) { /* iso */
@ -512,6 +513,7 @@ num_send_error_message(char s[])
static Term
get_num(int *chp, int *chbuffp, IOSTREAM *inp_stream, char *s, UInt max_size, int sign)
{
CACHE_REGS
char *sp = s;
int ch = *chp;
Int val = 0L, base = ch - '0';

View File

@ -74,9 +74,6 @@ p_creep( USES_REGS1 )
static Int
p_stop_creeping( USES_REGS1 )
{
Atom at;
PredEntry *pred;
LOCK(LOCAL_SignalLock);
LOCAL_ActiveSignals &= ~(YAP_CREEP_SIGNAL|YAP_DELAY_CREEP_SIGNAL);
if (!LOCAL_ActiveSignals) {

View File

@ -292,7 +292,7 @@ STD_PROTO(static Int p_values, ( USES_REGS1 ));
STD_PROTO(static CODEADDR *FindAtom, (CODEADDR, int *));
#endif /* undefined */
STD_PROTO(static Int p_opdec, ( USES_REGS1 ));
STD_PROTO(static Term get_num, (char *));
STD_PROTO(static Term get_num, (char * USES_REGS));
STD_PROTO(static Int p_name, ( USES_REGS1 ));
STD_PROTO(static Int p_atom_chars, ( USES_REGS1 ));
STD_PROTO(static Int p_atom_codes, ( USES_REGS1 ));
@ -537,7 +537,7 @@ strtod(s, pe)
#endif
static Term
get_num(char *t)
get_num(char *t USES_REGS)
{
Term out;
IOSTREAM *smem = Sopenmem(&t, NULL, "r");
@ -832,7 +832,7 @@ p_name( USES_REGS1 )
return(FALSE);
}
if (IsAtomTerm(t) && AtomOfTerm(t) == AtomNil) {
if ((NewT = get_num(String)) == TermNil) {
if ((NewT = get_num(String PASS_REGS)) == TermNil) {
Atom at;
while ((at = Yap_LookupAtom(String)) == NIL) {
if (!Yap_growheap(FALSE, 0, NULL)) {
@ -1375,7 +1375,7 @@ p_atom_concat( USES_REGS1 )
if (wide_mode) {
wchar_t *cptr = (wchar_t *)(((AtomEntry *)Yap_PreAllocCodeSpace())->StrOfAE), *cpt0;
wchar_t *top = (wchar_t *)AuxSp;
unsigned char *atom_str;
unsigned char *atom_str = NULL;
Atom ahead;
UInt sz;
@ -2227,7 +2227,7 @@ p_number_chars( USES_REGS1 )
}
}
*s++ = '\0';
if ((NewT = get_num(String)) == TermNil) {
if ((NewT = get_num(String PASS_REGS)) == TermNil) {
Yap_Error(SYNTAX_ERROR, gen_syntax_error(Yap_LookupAtom(String), "number_chars"), "while scanning %s", String);
return (FALSE);
}
@ -2294,7 +2294,7 @@ p_number_atom( USES_REGS1 )
return(FALSE);
}
s = RepAtom(AtomOfTerm(t))->StrOfAE;
if ((NewT = get_num(s)) == TermNil) {
if ((NewT = get_num(s PASS_REGS)) == TermNil) {
Yap_Error(SYNTAX_ERROR, gen_syntax_error(Yap_LookupAtom(String), "number_atom"), "while scanning %s", s);
return (FALSE);
}
@ -2387,7 +2387,7 @@ p_number_codes( USES_REGS1 )
}
}
*s++ = '\0';
if ((NewT = get_num(String)) == TermNil) {
if ((NewT = get_num(String PASS_REGS)) == TermNil) {
Yap_Error(SYNTAX_ERROR, gen_syntax_error(Yap_LookupAtom(String), "number_codes"), "while scanning %s", String);
return (FALSE);
}
@ -2452,7 +2452,7 @@ p_atom_number( USES_REGS1 )
return FALSE;
}
s = RepAtom(at)->StrOfAE; /* alloc temp space on Trail */
if ((NewT = get_num(s)) == TermNil) {
if ((NewT = get_num(s PASS_REGS)) == TermNil) {
Yap_Error(SYNTAX_ERROR, gen_syntax_error(at, "atom_number"), "while scanning %s", s);
return FALSE;
}

View File

@ -1520,9 +1520,18 @@ ExportTerm(Term inp, char * buf, size_t len, UInt arity, int newattvs USES_REGS)
Term t = Deref(inp);
tr_fr_ptr TR0 = TR;
size_t res = 0;
CELL *Hi;
CELL *Hi = H;
do {
if (IsVarTerm(t) || IsIntTerm(t)) {
return export_term_to_buffer(t, buf, buf+ 3*sizeof(CELL), &inp, &inp, len);
}
if (IsAtomTerm(t)) {
Atom at = AtomOfTerm(t);
char *b = buf+3*sizeof(CELL);
export_atom(at, &b, b, len-3*sizeof(CELL));
return export_term_to_buffer(t, buf, b, &inp, &inp, len);
}
if ((Int)res < 0) {
H = Hi;
TR = TR0;
@ -1634,16 +1643,14 @@ Yap_ImportTerm(char * buf) {
CELL *bc = (CELL *)buf;
size_t sz = bc[1];
Term tinp, tret;
tinp = bc[2];
if (IsVarTerm(tinp))
return MkVarTerm();
if (IsAtomOrIntTerm(tinp)) {
if (IsAtomTerm(tinp)) {
char *pt = (char *)AdjustSize(bc+3, buf);
return MkAtomTerm(Yap_LookupAtom(pt));
} else
return tinp;
else if (IsIntTerm(tinp))
return tinp;
else if (IsAtomTerm(tinp)) {
tret = MkAtomTerm(AddAtom(NULL,(char *)(bc+3)));
return tret;
}
if (H + sz > ASP)
return (Term)0;
@ -1654,7 +1661,7 @@ Yap_ImportTerm(char * buf) {
} else {
tret = AbsPair(H);
import_pair(H, (char *)H, buf, H);
}
}
H += sz;
return tret;
}
@ -4921,7 +4928,7 @@ numbervar_singleton(USES_REGS1)
}
static void
renumbervar(Term t, Int id)
renumbervar(Term t, Int id USES_REGS)
{
Term *ts = RepAppl(t);
ts[1] = MkIntegerTerm(id);
@ -4975,7 +4982,7 @@ static Int numbervars_in_complex_term(register CELL *pt0, register CELL *pt0_end
continue;
}
if (singles && ap2 >= InitialH && ap2 < H) {
renumbervar(d0, numbv++);
renumbervar(d0, numbv++ PASS_REGS);
continue;
}
/* store the terms to visit */

View File

@ -60,13 +60,14 @@ blob_type;
#include "inline-only.h"
INLINE_ONLY inline EXTERN int IsAttVar (CELL *pt);
#define IsAttVar(pt) __IsAttVar((pt) PASS_REGS)
INLINE_ONLY inline EXTERN int __IsAttVar (CELL *pt USES_REGS);
INLINE_ONLY inline EXTERN int
IsAttVar (CELL *pt)
__IsAttVar (CELL *pt USES_REGS)
{
#ifdef YAP_H
CACHE_REGS
return (pt)[-1] == (CELL)attvar_e
&& pt < H;
#else
@ -182,13 +183,13 @@ INLINE_ONLY inline EXTERN Float CpFloatUnaligned(CELL *ptr);
#if SIZEOF_DOUBLE == SIZEOF_LONG_INT
#define MkFloatTerm(fl) __MkFloatTerm((fl) PASS_REGS)
INLINE_ONLY inline EXTERN Term MkFloatTerm (Float);
INLINE_ONLY inline EXTERN Term __MkFloatTerm (Float USES_REGS);
INLINE_ONLY inline EXTERN Term
MkFloatTerm (Float dbl)
__MkFloatTerm (Float dbl USES_REGS)
{
CACHE_REGS
return (Term) ((H[0] = (CELL) FunctorDouble, *(Float *) (H + 1) =
dbl, H[2] = EndSpecials, H +=
3, AbsAppl (H - 3)));
@ -303,12 +304,13 @@ IsFloatTerm (Term t)
/* extern Functor FunctorLongInt; */
INLINE_ONLY inline EXTERN Term MkLongIntTerm (Int);
#define MkLongIntTerm(i) __MkLongIntTerm((i) PASS_REGS)
INLINE_ONLY inline EXTERN Term __MkLongIntTerm (Int USES_REGS);
INLINE_ONLY inline EXTERN Term
MkLongIntTerm (Int i)
__MkLongIntTerm (Int i USES_REGS)
{
CACHE_REGS
H[0] = (CELL) FunctorLongInt;
H[1] = (CELL) (i);
H[2] = EndSpecials;
@ -546,11 +548,12 @@ IsAttachFunc (Functor f)
#define IsAttachedTerm(t) __IsAttachedTerm(t PASS_REGS)
INLINE_ONLY inline EXTERN Int IsAttachedTerm (Term);
INLINE_ONLY inline EXTERN Int __IsAttachedTerm (Term USES_REGS);
INLINE_ONLY inline EXTERN Int
IsAttachedTerm (Term t)
__IsAttachedTerm (Term t USES_REGS)
{
return (Int) ((IsVarTerm (t) && IsAttVar(VarOfTerm(t))));
}
@ -563,17 +566,16 @@ GlobalIsAttachedTerm (Term t)
return (Int) ((IsVarTerm (t) && GlobalIsAttVar(VarOfTerm(t))));
}
INLINE_ONLY inline EXTERN Int SafeIsAttachedTerm (Term);
#define SafeIsAttachedTerm(t) __SafeIsAttachedTerm((t) PASS_REGS)
INLINE_ONLY inline EXTERN Int __SafeIsAttachedTerm (Term USES_REGS);
INLINE_ONLY inline EXTERN Int
SafeIsAttachedTerm (Term t)
__SafeIsAttachedTerm (Term t USES_REGS)
{
return (Int) (IsVarTerm (t) && IsAttVar(VarOfTerm(t)));
}
INLINE_ONLY inline EXTERN exts ExtFromCell (CELL *);
INLINE_ONLY inline EXTERN exts

View File

@ -364,10 +364,13 @@ MkPairTerm__ (Term head, Term tail USES_REGS)
#define IsAccessFunc(func) ((func) == FunctorAccess)
#ifdef YAP_H
INLINE_ONLY inline EXTERN Term MkIntegerTerm (Int);
#define MkIntegerTerm(i) __MkIntegerTerm(i PASS_REGS)
INLINE_ONLY inline EXTERN Term __MkIntegerTerm (Int USES_REGS);
INLINE_ONLY inline EXTERN Term
MkIntegerTerm (Int n)
__MkIntegerTerm (Int n USES_REGS)
{
return (Term) (IntInBnd (n) ? MkIntTerm (n) : MkLongIntTerm (n));
}

View File

@ -341,7 +341,7 @@ void STD_PROTO(Yap_InitSavePreds,(void));
/* signals.c */
void STD_PROTO(Yap_signal,(yap_signals));
void STD_PROTO(Yap_undo_signal,(yap_signals));
void STD_PROTO(Yap_InitSignalPreds,(void));
void STD_PROTO(Yap_InitSignalCPreds,(void));
/* sort.c */
void STD_PROTO(Yap_InitSortPreds,(void));

View File

@ -26,7 +26,7 @@ add_overflow(Int x, Int i, Int j)
}
inline static Term
add_int(Int i, Int j)
add_int(Int i, Int j USES_REGS)
{
Int x = i+j;
#if USE_GMP
@ -51,7 +51,7 @@ sub_overflow(Int x, Int i, Int j)
}
inline static Term
sub_int(Int i, Int j)
sub_int(Int i, Int j USES_REGS)
{
Int x = i-j;
#if USE_GMP
@ -105,7 +105,7 @@ mul_overflow(Int z, Int i1, Int i2)
#endif
inline static Term
times_int(Int i1, Int i2) {
times_int(Int i1, Int i2 USES_REGS) {
#ifdef USE_GMP
Int z;
DO_MULTI();
@ -151,7 +151,7 @@ clrsb(Int i)
#endif
inline static Term
do_sll(Int i, Int j) /* j > 0 */
do_sll(Int i, Int j USES_REGS) /* j > 0 */
{
#ifdef USE_GMP
if (
@ -174,13 +174,13 @@ do_sll(Int i, Int j) /* j > 0 */
static inline Term
p_plus(Term t1, Term t2) {
p_plus(Term t1, Term t2 USES_REGS) {
switch (ETypeOfTerm(t1)) {
case long_int_e:
switch (ETypeOfTerm(t2)) {
case long_int_e:
/* two integers */
return add_int(IntegerOfTerm(t1),IntegerOfTerm(t2));
return add_int(IntegerOfTerm(t1),IntegerOfTerm(t2) PASS_REGS);
case double_e:
{
/* integer, double */
@ -230,13 +230,13 @@ p_plus(Term t1, Term t2) {
}
static Term
p_minus(Term t1, Term t2) {
p_minus(Term t1, Term t2 USES_REGS) {
switch (ETypeOfTerm(t1)) {
case long_int_e:
switch (ETypeOfTerm(t2)) {
case long_int_e:
/* two integers */
return sub_int(IntegerOfTerm(t1), IntegerOfTerm(t2));
return sub_int(IntegerOfTerm(t1), IntegerOfTerm(t2) PASS_REGS);
case double_e:
{
/* integer, double */
@ -290,13 +290,13 @@ p_minus(Term t1, Term t2) {
static Term
p_times(Term t1, Term t2) {
p_times(Term t1, Term t2 USES_REGS) {
switch (ETypeOfTerm(t1)) {
case long_int_e:
switch (ETypeOfTerm(t2)) {
case long_int_e:
/* two integers */
return(times_int(IntegerOfTerm(t1),IntegerOfTerm(t2)));
return(times_int(IntegerOfTerm(t1),IntegerOfTerm(t2) PASS_REGS));
case double_e:
{
/* integer, double */
@ -348,7 +348,7 @@ p_times(Term t1, Term t2) {
}
static Term
p_div(Term t1, Term t2) {
p_div(Term t1, Term t2 USES_REGS) {
switch (ETypeOfTerm(t1)) {
case long_int_e:
switch (ETypeOfTerm(t2)) {
@ -405,7 +405,7 @@ p_div(Term t1, Term t2) {
}
static Term
p_and(Term t1, Term t2) {
p_and(Term t1, Term t2 USES_REGS) {
switch (ETypeOfTerm(t1)) {
case long_int_e:
switch (ETypeOfTerm(t2)) {
@ -446,7 +446,7 @@ p_and(Term t1, Term t2) {
}
static Term
p_or(Term t1, Term t2) {
p_or(Term t1, Term t2 USES_REGS) {
switch(ETypeOfTerm(t1)) {
case long_int_e:
switch (ETypeOfTerm(t2)) {
@ -487,7 +487,7 @@ p_or(Term t1, Term t2) {
}
static Term
p_sll(Term t1, Term t2) {
p_sll(Term t1, Term t2 USES_REGS) {
switch (ETypeOfTerm(t1)) {
case long_int_e:
switch (ETypeOfTerm(t2)) {
@ -501,7 +501,7 @@ p_sll(Term t1, Term t2) {
}
RINT(SLR(IntegerOfTerm(t1), -i2));
}
return do_sll(IntegerOfTerm(t1),i2);
return do_sll(IntegerOfTerm(t1),i2 PASS_REGS);
}
case double_e:
return Yap_ArithError(TYPE_ERROR_INTEGER, t2, "<</2");
@ -535,7 +535,7 @@ p_sll(Term t1, Term t2) {
}
static Term
p_slr(Term t1, Term t2) {
p_slr(Term t1, Term t2 USES_REGS) {
switch (ETypeOfTerm(t1)) {
case long_int_e:
switch (ETypeOfTerm(t2)) {
@ -547,7 +547,7 @@ p_slr(Term t1, Term t2) {
if (i2 == Int_MIN) {
return Yap_ArithError(RESOURCE_ERROR_HUGE_INT, t2, ">>/2");
}
return do_sll(IntegerOfTerm(t1), -i2);
return do_sll(IntegerOfTerm(t1), -i2 PASS_REGS);
}
RINT(SLR(IntegerOfTerm(t1), i2));
}

View File

@ -170,25 +170,43 @@ typedef struct index_t {
UInt ntrys;
UInt nentries;
UInt hsize;
CELL **key;
BITS32 *key;
CELL *cls;
CELL *links;
BITS32 *links;
size_t size;
yamop *code;
} Index_t;
INLINE_ONLY EXTERN inline UInt EXO_ADDRESS_TO_OFFSET(struct index_t *it, CELL *ptr);
INLINE_ONLY EXTERN inline BITS32 EXO_ADDRESS_TO_OFFSET(struct index_t *it, CELL *ptr);
INLINE_ONLY EXTERN inline UInt
INLINE_ONLY EXTERN inline BITS32
EXO_ADDRESS_TO_OFFSET(struct index_t *it, CELL* ptr)
{
return ptr-it->links;
return 1+(ptr-it->cls);
}
INLINE_ONLY EXTERN inline CELL *EXO_OFFSET_TO_ADDRESS(struct index_t *it, UInt off);
INLINE_ONLY EXTERN inline CELL *
EXO_OFFSET_TO_ADDRESS(struct index_t *it, UInt off)
EXO_OFFSET_TO_ADDRESS(struct index_t *it, BITS32 off)
{
if (off == 0L)
return NULL;
return (it->cls-1)+off;
}
INLINE_ONLY EXTERN inline BITS32 ADDRESS_TO_LINK(struct index_t *it, CELL *ptr);
INLINE_ONLY EXTERN inline BITS32
ADDRESS_TO_LINK(struct index_t *it, CELL* ptr)
{
return ptr-it->links;
}
INLINE_ONLY EXTERN inline CELL *LINK_TO_ADDRESS(struct index_t *it, BITS32 off);
INLINE_ONLY EXTERN inline CELL *
LINK_TO_ADDRESS(struct index_t *it, BITS32 off)
{
return it->links+off;
}
@ -323,8 +341,10 @@ same_lu_block(yamop **paddr, yamop *p)
}
#endif
#define Yap_MkStaticRefTerm(cp) __Yap_MkStaticRefTerm((cp) PASS_REGS)
static inline Term
Yap_MkStaticRefTerm(StaticClause *cp)
__Yap_MkStaticRefTerm(StaticClause *cp USES_REGS)
{
Term t[1];
t[0] = MkIntegerTerm((Int)cp);
@ -337,8 +357,10 @@ Yap_ClauseFromTerm(Term t)
return (StaticClause *)IntegerOfTerm(ArgOfTerm(1,t));
}
#define Yap_MkMegaRefTerm(ap, ipc) __Yap_MkMegaRefTerm((ap), (ipc) PASS_REGS)
static inline Term
Yap_MkMegaRefTerm(PredEntry *ap,yamop *ipc)
__Yap_MkMegaRefTerm(PredEntry *ap,yamop *ipc USES_REGS)
{
Term t[2];
t[0] = MkIntegerTerm((Int)ap);

View File

@ -314,12 +314,16 @@ size_t STD_PROTO(Yap_gmp_to_size,(Term, int));
int STD_PROTO(Yap_term_to_existing_big,(Term, MP_INT *));
int STD_PROTO(Yap_term_to_existing_rat,(Term, MP_RAT *));
void Yap_gmp_set_bit(Int i, Term t);
#endif
INLINE_ONLY inline EXTERN Term Yap_Mk64IntegerTerm(YAP_LONG_LONG);
#define Yap_Mk64IntegerTerm(i) __Yap_Mk64IntegerTerm((i) PASS_REGS)
INLINE_ONLY inline EXTERN Term __Yap_Mk64IntegerTerm(YAP_LONG_LONG USES_REGS);
INLINE_ONLY inline EXTERN Term
Yap_Mk64IntegerTerm(YAP_LONG_LONG i)
__Yap_Mk64IntegerTerm(YAP_LONG_LONG i USES_REGS)
{
if (i <= Int_MAX && i >= Int_MIN) {
return MkIntegerTerm((Int)i);

View File

@ -10,6 +10,7 @@ INIT_SEQ_STRING(size_t n)
static inline Word
EXTEND_SEQ_CODES(Word ptr, int c) {
CACHE_REGS
ptr[0] = MkIntegerTerm(c);
ptr[1] = AbsPair(ptr+2);

View File

@ -13,7 +13,7 @@ static void readswap8(double *buf);
static byte get_hostbyteorder(void);
static byte get_inbyteorder(void);
static uint32 get_wkbType(void);
static Term get_point(char *functor);
static Term get_point(char *functor USES_REGS);
static Term get_linestring(char *functor);
static Term get_polygon(char *functor);
static Term get_geometry(uint32 type);
@ -150,7 +150,7 @@ static void readswap8(double *buf) {
cursor += 8;
}
static Term get_point(char *func){
static Term get_point(char *func USES_REGS){
Term args[2];
Functor functor;
double d;
@ -188,7 +188,7 @@ static Term get_linestring(char *func){
c_list = (Term *) calloc(sizeof(Term),n);
for ( i = 0; i < n; i++) {
c_list[i] = get_point(NULL);
c_list[i] = get_point(NULL PASS_REGS);
}
list = MkAtomTerm(Yap_LookupAtom("[]"));
@ -241,15 +241,14 @@ static Term get_geometry(uint32 type){
switch(type) {
case WKBPOINT:
return get_point("point");
return get_point("point" PASS_REGS);
case WKBLINESTRING:
return get_linestring("linestring");
case WKBPOLYGON:
return get_polygon("polygon");
case WKBMULTIPOINT:
{
byte b;
uint32 n, u;
uint32 n;
int i;
Functor functor;
Term *c_list;
@ -264,10 +263,10 @@ static Term get_geometry(uint32 type){
for ( i = 0; i < n; i++ ) {
/* read (and ignore) the byteorder and type */
b = get_inbyteorder();
u = get_wkbType();
get_inbyteorder();
get_wkbType();
c_list[i] = get_point(NULL);
c_list[i] = get_point(NULL PASS_REGS);
}
list = MkAtomTerm(Yap_LookupAtom("[]"));
@ -282,8 +281,7 @@ static Term get_geometry(uint32 type){
}
case WKBMULTILINESTRING:
{
byte b;
uint32 n, u;
uint32 n;
int i;
Functor functor;
Term *c_list;
@ -298,8 +296,8 @@ static Term get_geometry(uint32 type){
for ( i = 0; i < n; i++ ) {
/* read (and ignore) the byteorder and type */
b = get_inbyteorder();
u = get_wkbType();
get_inbyteorder();
get_wkbType();
c_list[i] = get_linestring(NULL);
}
@ -316,8 +314,7 @@ static Term get_geometry(uint32 type){
}
case WKBMULTIPOLYGON:
{
byte b;
uint32 n, u;
uint32 n;
int i;
Functor functor;
Term *c_list;
@ -332,8 +329,8 @@ static Term get_geometry(uint32 type){
for ( i = 0; i < n; i++ ) {
/* read (and ignore) the byteorder and type */
b = get_inbyteorder();
u = get_wkbType();
get_inbyteorder();
get_wkbType();
c_list[i] = get_polygon(NULL);
}
@ -350,7 +347,6 @@ static Term get_geometry(uint32 type){
}
case WKBGEOMETRYCOLLECTION:
{
byte b;
uint32 n;
int i;
Functor functor;
@ -365,7 +361,7 @@ static Term get_geometry(uint32 type){
for ( i = 0; i < n; i++ ) {
b = get_inbyteorder();
get_inbyteorder();
c_list[i] = get_geometry(get_wkbType());
}

View File

@ -22,39 +22,39 @@
#include "opt.mavar.h"
#ifdef THREADS
static inline void **get_insert_thread_bucket(void **, lockvar *);
static inline void **get_thread_bucket(void **);
static inline void **__get_insert_thread_bucket(void **, lockvar * USES_REGS);
static inline void **__get_thread_bucket(void ** USES_REGS);
static inline void abolish_thread_buckets(void **);
#endif /* THREADS */
static inline sg_node_ptr get_insert_subgoal_trie(tab_ent_ptr USES_REGS);
static inline sg_node_ptr get_subgoal_trie(tab_ent_ptr);
static inline sg_node_ptr __get_subgoal_trie(tab_ent_ptr USES_REGS);
static inline sg_node_ptr get_subgoal_trie_for_abolish(tab_ent_ptr USES_REGS);
static inline sg_fr_ptr *get_insert_subgoal_frame_addr(sg_node_ptr USES_REGS);
static inline sg_fr_ptr get_subgoal_frame(sg_node_ptr);
static inline sg_fr_ptr get_subgoal_frame_for_abolish(sg_node_ptr USES_REGS);
#ifdef THREADS_FULL_SHARING
static inline void SgFr_batched_cached_answers_check_insert(sg_fr_ptr, ans_node_ptr);
static inline void __SgFr_batched_cached_answers_check_insert(sg_fr_ptr, ans_node_ptr USES_REGS);
static inline int SgFr_batched_cached_answers_check_remove(sg_fr_ptr, ans_node_ptr);
#endif /* THREADS_FULL_SHARING */
#ifdef THREADS_CONSUMER_SHARING
static inline void add_to_tdv(int, int);
static inline void check_for_deadlock(sg_fr_ptr);
static inline sg_fr_ptr deadlock_detection(sg_fr_ptr);
static inline void __add_to_tdv(int, int USES_REGS);
static inline void __check_for_deadlock(sg_fr_ptr USES_REGS);
static inline sg_fr_ptr __deadlock_detection(sg_fr_ptr USES_REGS);
#endif /* THREADS_CONSUMER_SHARING */
static inline Int freeze_current_cp(void);
static inline void wake_frozen_cp(Int);
static inline void abolish_frozen_cps_until(Int);
static inline void abolish_frozen_cps_all(void);
static inline void adjust_freeze_registers(void);
static inline void mark_as_completed(sg_fr_ptr);
static inline void unbind_variables(tr_fr_ptr, tr_fr_ptr);
static inline void rebind_variables(tr_fr_ptr, tr_fr_ptr);
static inline void restore_bindings(tr_fr_ptr, tr_fr_ptr);
static inline CELL *expand_auxiliary_stack(CELL *);
static inline void abolish_incomplete_subgoals(choiceptr);
static inline Int __freeze_current_cp( USES_REGS1 );
static inline void __wake_frozen_cp(Int USES_REGS);
static inline void __abolish_frozen_cps_until(Int USES_REGS);
static inline void __abolish_frozen_cps_all( USES_REGS1 );
static inline void __adjust_freeze_registers( USES_REGS1 );
static inline void __mark_as_completed(sg_fr_ptr USES_REGS);
static inline void __unbind_variables(tr_fr_ptr, tr_fr_ptr USES_REGS);
static inline void __rebind_variables(tr_fr_ptr, tr_fr_ptr USES_REGS);
static inline void __restore_bindings(tr_fr_ptr, tr_fr_ptr USES_REGS);
static inline CELL *__expand_auxiliary_stack(CELL * USES_REGS);
static inline void __abolish_incomplete_subgoals(choiceptr USES_REGS);
#ifdef YAPOR
static inline void pruning_over_tabling_data_structures(void);
static inline void collect_suspension_frames(or_fr_ptr);
static inline void __collect_suspension_frames(or_fr_ptr USES_REGS);
#ifdef TIMESTAMP_CHECK
static inline susp_fr_ptr suspension_frame_to_resume(or_fr_ptr, long);
#else
@ -658,8 +658,9 @@ static inline tg_sol_fr_ptr CUT_prune_tg_solution_frames(tg_sol_fr_ptr, int);
******************************/
#ifdef THREADS
static inline void **get_insert_thread_bucket(void **buckets, lockvar *buckets_lock) {
CACHE_REGS
#define get_insert_thread_bucket(b, bl) __get_insert_thread_bucket((b), (bl) PASS_REGS)
static inline void **__get_insert_thread_bucket(void **buckets, lockvar *buckets_lock USES_REGS) {
/* direct bucket */
if (worker_id < THREADS_DIRECT_BUCKETS)
@ -678,9 +679,9 @@ static inline void **get_insert_thread_bucket(void **buckets, lockvar *buckets_l
return *buckets + (worker_id - THREADS_DIRECT_BUCKETS) % THREADS_DIRECT_BUCKETS;
}
#define get_thread_bucket(b) __get_thread_bucket((b) PASS_REGS)
static inline void **get_thread_bucket(void **buckets) {
CACHE_REGS
static inline void **__get_thread_bucket(void **buckets USES_REGS) {
/* direct bucket */
if (worker_id < THREADS_DIRECT_BUCKETS)
@ -729,8 +730,9 @@ static inline sg_node_ptr get_insert_subgoal_trie(tab_ent_ptr tab_ent USES_REGS)
#endif /* THREADS_NO_SHARING */
}
#define get_subgoal_trie(te) __get_subgoal_trie((te) PASS_REGS)
static inline sg_node_ptr get_subgoal_trie(tab_ent_ptr tab_ent) {
static inline sg_node_ptr __get_subgoal_trie(tab_ent_ptr tab_ent USES_REGS) {
#ifdef THREADS_NO_SHARING
sg_node_ptr *sg_node_addr = (sg_node_ptr *) get_thread_bucket((void **) &TabEnt_subgoal_trie(tab_ent));
return *sg_node_addr;
@ -825,8 +827,8 @@ static inline sg_fr_ptr get_subgoal_frame_for_abolish(sg_node_ptr sg_node USES_R
#ifdef THREADS_FULL_SHARING
static inline void SgFr_batched_cached_answers_check_insert(sg_fr_ptr sg_fr, ans_node_ptr ans_node) {
CACHE_REGS
#define SgFr_batched_cached_answers_check_insert(s, a) __SgFr_batched_cached_answers_check_insert((s), (a) PASS_REGS)
static inline void SgFr_batched_cached_answers_check_insert(sg_fr_ptr sg_fr, ans_node_ptr ans_node USES_REGS) {
if (SgFr_batched_last_answer(sg_fr) == NULL)
SgFr_batched_last_answer(sg_fr) = SgFr_first_answer(sg_fr);
@ -854,8 +856,9 @@ static inline void SgFr_batched_cached_answers_check_insert(sg_fr_ptr sg_fr, ans
return;
}
static inline int SgFr_batched_cached_answers_check_remove(sg_fr_ptr sg_fr, ans_node_ptr ans_node) {
CACHE_REGS
#define SgFr_batched_cached_answers_check_remove(s, a) __SgFr_batched_cached_answers_check_remove((s), (a) PASS_REGS)
static inline int __SgFr_batched_cached_answers_check_remove(sg_fr_ptr sg_fr, ans_node_ptr ans_node USES_REgS) {
struct answer_ref_node *local_uncons_ans;
local_uncons_ans = SgFr_batched_cached_answers(sg_fr) ;
@ -884,10 +887,10 @@ static inline int SgFr_batched_cached_answers_check_remove(sg_fr_ptr sg_fr, ans_
#ifdef THREADS_CONSUMER_SHARING
static inline void add_to_tdv(int wid, int wid_dep) {
#ifdef OUTPUT_THREADS_TABLING
CACHE_REGS
#endif /* OUTPUT_THREADS_TABLING */
#define add_to_tdv(w, wd) __add_to_tdv((w), (wd) PASS_REGS)
static inline void __add_to_tdv(int wid, int wid_dep USES_REGS) {
// thread wid next of thread wid_dep
/* check before insert */
int c_wid = ThDepFr_next(GLOBAL_th_dep_fr(wid));
@ -927,9 +930,9 @@ static inline void add_to_tdv(int wid, int wid_dep) {
return;
}
#define check_for_deadlock(s) __check_for_deadlock((s) PASS_REGS)
static inline void check_for_deadlock(sg_fr_ptr sg_fr) {
CACHE_REGS
static inline void __check_for_deadlock(sg_fr_ptr sg_fr USES_REGS) {
sg_fr_ptr local_sg_fr = deadlock_detection(sg_fr);
if (local_sg_fr){
@ -942,9 +945,9 @@ static inline void check_for_deadlock(sg_fr_ptr sg_fr) {
return;
}
#define deadlock_detection(s) __deadlock_detection((s) PASS_REGS)
static inline sg_fr_ptr deadlock_detection(sg_fr_ptr sg_fr) {
CACHE_REGS
static inline sg_fr_ptr __deadlock_detection(sg_fr_ptr sg_fr USES_REGS) {
sg_fr_ptr remote_sg_fr = REMOTE_top_sg_fr(SgFr_gen_worker(sg_fr));
while( SgFr_sg_ent(remote_sg_fr) != SgFr_sg_ent(sg_fr)){
@ -977,9 +980,9 @@ static inline sg_fr_ptr deadlock_detection(sg_fr_ptr sg_fr) {
}
#endif /* THREADS_CONSUMER_SHARING */
#define freeze_current_cp() __freeze_current_cp( PASS_REGS1 )
static inline Int freeze_current_cp(void) {
CACHE_REGS
static inline Int __freeze_current_cp(USES_REGS1) {
choiceptr freeze_cp = B;
B_FZ = freeze_cp;
@ -991,8 +994,11 @@ static inline Int freeze_current_cp(void) {
}
static inline void wake_frozen_cp(Int frozen_offset) {
CACHE_REGS
#define wake_frozen_cp(f) __wake_frozen_cp((f) PASS_REGS)
#define restore_bindings(u, r) __restore_bindings((u), (r) PASS_REGS)
static inline void __wake_frozen_cp(Int frozen_offset USES_REGS) {
choiceptr frozen_cp = (choiceptr)(LOCAL_LocalBase - frozen_offset);
restore_bindings(TR, frozen_cp->cp_tr);
@ -1003,8 +1009,9 @@ static inline void wake_frozen_cp(Int frozen_offset) {
}
static inline void abolish_frozen_cps_until(Int frozen_offset) {
CACHE_REGS
#define abolish_frozen_cps_until(f) __abolish_frozen_cps_until((f) PASS_REGS )
static inline void __abolish_frozen_cps_until(Int frozen_offset USES_REGS) {
choiceptr frozen_cp = (choiceptr)(LOCAL_LocalBase - frozen_offset);
B_FZ = frozen_cp;
@ -1013,28 +1020,28 @@ static inline void abolish_frozen_cps_until(Int frozen_offset) {
return;
}
#define abolish_frozen_cps_all() __abolish_frozen_cps_all( PASS_REGS1 )
static inline void abolish_frozen_cps_all(void) {
CACHE_REGS
static inline void __abolish_frozen_cps_all( USES_REGS1 ) {
B_FZ = (choiceptr) LOCAL_LocalBase;
H_FZ = (CELL *) LOCAL_GlobalBase;
TR_FZ = (tr_fr_ptr) LOCAL_TrailBase;
return;
}
#define adjust_freeze_registers() __adjust_freeze_registers( PASS_REGS1 )
static inline void adjust_freeze_registers(void) {
CACHE_REGS
static inline void __adjust_freeze_registers( USES_REGS1 ) {
B_FZ = DepFr_cons_cp(LOCAL_top_dep_fr);
H_FZ = B_FZ->cp_h;
TR_FZ = B_FZ->cp_tr;
return;
}
#define mark_as_completed(sg) __mark_as_completed((sg) PASS_REGS )
static inline void mark_as_completed(sg_fr_ptr sg_fr) {
static inline void __mark_as_completed(sg_fr_ptr sg_fr USES_REGS) {
#if defined(MODE_DIRECTED_TABLING) && !defined(THREADS_FULL_SHARING) && !defined(THREADS_CONSUMER_SHARING)
CACHE_REGS
#endif /* MODE_DIRECTED_TABLING && !THREADS_FULL_SHARING && !THREADS_CONSUMER_SHARING */
LOCK_SG_FR(sg_fr);
@ -1079,9 +1086,9 @@ static inline void mark_as_completed(sg_fr_ptr sg_fr) {
return;
}
#define unbind_variables(u, e) __unbind_variables((u), (e) PASS_REGS)
static inline void unbind_variables(tr_fr_ptr unbind_tr, tr_fr_ptr end_tr) {
CACHE_REGS
static inline void __unbind_variables(tr_fr_ptr unbind_tr, tr_fr_ptr end_tr USES_REGS) {
TABLING_ERROR_CHECKING(unbind_variables, unbind_tr < end_tr);
/* unbind loop */
while (unbind_tr != end_tr) {
@ -1111,8 +1118,9 @@ static inline void unbind_variables(tr_fr_ptr unbind_tr, tr_fr_ptr end_tr) {
}
static inline void rebind_variables(tr_fr_ptr rebind_tr, tr_fr_ptr end_tr) {
CACHE_REGS
#define rebind_variables(u, e) __rebind_variables(u, e PASS_REGS)
static inline void __rebind_variables(tr_fr_ptr rebind_tr, tr_fr_ptr end_tr USES_REGS) {
TABLING_ERROR_CHECKING(rebind_variables, rebind_tr < end_tr);
/* rebind loop */
Yap_NEW_MAHASH((ma_h_inner_struct *)H PASS_REGS);
@ -1144,9 +1152,7 @@ static inline void rebind_variables(tr_fr_ptr rebind_tr, tr_fr_ptr end_tr) {
return;
}
static inline void restore_bindings(tr_fr_ptr unbind_tr, tr_fr_ptr rebind_tr) {
CACHE_REGS
static inline void __restore_bindings(tr_fr_ptr unbind_tr, tr_fr_ptr rebind_tr USES_REGS) {
CELL ref;
tr_fr_ptr end_tr;
@ -1218,9 +1224,9 @@ static inline void restore_bindings(tr_fr_ptr unbind_tr, tr_fr_ptr rebind_tr) {
return;
}
#define expand_auxiliary_stack(s) __expand_auxiliary_stack((s) PASS_REGS)
static inline CELL *expand_auxiliary_stack(CELL *stack) {
CACHE_REGS
static inline CELL *__expand_auxiliary_stack(CELL *stack USES_REGS) {
void *old_top = LOCAL_TrailTop;
INFORMATION_MESSAGE("Expanding trail in 64 Kbytes");
if (! Yap_growtrail(K64, TRUE)) { /* TRUE means 'contiguous_only' */
@ -1234,9 +1240,10 @@ static inline CELL *expand_auxiliary_stack(CELL *stack) {
}
}
#define abolish_incomplete_subgoals(p) __abolish_incomplete_subgoals((p) PASS_REGS)
static inline void abolish_incomplete_subgoals(choiceptr prune_cp) {
CACHE_REGS
static inline void __abolish_incomplete_subgoals(choiceptr prune_cp USES_REGS) {
#ifdef YAPOR
if (EQUAL_OR_YOUNGER_CP(GetOrFr_node(LOCAL_top_susp_or_fr), prune_cp))
@ -1389,8 +1396,9 @@ static inline void pruning_over_tabling_data_structures(void) {
}
static inline void collect_suspension_frames(or_fr_ptr or_fr) {
CACHE_REGS
#define collect_suspension_frames(o) __collect_suspension_frames((o) PASS_REGS)
static inline void __collect_suspension_frames(or_fr_ptr or_fr USES_REGS) {
int depth;
or_fr_ptr *susp_ptr;

10
configure vendored
View File

@ -1539,7 +1539,7 @@ Optional Packages:
--with-java=JAVA_HOME use Java instalation in JAVA_HOME
--with-readline=DIR use GNU Readline Library in DIR
--with-matlab=DIR use MATLAB package in DIR
--with-mpi=DIR use MPI library in DIR
--with-mpi=DIR use LAM/MPI library in DIR
--with-mpe=DIR use MPE library in DIR
--with-lam=DIR use LAM MPI library in DIR
--with-heap-space=space default heap size in Kbytes
@ -4860,16 +4860,16 @@ fi
# Check whether --with-mpi was given.
if test "${with_mpi+set}" = set; then :
withval=$with_mpi; if test "$withval" = yes; then
yap_cv_mpi=yes
yap_cv_lam=yes
elif test "$withval" = no; then
yap_cv_mpi=no
yap_cv_lam=no
else
yap_cv_mpi=$with_mpi
yap_cv_lam=$with_mpi
LDFLAGS="$LDFLAGS -L${yap_cv_mpi}/lib"
CPPFLAGS="$CPPFLAGS -I${yap_cv_mpi}/include"
fi
else
yap_cv_mpi=no
yap_cv_lam=no
fi

View File

@ -360,18 +360,18 @@ AC_ARG_WITH(matlab,
[yap_cv_matlab=no])
AC_ARG_WITH(mpi,
[ --with-mpi[=DIR] use MPI library in DIR],
[ --with-mpi[=DIR] use LAM/MPI library in DIR],
if test "$withval" = yes; then
dnl handle UBUNTU systems
yap_cv_mpi=yes
yap_cv_lam=yes
elif test "$withval" = no; then
yap_cv_mpi=no
yap_cv_lam=no
else
yap_cv_mpi=$with_mpi
yap_cv_lam=$with_mpi
LDFLAGS="$LDFLAGS -L${yap_cv_mpi}/lib"
CPPFLAGS="$CPPFLAGS -I${yap_cv_mpi}/include"
fi,
[yap_cv_mpi=no])
[yap_cv_lam=no])
AC_ARG_WITH(mpe,

View File

@ -12840,6 +12840,13 @@ vertices in @var{Vs} map to vertices in @var{NewVs}.
The path @var{Path} is a path starting at vertex @var{Vertex} in graph
@var{Graph}.
@item dgraph_path(+@var{Vertex}, +@var{Vertex1}, +@var{Graph}, ?@var{Path})
@findex dgraph_path/3
@snindex dgraph_path/3
@cnindex dgraph_path/3
The path @var{Path} is a path starting at vertex @var{Vertex} in graph
@var{Graph} and ending at path @var{Vertex2}.
@item dgraph_reachable(+@var{Vertex}, +@var{Graph}, ?@var{Edges})
@findex dgraph_path/3
@snindex dgraph_path/3

View File

@ -32,6 +32,7 @@
dgraph_min_paths/3,
dgraph_isomorphic/4,
dgraph_path/3,
dgraph_path/4,
dgraph_leaves/2,
dgraph_reachable/3
]).
@ -40,7 +41,8 @@
[rb_new/1 as dgraph_new]).
:- use_module(library(rbtrees),
[rb_empty/1,
[rb_new/1,
rb_empty/1,
rb_lookup/3,
rb_apply/4,
rb_insert/4,
@ -361,10 +363,21 @@ dgraph_min_paths(V1, Graph, Paths) :-
dgraph_to_wdgraph(Graph, WGraph),
wdgraph_min_paths(V1, WGraph, Paths).
dgraph_path(V, G, [V|P]) :-
rb_lookup(V, Children, G),
ord_del_element(Children, V, Ch),
do_path(Ch, G, [V], P).
dgraph_path(V1, V2, Graph, Path) :-
rb_new(E0),
rb_lookup(V1, Children, Graph),
dgraph_path_children(Children, V2, E0, Graph, Path).
dgraph_path_children([V1|_], V2, _E1, _Graph, []) :- V1 == V2.
dgraph_path_children([V1|_], V2, E1, Graph, [V1|Path]) :-
V2 \== V1,
\+ rb_lookup(V1, _, E0),
rb_insert(E0, V2, [], E1),
rb_lookup(V1, Children, Graph),
dgraph_path_children(Children, V2, E1, Graph, Path).
dgraph_path_children([_|Children], V2, E1, Graph, Path) :-
dgraph_path_children(Children, V2, E1, Graph, Path).
do_path([], _, _, []).
do_path([C|Children], G, SoFar, Path) :-
@ -378,6 +391,11 @@ do_children([V|_], G, SoFar, [V|Path]) :-
do_children([_|Children], G, SoFar, Path) :-
do_children(Children, G, SoFar, Path).
dgraph_path(V, G, [V|P]) :-
rb_lookup(V, Children, G),
ord_del_element(Children, V, Ch),
do_path(Ch, G, [V], P).
dgraph_isomorphic(Vs, Vs2, G1, G2) :-
rb_new(Map0),

View File

@ -37,8 +37,8 @@ RANLIB=@RANLIB@
srcdir=@srcdir@
SO=@SO@
CWD=$(PWD)
MPILDF=`$(MPI_CC) -showme|sed "s/[^ ]*//"|sed "s/-pt/-lpt/"`
MPICF=`$(MPI_CC) -showme| cut -d " " -f 2`
MPILDF=`$(MPI_CC) --showme:link`
MPICF=`$(MPI_CC) --showme:compile`
#
OBJS=yap_mpi.o hash.o prologterms2c.o

View File

@ -32,23 +32,23 @@ static char *rcsid = "$Header: /Users/vitor/Yap/yap-cvsbackup/library/mpi/mpi.c,
void STD_PROTO(YAP_Write, (Term, void (*)(int), int));
STATIC_PROTO (Int p_mpi_open, (void));
STATIC_PROTO (Int p_mpi_close, (void));
STATIC_PROTO (Int p_mpi_send, (void));
STATIC_PROTO (Int p_mpi_receive, (void));
STATIC_PROTO (Int p_mpi_bcast3, (void));
STATIC_PROTO (Int p_mpi_bcast2, (void));
STATIC_PROTO (Int p_mpi_barrier, (void));
STATIC_PROTO (Int p_mpi_open, ( USES_REGS1 ));
STATIC_PROTO (Int p_mpi_close, ( USES_REGS1 ));
STATIC_PROTO (Int p_mpi_send, ( USES_REGS1 ));
STATIC_PROTO (Int p_mpi_receive, ( USES_REGS1 ));
STATIC_PROTO (Int p_mpi_bcast3, ( USES_REGS1 ));
STATIC_PROTO (Int p_mpi_bcast2, ( USES_REGS1 ));
STATIC_PROTO (Int p_mpi_barrier, ( USES_REGS1 ));
/*
* Auxiliary Data
*/
static Int rank, numprocs, namelen;
static int rank, numprocs, namelen;
static char processor_name[MPI_MAX_PROCESSOR_NAME];
static Int mpi_argc;
static int mpi_argc;
static char **mpi_argv;
/* this should eventually be moved to config.h */
@ -111,7 +111,7 @@ mpi_putc(Int ch)
static Int
p_mpi_open(void) /* mpi_open(?rank, ?num_procs, ?proc_name) */
p_mpi_open( USES_REGS1 ) /* mpi_open(?rank, ?num_procs, ?proc_name) */
{
Term t_rank = Deref(ARG1), t_numprocs = Deref(ARG2), t_procname = Deref(ARG3);
Int retv;
@ -156,7 +156,7 @@ Yap exit(FAILURE), whereas in Yap/LAM mpi_open/3 simply fails.
static Int /* mpi_close */
p_mpi_close()
p_mpi_close( USES_REGS1 )
{
MPI_Finalize();
return TRUE;
@ -164,7 +164,7 @@ p_mpi_close()
static Int
p_mpi_send() /* mpi_send(+data, +destination, +tag) */
p_mpi_send( USES_REGS1 ) /* mpi_send(+data, +destination, +tag) */
{
Term t_data = Deref(ARG1), t_dest = Deref(ARG2), t_tag = Deref(ARG3);
int tag, dest, retv;
@ -216,7 +216,7 @@ p_mpi_send() /* mpi_send(+data, +destination, +tag) */
static Int
p_mpi_receive() /* mpi_receive(-data, ?orig, ?tag) */
p_mpi_receive( USES_REGS1 ) /* mpi_receive(-data, ?orig, ?tag) */
{
Term t, t_data = Deref(ARG1), t_orig = Deref(ARG2), t_tag = Deref(ARG3);
int tag, orig, retv;
@ -305,7 +305,7 @@ p_mpi_receive() /* mpi_receive(-data, ?orig, ?tag) */
static Int
p_mpi_bcast3() /* mpi_bcast( ?data, +root, +max_size ) */
p_mpi_bcast3( USES_REGS1 ) /* mpi_bcast( ?data, +root, +max_size ) */
{
Term t_data = Deref(ARG1), t_root = Deref(ARG2), t_max_size = Deref(ARG3);
int root, retv, max_size;
@ -386,7 +386,7 @@ p_mpi_bcast3() /* mpi_bcast( ?data, +root, +max_size ) */
*/
static Int
p_mpi_bcast2() /* mpi_bcast( ?data, +root ) */
p_mpi_bcast2( USES_REGS1 ) /* mpi_bcast( ?data, +root ) */
{
Term t_data = Deref(ARG1), t_root = Deref(ARG2);
int root, retv;
@ -460,7 +460,7 @@ p_mpi_bcast2() /* mpi_bcast( ?data, +root ) */
static Int
p_mpi_barrier() /* mpi_barrier/0 */
p_mpi_barrier( USES_REGS1 ) /* mpi_barrier/0 */
{
int retv;

View File

@ -473,7 +473,7 @@ factor_to_dist(Hash, f(bayes, Id, Ks)) :-
maplist(key_to_var(Hash), Ks, [V|Parents]),
Ks =[Key|_],
pfl:skolem(Key, Domain),
pfl:get_pfl_parameters(Id, CPT),
pfl:get_pfl_parameters(Id, Ks, CPT),
dist(p(Domain,CPT,Parents), DistInfo, Key, Parents),
put_atts(V,[dist(DistInfo,Parents)]).

View File

@ -122,7 +122,7 @@ evtotree(K=V,Ev0,Ev) :-
rb_insert(Ev0, K, V, Ev).
ftotree(F, Fs0, Fs) :-
F = f([K|_Parents],_,_,_),
F = fn([K|_Parents],_,_,_,_),
rb_insert(Fs0, K, F, Fs).
bdd([[]],_,_) :- !.
@ -160,7 +160,7 @@ sort_keys(AllFs, AllVars, Leaves) :-
dgraph_leaves(Graph, Leaves),
dgraph_top_sort(Graph, AllVars).
add_node(f([K|Parents],_,_,_), Graph0, Graph) :-
add_node(fn([K|Parents],_,_,_,_), Graph0, Graph) :-
dgraph_add_vertex(Graph0, K, Graph1),
foldl(add_edge(K), Parents, Graph1, Graph).
@ -190,7 +190,7 @@ add_parents([V0|Parents], V, Graph0, GraphF) :-
get_keys_info([], _, _, _, Vs, Vs, Ps, Ps, _, _) --> [].
get_keys_info([V|MoreVs], Evs, Fs, OrderVs, Vs, VsF, Ps, PsF, Lvs, Outs) -->
{ rb_lookup(V, F, Fs) }, !,
{ F = f([V|Parents], _, _, DistId) },
{ F = fn([V|Parents], _, _, DistId, _) },
%{writeln(v:DistId:Parents)},
[DIST],
{ get_key_info(V, F, Fs, Evs, OrderVs, DistId, Parents, Vs, Vs2, Ps, Ps1, Lvs, Outs, DIST) },
@ -200,7 +200,7 @@ get_key_info(V, F, Fs, Evs, OrderVs, DistId, Parents0, Vs, Vs2, Ps, Ps1, Lvs, Ou
reorder_keys(Parents0, OrderVs, Parents, Map),
check_key_p(DistId, F, Map, Parms, _ParmVars, Ps, Ps1),
unbound_parms(Parms, ParmVars),
F = f(_,[Size|_],_,_),
F = fn(_,[Size|_],_,_,_),
check_key(V, Size, DIST, Vs, Vs1),
DIST = info(V, Tree, Ev, Values, Formula, ParmVars, Parms),
% get a list of form [[P00,P01], [P10,P11], [P20,P21]]
@ -599,7 +599,7 @@ to_disj2([V,V1|Vs], V0, Out) :-
%
check_key_p(DistId, _, Map, Parms, ParmVars, Ps, Ps) :-
rb_lookup(DistId-Map, theta(Parms, ParmVars), Ps), !.
check_key_p(DistId, f(_, Sizes, Parms0, DistId), Map, Parms, ParmVars, Ps, PsF) :-
check_key_p(DistId, fn(_, Sizes, Parms0, DistId, _), Map, Parms, ParmVars, Ps, PsF) :-
swap_parms(Parms0, Sizes, [0|Map], Parms1),
length(Parms1, L0),
Sizes = [Size|_],
@ -693,7 +693,7 @@ get_parents(V.Parents, Values.PVars, Vs0, Vs) :-
get_key_parent(Fs, V, Values, Vs0, Vs) :-
INFO = info(V, _Parent, _Ev, Values, _, _, _),
rb_lookup(V, f(_, [Size|_], _, _), Fs),
rb_lookup(V, fn(_, [Size|_], _, _, _), Fs),
check_key(V, Size, INFO, Vs0, Vs).
check_key(V, _, INFO, Vs, Vs) :-

View File

@ -42,7 +42,7 @@ init_influences(Vs, G, RG) :-
to_dgraph(Vs, G0, G),
dgraph_transpose(G, RG).
factor_to_dgraph(f([V|Parents],_,_,_), G0, G) :-
factor_to_dgraph(fn([V|Parents],_,_,_,_), G0, G) :-
dgraph_add_vertex(G0, V, G00),
build_edges(Parents, V, Edges),
dgraph_add_edges(G00, Edges, G).

View File

@ -238,7 +238,7 @@ get_dist_all_sizes(Id, DSizes) :-
get_dist_domain_size(DistId, DSize) :-
use_parfactors(on), !,
pfl:get_pfl_parameters(DistId, Dist),
pfl:get_pfl_parameters(DistId, _, Dist),
length(Dist, DSize).
get_dist_domain_size(avg(D,_), DSize) :- !,
length(D, DSize).
@ -297,7 +297,7 @@ empty_dist(Dist, TAB) :-
dist_new_table(DistId, NewMat) :-
use_parfactors(on), !,
matrix_to_list(NewMat, List),
pfl:new_pfl_parameters(DistId, List).
pfl:new_pfl_parameters(DistId, _, List).
dist_new_table(Id, NewMat) :-
matrix_to_list(NewMat, List),
recorded(clpbn_dist_db, db(Id, Key, _, A, B, C, D), R),

View File

@ -144,7 +144,6 @@ create_new_variable(K, V, Vf0, Vff, C0, Cf) :-
Id =.. [Na,Dom],
Dist =.. [Na,Dom,NTVs],
{ V = K with Dist },
writeln(done),
add_stored_evidence(K, V),
add_variables(TVs, NTVs, Vf0, Vff, C0, Cf).

View File

@ -28,6 +28,8 @@
sum_list/2
]).
:- use_module(library(maplist)).
:- use_module(library(ordsets),
[ord_subtract/3]).
@ -87,7 +89,7 @@ run_gibbs_solver(LVs, LPs, Vs) :-
initialise(LVs, Graph, GVs, OutputVars, VarOrder) :-
init_keys(Keys0),
gen_keys(LVs, 0, VLen, Keys0, Keys),
foldl2(gen_key, LVs, 0, VLen, Keys0, Keys),
functor(Graph,graph,VLen),
graph_representation(LVs, Graph, 0, Keys, TGraph),
compile_graph(Graph),
@ -99,21 +101,18 @@ initialise(LVs, Graph, GVs, OutputVars, VarOrder) :-
init_keys(Keys0) :-
rb_new(Keys0).
gen_keys([], I, I, Keys, Keys).
gen_keys([V|Vs], I0, If, Keys0, Keys) :-
clpbn:get_atts(V,[evidence(_)]), !,
gen_keys(Vs, I0, If, Keys0, Keys).
gen_keys([V|Vs], I0, If, Keys0, Keys) :-
gen_key(V, I0, I0, Keys0, Keys0) :-
clpbn:get_atts(V,[evidence(_)]), !.
gen_key(V, I0, I, Keys0, Keys) :-
I is I0+1,
rb_insert(Keys0,V,I,KeysI),
gen_keys(Vs, I, If, KeysI, Keys).
rb_insert(Keys0,V,I,Keys).
graph_representation([],_,_,_,[]).
graph_representation([V|Vs], Graph, I0, Keys, TGraph) :-
clpbn:get_atts(V,[evidence(_)]), !,
clpbn:get_atts(V, [dist(Id,Parents)]),
get_possibly_deterministic_dist_matrix(Id, Parents, _, Vals, Table),
get_sizes(Parents, Szs),
maplist(get_size, Parents, Szs),
length(Vals,Sz),
project_evidence_out([V|Parents],[V|Parents],Table,[Sz|Szs],Variables,NewTable),
% all variables are parents
@ -123,7 +122,7 @@ graph_representation([V|Vs], Graph, I0, Keys, [I-IParents|TGraph]) :-
I is I0+1,
clpbn:get_atts(V, [dist(Id,Parents)]),
get_possibly_deterministic_dist_matrix(Id, Parents, _, Vals, Table),
get_sizes(Parents, Szs),
maplist( get_size, Parents, Szs),
length(Vals,Sz),
project_evidence_out([V|Parents],[V|Parents],Table,[Sz|Szs],Variables,NewTable),
Variables = [V|NewParents],
@ -131,7 +130,7 @@ graph_representation([V|Vs], Graph, I0, Keys, [I-IParents|TGraph]) :-
reorder_CPT(Variables,NewTable,[V|SortedNVs],NewTable2,_),
add2graph(V, Vals, NewTable2, SortedIndices, Graph, Keys),
propagate2parents(NewParents, NewTable, Variables, Graph,Keys),
parent_indices(NewParents, Keys, IVariables0),
maplist(parent_index(Keys), NewParents, IVariables0),
sort(IVariables0, IParents),
arg(I, Graph, var(_,_,_,_,_,_,_,NewTable2,SortedIndices)),
graph_representation(Vs, Graph, I, Keys, TGraph).
@ -141,18 +140,12 @@ write_pars([V|Parents]) :-
clpbn:get_atts(V, [key(K),dist(I,_)]),write(K:I),nl,
write_pars(Parents).
get_sizes([], []).
get_sizes([V|Parents], [Sz|Szs]) :-
get_size(V, Sz) :-
clpbn:get_atts(V, [dist(Id,_)]),
get_dist_domain_size(Id, Sz),
get_sizes(Parents, Szs).
parent_indices([], _, []).
parent_indices([V|Parents], Keys, [I|IParents]) :-
rb_lookup(V, I, Keys),
parent_indices(Parents, Keys, IParents).
get_dist_domain_size(Id, Sz).
parent_index(Keys, V, I) :-
rb_lookup(V, I, Keys).
%
% first, remove nodes that have evidence from tables.
@ -180,52 +173,36 @@ add2graph(V, Vals, Table, IParents, Graph, Keys) :-
member(tabular(Table,Index,IParents), VarSlot), !.
sort_according_to_indices(NVs,Keys,SortedNVs,SortedIndices) :-
vars2indices(NVs,Keys,ToSort),
maplist(var2index(Keys), NVs, ToSort),
keysort(ToSort, Sorted),
split_parents(Sorted, SortedNVs,SortedIndices).
maplist(split_parent, Sorted, SortedNVs,SortedIndices).
split_parents([], [], []).
split_parents([I-V|Sorted], [V|SortedNVs],[I|SortedIndices]) :-
split_parents(Sorted, SortedNVs, SortedIndices).
split_parent(I-V, V, I).
vars2indices([],_,[]).
vars2indices([V|Parents],Keys,[I-V|IParents]) :-
rb_lookup(V, I, Keys),
vars2indices(Parents,Keys,IParents).
var2index(Keys, V, I-V) :-
rb_lookup(V, I, Keys).
%
% This is the really cool bit.
%
compile_graph(Graph) :-
Graph =.. [_|VarsInfo],
compile_vars(VarsInfo,Graph).
maplist( compile_var(Graph), VarsInfo).
compile_vars([],_).
compile_vars([var(_,I,_,Vals,Sz,VarSlot,Parents,_,_)|VarsInfo],Graph)
:-
compile_var(I,Vals,Sz,VarSlot,Parents,Graph),
compile_vars(VarsInfo,Graph).
compile_var(I,Vals,Sz,VarSlot,Parents,Graph) :-
fetch_all_parents(VarSlot,Graph,[],Parents,[],Sizes),
mult_list(Sizes,1,TotSize),
compile_var(Graph, var(_,I,_,Vals,Sz,VarSlot,Parents,_,_)) :-
foldl2( fetch_parent(Graph), VarSlot, [], Parents, [], Sizes),
foldl( mult_list, Sizes,1,TotSize),
compile_var(TotSize,I,Vals,Sz,VarSlot,Parents,Sizes,Graph).
fetch_all_parents([],_,Parents,Parents,Sizes,Sizes) :- !.
fetch_all_parents([tabular(_,_,Ps)|CPTs],Graph,Parents0,ParentsF,Sizes0,SizesF) :-
merge_these_parents(Ps,Graph,Parents0,ParentsI,Sizes0,SizesI),
fetch_all_parents(CPTs,Graph,ParentsI,ParentsF,SizesI,SizesF).
fetch_parent(Graph, tabular(_,_,Ps), Parents0, ParentsF, Sizes0, SizesF) :-
foldl2( merge_these_parents(Graph), Ps, Parents0, ParentsF, Sizes0, SizesF).
merge_these_parents([],_,Parents,Parents,Sizes,Sizes).
merge_these_parents([I|Ps],Graph,Parents0,ParentsF,Sizes0,SizesF) :-
member(I,Parents0), !,
merge_these_parents(Ps,Graph,Parents0,ParentsF,Sizes0,SizesF).
merge_these_parents([I|Ps],Graph,Parents0,ParentsF,Sizes0,SizesF) :-
merge_these_parents(_Graph, I,Parents0,Parents0,Sizes0,Sizes0) :-
member(I,Parents0), !.
merge_these_parents(Graph, I, Parents0,ParentsF,Sizes0,SizesF) :-
arg(I,Graph,var(_,I,_,Vals,_,_,_,_,_)),
length(Vals, Sz),
add_parent(Parents0,I,ParentsI,Sizes0,Sz,SizesI),
merge_these_parents(Ps,Graph,ParentsI,ParentsF,SizesI,SizesF).
add_parent(Parents0,I,ParentsF,Sizes0,Sz,SizesF).
add_parent([],I,[I],[],Sz,[Sz]).
add_parent([P|Parents0],I,[I,P|Parents0],Sizes0,Sz,[Sz|Sizes0]) :-
@ -234,10 +211,8 @@ add_parent([P|Parents0],I,[P|ParentsI],[S|Sizes0],Sz,[S|SizesI]) :-
add_parent(Parents0,I,ParentsI,Sizes0,Sz,SizesI).
mult_list([],Mult,Mult).
mult_list([Sz|Sizes],Mult0,Mult) :-
MultI is Sz*Mult0,
mult_list(Sizes,MultI,Mult).
mult_list(Sz,Mult0,Mult) :-
Mult is Sz*Mult0.
% compile node as set of facts, faster execution
compile_var(TotSize,I,_Vals,Sz,CPTs,Parents,_Sizes,Graph) :-
@ -247,7 +222,7 @@ compile_var(TotSize,I,_Vals,Sz,CPTs,Parents,_Sizes,Graph) :-
compile_var(_,_,_,_,_,_,_,_).
multiply_all(I,Parents,CPTs,Sz,Graph) :-
markov_blanket_instance(Parents,Graph,Values),
maplist( markov_blanket_instance(Graph), Parents, Values),
(
multiply_all(CPTs,Graph,Probs)
->
@ -261,11 +236,9 @@ multiply_all(I,_,_,_,_) :-
% note: what matters is how this predicate instantiates the temp
% slot in the graph!
markov_blanket_instance([],_,[]).
markov_blanket_instance([I|Parents],Graph,[Pos|Values]) :-
arg(I,Graph,var(_,I,Pos,Vals,_,_,_,_,_)),
fetch_val(Vals,0,Pos),
markov_blanket_instance(Parents,Graph,Values).
markov_blanket_instance(Graph, I, Pos) :-
arg(I, Graph, var(_,I,Pos,Vals,_,_,_,_,_)),
fetch_val(Vals, 0, Pos).
% backtrack through every value in domain
%
@ -275,21 +248,19 @@ fetch_val([_|Vals],I0,Pos) :-
fetch_val(Vals,I,Pos).
multiply_all([tabular(Table,_,Parents)|CPTs],Graph,Probs) :-
fetch_parents(Parents, Graph, Vals),
maplist( fetch_parent(Graph), Parents, Vals),
column_from_possibly_deterministic_CPT(Table,Vals,Probs0),
multiply_more(CPTs,Graph,Probs0,Probs).
fetch_parents([], _, []).
fetch_parents([P|Parents], Graph, [Val|Vals]) :-
arg(P,Graph,var(_,_,Val,_,_,_,_,_,_)),
fetch_parents(Parents, Graph, Vals).
fetch_parent(Graph, P, Val) :-
arg(P,Graph,var(_,_,Val,_,_,_,_,_,_)).
multiply_more([],_,Probs0,LProbs) :-
normalise_possibly_deterministic_CPT(Probs0, Probs),
list_from_CPT(Probs, LProbs0),
accumulate_up_list(LProbs0, 0.0, LProbs).
multiply_more([tabular(Table,_,Parents)|CPTs],Graph,Probs0,Probs) :-
fetch_parents(Parents, Graph, Vals),
maplist( fetch_parent(Graph), Parents, Vals),
column_from_possibly_deterministic_CPT(Table, Vals, P0),
multiply_possibly_deterministic_factors(Probs0, P0, ProbsI),
multiply_more(CPTs,Graph,ProbsI,Probs).
@ -378,7 +349,7 @@ process_chains(0,_,F,F,_,_,Est,Est) :- !.
process_chains(ToDo,VarOrder,End,Start,Graph,Len,Est0,Estf) :-
%format('ToDo = ~d~n',[ToDo]),
process_chains(Start,VarOrder,Int,Graph,Len,Est0,Esti),
% (ToDo mod 100 =:= 1 -> statistics,cvt2problist(Esti, Probs), Int =[S|_], format('did ~d: ~w~n ~w~n',[ToDo,Probs,S]) ; true),
% (ToDo mod 100 =:= 1 -> statistics,maplist(cvt2prob, Esti, Probs), Int =[S|_], format('did ~d: ~w~n ~w~n',[ToDo,Probs,S]) ; true),
ToDo1 is ToDo-1,
process_chains(ToDo1,VarOrder,End,Int,Graph,Len,Esti,Estf).
@ -388,7 +359,7 @@ process_chains([Sample0|Samples0], VarOrder, [Sample|Samples], Graph, SampLen,[E
functor(Sample,sample,SampLen),
do_sample(VarOrder,Sample,Sample0,Graph),
% format('Sample = ~w~n',[Sample]),
update_estimates(E0,Sample,Ef),
maplist(update_estimate(Sample), E0, Ef),
process_chains(Samples0, VarOrder, Samples, Graph, SampLen,E0s,Efs).
do_sample([],_,_,_).
@ -439,15 +410,10 @@ pick_new_value([V|Vals],X,I0,Val) :-
pick_new_value(Vals,X,I,Val)
).
update_estimates([],_,[]).
update_estimates([Est|E0],Sample,[NEst|Ef]) :-
update_estimate(Est,Sample,NEst),
update_estimates(E0,Sample,Ef).
update_estimate([I|E],Sample,[I|NE]) :-
update_estimate(Sample, [I|E],[I|NE]) :-
arg(I,Sample,V),
update_estimate_for_var(V,E,NE).
update_estimate(me(Is,Mult,E),Sample,me(Is,Mult,NE)) :-
update_estimate(Sample,me(Is,Mult,E),me(Is,Mult,NE)) :-
get_estimate_pos(Is, Sample, Mult, 0, V),
update_estimate_for_var(V,E,NE).
@ -481,21 +447,15 @@ clean_up.
gibbs_params(5,100,1000).
cvt2problist([], []).
cvt2problist([[[_|E]]|Est0], [Ps|Probs]) :-
sum_all(E,0,Sum),
do_probs(E,Sum,Ps),
cvt2problist(Est0, Probs) .
cvt2prob([[_|E]], Ps) :-
foldl(sum_all, E, 0, Sum),
maplist( do_prob(Sum), E, Ps).
sum_all([],Sum,Sum).
sum_all([E|Es],S0,Sum) :-
SI is S0+E,
sum_all(Es,SI,Sum).
sum_all(E, S0, Sum) :-
Sum is S0+E.
do_probs([],_,[]).
do_probs([E|Es],Sum,[P|Ps]) :-
P is E/Sum,
do_probs(Es,Sum,Ps).
do_prob(Sum, E, P) :-
P is E/Sum.
show_sorted([], _) :- nl.
show_sorted([I|VarOrder], Graph) :-
@ -506,13 +466,11 @@ show_sorted([I|VarOrder], Graph) :-
sum_up_all([[]|_], []).
sum_up_all([[C|MoreC]|Chains], [Dist|Dists]) :-
extract_sums(Chains, CurrentChains, LeftChains),
maplist( extract_sum, Chains, CurrentChains, LeftChains),
sum_up([C|CurrentChains], Dist),
sum_up_all([MoreC|LeftChains], Dists).
extract_sums([], [], []).
extract_sums([[C|Chains]|MoreChains], [C|CurrentChains], [Chains|LeftChains]) :-
extract_sums(MoreChains, CurrentChains, LeftChains).
extract_sum([C|Chains], C, Chains).
sum_up([[_|Counts]|Chains], Dist) :-
add_up(Counts,Chains, Add),
@ -523,25 +481,21 @@ sum_up([me(_,_,Counts)|Chains], Dist) :-
add_up(Counts,[],Counts).
add_up(Counts,[[_|Cs]|Chains], Add) :-
sum_lists(Counts, Cs, NCounts),
maplist(sum, Counts, Cs, NCounts),
add_up(NCounts, Chains, Add).
add_up_mes(Counts,[],Counts).
add_up_mes(Counts,[me(_,_,Cs)|Chains], Add) :-
sum_lists(Counts, Cs, NCounts),
maplist( sum_list, Counts, Cs, NCounts),
add_up_mes(NCounts, Chains, Add).
sum_lists([],[],[]).
sum_lists([Count|Counts], [C|Cs], [NC|NCounts]) :-
NC is Count+C,
sum_lists(Counts, Cs, NCounts).
sum(Count, C, NC) :-
NC is Count+C.
normalise(Add, Dist) :-
sum_list(Add, Sum),
divide_list(Add, Sum, Dist).
maplist(divide(Sum), Add, Dist).
divide_list([], _, []).
divide_list([C|Add], Sum, [P|Dist]) :-
P is C/Sum,
divide_list(Add, Sum, Dist).
divide(Sum, C, P) :-
P is C/Sum.

View File

@ -32,7 +32,7 @@
[clpbn_bind_vals/3]).
:- use_module(library(pfl),
[get_pfl_parameters/2,
[get_pfl_parameters/3,
skolem/2
]).
@ -50,7 +50,7 @@ call_horus_ground_solver(QueryVars, QueryKeys, AllKeys, Factors, Evidence,
end_horus_ground_solver(State).
init_horus_ground_solver(QueryKeys, AllKeys, Factors, Evidence,
init_horus_ground_solver(_QueryKeys, AllKeys, Factors, Evidence,
state(Network,Hash,Id,DistIds)) :-
factors_type(Factors, Type),
keys_to_numbers(AllKeys, Factors, Evidence, Hash, Id, FacIds, EvIds),
@ -64,10 +64,10 @@ init_horus_ground_solver(QueryKeys, AllKeys, Factors, Evidence,
run_horus_ground_solver(QueryKeys, Solutions,
state(Network,Hash,Id, DistIds)) :-
state(Network,Hash,Id, _DistIds)) :-
lists_of_keys_to_ids(QueryKeys, QueryIds, Hash, _, Id, _),
%maplist(get_pfl_parameters, DistIds, DistParams),
%cpp_set_factors_params(Network, DistIds, DistParams),
%maplist(get_pfl_parameters, _DistIds, _, DistParams),
%cpp_set_factors_params(Network, _DistIds, DistParams),
cpp_run_ground_solver(Network, QueryIds, Solutions).
@ -79,7 +79,7 @@ factors_type([f(bayes, _, _)|_], bayes) :- ! .
factors_type([f(markov, _, _)|_], markov) :- ! .
get_dist_id(f(_, _, _, DistId), DistId).
get_dist_id(fn(_, _, _, DistId, _), DistId).
get_domain(_:Key, Domain) :- !,

View File

@ -28,7 +28,7 @@
:- use_module(library(pfl),
[factor/6,
skolem/2,
get_pfl_parameters/2
get_pfl_parameters/3
]).
:- use_module(library(maplist)).
@ -50,9 +50,9 @@ init_horus_lifted_solver(_, AllVars, _, state(Network, DistIds)) :-
sort(DistIds0, DistIds).
run_horus_lifted_solver(QueryVars, Solutions, state(Network, DistIds)) :-
run_horus_lifted_solver(QueryVars, Solutions, state(Network, _DistIds)) :-
maplist(get_query_keys, QueryVars, QueryKeys),
%maplist(get_pfl_parameters, DistIds,DistsParams),
%maplist(get_pfl_parameters, DistIds, _, DistsParams),
%cpp_set_parfactors_params(Network, DistIds, DistsParams),
cpp_run_lifted_solver(Network, QueryKeys, Solutions).

View File

@ -41,7 +41,7 @@ key_to_id(Key, I0, Hash0, Hash, I0, I) :-
b_hash_insert(Hash0, Key, I0, Hash),
I is I0+1.
factor_to_id(Ev, f(_, DistId, Keys), f(Ids, Ranges, CPT, DistId), Hash0, Hash, I0, I) :-
factor_to_id(Ev, f(_, DistId, Keys), fn(Ids, Ranges, CPT, DistId, Keys), Hash0, Hash, I0, I) :-
get_pfl_cpt(DistId, Keys, Ev, NKeys, CPT),
foldl2(key_to_id, NKeys, Ids, Hash0, Hash, I0, I),
maplist(get_range, Keys, Ranges).

View File

@ -131,9 +131,9 @@ init_ve(FactorIds, EvidenceIds, Hash, Id, ve(FactorIds, Hash, Id, Ev)) :-
evtotree(K=V,Ev0,Ev) :-
rb_insert(Ev0, K, V, Ev).
factor_to_graph( f(Nodes, Sizes, _Pars0, Id), Factors0, Factors, Edges0, Edges, I0, I) :-
factor_to_graph( fn(Nodes, Sizes, _Pars0, Id, Keys), Factors0, Factors, Edges0, Edges, I0, I) :-
I is I0+1,
pfl:get_pfl_parameters(Id, Pars0),
pfl:get_pfl_parameters(Id, Keys, Pars0),
init_CPT(Pars0, Sizes, CPT0),
reorder_CPT(Nodes, CPT0, FIPs, CPT, _),
F = f(I0, FIPs, CPT),

View File

@ -0,0 +1,137 @@
10
2
0 1
2 2
4
0 1.02
1 0.87
2 0.88
3 0.45
4
1 2 3 4
2 2 3 3
36
0 0.11
1 1.11
2 0.41
3 0.12
4 0.1
5 0.17
6 1.21
7 1.1
8 0.11
9 0.41
10 0.8
11 0.71
12 0.14
13 0.24
14 0.54
15 1.4
16 0.23
17 0.24
18 0.65
19 0.05
20 0.32
21 0.12
22 0.99
23 0.69
24 0.29
25 1.29
26 0.15
27 1.24
28 0.42
29 0.124
30 0.67
31 0.078
32 0.14
33 0.55
34 0.45
35 0.1
3
2 5 6
2 2 3
12
0 0.15
1 0.55
2 2.21
3 5.71
4 0.44
5 0.14
6 0.5
7 1.75
8 1.29
9 3.29
10 0.36
11 1.56
2
7 2
4 2
8
0 0.11
1 0.59
2 0.15
3 0.124
4 0.41
5 2.11
6 1.06
7 0.929
1
3
3
3
0 0.1
1 0.58
2 0.74
1
4
3
3
0 3.2
1 0.28
2 1.24
2
8 4
2 3
6
0 0.19
1 3.1
2 0.49
3 1.5
4 2.1
5 2.8
1
5
2
2
0 0.74
1 0.14
1
6
3
3
0 0.032
1 0.028
2 0.24
2
9 7
2 4
8
0 0.61
1 0.61
2 1.4
3 0.24
4 0.09
5 0.19
6 1.4
7 0.6

View File

@ -0,0 +1,79 @@
:- use_module(library(pfl)).
%:- set_solver(ve).
%:- set_solver(hve).
%:- set_solver(jt).
%:- set_solver(bdd).
%:- set_solver(bp).
%:- set_solver(cbp).
%:- set_solver(gibbs).
%:- set_solver(lve).
%:- set_solver(lkc).
%:- set_solver(lbp).
/*
v01 v02
\ /
\ /
\ /
v03 v04 v05
/ \ | / \
/ \ | / \
/ \ | / \
v06 v07 v08
| |
| |
| |
v09 v10
*/
markov v01::[a,b] ; table1 ; [].
markov v02::[a,b,c] ; table2 ; [].
markov v03::[a,b], v01, v02 ; table3 ; [].
markov v04::[a,b,c] ; table4 ; [].
markov v05::[a,b,c] ; table5 ; [].
markov v06::[a,b,c,d], v03 ; table6 ; [].
markov v07::[a,b], v03, v04, v05 ; table7 ; [].
markov v08::[a,b], v05 ; table8 ; [].
markov v09::[a,b], v06 ; table9 ; [].
markov v10::[a,b], v07 ; table10 ; [].
table1([ 0.74, 0.14 ]).
table2([ 0.032, 0.028, 0.24 ]).
table3([
0.15, 0.44, 1.29, 2.21, 0.5, 0.36,
0.55, 0.14, 3.29, 5.71, 1.75, 1.56
]).
table4([ 0.1, 0.58, 0.74 ]).
table5([ 3.2, 0.28, 1.24 ]).
table6([ 0.11, 0.41, 0.59, 2.11, 0.15, 1.06, 0.124, 0.929 ]).
table7([
0.11, 0.14, 0.29, 0.1, 0.23, 0.42, 0.11, 0.32, 0.14,
0.41, 0.54, 0.15, 1.21, 0.65, 0.67, 0.8, 0.99, 0.45,
1.11, 0.24, 1.29, 0.17, 0.24, 0.124, 0.41, 0.12, 0.55,
0.12, 1.4, 1.24, 1.1, 0.05, 0.078, 0.71, 0.69, 0.1
]).
table8([ 0.19, 0.49, 2.1, 3.1, 1.5, 2.8 ]).
table9([ 0.61, 1.4, 0.09, 1.4, 0.61, 0.24, 0.19, 0.6 ]).
table10([ 1.02, 0.88, 0.87, 0.45 ]).

View File

@ -3,6 +3,25 @@
#include "BayesBall.h"
namespace Horus {
BayesBall::BayesBall (FactorGraph& fg)
: fg_(fg) , dag_(fg.getStructure())
{
dag_.clear();
}
FactorGraph*
BayesBall::getMinimalFactorGraph (FactorGraph& fg, VarIds vids)
{
BayesBall bb (fg);
return bb.getMinimalFactorGraph (vids);
}
FactorGraph*
BayesBall::getMinimalFactorGraph (const VarIds& queryIds)
{
@ -19,22 +38,22 @@ BayesBall::getMinimalFactorGraph (const VarIds& queryIds)
BBNode* n = sch.node;
n->setAsVisited();
if (n->hasEvidence() == false && sch.visitedFromChild) {
if (n->isMarkedOnTop() == false) {
n->markOnTop();
if (n->isMarkedAbove() == false) {
n->markAbove();
scheduleParents (n, scheduling);
}
if (n->isMarkedOnBottom() == false) {
n->markOnBottom();
if (n->isMarkedBelow() == false) {
n->markBelow();
scheduleChilds (n, scheduling);
}
}
if (sch.visitedFromParent) {
if (n->hasEvidence() && n->isMarkedOnTop() == false) {
n->markOnTop();
if (n->hasEvidence() && n->isMarkedAbove() == false) {
n->markAbove();
scheduleParents (n, scheduling);
}
if (n->hasEvidence() == false && n->isMarkedOnBottom() == false) {
n->markOnBottom();
if (n->hasEvidence() == false && n->isMarkedBelow() == false) {
n->markBelow();
scheduleChilds (n, scheduling);
}
}
@ -55,7 +74,7 @@ BayesBall::constructGraph (FactorGraph* fg) const
for (size_t i = 0; i < facNodes.size(); i++) {
const BBNode* n = dag_.getNode (
facNodes[i]->factor().argument (0));
if (n->isMarkedOnTop()) {
if (n->isMarkedAbove()) {
fg->addFactor (facNodes[i]->factor());
} else if (n->hasEvidence() && n->isVisited()) {
VarIds varIds = { facNodes[i]->factor().argument (0) };
@ -76,3 +95,5 @@ BayesBall::constructGraph (FactorGraph* fg) const
}
}
} // namespace Horus

View File

@ -1,5 +1,5 @@
#ifndef HORUS_BAYESBALL_H
#define HORUS_BAYESBALL_H
#ifndef YAP_PACKAGES_CLPBN_HORUS_BAYESBALL_H_
#define YAP_PACKAGES_CLPBN_HORUS_BAYESBALL_H_
#include <vector>
#include <queue>
@ -9,41 +9,28 @@
#include "BayesBallGraph.h"
#include "Horus.h"
using namespace std;
namespace Horus {
struct ScheduleInfo
{
ScheduleInfo (BBNode* n, bool vfp, bool vfc)
: node(n), visitedFromParent(vfp), visitedFromChild(vfc) { }
BBNode* node;
bool visitedFromParent;
bool visitedFromChild;
};
typedef queue<ScheduleInfo, list<ScheduleInfo>> Scheduling;
class BayesBall
{
class BayesBall {
public:
BayesBall (FactorGraph& fg)
: fg_(fg) , dag_(fg.getStructure())
{
dag_.clear();
}
BayesBall (FactorGraph& fg);
FactorGraph* getMinimalFactorGraph (const VarIds&);
static FactorGraph* getMinimalFactorGraph (FactorGraph& fg, VarIds vids)
{
BayesBall bb (fg);
return bb.getMinimalFactorGraph (vids);
}
static FactorGraph* getMinimalFactorGraph (FactorGraph& fg, VarIds vids);
private:
struct ScheduleInfo {
ScheduleInfo (BBNode* n, bool vfp, bool vfc)
: node(n), visitedFromParent(vfp), visitedFromChild(vfc) { }
BBNode* node;
bool visitedFromParent;
bool visitedFromChild;
};
typedef std::queue<ScheduleInfo, std::list<ScheduleInfo>> Scheduling;
void constructGraph (FactorGraph* fg) const;
@ -51,9 +38,8 @@ class BayesBall
void scheduleChilds (const BBNode* n, Scheduling& sch) const;
FactorGraph& fg_;
BayesBallGraph& dag_;
FactorGraph& fg_;
BayesBallGraph& dag_;
};
@ -61,8 +47,8 @@ class BayesBall
inline void
BayesBall::scheduleParents (const BBNode* n, Scheduling& sch) const
{
const vector<BBNode*>& ps = n->parents();
for (vector<BBNode*>::const_iterator it = ps.begin();
const std::vector<BBNode*>& ps = n->parents();
for (std::vector<BBNode*>::const_iterator it = ps.begin();
it != ps.end(); ++it) {
sch.push (ScheduleInfo (*it, false, true));
}
@ -73,12 +59,14 @@ BayesBall::scheduleParents (const BBNode* n, Scheduling& sch) const
inline void
BayesBall::scheduleChilds (const BBNode* n, Scheduling& sch) const
{
const vector<BBNode*>& cs = n->childs();
for (vector<BBNode*>::const_iterator it = cs.begin();
const std::vector<BBNode*>& cs = n->childs();
for (std::vector<BBNode*>::const_iterator it = cs.begin();
it != cs.end(); ++it) {
sch.push (ScheduleInfo (*it, true, false));
}
}
#endif // HORUS_BAYESBALL_H
} // namespace Horus
#endif // YAP_PACKAGES_CLPBN_HORUS_BAYESBALL_H_

View File

@ -1,14 +1,15 @@
#include <cstdlib>
#include <cassert>
#include <iostream>
#include <sstream>
#include <fstream>
#include <sstream>
#include "BayesBallGraph.h"
#include "Util.h"
namespace Horus {
void
BayesBallGraph::addNode (BBNode* n)
{
@ -22,8 +23,8 @@ BayesBallGraph::addNode (BBNode* n)
void
BayesBallGraph::addEdge (VarId vid1, VarId vid2)
{
unordered_map<VarId, BBNode*>::iterator it1;
unordered_map<VarId, BBNode*>::iterator it2;
std::unordered_map<VarId, BBNode*>::iterator it1;
std::unordered_map<VarId, BBNode*>::iterator it2;
it1 = varMap_.find (vid1);
it2 = varMap_.find (vid2);
assert (it1 != varMap_.end());
@ -37,7 +38,7 @@ BayesBallGraph::addEdge (VarId vid1, VarId vid2)
const BBNode*
BayesBallGraph::getNode (VarId vid) const
{
unordered_map<VarId, BBNode*>::const_iterator it;
std::unordered_map<VarId, BBNode*>::const_iterator it;
it = varMap_.find (vid);
return it != varMap_.end() ? it->second : 0;
}
@ -47,7 +48,7 @@ BayesBallGraph::getNode (VarId vid) const
BBNode*
BayesBallGraph::getNode (VarId vid)
{
unordered_map<VarId, BBNode*>::const_iterator it;
std::unordered_map<VarId, BBNode*>::const_iterator it;
it = varMap_.find (vid);
return it != varMap_.end() ? it->second : 0;
}
@ -55,7 +56,7 @@ BayesBallGraph::getNode (VarId vid)
void
BayesBallGraph::setIndexes (void)
BayesBallGraph::setIndexes()
{
for (size_t i = 0; i < nodes_.size(); i++) {
nodes_[i]->setIndex (i);
@ -65,7 +66,7 @@ BayesBallGraph::setIndexes (void)
void
BayesBallGraph::clear (void)
BayesBallGraph::clear()
{
for (size_t i = 0; i < nodes_.size(); i++) {
nodes_[i]->clear();
@ -77,13 +78,14 @@ BayesBallGraph::clear (void)
void
BayesBallGraph::exportToGraphViz (const char* fileName)
{
ofstream out (fileName);
std::ofstream out (fileName);
if (!out.is_open()) {
cerr << "Error: couldn't open file '" << fileName << "'." ;
std::cerr << "Error: couldn't open file '" << fileName << "'." ;
std::cerr << std::endl;
return;
}
out << "digraph {" << endl;
out << "ranksep=1" << endl;
out << "digraph {" << std::endl;
out << "ranksep=1" << std::endl;
for (size_t i = 0; i < nodes_.size(); i++) {
out << nodes_[i]->varId() ;
out << " [" ;
@ -91,16 +93,18 @@ BayesBallGraph::exportToGraphViz (const char* fileName)
if (nodes_[i]->hasEvidence()) {
out << ",style=filled, fillcolor=yellow" ;
}
out << "]" << endl;
out << "]" << std::endl;
}
for (size_t i = 0; i < nodes_.size(); i++) {
const vector<BBNode*>& childs = nodes_[i]->childs();
const std::vector<BBNode*>& childs = nodes_[i]->childs();
for (size_t j = 0; j < childs.size(); j++) {
out << nodes_[i]->varId() << " -> " << childs[j]->varId();
out << " [style=bold]" << endl ;
out << " [style=bold]" << std::endl;
}
}
out << "}" << endl;
out << "}" << std::endl;
out.close();
}
} // namespace Horus

View File

@ -1,5 +1,5 @@
#ifndef HORUS_BAYESBALLGRAPH_H
#define HORUS_BAYESBALLGRAPH_H
#ifndef YAP_PACKAGES_CLPBN_HORUS_BAYESBALLGRAPH_H_
#define YAP_PACKAGES_CLPBN_HORUS_BAYESBALLGRAPH_H_
#include <vector>
#include <unordered_map>
@ -7,54 +7,55 @@
#include "Var.h"
#include "Horus.h"
using namespace std;
class BBNode : public Var
{
namespace Horus {
class BBNode : public Var {
public:
BBNode (Var* v) : Var (v), visited_(false),
markedOnTop_(false), markedOnBottom_(false) { }
markedAbove_(false), markedBelow_(false) { }
const vector<BBNode*>& childs (void) const { return childs_; }
const std::vector<BBNode*>& childs() const { return childs_; }
vector<BBNode*>& childs (void) { return childs_; }
std::vector<BBNode*>& childs() { return childs_; }
const vector<BBNode*>& parents (void) const { return parents_; }
const std::vector<BBNode*>& parents() const { return parents_; }
vector<BBNode*>& parents (void) { return parents_; }
std::vector<BBNode*>& parents() { return parents_; }
void addParent (BBNode* p) { parents_.push_back (p); }
void addChild (BBNode* c) { childs_.push_back (c); }
bool isVisited (void) const { return visited_; }
bool isVisited() const { return visited_; }
void setAsVisited (void) { visited_ = true; }
void setAsVisited() { visited_ = true; }
bool isMarkedOnTop (void) const { return markedOnTop_; }
bool isMarkedAbove() const { return markedAbove_; }
void markOnTop (void) { markedOnTop_ = true; }
void markAbove() { markedAbove_ = true; }
bool isMarkedOnBottom (void) const { return markedOnBottom_; }
bool isMarkedBelow() const { return markedBelow_; }
void markOnBottom (void) { markedOnBottom_ = true; }
void markBelow() { markedBelow_ = true; }
void clear (void) { visited_ = markedOnTop_ = markedOnBottom_ = false; }
void clear() { visited_ = markedAbove_ = markedBelow_ = false; }
private:
bool visited_;
bool markedOnTop_;
bool markedOnBottom_;
bool markedAbove_;
bool markedBelow_;
vector<BBNode*> childs_;
vector<BBNode*> parents_;
std::vector<BBNode*> childs_;
std::vector<BBNode*> parents_;
};
class BayesBallGraph
{
class BayesBallGraph {
public:
BayesBallGraph (void) { }
BayesBallGraph() { }
bool empty() const { return nodes_.empty(); }
void addNode (BBNode* n);
@ -64,19 +65,18 @@ class BayesBallGraph
BBNode* getNode (VarId vid);
bool empty (void) const { return nodes_.empty(); }
void setIndexes();
void setIndexes (void);
void clear (void);
void clear();
void exportToGraphViz (const char*);
private:
vector<BBNode*> nodes_;
unordered_map<VarId, BBNode*> varMap_;
std::vector<BBNode*> nodes_;
std::unordered_map<VarId, BBNode*> varMap_;
};
#endif // HORUS_BAYESBALLGRAPH_H
} // namespace Horus
#endif // YAP_PACKAGES_CLPBN_HORUS_BAYESBALLGRAPH_H_

View File

@ -1,37 +1,39 @@
#include <cassert>
#include <algorithm>
#include <iostream>
#include <iomanip>
#include <sstream>
#include "BeliefProp.h"
#include "Indexer.h"
#include "Horus.h"
double BeliefProp::accuracy_ = 0.0001;
unsigned BeliefProp::maxIter_ = 1000;
MsgSchedule BeliefProp::schedule_ = MsgSchedule::SEQ_FIXED;
namespace Horus {
double BeliefProp::accuracy_ = 0.0001;
unsigned BeliefProp::maxIter_ = 1000;
BeliefProp::MsgSchedule BeliefProp::schedule_ =
MsgSchedule::seqFixedSch;
BeliefProp::BeliefProp (const FactorGraph& fg) : GroundSolver (fg)
BeliefProp::BeliefProp (const FactorGraph& fg)
: GroundSolver (fg), nIters_(0), runned_(false)
{
runned_ = false;
}
BeliefProp::~BeliefProp (void)
BeliefProp::~BeliefProp()
{
for (size_t i = 0; i < varsI_.size(); i++) {
delete varsI_[i];
}
for (size_t i = 0; i < facsI_.size(); i++) {
delete facsI_[i];
}
for (size_t i = 0; i < links_.size(); i++) {
delete links_[i];
}
links_.clear();
}
@ -48,22 +50,22 @@ BeliefProp::solveQuery (VarIds queryVids)
void
BeliefProp::printSolverFlags (void) const
BeliefProp::printSolverFlags() const
{
stringstream ss;
std::stringstream ss;
ss << "belief propagation [" ;
ss << "bp_msg_schedule=" ;
switch (schedule_) {
case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break;
case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break;
case MsgSchedule::PARALLEL: ss << "parallel"; break;
case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break;
case MsgSchedule::seqFixedSch: ss << "seq_fixed"; break;
case MsgSchedule::seqRandomSch: ss << "seq_random"; break;
case MsgSchedule::parallelSch: ss << "parallel"; break;
case MsgSchedule::maxResidualSch: ss << "max_residual"; break;
}
ss << ",bp_max_iter=" << Util::toString (maxIter_);
ss << ",bp_accuracy=" << Util::toString (accuracy_);
ss << ",log_domain=" << Util::toString (Globals::logDomain);
ss << ",bp_max_iter=" << Util::toString (maxIter_);
ss << ",bp_accuracy=" << Util::toString (accuracy_);
ss << ",log_domain=" << Util::toString (Globals::logDomain);
ss << "]" ;
cout << ss.str() << endl;
std::cout << ss.str() << std::endl;
}
@ -82,7 +84,7 @@ BeliefProp::getPosterioriOf (VarId vid)
probs[var->getEvidence()] = LogAware::withEvidence();
} else {
probs.resize (var->range(), LogAware::multIdenty());
const BpLinks& links = ninf(var)->getLinks();
const BpLinks& links = getLinks (var);
if (Globals::logDomain) {
for (size_t i = 0; i < links.size(); i++) {
probs += links[i]->message();
@ -133,7 +135,7 @@ BeliefProp::getFactorJoint (
runSolver();
}
Factor res (fn->factor());
const BpLinks& links = ninf(fn)->getLinks();
const BpLinks& links = getLinks( fn);
for (size_t i = 0; i < links.size(); i++) {
Factor msg ({links[i]->varNode()->varId()},
{links[i]->varNode()->range()},
@ -152,26 +154,119 @@ BeliefProp::getFactorJoint (
BeliefProp::BpLink::BpLink (FacNode* fn, VarNode* vn)
{
fac_ = fn;
var_ = vn;
v1_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
v2_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
currMsg_ = &v1_;
nextMsg_ = &v2_;
residual_ = 0.0;
}
void
BeliefProp::runSolver (void)
BeliefProp::BpLink::clearResidual()
{
residual_ = 0.0;
}
void
BeliefProp::BpLink::updateResidual()
{
residual_ = LogAware::getMaxNorm (v1_, v2_);
}
void
BeliefProp::BpLink::updateMessage()
{
swap (currMsg_, nextMsg_);
}
std::string
BeliefProp::BpLink::toString() const
{
std::stringstream ss;
ss << fac_->getLabel();
ss << " -- " ;
ss << var_->label();
return ss.str();
}
void
BeliefProp::calculateAndUpdateMessage (BpLink* link, bool calcResidual)
{
if (Globals::verbosity > 2) {
std::cout << "calculating & updating " << link->toString();
std::cout << std::endl;
}
calcFactorToVarMsg (link);
if (calcResidual) {
link->updateResidual();
}
link->updateMessage();
}
void
BeliefProp::calculateMessage (BpLink* link, bool calcResidual)
{
if (Globals::verbosity > 2) {
std::cout << "calculating " << link->toString();
std::cout << std::endl;
}
calcFactorToVarMsg (link);
if (calcResidual) {
link->updateResidual();
}
}
void
BeliefProp::updateMessage (BpLink* link)
{
link->updateMessage();
if (Globals::verbosity > 2) {
std::cout << "updating " << link->toString();
std::cout << std::endl;
}
}
void
BeliefProp::runSolver()
{
initializeSolver();
nIters_ = 0;
while (!converged() && nIters_ < maxIter_) {
nIters_ ++;
if (Globals::verbosity > 1) {
Util::printHeader (string ("Iteration ") + Util::toString (nIters_));
Util::printHeader (std::string ("Iteration ")
+ Util::toString (nIters_));
}
switch (schedule_) {
case MsgSchedule::SEQ_RANDOM:
case MsgSchedule::seqRandomSch:
std::random_shuffle (links_.begin(), links_.end());
// no break
case MsgSchedule::SEQ_FIXED:
case MsgSchedule::seqFixedSch:
for (size_t i = 0; i < links_.size(); i++) {
calculateAndUpdateMessage (links_[i]);
}
break;
case MsgSchedule::PARALLEL:
case MsgSchedule::parallelSch:
for (size_t i = 0; i < links_.size(); i++) {
calculateMessage (links_[i]);
}
@ -179,20 +274,21 @@ BeliefProp::runSolver (void)
updateMessage(links_[i]);
}
break;
case MsgSchedule::MAX_RESIDUAL:
case MsgSchedule::maxResidualSch:
maxResidualSchedule();
break;
}
}
if (Globals::verbosity > 0) {
if (nIters_ < maxIter_) {
cout << "Belief propagation converged in " ;
cout << nIters_ << " iterations" << endl;
std::cout << "Belief propagation converged in " ;
std::cout << nIters_ << " iterations" << std::endl;
} else {
cout << "The maximum number of iterations was hit, terminating..." ;
cout << endl;
std::cout << "The maximum number of iterations was hit," ;
std::cout << " terminating..." ;
std::cout << std::endl;
}
cout << endl;
std::cout << std::endl;
}
runned_ = true;
}
@ -200,7 +296,7 @@ BeliefProp::runSolver (void)
void
BeliefProp::createLinks (void)
BeliefProp::createLinks()
{
const FacNodes& facNodes = fg.facNodes();
for (size_t i = 0; i < facNodes.size(); i++) {
@ -214,7 +310,7 @@ BeliefProp::createLinks (void)
void
BeliefProp::maxResidualSchedule (void)
BeliefProp::maxResidualSchedule()
{
if (nIters_ == 1) {
for (size_t i = 0; i < links_.size(); i++) {
@ -227,11 +323,13 @@ BeliefProp::maxResidualSchedule (void)
for (size_t c = 0; c < links_.size(); c++) {
if (Globals::verbosity > 1) {
cout << "current residuals:" << endl;
std::cout << "current residuals:" << std::endl;
for (SortedOrder::iterator it = sortedOrder_.begin();
it != sortedOrder_.end(); ++it) {
cout << " " << setw (30) << left << (*it)->toString();
cout << "residual = " << (*it)->residual() << endl;
std::cout << " " << std::setw (30) << std::left;
std::cout << (*it)->toString();
std::cout << "residual = " << (*it)->residual();
std::cout << std::endl;
}
}
@ -249,7 +347,7 @@ BeliefProp::maxResidualSchedule (void)
const FacNodes& factorNeighbors = link->varNode()->neighbors();
for (size_t i = 0; i < factorNeighbors.size(); i++) {
if (factorNeighbors[i] != link->facNode()) {
const BpLinks& links = ninf(factorNeighbors[i])->getLinks();
const BpLinks& links = getLinks (factorNeighbors[i]);
for (size_t j = 0; j < links.size(); j++) {
if (links[j]->varNode() != link->varNode()) {
calculateMessage (links[j]);
@ -273,7 +371,7 @@ BeliefProp::calcFactorToVarMsg (BpLink* link)
{
FacNode* src = link->facNode();
const VarNode* dst = link->varNode();
const BpLinks& links = ninf(src)->getLinks();
const BpLinks& links = getLinks (src);
// calculate the product of messages that were sent
// to factor `src', except from var `dst'
unsigned reps = 1;
@ -282,14 +380,14 @@ BeliefProp::calcFactorToVarMsg (BpLink* link)
if (Globals::logDomain) {
for (size_t i = links.size(); i-- > 0; ) {
if (links[i]->varNode() != dst) {
if (Constants::SHOW_BP_CALCS) {
cout << " message from " << links[i]->varNode()->label();
cout << ": " ;
if (Constants::showBpCalcs) {
std::cout << " message from " << links[i]->varNode()->label();
std::cout << ": " ;
}
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
reps, std::plus<double>());
if (Constants::SHOW_BP_CALCS) {
cout << endl;
if (Constants::showBpCalcs) {
std::cout << std::endl;
}
}
reps *= links[i]->varNode()->range();
@ -297,14 +395,14 @@ BeliefProp::calcFactorToVarMsg (BpLink* link)
} else {
for (size_t i = links.size(); i-- > 0; ) {
if (links[i]->varNode() != dst) {
if (Constants::SHOW_BP_CALCS) {
cout << " message from " << links[i]->varNode()->label();
cout << ": " ;
if (Constants::showBpCalcs) {
std::cout << " message from " << links[i]->varNode()->label();
std::cout << ": " ;
}
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
reps, std::multiplies<double>());
if (Constants::SHOW_BP_CALCS) {
cout << endl;
if (Constants::showBpCalcs) {
std::cout << std::endl;
}
}
reps *= links[i]->varNode()->range();
@ -313,27 +411,28 @@ BeliefProp::calcFactorToVarMsg (BpLink* link)
Factor result (src->factor().arguments(),
src->factor().ranges(), msgProduct);
result.multiply (src->factor());
if (Constants::SHOW_BP_CALCS) {
cout << " message product: " << msgProduct << endl;
cout << " original factor: " << src->factor().params() << endl;
cout << " factor product: " << result.params() << endl;
if (Constants::showBpCalcs) {
std::cout << " message product: " << msgProduct << std::endl;
std::cout << " original factor: " << src->factor().params();
std::cout << std::endl;
std::cout << " factor product: " << result.params() << std::endl;
}
result.sumOutAllExcept (dst->varId());
if (Constants::SHOW_BP_CALCS) {
cout << " marginalized: " << result.params() << endl;
if (Constants::showBpCalcs) {
std::cout << " marginalized: " << result.params() << std::endl;
}
link->nextMessage() = result.params();
LogAware::normalize (link->nextMessage());
if (Constants::SHOW_BP_CALCS) {
cout << " curr msg: " << link->message() << endl;
cout << " next msg: " << link->nextMessage() << endl;
if (Constants::showBpCalcs) {
std::cout << " curr msg: " << link->message() << std::endl;
std::cout << " next msg: " << link->nextMessage() << std::endl;
}
}
Params
BeliefProp::getVarToFactorMsg (const BpLink* link) const
BeliefProp::getVarToFactorMsg (const BpLink* link)
{
const VarNode* src = link->varNode();
Params msg;
@ -343,18 +442,18 @@ BeliefProp::getVarToFactorMsg (const BpLink* link) const
} else {
msg.resize (src->range(), LogAware::one());
}
if (Constants::SHOW_BP_CALCS) {
cout << msg;
if (Constants::showBpCalcs) {
std::cout << msg;
}
BpLinks::const_iterator it;
const BpLinks& links = ninf (src)->getLinks();
const BpLinks& links = getLinks (src);
if (Globals::logDomain) {
for (it = links.begin(); it != links.end(); ++it) {
if (*it != link) {
msg += (*it)->message();
}
if (Constants::SHOW_BP_CALCS) {
cout << " x " << (*it)->message();
if (Constants::showBpCalcs) {
std::cout << " x " << (*it)->message();
}
}
} else {
@ -362,13 +461,13 @@ BeliefProp::getVarToFactorMsg (const BpLink* link) const
if (*it != link) {
msg *= (*it)->message();
}
if (Constants::SHOW_BP_CALCS) {
cout << " x " << (*it)->message();
if (Constants::showBpCalcs) {
std::cout << " x " << (*it)->message();
}
}
}
if (Constants::SHOW_BP_CALCS) {
cout << " = " << msg;
if (Constants::showBpCalcs) {
std::cout << " = " << msg;
}
return msg;
}
@ -379,37 +478,37 @@ Params
BeliefProp::getJointByConditioning (const VarIds& jointVarIds) const
{
return GroundSolver::getJointByConditioning (
GroundSolverType::BP, fg, jointVarIds);
GroundSolverType::bpSolver, fg, jointVarIds);
}
void
BeliefProp::initializeSolver (void)
BeliefProp::initializeSolver()
{
const VarNodes& varNodes = fg.varNodes();
varsI_.reserve (varNodes.size());
varsLinks_.reserve (varNodes.size());
for (size_t i = 0; i < varNodes.size(); i++) {
varsI_.push_back (new SPNodeInfo());
varsLinks_.push_back (BpLinks());
}
const FacNodes& facNodes = fg.facNodes();
facsI_.reserve (facNodes.size());
facsLinks_.reserve (facNodes.size());
for (size_t i = 0; i < facNodes.size(); i++) {
facsI_.push_back (new SPNodeInfo());
facsLinks_.push_back (BpLinks());
}
createLinks();
for (size_t i = 0; i < links_.size(); i++) {
FacNode* src = links_[i]->facNode();
VarNode* dst = links_[i]->varNode();
ninf (dst)->addBpLink (links_[i]);
ninf (src)->addBpLink (links_[i]);
getLinks (dst).push_back (links_[i]);
getLinks (src).push_back (links_[i]);
}
}
bool
BeliefProp::converged (void)
BeliefProp::converged()
{
if (links_.empty()) {
return true;
@ -418,16 +517,16 @@ BeliefProp::converged (void)
return false;
}
if (Globals::verbosity > 2) {
cout << endl;
std::cout << std::endl;
}
if (nIters_ == 1) {
if (Globals::verbosity > 1) {
cout << "no residuals" << endl << endl;
std::cout << "no residuals" << std::endl << std::endl;
}
return false;
}
bool converged = true;
if (schedule_ == MsgSchedule::MAX_RESIDUAL) {
if (schedule_ == MsgSchedule::maxResidualSch) {
double maxResidual = (*(sortedOrder_.begin()))->residual();
if (maxResidual > accuracy_) {
converged = false;
@ -438,7 +537,8 @@ BeliefProp::converged (void)
for (size_t i = 0; i < links_.size(); i++) {
double residual = links_[i]->residual();
if (Globals::verbosity > 1) {
cout << links_[i]->toString() + " residual = " << residual << endl;
std::cout << links_[i]->toString() + " residual = " << residual;
std::cout << std::endl;
}
if (residual > accuracy_) {
converged = false;
@ -448,7 +548,7 @@ BeliefProp::converged (void)
}
}
if (Globals::verbosity > 1) {
cout << endl;
std::cout << std::endl;
}
}
return converged;
@ -457,8 +557,10 @@ BeliefProp::converged (void)
void
BeliefProp::printLinkInformation (void) const
BeliefProp::printLinkInformation() const
{
using std::cout;
using std::endl;
for (size_t i = 0; i < links_.size(); i++) {
BpLink* l = links_[i];
cout << l->toString() << ":" << endl;
@ -470,3 +572,5 @@ BeliefProp::printLinkInformation (void) const
}
}
} // namespace Horus

View File

@ -1,111 +1,35 @@
#ifndef HORUS_BELIEFPROP_H
#define HORUS_BELIEFPROP_H
#ifndef YAP_PACKAGES_CLPBN_HORUS_BELIEFPROP_H_
#define YAP_PACKAGES_CLPBN_HORUS_BELIEFPROP_H_
#include <set>
#include <vector>
#include <sstream>
#include <set>
#include <string>
#include "GroundSolver.h"
#include "FactorGraph.h"
using namespace std;
enum MsgSchedule {
SEQ_FIXED,
SEQ_RANDOM,
PARALLEL,
MAX_RESIDUAL
};
class BpLink
{
public:
BpLink (FacNode* fn, VarNode* vn)
{
fac_ = fn;
var_ = vn;
v1_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
v2_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
currMsg_ = &v1_;
nextMsg_ = &v2_;
residual_ = 0.0;
}
virtual ~BpLink (void) { };
FacNode* facNode (void) const { return fac_; }
VarNode* varNode (void) const { return var_; }
const Params& message (void) const { return *currMsg_; }
Params& nextMessage (void) { return *nextMsg_; }
double residual (void) const { return residual_; }
void clearResidual (void) { residual_ = 0.0; }
void updateResidual (void)
{
residual_ = LogAware::getMaxNorm (v1_, v2_);
}
virtual void updateMessage (void)
{
swap (currMsg_, nextMsg_);
}
string toString (void) const
{
stringstream ss;
ss << fac_->getLabel();
ss << " -- " ;
ss << var_->label();
return ss.str();
}
protected:
FacNode* fac_;
VarNode* var_;
Params v1_;
Params v2_;
Params* currMsg_;
Params* nextMsg_;
double residual_;
namespace Horus {
class BeliefProp : public GroundSolver {
private:
DISALLOW_COPY_AND_ASSIGN (BpLink);
};
class SPNodeInfo;
typedef vector<BpLink*> BpLinks;
class SPNodeInfo
{
public:
SPNodeInfo (void) { }
void addBpLink (BpLink* link) { links_.push_back (link); }
const BpLinks& getLinks (void) { return links_; }
private:
BpLinks links_;
DISALLOW_COPY_AND_ASSIGN (SPNodeInfo);
};
enum class MsgSchedule {
seqFixedSch,
seqRandomSch,
parallelSch,
maxResidualSch
};
class BeliefProp : public GroundSolver
{
public:
BeliefProp (const FactorGraph&);
virtual ~BeliefProp (void);
virtual ~BeliefProp();
Params solveQuery (VarIds);
virtual void printSolverFlags (void) const;
virtual void printSolverFlags() const;
virtual Params getPosterioriOf (VarId);
@ -113,105 +37,128 @@ class BeliefProp : public GroundSolver
Params getFactorJoint (FacNode* fn, const VarIds&);
static double accuracy (void) { return accuracy_; }
static double accuracy() { return accuracy_; }
static void setAccuracy (double acc) { accuracy_ = acc; }
static unsigned maxIterations (void) { return maxIter_; }
static unsigned maxIterations() { return maxIter_; }
static void setMaxIterations (unsigned mi) { maxIter_ = mi; }
static MsgSchedule msgSchedule (void) { return schedule_; }
static MsgSchedule msgSchedule() { return schedule_; }
static void setMsgSchedule (MsgSchedule sch) { schedule_ = sch; }
protected:
SPNodeInfo* ninf (const VarNode* var) const
{
return varsI_[var->getIndex()];
}
class BpLink {
public:
BpLink (FacNode* fn, VarNode* vn);
SPNodeInfo* ninf (const FacNode* fac) const
{
return facsI_[fac->getIndex()];
}
virtual ~BpLink() { };
void calculateAndUpdateMessage (BpLink* link, bool calcResidual = true)
{
if (Globals::verbosity > 2) {
cout << "calculating & updating " << link->toString() << endl;
}
calcFactorToVarMsg (link);
if (calcResidual) {
link->updateResidual();
}
link->updateMessage();
}
FacNode* facNode() const { return fac_; }
void calculateMessage (BpLink* link, bool calcResidual = true)
{
if (Globals::verbosity > 2) {
cout << "calculating " << link->toString() << endl;
}
calcFactorToVarMsg (link);
if (calcResidual) {
link->updateResidual();
}
}
VarNode* varNode() const { return var_; }
void updateMessage (BpLink* link)
{
link->updateMessage();
if (Globals::verbosity > 2) {
cout << "updating " << link->toString() << endl;
}
}
const Params& message() const { return *currMsg_; }
struct CompareResidual
{
inline bool operator() (const BpLink* link1, const BpLink* link2)
{
return link1->residual() > link2->residual();
}
Params& nextMessage() { return *nextMsg_; }
double residual() const { return residual_; }
void clearResidual();
void updateResidual();
virtual void updateMessage();
std::string toString() const;
protected:
FacNode* fac_;
VarNode* var_;
Params v1_;
Params v2_;
Params* currMsg_;
Params* nextMsg_;
double residual_;
private:
DISALLOW_COPY_AND_ASSIGN (BpLink);
};
void runSolver (void);
struct CmpResidual {
bool operator() (const BpLink* l1, const BpLink* l2) {
return l1->residual() > l2->residual();
}};
virtual void createLinks (void);
typedef std::vector<BeliefProp::BpLink*> BpLinks;
typedef std::multiset<BpLink*, CmpResidual> SortedOrder;
typedef std::unordered_map<BpLink*, SortedOrder::iterator> BpLinkMap;
virtual void maxResidualSchedule (void);
BpLinks& getLinks (const VarNode* var);
BpLinks& getLinks (const FacNode* fac);
void calculateAndUpdateMessage (BpLink* link, bool calcResidual = true);
void calculateMessage (BpLink* link, bool calcResidual = true);
void updateMessage (BpLink* link);
void runSolver();
virtual void createLinks();
virtual void maxResidualSchedule();
virtual void calcFactorToVarMsg (BpLink*);
virtual Params getVarToFactorMsg (const BpLink*) const;
virtual Params getVarToFactorMsg (const BpLink*);
virtual Params getJointByConditioning (const VarIds&) const;
BpLinks links_;
unsigned nIters_;
vector<SPNodeInfo*> varsI_;
vector<SPNodeInfo*> facsI_;
bool runned_;
BpLinks links_;
unsigned nIters_;
bool runned_;
SortedOrder sortedOrder_;
BpLinkMap linkMap_;
typedef multiset<BpLink*, CompareResidual> SortedOrder;
SortedOrder sortedOrder_;
typedef unordered_map<BpLink*, SortedOrder::iterator> BpLinkMap;
BpLinkMap linkMap_;
static double accuracy_;
static unsigned maxIter_;
static MsgSchedule schedule_;
static double accuracy_;
private:
void initializeSolver (void);
void initializeSolver();
bool converged (void);
bool converged();
virtual void printLinkInformation (void) const;
virtual void printLinkInformation() const;
std::vector<BpLinks> varsLinks_;
std::vector<BpLinks> facsLinks_;
static unsigned maxIter_;
static MsgSchedule schedule_;
DISALLOW_COPY_AND_ASSIGN (BeliefProp);
};
#endif // HORUS_BELIEFPROP_H
inline BeliefProp::BpLinks&
BeliefProp::getLinks (const VarNode* var)
{
return varsLinks_[var->getIndex()];
}
inline BeliefProp::BpLinks&
BeliefProp::getLinks (const FacNode* fac)
{
return facsLinks_[fac->getIndex()];
}
} // namespace Horus
#endif // YAP_PACKAGES_CLPBN_HORUS_BELIEFPROP_H_

View File

@ -1,11 +1,88 @@
#include <queue>
#include <iostream>
#include <ostream>
#include <fstream>
#include "ConstraintTree.h"
#include "Util.h"
namespace Horus {
class CTNode {
public:
CTNode (const CTNode& n, const CTChilds& chs = CTChilds())
: symbol_(n.symbol()), childs_(chs), level_(n.level()) { }
CTNode (Symbol s, unsigned l, const CTChilds& chs = CTChilds())
: symbol_(s), childs_(chs), level_(l) { }
unsigned level() const { return level_; }
void setLevel (unsigned level) { level_ = level; }
Symbol symbol() const { return symbol_; }
void setSymbol (Symbol s) { symbol_ = s; }
CTChilds& childs() { return childs_; }
const CTChilds& childs() const { return childs_; }
size_t nrChilds() const { return childs_.size(); }
bool isRoot() const { return level_ == 0; }
bool isLeaf() const { return childs_.empty(); }
CTChilds::iterator findSymbol (Symbol symb);
void mergeSubtree (CTNode*, bool = true);
void removeChild (CTNode*);
void removeChilds();
void removeAndDeleteChild (CTNode*);
void removeAndDeleteAllChilds();
SymbolSet childSymbols() const;
static CTNode* copySubtree (const CTNode*);
static void deleteSubtree (CTNode*);
private:
void updateChildLevels (CTNode*, unsigned);
Symbol symbol_;
CTChilds childs_;
unsigned level_;
DISALLOW_ASSIGN (CTNode);
};
inline CTChilds::iterator
CTNode::findSymbol (Symbol symb)
{
CTNode tmp (symb, 0);
return childs_.find (&tmp);
}
inline bool
CmpSymbol::operator() (const CTNode* n1, const CTNode* n2) const
{
return n1->symbol() < n2->symbol();
}
void
CTNode::mergeSubtree (CTNode* n, bool updateLevels)
{
@ -38,7 +115,7 @@ CTNode::removeChild (CTNode* child)
void
CTNode::removeChilds (void)
CTNode::removeChilds()
{
childs_.clear();
}
@ -55,7 +132,7 @@ CTNode::removeAndDeleteChild (CTNode* child)
void
CTNode::removeAndDeleteAllChilds (void)
CTNode::removeAndDeleteAllChilds()
{
for (CTChilds::const_iterator chIt = childs_.begin();
chIt != childs_.end(); ++ chIt) {
@ -67,7 +144,7 @@ CTNode::removeAndDeleteAllChilds (void)
SymbolSet
CTNode::childSymbols (void) const
CTNode::childSymbols() const
{
SymbolSet symbols;
for (CTChilds::const_iterator chIt = childs_.begin();
@ -106,14 +183,14 @@ CTNode::copySubtree (const CTNode* root1)
return new CTNode (*root1);
}
CTNode* root2 = new CTNode (*root1);
typedef pair<const CTNode*, CTNode*> StackPair;
vector<StackPair> stack = { StackPair (root1, root2) };
typedef std::pair<const CTNode*, CTNode*> StackPair;
std::vector<StackPair> stack = { StackPair (root1, root2) };
while (stack.empty() == false) {
const CTNode* n1 = stack.back().first;
CTNode* n2 = stack.back().second;
stack.pop_back();
// cout << "n2 childs: " << n2->childs();
// cout << "n1 childs: " << n1->childs();
// std::cout << "n2 childs: " << n2->childs();
// std::cout << "n1 childs: " << n1->childs();
n2->childs().reserve (n1->nrChilds());
stack.reserve (n1->nrChilds());
for (CTChilds::const_iterator chIt = n1->childs().begin();
@ -144,7 +221,8 @@ CTNode::deleteSubtree (CTNode* n)
ostream& operator<< (ostream &out, const CTNode& n)
std::ostream&
operator<< (std::ostream& out, const CTNode& n)
{
out << "(" << n.level() << ") " ;
out << n.symbol();
@ -187,7 +265,8 @@ ConstraintTree::ConstraintTree (
ConstraintTree::ConstraintTree (vector<vector<string>> names)
ConstraintTree::ConstraintTree (
std::vector<std::vector<std::string>> names)
{
assert (names.empty() == false);
assert (names.front().empty() == false);
@ -216,13 +295,33 @@ ConstraintTree::ConstraintTree (const ConstraintTree& ct)
ConstraintTree::~ConstraintTree (void)
ConstraintTree::ConstraintTree (
const CTChilds& rootChilds,
const LogVars& logVars)
: root_(new CTNode (Symbol (0), unsigned (0), rootChilds)),
logVars_(logVars),
logVarSet_(logVars)
{
}
ConstraintTree::~ConstraintTree()
{
CTNode::deleteSubtree (root_);
}
bool
ConstraintTree::empty() const
{
return root_->childs().empty();
}
void
ConstraintTree::addTuple (const Tuple& tuple)
{
@ -448,7 +547,7 @@ ConstraintTree::ConstraintTree::isSingleton (LogVar X)
LogVarSet
ConstraintTree::singletons (void)
ConstraintTree::singletons()
{
LogVarSet singletons;
for (size_t i = 0; i < logVars_.size(); i++) {
@ -491,7 +590,7 @@ ConstraintTree::tupleSet (const LogVars& originalLvs)
getTuples (root_, Tuples(), stopLevel, tuples, CTNodes() = {});
if (originalLvs.size() != uniqueLvs.size()) {
vector<size_t> indexes;
std::vector<size_t> indexes;
indexes.reserve (originalLvs.size());
for (size_t i = 0; i < originalLvs.size(); i++) {
indexes.push_back (Util::indexOf (uniqueLvs, originalLvs[i]));
@ -519,21 +618,22 @@ ConstraintTree::exportToGraphViz (
const char* fileName,
bool showLogVars) const
{
ofstream out (fileName);
std::ofstream out (fileName);
if (!out.is_open()) {
cerr << "Error: couldn't open file '" << fileName << "'." ;
std::cerr << "Error: couldn't open file '" << fileName << "'." ;
std::cerr << std::endl;
return;
}
out << "digraph {" << endl;
out << "digraph {" << std::endl;
ConstraintTree copy (*this);
copy.moveToTop (copy.logVarSet_.elements());
CTNodes nodes = getNodesBelow (copy.root_);
out << "\"" << copy.root_ << "\"" << " [label=\"R\"]" << endl;
out << "\"" << copy.root_ << "\"" << " [label=\"R\"]" << std::endl;
for (CTNodes::const_iterator it = ++ nodes.begin();
it != nodes.end(); ++ it) {
out << "\"" << *it << "\"";
out << " [label=\"" << **it << "\"]" ;
out << endl;
out << std::endl;
}
for (CTNodes::const_iterator it = nodes.begin();
it != nodes.end(); ++ it) {
@ -542,24 +642,24 @@ ConstraintTree::exportToGraphViz (
chIt != childs.end(); ++ chIt) {
out << "\"" << *it << "\"" ;
out << " -> " ;
out << "\"" << *chIt << "\"" << endl ;
out << "\"" << *chIt << "\"" << std::endl ;
}
}
if (showLogVars) {
out << "Root [label=\"\", shape=plaintext]" << endl;
out << "Root [label=\"\", shape=plaintext]" << std::endl;
for (size_t i = 0; i < copy.logVars_.size(); i++) {
out << copy.logVars_[i] << " [label=" ;
out << copy.logVars_[i] << ", " ;
out << "shape=plaintext, fontsize=14]" << endl;
out << "shape=plaintext, fontsize=14]" << std::endl;
}
out << "Root -> " << copy.logVars_[0];
out << " [style=invis]" << endl;
out << " [style=invis]" << std::endl;
for (size_t i = 0; i < copy.logVars_.size() - 1; i++) {
out << copy.logVars_[i] << " -> " << copy.logVars_[i + 1];
out << " [style=invis]" << endl;
out << " [style=invis]" << std::endl;
}
}
out << "}" << endl;
out << "}" <<std::endl;
out.close();
}
@ -690,9 +790,9 @@ ConstraintTree::split (
split (root_, ct->root(), commChilds, exclChilds, stopLevel);
ConstraintTree* commCt = new ConstraintTree (commChilds, logVars_);
ConstraintTree* exclCt = new ConstraintTree (exclChilds, logVars_);
// cout << commCt->tupleSet() << " + " ;
// cout << exclCt->tupleSet() << " = " ;
// cout << tupleSet() << endl;
// std::cout << commCt->tupleSet() << " + " ;
// std::cout << exclCt->tupleSet() << " = " ;
// std::cout << tupleSet() << std::endl;
assert ((commCt->tupleSet() | exclCt->tupleSet()) == tupleSet());
assert ((exclCt->tupleSet (stopLevel) & ct->tupleSet (stopLevel)).empty());
return {commCt, exclCt};
@ -710,20 +810,20 @@ ConstraintTree::countNormalize (const LogVarSet& Ys)
}
moveToTop (Zs.elements());
ConstraintTrees cts;
unordered_map<unsigned, ConstraintTree*> countMap;
std::unordered_map<unsigned, ConstraintTree*> countMap;
unsigned stopLevel = getLevel (Zs.back());
const CTChilds& childs = root_->childs();
for (CTChilds::const_iterator chIt = childs.begin();
chIt != childs.end(); ++ chIt) {
const vector<pair<CTNode*, unsigned>>& res =
const std::vector<std::pair<CTNode*, unsigned>>& res =
countNormalize (*chIt, stopLevel);
for (size_t j = 0; j < res.size(); j++) {
unordered_map<unsigned, ConstraintTree*>::iterator it
std::unordered_map<unsigned, ConstraintTree*>::iterator it
= countMap.find (res[j].second);
if (it == countMap.end()) {
ConstraintTree* newCt = new ConstraintTree (logVars_);
it = countMap.insert (make_pair (res[j].second, newCt)).first;
it = countMap.insert (std::make_pair (res[j].second, newCt)).first;
cts.push_back (newCt);
}
it->second->root_->mergeSubtree (res[j].first);
@ -743,31 +843,31 @@ ConstraintTree::jointCountNormalize (
LogVar X_new2)
{
unsigned N = getConditionalCount (X);
// cout << "My tuples: " << tupleSet() << endl;
// cout << "CommCt tuples: " << commCt->tupleSet() << endl;
// cout << "ExclCt tuples: " << exclCt->tupleSet() << endl;
// cout << "Counted Lv: " << X << endl;
// cout << "X_new1: " << X_new1 << endl;
// cout << "X_new2: " << X_new2 << endl;
// cout << "Original N: " << N << endl;
// cout << endl;
// std::cout << "My tuples: " << tupleSet() << std::endl;
// std::cout << "CommCt tuples: " << commCt->tupleSet() << std::endl;
// std::cout << "ExclCt tuples: " << exclCt->tupleSet() << std::endl;
// std::cout << "Counted Lv: " << X << std::endl;
// std::cout << "X_new1: " << X_new1 << std::endl;
// std::cout << "X_new2: " << X_new2 << std::endl;
// std::cout << "Original N: " << N << std::endl;
// std::cout << endl;
ConstraintTrees normCts1 = commCt->countNormalize (X);
vector<unsigned> counts1 (normCts1.size());
std::vector<unsigned> counts1 (normCts1.size());
for (size_t i = 0; i < normCts1.size(); i++) {
counts1[i] = normCts1[i]->getConditionalCount (X);
// cout << "normCts1[" << i << "] #" << counts1[i] ;
// cout << " " << normCts1[i]->tupleSet() << endl;
// std::cout << "normCts1[" << i << "] #" << counts1[i] ;
// std::cout << " " << normCts1[i]->tupleSet() << std::endl;
}
ConstraintTrees normCts2 = exclCt->countNormalize (X);
vector<unsigned> counts2 (normCts2.size());
std::vector<unsigned> counts2 (normCts2.size());
for (size_t i = 0; i < normCts2.size(); i++) {
counts2[i] = normCts2[i]->getConditionalCount (X);
// cout << "normCts2[" << i << "] #" << counts2[i] ;
// cout << " " << normCts2[i]->tupleSet() << endl;
// std::cout << "normCts2[" << i << "] #" << counts2[i] ;
// std::cout << " " << normCts2[i]->tupleSet() << std::endl;
}
// cout << endl;
// std::cout << std::endl;
ConstraintTree* excl1 = 0;
for (size_t i = 0; i < normCts1.size(); i++) {
@ -775,7 +875,7 @@ ConstraintTree::jointCountNormalize (
excl1 = normCts1[i];
normCts1.erase (normCts1.begin() + i);
counts1.erase (counts1.begin() + i);
// cout << "joint-count(" << N << ",0)" << endl;
// std::cout << "joint-count(" << N << ",0)" << std::endl;
break;
}
}
@ -786,7 +886,7 @@ ConstraintTree::jointCountNormalize (
excl2 = normCts2[i];
normCts2.erase (normCts2.begin() + i);
counts2.erase (counts2.begin() + i);
// cout << "joint-count(0," << N << ")" << endl;
// std::cout << "joint-count(0," << N << ")" << std::endl;
break;
}
}
@ -794,8 +894,8 @@ ConstraintTree::jointCountNormalize (
for (size_t i = 0; i < normCts1.size(); i++) {
unsigned j;
for (j = 0; counts1[i] + counts2[j] != N; j++) ;
// cout << "joint-count(" << counts1[i] ;
// cout << "," << counts2[j] << ")" << endl;
// std::cout << "joint-count(" << counts1[i] ;
// std::cout << "," << counts2[j] << ")" << std::endl;
const CTChilds& childs = normCts2[j]->root_->childs();
for (CTChilds::const_iterator chIt = childs.begin();
chIt != childs.end(); ++ chIt) {
@ -930,7 +1030,7 @@ CTNodes
ConstraintTree::getNodesBelow (CTNode* fromHere) const
{
CTNodes nodes;
queue<CTNode*> queue;
std::queue<CTNode*> queue;
queue.push (fromHere);
while (queue.empty() == false) {
CTNode* node = queue.front();
@ -1016,7 +1116,7 @@ ConstraintTree::swapLogVar (LogVar X)
{
size_t pos = Util::indexOf (logVars_, X);
assert (pos != logVars_.size());
const CTNodes& nodes = getNodesAtLevel (pos);
CTNodes nodes = getNodesAtLevel (pos);
for (CTNodes::const_iterator nodeIt = nodes.begin();
nodeIt != nodes.end(); ++ nodeIt) {
CTChilds childsCopy = (*nodeIt)->childs();
@ -1098,7 +1198,7 @@ ConstraintTree::getTuples (
unsigned
ConstraintTree::size (void) const
ConstraintTree::size() const
{
return countTuples (root_);
}
@ -1114,26 +1214,26 @@ ConstraintTree::nrSymbols (LogVar X)
vector<pair<CTNode*, unsigned>>
std::vector<std::pair<CTNode*, unsigned>>
ConstraintTree::countNormalize (
const CTNode* n,
unsigned stopLevel)
{
if (n->level() == stopLevel) {
return vector<pair<CTNode*, unsigned>>() = {
make_pair (CTNode::copySubtree (n), countTuples (n))
return std::vector<std::pair<CTNode*, unsigned>>() = {
std::make_pair (CTNode::copySubtree (n), countTuples (n))
};
}
vector<pair<CTNode*, unsigned>> res;
std::vector<std::pair<CTNode*, unsigned>> res;
const CTChilds& childs = n->childs();
for (CTChilds::const_iterator chIt = childs.begin();
chIt != childs.end(); ++ chIt) {
const vector<pair<CTNode*, unsigned>>& lowerRes =
const std::vector<std::pair<CTNode*, unsigned>>& lowerRes =
countNormalize (*chIt, stopLevel);
for (size_t j = 0; j < lowerRes.size(); j++) {
CTNode* newNode = new CTNode (*n);
newNode->mergeSubtree (lowerRes[j].first);
res.push_back (make_pair (newNode, lowerRes[j].second));
res.push_back (std::make_pair (newNode, lowerRes[j].second));
}
}
return res;
@ -1172,3 +1272,5 @@ ConstraintTree::split (
}
}
} // namespace Horus

View File

@ -1,104 +1,35 @@
#ifndef HORUS_CONSTRAINTTREE_H
#define HORUS_CONSTRAINTTREE_H
#ifndef YAP_PACKAGES_CLPBN_HORUS_CONSTRAINTTREE_H_
#define YAP_PACKAGES_CLPBN_HORUS_CONSTRAINTTREE_H_
#include <cassert>
#include <algorithm>
#include <iostream>
#include <sstream>
#include <vector>
#include <algorithm>
#include <string>
#include "TinySet.h"
#include "LiftedUtils.h"
using namespace std;
namespace Horus {
class CTNode;
typedef vector<CTNode*> CTNodes;
class ConstraintTree;
typedef vector<ConstraintTree*> ConstraintTrees;
class CTNode
{
public:
struct CompareSymbol
{
bool operator() (const CTNode* n1, const CTNode* n2) const
{
return n1->symbol() < n2->symbol();
}
};
typedef std::vector<CTNode*> CTNodes;
typedef std::vector<ConstraintTree*> ConstraintTrees;
private:
typedef TinySet<CTNode*, CompareSymbol> CTChilds_;
public:
CTNode (const CTNode& n, const CTChilds_& chs = CTChilds_())
: symbol_(n.symbol()), childs_(chs), level_(n.level()) { }
CTNode (Symbol s, unsigned l, const CTChilds_& chs = CTChilds_())
: symbol_(s), childs_(chs), level_(l) { }
unsigned level (void) const { return level_; }
void setLevel (unsigned level) { level_ = level; }
Symbol symbol (void) const { return symbol_; }
void setSymbol (const Symbol s) { symbol_ = s; }
CTChilds_& childs (void) { return childs_; }
const CTChilds_& childs (void) const { return childs_; }
size_t nrChilds (void) const { return childs_.size(); }
bool isRoot (void) const { return level_ == 0; }
bool isLeaf (void) const { return childs_.empty(); }
CTChilds_::iterator findSymbol (Symbol symb)
{
CTNode tmp (symb, 0);
return childs_.find (&tmp);
}
void mergeSubtree (CTNode*, bool = true);
void removeChild (CTNode*);
void removeChilds (void);
void removeAndDeleteChild (CTNode*);
void removeAndDeleteAllChilds (void);
SymbolSet childSymbols (void) const;
static CTNode* copySubtree (const CTNode*);
static void deleteSubtree (CTNode*);
private:
void updateChildLevels (CTNode*, unsigned);
Symbol symbol_;
CTChilds_ childs_;
unsigned level_;
DISALLOW_ASSIGN (CTNode);
struct CmpSymbol {
bool operator() (const CTNode* n1, const CTNode* n2) const;
};
ostream& operator<< (ostream &out, const CTNode&);
typedef TinySet<CTNode*, CmpSymbol> CTChilds;
typedef TinySet<CTNode*, CTNode::CompareSymbol> CTChilds;
class ConstraintTree
{
class ConstraintTree {
public:
ConstraintTree (unsigned);
@ -106,38 +37,23 @@ class ConstraintTree
ConstraintTree (const LogVars&, const Tuples&);
ConstraintTree (vector<vector<string>> names);
ConstraintTree (std::vector<std::vector<std::string>> names);
ConstraintTree (const ConstraintTree&);
ConstraintTree (const CTChilds& rootChilds, const LogVars& logVars)
: root_(new CTNode (0, 0, rootChilds)),
logVars_(logVars),
logVarSet_(logVars) { }
ConstraintTree (const CTChilds& rootChilds, const LogVars& logVars);
~ConstraintTree (void);
~ConstraintTree();
CTNode* root (void) const { return root_; }
CTNode* root() const { return root_; }
bool empty (void) const { return root_->childs().empty(); }
bool empty() const;
const LogVars& logVars (void) const
{
assert (LogVarSet (logVars_) == logVarSet_);
return logVars_;
}
const LogVars& logVars() const;
const LogVarSet& logVarSet (void) const
{
assert (LogVarSet (logVars_) == logVarSet_);
return logVarSet_;
}
const LogVarSet& logVarSet() const;
size_t nrLogVars (void) const
{
return logVars_.size();
assert (LogVarSet (logVars_) == logVarSet_);
}
size_t nrLogVars() const;
void addTuple (const Tuple&);
@ -163,13 +79,13 @@ class ConstraintTree
bool isSingleton (LogVar);
LogVarSet singletons (void);
LogVarSet singletons();
TupleSet tupleSet (unsigned = 0) const;
TupleSet tupleSet (const LogVars&);
unsigned size (void) const;
unsigned size() const;
unsigned nrSymbols (LogVar);
@ -218,11 +134,10 @@ class ConstraintTree
void getTuples (CTNode*, Tuples, unsigned, Tuples&, CTNodes&) const;
vector<std::pair<CTNode*, unsigned>> countNormalize (
std::vector<std::pair<CTNode*, unsigned>> countNormalize (
const CTNode*, unsigned);
static void split (
CTNode*, CTNode*, CTChilds&, CTChilds&, unsigned);
static void split (CTNode*, CTNode*, CTChilds&, CTChilds&, unsigned);
CTNode* root_;
LogVars logVars_;
@ -230,5 +145,33 @@ class ConstraintTree
};
#endif // HORUS_CONSTRAINTTREE_H
inline const LogVars&
ConstraintTree::logVars() const
{
assert (LogVarSet (logVars_) == logVarSet_);
return logVars_;
}
inline const LogVarSet&
ConstraintTree::logVarSet() const
{
assert (LogVarSet (logVars_) == logVarSet_);
return logVarSet_;
}
inline size_t
ConstraintTree::nrLogVars() const
{
assert (LogVarSet (logVars_) == logVarSet_);
return logVars_.size();
}
} // namespace Horus
#endif // YAP_PACKAGES_CLPBN_HORUS_CONSTRAINTTREE_H_

View File

@ -1,7 +1,62 @@
#include <cassert>
#include <iostream>
#include <sstream>
#include "CountingBp.h"
#include "WeightedBp.h"
namespace Horus {
class VarCluster {
public:
VarCluster (const VarNodes& vs) : members_(vs) { }
const VarNode* first() const { return members_.front(); }
const VarNodes& members() const { return members_; }
VarNode* representative() const { return repr_; }
void setRepresentative (VarNode* vn) { repr_ = vn; }
private:
VarNodes members_;
VarNode* repr_;
DISALLOW_COPY_AND_ASSIGN (VarCluster);
};
class FacCluster {
private:
typedef std::vector<VarCluster*> VarClusters;
public:
FacCluster (const FacNodes& fcs, const VarClusters& vcs)
: members_(fcs), varClusters_(vcs) { }
const FacNode* first() const { return members_.front(); }
const FacNodes& members() const { return members_; }
FacNode* representative() const { return repr_; }
void setRepresentative (FacNode* fn) { repr_ = fn; }
VarClusters& varClusters() { return varClusters_; }
FacNodes members_;
FacNode* repr_;
VarClusters varClusters_;
DISALLOW_COPY_AND_ASSIGN (FacCluster);
};
bool CountingBp::fif_ = true;
@ -17,7 +72,7 @@ CountingBp::CountingBp (const FactorGraph& fg)
CountingBp::~CountingBp (void)
CountingBp::~CountingBp()
{
delete solver_;
delete compressedFg_;
@ -32,23 +87,24 @@ CountingBp::~CountingBp (void)
void
CountingBp::printSolverFlags (void) const
CountingBp::printSolverFlags() const
{
stringstream ss;
std::stringstream ss;
ss << "counting bp [" ;
ss << "bp_msg_schedule=" ;
typedef WeightedBp::MsgSchedule MsgSchedule;
switch (WeightedBp::msgSchedule()) {
case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break;
case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break;
case MsgSchedule::PARALLEL: ss << "parallel"; break;
case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break;
case MsgSchedule::seqFixedSch: ss << "seq_fixed"; break;
case MsgSchedule::seqRandomSch: ss << "seq_random"; break;
case MsgSchedule::parallelSch: ss << "parallel"; break;
case MsgSchedule::maxResidualSch: ss << "max_residual"; break;
}
ss << ",bp_max_iter=" << WeightedBp::maxIterations();
ss << ",bp_accuracy=" << WeightedBp::accuracy();
ss << ",log_domain=" << Util::toString (Globals::logDomain);
ss << ",fif=" << Util::toString (CountingBp::fif_);
ss << "]" ;
cout << ss.str() << endl;
std::cout << ss.str() << std::endl;
}
@ -69,11 +125,10 @@ CountingBp::solveQuery (VarIds queryVids)
idx = i;
break;
}
cout << endl;
}
if (idx == facNodes.size()) {
res = GroundSolver::getJointByConditioning (
GroundSolverType::CBP, fg, queryVids);
GroundSolverType::CbpSolver, fg, queryVids);
} else {
VarIds reprArgs;
for (size_t i = 0; i < queryVids.size(); i++) {
@ -124,7 +179,7 @@ CountingBp::findIdenticalFactors()
void
CountingBp::setInitialColors (void)
CountingBp::setInitialColors()
{
varColors_.resize (fg.nrVarNodes());
facColors_.resize (fg.nrFacNodes());
@ -135,7 +190,7 @@ CountingBp::setInitialColors (void)
unsigned range = varNodes[i]->range();
VarColorMap::iterator it = colorMap.find (range);
if (it == colorMap.end()) {
it = colorMap.insert (make_pair (
it = colorMap.insert (std::make_pair (
range, Colors (range + 1, -1))).first;
}
unsigned idx = varNodes[i]->hasEvidence()
@ -154,7 +209,8 @@ CountingBp::setInitialColors (void)
unsigned distId = facNodes[i]->factor().distId();
DistColorMap::iterator it = distColors.find (distId);
if (it == distColors.end()) {
it = distColors.insert (make_pair (distId, getNewColor())).first;
it = distColors.insert (std::make_pair (
distId, getNewColor())).first;
}
setColor (facNodes[i], it->second);
}
@ -163,7 +219,7 @@ CountingBp::setInitialColors (void)
void
CountingBp::createGroups (void)
CountingBp::createGroups()
{
VarSignMap varGroups;
FacSignMap facGroups;
@ -179,10 +235,11 @@ CountingBp::createGroups (void)
size_t prevVarGroupsSize = varGroups.size();
varGroups.clear();
for (size_t i = 0; i < varNodes.size(); i++) {
const VarSignature& signature = getSignature (varNodes[i]);
VarSignature signature = getSignature (varNodes[i]);
VarSignMap::iterator it = varGroups.find (signature);
if (it == varGroups.end()) {
it = varGroups.insert (make_pair (signature, VarNodes())).first;
it = varGroups.insert (std::make_pair (
signature, VarNodes())).first;
}
it->second.push_back (varNodes[i]);
}
@ -199,10 +256,11 @@ CountingBp::createGroups (void)
facGroups.clear();
// set a new color to the factors with the same signature
for (size_t i = 0; i < facNodes.size(); i++) {
const FacSignature& signature = getSignature (facNodes[i]);
FacSignature signature = getSignature (facNodes[i]);
FacSignMap::iterator it = facGroups.find (signature);
if (it == facGroups.end()) {
it = facGroups.insert (make_pair (signature, FacNodes())).first;
it = facGroups.insert (std::make_pair (
signature, FacNodes())).first;
}
it->second.push_back (facNodes[i]);
}
@ -235,7 +293,8 @@ CountingBp::createClusters (
const VarNodes& groupVars = it->second;
VarCluster* vc = new VarCluster (groupVars);
for (size_t i = 0; i < groupVars.size(); i++) {
varClusterMap_.insert (make_pair (groupVars[i]->varId(), vc));
varClusterMap_.insert (std::make_pair (
groupVars[i]->varId(), vc));
}
varClusters_.push_back (vc);
}
@ -257,29 +316,29 @@ CountingBp::createClusters (
VarSignature
CountingBp::VarSignature
CountingBp::getSignature (const VarNode* varNode)
{
const FacNodes& neighs = varNode->neighbors();
VarSignature sign;
const FacNodes& neighs = varNode->neighbors();
sign.reserve (neighs.size() + 1);
for (size_t i = 0; i < neighs.size(); i++) {
sign.push_back (make_pair (
sign.push_back (std::make_pair (
getColor (neighs[i]),
neighs[i]->factor().indexOf (varNode->varId())));
}
std::sort (sign.begin(), sign.end());
sign.push_back (make_pair (getColor (varNode), 0));
sign.push_back (std::make_pair (getColor (varNode), 0));
return sign;
}
FacSignature
CountingBp::FacSignature
CountingBp::getSignature (const FacNode* facNode)
{
const VarNodes& neighs = facNode->neighbors();
FacSignature sign;
const VarNodes& neighs = facNode->neighbors();
sign.reserve (neighs.size() + 1);
for (size_t i = 0; i < neighs.size(); i++) {
sign.push_back (getColor (neighs[i]));
@ -314,7 +373,7 @@ CountingBp::getRepresentative (FacNode* fn)
FactorGraph*
CountingBp::getCompressedFactorGraph (void)
CountingBp::getCompressedFactorGraph()
{
FactorGraph* fg = new FactorGraph();
for (size_t i = 0; i < varClusters_.size(); i++) {
@ -342,10 +401,10 @@ CountingBp::getCompressedFactorGraph (void)
vector<vector<unsigned>>
CountingBp::getWeights (void) const
std::vector<std::vector<unsigned>>
CountingBp::getWeights() const
{
vector<vector<unsigned>> weights;
std::vector<std::vector<unsigned>> weights;
weights.reserve (facClusters_.size());
for (size_t i = 0; i < facClusters_.size(); i++) {
const VarClusters& neighs = facClusters_[i]->varClusters();
@ -390,32 +449,34 @@ CountingBp::printGroups (
const FacSignMap& facGroups) const
{
unsigned count = 1;
cout << "variable groups:" << endl;
std::cout << "variable groups:" << std::endl;
for (VarSignMap::const_iterator it = varGroups.begin();
it != varGroups.end(); ++it) {
const VarNodes& groupMembers = it->second;
if (groupMembers.size() > 0) {
cout << count << ": " ;
std::cout << count << ": " ;
for (size_t i = 0; i < groupMembers.size(); i++) {
cout << groupMembers[i]->label() << " " ;
std::cout << groupMembers[i]->label() << " " ;
}
count ++;
cout << endl;
std::cout << std::endl;
}
}
count = 1;
cout << endl << "factor groups:" << endl;
std::cout << std::endl << "factor groups:" << std::endl;
for (FacSignMap::const_iterator it = facGroups.begin();
it != facGroups.end(); ++it) {
const FacNodes& groupMembers = it->second;
if (groupMembers.size() > 0) {
cout << ++count << ": " ;
std::cout << ++count << ": " ;
for (size_t i = 0; i < groupMembers.size(); i++) {
cout << groupMembers[i]->getLabel() << " " ;
std::cout << groupMembers[i]->getLabel() << " " ;
}
count ++;
cout << endl;
std::cout << std::endl;
}
}
}
} // namespace Horus

View File

@ -1,155 +1,99 @@
#ifndef HORUS_COUNTINGBP_H
#define HORUS_COUNTINGBP_H
#ifndef YAP_PACKAGES_CLPBN_HORUS_COUNTINGBP_H_
#define YAP_PACKAGES_CLPBN_HORUS_COUNTINGBP_H_
#include <vector>
#include <unordered_map>
#include "GroundSolver.h"
#include "FactorGraph.h"
#include "Horus.h"
namespace Horus {
class VarCluster;
class FacCluster;
class WeightedBp;
typedef long Color;
typedef vector<Color> Colors;
typedef vector<std::pair<Color,unsigned>> VarSignature;
typedef vector<Color> FacSignature;
typedef unordered_map<unsigned, Color> DistColorMap;
typedef unordered_map<unsigned, Colors> VarColorMap;
typedef unordered_map<VarSignature, VarNodes> VarSignMap;
typedef unordered_map<FacSignature, FacNodes> FacSignMap;
typedef unordered_map<VarId, VarCluster*> VarClusterMap;
typedef vector<VarCluster*> VarClusters;
typedef vector<FacCluster*> FacClusters;
template <class T>
inline size_t hash_combine (size_t seed, const T& v)
template <class T> inline size_t
hash_combine (size_t seed, const T& v)
{
return seed ^ (hash<T>()(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2));
return seed ^ (std::hash<T>()(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2));
}
} // namespace Horus
namespace std {
template <typename T1, typename T2> struct hash<std::pair<T1,T2>>
{
size_t operator() (const std::pair<T1,T2>& p) const
{
return hash_combine (std::hash<T1>()(p.first), p.second);
}
};
template <typename T> struct hash<std::vector<T>>
{
size_t operator() (const std::vector<T>& vec) const
{
size_t h = 0;
typename vector<T>::const_iterator first = vec.begin();
typename vector<T>::const_iterator last = vec.end();
for (; first != last; ++first) {
h = hash_combine (h, *first);
}
return h;
}
};
}
template <typename T1, typename T2> struct hash<std::pair<T1,T2>> {
size_t operator() (const std::pair<T1,T2>& p) const {
return Horus::hash_combine (std::hash<T1>()(p.first), p.second);
}};
class VarCluster
template <typename T> struct hash<std::vector<T>>
{
public:
VarCluster (const VarNodes& vs) : members_(vs) { }
const VarNode* first (void) const { return members_.front(); }
const VarNodes& members (void) const { return members_; }
VarNode* representative (void) const { return repr_; }
void setRepresentative (VarNode* vn) { repr_ = vn; }
private:
VarNodes members_;
VarNode* repr_;
DISALLOW_COPY_AND_ASSIGN (VarCluster);
size_t operator() (const std::vector<T>& vec) const
{
size_t h = 0;
typename std::vector<T>::const_iterator first = vec.begin();
typename std::vector<T>::const_iterator last = vec.end();
for (; first != last; ++first) {
h = Horus::hash_combine (h, *first);
}
return h;
}
};
class FacCluster
{
public:
FacCluster (const FacNodes& fcs, const VarClusters& vcs)
: members_(fcs), varClusters_(vcs) { }
const FacNode* first (void) const { return members_.front(); }
const FacNodes& members (void) const { return members_; }
FacNode* representative (void) const { return repr_; }
void setRepresentative (FacNode* fn) { repr_ = fn; }
VarClusters& varClusters (void) { return varClusters_; }
private:
FacNodes members_;
FacNode* repr_;
VarClusters varClusters_;
DISALLOW_COPY_AND_ASSIGN (FacCluster);
};
} // namespace std
class CountingBp : public GroundSolver
{
namespace Horus {
class CountingBp : public GroundSolver {
public:
CountingBp (const FactorGraph& fg);
~CountingBp (void);
~CountingBp();
void printSolverFlags (void) const;
void printSolverFlags() const;
Params solveQuery (VarIds);
static void setFindIdenticalFactorsFlag (bool fif) { fif_ = fif; }
private:
Color getNewColor (void)
{
++ freeColor_;
return freeColor_ - 1;
}
typedef long Color;
typedef std::vector<Color> Colors;
Color getColor (const VarNode* vn) const
{
return varColors_[vn->getIndex()];
}
typedef std::vector<std::pair<Color,unsigned>> VarSignature;
typedef std::vector<Color> FacSignature;
Color getColor (const FacNode* fn) const
{
return facColors_[fn->getIndex()];
}
typedef std::vector<VarCluster*> VarClusters;
typedef std::vector<FacCluster*> FacClusters;
void setColor (const VarNode* vn, Color c)
{
varColors_[vn->getIndex()] = c;
}
typedef std::unordered_map<unsigned, Color> DistColorMap;
typedef std::unordered_map<unsigned, Colors> VarColorMap;
typedef std::unordered_map<VarSignature, VarNodes> VarSignMap;
typedef std::unordered_map<FacSignature, FacNodes> FacSignMap;
typedef std::unordered_map<VarId, VarCluster*> VarClusterMap;
void setColor (const FacNode* fn, Color c)
{
facColors_[fn->getIndex()] = c;
}
Color getNewColor();
void findIdenticalFactors (void);
Color getColor (const VarNode* vn) const;
void setInitialColors (void);
Color getColor (const FacNode* fn) const;
void createGroups (void);
void setColor (const VarNode* vn, Color c);
void setColor (const FacNode* fn, Color c);
void findIdenticalFactors();
void setInitialColors();
void createGroups();
void createClusters (const VarSignMap&, const FacSignMap&);
@ -163,12 +107,12 @@ class CountingBp : public GroundSolver
FacNode* getRepresentative (FacNode*);
FactorGraph* getCompressedFactorGraph (void);
FactorGraph* getCompressedFactorGraph();
vector<vector<unsigned>> getWeights (void) const;
std::vector<std::vector<unsigned>> getWeights() const;
unsigned getWeight (const FacCluster*,
const VarCluster*, size_t index) const;
unsigned getWeight (const FacCluster*, const VarCluster*,
size_t index) const;
Color freeColor_;
Colors varColors_;
@ -184,5 +128,48 @@ class CountingBp : public GroundSolver
DISALLOW_COPY_AND_ASSIGN (CountingBp);
};
#endif // HORUS_COUNTINGBP_H
inline CountingBp::Color
CountingBp::getNewColor()
{
++ freeColor_;
return freeColor_ - 1;
}
inline CountingBp::Color
CountingBp::getColor (const VarNode* vn) const
{
return varColors_[vn->getIndex()];
}
inline CountingBp::Color
CountingBp::getColor (const FacNode* fn) const
{
return facColors_[fn->getIndex()];
}
inline void
CountingBp::setColor (const VarNode* vn, CountingBp::Color c)
{
varColors_[vn->getIndex()] = c;
}
inline void
CountingBp::setColor (const FacNode* fn, Color c)
{
facColors_[fn->getIndex()] = c;
}
} // namespace Horus
#endif // YAP_PACKAGES_CLPBN_HORUS_COUNTINGBP_H_

View File

@ -1,25 +1,30 @@
#include <iostream>
#include <fstream>
#include "ElimGraph.h"
ElimHeuristic ElimGraph::elimHeuristic_ = MIN_NEIGHBORS;
namespace Horus {
ElimGraph::ElimHeuristic ElimGraph::elimHeuristic_ =
ElimHeuristic::minNeighborsEh;
ElimGraph::ElimGraph (const vector<Factor*>& factors)
ElimGraph::ElimGraph (const std::vector<Factor*>& factors)
{
for (size_t i = 0; i < factors.size(); i++) {
if (factors[i]) {
const VarIds& args = factors[i]->arguments();
for (size_t j = 0; j < args.size() - 1; j++) {
EgNode* n1 = getEgNode (args[j]);
EGNode* n1 = getEGNode (args[j]);
if (!n1) {
n1 = new EgNode (args[j], factors[i]->range (j));
n1 = new EGNode (args[j], factors[i]->range (j));
addNode (n1);
}
for (size_t k = j + 1; k < args.size(); k++) {
EgNode* n2 = getEgNode (args[k]);
EGNode* n2 = getEGNode (args[k]);
if (!n2) {
n2 = new EgNode (args[k], factors[i]->range (k));
n2 = new EGNode (args[k], factors[i]->range (k));
addNode (n2);
}
if (!neighbors (n1, n2)) {
@ -27,8 +32,8 @@ ElimGraph::ElimGraph (const vector<Factor*>& factors)
}
}
}
if (args.size() == 1 && !getEgNode (args[0])) {
addNode (new EgNode (args[0], factors[i]->range (0)));
if (args.size() == 1 && !getEGNode (args[0])) {
addNode (new EGNode (args[0], factors[i]->range (0)));
}
}
}
@ -36,7 +41,7 @@ ElimGraph::ElimGraph (const vector<Factor*>& factors)
ElimGraph::~ElimGraph (void)
ElimGraph::~ElimGraph()
{
for (size_t i = 0; i < nodes_.size(); i++) {
delete nodes_[i];
@ -57,7 +62,7 @@ ElimGraph::getEliminatingOrder (const VarIds& excludedVids)
}
size_t nrVarsToEliminate = nodes_.size() - excludedVids.size();
for (size_t i = 0; i < nrVarsToEliminate; i++) {
EgNode* node = getLowestCostNode();
EGNode* node = getLowestCostNode();
unmarked_.remove (node);
const EGNeighs& neighs = node->neighbors();
for (size_t j = 0; j < neighs.size(); j++) {
@ -72,15 +77,15 @@ ElimGraph::getEliminatingOrder (const VarIds& excludedVids)
void
ElimGraph::print (void) const
ElimGraph::print() const
{
for (size_t i = 0; i < nodes_.size(); i++) {
cout << "node " << nodes_[i]->label() << " neighs:" ;
std::cout << "node " << nodes_[i]->label() << " neighs:" ;
EGNeighs neighs = nodes_[i]->neighbors();
for (size_t j = 0; j < neighs.size(); j++) {
cout << " " << neighs[j]->label();
std::cout << " " << neighs[j]->label();
}
cout << endl;
std::cout << std::endl;
}
}
@ -92,25 +97,27 @@ ElimGraph::exportToGraphViz (
bool showNeighborless,
const VarIds& highlightVarIds) const
{
ofstream out (fileName);
std::ofstream out (fileName);
if (!out.is_open()) {
cerr << "Error: couldn't open file '" << fileName << "'." ;
std::cerr << "Error: couldn't open file '" << fileName << "'." ;
std::cerr << std::endl;
return;
}
out << "strict graph {" << endl;
out << "strict graph {" << std::endl;
for (size_t i = 0; i < nodes_.size(); i++) {
if (showNeighborless || nodes_[i]->neighbors().empty() == false) {
out << '"' << nodes_[i]->label() << '"' << endl;
out << '"' << nodes_[i]->label() << '"' << std::endl;
}
}
for (size_t i = 0; i < highlightVarIds.size(); i++) {
EgNode* node =getEgNode (highlightVarIds[i]);
EGNode* node =getEGNode (highlightVarIds[i]);
if (node) {
out << '"' << node->label() << '"' ;
out << " [shape=box3d]" << endl;
out << " [shape=box3d]" << std::endl;
} else {
cerr << "Error: invalid variable id: " << highlightVarIds[i] << "." ;
cerr << endl;
std::cerr << "Error: invalid variable id: " ;
std::cerr << highlightVarIds[i] << "." ;
std::cerr << std::endl;
exit (EXIT_FAILURE);
}
}
@ -118,10 +125,10 @@ ElimGraph::exportToGraphViz (
EGNeighs neighs = nodes_[i]->neighbors();
for (size_t j = 0; j < neighs.size(); j++) {
out << '"' << nodes_[i]->label() << '"' << " -- " ;
out << '"' << neighs[j]->label() << '"' << endl;
out << '"' << neighs[j]->label() << '"' << std::endl;
}
}
out << "}" << endl;
out << "}" << std::endl;
out.close();
}
@ -132,7 +139,7 @@ ElimGraph::getEliminationOrder (
const Factors& factors,
VarIds excludedVids)
{
if (elimHeuristic_ == ElimHeuristic::SEQUENTIAL) {
if (elimHeuristic_ == ElimHeuristic::sequentialEh) {
VarIds allVids;
Factors::const_iterator first = factors.begin();
Factors::const_iterator end = factors.end();
@ -150,33 +157,33 @@ ElimGraph::getEliminationOrder (
void
ElimGraph::addNode (EgNode* n)
ElimGraph::addNode (EGNode* n)
{
nodes_.push_back (n);
n->setIndex (nodes_.size() - 1);
varMap_.insert (make_pair (n->varId(), n));
varMap_.insert (std::make_pair (n->varId(), n));
}
EgNode*
ElimGraph::getEgNode (VarId vid) const
ElimGraph::EGNode*
ElimGraph::getEGNode (VarId vid) const
{
unordered_map<VarId, EgNode*>::const_iterator it;
std::unordered_map<VarId, EGNode*>::const_iterator it;
it = varMap_.find (vid);
return (it != varMap_.end()) ? it->second : 0;
}
EgNode*
ElimGraph::getLowestCostNode (void) const
ElimGraph::EGNode*
ElimGraph::getLowestCostNode() const
{
EgNode* bestNode = 0;
EGNode* bestNode = 0;
unsigned minCost = Util::maxUnsigned();
EGNeighs::const_iterator it;
switch (elimHeuristic_) {
case MIN_NEIGHBORS: {
case ElimHeuristic::minNeighborsEh: {
for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) {
unsigned cost = getNeighborsCost (*it);
if (cost < minCost) {
@ -185,7 +192,7 @@ ElimGraph::getLowestCostNode (void) const
}
}}
break;
case MIN_WEIGHT: {
case ElimHeuristic::minWeightEh: {
for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) {
unsigned cost = getWeightCost (*it);
if (cost < minCost) {
@ -194,7 +201,7 @@ ElimGraph::getLowestCostNode (void) const
}
}}
break;
case MIN_FILL: {
case ElimHeuristic::minFillEh: {
for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) {
unsigned cost = getFillCost (*it);
if (cost < minCost) {
@ -203,7 +210,7 @@ ElimGraph::getLowestCostNode (void) const
}
}}
break;
case WEIGHTED_MIN_FILL: {
case ElimHeuristic::weightedMinFillEh: {
for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) {
unsigned cost = getWeightedFillCost (*it);
if (cost < minCost) {
@ -222,7 +229,7 @@ ElimGraph::getLowestCostNode (void) const
void
ElimGraph::connectAllNeighbors (const EgNode* n)
ElimGraph::connectAllNeighbors (const EGNode* n)
{
const EGNeighs& neighs = n->neighbors();
if (neighs.size() > 0) {
@ -236,3 +243,5 @@ ElimGraph::connectAllNeighbors (const EgNode* n)
}
}
} // namespace Horus

View File

@ -1,143 +1,177 @@
#ifndef HORUS_ELIMGRAPH_H
#define HORUS_ELIMGRAPH_H
#ifndef YAP_PACKAGES_CLPBN_HORUS_ELIMGRAPH_H_
#define YAP_PACKAGES_CLPBN_HORUS_ELIMGRAPH_H_
#include "unordered_map"
#include <cassert>
#include <vector>
#include <unordered_map>
#include "FactorGraph.h"
#include "TinySet.h"
#include "Horus.h"
using namespace std;
enum ElimHeuristic
{
SEQUENTIAL,
MIN_NEIGHBORS,
MIN_WEIGHT,
MIN_FILL,
WEIGHTED_MIN_FILL
};
namespace Horus {
class EgNode;
typedef TinySet<EgNode*> EGNeighs;
class EgNode : public Var
{
class ElimGraph {
public:
EgNode (VarId vid, unsigned range) : Var (vid, range) { }
enum class ElimHeuristic {
sequentialEh,
minNeighborsEh,
minWeightEh,
minFillEh,
weightedMinFillEh
};
void addNeighbor (EgNode* n) { neighs_.insert (n); }
void removeNeighbor (EgNode* n) { neighs_.remove (n); }
bool isNeighbor (EgNode* n) const { return neighs_.contains (n); }
const EGNeighs& neighbors (void) const { return neighs_; }
private:
EGNeighs neighs_;
};
class ElimGraph
{
public:
ElimGraph (const Factors&);
~ElimGraph (void);
~ElimGraph();
VarIds getEliminatingOrder (const VarIds&);
void print (void) const;
void print() const;
void exportToGraphViz (const char*, bool = true,
const VarIds& = VarIds()) const;
static VarIds getEliminationOrder (const Factors&, VarIds);
static ElimHeuristic elimHeuristic (void) { return elimHeuristic_; }
static ElimHeuristic elimHeuristic() { return elimHeuristic_; }
static void setElimHeuristic (ElimHeuristic eh) { elimHeuristic_ = eh; }
private:
void addEdge (EgNode* n1, EgNode* n2)
{
assert (n1 != n2);
n1->addNeighbor (n2);
n2->addNeighbor (n1);
}
class EGNode;
unsigned getNeighborsCost (const EgNode* n) const
{
return n->neighbors().size();
}
typedef TinySet<EGNode*> EGNeighs;
unsigned getWeightCost (const EgNode* n) const
{
unsigned cost = 1;
const EGNeighs& neighs = n->neighbors();
for (size_t i = 0; i < neighs.size(); i++) {
cost *= neighs[i]->range();
}
return cost;
}
class EGNode : public Var {
public:
EGNode (VarId vid, unsigned range) : Var (vid, range) { }
unsigned getFillCost (const EgNode* n) const
{
unsigned cost = 0;
const EGNeighs& neighs = n->neighbors();
if (neighs.size() > 0) {
for (size_t i = 0; i < neighs.size() - 1; i++) {
for (size_t j = i + 1; j < neighs.size(); j++) {
if ( ! neighbors (neighs[i], neighs[j])) {
cost ++;
}
}
}
}
return cost;
}
void addNeighbor (EGNode* n) { neighs_.insert (n); }
unsigned getWeightedFillCost (const EgNode* n) const
{
unsigned cost = 0;
const EGNeighs& neighs = n->neighbors();
if (neighs.size() > 0) {
for (size_t i = 0; i < neighs.size() - 1; i++) {
for (size_t j = i + 1; j < neighs.size(); j++) {
if ( ! neighbors (neighs[i], neighs[j])) {
cost += neighs[i]->range() * neighs[j]->range();
}
}
}
}
return cost;
}
void removeNeighbor (EGNode* n) { neighs_.remove (n); }
bool neighbors (EgNode* n1, EgNode* n2) const
{
return n1->isNeighbor (n2);
}
bool isNeighbor (EGNode* n) const { return neighs_.contains (n); }
void addNode (EgNode*);
const EGNeighs& neighbors() const { return neighs_; }
EgNode* getEgNode (VarId) const;
private:
EGNeighs neighs_;
};
EgNode* getLowestCostNode (void) const;
void addEdge (EGNode* n1, EGNode* n2);
void connectAllNeighbors (const EgNode*);
unsigned getNeighborsCost (const EGNode* n) const;
vector<EgNode*> nodes_;
TinySet<EgNode*> unmarked_;
unordered_map<VarId, EgNode*> varMap_;
unsigned getWeightCost (const EGNode* n) const;
unsigned getFillCost (const EGNode* n) const;
unsigned getWeightedFillCost (const EGNode* n) const;
bool neighbors (EGNode* n1, EGNode* n2) const;
void addNode (EGNode*);
EGNode* getEGNode (VarId) const;
EGNode* getLowestCostNode() const;
void connectAllNeighbors (const EGNode*);
std::vector<EGNode*> nodes_;
EGNeighs unmarked_;
std::unordered_map<VarId, EGNode*> varMap_;
static ElimHeuristic elimHeuristic_;
DISALLOW_COPY_AND_ASSIGN (ElimGraph);
};
#endif // HORUS_ELIMGRAPH_H
/* Profiling shows that we should inline the following functions */
inline void
ElimGraph::addEdge (EGNode* n1, EGNode* n2)
{
assert (n1 != n2);
n1->addNeighbor (n2);
n2->addNeighbor (n1);
}
inline unsigned
ElimGraph::getNeighborsCost (const EGNode* n) const
{
return n->neighbors().size();
}
inline unsigned
ElimGraph::getWeightCost (const EGNode* n) const
{
unsigned cost = 1;
const EGNeighs& neighs = n->neighbors();
for (size_t i = 0; i < neighs.size(); i++) {
cost *= neighs[i]->range();
}
return cost;
}
inline unsigned
ElimGraph::getFillCost (const EGNode* n) const
{
unsigned cost = 0;
const EGNeighs& neighs = n->neighbors();
if (neighs.size() > 0) {
for (size_t i = 0; i < neighs.size() - 1; i++) {
for (size_t j = i + 1; j < neighs.size(); j++) {
if ( ! neighbors (neighs[i], neighs[j])) {
cost ++;
}
}
}
}
return cost;
}
inline unsigned
ElimGraph::getWeightedFillCost (const EGNode* n) const
{
unsigned cost = 0;
const EGNeighs& neighs = n->neighbors();
if (neighs.size() > 0) {
for (size_t i = 0; i < neighs.size() - 1; i++) {
for (size_t j = i + 1; j < neighs.size(); j++) {
if ( ! neighbors (neighs[i], neighs[j])) {
cost += neighs[i]->range() * neighs[j]->range();
}
}
}
}
return cost;
}
inline bool
ElimGraph::neighbors (EGNode* n1, EGNode* n2) const
{
return n1->isNeighbor (n2);
}
} // namespace Horus
#endif // YAP_PACKAGES_CLPBN_HORUS_ELIMGRAPH_H_

View File

@ -1,21 +1,15 @@
#include <cstdlib>
#include <cassert>
#include <algorithm>
#include <iostream>
#include <sstream>
#include "Factor.h"
#include "Indexer.h"
#include "Var.h"
Factor::Factor (const Factor& g)
{
clone (g);
}
namespace Horus {
Factor::Factor (
const VarIds& vids,
@ -77,7 +71,7 @@ Factor::sumOutAllExcept (VarId vid)
void
Factor::sumOutAllExcept (const VarIds& vids)
{
vector<bool> mask (args_.size(), false);
std::vector<bool> mask (args_.size(), false);
for (unsigned i = 0; i < vids.size(); i++) {
assert (indexOf (vids[i]) != args_.size());
mask[indexOf (vids[i])] = true;
@ -91,28 +85,30 @@ void
Factor::sumOutAllExceptIndex (size_t idx)
{
assert (idx < args_.size());
vector<bool> mask (args_.size(), false);
std::vector<bool> mask (args_.size(), false);
mask[idx] = true;
sumOutArgs (mask);
}
void
Factor::multiply (Factor& g)
Factor&
Factor::multiply (const Factor& g)
{
if (args_.empty()) {
clone (g);
operator= (g);
} else {
TFactor<VarId>::multiply (g);
GenericFactor<VarId>::multiply (g);
}
return *this;
}
string
Factor::getLabel (void) const
std::string
Factor::getLabel() const
{
stringstream ss;
std::stringstream ss;
ss << "f(" ;
for (size_t i = 0; i < args_.size(); i++) {
if (i != 0) ss << "," ;
@ -125,19 +121,19 @@ Factor::getLabel (void) const
void
Factor::print (void) const
Factor::print() const
{
Vars vars;
for (size_t i = 0; i < args_.size(); i++) {
vars.push_back (new Var (args_[i], ranges_[i]));
}
vector<string> jointStrings = Util::getStateLines (vars);
std::vector<std::string> jointStrings = Util::getStateLines (vars);
for (size_t i = 0; i < params_.size(); i++) {
// cout << "[" << distId_ << "] " ;
cout << "f(" << jointStrings[i] << ")" ;
cout << " = " << params_[i] << endl;
std::cout << "f(" << jointStrings[i] << ")" ;
std::cout << " = " << params_[i] << std::endl;
}
cout << endl;
std::cout << std::endl;
for (size_t i = 0; i < vars.size(); i++) {
delete vars[i];
}
@ -146,8 +142,9 @@ Factor::print (void) const
void
Factor::sumOutFirstVariable (void)
Factor::sumOutFirstVariable()
{
assert (ranges_.front() == 2);
size_t sep = params_.size() / 2;
if (Globals::logDomain) {
std::transform (
@ -169,19 +166,21 @@ Factor::sumOutFirstVariable (void)
void
Factor::sumOutLastVariable (void)
Factor::sumOutLastVariable()
{
assert (ranges_.back() == 2);
Params::iterator first1 = params_.begin();
Params::iterator first2 = params_.begin();
Params::iterator last = params_.end();
if (Globals::logDomain) {
while (first2 != last) {
// the arguments can be swaped, but that is ok
*first1++ = Util::logSum (*first2++, *first2++);
double tmp = *first2++;
*first1++ = Util::logSum (tmp, *first2++);
}
} else {
while (first2 != last) {
*first1++ = (*first2++) + (*first2++);
*first1 = *first2++;
*first1++ += *first2++;
}
}
params_.resize (params_.size() / 2);
@ -192,7 +191,7 @@ Factor::sumOutLastVariable (void)
void
Factor::sumOutArgs (const vector<bool>& mask)
Factor::sumOutArgs (const std::vector<bool>& mask)
{
assert (mask.size() == args_.size());
size_t new_size = 1;
@ -224,14 +223,5 @@ Factor::sumOutArgs (const vector<bool>& mask)
params_ = newps;
}
void
Factor::clone (const Factor& g)
{
args_ = g.arguments();
ranges_ = g.ranges();
params_ = g.params();
distId_ = g.distId();
}
} // namespace Horus

View File

@ -1,262 +1,20 @@
#ifndef HORUS_FACTOR_H
#define HORUS_FACTOR_H
#ifndef YAP_PACKAGES_CLPBN_HORUS_FACTOR_H_
#define YAP_PACKAGES_CLPBN_HORUS_FACTOR_H_
#include <cassert>
#include <vector>
#include <string>
#include "Indexer.h"
#include "GenericFactor.h"
#include "Util.h"
using namespace std;
namespace Horus {
template <typename T>
class TFactor
{
class Factor : public GenericFactor<VarId> {
public:
const vector<T>& arguments (void) const { return args_; }
vector<T>& arguments (void) { return args_; }
const Ranges& ranges (void) const { return ranges_; }
const Params& params (void) const { return params_; }
Params& params (void) { return params_; }
size_t nrArguments (void) const { return args_.size(); }
size_t size (void) const { return params_.size(); }
unsigned distId (void) const { return distId_; }
void setDistId (unsigned id) { distId_ = id; }
void normalize (void) { LogAware::normalize (params_); }
void randomize (void)
{
for (size_t i = 0; i < params_.size(); ++i) {
params_[i] = (double) std::rand() / RAND_MAX;
}
}
void setParams (const Params& newParams)
{
params_ = newParams;
assert (params_.size() == Util::sizeExpected (ranges_));
}
size_t indexOf (const T& t) const
{
return Util::indexOf (args_, t);
}
const T& argument (size_t idx) const
{
assert (idx < args_.size());
return args_[idx];
}
T& argument (size_t idx)
{
assert (idx < args_.size());
return args_[idx];
}
unsigned range (size_t idx) const
{
assert (idx < ranges_.size());
return ranges_[idx];
}
void multiply (TFactor<T>& g)
{
if (args_ == g.arguments()) {
// optimization
Globals::logDomain
? params_ += g.params()
: params_ *= g.params();
return;
}
unsigned range_prod = 1;
bool share_arguments = false;
const vector<T>& g_args = g.arguments();
const Ranges& g_ranges = g.ranges();
const Params& g_params = g.params();
for (size_t i = 0; i < g_args.size(); i++) {
size_t idx = indexOf (g_args[i]);
if (idx == args_.size()) {
range_prod *= g_ranges[i];
args_.push_back (g_args[i]);
ranges_.push_back (g_ranges[i]);
} else {
share_arguments = true;
}
}
if (share_arguments == false) {
// optimization
cartesianProduct (g_params.begin(), g_params.end());
} else {
extend (range_prod);
Params::iterator it = params_.begin();
MapIndexer indexer (args_, ranges_, g_args, g_ranges);
if (Globals::logDomain) {
for (; indexer.valid(); ++it, ++indexer) {
*it += g_params[indexer];
}
} else {
for (; indexer.valid(); ++it, ++indexer) {
*it *= g_params[indexer];
}
}
}
}
void sumOutIndex (size_t idx)
{
assert (idx < args_.size());
assert (args_.size() > 1);
size_t new_size = params_.size() / ranges_[idx];
Params newps (new_size, LogAware::addIdenty());
Params::const_iterator first = params_.begin();
Params::const_iterator last = params_.end();
MapIndexer indexer (ranges_, idx);
if (Globals::logDomain) {
for (; first != last; ++indexer) {
newps[indexer] = Util::logSum (newps[indexer], *first++);
}
} else {
for (; first != last; ++indexer) {
newps[indexer] += *first++;
}
}
params_ = newps;
args_.erase (args_.begin() + idx);
ranges_.erase (ranges_.begin() + idx);
}
void absorveEvidence (const T& arg, unsigned obsIdx)
{
size_t idx = indexOf (arg);
assert (idx != args_.size());
assert (obsIdx < ranges_[idx]);
Params newps;
newps.reserve (params_.size() / ranges_[idx]);
Indexer indexer (ranges_);
for (unsigned i = 0; i < obsIdx; ++i) {
indexer.incrementDimension (idx);
}
while (indexer.valid()) {
newps.push_back (params_[indexer]);
indexer.incrementExceptDimension (idx);
}
params_ = newps;
args_.erase (args_.begin() + idx);
ranges_.erase (ranges_.begin() + idx);
}
void reorderArguments (const vector<T> new_args)
{
assert (new_args.size() == args_.size());
if (new_args == args_) {
return; // already on the desired order
}
Ranges new_ranges;
for (size_t i = 0; i < new_args.size(); i++) {
size_t idx = indexOf (new_args[i]);
assert (idx != args_.size());
new_ranges.push_back (ranges_[idx]);
}
Params newps;
newps.reserve (params_.size());
MapIndexer indexer (new_args, new_ranges, args_, ranges_);
for (; indexer.valid(); ++indexer) {
newps.push_back (params_[indexer]);
}
params_ = newps;
args_ = new_args;
ranges_ = new_ranges;
}
bool contains (const T& arg) const
{
return Util::contains (args_, arg);
}
bool contains (const vector<T>& args) const
{
for (size_t i = 0; i < args.size(); i++) {
if (contains (args[i]) == false) {
return false;
}
}
return true;
}
double& operator[] (size_t idx)
{
assert (idx < params_.size());
return params_[idx];
}
protected:
vector<T> args_;
Ranges ranges_;
Params params_;
unsigned distId_;
private:
void extend (unsigned range_prod)
{
Params backup = params_;
params_.clear();
params_.reserve (backup.size() * range_prod);
Params::const_iterator first = backup.begin();
Params::const_iterator last = backup.end();
for (; first != last; ++first) {
for (unsigned reps = 0; reps < range_prod; ++reps) {
params_.push_back (*first);
}
}
}
void cartesianProduct (
Params::const_iterator first2,
Params::const_iterator last2)
{
Params backup = params_;
params_.clear();
params_.reserve (params_.size() * (last2 - first2));
Params::const_iterator first1 = backup.begin();
Params::const_iterator last1 = backup.end();
Params::const_iterator tmp;
if (Globals::logDomain) {
for (; first1 != last1; ++first1) {
for (tmp = first2; tmp != last2; ++tmp) {
params_.push_back ((*first1) + (*tmp));
}
}
} else {
for (; first1 != last1; ++first1) {
for (tmp = first2; tmp != last2; ++tmp) {
params_.push_back ((*first1) * (*tmp));
}
}
}
}
};
class Factor : public TFactor<VarId>
{
public:
Factor (void) { }
Factor (const Factor&);
Factor() { }
Factor (const VarIds&, const Ranges&, const Params&,
unsigned = Util::maxUnsigned());
@ -272,23 +30,21 @@ class Factor : public TFactor<VarId>
void sumOutAllExceptIndex (size_t idx);
void multiply (Factor&);
Factor& multiply (const Factor&);
string getLabel (void) const;
std::string getLabel() const;
void print (void) const;
void print() const;
private:
void sumOutFirstVariable (void);
void sumOutFirstVariable();
void sumOutLastVariable (void);
void sumOutLastVariable();
void sumOutArgs (const vector<bool>& mask);
void clone (const Factor& f);
DISALLOW_ASSIGN (Factor);
void sumOutArgs (const std::vector<bool>& mask);
};
#endif // HORUS_FACTOR_H
} // namespace Horus
#endif // YAP_PACKAGES_CLPBN_HORUS_FACTOR_H_

View File

@ -1,17 +1,15 @@
#include <cassert>
#include <algorithm>
#include <set>
#include <vector>
#include <iostream>
#include <sstream>
#include <fstream>
#include "FactorGraph.h"
#include "BayesBall.h"
#include "Util.h"
namespace Horus {
bool FactorGraph::exportLd_ = false;
bool FactorGraph::exportUai_ = false;
bool FactorGraph::exportGv_ = false;
@ -20,25 +18,12 @@ bool FactorGraph::printFg_ = false;
FactorGraph::FactorGraph (const FactorGraph& fg)
{
const VarNodes& varNodes = fg.varNodes();
for (size_t i = 0; i < varNodes.size(); i++) {
addVarNode (new VarNode (varNodes[i]));
}
const FacNodes& facNodes = fg.facNodes();
for (size_t i = 0; i < facNodes.size(); i++) {
FacNode* facNode = new FacNode (facNodes[i]->factor());
addFacNode (facNode);
const VarNodes& neighs = facNodes[i]->neighbors();
for (size_t j = 0; j < neighs.size(); j++) {
addEdge (varNodes_[neighs[j]->getIndex()], facNode);
}
}
bayesFactors_ = fg.bayesianFactors();
clone (fg);
}
FactorGraph::~FactorGraph (void)
FactorGraph::~FactorGraph()
{
for (size_t i = 0; i < varNodes_.size(); i++) {
delete varNodes_[i];
@ -50,152 +35,6 @@ FactorGraph::~FactorGraph (void)
void
FactorGraph::readFromUaiFormat (const char* fileName)
{
std::ifstream is (fileName);
if (!is.is_open()) {
cerr << "Error: couldn't open file '" << fileName << "'." ;
exit (EXIT_FAILURE);
}
ignoreLines (is);
string line;
getline (is, line);
if (line == "BAYES") {
bayesFactors_ = true;
} else if (line == "MARKOV") {
bayesFactors_ = false;
} else {
cerr << "Error: the type of network is missing." << endl;
exit (EXIT_FAILURE);
}
// read the number of vars
ignoreLines (is);
unsigned nrVars;
is >> nrVars;
// read the range of each var
ignoreLines (is);
Ranges ranges (nrVars);
for (unsigned i = 0; i < nrVars; i++) {
is >> ranges[i];
}
unsigned nrFactors;
unsigned nrArgs;
unsigned vid;
is >> nrFactors;
vector<VarIds> allVarIds;
vector<Ranges> allRanges;
for (unsigned i = 0; i < nrFactors; i++) {
ignoreLines (is);
is >> nrArgs;
allVarIds.push_back ({ });
allRanges.push_back ({ });
for (unsigned j = 0; j < nrArgs; j++) {
is >> vid;
if (vid >= ranges.size()) {
cerr << "Error: invalid variable identifier `" << vid << "'. " ;
cerr << "Identifiers must be between 0 and " << ranges.size() - 1 ;
cerr << "." << endl;
exit (EXIT_FAILURE);
}
allVarIds.back().push_back (vid);
allRanges.back().push_back (ranges[vid]);
}
}
// read the parameters
unsigned nrParams;
for (unsigned i = 0; i < nrFactors; i++) {
ignoreLines (is);
is >> nrParams;
if (nrParams != Util::sizeExpected (allRanges[i])) {
cerr << "Error: invalid number of parameters for factor nº " << i ;
cerr << ", " << Util::sizeExpected (allRanges[i]);
cerr << " expected, " << nrParams << " given." << endl;
exit (EXIT_FAILURE);
}
Params params (nrParams);
for (unsigned j = 0; j < nrParams; j++) {
is >> params[j];
}
if (Globals::logDomain) {
Util::log (params);
}
Factor f (allVarIds[i], allRanges[i], params);
if (bayesFactors_ && allVarIds[i].size() > 1) {
// In this format the child is the last variable,
// move it to be the first
std::swap (allVarIds[i].front(), allVarIds[i].back());
f.reorderArguments (allVarIds[i]);
}
addFactor (f);
}
is.close();
}
void
FactorGraph::readFromLibDaiFormat (const char* fileName)
{
std::ifstream is (fileName);
if (!is.is_open()) {
cerr << "Error: couldn't open file '" << fileName << "'." ;
exit (EXIT_FAILURE);
}
ignoreLines (is);
unsigned nrFactors;
unsigned nrArgs;
VarId vid;
is >> nrFactors;
for (unsigned i = 0; i < nrFactors; i++) {
ignoreLines (is);
// read the factor arguments
is >> nrArgs;
VarIds vids;
for (unsigned j = 0; j < nrArgs; j++) {
ignoreLines (is);
is >> vid;
vids.push_back (vid);
}
// read ranges
Ranges ranges (nrArgs);
for (unsigned j = 0; j < nrArgs; j++) {
ignoreLines (is);
is >> ranges[j];
VarNode* var = getVarNode (vids[j]);
if (var && ranges[j] != var->range()) {
cerr << "Error: variable `" << vids[j] << "' appears in two or " ;
cerr << "more factors with a different range." << endl;
}
}
// read parameters
ignoreLines (is);
unsigned nNonzeros;
is >> nNonzeros;
Params params (Util::sizeExpected (ranges), 0);
for (unsigned j = 0; j < nNonzeros; j++) {
ignoreLines (is);
unsigned index;
is >> index;
ignoreLines (is);
double val;
is >> val;
params[index] = val;
}
if (Globals::logDomain) {
Util::log (params);
}
std::reverse (vids.begin(), vids.end());
Factor f (vids, ranges, params);
std::reverse (vids.begin(), vids.end());
f.reorderArguments (vids);
addFactor (f);
}
is.close();
}
void
FactorGraph::addFactor (const Factor& factor)
{
@ -221,7 +60,7 @@ FactorGraph::addVarNode (VarNode* vn)
{
varNodes_.push_back (vn);
vn->setIndex (varNodes_.size() - 1);
varMap_.insert (make_pair (vn->varId(), vn));
varMap_.insert (std::make_pair (vn->varId(), vn));
}
@ -245,7 +84,7 @@ FactorGraph::addEdge (VarNode* vn, FacNode* fn)
bool
FactorGraph::isTree (void) const
FactorGraph::isTree() const
{
return !containsCycle();
}
@ -253,7 +92,7 @@ FactorGraph::isTree (void) const
BayesBallGraph&
FactorGraph::getStructure (void)
FactorGraph::getStructure()
{
assert (bayesFactors_);
if (structure_.empty()) {
@ -273,8 +112,10 @@ FactorGraph::getStructure (void)
void
FactorGraph::print (void) const
FactorGraph::print() const
{
using std::cout;
using std::endl;
for (size_t i = 0; i < varNodes_.size(); i++) {
cout << "var id = " << varNodes_[i]->varId() << endl;
cout << "label = " << varNodes_[i]->label() << endl;
@ -296,28 +137,29 @@ FactorGraph::print (void) const
void
FactorGraph::exportToLibDai (const char* fileName) const
{
ofstream out (fileName);
std::ofstream out (fileName);
if (!out.is_open()) {
cerr << "Error: couldn't open file '" << fileName << "'." ;
std::cerr << "Error: couldn't open file '" << fileName << "'." ;
std::cerr << std::endl;
return;
}
out << facNodes_.size() << endl << endl;
out << facNodes_.size() << std::endl << std::endl;
for (size_t i = 0; i < facNodes_.size(); i++) {
Factor f (facNodes_[i]->factor());
out << f.nrArguments() << endl;
out << Util::elementsToString (f.arguments()) << endl;
out << Util::elementsToString (f.ranges()) << endl;
out << f.nrArguments() << std::endl;
out << Util::elementsToString (f.arguments()) << std::endl;
out << Util::elementsToString (f.ranges()) << std::endl;
VarIds args = f.arguments();
std::reverse (args.begin(), args.end());
f.reorderArguments (args);
if (Globals::logDomain) {
Util::exp (f.params());
}
out << f.size() << endl;
out << f.size() << std::endl;
for (size_t j = 0; j < f.size(); j++) {
out << j << " " << f[j] << endl;
out << j << " " << f[j] << std::endl;
}
out << endl;
out << std::endl;
}
out.close();
}
@ -327,28 +169,30 @@ FactorGraph::exportToLibDai (const char* fileName) const
void
FactorGraph::exportToUai (const char* fileName) const
{
ofstream out (fileName);
std::ofstream out (fileName);
if (!out.is_open()) {
cerr << "Error: couldn't open file '" << fileName << "'." ;
std::cerr << "Error: couldn't open file '" << fileName << "'." ;
std::cerr << std::endl;
return;
}
out << (bayesFactors_ ? "BAYES" : "MARKOV") ;
out << endl << endl;
out << varNodes_.size() << endl;
out << std::endl << std::endl;
out << varNodes_.size() << std::endl;
VarNodes sortedVns = varNodes_;
std::sort (sortedVns.begin(), sortedVns.end(), sortByVarId());
for (size_t i = 0; i < sortedVns.size(); i++) {
out << ((i != 0) ? " " : "") << sortedVns[i]->range();
}
out << endl << facNodes_.size() << endl;
out << std::endl << facNodes_.size() << std::endl;
for (size_t i = 0; i < facNodes_.size(); i++) {
VarIds args = facNodes_[i]->factor().arguments();
if (bayesFactors_) {
std::swap (args.front(), args.back());
}
out << args.size() << " " << Util::elementsToString (args) << endl;
out << args.size() << " " << Util::elementsToString (args);
out << std::endl;
}
out << endl;
out << std::endl;
for (size_t i = 0; i < facNodes_.size(); i++) {
Factor f = facNodes_[i]->factor();
if (bayesFactors_) {
@ -360,8 +204,9 @@ FactorGraph::exportToUai (const char* fileName) const
if (Globals::logDomain) {
Util::exp (params);
}
out << params.size() << endl << " " ;
out << Util::elementsToString (params) << endl << endl;
out << params.size() << std::endl << " " ;
out << Util::elementsToString (params);
out << std::endl << std::endl;
}
out.close();
}
@ -371,53 +216,239 @@ FactorGraph::exportToUai (const char* fileName) const
void
FactorGraph::exportToGraphViz (const char* fileName) const
{
ofstream out (fileName);
std::ofstream out (fileName);
if (!out.is_open()) {
cerr << "Error: couldn't open file '" << fileName << "'." ;
std::cerr << "Error: couldn't open file '" << fileName << "'." ;
std::cerr << std::endl;
return;
}
out << "graph \"" << fileName << "\" {" << endl;
out << "graph \"" << fileName << "\" {" << std::endl;
for (size_t i = 0; i < varNodes_.size(); i++) {
if (varNodes_[i]->hasEvidence()) {
out << '"' << varNodes_[i]->label() << '"' ;
out << " [style=filled, fillcolor=yellow]" << endl;
out << " [style=filled, fillcolor=yellow]" << std::endl;
}
}
for (size_t i = 0; i < facNodes_.size(); i++) {
out << '"' << facNodes_[i]->getLabel() << '"' ;
out << " [label=\"" << facNodes_[i]->getLabel();
out << "\"" << ", shape=box]" << endl;
out << "\"" << ", shape=box]" << std::endl;
}
for (size_t i = 0; i < facNodes_.size(); i++) {
const VarNodes& myVars = facNodes_[i]->neighbors();
for (size_t j = 0; j < myVars.size(); j++) {
out << '"' << facNodes_[i]->getLabel() << '"' ;
out << " -- " ;
out << '"' << myVars[j]->label() << '"' << endl;
out << '"' << myVars[j]->label() << '"' << std::endl;
}
}
out << "}" << endl;
out << "}" << std::endl;
out.close();
}
void
FactorGraph::ignoreLines (std::ifstream& is) const
FactorGraph&
FactorGraph::operator= (const FactorGraph& fg)
{
string ignoreStr;
while (is.peek() == '#' || is.peek() == '\n') {
getline (is, ignoreStr);
if (this != &fg) {
for (size_t i = 0; i < varNodes_.size(); i++) {
delete varNodes_[i];
}
varNodes_.clear();
for (size_t i = 0; i < facNodes_.size(); i++) {
delete facNodes_[i];
}
facNodes_.clear();
varMap_.clear();
clone (fg);
}
return *this;
}
FactorGraph
FactorGraph::readFromUaiFormat (const char* fileName)
{
std::ifstream is (fileName);
if (!is.is_open()) {
std::cerr << "Error: couldn't open file '" << fileName << "'." ;
std::cerr << std::endl;
exit (EXIT_FAILURE);
}
FactorGraph fg;
ignoreLines (is);
std::string line;
getline (is, line);
if (line == "BAYES") {
fg.bayesFactors_ = true;
} else if (line == "MARKOV") {
fg.bayesFactors_ = false;
} else {
std::cerr << "Error: the type of network is missing." << std::endl;
exit (EXIT_FAILURE);
}
// read the number of vars
ignoreLines (is);
unsigned nrVars;
is >> nrVars;
// read the range of each var
ignoreLines (is);
Ranges ranges (nrVars);
for (unsigned i = 0; i < nrVars; i++) {
is >> ranges[i];
}
unsigned nrFactors;
unsigned nrArgs;
unsigned vid;
is >> nrFactors;
std::vector<VarIds> allVarIds;
std::vector<Ranges> allRanges;
for (unsigned i = 0; i < nrFactors; i++) {
ignoreLines (is);
is >> nrArgs;
allVarIds.push_back ({ });
allRanges.push_back ({ });
for (unsigned j = 0; j < nrArgs; j++) {
is >> vid;
if (vid >= ranges.size()) {
std::cerr << "Error: invalid variable identifier `" << vid << "'" ;
std::cerr << ". Identifiers must be between 0 and " ;
std::cerr << ranges.size() - 1 << "." << std::endl;
exit (EXIT_FAILURE);
}
allVarIds.back().push_back (vid);
allRanges.back().push_back (ranges[vid]);
}
}
// read the parameters
unsigned nrParams;
for (unsigned i = 0; i < nrFactors; i++) {
ignoreLines (is);
is >> nrParams;
if (nrParams != Util::sizeExpected (allRanges[i])) {
std::cerr << "Error: invalid number of parameters for factor nº " ;
std::cerr << i << ", " << Util::sizeExpected (allRanges[i]);
std::cerr << " expected, " << nrParams << " given." << std::endl;
exit (EXIT_FAILURE);
}
Params params (nrParams);
for (unsigned j = 0; j < nrParams; j++) {
is >> params[j];
}
if (Globals::logDomain) {
Util::log (params);
}
Factor f (allVarIds[i], allRanges[i], params);
if (fg.bayesFactors_ && allVarIds[i].size() > 1) {
// In this format the child is the last variable,
// move it to be the first
std::swap (allVarIds[i].front(), allVarIds[i].back());
f.reorderArguments (allVarIds[i]);
}
fg.addFactor (f);
}
is.close();
return fg;
}
FactorGraph
FactorGraph::readFromLibDaiFormat (const char* fileName)
{
std::ifstream is (fileName);
if (!is.is_open()) {
std::cerr << "Error: couldn't open file '" << fileName << "'." ;
std::cerr << std::endl;
exit (EXIT_FAILURE);
}
FactorGraph fg;
ignoreLines (is);
unsigned nrFactors;
unsigned nrArgs;
VarId vid;
is >> nrFactors;
for (unsigned i = 0; i < nrFactors; i++) {
ignoreLines (is);
// read the factor arguments
is >> nrArgs;
VarIds vids;
for (unsigned j = 0; j < nrArgs; j++) {
ignoreLines (is);
is >> vid;
vids.push_back (vid);
}
// read ranges
Ranges ranges (nrArgs);
for (unsigned j = 0; j < nrArgs; j++) {
ignoreLines (is);
is >> ranges[j];
VarNode* var = fg.getVarNode (vids[j]);
if (var && ranges[j] != var->range()) {
std::cerr << "Error: variable `" << vids[j] << "' appears" ;
std::cerr << " in two or more factors with a different range." ;
std::cerr << std::endl;
exit (EXIT_FAILURE);
}
}
// read parameters
ignoreLines (is);
unsigned nNonzeros;
is >> nNonzeros;
Params params (Util::sizeExpected (ranges), 0);
for (unsigned j = 0; j < nNonzeros; j++) {
ignoreLines (is);
unsigned index;
is >> index;
ignoreLines (is);
double val;
is >> val;
params[index] = val;
}
if (Globals::logDomain) {
Util::log (params);
}
std::reverse (vids.begin(), vids.end());
std::reverse (ranges.begin(), ranges.end());
Factor f (vids, ranges, params);
std::reverse (vids.begin(), vids.end());
f.reorderArguments (vids);
fg.addFactor (f);
}
is.close();
return fg;
}
void
FactorGraph::clone (const FactorGraph& fg)
{
const VarNodes& varNodes = fg.varNodes();
for (size_t i = 0; i < varNodes.size(); i++) {
addVarNode (new VarNode (varNodes[i]));
}
const FacNodes& facNodes = fg.facNodes();
for (size_t i = 0; i < facNodes.size(); i++) {
FacNode* facNode = new FacNode (facNodes[i]->factor());
addFacNode (facNode);
const VarNodes& neighs = facNodes[i]->neighbors();
for (size_t j = 0; j < neighs.size(); j++) {
addEdge (varNodes_[neighs[j]->getIndex()], facNode);
}
}
bayesFactors_ = fg.bayesianFactors();
}
bool
FactorGraph::containsCycle (void) const
FactorGraph::containsCycle() const
{
vector<bool> visitedVars (varNodes_.size(), false);
vector<bool> visitedFactors (facNodes_.size(), false);
std::vector<bool> visitedVars (varNodes_.size(), false);
std::vector<bool> visitedFactors (facNodes_.size(), false);
for (size_t i = 0; i < varNodes_.size(); i++) {
int v = varNodes_[i]->getIndex();
if (!visitedVars[v]) {
@ -435,8 +466,8 @@ bool
FactorGraph::containsCycle (
const VarNode* v,
const FacNode* p,
vector<bool>& visitedVars,
vector<bool>& visitedFactors) const
std::vector<bool>& visitedVars,
std::vector<bool>& visitedFactors) const
{
visitedVars[v->getIndex()] = true;
const FacNodes& adjacencies = v->neighbors();
@ -460,8 +491,8 @@ bool
FactorGraph::containsCycle (
const FacNode* v,
const VarNode* p,
vector<bool>& visitedVars,
vector<bool>& visitedFactors) const
std::vector<bool>& visitedVars,
std::vector<bool>& visitedFactors) const
{
visitedFactors[v->getIndex()] = true;
const VarNodes& adjacencies = v->neighbors();
@ -479,3 +510,16 @@ FactorGraph::containsCycle (
return false; // no cycle detected in this component
}
void
FactorGraph::ignoreLines (std::ifstream& is)
{
std::string ignoreStr;
while (is.peek() == '#' || is.peek() == '\n') {
getline (is, ignoreStr);
}
}
} // namespace Horus

View File

@ -1,28 +1,32 @@
#ifndef HORUS_FACTORGRAPH_H
#define HORUS_FACTORGRAPH_H
#ifndef YAP_PACKAGES_CLPBN_HORUS_FACTORGRAPH_H_
#define YAP_PACKAGES_CLPBN_HORUS_FACTORGRAPH_H_
#include <vector>
#include <unordered_map>
#include <string>
#include <fstream>
#include "Factor.h"
#include "BayesBallGraph.h"
#include "Horus.h"
using namespace std;
namespace Horus {
class FacNode;
class VarNode : public Var
{
class VarNode : public Var {
public:
VarNode (VarId varId, unsigned nrStates,
int evidence = Constants::NO_EVIDENCE)
int evidence = Constants::unobserved)
: Var (varId, nrStates, evidence) { }
VarNode (const Var* v) : Var (v) { }
void addNeighbor (FacNode* fn) { neighs_.push_back (fn); }
const FacNodes& neighbors (void) const { return neighs_; }
const FacNodes& neighbors() const { return neighs_; }
private:
FacNodes neighs_;
@ -32,24 +36,23 @@ class VarNode : public Var
class FacNode
{
class FacNode {
public:
FacNode (const Factor& f) : factor_(f), index_(-1) { }
const Factor& factor (void) const { return factor_; }
const Factor& factor() const { return factor_; }
Factor& factor (void) { return factor_; }
Factor& factor() { return factor_; }
void addNeighbor (VarNode* vn) { neighs_.push_back (vn); }
const VarNodes& neighbors (void) const { return neighs_; }
const VarNodes& neighbors() const { return neighs_; }
size_t getIndex (void) const { return index_; }
size_t getIndex() const { return index_; }
void setIndex (size_t index) { index_ = index; }
string getLabel (void) { return factor_.getLabel(); }
std::string getLabel() { return factor_.getLabel(); }
private:
VarNodes neighs_;
@ -61,36 +64,27 @@ class FacNode
class FactorGraph
{
class FactorGraph {
public:
FactorGraph (void) : bayesFactors_(false) { }
FactorGraph() : bayesFactors_(false) { }
FactorGraph (const FactorGraph&);
~FactorGraph (void);
~FactorGraph();
const VarNodes& varNodes (void) const { return varNodes_; }
const VarNodes& varNodes() const { return varNodes_; }
const FacNodes& facNodes (void) const { return facNodes_; }
const FacNodes& facNodes() const { return facNodes_; }
void setFactorsAsBayesian (void) { bayesFactors_ = true; }
void setFactorsAsBayesian() { bayesFactors_ = true; }
bool bayesianFactors (void) const { return bayesFactors_; }
bool bayesianFactors() const { return bayesFactors_; }
size_t nrVarNodes (void) const { return varNodes_.size(); }
size_t nrVarNodes() const { return varNodes_.size(); }
size_t nrFacNodes (void) const { return facNodes_.size(); }
size_t nrFacNodes() const { return facNodes_.size(); }
VarNode* getVarNode (VarId vid) const
{
VarMap::const_iterator it = varMap_.find (vid);
return it != varMap_.end() ? it->second : 0;
}
void readFromUaiFormat (const char*);
void readFromLibDaiFormat (const char*);
VarNode* getVarNode (VarId vid) const;
void addFactor (const Factor& factor);
@ -100,11 +94,11 @@ class FactorGraph
void addEdge (VarNode*, FacNode*);
bool isTree (void) const;
bool isTree() const;
BayesBallGraph& getStructure (void);
BayesBallGraph& getStructure();
void print (void) const;
void print() const;
void exportToLibDai (const char*) const;
@ -112,67 +106,80 @@ class FactorGraph
void exportToGraphViz (const char*) const;
static bool exportToLibDai (void) { return exportLd_; }
FactorGraph& operator= (const FactorGraph&);
static bool exportToUai (void) { return exportUai_; }
static FactorGraph readFromUaiFormat (const char*);
static bool exportGraphViz (void) { return exportGv_; }
static FactorGraph readFromLibDaiFormat (const char*);
static bool printFactorGraph (void) { return printFg_; }
static bool exportToLibDai() { return exportLd_; }
static void enableExportToLibDai (void) { exportLd_ = true; }
static bool exportToUai() { return exportUai_; }
static void disableExportToLibDai (void) { exportLd_ = false; }
static bool exportGraphViz() { return exportGv_; }
static void enableExportToUai (void) { exportUai_ = true; }
static bool printFactorGraph() { return printFg_; }
static void disableExportToUai (void) { exportUai_ = false; }
static void enableExportToLibDai() { exportLd_ = true; }
static void enableExportToGraphViz (void) { exportGv_ = true; }
static void disableExportToLibDai() { exportLd_ = false; }
static void disableExportToGraphViz (void) { exportGv_ = false; }
static void enableExportToUai() { exportUai_ = true; }
static void enablePrintFactorGraph (void) { printFg_ = true; }
static void disableExportToUai() { exportUai_ = false; }
static void disablePrintFactorGraph (void) { printFg_ = false; }
static void enableExportToGraphViz() { exportGv_ = true; }
static void disableExportToGraphViz() { exportGv_ = false; }
static void enablePrintFactorGraph() { printFg_ = true; }
static void disablePrintFactorGraph() { printFg_ = false; }
private:
void ignoreLines (std::ifstream&) const;
typedef std::unordered_map<unsigned, VarNode*> VarMap;
bool containsCycle (void) const;
void clone (const FactorGraph& fg);
bool containsCycle() const;
bool containsCycle (const VarNode*, const FacNode*,
vector<bool>&, vector<bool>&) const;
std::vector<bool>&, std::vector<bool>&) const;
bool containsCycle (const FacNode*, const VarNode*,
vector<bool>&, vector<bool>&) const;
std::vector<bool>&, std::vector<bool>&) const;
VarNodes varNodes_;
FacNodes facNodes_;
static void ignoreLines (std::ifstream&);
VarNodes varNodes_;
FacNodes facNodes_;
VarMap varMap_;
BayesBallGraph structure_;
bool bayesFactors_;
typedef unordered_map<unsigned, VarNode*> VarMap;
VarMap varMap_;
static bool exportLd_;
static bool exportUai_;
static bool exportGv_;
static bool printFg_;
DISALLOW_ASSIGN (FactorGraph);
static bool exportLd_;
static bool exportUai_;
static bool exportGv_;
static bool printFg_;
};
struct sortByVarId
inline VarNode*
FactorGraph::getVarNode (VarId vid) const
{
VarMap::const_iterator it = varMap_.find (vid);
return it != varMap_.end() ? it->second : 0;
}
struct sortByVarId {
bool operator()(VarNode* vn1, VarNode* vn2) {
return vn1->varId() < vn2->varId();
}
};
}};
} // namespace Horus
#endif // HORUS_FACTORGRAPH_H
#endif // YAP_PACKAGES_CLPBN_HORUS_FACTORGRAPH_H_

View File

@ -0,0 +1,256 @@
#include <cassert>
#include "GenericFactor.h"
#include "ProbFormula.h"
#include "Indexer.h"
namespace Horus {
template <typename T> const T&
GenericFactor<T>::argument (size_t idx) const
{
assert (idx < args_.size());
return args_[idx];
}
template <typename T> T&
GenericFactor<T>::argument (size_t idx)
{
assert (idx < args_.size());
return args_[idx];
}
template <typename T> unsigned
GenericFactor<T>::range (size_t idx) const
{
assert (idx < ranges_.size());
return ranges_[idx];
}
template <typename T> bool
GenericFactor<T>::contains (const T& arg) const
{
return Util::contains (args_, arg);
}
template <typename T> bool
GenericFactor<T>::contains (const std::vector<T>& args) const
{
for (size_t i = 0; i < args.size(); i++) {
if (contains (args[i]) == false) {
return false;
}
}
return true;
}
template <typename T> void
GenericFactor<T>::setParams (const Params& newParams)
{
params_ = newParams;
assert (params_.size() == Util::sizeExpected (ranges_));
}
template <typename T> double
GenericFactor<T>::operator[] (size_t idx) const
{
assert (idx < params_.size());
return params_[idx];
}
template <typename T> double&
GenericFactor<T>::operator[] (size_t idx)
{
assert (idx < params_.size());
return params_[idx];
}
template <typename T> GenericFactor<T>&
GenericFactor<T>::multiply (const GenericFactor<T>& g)
{
if (args_ == g.arguments()) {
// optimization
Globals::logDomain
? params_ += g.params()
: params_ *= g.params();
return *this;
}
unsigned range_prod = 1;
bool share_arguments = false;
const std::vector<T>& g_args = g.arguments();
const Ranges& g_ranges = g.ranges();
const Params& g_params = g.params();
for (size_t i = 0; i < g_args.size(); i++) {
size_t idx = indexOf (g_args[i]);
if (idx == args_.size()) {
range_prod *= g_ranges[i];
args_.push_back (g_args[i]);
ranges_.push_back (g_ranges[i]);
} else {
share_arguments = true;
}
}
if (share_arguments == false) {
// optimization
cartesianProduct (g_params.begin(), g_params.end());
} else {
extend (range_prod);
Params::iterator it = params_.begin();
MapIndexer indexer (args_, ranges_, g_args, g_ranges);
if (Globals::logDomain) {
for (; indexer.valid(); ++it, ++indexer) {
*it += g_params[indexer];
}
} else {
for (; indexer.valid(); ++it, ++indexer) {
*it *= g_params[indexer];
}
}
}
return *this;
}
template <typename T> void
GenericFactor<T>::sumOutIndex (size_t idx)
{
assert (idx < args_.size());
assert (args_.size() > 1);
size_t new_size = params_.size() / ranges_[idx];
Params newps (new_size, LogAware::addIdenty());
Params::const_iterator first = params_.begin();
Params::const_iterator last = params_.end();
MapIndexer indexer (ranges_, idx);
if (Globals::logDomain) {
for (; first != last; ++indexer) {
newps[indexer] = Util::logSum (newps[indexer], *first++);
}
} else {
for (; first != last; ++indexer) {
newps[indexer] += *first++;
}
}
params_ = newps;
args_.erase (args_.begin() + idx);
ranges_.erase (ranges_.begin() + idx);
}
template <typename T> void
GenericFactor<T>::absorveEvidence (const T& arg, unsigned obsIdx)
{
size_t idx = indexOf (arg);
assert (idx != args_.size());
assert (obsIdx < ranges_[idx]);
Params newps;
newps.reserve (params_.size() / ranges_[idx]);
Indexer indexer (ranges_);
for (unsigned i = 0; i < obsIdx; ++i) {
indexer.incrementDimension (idx);
}
while (indexer.valid()) {
newps.push_back (params_[indexer]);
indexer.incrementExceptDimension (idx);
}
params_ = newps;
args_.erase (args_.begin() + idx);
ranges_.erase (ranges_.begin() + idx);
}
template <typename T> void
GenericFactor<T>::reorderArguments (const std::vector<T>& new_args)
{
assert (new_args.size() == args_.size());
if (new_args == args_) {
return; // already on the desired order
}
Ranges new_ranges;
for (size_t i = 0; i < new_args.size(); i++) {
size_t idx = indexOf (new_args[i]);
assert (idx != args_.size());
new_ranges.push_back (ranges_[idx]);
}
Params newps;
newps.reserve (params_.size());
MapIndexer indexer (new_args, new_ranges, args_, ranges_);
for (; indexer.valid(); ++indexer) {
newps.push_back (params_[indexer]);
}
params_ = newps;
args_ = new_args;
ranges_ = new_ranges;
}
template <typename T> void
GenericFactor<T>::extend (unsigned range_prod)
{
Params backup = params_;
params_.clear();
params_.reserve (backup.size() * range_prod);
Params::const_iterator first = backup.begin();
Params::const_iterator last = backup.end();
for (; first != last; ++first) {
for (unsigned reps = 0; reps < range_prod; ++reps) {
params_.push_back (*first);
}
}
}
template <typename T> void
GenericFactor<T>::cartesianProduct (
Params::const_iterator first2,
Params::const_iterator last2)
{
Params backup = params_;
params_.clear();
params_.reserve (params_.size() * (last2 - first2));
Params::const_iterator first1 = backup.begin();
Params::const_iterator last1 = backup.end();
Params::const_iterator tmp;
if (Globals::logDomain) {
for (; first1 != last1; ++first1) {
for (tmp = first2; tmp != last2; ++tmp) {
params_.push_back ((*first1) + (*tmp));
}
}
} else {
for (; first1 != last1; ++first1) {
for (tmp = first2; tmp != last2; ++tmp) {
params_.push_back ((*first1) * (*tmp));
}
}
}
}
template class GenericFactor<VarId>;
template class GenericFactor<ProbFormula>;
} // namespace Horus

View File

@ -0,0 +1,76 @@
#ifndef YAP_PACKAGES_CLPBN_HORUS_GENERICFACTOR_H_
#define YAP_PACKAGES_CLPBN_HORUS_GENERICFACTOR_H_
#include <vector>
#include "Util.h"
namespace Horus {
template <typename T>
class GenericFactor {
public:
const std::vector<T>& arguments() const { return args_; }
std::vector<T>& arguments() { return args_; }
const Ranges& ranges() const { return ranges_; }
const Params& params() const { return params_; }
Params& params() { return params_; }
size_t nrArguments() const { return args_.size(); }
size_t size() const { return params_.size(); }
unsigned distId() const { return distId_; }
void setDistId (unsigned id) { distId_ = id; }
void normalize() { LogAware::normalize (params_); }
size_t indexOf (const T& t) const { return Util::indexOf (args_, t); }
const T& argument (size_t idx) const;
T& argument (size_t idx);
unsigned range (size_t idx) const;
bool contains (const T& arg) const;
bool contains (const std::vector<T>& args) const;
void setParams (const Params& newParams);
double operator[] (size_t idx) const;
double& operator[] (size_t idx);
GenericFactor<T>& multiply (const GenericFactor<T>& g);
void sumOutIndex (size_t idx);
void absorveEvidence (const T& arg, unsigned obsIdx);
void reorderArguments (const std::vector<T>& new_args);
protected:
std::vector<T> args_;
Ranges ranges_;
Params params_;
unsigned distId_;
private:
void extend (unsigned range_prod);
void cartesianProduct (
Params::const_iterator first2, Params::const_iterator last2);
};
} // namespace Horus
#endif // YAP_PACKAGES_CLPBN_HORUS_GENERICFACTOR_H_

View File

@ -1,10 +1,20 @@
#include <cassert>
#include <vector>
#include <string>
#include <iostream>
#include <iomanip>
#include "GroundSolver.h"
#include "VarElim.h"
#include "BeliefProp.h"
#include "CountingBp.h"
#include "Indexer.h"
#include "Util.h"
namespace Horus {
void
GroundSolver::printAnswer (const VarIds& vids)
{
@ -19,20 +29,21 @@ GroundSolver::printAnswer (const VarIds& vids)
}
if (unobservedVids.empty() == false) {
Params res = solveQuery (unobservedVids);
vector<string> stateLines = Util::getStateLines (unobservedVars);
std::vector<std::string> stateLines =
Util::getStateLines (unobservedVars);
for (size_t i = 0; i < res.size(); i++) {
cout << "P(" << stateLines[i] << ") = " ;
cout << std::setprecision (Constants::PRECISION) << res[i];
cout << endl;
std::cout << "P(" << stateLines[i] << ") = " ;
std::cout << std::setprecision (Constants::precision) << res[i];
std::cout << std::endl;
}
cout << endl;
std::cout << std::endl;
}
}
void
GroundSolver::printAllPosterioris (void)
GroundSolver::printAllPosterioris()
{
VarNodes vars = fg.varNodes();
std::sort (vars.begin(), vars.end(), sortByVarId());
@ -57,9 +68,9 @@ GroundSolver::getJointByConditioning (
GroundSolver* solver = 0;
switch (solverType) {
case GroundSolverType::BP: solver = new BeliefProp (fg); break;
case GroundSolverType::CBP: solver = new CountingBp (fg); break;
case GroundSolverType::VE: solver = new VarElim (fg); break;
case GroundSolverType::bpSolver: solver = new BeliefProp (fg); break;
case GroundSolverType::CbpSolver: solver = new CountingBp (fg); break;
case GroundSolverType::veSolver: solver = new VarElim (fg); break;
}
Params prevBeliefs = solver->solveQuery ({jointVarIds[0]});
VarIds observedVids = {jointVars[0]->varId()};
@ -80,9 +91,9 @@ GroundSolver::getJointByConditioning (
}
delete solver;
switch (solverType) {
case GroundSolverType::BP: solver = new BeliefProp (fg); break;
case GroundSolverType::CBP: solver = new CountingBp (fg); break;
case GroundSolverType::VE: solver = new VarElim (fg); break;
case GroundSolverType::bpSolver: solver = new BeliefProp (fg); break;
case GroundSolverType::CbpSolver: solver = new CountingBp (fg); break;
case GroundSolverType::veSolver: solver = new VarElim (fg); break;
}
Params beliefs = solver->solveQuery ({jointVarIds[i]});
for (size_t k = 0; k < beliefs.size(); k++) {
@ -105,3 +116,5 @@ GroundSolver::getJointByConditioning (
return prevBeliefs;
}
} // namespace Horus

View File

@ -1,16 +1,13 @@
#ifndef HORUS_GROUNDSOLVER_H
#define HORUS_GROUNDSOLVER_H
#include <iomanip>
#ifndef YAP_PACKAGES_CLPBN_HORUS_GROUNDSOLVER_H_
#define YAP_PACKAGES_CLPBN_HORUS_GROUNDSOLVER_H_
#include "FactorGraph.h"
#include "Horus.h"
using namespace std;
namespace Horus {
class GroundSolver
{
class GroundSolver {
public:
GroundSolver (const FactorGraph& factorGraph) : fg(factorGraph) { }
@ -18,11 +15,11 @@ class GroundSolver
virtual Params solveQuery (VarIds queryVids) = 0;
virtual void printSolverFlags (void) const = 0;
virtual void printSolverFlags() const = 0;
void printAnswer (const VarIds& vids);
void printAllPosterioris (void);
void printAllPosterioris();
static Params getJointByConditioning (GroundSolverType,
FactorGraph, const VarIds& jointVarIds);
@ -30,8 +27,11 @@ class GroundSolver
protected:
const FactorGraph& fg;
private:
DISALLOW_COPY_AND_ASSIGN (GroundSolver);
};
#endif // HORUS_GROUNDSOLVER_H
} // namespace Horus
#endif // YAP_PACKAGES_CLPBN_HORUS_GROUNDSOLVER_H_

View File

@ -7,6 +7,8 @@
#include "Util.h"
namespace Horus {
HistogramSet::HistogramSet (unsigned size, unsigned range)
{
size_ = size;
@ -17,7 +19,7 @@ HistogramSet::HistogramSet (unsigned size, unsigned range)
void
HistogramSet::nextHistogram (void)
HistogramSet::nextHistogram()
{
for (size_t i = hist_.size() - 1; i-- > 0; ) {
if (hist_[i] > 0) {
@ -43,7 +45,7 @@ HistogramSet::operator[] (size_t idx) const
unsigned
HistogramSet::nrHistograms (void) const
HistogramSet::nrHistograms() const
{
return HistogramSet::nrHistograms (size_, hist_.size());
}
@ -51,7 +53,7 @@ HistogramSet::nrHistograms (void) const
void
HistogramSet::reset (void)
HistogramSet::reset()
{
std::fill (hist_.begin() + 1, hist_.end(), 0);
hist_[0] = size_;
@ -59,12 +61,12 @@ HistogramSet::reset (void)
vector<Histogram>
std::vector<Histogram>
HistogramSet::getHistograms (unsigned N, unsigned R)
{
HistogramSet hs (N, R);
unsigned H = hs.nrHistograms();
vector<Histogram> histograms;
std::vector<Histogram> histograms;
histograms.reserve (H);
for (unsigned i = 0; i < H; i++) {
histograms.push_back (hs.hist_);
@ -86,9 +88,9 @@ HistogramSet::nrHistograms (unsigned N, unsigned R)
size_t
HistogramSet::findIndex (
const Histogram& h,
const vector<Histogram>& hists)
const std::vector<Histogram>& hists)
{
vector<Histogram>::const_iterator it = std::lower_bound (
std::vector<Histogram>::const_iterator it = std::lower_bound (
hists.begin(), hists.end(), h, std::greater<Histogram>());
assert (it != hists.end() && *it == h);
return std::distance (hists.begin(), it);
@ -96,13 +98,13 @@ HistogramSet::findIndex (
vector<double>
std::vector<double>
HistogramSet::getNumAssigns (unsigned N, unsigned R)
{
HistogramSet hs (N, R);
double N_fac = Util::logFactorial (N);
unsigned H = hs.nrHistograms();
vector<double> numAssigns;
std::vector<double> numAssigns;
numAssigns.reserve (H);
for (unsigned h = 0; h < H; h++) {
double prod = 0.0;
@ -118,14 +120,6 @@ HistogramSet::getNumAssigns (unsigned N, unsigned R)
ostream& operator<< (ostream &os, const HistogramSet& hs)
{
os << "#" << hs.hist_;
return os;
}
unsigned
HistogramSet::maxCount (size_t idx) const
{
@ -144,3 +138,14 @@ HistogramSet::clearAfter (size_t idx)
std::fill (hist_.begin() + idx + 1, hist_.end(), 0);
}
std::ostream&
operator<< (std::ostream& os, const HistogramSet& hs)
{
os << "#" << hs.hist_;
return os;
}
} // namespace Horus

View File

@ -1,50 +1,51 @@
#ifndef HORUS_HISTOGRAM_H
#define HORUS_HISTOGRAM_H
#ifndef YAP_PACKAGES_CLPBN_HORUS_HISTOGRAM_H_
#define YAP_PACKAGES_CLPBN_HORUS_HISTOGRAM_H_
#include <vector>
#include <ostream>
#include "Horus.h"
using namespace std;
typedef std::vector<unsigned> Histogram;
typedef vector<unsigned> Histogram;
class HistogramSet
{
namespace Horus {
class HistogramSet {
public:
HistogramSet (unsigned, unsigned);
void nextHistogram (void);
void nextHistogram();
unsigned operator[] (size_t idx) const;
unsigned nrHistograms (void) const;
unsigned nrHistograms() const;
void reset (void);
void reset();
static vector<Histogram> getHistograms (unsigned, unsigned);
static std::vector<Histogram> getHistograms (unsigned, unsigned);
static unsigned nrHistograms (unsigned, unsigned);
static size_t findIndex (
const Histogram&, const vector<Histogram>&);
const Histogram&, const std::vector<Histogram>&);
static vector<double> getNumAssigns (unsigned, unsigned);
friend std::ostream& operator<< (ostream &os, const HistogramSet& hs);
static std::vector<double> getNumAssigns (unsigned, unsigned);
private:
unsigned maxCount (size_t) const;
void clearAfter (size_t);
friend std::ostream& operator<< (std::ostream&, const HistogramSet&);
unsigned size_;
Histogram hist_;
DISALLOW_COPY_AND_ASSIGN (HistogramSet);
};
#endif // HORUS_HISTOGRAM_H
} // namespace Horus
#endif // YAP_PACKAGES_CLPBN_HORUS_HISTOGRAM_H_

View File

@ -1,5 +1,5 @@
#ifndef HORUS_HORUS_H
#define HORUS_HORUS_H
#ifndef YAP_PACKAGES_CLPBN_HORUS_HORUS_H_
#define YAP_PACKAGES_CLPBN_HORUS_HORUS_H_
#define DISALLOW_COPY_AND_ASSIGN(TypeName) \
TypeName(const TypeName&); \
@ -14,6 +14,9 @@
#include <vector>
#include <string>
namespace Horus {
class Var;
class Factor;
class VarNode;
@ -31,19 +34,17 @@ typedef std::vector<unsigned> Ranges;
typedef unsigned long long ullong;
enum LiftedSolverType
{
LVE, // generalized counting first-order variable elimination (GC-FOVE)
LBP, // lifted first-order belief propagation
LKC // lifted first-order knowledge compilation
enum class LiftedSolverType {
lveSolver, // generalized counting first-order variable elimination
lbpSolver, // lifted first-order belief propagation
lkcSolver // lifted first-order knowledge compilation
};
enum GroundSolverType
{
VE, // variable elimination
BP, // belief propagation
CBP // counting belief propagation
enum class GroundSolverType {
veSolver, // variable elimination
bpSolver, // belief propagation
CbpSolver // counting belief propagation
};
@ -57,20 +58,22 @@ extern unsigned verbosity;
extern LiftedSolverType liftedSolver;
extern GroundSolverType groundSolver;
};
}
namespace Constants {
// show message calculation for belief propagation
const bool SHOW_BP_CALCS = false;
const bool showBpCalcs = false;
const int NO_EVIDENCE = -1;
const int unobserved = -1;
// number of digits to show when printing a parameter
const unsigned PRECISION = 6;
const unsigned precision = 8;
};
}
#endif // HORUS_HORUS_H
} // namespace Horus
#endif // YAP_PACKAGES_CLPBN_HORUS_HORUS_H_

View File

@ -1,53 +1,61 @@
#include <cstdlib>
#include <cassert>
#include <string>
#include <iostream>
#include <sstream>
#include "FactorGraph.h"
#include "VarElim.h"
#include "BeliefProp.h"
#include "CountingBp.h"
using namespace std;
namespace {
int readHorusFlags (int, const char* []);
void readFactorGraph (FactorGraph&, const char*);
VarIds readQueryAndEvidence (FactorGraph&, int, const char* [], int);
void runSolver (const FactorGraph&, const VarIds&);
void readFactorGraph (Horus::FactorGraph&, const char*);
const string USAGE = "usage: ./hcli [solver=hve|bp|cbp] \
Horus::VarIds readQueryAndEvidence (
Horus::FactorGraph&, int, const char* [], int);
void runSolver (const Horus::FactorGraph&, const Horus::VarIds&);
const std::string usage = "usage: ./hcli [solver=hve|bp|cbp] \
[<OPTION>=<VALUE>]... <FILE> [<VAR>|<VAR>=<EVIDENCE>]... " ;
}
int
main (int argc, const char* argv[])
{
if (argc <= 1) {
cerr << "Error: no probabilistic graphical model was given." << endl;
cerr << USAGE << endl;
std::cerr << "Error: no probabilistic graphical model was given." ;
std::cerr << std::endl << usage << std::endl;
exit (EXIT_FAILURE);
}
int idx = readHorusFlags (argc, argv);
FactorGraph fg;
Horus::FactorGraph fg;
readFactorGraph (fg, argv[idx]);
VarIds queryIds = readQueryAndEvidence (fg, argc, argv, idx + 1);
if (FactorGraph::exportToLibDai()) {
Horus::VarIds queryIds
= readQueryAndEvidence (fg, argc, argv, idx + 1);
if (Horus::FactorGraph::exportToLibDai()) {
fg.exportToLibDai ("model.fg");
}
if (FactorGraph::exportToUai()) {
if (Horus::FactorGraph::exportToUai()) {
fg.exportToUai ("model.uai");
}
if (FactorGraph::exportGraphViz()) {
if (Horus::FactorGraph::exportGraphViz()) {
fg.exportToGraphViz ("model.dot");
}
if (FactorGraph::printFactorGraph()) {
if (Horus::FactorGraph::printFactorGraph()) {
fg.print();
}
if (Globals::verbosity > 0) {
cout << "factor graph contains " ;
cout << fg.nrVarNodes() << " variables and " ;
cout << fg.nrFacNodes() << " factors " << endl;
if (Horus::Globals::verbosity > 0) {
std::cout << "factor graph contains " ;
std::cout << fg.nrVarNodes() << " variables and " ;
std::cout << fg.nrFacNodes() << " factors " << std::endl;
}
runSolver (fg, queryIds);
return 0;
@ -55,29 +63,31 @@ main (int argc, const char* argv[])
namespace {
int
readHorusFlags (int argc, const char* argv[])
{
int i = 1;
for (; i < argc; i++) {
const string& arg = argv[i];
const std::string& arg = argv[i];
size_t pos = arg.find ('=');
if (pos == std::string::npos) {
return i;
}
string leftArg = arg.substr (0, pos);
string rightArg = arg.substr (pos + 1);
std::string leftArg = arg.substr (0, pos);
std::string rightArg = arg.substr (pos + 1);
if (leftArg.empty()) {
cerr << "Error: missing left argument." << endl;
cerr << USAGE << endl;
std::cerr << "Error: missing left argument." << std::endl;
std::cerr << usage << std::endl;
exit (EXIT_FAILURE);
}
if (rightArg.empty()) {
cerr << "Error: missing right argument." << endl;
cerr << USAGE << endl;
std::cerr << "Error: missing right argument." << std::endl;
std::cerr << usage << std::endl;
exit (EXIT_FAILURE);
}
Util::setHorusFlag (leftArg, rightArg);
Horus::Util::setHorusFlag (leftArg, rightArg);
}
return i + 1;
}
@ -85,84 +95,84 @@ readHorusFlags (int argc, const char* argv[])
void
readFactorGraph (FactorGraph& fg, const char* s)
readFactorGraph (Horus::FactorGraph& fg, const char* s)
{
string fileName (s);
string extension = fileName.substr (fileName.find_last_of ('.') + 1);
std::string fileName (s);
std::string extension = fileName.substr (fileName.find_last_of ('.') + 1);
if (extension == "uai") {
fg.readFromUaiFormat (fileName.c_str());
fg = Horus::FactorGraph::readFromUaiFormat (fileName.c_str());
} else if (extension == "fg") {
fg.readFromLibDaiFormat (fileName.c_str());
fg = Horus::FactorGraph::readFromLibDaiFormat (fileName.c_str());
} else {
cerr << "Error: the probabilistic graphical model must be " ;
cerr << "defined either in a UAI or libDAI file." << endl;
std::cerr << "Error: the probabilistic graphical model must be " ;
std::cerr << "defined either in a UAI or libDAI file." << std::endl;
exit (EXIT_FAILURE);
}
}
VarIds
Horus::VarIds
readQueryAndEvidence (
FactorGraph& fg,
Horus::FactorGraph& fg,
int argc,
const char* argv[],
int start)
{
VarIds queryIds;
Horus::VarIds queryIds;
for (int i = start; i < argc; i++) {
const string& arg = argv[i];
const std::string& arg = argv[i];
if (arg.find ('=') == std::string::npos) {
if (Util::isInteger (arg) == false) {
cerr << "Error: `" << arg << "' " ;
cerr << "is not a variable id." ;
cerr << endl;
if (Horus::Util::isInteger (arg) == false) {
std::cerr << "Error: `" << arg << "' " ;
std::cerr << "is not a variable id." ;
std::cerr << std::endl;
exit (EXIT_FAILURE);
}
VarId vid = Util::stringToUnsigned (arg);
VarNode* queryVar = fg.getVarNode (vid);
Horus::VarId vid = Horus::Util::stringToUnsigned (arg);
Horus::VarNode* queryVar = fg.getVarNode (vid);
if (queryVar == false) {
cerr << "Error: unknow variable with id " ;
cerr << "`" << vid << "'." << endl;
std::cerr << "Error: unknow variable with id " ;
std::cerr << "`" << vid << "'." << std::endl;
exit (EXIT_FAILURE);
}
queryIds.push_back (vid);
} else {
size_t pos = arg.find ('=');
string leftArg = arg.substr (0, pos);
string rightArg = arg.substr (pos + 1);
std::string leftArg = arg.substr (0, pos);
std::string rightArg = arg.substr (pos + 1);
if (leftArg.empty()) {
cerr << "Error: missing left argument." << endl;
cerr << USAGE << endl;
std::cerr << "Error: missing left argument." << std::endl;
std::cerr << usage << std::endl;
exit (EXIT_FAILURE);
}
if (Util::isInteger (leftArg) == false) {
cerr << "Error: `" << leftArg << "' " ;
cerr << "is not a variable id." << endl ;
if (Horus::Util::isInteger (leftArg) == false) {
std::cerr << "Error: `" << leftArg << "' " ;
std::cerr << "is not a variable id." << std::endl;
exit (EXIT_FAILURE);
}
VarId vid = Util::stringToUnsigned (leftArg);
VarNode* observedVar = fg.getVarNode (vid);
Horus::VarId vid = Horus::Util::stringToUnsigned (leftArg);
Horus::VarNode* observedVar = fg.getVarNode (vid);
if (observedVar == false) {
cerr << "Error: unknow variable with id " ;
cerr << "`" << vid << "'." << endl;
std::cerr << "Error: unknow variable with id " ;
std::cerr << "`" << vid << "'." << std::endl;
exit (EXIT_FAILURE);
}
if (rightArg.empty()) {
cerr << "Error: missing right argument." << endl;
cerr << USAGE << endl;
std::cerr << "Error: missing right argument." << std::endl;
std::cerr << usage << std::endl;
exit (EXIT_FAILURE);
}
if (Util::isInteger (rightArg) == false) {
cerr << "Error: `" << rightArg << "' " ;
cerr << "is not a state index." << endl ;
if (Horus::Util::isInteger (rightArg) == false) {
std::cerr << "Error: `" << rightArg << "' " ;
std::cerr << "is not a state index." << std::endl;
exit (EXIT_FAILURE);
}
unsigned stateIdx = Util::stringToUnsigned (rightArg);
unsigned stateIdx = Horus::Util::stringToUnsigned (rightArg);
if (observedVar->isValidState (stateIdx) == false) {
cerr << "Error: `" << stateIdx << "' " ;
cerr << "is not a valid state index for variable with id " ;
cerr << "`" << vid << "'." << endl;
std::cerr << "Error: `" << stateIdx << "' " ;
std::cerr << "is not a valid state index for variable with id " ;
std::cerr << "`" << vid << "'." << std::endl;
exit (EXIT_FAILURE);
}
observedVar->setEvidence (stateIdx);
@ -174,25 +184,27 @@ readQueryAndEvidence (
void
runSolver (const FactorGraph& fg, const VarIds& queryIds)
runSolver (
const Horus::FactorGraph& fg,
const Horus::VarIds& queryIds)
{
GroundSolver* solver = 0;
switch (Globals::groundSolver) {
case GroundSolverType::VE:
solver = new VarElim (fg);
Horus::GroundSolver* solver = 0;
switch (Horus::Globals::groundSolver) {
case Horus::GroundSolverType::veSolver:
solver = new Horus::VarElim (fg);
break;
case GroundSolverType::BP:
solver = new BeliefProp (fg);
case Horus::GroundSolverType::bpSolver:
solver = new Horus::BeliefProp (fg);
break;
case GroundSolverType::CBP:
solver = new CountingBp (fg);
case Horus::GroundSolverType::CbpSolver:
solver = new Horus::CountingBp (fg);
break;
default:
assert (false);
}
if (Globals::verbosity > 0) {
if (Horus::Globals::verbosity > 0) {
solver->printSolverFlags();
cout << endl;
std::cout << std::endl;
}
if (queryIds.empty()) {
solver->printAllPosterioris();
@ -202,3 +214,5 @@ runSolver (const FactorGraph& fg, const VarIds& queryIds)
delete solver;
}
}

View File

@ -1,7 +1,8 @@
#include <cstdlib>
#include <cassert>
#include <vector>
#include <unordered_map>
#include <string>
#include <iostream>
#include <sstream>
@ -20,31 +21,35 @@
#include "BayesBall.h"
using namespace std;
namespace Horus {
typedef std::pair<ParfactorList*, ObservedFormulas*> LiftedNetwork;
namespace {
Parfactor* readParfactor (YAP_Term);
void readLiftedEvidence (YAP_Term, ObservedFormulas&);
ObservedFormulas* readLiftedEvidence (YAP_Term);
vector<unsigned> readUnsignedList (YAP_Term list);
std::vector<unsigned> readUnsignedList (YAP_Term);
Params readParameters (YAP_Term);
YAP_Term fillAnswersPrologList (vector<Params>& results);
YAP_Term fillSolutionList (const std::vector<Params>&);
}
typedef std::pair<ParfactorList*, ObservedFormulas*> LiftedNetwork;
int
createLiftedNetwork (void)
createLiftedNetwork()
{
Parfactors parfactors;
YAP_Term parfactorList = YAP_ARG1;
while (parfactorList != YAP_TermNil()) {
YAP_Term pfTerm = YAP_HeadOfTerm (parfactorList);
parfactors.push_back (readParfactor (pfTerm));
parfactorList = YAP_TailOfTerm (parfactorList);
parfactorList = YAP_TailOfTerm (parfactorList);
}
// LiftedUtils::printSymbolDictionary();
@ -52,7 +57,7 @@ createLiftedNetwork (void)
Util::printHeader ("INITIAL PARFACTORS");
for (size_t i = 0; i < parfactors.size(); i++) {
parfactors[i]->print();
cout << endl;
std::cout << std::endl;
}
}
@ -64,21 +69,20 @@ createLiftedNetwork (void)
}
// read evidence
ObservedFormulas* obsFormulas = new ObservedFormulas();
readLiftedEvidence (YAP_ARG2, *(obsFormulas));
ObservedFormulas* obsFormulas = readLiftedEvidence (YAP_ARG2);
LiftedNetwork* net = new LiftedNetwork (pfList, obsFormulas);
LiftedNetwork* network = new LiftedNetwork (pfList, obsFormulas);
YAP_Int p = (YAP_Int) (net);
YAP_Int p = (YAP_Int) (network);
return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG3);
}
int
createGroundNetwork (void)
createGroundNetwork()
{
string factorsType ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG1)));
std::string factorsType ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG1)));
FactorGraph* fg = new FactorGraph();
if (factorsType == "bayes") {
fg->setFactorsAsBayesian();
@ -121,9 +125,9 @@ createGroundNetwork (void)
fg->print();
}
if (Globals::verbosity > 0) {
cout << "factor graph contains " ;
cout << fg->nrVarNodes() << " variables and " ;
cout << fg->nrFacNodes() << " factors " << endl;
std::cout << "factor graph contains " ;
std::cout << fg->nrVarNodes() << " variables and " ;
std::cout << fg->nrFacNodes() << " factors " << std::endl;
}
YAP_Int p = (YAP_Int) (fg);
return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG4);
@ -132,45 +136,46 @@ createGroundNetwork (void)
int
runLiftedSolver (void)
runLiftedSolver()
{
LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1);
ParfactorList pfListCopy (*network->first);
LiftedOperations::absorveEvidence (pfListCopy, *network->second);
ParfactorList copy (*network->first);
LiftedOperations::absorveEvidence (copy, *network->second);
LiftedSolver* solver = 0;
switch (Globals::liftedSolver) {
case LiftedSolverType::LVE: solver = new LiftedVe (pfListCopy); break;
case LiftedSolverType::LBP: solver = new LiftedBp (pfListCopy); break;
case LiftedSolverType::LKC: solver = new LiftedKc (pfListCopy); break;
case LiftedSolverType::lveSolver: solver = new LiftedVe (copy); break;
case LiftedSolverType::lbpSolver: solver = new LiftedBp (copy); break;
case LiftedSolverType::lkcSolver: solver = new LiftedKc (copy); break;
}
if (Globals::verbosity > 0) {
solver->printSolverFlags();
cout << endl;
std::cout << std::endl;
}
YAP_Term taskList = YAP_ARG2;
vector<Params> results;
std::vector<Params> results;
while (taskList != YAP_TermNil()) {
Grounds queryVars;
YAP_Term jointList = YAP_HeadOfTerm (taskList);
while (jointList != YAP_TermNil()) {
YAP_Term ground = YAP_HeadOfTerm (jointList);
if (YAP_IsAtomTerm (ground)) {
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ground)));
std::string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ground)));
queryVars.push_back (Ground (LiftedUtils::getSymbol (name)));
} else {
assert (YAP_IsApplTerm (ground));
YAP_Functor yapFunctor = YAP_FunctorOfTerm (ground);
string name ((char*) (YAP_AtomName (YAP_NameOfFunctor (yapFunctor))));
std::string name ((char*) (YAP_AtomName (
YAP_NameOfFunctor (yapFunctor))));
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
Symbol functor = LiftedUtils::getSymbol (name);
Symbols args;
for (unsigned i = 1; i <= arity; i++) {
YAP_Term ti = YAP_ArgOfTerm (i, ground);
assert (YAP_IsAtomTerm (ti));
string arg ((char *) YAP_AtomName (YAP_AtomOfTerm (ti)));
std::string arg ((char *) YAP_AtomName (YAP_AtomOfTerm (ti)));
args.push_back (LiftedUtils::getSymbol (arg));
}
queryVars.push_back (Ground (functor, args));
@ -183,17 +188,17 @@ runLiftedSolver (void)
delete solver;
return YAP_Unify (fillAnswersPrologList (results), YAP_ARG3);
return YAP_Unify (fillSolutionList (results), YAP_ARG3);
}
int
runGroundSolver (void)
runGroundSolver()
{
FactorGraph* fg = (FactorGraph*) YAP_IntOfTerm (YAP_ARG1);
vector<VarIds> tasks;
std::vector<VarIds> tasks;
YAP_Term taskList = YAP_ARG2;
while (taskList != YAP_TermNil()) {
tasks.push_back (readUnsignedList (YAP_HeadOfTerm (taskList)));
@ -213,17 +218,17 @@ runGroundSolver (void)
GroundSolver* solver = 0;
CountingBp::setFindIdenticalFactorsFlag (false);
switch (Globals::groundSolver) {
case GroundSolverType::VE: solver = new VarElim (*mfg); break;
case GroundSolverType::BP: solver = new BeliefProp (*mfg); break;
case GroundSolverType::CBP: solver = new CountingBp (*mfg); break;
case GroundSolverType::veSolver: solver = new VarElim (*mfg); break;
case GroundSolverType::bpSolver: solver = new BeliefProp (*mfg); break;
case GroundSolverType::CbpSolver: solver = new CountingBp (*mfg); break;
}
if (Globals::verbosity > 0) {
solver->printSolverFlags();
cout << endl;
std::cout << std::endl;
}
vector<Params> results;
std::vector<Params> results;
results.reserve (tasks.size());
for (size_t i = 0; i < tasks.size(); i++) {
results.push_back (solver->solveQuery (tasks[i]));
@ -234,19 +239,19 @@ runGroundSolver (void)
delete mfg;
}
return YAP_Unify (fillAnswersPrologList (results), YAP_ARG3);
return YAP_Unify (fillSolutionList (results), YAP_ARG3);
}
int
setParfactorsParams (void)
setParfactorsParams()
{
LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1);
ParfactorList* pfList = network->first;
YAP_Term distIdsList = YAP_ARG2;
YAP_Term paramsList = YAP_ARG3;
unordered_map<unsigned, Params> paramsMap;
std::unordered_map<unsigned, Params> paramsMap;
while (distIdsList != YAP_TermNil()) {
unsigned distId = (unsigned) YAP_IntOfTerm (
YAP_HeadOfTerm (distIdsList));
@ -267,12 +272,12 @@ setParfactorsParams (void)
int
setFactorsParams (void)
setFactorsParams()
{
FactorGraph* fg = (FactorGraph*) YAP_IntOfTerm (YAP_ARG1);
YAP_Term distIdsList = YAP_ARG2;
YAP_Term paramsList = YAP_ARG3;
unordered_map<unsigned, Params> paramsMap;
std::unordered_map<unsigned, Params> paramsMap;
while (distIdsList != YAP_TermNil()) {
unsigned distId = (unsigned) YAP_IntOfTerm (
YAP_HeadOfTerm (distIdsList));
@ -293,10 +298,10 @@ setFactorsParams (void)
int
setVarsInformation (void)
setVarsInformation()
{
Var::clearVarsInfo();
vector<string> labels;
std::vector<std::string> labels;
YAP_Term labelsL = YAP_ARG1;
while (labelsL != YAP_TermNil()) {
YAP_Atom atom = YAP_AtomOfTerm (YAP_HeadOfTerm (labelsL));
@ -323,20 +328,20 @@ setVarsInformation (void)
int
setHorusFlag (void)
setHorusFlag()
{
string option ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG1)));
string value;
std::string option ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG1)));
std::string value;
if (option == "verbosity") {
stringstream ss;
std::stringstream ss;
ss << (int) YAP_IntOfTerm (YAP_ARG2);
ss >> value;
} else if (option == "bp_accuracy") {
stringstream ss;
std::stringstream ss;
ss << (float) YAP_FloatOfTerm (YAP_ARG2);
ss >> value;
} else if (option == "bp_max_iter") {
stringstream ss;
std::stringstream ss;
ss << (int) YAP_IntOfTerm (YAP_ARG2);
ss >> value;
} else {
@ -348,7 +353,7 @@ setHorusFlag (void)
int
freeGroundNetwork (void)
freeGroundNetwork()
{
delete (FactorGraph*) YAP_IntOfTerm (YAP_ARG1);
return TRUE;
@ -357,7 +362,7 @@ freeGroundNetwork (void)
int
freeLiftedNetwork (void)
freeLiftedNetwork()
{
LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1);
delete network->first;
@ -368,6 +373,8 @@ freeLiftedNetwork (void)
namespace {
Parfactor*
readParfactor (YAP_Term pfTerm)
{
@ -386,23 +393,24 @@ readParfactor (YAP_Term pfTerm)
// read parametric random vars
ProbFormulas formulas;
unsigned count = 0;
unordered_map<YAP_Term, LogVar> lvMap;
std::unordered_map<YAP_Term, LogVar> lvMap;
YAP_Term pvList = YAP_ArgOfTerm (2, pfTerm);
while (pvList != YAP_TermNil()) {
YAP_Term formulaTerm = YAP_HeadOfTerm (pvList);
if (YAP_IsAtomTerm (formulaTerm)) {
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (formulaTerm)));
std::string name ((char*) YAP_AtomName (YAP_AtomOfTerm (formulaTerm)));
Symbol functor = LiftedUtils::getSymbol (name);
formulas.push_back (ProbFormula (functor, ranges[count]));
} else {
LogVars logVars;
YAP_Functor yapFunctor = YAP_FunctorOfTerm (formulaTerm);
string name ((char*) YAP_AtomName (YAP_NameOfFunctor (yapFunctor)));
std::string name ((char*) YAP_AtomName (
YAP_NameOfFunctor (yapFunctor)));
Symbol functor = LiftedUtils::getSymbol (name);
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
for (unsigned i = 1; i <= arity; i++) {
YAP_Term ti = YAP_ArgOfTerm (i, formulaTerm);
unordered_map<YAP_Term, LogVar>::iterator it = lvMap.find (ti);
std::unordered_map<YAP_Term, LogVar>::iterator it = lvMap.find (ti);
if (it != lvMap.end()) {
logVars.push_back (it->second);
} else {
@ -418,7 +426,7 @@ readParfactor (YAP_Term pfTerm)
}
// read the parameters
const Params& params = readParameters (YAP_ArgOfTerm (4, pfTerm));
Params params = readParameters (YAP_ArgOfTerm (4, pfTerm));
// read the constraint
Tuples tuples;
@ -434,10 +442,11 @@ readParfactor (YAP_Term pfTerm)
for (unsigned i = 1; i <= arity; i++) {
YAP_Term ti = YAP_ArgOfTerm (i, term);
if (YAP_IsAtomTerm (ti) == false) {
cerr << "Error: the constraint contains free variables." << endl;
std::cerr << "Error: the constraint contains free variables." ;
std::cerr << std::endl;
exit (EXIT_FAILURE);
}
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ti)));
std::string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ti)));
tuple[i - 1] = LiftedUtils::getSymbol (name);
}
tuples.push_back (tuple);
@ -449,55 +458,56 @@ readParfactor (YAP_Term pfTerm)
void
readLiftedEvidence (
YAP_Term observedList,
ObservedFormulas& obsFormulas)
ObservedFormulas*
readLiftedEvidence (YAP_Term observedList)
{
ObservedFormulas* obsFormulas = new ObservedFormulas();
while (observedList != YAP_TermNil()) {
YAP_Term pair = YAP_HeadOfTerm (observedList);
YAP_Term ground = YAP_ArgOfTerm (1, pair);
Symbol functor;
Symbols args;
if (YAP_IsAtomTerm (ground)) {
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ground)));
std::string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ground)));
functor = LiftedUtils::getSymbol (name);
} else {
assert (YAP_IsApplTerm (ground));
YAP_Functor yapFunctor = YAP_FunctorOfTerm (ground);
string name ((char*) (YAP_AtomName (YAP_NameOfFunctor (yapFunctor))));
std::string name ((char*) (YAP_AtomName (
YAP_NameOfFunctor (yapFunctor))));
functor = LiftedUtils::getSymbol (name);
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
for (unsigned i = 1; i <= arity; i++) {
YAP_Term ti = YAP_ArgOfTerm (i, ground);
assert (YAP_IsAtomTerm (ti));
string arg ((char *) YAP_AtomName (YAP_AtomOfTerm (ti)));
std::string arg ((char *) YAP_AtomName (YAP_AtomOfTerm (ti)));
args.push_back (LiftedUtils::getSymbol (arg));
}
}
unsigned evidence = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (2, pair));
bool found = false;
for (size_t i = 0; i < obsFormulas.size(); i++) {
if (obsFormulas[i].functor() == functor &&
obsFormulas[i].arity() == args.size() &&
obsFormulas[i].evidence() == evidence) {
obsFormulas[i].addTuple (args);
for (size_t i = 0; i < obsFormulas->size(); i++) {
if ((*obsFormulas)[i].functor() == functor &&
(*obsFormulas)[i].arity() == args.size() &&
(*obsFormulas)[i].evidence() == evidence) {
(*obsFormulas)[i].addTuple (args);
found = true;
}
}
if (found == false) {
obsFormulas.push_back (ObservedFormula (functor, evidence, args));
obsFormulas->push_back (ObservedFormula (functor, evidence, args));
}
observedList = YAP_TailOfTerm (observedList);
}
return obsFormulas;
}
vector<unsigned>
std::vector<unsigned>
readUnsignedList (YAP_Term list)
{
vector<unsigned> vec;
std::vector<unsigned> vec;
while (list != YAP_TermNil()) {
vec.push_back ((unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (list)));
list = YAP_TailOfTerm (list);
@ -513,7 +523,12 @@ readParameters (YAP_Term paramL)
Params params;
assert (YAP_IsPairTerm (paramL));
while (paramL != YAP_TermNil()) {
params.push_back ((double) YAP_FloatOfTerm (YAP_HeadOfTerm (paramL)));
YAP_Term hd = YAP_HeadOfTerm (paramL);
if (YAP_IsFloatTerm (hd)) {
params.push_back ((double) YAP_FloatOfTerm (hd));
} else {
params.push_back ((double) YAP_IntOfTerm (hd));
}
paramL = YAP_TailOfTerm (paramL);
}
if (Globals::logDomain) {
@ -525,17 +540,17 @@ readParameters (YAP_Term paramL)
YAP_Term
fillAnswersPrologList (vector<Params>& results)
fillSolutionList (const std::vector<Params>& results)
{
YAP_Term list = YAP_TermNil();
for (size_t i = results.size(); i-- > 0; ) {
const Params& beliefs = results[i];
YAP_Term queryBeliefsL = YAP_TermNil();
for (size_t j = beliefs.size(); j-- > 0; ) {
YAP_Int sl1 = YAP_InitSlot (list);
YAP_Int sl = YAP_InitSlot (list);
YAP_Term belief = YAP_MkFloatTerm (beliefs[j]);
queryBeliefsL = YAP_MkPairTerm (belief, queryBeliefsL);
list = YAP_GetFromSlot (sl1);
list = YAP_GetFromSlot (sl);
YAP_RecoverSlots (1);
}
list = YAP_MkPairTerm (queryBeliefsL, list);
@ -543,10 +558,12 @@ fillAnswersPrologList (vector<Params>& results)
return list;
}
}
extern "C" void
init_predicates (void)
init_predicates()
{
YAP_UserCPredicate ("cpp_create_lifted_network",
createLiftedNetwork, 3);
@ -579,3 +596,5 @@ init_predicates (void)
freeGroundNetwork, 1);
}
} // namespace Horus

View File

@ -0,0 +1,32 @@
#include <sstream>
#include <iomanip>
#include "Indexer.h"
namespace Horus {
std::ostream&
operator<< (std::ostream& os, const Indexer& indexer)
{
os << "(" ;
os << std::setw (2) << std::setfill('0') << indexer.index_;
os << ") " ;
os << indexer.indices_;
return os;
}
std::ostream&
operator<< (std::ostream &os, const MapIndexer& indexer)
{
os << "(" ;
os << std::setw (2) << std::setfill('0') << indexer.index_;
os << ") " ;
os << indexer.indices_;
return os;
}
} // namespace Horus

View File

@ -1,262 +1,343 @@
#ifndef HORUS_INDEXER_H
#define HORUS_INDEXER_H
#ifndef YAP_PACKAGES_CLPBN_HORUS_INDEXER_H_
#define YAP_PACKAGES_CLPBN_HORUS_INDEXER_H_
#include <vector>
#include <algorithm>
#include <numeric>
#include <sstream>
#include <iomanip>
#include "Util.h"
class Indexer
{
namespace Horus {
class Indexer {
public:
Indexer (const Ranges& ranges, bool calcOffsets = true)
: index_(0), indices_(ranges.size(), 0), ranges_(ranges),
size_(Util::sizeExpected (ranges))
{
if (calcOffsets) {
calculateOffsets();
}
}
Indexer (const Ranges& ranges, bool calcOffsets = true);
void increment (void)
{
for (size_t i = ranges_.size(); i-- > 0; ) {
indices_[i] ++;
if (indices_[i] != ranges_[i]) {
break;
} else {
indices_[i] = 0;
}
}
index_ ++;
}
void increment();
void incrementDimension (size_t dim)
{
assert (dim < ranges_.size());
assert (ranges_.size() == offsets_.size());
assert (indices_[dim] < ranges_[dim]);
indices_[dim] ++;
index_ += offsets_[dim];
}
void incrementDimension (size_t dim);
void incrementExceptDimension (size_t dim)
{
assert (ranges_.size() == offsets_.size());
for (size_t i = ranges_.size(); i-- > 0; ) {
if (i != dim) {
indices_[i] ++;
index_ += offsets_[i];
if (indices_[i] != ranges_[i]) {
return;
} else {
indices_[i] = 0;
index_ -= offsets_[i] * ranges_[i];
}
}
}
index_ = size_;
}
void incrementExceptDimension (size_t dim);
Indexer& operator++ (void)
{
increment();
return *this;
}
Indexer& operator++();
operator size_t (void) const
{
return index_;
}
operator size_t() const;
unsigned operator[] (size_t dim) const
{
assert (valid());
assert (dim < ranges_.size());
return indices_[dim];
}
unsigned operator[] (size_t dim) const;
bool valid (void) const
{
return index_ < size_;
}
bool valid() const;
void reset (void)
{
std::fill (indices_.begin(), indices_.end(), 0);
index_ = 0;
}
void reset();
void resetDimension (size_t dim)
{
indices_[dim] = 0;
index_ -= offsets_[dim] * ranges_[dim];
}
void resetDimension (size_t dim);
size_t size (void) const
{
return size_ ;
}
size_t size() const;
private:
void calculateOffsets();
friend std::ostream& operator<< (std::ostream&, const Indexer&);
private:
void calculateOffsets (void)
{
size_t prod = 1;
offsets_.resize (ranges_.size());
for (size_t i = ranges_.size(); i-- > 0; ) {
offsets_[i] = prod;
prod *= ranges_[i];
}
}
size_t index_;
Ranges indices_;
const Ranges& ranges_;
size_t size_;
vector<size_t> offsets_;
size_t index_;
Ranges indices_;
const Ranges& ranges_;
size_t size_;
std::vector<size_t> offsets_;
DISALLOW_COPY_AND_ASSIGN (Indexer);
};
inline std::ostream&
operator<< (std::ostream& os, const Indexer& indexer)
inline
Indexer::Indexer (const Ranges& ranges, bool calcOffsets)
: index_(0), indices_(ranges.size(), 0), ranges_(ranges),
size_(Util::sizeExpected (ranges))
{
os << "(" ;
os << std::setw (2) << std::setfill('0') << indexer.index_;
os << ") " ;
os << indexer.indices_;
return os;
if (calcOffsets) {
calculateOffsets();
}
}
class MapIndexer
inline void
Indexer::increment()
{
for (size_t i = ranges_.size(); i-- > 0; ) {
indices_[i] ++;
if (indices_[i] != ranges_[i]) {
break;
} else {
indices_[i] = 0;
}
}
index_ ++;
}
inline void
Indexer::incrementDimension (size_t dim)
{
assert (dim < ranges_.size());
assert (ranges_.size() == offsets_.size());
assert (indices_[dim] < ranges_[dim]);
indices_[dim] ++;
index_ += offsets_[dim];
}
inline void
Indexer::incrementExceptDimension (size_t dim)
{
assert (ranges_.size() == offsets_.size());
for (size_t i = ranges_.size(); i-- > 0; ) {
if (i != dim) {
indices_[i] ++;
index_ += offsets_[i];
if (indices_[i] != ranges_[i]) {
return;
} else {
indices_[i] = 0;
index_ -= offsets_[i] * ranges_[i];
}
}
}
index_ = size_;
}
inline Indexer&
Indexer::operator++()
{
increment();
return *this;
}
inline
Indexer::operator size_t() const
{
return index_;
}
inline unsigned
Indexer::operator[] (size_t dim) const
{
assert (valid());
assert (dim < ranges_.size());
return indices_[dim];
}
inline bool
Indexer::valid() const
{
return index_ < size_;
}
inline void
Indexer::reset()
{
index_ = 0;
std::fill (indices_.begin(), indices_.end(), 0);
}
inline void
Indexer::resetDimension (size_t dim)
{
indices_[dim] = 0;
index_ -= offsets_[dim] * ranges_[dim];
}
inline size_t
Indexer::size() const
{
return size_ ;
}
inline void
Indexer::calculateOffsets()
{
size_t prod = 1;
offsets_.resize (ranges_.size());
for (size_t i = ranges_.size(); i-- > 0; ) {
offsets_[i] = prod;
prod *= ranges_[i];
}
}
class MapIndexer {
public:
MapIndexer (const Ranges& ranges, const vector<bool>& mask)
: index_(0), indices_(ranges.size(), 0), ranges_(ranges),
valid_(true)
{
size_t prod = 1;
offsets_.resize (ranges.size(), 0);
for (size_t i = ranges.size(); i-- > 0; ) {
if (mask[i]) {
offsets_[i] = prod;
prod *= ranges[i];
}
}
assert (ranges.size() == mask.size());
}
MapIndexer (const Ranges& ranges, const std::vector<bool>& mask);
MapIndexer (const Ranges& ranges, size_t dim)
: index_(0), indices_(ranges.size(), 0), ranges_(ranges),
valid_(true)
{
size_t prod = 1;
offsets_.resize (ranges.size(), 0);
for (size_t i = ranges.size(); i-- > 0; ) {
if (i != dim) {
offsets_[i] = prod;
prod *= ranges[i];
}
}
}
MapIndexer (const Ranges& ranges, size_t dim);
template <typename T>
MapIndexer (
const vector<T>& allArgs,
const Ranges& allRanges,
const vector<T>& wantedArgs,
const Ranges& wantedRanges)
: index_(0), indices_(allArgs.size(), 0), ranges_(allRanges),
valid_(true)
{
size_t prod = 1;
vector<size_t> offsets (wantedRanges.size());
for (size_t i = wantedRanges.size(); i-- > 0; ) {
offsets[i] = prod;
prod *= wantedRanges[i];
}
offsets_.reserve (allArgs.size());
for (size_t i = 0; i < allArgs.size(); i++) {
size_t idx = Util::indexOf (wantedArgs, allArgs[i]);
offsets_.push_back (idx != wantedArgs.size() ? offsets[idx] : 0);
}
}
template <typename T>
MapIndexer (
const std::vector<T>& allArgs,
const Ranges& allRanges,
const std::vector<T>& wantedArgs,
const Ranges& wantedRanges);
MapIndexer& operator++ (void)
{
assert (valid_);
for (size_t i = ranges_.size(); i-- > 0; ) {
indices_[i] ++;
index_ += offsets_[i];
if (indices_[i] != ranges_[i]) {
return *this;
} else {
indices_[i] = 0;
index_ -= offsets_[i] * ranges_[i];
}
}
valid_ = false;
return *this;
}
MapIndexer& operator++();
operator size_t (void) const
{
assert (valid());
return index_;
}
operator size_t() const;
unsigned operator[] (size_t dim) const
{
assert (valid());
assert (dim < ranges_.size());
return indices_[dim];
}
unsigned operator[] (size_t dim) const;
bool valid (void) const
{
return valid_;
}
bool valid() const;
void reset (void)
{
std::fill (indices_.begin(), indices_.end(), 0);
index_ = 0;
}
friend std::ostream& operator<< (std::ostream&, const MapIndexer&);
void reset();
private:
size_t index_;
Ranges indices_;
const Ranges& ranges_;
bool valid_;
vector<size_t> offsets_;
friend std::ostream& operator<< (std::ostream&, const MapIndexer&);
size_t index_;
Ranges indices_;
const Ranges& ranges_;
bool valid_;
std::vector<size_t> offsets_;
DISALLOW_COPY_AND_ASSIGN (MapIndexer);
};
inline std::ostream&
operator<< (std::ostream &os, const MapIndexer& indexer)
inline
MapIndexer::MapIndexer (
const Ranges& ranges,
const std::vector<bool>& mask)
: index_(0), indices_(ranges.size(), 0), ranges_(ranges),
valid_(true)
{
os << "(" ;
os << std::setw (2) << std::setfill('0') << indexer.index_;
os << ") " ;
os << indexer.indices_;
return os;
size_t prod = 1;
offsets_.resize (ranges.size(), 0);
for (size_t i = ranges.size(); i-- > 0; ) {
if (mask[i]) {
offsets_[i] = prod;
prod *= ranges[i];
}
}
assert (ranges.size() == mask.size());
}
#endif // HORUS_INDEXER_H
inline
MapIndexer::MapIndexer (const Ranges& ranges, size_t dim)
: index_(0), indices_(ranges.size(), 0), ranges_(ranges),
valid_(true)
{
size_t prod = 1;
offsets_.resize (ranges.size(), 0);
for (size_t i = ranges.size(); i-- > 0; ) {
if (i != dim) {
offsets_[i] = prod;
prod *= ranges[i];
}
}
}
template <typename T> inline
MapIndexer::MapIndexer (
const std::vector<T>& allArgs,
const Ranges& allRanges,
const std::vector<T>& wantedArgs,
const Ranges& wantedRanges)
: index_(0), indices_(allArgs.size(), 0), ranges_(allRanges),
valid_(true)
{
size_t prod = 1;
std::vector<size_t> offsets (wantedRanges.size());
for (size_t i = wantedRanges.size(); i-- > 0; ) {
offsets[i] = prod;
prod *= wantedRanges[i];
}
offsets_.reserve (allArgs.size());
for (size_t i = 0; i < allArgs.size(); i++) {
size_t idx = Util::indexOf (wantedArgs, allArgs[i]);
offsets_.push_back (idx != wantedArgs.size() ? offsets[idx] : 0);
}
}
inline MapIndexer&
MapIndexer::operator++()
{
assert (valid_);
for (size_t i = ranges_.size(); i-- > 0; ) {
indices_[i] ++;
index_ += offsets_[i];
if (indices_[i] != ranges_[i]) {
return *this;
} else {
indices_[i] = 0;
index_ -= offsets_[i] * ranges_[i];
}
}
valid_ = false;
return *this;
}
inline
MapIndexer::operator size_t() const
{
assert (valid());
return index_;
}
inline unsigned
MapIndexer::operator[] (size_t dim) const
{
assert (valid());
assert (dim < ranges_.size());
return indices_[dim];
}
inline bool
MapIndexer::valid() const
{
return valid_;
}
inline void
MapIndexer::reset()
{
index_ = 0;
std::fill (indices_.begin(), indices_.end(), 0);
}
} // namespace Horus
#endif // YAP_PACKAGES_CLPBN_HORUS_INDEXER_H_

View File

@ -1,9 +1,15 @@
#include <cassert>
#include <sstream>
#include "LiftedBp.h"
#include "LiftedOperations.h"
#include "WeightedBp.h"
#include "FactorGraph.h"
namespace Horus {
LiftedBp::LiftedBp (const ParfactorList& parfactorList)
: LiftedSolver (parfactorList)
{
@ -14,7 +20,7 @@ LiftedBp::LiftedBp (const ParfactorList& parfactorList)
LiftedBp::~LiftedBp (void)
LiftedBp::~LiftedBp()
{
delete solver_;
delete fg_;
@ -27,7 +33,7 @@ LiftedBp::solveQuery (const Grounds& query)
{
assert (query.empty() == false);
Params res;
vector<PrvGroup> groups = getQueryGroups (query);
std::vector<PrvGroup> groups = getQueryGroups (query);
if (query.size() == 1) {
res = solver_->getPosterioriOf (groups[0]);
} else {
@ -58,28 +64,29 @@ LiftedBp::solveQuery (const Grounds& query)
void
LiftedBp::printSolverFlags (void) const
LiftedBp::printSolverFlags() const
{
stringstream ss;
std::stringstream ss;
ss << "lifted bp [" ;
ss << "bp_msg_schedule=" ;
typedef WeightedBp::MsgSchedule MsgSchedule;
switch (WeightedBp::msgSchedule()) {
case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break;
case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break;
case MsgSchedule::PARALLEL: ss << "parallel"; break;
case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break;
case MsgSchedule::seqFixedSch: ss << "seq_fixed"; break;
case MsgSchedule::seqRandomSch: ss << "seq_random"; break;
case MsgSchedule::parallelSch: ss << "parallel"; break;
case MsgSchedule::maxResidualSch: ss << "max_residual"; break;
}
ss << ",bp_max_iter=" << WeightedBp::maxIterations();
ss << ",bp_accuracy=" << WeightedBp::accuracy();
ss << ",log_domain=" << Util::toString (Globals::logDomain);
ss << "]" ;
cout << ss.str() << endl;
std::cout << ss.str() << std::endl;
}
void
LiftedBp::refineParfactors (void)
LiftedBp::refineParfactors()
{
pfList_ = parfactorList;
while (iterate() == false);
@ -93,7 +100,7 @@ LiftedBp::refineParfactors (void)
bool
LiftedBp::iterate (void)
LiftedBp::iterate()
{
ParfactorList::iterator it = pfList_.begin();
while (it != pfList_.end()) {
@ -114,10 +121,10 @@ LiftedBp::iterate (void)
vector<PrvGroup>
std::vector<PrvGroup>
LiftedBp::getQueryGroups (const Grounds& query)
{
vector<PrvGroup> queryGroups;
std::vector<PrvGroup> queryGroups;
for (unsigned i = 0; i < query.size(); i++) {
ParfactorList::const_iterator it = pfList_.begin();
for (; it != pfList_.end(); ++it) {
@ -134,12 +141,12 @@ LiftedBp::getQueryGroups (const Grounds& query)
void
LiftedBp::createFactorGraph (void)
LiftedBp::createFactorGraph()
{
fg_ = new FactorGraph();
ParfactorList::const_iterator it = pfList_.begin();
for (; it != pfList_.end(); ++it) {
vector<PrvGroup> groups = (*it)->getAllGroups();
std::vector<PrvGroup> groups = (*it)->getAllGroups();
VarIds varIds;
for (size_t i = 0; i < groups.size(); i++) {
varIds.push_back (groups[i]);
@ -150,10 +157,10 @@ LiftedBp::createFactorGraph (void)
vector<vector<unsigned>>
LiftedBp::getWeights (void) const
std::vector<std::vector<unsigned>>
LiftedBp::getWeights() const
{
vector<vector<unsigned>> weights;
std::vector<std::vector<unsigned>> weights;
weights.reserve (pfList_.size());
ParfactorList::const_iterator it = pfList_.begin();
for (; it != pfList_.end(); ++it) {
@ -196,7 +203,7 @@ LiftedBp::getJointByConditioning (
Grounds obsGrounds = {query[0]};
for (size_t i = 1; i < query.size(); i++) {
Params newBeliefs;
vector<ObservedFormula> obsFs;
std::vector<ObservedFormula> obsFs;
Ranges obsRanges;
for (size_t j = 0; j < obsGrounds.size(); j++) {
obsFs.push_back (ObservedFormula (
@ -231,3 +238,5 @@ LiftedBp::getJointByConditioning (
return prevBeliefs;
}
} // namespace Horus

View File

@ -1,33 +1,38 @@
#ifndef HORUS_LIFTEDBP_H
#define HORUS_LIFTEDBP_H
#ifndef YAP_PACKAGES_CLPBN_HORUS_LIFTEDBP_H_
#define YAP_PACKAGES_CLPBN_HORUS_LIFTEDBP_H_
#include <vector>
#include "LiftedSolver.h"
#include "ParfactorList.h"
#include "Indexer.h"
namespace Horus {
class FactorGraph;
class WeightedBp;
class LiftedBp : public LiftedSolver
{
class LiftedBp : public LiftedSolver{
public:
LiftedBp (const ParfactorList& pfList);
LiftedBp (const ParfactorList& pfList);
~LiftedBp (void);
~LiftedBp();
Params solveQuery (const Grounds&);
Params solveQuery (const Grounds&);
void printSolverFlags (void) const;
void printSolverFlags() const;
private:
void refineParfactors (void);
void refineParfactors();
bool iterate (void);
bool iterate();
vector<PrvGroup> getQueryGroups (const Grounds&);
std::vector<PrvGroup> getQueryGroups (const Grounds&);
void createFactorGraph (void);
void createFactorGraph();
vector<vector<unsigned>> getWeights (void) const;
std::vector<std::vector<unsigned>> getWeights() const;
unsigned rangeOfGround (const Ground&);
@ -38,8 +43,9 @@ class LiftedBp : public LiftedSolver
FactorGraph* fg_;
DISALLOW_COPY_AND_ASSIGN (LiftedBp);
};
#endif // HORUS_LIFTEDBP_H
} // namespace Horus
#endif // YAP_PACKAGES_CLPBN_HORUS_LIFTEDBP_H_

View File

@ -1,11 +1,283 @@
#include <cassert>
#include <vector>
#include <unordered_map>
#include <string>
#include <fstream>
#include <iostream>
#include "LiftedKc.h"
#include "LiftedWCNF.h"
#include "LiftedOperations.h"
#include "Indexer.h"
OrNode::~OrNode (void)
namespace Horus {
enum class CircuitNodeType {
orCnt,
andCnt,
setOrCnt,
setAndCnt,
incExcCnt,
leafCnt,
smoothCnt,
trueCnt,
compilationFailedCnt
};
class CircuitNode {
public:
CircuitNode() { }
virtual ~CircuitNode() { }
virtual double weight() const = 0;
};
class OrNode : public CircuitNode {
public:
OrNode() : CircuitNode(), leftBranch_(0), rightBranch_(0) { }
~OrNode();
CircuitNode** leftBranch () { return &leftBranch_; }
CircuitNode** rightBranch() { return &rightBranch_; }
double weight() const;
private:
CircuitNode* leftBranch_;
CircuitNode* rightBranch_;
};
class AndNode : public CircuitNode {
public:
AndNode() : CircuitNode(), leftBranch_(0), rightBranch_(0) { }
AndNode (CircuitNode* leftBranch, CircuitNode* rightBranch)
: CircuitNode(), leftBranch_(leftBranch),
rightBranch_(rightBranch) { }
~AndNode();
CircuitNode** leftBranch () { return &leftBranch_; }
CircuitNode** rightBranch() { return &rightBranch_; }
double weight() const;
private:
CircuitNode* leftBranch_;
CircuitNode* rightBranch_;
};
class SetOrNode : public CircuitNode {
public:
SetOrNode (unsigned nrGroundings)
: CircuitNode(), follow_(0), nrGroundings_(nrGroundings) { }
~SetOrNode();
CircuitNode** follow() { return &follow_; }
static unsigned nrPositives() { return nrPos_; }
static unsigned nrNegatives() { return nrNeg_; }
static bool isSet() { return nrPos_ >= 0; }
double weight() const;
private:
CircuitNode* follow_;
unsigned nrGroundings_;
static int nrPos_;
static int nrNeg_;
};
class SetAndNode : public CircuitNode {
public:
SetAndNode (unsigned nrGroundings)
: CircuitNode(), follow_(0), nrGroundings_(nrGroundings) { }
~SetAndNode();
CircuitNode** follow() { return &follow_; }
double weight() const;
private:
CircuitNode* follow_;
unsigned nrGroundings_;
};
class IncExcNode : public CircuitNode {
public:
IncExcNode()
: CircuitNode(), plus1Branch_(0), plus2Branch_(0), minusBranch_(0) { }
~IncExcNode();
CircuitNode** plus1Branch() { return &plus1Branch_; }
CircuitNode** plus2Branch() { return &plus2Branch_; }
CircuitNode** minusBranch() { return &minusBranch_; }
double weight() const;
private:
CircuitNode* plus1Branch_;
CircuitNode* plus2Branch_;
CircuitNode* minusBranch_;
};
class LeafNode : public CircuitNode {
public:
LeafNode (Clause* clause, const LiftedWCNF& lwcnf)
: CircuitNode(), clause_(clause), lwcnf_(lwcnf) { }
~LeafNode();
const Clause* clause() const { return clause_; }
Clause* clause() { return clause_; }
double weight() const;
private:
Clause* clause_;
const LiftedWCNF& lwcnf_;
};
class SmoothNode : public CircuitNode {
public:
SmoothNode (const Clauses& clauses, const LiftedWCNF& lwcnf)
: CircuitNode(), clauses_(clauses), lwcnf_(lwcnf) { }
~SmoothNode();
const Clauses& clauses() const { return clauses_; }
Clauses clauses() { return clauses_; }
double weight() const;
private:
Clauses clauses_;
const LiftedWCNF& lwcnf_;
};
class TrueNode : public CircuitNode {
public:
TrueNode() : CircuitNode() { }
double weight() const;
};
class CompilationFailedNode : public CircuitNode {
public:
CompilationFailedNode() : CircuitNode() { }
double weight() const;
};
class LiftedCircuit {
public:
LiftedCircuit (const LiftedWCNF* lwcnf);
~LiftedCircuit();
bool isCompilationSucceeded() const;
double getWeightedModelCount() const;
void exportToGraphViz (const char*);
private:
void compile (CircuitNode** follow, Clauses& clauses);
bool tryUnitPropagation (CircuitNode** follow, Clauses& clauses);
bool tryIndependence (CircuitNode** follow, Clauses& clauses);
bool tryShannonDecomp (CircuitNode** follow, Clauses& clauses);
bool tryInclusionExclusion (CircuitNode** follow, Clauses& clauses);
bool tryIndepPartialGrounding (CircuitNode** follow, Clauses& clauses);
bool tryIndepPartialGroundingAux (Clauses& clauses, ConstraintTree& ct,
LogVars& rootLogVars);
bool tryAtomCounting (CircuitNode** follow, Clauses& clauses);
void shatterCountedLogVars (Clauses& clauses);
bool shatterCountedLogVarsAux (Clauses& clauses);
bool shatterCountedLogVarsAux (Clauses& clauses,
size_t idx1, size_t idx2);
bool independentClause (Clause& clause, Clauses& otherClauses) const;
bool independentLiteral (const Literal& lit,
const Literals& otherLits) const;
LitLvTypesSet smoothCircuit (CircuitNode* node);
void createSmoothNode (const LitLvTypesSet& lids,
CircuitNode** prev);
std::vector<LogVarTypes> getAllPossibleTypes (unsigned nrLogVars) const;
bool containsTypes (const LogVarTypes& typesA,
const LogVarTypes& typesB) const;
CircuitNodeType getCircuitNodeType (const CircuitNode* node) const;
void exportToGraphViz (CircuitNode* node, std::ofstream&);
void printClauses (CircuitNode* node, std::ofstream&,
std::string extraOptions = "");
std::string escapeNode (const CircuitNode* node) const;
std::string getExplanationString (CircuitNode* node);
CircuitNode* root_;
const LiftedWCNF* lwcnf_;
bool compilationSucceeded_;
Clauses backupClauses_;
std::unordered_map<CircuitNode*, Clauses> originClausesMap_;
std::unordered_map<CircuitNode*, std::string> explanationMap_;
DISALLOW_COPY_AND_ASSIGN (LiftedCircuit);
};
OrNode::~OrNode()
{
delete leftBranch_;
delete rightBranch_;
@ -14,7 +286,7 @@ OrNode::~OrNode (void)
double
OrNode::weight (void) const
OrNode::weight() const
{
double lw = leftBranch_->weight();
double rw = rightBranch_->weight();
@ -23,7 +295,7 @@ OrNode::weight (void) const
AndNode::~AndNode (void)
AndNode::~AndNode()
{
delete leftBranch_;
delete rightBranch_;
@ -32,7 +304,7 @@ AndNode::~AndNode (void)
double
AndNode::weight (void) const
AndNode::weight() const
{
double lw = leftBranch_->weight();
double rw = rightBranch_->weight();
@ -46,7 +318,7 @@ int SetOrNode::nrNeg_ = -1;
SetOrNode::~SetOrNode (void)
SetOrNode::~SetOrNode()
{
delete follow_;
}
@ -54,7 +326,7 @@ SetOrNode::~SetOrNode (void)
double
SetOrNode::weight (void) const
SetOrNode::weight() const
{
double weightSum = LogAware::addIdenty();
for (unsigned i = 0; i < nrGroundings_ + 1; i++) {
@ -76,7 +348,7 @@ SetOrNode::weight (void) const
SetAndNode::~SetAndNode (void)
SetAndNode::~SetAndNode()
{
delete follow_;
}
@ -84,14 +356,14 @@ SetAndNode::~SetAndNode (void)
double
SetAndNode::weight (void) const
SetAndNode::weight() const
{
return LogAware::pow (follow_->weight(), nrGroundings_);
}
IncExcNode::~IncExcNode (void)
IncExcNode::~IncExcNode()
{
delete plus1Branch_;
delete plus2Branch_;
@ -101,7 +373,7 @@ IncExcNode::~IncExcNode (void)
double
IncExcNode::weight (void) const
IncExcNode::weight() const
{
double w = 0.0;
if (Globals::logDomain) {
@ -116,7 +388,7 @@ IncExcNode::weight (void) const
LeafNode::~LeafNode (void)
LeafNode::~LeafNode()
{
delete clause_;
}
@ -124,7 +396,7 @@ LeafNode::~LeafNode (void)
double
LeafNode::weight (void) const
LeafNode::weight() const
{
assert (clause_->isUnit());
if (clause_->posCountedLogVars().empty() == false
@ -161,7 +433,7 @@ LeafNode::weight (void) const
SmoothNode::~SmoothNode (void)
SmoothNode::~SmoothNode()
{
Clause::deleteClauses (clauses_);
}
@ -169,7 +441,7 @@ SmoothNode::~SmoothNode (void)
double
SmoothNode::weight (void) const
SmoothNode::weight() const
{
Clauses cs = clauses();
double totalWeight = LogAware::multIdenty();
@ -204,7 +476,7 @@ SmoothNode::weight (void) const
double
TrueNode::weight (void) const
TrueNode::weight() const
{
return LogAware::multIdenty();
}
@ -212,7 +484,7 @@ TrueNode::weight (void) const
double
CompilationFailedNode::weight (void) const
CompilationFailedNode::weight() const
{
// weighted model counting in compilation
// failed nodes should give NaN
@ -234,21 +506,22 @@ LiftedCircuit::LiftedCircuit (const LiftedWCNF* lwcnf)
if (Globals::verbosity > 1) {
if (compilationSucceeded_) {
double wmc = LogAware::exp (getWeightedModelCount());
cout << "Weighted model count = " << wmc << endl << endl;
std::cout << "Weighted model count = " << wmc;
std::cout << std::endl << std::endl;
}
cout << "Exporting circuit to graphviz (circuit.dot)..." ;
cout << endl << endl;
std::cout << "Exporting circuit to graphviz (circuit.dot)..." ;
std::cout << std::endl << std::endl;
exportToGraphViz ("circuit.dot");
}
}
LiftedCircuit::~LiftedCircuit (void)
LiftedCircuit::~LiftedCircuit()
{
delete root_;
unordered_map<CircuitNode*, Clauses>::iterator it;
it = originClausesMap_.begin();
std::unordered_map<CircuitNode*, Clauses>::iterator it
= originClausesMap_.begin();
while (it != originClausesMap_.end()) {
Clause::deleteClauses (it->second);
++ it;
@ -258,7 +531,7 @@ LiftedCircuit::~LiftedCircuit (void)
bool
LiftedCircuit::isCompilationSucceeded (void) const
LiftedCircuit::isCompilationSucceeded() const
{
return compilationSucceeded_;
}
@ -266,7 +539,7 @@ LiftedCircuit::isCompilationSucceeded (void) const
double
LiftedCircuit::getWeightedModelCount (void) const
LiftedCircuit::getWeightedModelCount() const
{
assert (compilationSucceeded_);
return root_->weight();
@ -277,15 +550,16 @@ LiftedCircuit::getWeightedModelCount (void) const
void
LiftedCircuit::exportToGraphViz (const char* fileName)
{
ofstream out (fileName);
std::ofstream out (fileName);
if (!out.is_open()) {
cerr << "Error: couldn't open file '" << fileName << "'." ;
std::cerr << "Error: couldn't open file '" << fileName << "'." ;
std::cerr << std::endl;
return;
}
out << "digraph {" << endl;
out << "ranksep=1" << endl;
out << "digraph {" << std::endl;
out << "ranksep=1" << std::endl;
exportToGraphViz (root_, out);
out << "}" << endl;
out << "}" << std::endl;
out.close();
}
@ -389,7 +663,7 @@ LiftedCircuit::tryUnitPropagation (
AndNode* andNode = new AndNode();
if (Globals::verbosity > 1) {
originClausesMap_[andNode] = backupClauses_;
stringstream explanation;
std::stringstream explanation;
explanation << " UP on " << clauses[i]->literals()[0];
explanationMap_[andNode] = explanation.str();
}
@ -478,7 +752,7 @@ LiftedCircuit::tryShannonDecomp (
OrNode* orNode = new OrNode();
if (Globals::verbosity > 1) {
originClausesMap_[orNode] = backupClauses_;
stringstream explanation;
std::stringstream explanation;
explanation << " SD on " << literals[j];
explanationMap_[orNode] = explanation.str();
}
@ -558,7 +832,7 @@ LiftedCircuit::tryInclusionExclusion (
IncExcNode* ieNode = new IncExcNode();
if (Globals::verbosity > 1) {
originClausesMap_[ieNode] = backupClauses_;
stringstream explanation;
std::stringstream explanation;
explanation << " IncExc on clause nº " << i + 1;
explanationMap_[ieNode] = explanation.str();
}
@ -635,13 +909,13 @@ LiftedCircuit::tryIndepPartialGroundingAux (
}
}
// verifies if the IPG logical vars appear in the same positions
unordered_map<LiteralId, size_t> positions;
std::unordered_map<LiteralId, size_t> positions;
for (size_t i = 0; i < clauses.size(); i++) {
const Literals& literals = clauses[i]->literals();
for (size_t j = 0; j < literals.size(); j++) {
size_t idx = literals[j].indexOfLogVar (rootLogVars[i]);
assert (idx != literals[j].nrLogVars());
unordered_map<LiteralId, size_t>::iterator it;
std::unordered_map<LiteralId, size_t>::iterator it;
it = positions.find (literals[j].lid());
if (it != positions.end()) {
if (it->second != idx) {
@ -810,7 +1084,7 @@ LiftedCircuit::smoothCircuit (CircuitNode* node)
switch (getCircuitNodeType (node)) {
case CircuitNodeType::OR_NODE: {
case CircuitNodeType::orCnt: {
OrNode* casted = dynamic_cast<OrNode*>(node);
LitLvTypesSet lids1 = smoothCircuit (*casted->leftBranch());
LitLvTypesSet lids2 = smoothCircuit (*casted->rightBranch());
@ -823,7 +1097,7 @@ LiftedCircuit::smoothCircuit (CircuitNode* node)
break;
}
case CircuitNodeType::AND_NODE: {
case CircuitNodeType::andCnt: {
AndNode* casted = dynamic_cast<AndNode*>(node);
LitLvTypesSet lids1 = smoothCircuit (*casted->leftBranch());
LitLvTypesSet lids2 = smoothCircuit (*casted->rightBranch());
@ -832,17 +1106,18 @@ LiftedCircuit::smoothCircuit (CircuitNode* node)
break;
}
case CircuitNodeType::SET_OR_NODE: {
case CircuitNodeType::setOrCnt: {
SetOrNode* casted = dynamic_cast<SetOrNode*>(node);
propagLits = smoothCircuit (*casted->follow());
TinySet<pair<LiteralId,unsigned>> litSet;
TinySet<std::pair<LiteralId,unsigned>> litSet;
for (size_t i = 0; i < propagLits.size(); i++) {
litSet.insert (make_pair (propagLits[i].lid(),
litSet.insert (std::make_pair (propagLits[i].lid(),
propagLits[i].logVarTypes().size()));
}
LitLvTypesSet missingLids;
for (size_t i = 0; i < litSet.size(); i++) {
vector<LogVarTypes> allTypes = getAllPossibleTypes (litSet[i].second);
std::vector<LogVarTypes> allTypes
= getAllPossibleTypes (litSet[i].second);
for (size_t j = 0; j < allTypes.size(); j++) {
bool typeFound = false;
for (size_t k = 0; k < propagLits.size(); k++) {
@ -869,13 +1144,13 @@ LiftedCircuit::smoothCircuit (CircuitNode* node)
break;
}
case CircuitNodeType::SET_AND_NODE: {
case CircuitNodeType::setAndCnt: {
SetAndNode* casted = dynamic_cast<SetAndNode*>(node);
propagLits = smoothCircuit (*casted->follow());
break;
}
case CircuitNodeType::INC_EXC_NODE: {
case CircuitNodeType::incExcCnt: {
IncExcNode* casted = dynamic_cast<IncExcNode*>(node);
LitLvTypesSet lids1 = smoothCircuit (*casted->plus1Branch());
LitLvTypesSet lids2 = smoothCircuit (*casted->plus2Branch());
@ -888,7 +1163,7 @@ LiftedCircuit::smoothCircuit (CircuitNode* node)
break;
}
case CircuitNodeType::LEAF_NODE: {
case CircuitNodeType::leafCnt: {
LeafNode* casted = dynamic_cast<LeafNode*>(node);
propagLits.insert (LitLvTypes (
casted->clause()->literals()[0].lid(),
@ -911,8 +1186,8 @@ LiftedCircuit::createSmoothNode (
{
if (missingLits.empty() == false) {
if (Globals::verbosity > 1) {
unordered_map<CircuitNode*, Clauses>::iterator it;
it = originClausesMap_.find (*prev);
std::unordered_map<CircuitNode*, Clauses>::iterator it
= originClausesMap_.find (*prev);
if (it != originClausesMap_.end()) {
backupClauses_ = it->second;
} else {
@ -927,9 +1202,9 @@ LiftedCircuit::createSmoothNode (
Clause* c = lwcnf_->createClause (lid);
for (size_t j = 0; j < types.size(); j++) {
LogVar X = c->literals().front().logVars()[j];
if (types[j] == LogVarType::POS_LV) {
if (types[j] == LogVarType::posLvt) {
c->addPosCountedLogVar (X);
} else if (types[j] == LogVarType::NEG_LV) {
} else if (types[j] == LogVarType::negLvt) {
c->addNegCountedLogVar (X);
}
}
@ -947,15 +1222,15 @@ LiftedCircuit::createSmoothNode (
vector<LogVarTypes>
std::vector<LogVarTypes>
LiftedCircuit::getAllPossibleTypes (unsigned nrLogVars) const
{
vector<LogVarTypes> res;
std::vector<LogVarTypes> res;
if (nrLogVars == 0) {
// do nothing
} else if (nrLogVars == 1) {
res.push_back ({ LogVarType::POS_LV });
res.push_back ({ LogVarType::NEG_LV });
res.push_back ({ LogVarType::posLvt });
res.push_back ({ LogVarType::negLvt });
} else {
Ranges ranges (nrLogVars, 2);
Indexer indexer (ranges);
@ -963,9 +1238,9 @@ LiftedCircuit::getAllPossibleTypes (unsigned nrLogVars) const
LogVarTypes types;
for (size_t i = 0; i < nrLogVars; i++) {
if (indexer[i] == 0) {
types.push_back (LogVarType::POS_LV);
types.push_back (LogVarType::posLvt);
} else {
types.push_back (LogVarType::NEG_LV);
types.push_back (LogVarType::negLvt);
}
}
res.push_back (types);
@ -983,13 +1258,13 @@ LiftedCircuit::containsTypes (
const LogVarTypes& typesB) const
{
for (size_t i = 0; i < typesA.size(); i++) {
if (typesA[i] == LogVarType::FULL_LV) {
if (typesA[i] == LogVarType::fullLvt) {
} else if (typesA[i] == LogVarType::POS_LV
&& typesB[i] == LogVarType::POS_LV) {
} else if (typesA[i] == LogVarType::posLvt
&& typesB[i] == LogVarType::posLvt) {
} else if (typesA[i] == LogVarType::NEG_LV
&& typesB[i] == LogVarType::NEG_LV) {
} else if (typesA[i] == LogVarType::negLvt
&& typesB[i] == LogVarType::negLvt) {
} else {
return false;
@ -1003,25 +1278,25 @@ LiftedCircuit::containsTypes (
CircuitNodeType
LiftedCircuit::getCircuitNodeType (const CircuitNode* node) const
{
CircuitNodeType type = CircuitNodeType::OR_NODE;
CircuitNodeType type = CircuitNodeType::orCnt;
if (dynamic_cast<const OrNode*>(node)) {
type = CircuitNodeType::OR_NODE;
type = CircuitNodeType::orCnt;
} else if (dynamic_cast<const AndNode*>(node)) {
type = CircuitNodeType::AND_NODE;
type = CircuitNodeType::andCnt;
} else if (dynamic_cast<const SetOrNode*>(node)) {
type = CircuitNodeType::SET_OR_NODE;
type = CircuitNodeType::setOrCnt;
} else if (dynamic_cast<const SetAndNode*>(node)) {
type = CircuitNodeType::SET_AND_NODE;
type = CircuitNodeType::setAndCnt;
} else if (dynamic_cast<const IncExcNode*>(node)) {
type = CircuitNodeType::INC_EXC_NODE;
type = CircuitNodeType::incExcCnt;
} else if (dynamic_cast<const LeafNode*>(node)) {
type = CircuitNodeType::LEAF_NODE;
type = CircuitNodeType::leafCnt;
} else if (dynamic_cast<const SmoothNode*>(node)) {
type = CircuitNodeType::SMOOTH_NODE;
type = CircuitNodeType::smoothCnt;
} else if (dynamic_cast<const TrueNode*>(node)) {
type = CircuitNodeType::TRUE_NODE;
type = CircuitNodeType::trueCnt;
} else if (dynamic_cast<const CompilationFailedNode*>(node)) {
type = CircuitNodeType::COMPILATION_FAILED_NODE;
type = CircuitNodeType::compilationFailedCnt;
} else {
assert (false);
}
@ -1031,127 +1306,131 @@ LiftedCircuit::getCircuitNodeType (const CircuitNode* node) const
void
LiftedCircuit::exportToGraphViz (CircuitNode* node, ofstream& os)
LiftedCircuit::exportToGraphViz (CircuitNode* node, std::ofstream& os)
{
assert (node);
static unsigned nrAuxNodes = 0;
stringstream ss;
std::stringstream ss;
ss << "n" << nrAuxNodes;
string auxNode = ss.str();
std::string auxNode = ss.str();
nrAuxNodes ++;
string opStyle = "shape=circle,width=0.7,margin=\"0.0,0.0\"," ;
std::string opStyle = "shape=circle,width=0.7,margin=\"0.0,0.0\"," ;
switch (getCircuitNodeType (node)) {
case OR_NODE: {
case CircuitNodeType::orCnt: {
OrNode* casted = dynamic_cast<OrNode*>(node);
printClauses (casted, os);
os << auxNode << " [" << opStyle << "label=\"\"]" << endl;
os << auxNode << " [" << opStyle << "label=\"\"]" ;
os << std::endl;
os << escapeNode (node) << " -> " << auxNode;
os << " [label=\"" << getExplanationString (node) << "\"]" ;
os << endl;
os << std::endl;
os << auxNode << " -> " ;
os << escapeNode (*casted->leftBranch());
os << " [label=\" " << (*casted->leftBranch())->weight() << "\"]" ;
os << endl;
os << std::endl;
os << auxNode << " -> " ;
os << escapeNode (*casted->rightBranch());
os << " [label=\" " << (*casted->rightBranch())->weight() << "\"]" ;
os << endl;
os << std::endl;
exportToGraphViz (*casted->leftBranch(), os);
exportToGraphViz (*casted->rightBranch(), os);
break;
}
case AND_NODE: {
case CircuitNodeType::andCnt: {
AndNode* casted = dynamic_cast<AndNode*>(node);
printClauses (casted, os);
os << auxNode << " [" << opStyle << "label=\"\"]" << endl;
os << auxNode << " [" << opStyle << "label=\"\"]" ;
os << std::endl;
os << escapeNode (node) << " -> " << auxNode;
os << " [label=\"" << getExplanationString (node) << "\"]" ;
os << endl;
os << std::endl;
os << auxNode << " -> " ;
os << escapeNode (*casted->leftBranch());
os << " [label=\" " << (*casted->leftBranch())->weight() << "\"]" ;
os << endl;
os << std::endl;
os << auxNode << " -> " ;
os << escapeNode (*casted->rightBranch()) << endl;
os << escapeNode (*casted->rightBranch());
os << " [label=\" " << (*casted->rightBranch())->weight() << "\"]" ;
os << endl;
os << std::endl;
exportToGraphViz (*casted->leftBranch(), os);
exportToGraphViz (*casted->rightBranch(), os);
break;
}
case SET_OR_NODE: {
case CircuitNodeType::setOrCnt: {
SetOrNode* casted = dynamic_cast<SetOrNode*>(node);
printClauses (casted, os);
os << auxNode << " [" << opStyle << "label=\"(X)\"]" << endl;
os << auxNode << " [" << opStyle << "label=\"(X)\"]" ;
os << std::endl;
os << escapeNode (node) << " -> " << auxNode;
os << " [label=\"" << getExplanationString (node) << "\"]" ;
os << endl;
os << std::endl;
os << auxNode << " -> " ;
os << escapeNode (*casted->follow());
os << " [label=\" " << (*casted->follow())->weight() << "\"]" ;
os << endl;
os << std::endl;
exportToGraphViz (*casted->follow(), os);
break;
}
case SET_AND_NODE: {
case CircuitNodeType::setAndCnt: {
SetAndNode* casted = dynamic_cast<SetAndNode*>(node);
printClauses (casted, os);
os << auxNode << " [" << opStyle << "label=\"∧(X)\"]" << endl;
os << auxNode << " [" << opStyle << "label=\"∧(X)\"]" ;
os << std::endl;
os << escapeNode (node) << " -> " << auxNode;
os << " [label=\"" << getExplanationString (node) << "\"]" ;
os << endl;
os << std::endl;
os << auxNode << " -> " ;
os << escapeNode (*casted->follow());
os << " [label=\" " << (*casted->follow())->weight() << "\"]" ;
os << endl;
os << std::endl;
exportToGraphViz (*casted->follow(), os);
break;
}
case INC_EXC_NODE: {
case CircuitNodeType::incExcCnt: {
IncExcNode* casted = dynamic_cast<IncExcNode*>(node);
printClauses (casted, os);
os << auxNode << " [" << opStyle << "label=\"+ - +\"]" ;
os << endl;
os << std::endl;
os << escapeNode (node) << " -> " << auxNode;
os << " [label=\"" << getExplanationString (node) << "\"]" ;
os << endl;
os << std::endl;
os << auxNode << " -> " ;
os << escapeNode (*casted->plus1Branch());
os << " [label=\" " << (*casted->plus1Branch())->weight() << "\"]" ;
os << endl;
os << std::endl;
os << auxNode << " -> " ;
os << escapeNode (*casted->minusBranch()) << endl;
os << escapeNode (*casted->minusBranch()) << std::endl;
os << " [label=\" " << (*casted->minusBranch())->weight() << "\"]" ;
os << endl;
os << std::endl;
os << auxNode << " -> " ;
os << escapeNode (*casted->plus2Branch());
os << " [label=\" " << (*casted->plus2Branch())->weight() << "\"]" ;
os << endl;
os << std::endl;
exportToGraphViz (*casted->plus1Branch(), os);
exportToGraphViz (*casted->plus2Branch(), os);
@ -1159,24 +1438,24 @@ LiftedCircuit::exportToGraphViz (CircuitNode* node, ofstream& os)
break;
}
case LEAF_NODE: {
case CircuitNodeType::leafCnt: {
printClauses (node, os, "style=filled,fillcolor=palegreen,");
break;
}
case SMOOTH_NODE: {
case CircuitNodeType::smoothCnt: {
printClauses (node, os, "style=filled,fillcolor=lightblue,");
break;
}
case TRUE_NODE: {
case CircuitNodeType::trueCnt: {
os << escapeNode (node);
os << " [shape=box,label=\"\"]" ;
os << endl;
os << std::endl;
break;
}
case COMPILATION_FAILED_NODE: {
case CircuitNodeType::compilationFailedCnt: {
printClauses (node, os, "style=filled,fillcolor=salmon,");
break;
}
@ -1188,17 +1467,17 @@ LiftedCircuit::exportToGraphViz (CircuitNode* node, ofstream& os)
string
std::string
LiftedCircuit::escapeNode (const CircuitNode* node) const
{
stringstream ss;
std::stringstream ss;
ss << "\"" << node << "\"" ;
return ss.str();
}
string
std::string
LiftedCircuit::getExplanationString (CircuitNode* node)
{
return Util::contains (explanationMap_, node)
@ -1211,15 +1490,15 @@ LiftedCircuit::getExplanationString (CircuitNode* node)
void
LiftedCircuit::printClauses (
CircuitNode* node,
ofstream& os,
string extraOptions)
std::ofstream& os,
std::string extraOptions)
{
Clauses clauses;
if (Util::contains (originClausesMap_, node)) {
clauses = originClausesMap_[node];
} else if (getCircuitNodeType (node) == CircuitNodeType::LEAF_NODE) {
} else if (getCircuitNodeType (node) == CircuitNodeType::leafCnt) {
clauses = { (dynamic_cast<LeafNode*>(node))->clause() } ;
} else if (getCircuitNodeType (node) == CircuitNodeType::SMOOTH_NODE) {
} else if (getCircuitNodeType (node) == CircuitNodeType::smoothCnt) {
clauses = (dynamic_cast<SmoothNode*>(node))->clauses();
}
assert (clauses.empty() == false);
@ -1230,15 +1509,7 @@ LiftedCircuit::printClauses (
os << *clauses[i];
}
os << "\"]" ;
os << endl;
}
LiftedKc::~LiftedKc (void)
{
delete lwcnf_;
delete circuit_;
os << std::endl;
}
@ -1246,20 +1517,21 @@ LiftedKc::~LiftedKc (void)
Params
LiftedKc::solveQuery (const Grounds& query)
{
pfList_ = parfactorList;
LiftedOperations::shatterAgainstQuery (pfList_, query);
LiftedOperations::runWeakBayesBall (pfList_, query);
lwcnf_ = new LiftedWCNF (pfList_);
circuit_ = new LiftedCircuit (lwcnf_);
if (circuit_->isCompilationSucceeded() == false) {
cerr << "Error: the circuit compilation has failed." << endl;
ParfactorList pfList (parfactorList);
LiftedOperations::shatterAgainstQuery (pfList, query);
LiftedOperations::runWeakBayesBall (pfList, query);
LiftedWCNF lwcnf (pfList);
LiftedCircuit circuit (&lwcnf);
if (circuit.isCompilationSucceeded() == false) {
std::cerr << "Error: the circuit compilation has failed." ;
std::cerr << std::endl;
exit (EXIT_FAILURE);
}
vector<PrvGroup> groups;
std::vector<PrvGroup> groups;
Ranges ranges;
for (size_t i = 0; i < query.size(); i++) {
ParfactorList::const_iterator it = pfList_.begin();
while (it != pfList_.end()) {
ParfactorList::const_iterator it = pfList.begin();
while (it != pfList.end()) {
size_t idx = (*it)->indexOfGround (query[i]);
if (idx != (*it)->nrArguments()) {
groups.push_back ((*it)->argument (idx).group());
@ -1274,18 +1546,18 @@ LiftedKc::solveQuery (const Grounds& query)
Indexer indexer (ranges);
while (indexer.valid()) {
for (size_t i = 0; i < groups.size(); i++) {
vector<LiteralId> litIds = lwcnf_->prvGroupLiterals (groups[i]);
std::vector<LiteralId> litIds = lwcnf.prvGroupLiterals (groups[i]);
for (size_t j = 0; j < litIds.size(); j++) {
if (indexer[i] == j) {
lwcnf_->addWeight (litIds[j], LogAware::one(),
lwcnf.addWeight (litIds[j], LogAware::one(),
LogAware::one());
} else {
lwcnf_->addWeight (litIds[j], LogAware::zero(),
lwcnf.addWeight (litIds[j], LogAware::zero(),
LogAware::one());
}
}
}
params.push_back (circuit_->getWeightedModelCount());
params.push_back (circuit.getWeightedModelCount());
++ indexer;
}
LogAware::normalize (params);
@ -1298,12 +1570,14 @@ LiftedKc::solveQuery (const Grounds& query)
void
LiftedKc::printSolverFlags (void) const
LiftedKc::printSolverFlags() const
{
stringstream ss;
std::stringstream ss;
ss << "lifted kc [" ;
ss << "log_domain=" << Util::toString (Globals::logDomain);
ss << "]" ;
cout << ss.str() << endl;
std::cout << ss.str() << std::endl;
}
} // namespace Horus

View File

@ -1,302 +1,26 @@
#ifndef HORUS_LIFTEDKC_H
#define HORUS_LIFTEDKC_H
#ifndef YAP_PACKAGES_CLPBN_HORUS_LIFTEDKC_H_
#define YAP_PACKAGES_CLPBN_HORUS_LIFTEDKC_H_
#include "LiftedSolver.h"
#include "LiftedWCNF.h"
#include "ParfactorList.h"
enum CircuitNodeType {
OR_NODE,
AND_NODE,
SET_OR_NODE,
SET_AND_NODE,
INC_EXC_NODE,
LEAF_NODE,
SMOOTH_NODE,
TRUE_NODE,
COMPILATION_FAILED_NODE
};
namespace Horus {
class CircuitNode
{
public:
CircuitNode (void) { }
virtual ~CircuitNode (void) { }
virtual double weight (void) const = 0;
};
class OrNode : public CircuitNode
{
public:
OrNode (void) : CircuitNode(), leftBranch_(0), rightBranch_(0) { }
~OrNode (void);
CircuitNode** leftBranch (void) { return &leftBranch_; }
CircuitNode** rightBranch (void) { return &rightBranch_; }
double weight (void) const;
private:
CircuitNode* leftBranch_;
CircuitNode* rightBranch_;
};
class AndNode : public CircuitNode
{
public:
AndNode (void) : CircuitNode(), leftBranch_(0), rightBranch_(0) { }
AndNode (CircuitNode* leftBranch, CircuitNode* rightBranch)
: CircuitNode(), leftBranch_(leftBranch), rightBranch_(rightBranch) { }
~AndNode (void);
CircuitNode** leftBranch (void) { return &leftBranch_; }
CircuitNode** rightBranch (void) { return &rightBranch_; }
double weight (void) const;
private:
CircuitNode* leftBranch_;
CircuitNode* rightBranch_;
};
class SetOrNode : public CircuitNode
{
public:
SetOrNode (unsigned nrGroundings)
: CircuitNode(), follow_(0), nrGroundings_(nrGroundings) { }
~SetOrNode (void);
CircuitNode** follow (void) { return &follow_; }
static unsigned nrPositives (void) { return nrPos_; }
static unsigned nrNegatives (void) { return nrNeg_; }
static bool isSet (void) { return nrPos_ >= 0; }
double weight (void) const;
private:
CircuitNode* follow_;
unsigned nrGroundings_;
static int nrPos_;
static int nrNeg_;
};
class SetAndNode : public CircuitNode
{
public:
SetAndNode (unsigned nrGroundings)
: CircuitNode(), follow_(0), nrGroundings_(nrGroundings) { }
~SetAndNode (void);
CircuitNode** follow (void) { return &follow_; }
double weight (void) const;
private:
CircuitNode* follow_;
unsigned nrGroundings_;
};
class IncExcNode : public CircuitNode
{
public:
IncExcNode (void)
: CircuitNode(), plus1Branch_(0), plus2Branch_(0), minusBranch_(0) { }
~IncExcNode (void);
CircuitNode** plus1Branch (void) { return &plus1Branch_; }
CircuitNode** plus2Branch (void) { return &plus2Branch_; }
CircuitNode** minusBranch (void) { return &minusBranch_; }
double weight (void) const;
private:
CircuitNode* plus1Branch_;
CircuitNode* plus2Branch_;
CircuitNode* minusBranch_;
};
class LeafNode : public CircuitNode
{
public:
LeafNode (Clause* clause, const LiftedWCNF& lwcnf)
: CircuitNode(), clause_(clause), lwcnf_(lwcnf) { }
~LeafNode (void);
const Clause* clause (void) const { return clause_; }
Clause* clause (void) { return clause_; }
double weight (void) const;
private:
Clause* clause_;
const LiftedWCNF& lwcnf_;
};
class SmoothNode : public CircuitNode
{
public:
SmoothNode (const Clauses& clauses, const LiftedWCNF& lwcnf)
: CircuitNode(), clauses_(clauses), lwcnf_(lwcnf) { }
~SmoothNode (void);
const Clauses& clauses (void) const { return clauses_; }
Clauses clauses (void) { return clauses_; }
double weight (void) const;
private:
Clauses clauses_;
const LiftedWCNF& lwcnf_;
};
class TrueNode : public CircuitNode
{
public:
TrueNode (void) : CircuitNode() { }
double weight (void) const;
};
class CompilationFailedNode : public CircuitNode
{
public:
CompilationFailedNode (void) : CircuitNode() { }
double weight (void) const;
};
class LiftedCircuit
{
public:
LiftedCircuit (const LiftedWCNF* lwcnf);
~LiftedCircuit (void);
bool isCompilationSucceeded (void) const;
double getWeightedModelCount (void) const;
void exportToGraphViz (const char*);
private:
void compile (CircuitNode** follow, Clauses& clauses);
bool tryUnitPropagation (CircuitNode** follow, Clauses& clauses);
bool tryIndependence (CircuitNode** follow, Clauses& clauses);
bool tryShannonDecomp (CircuitNode** follow, Clauses& clauses);
bool tryInclusionExclusion (CircuitNode** follow, Clauses& clauses);
bool tryIndepPartialGrounding (CircuitNode** follow, Clauses& clauses);
bool tryIndepPartialGroundingAux (Clauses& clauses, ConstraintTree& ct,
LogVars& rootLogVars);
bool tryAtomCounting (CircuitNode** follow, Clauses& clauses);
void shatterCountedLogVars (Clauses& clauses);
bool shatterCountedLogVarsAux (Clauses& clauses);
bool shatterCountedLogVarsAux (Clauses& clauses, size_t idx1, size_t idx2);
bool independentClause (Clause& clause, Clauses& otherClauses) const;
bool independentLiteral (const Literal& lit,
const Literals& otherLits) const;
LitLvTypesSet smoothCircuit (CircuitNode* node);
void createSmoothNode (const LitLvTypesSet& lids,
CircuitNode** prev);
vector<LogVarTypes> getAllPossibleTypes (unsigned nrLogVars) const;
bool containsTypes (const LogVarTypes& typesA,
const LogVarTypes& typesB) const;
CircuitNodeType getCircuitNodeType (const CircuitNode* node) const;
void exportToGraphViz (CircuitNode* node, ofstream&);
void printClauses (CircuitNode* node, ofstream&,
string extraOptions = "");
string escapeNode (const CircuitNode* node) const;
string getExplanationString (CircuitNode* node);
CircuitNode* root_;
const LiftedWCNF* lwcnf_;
bool compilationSucceeded_;
Clauses backupClauses_;
unordered_map<CircuitNode*, Clauses> originClausesMap_;
unordered_map<CircuitNode*, string> explanationMap_;
DISALLOW_COPY_AND_ASSIGN (LiftedCircuit);
};
class LiftedKc : public LiftedSolver
{
class LiftedKc : public LiftedSolver {
public:
LiftedKc (const ParfactorList& pfList)
: LiftedSolver(pfList) { }
~LiftedKc (void);
Params solveQuery (const Grounds&);
void printSolverFlags (void) const;
void printSolverFlags() const;
private:
LiftedWCNF* lwcnf_;
LiftedCircuit* circuit_;
ParfactorList pfList_;
DISALLOW_COPY_AND_ASSIGN (LiftedKc);
};
#endif // HORUS_LIFTEDKC_H
} // namespace Horus
#endif // YAP_PACKAGES_CLPBN_HORUS_LIFTEDKC_H_

View File

@ -1,10 +1,22 @@
#include <vector>
#include <queue>
#include <iostream>
#include "LiftedOperations.h"
namespace Horus {
namespace LiftedOperations {
namespace {
Parfactors absorve (ObservedFormula& obsFormula, Parfactor* g);
}
void
LiftedOperations::shatterAgainstQuery (
ParfactorList& pfList,
const Grounds& query)
shatterAgainstQuery (ParfactorList& pfList, const Grounds& query)
{
for (size_t i = 0; i < query.size(); i++) {
if (query[i].isAtom()) {
@ -35,17 +47,17 @@ LiftedOperations::shatterAgainstQuery (
}
}
if (found == false) {
cerr << "Error: could not find a parfactor with ground " ;
cerr << "`" << query[i] << "'." << endl;
std::cerr << "Error: could not find a parfactor with ground " ;
std::cerr << "`" << query[i] << "'." << std::endl;
exit (EXIT_FAILURE);
}
pfList.add (newPfs);
}
if (Globals::verbosity > 2) {
Util::printAsteriskLine();
cout << "SHATTERED AGAINST THE QUERY" << endl;
std::cout << "SHATTERED AGAINST THE QUERY" << std::endl;
for (size_t i = 0; i < query.size(); i++) {
cout << " -> " << query[i] << endl;
std::cout << " -> " << query[i] << std::endl;
}
Util::printAsteriskLine();
pfList.print();
@ -55,12 +67,10 @@ LiftedOperations::shatterAgainstQuery (
void
LiftedOperations::runWeakBayesBall (
ParfactorList& pfList,
const Grounds& query)
runWeakBayesBall (ParfactorList& pfList, const Grounds& query)
{
queue<PrvGroup> todo; // groups to process
set<PrvGroup> done; // processed or in queue
std::queue<PrvGroup> todo; // groups to process
std::set<PrvGroup> done; // processed or in queue
for (size_t i = 0; i < query.size(); i++) {
ParfactorList::iterator it = pfList.begin();
while (it != pfList.end()) {
@ -74,14 +84,14 @@ LiftedOperations::runWeakBayesBall (
}
}
set<Parfactor*> requiredPfs;
std::set<Parfactor*> requiredPfs;
while (todo.empty() == false) {
PrvGroup group = todo.front();
ParfactorList::iterator it = pfList.begin();
while (it != pfList.end()) {
if (Util::contains (requiredPfs, *it) == false &&
(*it)->containsGroup (group)) {
vector<PrvGroup> groups = (*it)->getAllGroups();
std::vector<PrvGroup> groups = (*it)->getAllGroups();
for (size_t i = 0; i < groups.size(); i++) {
if (Util::contains (done, groups[i]) == false) {
todo.push (groups[i]);
@ -116,9 +126,7 @@ LiftedOperations::runWeakBayesBall (
void
LiftedOperations::absorveEvidence (
ParfactorList& pfList,
ObservedFormulas& obsFormulas)
absorveEvidence (ParfactorList& pfList, ObservedFormulas& obsFormulas)
{
for (size_t i = 0; i < obsFormulas.size(); i++) {
Parfactors newPfs;
@ -143,9 +151,9 @@ LiftedOperations::absorveEvidence (
}
if (Globals::verbosity > 2 && obsFormulas.empty() == false) {
Util::printAsteriskLine();
cout << "AFTER EVIDENCE ABSORVED" << endl;
std::cout << "AFTER EVIDENCE ABSORVED" << std::endl;
for (size_t i = 0; i < obsFormulas.size(); i++) {
cout << " -> " << obsFormulas[i] << endl;
std::cout << " -> " << obsFormulas[i] << std::endl;
}
Util::printAsteriskLine();
pfList.print();
@ -155,9 +163,7 @@ LiftedOperations::absorveEvidence (
Parfactors
LiftedOperations::countNormalize (
Parfactor* g,
const LogVarSet& set)
countNormalize (Parfactor* g, const LogVarSet& set)
{
Parfactors normPfs;
if (set.empty()) {
@ -174,7 +180,7 @@ LiftedOperations::countNormalize (
Parfactor
LiftedOperations::calcGroundMultiplication (Parfactor pf)
calcGroundMultiplication (Parfactor pf)
{
LogVarSet lvs = pf.constr()->logVarSet();
lvs -= pf.constr()->singletons();
@ -206,10 +212,10 @@ LiftedOperations::calcGroundMultiplication (Parfactor pf)
namespace {
Parfactors
LiftedOperations::absorve (
ObservedFormula& obsFormula,
Parfactor* g)
absorve (ObservedFormula& obsFormula, Parfactor* g)
{
Parfactors absorvedPfs;
const ProbFormulas& formulas = g->arguments();
@ -269,3 +275,9 @@ LiftedOperations::absorve (
return absorvedPfs;
}
}
} // namespace LiftedOperations
} // namespace Horus

View File

@ -1,29 +1,26 @@
#ifndef HORUS_LIFTEDOPERATIONS_H
#define HORUS_LIFTEDOPERATIONS_H
#ifndef YAP_PACKAGES_CLPBN_HORUS_LIFTEDOPERATIONS_H_
#define YAP_PACKAGES_CLPBN_HORUS_LIFTEDOPERATIONS_H_
#include "ParfactorList.h"
class LiftedOperations
{
public:
static void shatterAgainstQuery (
ParfactorList& pfList, const Grounds& query);
static void runWeakBayesBall (
ParfactorList& pfList, const Grounds&);
namespace Horus {
static void absorveEvidence (
ParfactorList& pfList, ObservedFormulas& obsFormulas);
namespace LiftedOperations {
static Parfactors countNormalize (Parfactor*, const LogVarSet&);
void shatterAgainstQuery (ParfactorList& pfList, const Grounds& query);
static Parfactor calcGroundMultiplication (Parfactor pf);
void runWeakBayesBall (ParfactorList& pfList, const Grounds& query);
private:
static Parfactors absorve (ObservedFormula&, Parfactor*);
void absorveEvidence (ParfactorList& pfList, ObservedFormulas&);
DISALLOW_COPY_AND_ASSIGN (LiftedOperations);
};
Parfactors countNormalize (Parfactor*, const LogVarSet&);
#endif // HORUS_LIFTEDOPERATIONS_H
Parfactor calcGroundMultiplication (Parfactor pf);
} // namespace LiftedOperations
} // namespace Horus
#endif // YAP_PACKAGES_CLPBN_HORUS_LIFTEDOPERATIONS_H_

View File

@ -1,14 +1,12 @@
#ifndef HORUS_LIFTEDSOLVER_H
#define HORUS_LIFTEDSOLVER_H
#ifndef YAP_PACKAGES_CLPBN_HORUS_LIFTEDSOLVER_H_
#define YAP_PACKAGES_CLPBN_HORUS_LIFTEDSOLVER_H_
#include "ParfactorList.h"
#include "Horus.h"
using namespace std;
namespace Horus {
class LiftedSolver
{
class LiftedSolver {
public:
LiftedSolver (const ParfactorList& pfList)
: parfactorList(pfList) { }
@ -17,7 +15,7 @@ class LiftedSolver
virtual Params solveQuery (const Grounds& query) = 0;
virtual void printSolverFlags (void) const = 0;
virtual void printSolverFlags() const = 0;
protected:
const ParfactorList& parfactorList;
@ -26,5 +24,7 @@ class LiftedSolver
DISALLOW_COPY_AND_ASSIGN (LiftedSolver);
};
#endif // HORUS_LIFTEDSOLVER_H
} // namespace Horus
#endif // YAP_PACKAGES_CLPBN_HORUS_LIFTEDSOLVER_H_

View File

@ -1,22 +1,21 @@
#include <cassert>
#include <iostream>
#include <sstream>
#include "LiftedUtils.h"
#include "ConstraintTree.h"
namespace Horus {
namespace LiftedUtils {
unordered_map<string, unsigned> symbolDict;
std::unordered_map<std::string, unsigned> symbolDict;
Symbol
getSymbol (const string& symbolName)
getSymbol (const std::string& symbolName)
{
unordered_map<string, unsigned>::iterator it
std::unordered_map<std::string, unsigned>::iterator it
= symbolDict.find (symbolName);
if (it != symbolDict.end()) {
return it->second;
@ -29,12 +28,12 @@ getSymbol (const string& symbolName)
void
printSymbolDictionary (void)
printSymbolDictionary()
{
unordered_map<string, unsigned>::const_iterator it
std::unordered_map<std::string, unsigned>::const_iterator it
= symbolDict.begin();
while (it != symbolDict.end()) {
cout << it->first << " -> " << it->second << endl;
std::cout << it->first << " -> " << it->second << std::endl;
++ it;
}
}
@ -43,9 +42,10 @@ printSymbolDictionary (void)
ostream& operator<< (ostream &os, const Symbol& s)
std::ostream&
operator<< (std::ostream& os, const Symbol& s)
{
unordered_map<string, unsigned>::const_iterator it
std::unordered_map<std::string, unsigned>::const_iterator it
= LiftedUtils::symbolDict.begin();
while (it != LiftedUtils::symbolDict.end() && it->second != s) {
++ it;
@ -57,9 +57,10 @@ ostream& operator<< (ostream &os, const Symbol& s)
ostream& operator<< (ostream &os, const LogVar& X)
std::ostream&
operator<< (std::ostream& os, const LogVar& X)
{
const string labels[] = {
const std::string labels[] = {
"A", "B", "C", "D", "E", "F",
"G", "H", "I", "J", "K", "M" };
(X >= 12) ? os << "X_" << X.id_ : os << labels[X];
@ -68,7 +69,8 @@ ostream& operator<< (ostream &os, const LogVar& X)
ostream& operator<< (ostream &os, const Tuple& t)
std::ostream&
operator<< (std::ostream& os, const Tuple& t)
{
os << "(" ;
for (size_t i = 0; i < t.size(); i++) {
@ -80,7 +82,8 @@ ostream& operator<< (ostream &os, const Tuple& t)
ostream& operator<< (ostream &os, const Ground& gr)
std::ostream&
operator<< (std::ostream& os, const Ground& gr)
{
os << gr.functor();
os << "(" ;
@ -95,12 +98,12 @@ ostream& operator<< (ostream &os, const Ground& gr)
LogVars
Substitution::getDiscardedLogVars (void) const
Substitution::getDiscardedLogVars() const
{
LogVars discardedLvs;
set<LogVar> doneLvs;
unordered_map<LogVar, LogVar>::const_iterator it;
it = subs_.begin();
std::set<LogVar> doneLvs;
std::unordered_map<LogVar, LogVar>::const_iterator it
= subs_.begin();
while (it != subs_.end()) {
if (Util::contains (doneLvs, it->second)) {
discardedLvs.push_back (it->first);
@ -114,9 +117,10 @@ Substitution::getDiscardedLogVars (void) const
ostream& operator<< (ostream &os, const Substitution& theta)
std::ostream&
operator<< (std::ostream& os, const Substitution& theta)
{
unordered_map<LogVar, LogVar>::const_iterator it;
std::unordered_map<LogVar, LogVar>::const_iterator it;
os << "[" ;
it = theta.subs_.begin();
while (it != theta.subs_.end()) {
@ -128,3 +132,5 @@ ostream& operator<< (ostream &os, const Substitution& theta)
return os;
}
} // namespace Horus

View File

@ -1,164 +1,210 @@
#ifndef HORUS_LIFTEDUTILS_H
#define HORUS_LIFTEDUTILS_H
#include <string>
#ifndef YAP_PACKAGES_CLPBN_HORUS_LIFTEDUTILS_H_
#define YAP_PACKAGES_CLPBN_HORUS_LIFTEDUTILS_H_
#include <vector>
#include <unordered_map>
#include <string>
#include <ostream>
#include "TinySet.h"
#include "Util.h"
using namespace std;
namespace Horus {
class Symbol
{
class Symbol {
public:
Symbol (void) : id_(Util::maxUnsigned()) { }
Symbol() : id_(Util::maxUnsigned()) { }
Symbol (unsigned id) : id_(id) { }
operator unsigned (void) const { return id_; }
operator unsigned() const { return id_; }
bool valid (void) const { return id_ != Util::maxUnsigned(); }
bool valid() const { return id_ != Util::maxUnsigned(); }
static Symbol invalid (void) { return Symbol(); }
friend ostream& operator<< (ostream &os, const Symbol& s);
static Symbol invalid() { return Symbol(); }
private:
friend std::ostream& operator<< (std::ostream&, const Symbol&);
unsigned id_;
};
class LogVar
{
class LogVar {
public:
LogVar (void) : id_(Util::maxUnsigned()) { }
LogVar() : id_(Util::maxUnsigned()) { }
LogVar (unsigned id) : id_(id) { }
operator unsigned (void) const { return id_; }
operator unsigned() const { return id_; }
LogVar& operator++ (void)
{
assert (valid());
id_ ++;
return *this;
}
LogVar& operator++();
bool valid (void) const
{
return id_ != Util::maxUnsigned();
}
friend ostream& operator<< (ostream &os, const LogVar& X);
bool valid() const;
private:
friend std::ostream& operator<< (std::ostream&, const LogVar&);
unsigned id_;
};
namespace std {
template <> struct hash<Symbol> {
size_t operator() (const Symbol& s) const {
return std::hash<unsigned>() (s);
}};
template <> struct hash<LogVar> {
size_t operator() (const LogVar& X) const {
return std::hash<unsigned>() (X);
}};
};
typedef vector<Symbol> Symbols;
typedef vector<Symbol> Tuple;
typedef vector<Tuple> Tuples;
typedef vector<LogVar> LogVars;
typedef TinySet<Symbol> SymbolSet;
typedef TinySet<LogVar> LogVarSet;
typedef TinySet<Tuple> TupleSet;
ostream& operator<< (ostream &os, const Tuple& t);
namespace LiftedUtils {
Symbol getSymbol (const string&);
void printSymbolDictionary (void);
inline LogVar&
LogVar::operator++()
{
assert (valid());
id_ ++;
return *this;
}
class Ground
inline bool
LogVar::valid() const
{
return id_ != Util::maxUnsigned();
}
} // namespace Horus
namespace std {
template <> struct hash<Horus::Symbol> {
size_t operator() (const Horus::Symbol& s) const {
return std::hash<unsigned>() (s);
}};
template <> struct hash<Horus::LogVar> {
size_t operator() (const Horus::LogVar& X) const {
return std::hash<unsigned>() (X);
}};
} // namespace std
namespace Horus {
typedef std::vector<Symbol> Symbols;
typedef std::vector<Symbol> Tuple;
typedef std::vector<Tuple> Tuples;
typedef std::vector<LogVar> LogVars;
typedef TinySet<Symbol> SymbolSet;
typedef TinySet<LogVar> LogVarSet;
typedef TinySet<Tuple> TupleSet;
std::ostream& operator<< (std::ostream&, const Tuple&);
namespace LiftedUtils {
Symbol getSymbol (const std::string&);
void printSymbolDictionary();
}
class Ground {
public:
Ground (Symbol f) : functor_(f) { }
Ground (Symbol f, const Symbols& args) : functor_(f), args_(args) { }
Ground (Symbol f, const Symbols& args)
: functor_(f), args_(args) { }
Symbol functor (void) const { return functor_; }
Symbol functor() const { return functor_; }
Symbols args (void) const { return args_; }
Symbols args() const { return args_; }
size_t arity (void) const { return args_.size(); }
size_t arity() const { return args_.size(); }
bool isAtom (void) const { return args_.empty(); }
friend ostream& operator<< (ostream &os, const Ground& gr);
bool isAtom() const { return args_.empty(); }
private:
friend std::ostream& operator<< (std::ostream&, const Ground&);
Symbol functor_;
Symbols args_;
};
typedef vector<Ground> Grounds;
typedef std::vector<Ground> Grounds;
class Substitution
{
class Substitution {
public:
void add (LogVar X_old, LogVar X_new)
{
assert (Util::contains (subs_, X_old) == false);
subs_.insert (make_pair (X_old, X_new));
}
void add (LogVar X_old, LogVar X_new);
void rename (LogVar X_old, LogVar X_new)
{
assert (Util::contains (subs_, X_old));
subs_.find (X_old)->second = X_new;
}
void rename (LogVar X_old, LogVar X_new);
LogVar newNameFor (LogVar X) const
{
unordered_map<LogVar, LogVar>::const_iterator it;
it = subs_.find (X);
if (it != subs_.end()) {
return subs_.find (X)->second;
}
return X;
}
LogVar newNameFor (LogVar X) const;
bool containsReplacementFor (LogVar X) const
{
return Util::contains (subs_, X);
}
bool containsReplacementFor (LogVar X) const;
size_t nrReplacements (void) const { return subs_.size(); }
size_t nrReplacements() const;
LogVars getDiscardedLogVars (void) const;
friend ostream& operator<< (ostream &os, const Substitution& theta);
LogVars getDiscardedLogVars() const;
private:
unordered_map<LogVar, LogVar> subs_;
friend std::ostream& operator<< (
std::ostream&, const Substitution&);
std::unordered_map<LogVar, LogVar> subs_;
};
#endif // HORUS_LIFTEDUTILS_H
inline void
Substitution::add (LogVar X_old, LogVar X_new)
{
assert (Util::contains (subs_, X_old) == false);
subs_.insert (std::make_pair (X_old, X_new));
}
inline void
Substitution::rename (LogVar X_old, LogVar X_new)
{
assert (Util::contains (subs_, X_old));
subs_.find (X_old)->second = X_new;
}
inline LogVar
Substitution::newNameFor (LogVar X) const
{
std::unordered_map<LogVar, LogVar>::const_iterator it;
it = subs_.find (X);
if (it != subs_.end()) {
return subs_.find (X)->second;
}
return X;
}
inline bool
Substitution::containsReplacementFor (LogVar X) const
{
return Util::contains (subs_, X);
}
inline size_t
Substitution::nrReplacements() const
{
return subs_.size();
}
} // namespace Horus
#endif // YAP_PACKAGES_CLPBN_HORUS_LIFTEDUTILS_H_

View File

@ -1,6 +1,12 @@
#include <algorithm>
#include <cassert>
#include <vector>
#include <set>
#include <queue>
#include <algorithm>
#include <string>
#include <iostream>
#include <sstream>
#include "LiftedVe.h"
#include "LiftedOperations.h"
@ -8,21 +14,158 @@
#include "Util.h"
vector<LiftedOperator*>
namespace Horus {
class LiftedOperator {
public:
virtual ~LiftedOperator() { }
virtual double getLogCost() = 0;
virtual void apply() = 0;
virtual std::string toString() = 0;
static std::vector<LiftedOperator*> getValidOps (
ParfactorList&, const Grounds&);
static void printValidOps (ParfactorList&, const Grounds&);
static std::vector<ParfactorList::iterator> getParfactorsWithGroup (
ParfactorList&, PrvGroup group);
private:
DISALLOW_ASSIGN (LiftedOperator);
};
class ProductOperator : public LiftedOperator {
public:
ProductOperator (
ParfactorList::iterator g1,
ParfactorList::iterator g2,
ParfactorList& pfList)
: g1_(g1), g2_(g2), pfList_(pfList) { }
double getLogCost();
void apply();
static std::vector<ProductOperator*> getValidOps (ParfactorList&);
std::string toString();
private:
static bool validOp (Parfactor*, Parfactor*);
ParfactorList::iterator g1_;
ParfactorList::iterator g2_;
ParfactorList& pfList_;
DISALLOW_COPY_AND_ASSIGN (ProductOperator);
};
class SumOutOperator : public LiftedOperator {
public:
SumOutOperator (PrvGroup group, ParfactorList& pfList)
: group_(group), pfList_(pfList) { }
double getLogCost();
void apply();
static std::vector<SumOutOperator*> getValidOps (
ParfactorList&, const Grounds&);
std::string toString();
private:
static bool validOp (PrvGroup, ParfactorList&, const Grounds&);
static bool isToEliminate (Parfactor*, PrvGroup, const Grounds&);
PrvGroup group_;
ParfactorList& pfList_;
DISALLOW_COPY_AND_ASSIGN (SumOutOperator);
};
class CountingOperator : public LiftedOperator {
public:
CountingOperator (
ParfactorList::iterator pfIter,
LogVar X,
ParfactorList& pfList)
: pfIter_(pfIter), X_(X), pfList_(pfList) { }
double getLogCost();
void apply();
static std::vector<CountingOperator*> getValidOps (ParfactorList&);
std::string toString();
private:
static bool validOp (Parfactor*, LogVar);
ParfactorList::iterator pfIter_;
LogVar X_;
ParfactorList& pfList_;
DISALLOW_COPY_AND_ASSIGN (CountingOperator);
};
class GroundOperator : public LiftedOperator {
public:
GroundOperator (
PrvGroup group,
unsigned lvIndex,
ParfactorList& pfList)
: group_(group), lvIndex_(lvIndex), pfList_(pfList) { }
double getLogCost();
void apply();
static std::vector<GroundOperator*> getValidOps (ParfactorList&);
std::string toString();
private:
std::vector<std::pair<PrvGroup, unsigned>> getAffectedFormulas();
PrvGroup group_;
unsigned lvIndex_;
ParfactorList& pfList_;
DISALLOW_COPY_AND_ASSIGN (GroundOperator);
};
std::vector<LiftedOperator*>
LiftedOperator::getValidOps (
ParfactorList& pfList,
const Grounds& query)
{
vector<LiftedOperator*> validOps;
vector<ProductOperator*> multOps;
std::vector<LiftedOperator*> validOps;
std::vector<ProductOperator*> multOps;
multOps = ProductOperator::getValidOps (pfList);
validOps.insert (validOps.end(), multOps.begin(), multOps.end());
if (Globals::verbosity > 1 || multOps.empty()) {
vector<SumOutOperator*> sumOutOps;
vector<CountingOperator*> countOps;
vector<GroundOperator*> groundOps;
std::vector<SumOutOperator*> sumOutOps;
std::vector<CountingOperator*> countOps;
std::vector<GroundOperator*> groundOps;
sumOutOps = SumOutOperator::getValidOps (pfList, query);
countOps = CountingOperator::getValidOps (pfList);
groundOps = GroundOperator::getValidOps (pfList);
@ -41,21 +184,21 @@ LiftedOperator::printValidOps (
ParfactorList& pfList,
const Grounds& query)
{
vector<LiftedOperator*> validOps;
std::vector<LiftedOperator*> validOps;
validOps = LiftedOperator::getValidOps (pfList, query);
for (size_t i = 0; i < validOps.size(); i++) {
cout << "-> " << validOps[i]->toString();
std::cout << "-> " << validOps[i]->toString();
delete validOps[i];
}
}
vector<ParfactorList::iterator>
std::vector<ParfactorList::iterator>
LiftedOperator::getParfactorsWithGroup (
ParfactorList& pfList, PrvGroup group)
{
vector<ParfactorList::iterator> iters;
std::vector<ParfactorList::iterator> iters;
ParfactorList::iterator pflIt = pfList.begin();
while (pflIt != pfList.end()) {
if ((*pflIt)->containsGroup (group)) {
@ -69,7 +212,7 @@ LiftedOperator::getParfactorsWithGroup (
double
ProductOperator::getLogCost (void)
ProductOperator::getLogCost()
{
return std::log (0.0);
}
@ -77,7 +220,7 @@ ProductOperator::getLogCost (void)
void
ProductOperator::apply (void)
ProductOperator::apply()
{
Parfactor* g1 = *g1_;
Parfactor* g2 = *g2_;
@ -89,13 +232,13 @@ ProductOperator::apply (void)
vector<ProductOperator*>
std::vector<ProductOperator*>
ProductOperator::getValidOps (ParfactorList& pfList)
{
vector<ProductOperator*> validOps;
std::vector<ProductOperator*> validOps;
ParfactorList::iterator it1 = pfList.begin();
ParfactorList::iterator penultimate = -- pfList.end();
set<Parfactor*> pfs;
std::set<Parfactor*> pfs;
while (it1 != penultimate) {
if (Util::contains (pfs, *it1)) {
++ it1;
@ -128,15 +271,15 @@ ProductOperator::getValidOps (ParfactorList& pfList)
string
ProductOperator::toString (void)
std::string
ProductOperator::toString()
{
stringstream ss;
std::stringstream ss;
ss << "just multiplicate " ;
ss << (*g1_)->getAllGroups();
ss << " x " ;
ss << (*g2_)->getAllGroups();
ss << " [cost=" << std::exp (getLogCost()) << "]" << endl;
ss << " [cost=" << std::exp (getLogCost()) << "]" << std::endl;
return ss.str();
}
@ -168,14 +311,14 @@ ProductOperator::validOp (Parfactor* g1, Parfactor* g2)
double
SumOutOperator::getLogCost (void)
SumOutOperator::getLogCost()
{
TinySet<PrvGroup> groupSet;
ParfactorList::const_iterator pfIter = pfList_.begin();
unsigned nrProdFactors = 0;
while (pfIter != pfList_.end()) {
if ((*pfIter)->containsGroup (group_)) {
vector<PrvGroup> groups = (*pfIter)->getAllGroups();
std::vector<PrvGroup> groups = (*pfIter)->getAllGroups();
groupSet |= TinySet<PrvGroup> (groups);
++ nrProdFactors;
}
@ -203,9 +346,9 @@ SumOutOperator::getLogCost (void)
void
SumOutOperator::apply (void)
SumOutOperator::apply()
{
vector<ParfactorList::iterator> iters;
std::vector<ParfactorList::iterator> iters;
iters = getParfactorsWithGroup (pfList_, group_);
Parfactor* product = *(iters[0]);
pfList_.remove (iters[0]);
@ -234,13 +377,13 @@ SumOutOperator::apply (void)
vector<SumOutOperator*>
std::vector<SumOutOperator*>
SumOutOperator::getValidOps (
ParfactorList& pfList,
const Grounds& query)
{
vector<SumOutOperator*> validOps;
set<PrvGroup> allGroups;
std::vector<SumOutOperator*> validOps;
std::set<PrvGroup> allGroups;
ParfactorList::const_iterator it = pfList.begin();
while (it != pfList.end()) {
const ProbFormulas& formulas = (*it)->arguments();
@ -249,7 +392,7 @@ SumOutOperator::getValidOps (
}
++ it;
}
set<PrvGroup>::const_iterator groupIt = allGroups.begin();
std::set<PrvGroup>::const_iterator groupIt = allGroups.begin();
while (groupIt != allGroups.end()) {
if (validOp (*groupIt, pfList, query)) {
validOps.push_back (new SumOutOperator (*groupIt, pfList));
@ -261,18 +404,18 @@ SumOutOperator::getValidOps (
string
SumOutOperator::toString (void)
std::string
SumOutOperator::toString()
{
stringstream ss;
vector<ParfactorList::iterator> pfIters;
std::stringstream ss;
std::vector<ParfactorList::iterator> pfIters;
pfIters = getParfactorsWithGroup (pfList_, group_);
size_t idx = (*pfIters[0])->indexOfGroup (group_);
ProbFormula f = (*pfIters[0])->argument (idx);
TupleSet tupleSet = (*pfIters[0])->constr()->tupleSet (f.logVars());
ss << "sum out " << f.functor() << "/" << f.arity();
ss << "|" << tupleSet << " (group " << group_ << ")";
ss << " [cost=" << std::exp (getLogCost()) << "]" << endl;
ss << " [cost=" << std::exp (getLogCost()) << "]" << std::endl;
return ss.str();
}
@ -284,7 +427,7 @@ SumOutOperator::validOp (
ParfactorList& pfList,
const Grounds& query)
{
vector<ParfactorList::iterator> pfIters;
std::vector<ParfactorList::iterator> pfIters;
pfIters = getParfactorsWithGroup (pfList, group);
if (isToEliminate (*pfIters[0], group, query) == false) {
return false;
@ -335,7 +478,7 @@ SumOutOperator::isToEliminate (
double
CountingOperator::getLogCost (void)
CountingOperator::getLogCost()
{
double cost = 0.0;
size_t fIdx = (*pfIter_)->indexOfLogVar (X_);
@ -370,7 +513,7 @@ CountingOperator::getLogCost (void)
void
CountingOperator::apply (void)
CountingOperator::apply()
{
if ((*pfIter_)->constr()->isCountNormalized (X_)) {
(*pfIter_)->countConvert (X_);
@ -393,10 +536,10 @@ CountingOperator::apply (void)
vector<CountingOperator*>
std::vector<CountingOperator*>
CountingOperator::getValidOps (ParfactorList& pfList)
{
vector<CountingOperator*> validOps;
std::vector<CountingOperator*> validOps;
ParfactorList::iterator it = pfList.begin();
while (it != pfList.end()) {
LogVarSet candidates = (*it)->uncountedLogVars();
@ -414,17 +557,17 @@ CountingOperator::getValidOps (ParfactorList& pfList)
string
CountingOperator::toString (void)
std::string
CountingOperator::toString()
{
stringstream ss;
std::stringstream ss;
ss << "count convert " << X_ << " in " ;
ss << (*pfIter_)->getLabel();
ss << " [cost=" << std::exp (getLogCost()) << "]" << endl;
ss << " [cost=" << std::exp (getLogCost()) << "]" << std::endl;
Parfactors pfs = LiftedOperations::countNormalize (*pfIter_, X_);
if ((*pfIter_)->constr()->isCountNormalized (X_) == false) {
for (size_t i = 0; i < pfs.size(); i++) {
ss << " º " << pfs[i]->getLabel() << endl;
ss << " º " << pfs[i]->getLabel() << std::endl;
}
}
for (size_t i = 0; i < pfs.size(); i++) {
@ -455,16 +598,16 @@ CountingOperator::validOp (Parfactor* g, LogVar X)
double
GroundOperator::getLogCost (void)
GroundOperator::getLogCost()
{
vector<pair<PrvGroup, unsigned>> affectedFormulas;
std::vector<std::pair<PrvGroup, unsigned>> affectedFormulas;
affectedFormulas = getAffectedFormulas();
// cout << "affected formulas: " ;
// std::cout << "affected formulas: " ;
// for (size_t i = 0; i < affectedFormulas.size(); i++) {
// cout << affectedFormulas[i].first << ":" ;
// cout << affectedFormulas[i].second << " " ;
// std::cout << affectedFormulas[i].first << ":" ;
// std::cout << affectedFormulas[i].second << " " ;
// }
// cout << "cost =" ;
// std::cout << "cost =" ;
double totalCost = std::log (0.0);
ParfactorList::iterator pflIt = pfList_.begin();
while (pflIt != pfList_.end()) {
@ -495,20 +638,20 @@ GroundOperator::getLogCost (void)
}
}
if (willBeAffected) {
// cout << " + " << std::exp (reps) << "x" << std::exp (pfSize);
// std::cout << " + " << std::exp (reps) << "x" << std::exp (pfSize);
double pfCost = reps + pfSize;
totalCost = Util::logSum (totalCost, pfCost);
}
++ pflIt;
}
// cout << endl;
// std::cout << std::endl;
return totalCost + 3;
}
void
GroundOperator::apply (void)
GroundOperator::apply()
{
ParfactorList::iterator pfIter;
pfIter = getParfactorsWithGroup (pfList_, group_).front();
@ -537,11 +680,11 @@ GroundOperator::apply (void)
vector<GroundOperator*>
std::vector<GroundOperator*>
GroundOperator::getValidOps (ParfactorList& pfList)
{
vector<GroundOperator*> validOps;
set<PrvGroup> allGroups;
std::vector<GroundOperator*> validOps;
std::set<PrvGroup> allGroups;
ParfactorList::const_iterator it = pfList.begin();
while (it != pfList.end()) {
const ProbFormulas& formulas = (*it)->arguments();
@ -564,18 +707,18 @@ GroundOperator::getValidOps (ParfactorList& pfList)
string
GroundOperator::toString (void)
std::string
GroundOperator::toString()
{
stringstream ss;
vector<ParfactorList::iterator> pfIters;
std::stringstream ss;
std::vector<ParfactorList::iterator> pfIters;
pfIters = getParfactorsWithGroup (pfList_, group_);
Parfactor* pf = *(getParfactorsWithGroup (pfList_, group_).front());
size_t idx = pf->indexOfGroup (group_);
ProbFormula f = pf->argument (idx);
LogVar lv = f.logVars()[lvIndex_];
TupleSet tupleSet = pf->constr()->tupleSet ({lv});
string pos = "th";
std::string pos = "th";
if (lvIndex_ == 0) {
pos = "st" ;
} else if (lvIndex_ == 1) {
@ -586,21 +729,21 @@ GroundOperator::toString (void)
ss << "grounding " << lvIndex_ + 1 << pos << " log var in " ;
ss << f.functor() << "/" << f.arity();
ss << "|" << tupleSet << " (group " << group_ << ")";
ss << " [cost=" << std::exp (getLogCost()) << "]" << endl;
ss << " [cost=" << std::exp (getLogCost()) << "]" << std::endl;
return ss.str();
}
vector<pair<PrvGroup, unsigned>>
GroundOperator::getAffectedFormulas (void)
std::vector<std::pair<PrvGroup, unsigned>>
GroundOperator::getAffectedFormulas()
{
vector<pair<PrvGroup, unsigned>> affectedFormulas;
affectedFormulas.push_back (make_pair (group_, lvIndex_));
queue<pair<PrvGroup, unsigned>> q;
q.push (make_pair (group_, lvIndex_));
std::vector<std::pair<PrvGroup, unsigned>> affectedFormulas;
affectedFormulas.push_back (std::make_pair (group_, lvIndex_));
std::queue<std::pair<PrvGroup, unsigned>> q;
q.push (std::make_pair (group_, lvIndex_));
while (q.empty() == false) {
pair<PrvGroup, unsigned> front = q.front();
std::pair<PrvGroup, unsigned> front = q.front();
ParfactorList::iterator pflIt = pfList_.begin();
while (pflIt != pfList_.end()) {
size_t idx = (*pflIt)->indexOfGroup (front.first);
@ -610,7 +753,7 @@ GroundOperator::getAffectedFormulas (void)
const ProbFormulas& fs = (*pflIt)->arguments();
for (size_t i = 0; i < fs.size(); i++) {
if (i != idx && fs[i].contains (X)) {
pair<PrvGroup, unsigned> pair = make_pair (
std::pair<PrvGroup, unsigned> pair = std::make_pair (
fs[i].group(), fs[i].indexOf (X));
if (Util::contains (affectedFormulas, pair) == false) {
q.push (pair);
@ -645,13 +788,13 @@ LiftedVe::solveQuery (const Grounds& query)
void
LiftedVe::printSolverFlags (void) const
LiftedVe::printSolverFlags() const
{
stringstream ss;
std::stringstream ss;
ss << "lve [" ;
ss << "log_domain=" << Util::toString (Globals::logDomain);
ss << "]" ;
cout << ss.str() << endl;
std::cout << ss.str() << std::endl;
}
@ -675,9 +818,9 @@ LiftedVe::runSolver (const Grounds& query)
break;
}
if (Globals::verbosity > 1) {
cout << "best operation: " << op->toString();
std::cout << "best operation: " << op->toString();
if (Globals::verbosity > 2) {
cout << endl;
std::cout << std::endl;
}
}
op->apply();
@ -693,8 +836,9 @@ LiftedVe::runSolver (const Grounds& query)
}
}
if (Globals::verbosity > 0) {
cout << "largest cost = " << std::exp (largestCost_) << endl;
cout << endl;
std::cout << "largest cost = " << std::exp (largestCost_);
std::cout << std::endl;
std::cout << std::endl;
}
(*pfList_.begin())->simplifyGrounds();
(*pfList_.begin())->reorderAccordingGrounds (query);
@ -707,7 +851,7 @@ LiftedVe::getBestOperation (const Grounds& query)
{
double bestCost = 0.0;
LiftedOperator* bestOp = 0;
vector<LiftedOperator*> validOps;
std::vector<LiftedOperator*> validOps;
validOps = LiftedOperator::getValidOps (pfList_, query);
for (size_t i = 0; i < validOps.size(); i++) {
double cost = validOps[i]->getLogCost();
@ -727,3 +871,5 @@ LiftedVe::getBestOperation (const Grounds& query)
return bestOp;
}
} // namespace Horus

View File

@ -1,157 +1,23 @@
#ifndef HORUS_LIFTEDVE_H
#define HORUS_LIFTEDVE_H
#ifndef YAP_PACKAGES_CLPBN_HORUS_LIFTEDVE_H_
#define YAP_PACKAGES_CLPBN_HORUS_LIFTEDVE_H_
#include "LiftedSolver.h"
#include "ParfactorList.h"
class LiftedOperator
{
public:
virtual ~LiftedOperator (void) { }
namespace Horus {
virtual double getLogCost (void) = 0;
virtual void apply (void) = 0;
virtual string toString (void) = 0;
static vector<LiftedOperator*> getValidOps (
ParfactorList&, const Grounds&);
static void printValidOps (ParfactorList&, const Grounds&);
static vector<ParfactorList::iterator> getParfactorsWithGroup (
ParfactorList&, PrvGroup group);
private:
DISALLOW_ASSIGN (LiftedOperator);
};
class LiftedOperator;
class ProductOperator : public LiftedOperator
{
public:
ProductOperator (
ParfactorList::iterator g1, ParfactorList::iterator g2,
ParfactorList& pfList) : g1_(g1), g2_(g2), pfList_(pfList) { }
double getLogCost (void);
void apply (void);
static vector<ProductOperator*> getValidOps (ParfactorList&);
string toString (void);
private:
static bool validOp (Parfactor*, Parfactor*);
ParfactorList::iterator g1_;
ParfactorList::iterator g2_;
ParfactorList& pfList_;
DISALLOW_COPY_AND_ASSIGN (ProductOperator);
};
class SumOutOperator : public LiftedOperator
{
public:
SumOutOperator (PrvGroup group, ParfactorList& pfList)
: group_(group), pfList_(pfList) { }
double getLogCost (void);
void apply (void);
static vector<SumOutOperator*> getValidOps (
ParfactorList&, const Grounds&);
string toString (void);
private:
static bool validOp (PrvGroup, ParfactorList&, const Grounds&);
static bool isToEliminate (Parfactor*, PrvGroup, const Grounds&);
PrvGroup group_;
ParfactorList& pfList_;
DISALLOW_COPY_AND_ASSIGN (SumOutOperator);
};
class CountingOperator : public LiftedOperator
{
public:
CountingOperator (
ParfactorList::iterator pfIter,
LogVar X,
ParfactorList& pfList)
: pfIter_(pfIter), X_(X), pfList_(pfList) { }
double getLogCost (void);
void apply (void);
static vector<CountingOperator*> getValidOps (ParfactorList&);
string toString (void);
private:
static bool validOp (Parfactor*, LogVar);
ParfactorList::iterator pfIter_;
LogVar X_;
ParfactorList& pfList_;
DISALLOW_COPY_AND_ASSIGN (CountingOperator);
};
class GroundOperator : public LiftedOperator
{
public:
GroundOperator (
PrvGroup group,
unsigned lvIndex,
ParfactorList& pfList)
: group_(group), lvIndex_(lvIndex), pfList_(pfList) { }
double getLogCost (void);
void apply (void);
static vector<GroundOperator*> getValidOps (ParfactorList&);
string toString (void);
private:
vector<pair<PrvGroup, unsigned>> getAffectedFormulas (void);
PrvGroup group_;
unsigned lvIndex_;
ParfactorList& pfList_;
DISALLOW_COPY_AND_ASSIGN (GroundOperator);
};
class LiftedVe : public LiftedSolver
{
class LiftedVe : public LiftedSolver {
public:
LiftedVe (const ParfactorList& pfList)
: LiftedSolver(pfList) { }
Params solveQuery (const Grounds&);
void printSolverFlags (void) const;
void printSolverFlags() const;
private:
void runSolver (const Grounds&);
@ -164,5 +30,7 @@ class LiftedVe : public LiftedSolver
DISALLOW_COPY_AND_ASSIGN (LiftedVe);
};
#endif // HORUS_LIFTEDVE_H
} // namespace Horus
#endif // YAP_PACKAGES_CLPBN_HORUS_LIFTEDVE_H_

View File

@ -1,10 +1,20 @@
#include <cassert>
#include <iostream>
#include <sstream>
#include "LiftedWCNF.h"
#include "ParfactorList.h"
#include "ConstraintTree.h"
#include "Indexer.h"
namespace Horus {
bool
Literal::isGround (ConstraintTree constr, LogVarSet ipgLogVars) const
Literal::isGround (
ConstraintTree constr,
const LogVarSet& ipgLogVars) const
{
if (logVars_.empty()) {
return true;
@ -24,13 +34,13 @@ Literal::indexOfLogVar (LogVar X) const
string
std::string
Literal::toString (
LogVarSet ipgLogVars,
LogVarSet posCountedLvs,
LogVarSet negCountedLvs) const
{
stringstream ss;
std::stringstream ss;
negated_ ? ss << "¬" : ss << "" ;
ss << "λ" ;
ss << lid_ ;
@ -44,7 +54,7 @@ Literal::toString (
ss << "-" << logVars_[i];
} else if (ipgLogVars.contains (logVars_[i])) {
LogVar X = logVars_[i];
const string labels[] = {
const std::string labels[] = {
"a", "b", "c", "d", "e", "f",
"g", "h", "i", "j", "k", "m" };
(X >= 12) ? ss << "x_" << X : ss << labels[X];
@ -60,7 +70,7 @@ Literal::toString (
std::ostream&
operator<< (ostream &os, const Literal& lit)
operator<< (std::ostream& os, const Literal& lit)
{
os << lit.toString();
return os;
@ -216,7 +226,7 @@ Clause::isIpgLogVar (LogVar X) const
TinySet<LiteralId>
Clause::lidSet (void) const
Clause::lidSet() const
{
TinySet<LiteralId> lidSet;
for (size_t i = 0; i < literals_.size(); i++) {
@ -228,7 +238,7 @@ Clause::lidSet (void) const
LogVarSet
Clause::ipgCandidates (void) const
Clause::ipgCandidates() const
{
LogVarSet candidates;
LogVarSet allLvs = constr_.logVarSet();
@ -259,11 +269,11 @@ Clause::logVarTypes (size_t litIdx) const
const LogVars& lvs = literals_[litIdx].logVars();
for (size_t i = 0; i < lvs.size(); i++) {
if (posCountedLvs_.contains (lvs[i])) {
types.push_back (LogVarType::POS_LV);
types.push_back (LogVarType::posLvt);
} else if (negCountedLvs_.contains (lvs[i])) {
types.push_back (LogVarType::NEG_LV);
types.push_back (LogVarType::negLvt);
} else {
types.push_back (LogVarType::FULL_LV);
types.push_back (LogVarType::fullLvt);
}
}
return types;
@ -320,7 +330,7 @@ void
Clause::printClauses (const Clauses& clauses)
{
for (size_t i = 0; i < clauses.size(); i++) {
cout << *clauses[i] << endl;
std::cout << *clauses[i] << std::endl;
}
}
@ -337,7 +347,7 @@ Clause::deleteClauses (Clauses& clauses)
std::ostream&
operator<< (ostream &os, const Clause& clause)
operator<< (std::ostream& os, const Clause& clause)
{
for (unsigned i = 0; i < clause.literals_.size(); i++) {
if (i != 0) os << " v " ;
@ -369,14 +379,14 @@ Clause::getLogVarSetExcluding (size_t idx) const
std::ostream&
operator<< (std::ostream &os, const LitLvTypes& lit)
operator<< (std::ostream& os, const LitLvTypes& lit)
{
os << lit.lid_ << "<" ;
for (size_t i = 0; i < lit.lvTypes_.size(); i++) {
switch (lit.lvTypes_[i]) {
case LogVarType::FULL_LV: os << "F" ; break;
case LogVarType::POS_LV: os << "P" ; break;
case LogVarType::NEG_LV: os << "N" ; break;
case LogVarType::fullLvt: os << "F" ; break;
case LogVarType::posLvt: os << "P" ; break;
case LogVarType::negLvt: os << "N" ; break;
}
}
os << ">" ;
@ -385,6 +395,14 @@ operator<< (std::ostream &os, const LitLvTypes& lit)
void
LitLvTypes::setAllFullLogVars()
{
std::fill (lvTypes_.begin(), lvTypes_.end(), LogVarType::fullLvt);
}
LiftedWCNF::LiftedWCNF (const ParfactorList& pfList)
: freeLiteralId_(0), pfList_(pfList)
{
@ -394,7 +412,7 @@ LiftedWCNF::LiftedWCNF (const ParfactorList& pfList)
/*
// INCLUSION-EXCLUSION TEST
clauses_.clear();
vector<vector<string>> names = {
std::vector<std::vector<string>> names = {
{"a1","b1"},{"a2","b2"}
};
Clause* c1 = new Clause (names);
@ -406,7 +424,7 @@ LiftedWCNF::LiftedWCNF (const ParfactorList& pfList)
/*
// INDEPENDENT PARTIAL GROUND TEST
clauses_.clear();
vector<vector<string>> names = {
std::vector<std::vector<string>> names = {
{"a1","b1"},{"a2","b2"}
};
Clause* c1 = new Clause (names);
@ -422,7 +440,7 @@ LiftedWCNF::LiftedWCNF (const ParfactorList& pfList)
/*
// ATOM-COUNTING TEST
clauses_.clear();
vector<vector<string>> names = {
std::vector<std::vector<string>> names = {
{"p1","p1"},{"p1","p2"},{"p1","p3"},
{"p2","p1"},{"p2","p2"},{"p2","p3"},
{"p3","p1"},{"p3","p2"},{"p3","p3"}
@ -438,21 +456,21 @@ LiftedWCNF::LiftedWCNF (const ParfactorList& pfList)
*/
if (Globals::verbosity > 1) {
cout << "FORMULA INDICATORS:" << endl;
std::cout << "FORMULA INDICATORS:" << std::endl;
printFormulaIndicators();
cout << endl;
cout << "WEIGHTED INDICATORS:" << endl;
std::cout << std::endl;
std::cout << "WEIGHTED INDICATORS:" << std::endl;
printWeights();
cout << endl;
cout << "CLAUSES:" << endl;
std::cout << std::endl;
std::cout << "CLAUSES:" << std::endl;
printClauses();
cout << endl;
std::cout << std::endl;
}
}
LiftedWCNF::~LiftedWCNF (void)
LiftedWCNF::~LiftedWCNF()
{
Clause::deleteClauses (clauses_);
}
@ -462,7 +480,7 @@ LiftedWCNF::~LiftedWCNF (void)
void
LiftedWCNF::addWeight (LiteralId lid, double posW, double negW)
{
weights_[lid] = make_pair (posW, negW);
weights_[lid] = std::make_pair (posW, negW);
}
@ -470,8 +488,8 @@ LiftedWCNF::addWeight (LiteralId lid, double posW, double negW)
double
LiftedWCNF::posWeight (LiteralId lid) const
{
unordered_map<LiteralId, std::pair<double,double>>::const_iterator it;
it = weights_.find (lid);
std::unordered_map<LiteralId, std::pair<double,double>>::const_iterator it
= weights_.find (lid);
return it != weights_.end() ? it->second.first : LogAware::one();
}
@ -480,14 +498,14 @@ LiftedWCNF::posWeight (LiteralId lid) const
double
LiftedWCNF::negWeight (LiteralId lid) const
{
unordered_map<LiteralId, std::pair<double,double>>::const_iterator it;
it = weights_.find (lid);
std::unordered_map<LiteralId, std::pair<double,double>>::const_iterator it
= weights_.find (lid);
return it != weights_.end() ? it->second.second : LogAware::one();
}
vector<LiteralId>
std::vector<LiteralId>
LiftedWCNF::prvGroupLiterals (PrvGroup prvGroup)
{
assert (Util::contains (map_, prvGroup));
@ -536,9 +554,10 @@ LiftedWCNF::addIndicatorClauses (const ParfactorList& pfList)
ConstraintTree tempConstr = (*it)->constr()->projectedCopy(
formulas[i].logVars());
Clause* clause = new Clause (tempConstr);
vector<LiteralId> lids;
std::vector<LiteralId> lids;
for (size_t j = 0; j < formulas[i].range(); j++) {
clause->addLiteral (Literal (freeLiteralId_, formulas[i].logVars()));
clause->addLiteral (Literal (
freeLiteralId_, formulas[i].logVars()));
lids.push_back (freeLiteralId_);
freeLiteralId_ ++;
}
@ -568,7 +587,7 @@ LiftedWCNF::addParameterClauses (const ParfactorList& pfList)
ParfactorList::const_iterator it = pfList.begin();
while (it != pfList.end()) {
Indexer indexer ((*it)->ranges());
vector<PrvGroup> groups = (*it)->getAllGroups();
std::vector<PrvGroup> groups = (*it)->getAllGroups();
while (indexer.valid()) {
LiteralId paramVarLid = freeLiteralId_;
// λu1 ∧ ... ∧ λun ∧ λxi <=> θxi|u1,...,un
@ -606,26 +625,26 @@ LiftedWCNF::addParameterClauses (const ParfactorList& pfList)
void
LiftedWCNF::printFormulaIndicators (void) const
LiftedWCNF::printFormulaIndicators() const
{
if (map_.empty()) {
return;
}
set<PrvGroup> allGroups;
std::set<PrvGroup> allGroups;
ParfactorList::const_iterator it = pfList_.begin();
while (it != pfList_.end()) {
const ProbFormulas& formulas = (*it)->arguments();
for (size_t i = 0; i < formulas.size(); i++) {
if (Util::contains (allGroups, formulas[i].group()) == false) {
allGroups.insert (formulas[i].group());
cout << formulas[i] << " | " ;
std::cout << formulas[i] << " | " ;
ConstraintTree tempCt = (*it)->constr()->projectedCopy (
formulas[i].logVars());
cout << tempCt.tupleSet();
cout << " indicators => " ;
vector<LiteralId> indicators =
std::cout << tempCt.tupleSet();
std::cout << " indicators => " ;
std::vector<LiteralId> indicators =
(map_.find (formulas[i].group()))->second;
cout << indicators << endl;
std::cout << indicators << std::endl;
}
}
++ it;
@ -635,14 +654,14 @@ LiftedWCNF::printFormulaIndicators (void) const
void
LiftedWCNF::printWeights (void) const
LiftedWCNF::printWeights() const
{
unordered_map<LiteralId, std::pair<double,double>>::const_iterator it;
it = weights_.begin();
std::unordered_map<LiteralId, std::pair<double,double>>::const_iterator it
= weights_.begin();
while (it != weights_.end()) {
cout << "λ" << it->first << " weights: " ;
cout << it->second.first << " " << it->second.second;
cout << endl;
std::cout << "λ" << it->first << " weights: " ;
std::cout << it->second.first << " " << it->second.second;
std::cout << std::endl;
++ it;
}
}
@ -650,8 +669,10 @@ LiftedWCNF::printWeights (void) const
void
LiftedWCNF::printClauses (void) const
LiftedWCNF::printClauses() const
{
Clause::printClauses (clauses_);
}
}

View File

@ -1,90 +1,95 @@
#ifndef HORUS_LIFTEDWCNF_H
#define HORUS_LIFTEDWCNF_H
#ifndef YAP_PACKAGES_CLPBN_HORUS_LIFTEDWCNF_H_
#define YAP_PACKAGES_CLPBN_HORUS_LIFTEDWCNF_H_
#include <vector>
#include <unordered_map>
#include <string>
#include <ostream>
#include "ParfactorList.h"
#include "ConstraintTree.h"
#include "ProbFormula.h"
#include "LiftedUtils.h"
using namespace std;
class ConstraintTree;
namespace Horus {
enum LogVarType
{
FULL_LV,
POS_LV,
NEG_LV
class ParfactorList;
enum class LogVarType {
fullLvt,
posLvt,
negLvt
};
typedef long LiteralId;
typedef vector<LogVarType> LogVarTypes;
typedef long LiteralId;
typedef std::vector<LogVarType> LogVarTypes;
class Literal
{
class Literal {
public:
Literal (LiteralId lid, const LogVars& lvs) :
lid_(lid), logVars_(lvs), negated_(false) { }
Literal (LiteralId lid, const LogVars& lvs)
: lid_(lid), logVars_(lvs), negated_(false) { }
Literal (const Literal& lit, bool negated) :
lid_(lit.lid_), logVars_(lit.logVars_), negated_(negated) { }
Literal (const Literal& lit, bool negated)
: lid_(lit.lid_), logVars_(lit.logVars_), negated_(negated) { }
LiteralId lid (void) const { return lid_; }
LiteralId lid() const { return lid_; }
LogVars logVars (void) const { return logVars_; }
LogVars logVars() const { return logVars_; }
size_t nrLogVars (void) const { return logVars_.size(); }
size_t nrLogVars() const { return logVars_.size(); }
LogVarSet logVarSet (void) const { return LogVarSet (logVars_); }
LogVarSet logVarSet() const { return LogVarSet (logVars_); }
void complement (void) { negated_ = !negated_; }
void complement() { negated_ = !negated_; }
bool isPositive (void) const { return negated_ == false; }
bool isPositive() const { return negated_ == false; }
bool isNegative (void) const { return negated_; }
bool isNegative() const { return negated_; }
bool isGround (ConstraintTree constr, LogVarSet ipgLogVars) const;
bool isGround (ConstraintTree constr, const LogVarSet& ipgLogVars) const;
size_t indexOfLogVar (LogVar X) const;
string toString (LogVarSet ipgLogVars = LogVarSet(),
LogVarSet posCountedLvs = LogVarSet(),
LogVarSet negCountedLvs = LogVarSet()) const;
friend std::ostream& operator<< (std::ostream &os, const Literal& lit);
std::string toString (
LogVarSet ipgLogVars = LogVarSet(),
LogVarSet posCountedLvs = LogVarSet(),
LogVarSet negCountedLvs = LogVarSet()) const;
private:
friend std::ostream& operator<< (std::ostream&, const Literal&);
LiteralId lid_;
LogVars logVars_;
bool negated_;
};
typedef vector<Literal> Literals;
typedef std::vector<Literal> Literals;
class Clause
{
class Clause {
public:
Clause (const ConstraintTree& ct = ConstraintTree({})) : constr_(ct) { }
Clause (vector<vector<string>> names) : constr_(ConstraintTree (names)) { }
Clause (std::vector<std::vector<std::string>> names) :
constr_(ConstraintTree (names)) { }
void addLiteral (const Literal& l) { literals_.push_back (l); }
const Literals& literals (void) const { return literals_; }
const Literals& literals() const { return literals_; }
Literals& literals (void) { return literals_; }
Literals& literals() { return literals_; }
size_t nrLiterals (void) const { return literals_.size(); }
size_t nrLiterals() const { return literals_.size(); }
const ConstraintTree& constr (void) const { return constr_; }
const ConstraintTree& constr() const { return constr_; }
ConstraintTree constr (void) { return constr_; }
ConstraintTree constr() { return constr_; }
bool isUnit (void) const { return literals_.size() == 1; }
bool isUnit() const { return literals_.size() == 1; }
LogVarSet ipgLogVars (void) const { return ipgLvs_; }
LogVarSet ipgLogVars() const { return ipgLvs_; }
void addIpgLogVar (LogVar X) { ipgLvs_.insert (X); }
@ -92,13 +97,13 @@ class Clause
void addNegCountedLogVar (LogVar X) { negCountedLvs_.insert (X); }
LogVarSet posCountedLogVars (void) const { return posCountedLvs_; }
LogVarSet posCountedLogVars() const { return posCountedLvs_; }
LogVarSet negCountedLogVars (void) const { return negCountedLvs_; }
LogVarSet negCountedLogVars() const { return negCountedLvs_; }
unsigned nrPosCountedLogVars (void) const { return posCountedLvs_.size(); }
unsigned nrPosCountedLogVars() const { return posCountedLvs_.size(); }
unsigned nrNegCountedLogVars (void) const { return negCountedLvs_.size(); }
unsigned nrNegCountedLogVars() const { return negCountedLvs_.size(); }
void addLiteralComplemented (const Literal& lit);
@ -122,9 +127,9 @@ class Clause
bool isIpgLogVar (LogVar X) const;
TinySet<LiteralId> lidSet (void) const;
TinySet<LiteralId> lidSet() const;
LogVarSet ipgCandidates (void) const;
LogVarSet ipgCandidates() const;
LogVarTypes logVarTypes (size_t litIdx) const;
@ -132,78 +137,78 @@ class Clause
static bool independentClauses (Clause& c1, Clause& c2);
static vector<Clause*> copyClauses (const vector<Clause*>& clauses);
static std::vector<Clause*> copyClauses (
const std::vector<Clause*>& clauses);
static void printClauses (const vector<Clause*>& clauses);
static void printClauses (const std::vector<Clause*>& clauses);
static void deleteClauses (vector<Clause*>& clauses);
friend std::ostream& operator<< (ostream &os, const Clause& clause);
static void deleteClauses (std::vector<Clause*>& clauses);
private:
LogVarSet getLogVarSetExcluding (size_t idx) const;
Literals literals_;
LogVarSet ipgLvs_;
LogVarSet posCountedLvs_;
LogVarSet negCountedLvs_;
ConstraintTree constr_;
friend std::ostream& operator<< (std::ostream&, const Clause&);
Literals literals_;
LogVarSet ipgLvs_;
LogVarSet posCountedLvs_;
LogVarSet negCountedLvs_;
ConstraintTree constr_;
DISALLOW_ASSIGN (Clause);
};
typedef vector<Clause*> Clauses;
typedef std::vector<Clause*> Clauses;
class LitLvTypes
{
class LitLvTypes {
public:
struct CompareLitLvTypes
{
bool operator() (
const LitLvTypes& types1,
const LitLvTypes& types2) const
{
if (types1.lid_ < types2.lid_) {
return true;
}
if (types1.lid_ == types2.lid_) {
return types1.lvTypes_ < types2.lvTypes_;
}
return false;
}
};
LitLvTypes (LiteralId lid, const LogVarTypes& lvTypes) :
lid_(lid), lvTypes_(lvTypes) { }
LiteralId lid (void) const { return lid_; }
LiteralId lid() const { return lid_; }
const LogVarTypes& logVarTypes (void) const { return lvTypes_; }
const LogVarTypes& logVarTypes() const { return lvTypes_; }
void setAllFullLogVars (void) {
std::fill (lvTypes_.begin(), lvTypes_.end(), LogVarType::FULL_LV); }
friend std::ostream& operator<< (std::ostream &os, const LitLvTypes& lit);
void setAllFullLogVars();
private:
friend std::ostream& operator<< (std::ostream&, const LitLvTypes&);
LiteralId lid_;
LogVarTypes lvTypes_;
};
typedef TinySet<LitLvTypes,LitLvTypes::CompareLitLvTypes> LitLvTypesSet;
class LiftedWCNF
struct CmpLitLvTypes
{
bool operator() (
const LitLvTypes& types1,
const LitLvTypes& types2) const
{
if (types1.lid() < types2.lid()) {
return true;
}
// vsc if (types1.lid() == types2.lid()){
// return types1.logVarTypes() < types2.logVarTypes();
//}
return false;
}
};
typedef TinySet<LitLvTypes, CmpLitLvTypes> LitLvTypesSet;
class LiftedWCNF {
public:
LiftedWCNF (const ParfactorList& pfList);
~LiftedWCNF (void);
~LiftedWCNF();
const Clauses& clauses (void) const { return clauses_; }
const Clauses& clauses() const { return clauses_; }
void addWeight (LiteralId lid, double posW, double negW);
@ -211,15 +216,15 @@ class LiftedWCNF
double negWeight (LiteralId lid) const;
vector<LiteralId> prvGroupLiterals (PrvGroup prvGroup);
std::vector<LiteralId> prvGroupLiterals (PrvGroup prvGroup);
Clause* createClause (LiteralId lid) const;
void printFormulaIndicators (void) const;
void printFormulaIndicators() const;
void printWeights (void) const;
void printWeights() const;
void printClauses (void) const;
void printClauses() const;
private:
LiteralId getLiteralId (PrvGroup prvGroup, unsigned range);
@ -228,14 +233,16 @@ class LiftedWCNF
void addParameterClauses (const ParfactorList& pfList);
Clauses clauses_;
LiteralId freeLiteralId_;
const ParfactorList& pfList_;
unordered_map<PrvGroup, vector<LiteralId>> map_;
unordered_map<LiteralId, std::pair<double,double>> weights_;
Clauses clauses_;
LiteralId freeLiteralId_;
const ParfactorList& pfList_;
std::unordered_map<PrvGroup, std::vector<LiteralId>> map_;
std::unordered_map<LiteralId, std::pair<double,double>> weights_;
DISALLOW_COPY_AND_ASSIGN (LiftedWCNF);
};
#endif // HORUS_LIFTEDWCNF_H
} // namespace Horus
#endif // YAP_PACKAGES_CLPBN_HORUS_LIFTEDWCNF_H_

View File

@ -43,9 +43,9 @@ SO=@SO@
#4.1VPATH=@srcdir@:@srcdir@/OPTYap
CWD=$(PWD)
HCLI = $(srcdir)/hcli
utestsdir=@srcdir@/unit_tests
HEADERS = \
MAIN_HEADERS = \
$(srcdir)/BayesBall.h \
$(srcdir)/BayesBallGraph.h \
$(srcdir)/BeliefProp.h \
@ -54,6 +54,8 @@ HEADERS = \
$(srcdir)/ElimGraph.h \
$(srcdir)/Factor.h \
$(srcdir)/FactorGraph.h \
$(srcdir)/GenericFactor.h \
$(srcdir)/GroundSolver.h \
$(srcdir)/Histogram.h \
$(srcdir)/Horus.h \
$(srcdir)/Indexer.h \
@ -67,14 +69,20 @@ HEADERS = \
$(srcdir)/Parfactor.h \
$(srcdir)/ParfactorList.h \
$(srcdir)/ProbFormula.h \
$(srcdir)/GroundSolver.h \
$(srcdir)/TinySet.h \
$(srcdir)/Util.h \
$(srcdir)/Var.h \
$(srcdir)/VarElim.h \
$(srcdir)/WeightedBp.h
CPP_SOURCES = \
UTESTS_HEADERS = \
$(utestsdir)/Common.h
HEADERS = \
$(MAIN_HEADERS) \
$(UTESTS_HEADERS)
MAIN_SOURCES = \
$(srcdir)/BayesBall.cpp \
$(srcdir)/BayesBallGraph.cpp \
$(srcdir)/BeliefProp.cpp \
@ -83,9 +91,12 @@ CPP_SOURCES = \
$(srcdir)/ElimGraph.cpp \
$(srcdir)/Factor.cpp \
$(srcdir)/FactorGraph.cpp \
$(srcdir)/GenericFactor.cpp \
$(srcdir)/GroundSolver.cpp \
$(srcdir)/Histogram.cpp \
$(srcdir)/HorusCli.cpp \
$(srcdir)/HorusYap.cpp \
$(srcdir)/Indexer.cpp \
$(srcdir)/LiftedBp.cpp \
$(srcdir)/LiftedKc.cpp \
$(srcdir)/LiftedOperations.cpp \
@ -95,12 +106,23 @@ CPP_SOURCES = \
$(srcdir)/Parfactor.cpp \
$(srcdir)/ParfactorList.cpp \
$(srcdir)/ProbFormula.cpp \
$(srcdir)/GroundSolver.cpp \
$(srcdir)/Util.cpp \
$(srcdir)/Var.cpp \
$(srcdir)/VarElim.cpp \
$(srcdir)/WeightedBp.cpp
UTESTS_SOURCES = \
$(utestsdir)/BeliefPropTest.cpp \
$(utestsdir)/Common.cpp \
$(utestsdir)/CountingBpTest.cpp \
$(utestsdir)/FactorTest.cpp \
$(utestsdir)/VarElimTest.cpp \
$(utestsdir)/UnitTesting.cpp
SOURCES = \
$(MAIN_SOURCES) \
$(UTESTS_SOURCES)
OBJS = \
BayesBall.o \
BayesBallGraph.o \
@ -110,8 +132,10 @@ OBJS = \
ElimGraph.o \
Factor.o \
FactorGraph.o \
GenericFactor.o \
GroundSolver.o \
Histogram.o \
HorusYap.o \
Indexer.o \
LiftedBp.o \
LiftedKc.o \
LiftedOperations.o \
@ -121,12 +145,15 @@ OBJS = \
ProbFormula.o \
Parfactor.o \
ParfactorList.o \
GroundSolver.o \
Util.o \
Var.o \
VarElim.o \
WeightedBp.o
LIB_OBJS = \
$(OBJS) \
HorusYap.o
HCLI_OBJS = \
BayesBall.o \
BayesBallGraph.o \
@ -135,51 +162,82 @@ HCLI_OBJS = \
ElimGraph.o \
Factor.o \
FactorGraph.o \
HorusCli.o \
GenericFactor.o \
GroundSolver.o \
HorusCli.o \
Indexer.o \
Util.o \
Var.o \
VarElim.o \
WeightedBp.o
SOBJS=horus.@SO@
UTESTS_OBJS = \
$(OBJS) \
$(utestsdir)/BeliefPropTest.o \
$(utestsdir)/Common.o \
$(utestsdir)/CountingBpTest.o \
$(utestsdir)/FactorTest.o \
$(utestsdir)/VarElimTest.o \
$(utestsdir)/UnitTesting.o
all: $(SOBJS) hcli
LIB = $(srcdir)/horus.@SO@
HCLI = $(srcdir)/hcli
UTESTING = $(srcdir)/run_tests
all: $(LIB) $(HCLI)
# Don't require $(UTESTING) by default as we
# don't want a hard dependency on CppUnit
with_tests: $(LIB) $(HCLI) $(UTESTING)
@DO_SECOND_LD@$(LIB): $(LIB_OBJS)
@DO_SECOND_LD@ @SHLIB_CXX_LD@ -o $@ $(LIB_OBJS) @EXTRA_LIBS_FOR_SWIDLLS@
$(HCLI): $(HCLI_OBJS)
$(CXX) -o $@ $(HCLI_OBJS)
$(UTESTING): $(UTESTS_OBJS)
$(CXX) -o $@ $(UTESTS_OBJS) -lcppunit
# default rule
%.o : $(srcdir)/%.cpp
$(CXX) -c $(CXXFLAGS) $< -o $@
@DO_SECOND_LD@horus.@SO@: $(OBJS)
@DO_SECOND_LD@ @SHLIB_CXX_LD@ -o horus.@SO@ $(OBJS) @EXTRA_LIBS_FOR_SWIDLLS@
hcli: $(HCLI_OBJS)
$(CXX) -o $(HCLI) $(HCLI_OBJS)
$(CXX) -o $@ -c $(CXXFLAGS) $<
install: all
$(INSTALL_PROGRAM) $(SOBJS) $(DESTDIR)$(YAPLIBDIR)
$(INSTALL_PROGRAM) $(LIB) $(DESTDIR)$(YAPLIBDIR)
$(INSTALL_PROGRAM) $(HCLI) $(DESTDIR)$(BINDIR)
clean:
rm -f *.o *~ $(OBJS) $(SOBJS) $(HCLI) *.BAK
rm -f $(LIB) $(HCLI) $(UTESTING) *.o *~ $(utestsdir)/*.o $(utestsdir)/*~
erase_dots:
rm -f *.dot *.png
remove_dots:
rm -f *.dot *.png *.svg
depend: $(HEADERS) $(CPP_SOURCES)
depend: $(SOURCES) $(HEADERS)
-@if test "$(GCC)" = yes; then\
$(CC) -std=c++0x -MM -MG $(CFLAGS) -I$(srcdir) -I$(srcdir)/../../../../include -I$(srcdir)/../../../../H $(CPP_SOURCES) >> Makefile;\
for F in $(SOURCES); do \
D=`dirname $$F`; \
B=`basename $$F .cpp`; \
$(CXX) $(CXXFLAGS) -MM -MG -MT "$$D/$$B.o" -I$(srcdir)/../../../../H -I$(srcdir)/../../../../include $$F >> Makefile; \
done; \
else\
makedepend -f - -- $(CFLAGS) -I$(srcdir)/../../../../H -I$(srcdir)/../../../../include -- $(CPP_SOURCES) |\
sed 's|.*/\([^:]*\):|\1:|' >> Makefile ;\
makedepend -- $(CXXFLAGS) -- -I$(srcdir)/../../../../H -I$(srcdir)/../../../../include $(SOURCES); \
fi
.PHONY: default all install clean remove_dots depend
# DO NOT DELETE THIS LINE -- make depend depends on it.

View File

@ -1,3 +1,8 @@
#include <cassert>
#include <iostream>
#include <sstream>
#include "Parfactor.h"
#include "Histogram.h"
#include "Indexer.h"
@ -5,6 +10,8 @@
#include "Horus.h"
namespace Horus {
Parfactor::Parfactor (
const ProbFormulas& formulas,
const Params& params,
@ -84,7 +91,7 @@ Parfactor::Parfactor (const Parfactor& g)
Parfactor::~Parfactor (void)
Parfactor::~Parfactor()
{
delete constr_;
}
@ -92,7 +99,7 @@ Parfactor::~Parfactor (void)
LogVarSet
Parfactor::countedLogVars (void) const
Parfactor::countedLogVars() const
{
LogVarSet set;
for (size_t i = 0; i < args_.size(); i++) {
@ -106,7 +113,7 @@ Parfactor::countedLogVars (void) const
LogVarSet
Parfactor::uncountedLogVars (void) const
Parfactor::uncountedLogVars() const
{
return constr_->logVarSet() - countedLogVars();
}
@ -114,7 +121,7 @@ Parfactor::uncountedLogVars (void) const
LogVarSet
Parfactor::elimLogVars (void) const
Parfactor::elimLogVars() const
{
LogVarSet requiredToElim = constr_->logVarSet();
requiredToElim -= constr_->singletons();
@ -149,7 +156,7 @@ Parfactor::sumOutIndex (size_t fIdx)
unsigned N = constr_->getConditionalCount (
args_[fIdx].countedLogVar());
unsigned R = args_[fIdx].range();
vector<double> numAssigns = HistogramSet::getNumAssigns (N, R);
std::vector<double> numAssigns = HistogramSet::getNumAssigns (N, R);
Indexer indexer (ranges_, fIdx);
while (indexer.valid()) {
if (Globals::logDomain) {
@ -171,7 +178,7 @@ Parfactor::sumOutIndex (size_t fIdx)
}
constr_->remove (excl);
TFactor<ProbFormula>::sumOutIndex (fIdx);
GenericFactor<ProbFormula>::sumOutIndex (fIdx);
LogAware::pow (params_, exp);
}
@ -181,7 +188,7 @@ void
Parfactor::multiply (Parfactor& g)
{
alignAndExponentiate (this, &g);
TFactor<ProbFormula>::multiply (g);
GenericFactor<ProbFormula>::multiply (g);
constr_->join (g.constr(), true);
simplifyGrounds();
assert (constr_->isCartesianProduct (countedLogVars()));
@ -224,10 +231,10 @@ Parfactor::countConvert (LogVar X)
unsigned N = constr_->getConditionalCount (X);
unsigned R = ranges_[fIdx];
unsigned H = HistogramSet::nrHistograms (N, R);
vector<Histogram> histograms = HistogramSet::getHistograms (N, R);
std::vector<Histogram> histograms = HistogramSet::getHistograms (N, R);
Indexer indexer (ranges_);
vector<Params> sumout (params_.size() / R);
std::vector<Params> sumout (params_.size() / R);
unsigned count = 0;
while (indexer.valid()) {
sumout[count].reserve (R);
@ -279,11 +286,11 @@ Parfactor::expand (LogVar X, LogVar X_new1, LogVar X_new2)
unsigned H1 = HistogramSet::nrHistograms (N1, R);
unsigned H2 = HistogramSet::nrHistograms (N2, R);
vector<Histogram> histograms = HistogramSet::getHistograms (N, R);
vector<Histogram> histograms1 = HistogramSet::getHistograms (N1, R);
vector<Histogram> histograms2 = HistogramSet::getHistograms (N2, R);
std::vector<Histogram> histograms = HistogramSet::getHistograms (N, R);
std::vector<Histogram> histograms1 = HistogramSet::getHistograms (N1, R);
std::vector<Histogram> histograms2 = HistogramSet::getHistograms (N2, R);
vector<unsigned> sumIndexes;
std::vector<unsigned> sumIndexes;
sumIndexes.reserve (H1 * H2);
for (unsigned i = 0; i < H1; i++) {
for (unsigned j = 0; j < H2; j++) {
@ -319,16 +326,16 @@ Parfactor::fullExpand (LogVar X)
unsigned N = constr_->getConditionalCount (X);
unsigned R = args_[fIdx].range();
vector<Histogram> originHists = HistogramSet::getHistograms (N, R);
vector<Histogram> expandHists = HistogramSet::getHistograms (1, R);
std::vector<Histogram> originHists = HistogramSet::getHistograms (N, R);
std::vector<Histogram> expandHists = HistogramSet::getHistograms (1, R);
assert (ranges_[fIdx] == originHists.size());
vector<unsigned> sumIndexes;
std::vector<unsigned> sumIndexes;
sumIndexes.reserve (N * R);
Ranges expandRanges (N, R);
Indexer indexer (expandRanges);
while (indexer.valid()) {
vector<unsigned> hist (R, 0);
std::vector<unsigned> hist (R, 0);
for (unsigned n = 0; n < N; n++) {
hist += expandHists[indexer[n]];
}
@ -384,14 +391,14 @@ Parfactor::absorveEvidence (const ProbFormula& formula, unsigned evidence)
assert (args_[fIdx].isCounting() == false);
assert (constr_->isCountNormalized (excl));
LogAware::pow (params_, constr_->getConditionalCount (excl));
TFactor<ProbFormula>::absorveEvidence (formula, evidence);
GenericFactor<ProbFormula>::absorveEvidence (formula, evidence);
constr_->remove (excl);
}
void
Parfactor::setNewGroups (void)
Parfactor::setNewGroups()
{
for (size_t i = 0; i < args_.size(); i++) {
args_[i].setGroup (ProbFormula::getNewGroup());
@ -494,7 +501,7 @@ Parfactor::containsGroup (PrvGroup group) const
bool
Parfactor::containsGroups (vector<PrvGroup> groups) const
Parfactor::containsGroups (std::vector<PrvGroup> groups) const
{
for (size_t i = 0; i < groups.size(); i++) {
if (containsGroup (groups[i]) == false) {
@ -565,10 +572,10 @@ Parfactor::nrFormulasWithGroup (PrvGroup group) const
vector<PrvGroup>
Parfactor::getAllGroups (void) const
std::vector<PrvGroup>
Parfactor::getAllGroups() const
{
vector<PrvGroup> groups (args_.size());
std::vector<PrvGroup> groups (args_.size());
for (size_t i = 0; i < args_.size(); i++) {
groups[i] = args_[i].group();
}
@ -577,10 +584,10 @@ Parfactor::getAllGroups (void) const
string
Parfactor::getLabel (void) const
std::string
Parfactor::getLabel() const
{
stringstream ss;
std::stringstream ss;
ss << "phi(" ;
for (size_t i = 0; i < args_.size(); i++) {
if (i != 0) ss << "," ;
@ -598,6 +605,8 @@ Parfactor::getLabel (void) const
void
Parfactor::print (bool printParams) const
{
using std::cout;
using std::endl;
cout << "Formulas: " ;
for (size_t i = 0; i < args_.size(); i++) {
if (i != 0) cout << ", " ;
@ -605,9 +614,10 @@ Parfactor::print (bool printParams) const
}
cout << endl;
if (args_[0].group() != Util::maxUnsigned()) {
vector<string> groups;
std::vector<std::string> groups;
for (size_t i = 0; i < args_.size(); i++) {
groups.push_back (string ("g") + Util::toString (args_[i].group()));
groups.push_back (std::string ("g")
+ Util::toString (args_[i].group()));
}
cout << "Groups: " << groups << endl;
}
@ -633,12 +643,12 @@ Parfactor::print (bool printParams) const
void
Parfactor::printParameters (void) const
Parfactor::printParameters() const
{
vector<string> jointStrings;
std::vector<std::string> jointStrings;
Indexer indexer (ranges_);
while (indexer.valid()) {
stringstream ss;
std::stringstream ss;
for (size_t i = 0; i < args_.size(); i++) {
if (i != 0) ss << ", " ;
if (args_[i].isCounting()) {
@ -659,22 +669,22 @@ Parfactor::printParameters (void) const
++ indexer;
}
for (size_t i = 0; i < params_.size(); i++) {
cout << "f(" << jointStrings[i] << ")" ;
cout << " = " << params_[i] << endl;
std::cout << "f(" << jointStrings[i] << ")" ;
std::cout << " = " << params_[i] << std::endl;
}
}
void
Parfactor::printProjections (void) const
Parfactor::printProjections() const
{
ConstraintTree copy (*constr_);
LogVarSet Xs = copy.logVarSet();
for (size_t i = 0; i < Xs.size(); i++) {
cout << "-> projection of " << Xs[i] << ": " ;
cout << copy.tupleSet ({Xs[i]}) << endl;
std::cout << "-> projection of " << Xs[i] << ": " ;
std::cout << copy.tupleSet ({Xs[i]}) << std::endl;
}
}
@ -684,12 +694,12 @@ void
Parfactor::expandPotential (
size_t fIdx,
unsigned newRange,
const vector<unsigned>& sumIndexes)
const std::vector<unsigned>& sumIndexes)
{
ullong newSize = (params_.size() / ranges_[fIdx]) * newRange;
if (newSize > params_.max_size()) {
cerr << "Error: an overflow occurred when performing expansion." ;
cerr << endl;
std::cerr << "Error: an overflow occurred when performing expansion." ;
std::cerr << std::endl;
exit (EXIT_FAILURE);
}
@ -698,7 +708,7 @@ Parfactor::expandPotential (
params_.reserve (newSize);
size_t prod = 1;
vector<size_t> offsets (ranges_.size());
std::vector<size_t> offsets (ranges_.size());
for (size_t i = ranges_.size(); i-- > 0; ) {
offsets[i] = prod;
prod *= ranges_[i];
@ -706,7 +716,7 @@ Parfactor::expandPotential (
size_t index = 0;
ranges_[fIdx] = newRange;
vector<unsigned> indices (ranges_.size(), 0);
std::vector<unsigned> indices (ranges_.size(), 0);
for (size_t k = 0; k < newSize; k++) {
assert (index < backup.size());
params_.push_back (backup[index]);
@ -759,7 +769,7 @@ Parfactor::simplifyCountingFormulas (size_t fIdx)
void
Parfactor::simplifyGrounds (void)
Parfactor::simplifyGrounds()
{
if (args_.size() == 1) {
return;
@ -872,12 +882,12 @@ Parfactor::alignLogicalVars (Parfactor* g1, Parfactor* g2)
std::pair<LogVars, LogVars> res = getAlignLogVars (g1, g2);
const LogVars& alignLvs1 = res.first;
const LogVars& alignLvs2 = res.second;
// cout << "ALIGNING :::::::::::::::::" << endl;
// std::cout << "ALIGNING :::::::::::::::::" << std::endl;
// g1->print();
// cout << "AND" << endl;
// g2->print();
// cout << "-> align lvs1 = " << alignLvs1 << endl;
// cout << "-> align lvs2 = " << alignLvs2 << endl;
// std::cout << "-> align lvs1 = " << alignLvs1 << std::endl;
// std::cout << "-> align lvs2 = " << alignLvs2 << std::endl;
LogVar freeLogVar (0);
Substitution theta1, theta2;
for (size_t i = 0; i < alignLvs1.size(); i++) {
@ -933,9 +943,11 @@ Parfactor::alignLogicalVars (Parfactor* g1, Parfactor* g2)
}
}
// cout << "theta1: " << theta1 << endl;
// cout << "theta2: " << theta2 << endl;
// std::cout << "theta1: " << theta1 << std::endl;
// std::cout << "theta2: " << theta2 << std::endl;
g1->applySubstitution (theta1);
g2->applySubstitution (theta2);
}
} // namespace Horus

View File

@ -1,20 +1,21 @@
#ifndef HORUS_PARFACTOR_H
#define HORUS_PARFACTOR_H
#ifndef YAP_PACKAGES_CLPBN_HORUS_PARFACTOR_H_
#define YAP_PACKAGES_CLPBN_HORUS_PARFACTOR_H_
#include "Factor.h"
#include <vector>
#include <string>
#include "GenericFactor.h"
#include "ProbFormula.h"
#include "ConstraintTree.h"
#include "LiftedUtils.h"
#include "Horus.h"
class Parfactor : public TFactor<ProbFormula>
{
namespace Horus {
class Parfactor : public GenericFactor<ProbFormula> {
public:
Parfactor (
const ProbFormulas&,
const Params&,
const Tuples&,
Parfactor (const ProbFormulas&, const Params&, const Tuples&,
unsigned distId);
Parfactor (const Parfactor*, const Tuple&);
@ -23,21 +24,21 @@ class Parfactor : public TFactor<ProbFormula>
Parfactor (const Parfactor&);
~Parfactor (void);
~Parfactor();
ConstraintTree* constr (void) { return constr_; }
ConstraintTree* constr() { return constr_; }
const ConstraintTree* constr (void) const { return constr_; }
const ConstraintTree* constr() const { return constr_; }
const LogVars& logVars (void) const { return constr_->logVars(); }
const LogVars& logVars() const { return constr_->logVars(); }
const LogVarSet& logVarSet (void) const { return constr_->logVarSet(); }
const LogVarSet& logVarSet() const { return constr_->logVarSet(); }
LogVarSet countedLogVars (void) const;
LogVarSet countedLogVars() const;
LogVarSet uncountedLogVars (void) const;
LogVarSet uncountedLogVars() const;
LogVarSet elimLogVars (void) const;
LogVarSet elimLogVars() const;
LogVarSet exclusiveLogVars (size_t fIdx) const;
@ -57,7 +58,7 @@ class Parfactor : public TFactor<ProbFormula>
void absorveEvidence (const ProbFormula&, unsigned);
void setNewGroups (void);
void setNewGroups();
void applySubstitution (const Substitution&);
@ -71,7 +72,7 @@ class Parfactor : public TFactor<ProbFormula>
bool containsGroup (PrvGroup) const;
bool containsGroups (vector<PrvGroup>) const;
bool containsGroups (std::vector<PrvGroup>) const;
unsigned nrFormulas (LogVar) const;
@ -81,17 +82,17 @@ class Parfactor : public TFactor<ProbFormula>
unsigned nrFormulasWithGroup (PrvGroup) const;
vector<PrvGroup> getAllGroups (void) const;
std::vector<PrvGroup> getAllGroups() const;
void print (bool = false) const;
void printParameters (void) const;
void printParameters() const;
void printProjections (void) const;
void printProjections() const;
string getLabel (void) const;
std::string getLabel() const;
void simplifyGrounds (void);
void simplifyGrounds();
static bool canMultiply (Parfactor*, Parfactor*);
@ -104,18 +105,20 @@ class Parfactor : public TFactor<ProbFormula>
Parfactor* g1, Parfactor* g2);
void expandPotential (size_t fIdx, unsigned newRange,
const vector<unsigned>& sumIndexes);
const std::vector<unsigned>& sumIndexes);
static void alignAndExponentiate (Parfactor*, Parfactor*);
static void alignLogicalVars (Parfactor*, Parfactor*);
ConstraintTree* constr_;
ConstraintTree* constr_;
DISALLOW_ASSIGN (Parfactor);
};
typedef vector<Parfactor*> Parfactors;
typedef std::vector<Parfactor*> Parfactors;
#endif // HORUS_PARFACTOR_H
} // namespace Horus
#endif // YAP_PACKAGES_CLPBN_HORUS_PARFACTOR_H_

View File

@ -1,10 +1,14 @@
#include <cassert>
#include <queue>
#include <iostream>
#include <sstream>
#include "ParfactorList.h"
namespace Horus {
ParfactorList::ParfactorList (const ParfactorList& pfList)
{
ParfactorList::const_iterator it = pfList.begin();
@ -23,7 +27,7 @@ ParfactorList::ParfactorList (const Parfactors& pfs)
ParfactorList::~ParfactorList (void)
ParfactorList::~ParfactorList()
{
ParfactorList::const_iterator it = pfList_.begin();
while (it != pfList_.end()) {
@ -64,27 +68,27 @@ ParfactorList::addShattered (Parfactor* pf)
list<Parfactor*>::iterator
std::list<Parfactor*>::iterator
ParfactorList::insertShattered (
list<Parfactor*>::iterator it,
std::list<Parfactor*>::iterator it,
Parfactor* pf)
{
return pfList_.insert (it, pf);
assert (isAllShattered());
return pfList_.insert (it, pf);
}
list<Parfactor*>::iterator
ParfactorList::remove (list<Parfactor*>::iterator it)
std::list<Parfactor*>::iterator
ParfactorList::remove (std::list<Parfactor*>::iterator it)
{
return pfList_.erase (it);
}
list<Parfactor*>::iterator
ParfactorList::removeAndDelete (list<Parfactor*>::iterator it)
std::list<Parfactor*>::iterator
ParfactorList::removeAndDelete (std::list<Parfactor*>::iterator it)
{
delete *it;
return pfList_.erase (it);
@ -93,12 +97,12 @@ ParfactorList::removeAndDelete (list<Parfactor*>::iterator it)
bool
ParfactorList::isAllShattered (void) const
ParfactorList::isAllShattered() const
{
if (pfList_.size() <= 1) {
return true;
}
vector<Parfactor*> pfs (pfList_.begin(), pfList_.end());
Parfactors pfs (pfList_.begin(), pfList_.end());
for (size_t i = 0; i < pfs.size(); i++) {
assert (isShattered (pfs[i]));
}
@ -115,13 +119,25 @@ ParfactorList::isAllShattered (void) const
void
ParfactorList::print (void) const
ParfactorList::print() const
{
struct sortByParams {
bool operator() (const Parfactor* pf1, const Parfactor* pf2)
{
if (pf1->params().size() < pf2->params().size()) {
return true;
} else if (pf1->params().size() == pf2->params().size() &&
pf1->params() < pf2->params()) {
return true;
}
return false;
}
};
Parfactors pfVec (pfList_.begin(), pfList_.end());
std::sort (pfVec.begin(), pfVec.end(), sortByParams());
// vsc std::sort (pfVec.begin(), pfVec.end(), sortByParams());
for (size_t i = 0; i < pfVec.size(); i++) {
pfVec[i]->print();
cout << endl;
std::cout << std::endl;
}
}
@ -163,8 +179,8 @@ ParfactorList::isShattered (const Parfactor* g) const
formulas[i], *(g->constr()),
formulas[j], *(g->constr())) == false) {
g->print();
cout << "-> not identical on positions " ;
cout << i << " and " << j << endl;
std::cout << "-> not identical on positions " ;
std::cout << i << " and " << j << std::endl;
return false;
}
} else {
@ -172,8 +188,8 @@ ParfactorList::isShattered (const Parfactor* g) const
formulas[i], *(g->constr()),
formulas[j], *(g->constr())) == false) {
g->print();
cout << "-> not disjoint on positions " ;
cout << i << " and " << j << endl;
std::cout << "-> not disjoint on positions " ;
std::cout << i << " and " << j << std::endl;
return false;
}
}
@ -200,9 +216,10 @@ ParfactorList::isShattered (
fms1[i], *(g1->constr()),
fms2[j], *(g2->constr())) == false) {
g1->print();
cout << "^" << endl;
std::cout << "^" << std::endl;
g2->print();
cout << "-> not identical on group " << fms1[i].group() << endl;
std::cout << "-> not identical on group " ;
std::cout << fms1[i].group() << std::endl;
return false;
}
} else {
@ -210,10 +227,10 @@ ParfactorList::isShattered (
fms1[i], *(g1->constr()),
fms2[j], *(g2->constr())) == false) {
g1->print();
cout << "^" << endl;
std::cout << "^" << std::endl;
g2->print();
cout << "-> not disjoint on groups " << fms1[i].group();
cout << " and " << fms2[j].group() << endl;
std::cout << "-> not disjoint on groups " << fms1[i].group();
std::cout << " and " << fms2[j].group() << std::endl;
return false;
}
}
@ -227,12 +244,12 @@ ParfactorList::isShattered (
void
ParfactorList::addToShatteredList (Parfactor* g)
{
queue<Parfactor*> residuals;
std::queue<Parfactor*> residuals;
residuals.push (g);
while (residuals.empty() == false) {
Parfactor* pf = residuals.front();
bool pfSplitted = false;
list<Parfactor*>::iterator pfIter;
std::list<Parfactor*>::iterator pfIter;
pfIter = pfList_.begin();
while (pfIter != pfList_.end()) {
std::pair<Parfactors, Parfactors> shattRes;
@ -269,7 +286,7 @@ Parfactors
ParfactorList::shatterAgainstMySelf (Parfactor* g)
{
Parfactors pfs;
queue<Parfactor*> residuals;
std::queue<Parfactor*> residuals;
residuals.push (g);
bool shattered = true;
while (residuals.empty() == false) {
@ -325,19 +342,22 @@ ParfactorList::shatterAgainstMySelf (
{
/*
Util::printDashedLine();
cout << "-> SHATTERING" << endl;
std::cout << "-> SHATTERING" << std::endl;
g->print();
cout << "-> ON: " << g->argument (fIdx1) << "|" ;
cout << g->constr()->tupleSet (g->argument (fIdx1).logVars()) << endl;
cout << "-> ON: " << g->argument (fIdx2) << "|" ;
cout << g->constr()->tupleSet (g->argument (fIdx2).logVars()) << endl;
std::cout << "-> ON: " << g->argument (fIdx1) << "|" ;
std::cout << g->constr()->tupleSet (g->argument (fIdx1).logVars());
std::cout << std::endl;
std::cout << "-> ON: " << g->argument (fIdx2) << "|" ;
std::cout << g->constr()->tupleSet (g->argument (fIdx2).logVars())
std::cout << std::endl;
Util::printDashedLine();
*/
ProbFormula& f1 = g->argument (fIdx1);
ProbFormula& f2 = g->argument (fIdx2);
if (f1.isAtom()) {
cerr << "Error: a ground occurs twice in the same parfactor." << endl;
cerr << endl;
std::cerr << "Error: a ground occurs twice in the same parfactor." ;
std::cerr << std::endl;
std::cerr << std::endl;
exit (EXIT_FAILURE);
}
assert (g->constr()->empty() == false);
@ -441,14 +461,14 @@ ParfactorList::shatter (
ProbFormula& f2 = g2->argument (fIdx2);
/*
Util::printDashedLine();
cout << "-> SHATTERING" << endl;
std::cout << "-> SHATTERING" << std::endl;
g1->print();
cout << "-> WITH" << endl;
std::cout << "-> WITH" << std::endl;
g2->print();
cout << "-> ON: " << f1 << "|" ;
cout << g1->constr()->tupleSet (f1.logVars()) << endl;
cout << "-> ON: " << f2 << "|" ;
cout << g2->constr()->tupleSet (f2.logVars()) << endl;
std::cout << "-> ON: " << f1 << "|" ;
std::cout << g1->constr()->tupleSet (f1.logVars()) << std::endl;
std::cout << "-> ON: " << f2 << "|" ;
std::cout << g2->constr()->tupleSet (f2.logVars()) << std::endl;
Util::printDashedLine();
*/
if (f1.isAtom()) {
@ -486,12 +506,12 @@ ParfactorList::shatter (
assert (commCt1->tupleSet (f1.logVars()) ==
commCt2->tupleSet (f2.logVars()));
// stringstream ss1; ss1 << "" << count << "_A.dot" ;
// stringstream ss2; ss2 << "" << count << "_B.dot" ;
// stringstream ss3; ss3 << "" << count << "_A_comm.dot" ;
// stringstream ss4; ss4 << "" << count << "_A_excl.dot" ;
// stringstream ss5; ss5 << "" << count << "_B_comm.dot" ;
// stringstream ss6; ss6 << "" << count << "_B_excl.dot" ;
// std::stringstream ss1; ss1 << "" << count << "_A.dot" ;
// std::stringstream ss2; ss2 << "" << count << "_B.dot" ;
// std::stringstream ss3; ss3 << "" << count << "_A_comm.dot" ;
// std::stringstream ss4; ss4 << "" << count << "_A_excl.dot" ;
// std::stringstream ss5; ss5 << "" << count << "_B_comm.dot" ;
// std::stringstream ss6; ss6 << "" << count << "_B_excl.dot" ;
// g1->constr()->exportToGraphViz (ss1.str().c_str(), true);
// g2->constr()->exportToGraphViz (ss2.str().c_str(), true);
// commCt1->exportToGraphViz (ss3.str().c_str(), true);
@ -638,3 +658,5 @@ ParfactorList::disjoint (
return (ts1 & ts2).empty();
}
} // namespace Horus

View File

@ -1,5 +1,5 @@
#ifndef HORUS_PARFACTORLIST_H
#define HORUS_PARFACTORLIST_H
#ifndef YAP_PACKAGES_CLPBN_HORUS_PARFACTORLIST_H_
#define YAP_PACKAGES_CLPBN_HORUS_PARFACTORLIST_H_
#include <list>
@ -7,39 +7,38 @@
#include "ProbFormula.h"
using namespace std;
namespace Horus {
class Parfactor;
class ParfactorList
{
class ParfactorList {
public:
ParfactorList (void) { }
ParfactorList() { }
ParfactorList (const ParfactorList&);
ParfactorList (const Parfactors&);
~ParfactorList (void);
~ParfactorList();
const list<Parfactor*>& parfactors (void) const { return pfList_; }
const std::list<Parfactor*>& parfactors() const { return pfList_; }
void clear (void) { pfList_.clear(); }
void clear() { pfList_.clear(); }
size_t size (void) const { return pfList_.size(); }
size_t size() const { return pfList_.size(); }
typedef std::list<Parfactor*>::iterator iterator;
iterator begin (void) { return pfList_.begin(); }
iterator begin() { return pfList_.begin(); }
iterator end (void) { return pfList_.end(); }
iterator end() { return pfList_.end(); }
typedef std::list<Parfactor*>::const_iterator const_iterator;
const_iterator begin (void) const { return pfList_.begin(); }
const_iterator begin() const { return pfList_.begin(); }
const_iterator end (void) const { return pfList_.end(); }
const_iterator end() const { return pfList_.end(); }
void add (Parfactor* pf);
@ -47,16 +46,18 @@ class ParfactorList
void addShattered (Parfactor* pf);
list<Parfactor*>::iterator insertShattered (
list<Parfactor*>::iterator, Parfactor*);
std::list<Parfactor*>::iterator insertShattered (
std::list<Parfactor*>::iterator, Parfactor*);
list<Parfactor*>::iterator remove (list<Parfactor*>::iterator);
std::list<Parfactor*>::iterator remove (
std::list<Parfactor*>::iterator);
list<Parfactor*>::iterator removeAndDelete (list<Parfactor*>::iterator);
std::list<Parfactor*>::iterator removeAndDelete (
std::list<Parfactor*>::iterator);
bool isAllShattered (void) const;
bool isAllShattered() const;
void print (void) const;
void print() const;
ParfactorList& operator= (const ParfactorList& pfList);
@ -101,22 +102,10 @@ class ParfactorList
const ProbFormula&, ConstraintTree,
const ProbFormula&, ConstraintTree) const;
struct sortByParams
{
inline bool operator() (const Parfactor* pf1, const Parfactor* pf2)
{
if (pf1->params().size() < pf2->params().size()) {
return true;
} else if (pf1->params().size() == pf2->params().size() &&
pf1->params() < pf2->params()) {
return true;
}
return false;
}
};
list<Parfactor*> pfList_;
std::list<Parfactor*> pfList_;
};
#endif // HORUS_PARFACTORLIST_H
} // namespace Horus
#endif // YAP_PACKAGES_CLPBN_HORUS_PARFACTORLIST_H_

View File

@ -1,6 +1,13 @@
#include <cassert>
#include <iostream>
#include "ProbFormula.h"
namespace Horus {
PrvGroup ProbFormula::freeGroup_ = 0;
@ -38,7 +45,7 @@ ProbFormula::indexOf (LogVar X) const
bool
ProbFormula::isAtom (void) const
ProbFormula::isAtom() const
{
return logVars_.empty();
}
@ -46,7 +53,7 @@ ProbFormula::isAtom (void) const
bool
ProbFormula::isCounting (void) const
ProbFormula::isCounting() const
{
return countedLogVar_.valid();
}
@ -54,7 +61,7 @@ ProbFormula::isCounting (void) const
LogVar
ProbFormula::countedLogVar (void) const
ProbFormula::countedLogVar() const
{
assert (isCounting());
return countedLogVar_;
@ -71,7 +78,7 @@ ProbFormula::setCountedLogVar (LogVar lv)
void
ProbFormula::clearCountedLogVar (void)
ProbFormula::clearCountedLogVar()
{
countedLogVar_ = LogVar();
}
@ -93,15 +100,8 @@ ProbFormula::rename (LogVar oldName, LogVar newName)
bool operator== (const ProbFormula& f1, const ProbFormula& f2)
{
return f1.group_ == f2.group_ &&
f1.logVars_ == f2.logVars_;
}
std::ostream& operator<< (ostream &os, const ProbFormula& f)
std::ostream&
operator<< (std::ostream& os, const ProbFormula& f)
{
os << f.functor_;
if (f.isAtom() == false) {
@ -122,7 +122,7 @@ std::ostream& operator<< (ostream &os, const ProbFormula& f)
PrvGroup
ProbFormula::getNewGroup (void)
ProbFormula::getNewGroup()
{
freeGroup_ ++;
assert (freeGroup_ != std::numeric_limits<PrvGroup>::max());
@ -131,7 +131,24 @@ ProbFormula::getNewGroup (void)
ostream& operator<< (ostream &os, const ObservedFormula& of)
ObservedFormula::ObservedFormula (Symbol f, unsigned a, unsigned ev)
: functor_(f), arity_(a), evidence_(ev), constr_(a)
{
}
ObservedFormula::ObservedFormula (Symbol f, unsigned ev, const Tuple& tuple)
: functor_(f), arity_(tuple.size()), evidence_(ev), constr_(arity_)
{
constr_.addTuple (tuple);
}
std::ostream&
operator<< (std::ostream& os, const ObservedFormula& of)
{
os << of.functor_ << "/" << of.arity_;
os << "|" << of.constr_.tupleSet();
@ -139,3 +156,5 @@ ostream& operator<< (ostream &os, const ObservedFormula& of)
return os;
}
} // namespace Horus

View File

@ -1,16 +1,20 @@
#ifndef HORUS_PROBFORMULA_H
#define HORUS_PROBFORMULA_H
#ifndef YAP_PACKAGES_CLPBN_HORUS_PROBFORMULA_H_
#define YAP_PACKAGES_CLPBN_HORUS_PROBFORMULA_H_
#include <vector>
#include <ostream>
#include <limits>
#include "ConstraintTree.h"
#include "LiftedUtils.h"
#include "Horus.h"
namespace Horus {
typedef unsigned long PrvGroup;
class ProbFormula
{
class ProbFormula {
public:
ProbFormula (Symbol f, const LogVars& lvs, unsigned range)
: functor_(f), logVars_(lvs), range_(range),
@ -20,19 +24,19 @@ class ProbFormula
: functor_(f), range_(r),
group_(std::numeric_limits<PrvGroup>::max()) { }
Symbol functor (void) const { return functor_; }
Symbol functor() const { return functor_; }
unsigned arity (void) const { return logVars_.size(); }
unsigned arity() const { return logVars_.size(); }
unsigned range (void) const { return range_; }
unsigned range() const { return range_; }
LogVars& logVars (void) { return logVars_; }
LogVars& logVars() { return logVars_; }
const LogVars& logVars (void) const { return logVars_; }
const LogVars& logVars() const { return logVars_; }
LogVarSet logVarSet (void) const { return LogVarSet (logVars_); }
LogVarSet logVarSet() const { return LogVarSet (logVars_); }
PrvGroup group (void) const { return group_; }
PrvGroup group() const { return group_; }
void setGroup (PrvGroup g) { group_ = g; }
@ -44,25 +48,28 @@ class ProbFormula
size_t indexOf (LogVar) const;
bool isAtom (void) const;
bool isAtom() const;
bool isCounting (void) const;
bool isCounting() const;
LogVar countedLogVar (void) const;
LogVar countedLogVar() const;
void setCountedLogVar (LogVar);
void clearCountedLogVar (void);
void clearCountedLogVar();
void rename (LogVar, LogVar);
static PrvGroup getNewGroup (void);
friend std::ostream& operator<< (ostream &os, const ProbFormula& f);
friend bool operator== (const ProbFormula& f1, const ProbFormula& f2);
static PrvGroup getNewGroup();
private:
friend bool operator== (
const ProbFormula& f1, const ProbFormula& f2);
friend std::ostream& operator<< (
std::ostream&, const ProbFormula&);
Symbol functor_;
LogVars logVars_;
unsigned range_;
@ -71,45 +78,50 @@ class ProbFormula
static PrvGroup freeGroup_;
};
typedef vector<ProbFormula> ProbFormulas;
typedef std::vector<ProbFormula> ProbFormulas;
class ObservedFormula
inline bool
operator== (const ProbFormula& f1, const ProbFormula& f2)
{
return f1.group_ == f2.group_ && f1.logVars_ == f2.logVars_;
}
class ObservedFormula {
public:
ObservedFormula (Symbol f, unsigned a, unsigned ev)
: functor_(f), arity_(a), evidence_(ev), constr_(a) { }
ObservedFormula (Symbol f, unsigned a, unsigned ev);
ObservedFormula (Symbol f, unsigned ev, const Tuple& tuple)
: functor_(f), arity_(tuple.size()), evidence_(ev), constr_(arity_)
{
constr_.addTuple (tuple);
}
ObservedFormula (Symbol f, unsigned ev, const Tuple& tuple);
Symbol functor (void) const { return functor_; }
Symbol functor() const { return functor_; }
unsigned arity (void) const { return arity_; }
unsigned arity() const { return arity_; }
unsigned evidence (void) const { return evidence_; }
unsigned evidence() const { return evidence_; }
void setEvidence (unsigned ev) { evidence_ = ev; }
ConstraintTree& constr (void) { return constr_; }
ConstraintTree& constr() { return constr_; }
bool isAtom (void) const { return arity_ == 0; }
bool isAtom() const { return arity_ == 0; }
void addTuple (const Tuple& tuple) { constr_.addTuple (tuple); }
friend ostream& operator<< (ostream &os, const ObservedFormula& of);
private:
friend std::ostream& operator<< (
std::ostream&, const ObservedFormula&);
Symbol functor_;
unsigned arity_;
unsigned evidence_;
ConstraintTree constr_;
};
typedef vector<ObservedFormula> ObservedFormulas;
typedef std::vector<ObservedFormula> ObservedFormulas;
#endif // HORUS_PROBFORMULA_H
} // namespace Horus
#endif // YAP_PACKAGES_CLPBN_HORUS_PROBFORMULA_H_

View File

@ -1,20 +1,18 @@
#ifndef HORUS_TINYSET_H
#define HORUS_TINYSET_H
#include <algorithm>
#ifndef YAP_PACKAGES_CLPBN_HORUS_TINYSET_H_
#define YAP_PACKAGES_CLPBN_HORUS_TINYSET_H_
#include <vector>
#include <algorithm>
#include <ostream>
using namespace std;
namespace Horus {
template <typename T, typename Compare = std::less<T>>
class TinySet
{
class TinySet {
public:
typedef typename vector<T>::iterator iterator;
typedef typename vector<T>::const_iterator const_iterator;
typedef typename std::vector<T>::iterator iterator;
typedef typename std::vector<T>::const_iterator const_iterator;
TinySet (const TinySet& s)
: vec_(s.vec_), cmp_(s.cmp_) { }
@ -25,190 +23,72 @@ class TinySet
TinySet (const T& t, const Compare& cmp = Compare())
: vec_(1, t), cmp_(cmp) { }
TinySet (const vector<T>& elements, const Compare& cmp = Compare())
: vec_(elements), cmp_(cmp)
{
std::sort (begin(), end(), cmp_);
iterator it = unique_cmp (begin(), end());
vec_.resize (it - begin());
}
TinySet (const std::vector<T>& elements, const Compare& cmp = Compare());
iterator insert (const T& t)
{
iterator it = std::lower_bound (begin(), end(), t, cmp_);
if (it == end() || cmp_(t, *it)) {
vec_.insert (it, t);
}
return it;
}
iterator insert (const T& t);
void insert_sorted (const T& t)
{
vec_.push_back (t);
assert (consistent());
}
void insert_sorted (const T& t);
void remove (const T& t)
{
iterator it = std::lower_bound (begin(), end(), t, cmp_);
if (it != end()) {
vec_.erase (it);
}
}
void remove (const T& t);
const_iterator find (const T& t) const
{
const_iterator it = std::lower_bound (begin(), end(), t, cmp_);
return it == end() || cmp_(t, *it) ? end() : it;
}
iterator find (const T& t)
{
iterator it = std::lower_bound (begin(), end(), t, cmp_);
return it == end() || cmp_(t, *it) ? end() : it;
}
const_iterator find (const T& t) const;
iterator find (const T& t);
/* set union */
TinySet operator| (const TinySet& s) const
{
TinySet res;
std::set_union (
vec_.begin(), vec_.end(),
s.vec_.begin(), s.vec_.end(),
std::back_inserter (res.vec_),
cmp_);
return res;
}
TinySet operator| (const TinySet& s) const;
/* set intersection */
TinySet operator& (const TinySet& s) const
{
TinySet res;
std::set_intersection (
vec_.begin(), vec_.end(),
s.vec_.begin(), s.vec_.end(),
std::back_inserter (res.vec_),
cmp_);
return res;
}
TinySet operator& (const TinySet& s) const;
/* set difference */
TinySet operator- (const TinySet& s) const
{
TinySet res;
std::set_difference (
vec_.begin(), vec_.end(),
s.vec_.begin(), s.vec_.end(),
std::back_inserter (res.vec_),
cmp_);
return res;
}
TinySet operator- (const TinySet& s) const;
TinySet& operator|= (const TinySet& s)
{
return *this = (*this | s);
}
TinySet& operator|= (const TinySet& s);
TinySet& operator&= (const TinySet& s)
{
return *this = (*this & s);
}
TinySet& operator&= (const TinySet& s);
TinySet& operator-= (const TinySet& s)
{
return *this = (*this - s);
}
TinySet& operator-= (const TinySet& s);
bool contains (const T& t) const
{
return std::binary_search (
vec_.begin(), vec_.end(), t, cmp_);
}
bool contains (const T& t) const;
bool contains (const TinySet& s) const
{
return std::includes (
vec_.begin(),
vec_.end(),
s.vec_.begin(),
s.vec_.end(),
cmp_);
}
bool contains (const TinySet& s) const;
bool in (const TinySet& s) const
{
return std::includes (
s.vec_.begin(),
s.vec_.end(),
vec_.begin(),
vec_.end(),
cmp_);
}
bool in (const TinySet& s) const;
bool intersects (const TinySet& s) const
{
return (*this & s).size() > 0;
}
bool intersects (const TinySet& s) const;
const T& operator[] (typename vector<T>::size_type i) const
{
return vec_[i];
}
const T& operator[] (typename std::vector<T>::size_type i) const;
T& operator[] (typename vector<T>::size_type i)
{
return vec_[i];
}
T& operator[] (typename std::vector<T>::size_type i);
T front (void) const
{
return vec_.front();
}
T front() const;
T& front (void)
{
return vec_.front();
}
T& front();
T back (void) const
{
return vec_.back();
}
T back() const;
T& back (void)
{
return vec_.back();
}
T& back();
const vector<T>& elements (void) const
{
return vec_;
}
const std::vector<T>& elements() const;
bool empty (void) const
{
return vec_.empty();
}
bool empty() const;
typename vector<T>::size_type size (void) const
{
return vec_.size();
}
typename std::vector<T>::size_type size() const;
void clear (void)
{
vec_.clear();
}
void clear();
void reserve (typename vector<T>::size_type size)
{
vec_.reserve (size);
}
void reserve (typename std::vector<T>::size_type size);
iterator begin (void) { return vec_.begin(); }
iterator end (void) { return vec_.end(); }
const_iterator begin (void) const { return vec_.begin(); }
const_iterator end (void) const { return vec_.end(); }
iterator begin() { return vec_.begin(); }
iterator end () { return vec_.end(); }
const_iterator begin() const { return vec_.begin(); }
const_iterator end () const { return vec_.end(); }
private:
iterator unique_cmp (iterator first, iterator last);
bool consistent() const;
friend bool operator== (const TinySet& s1, const TinySet& s2)
{
@ -223,7 +103,7 @@ class TinySet
friend std::ostream& operator<< (std::ostream& out, const TinySet& s)
{
out << "{" ;
typename vector<T>::size_type i;
typename std::vector<T>::size_type i;
for (i = 0; i < s.size(); i++) {
out << ((i != 0) ? "," : "") << s.vec_[i];
}
@ -231,35 +111,299 @@ class TinySet
return out;
}
private:
iterator unique_cmp (iterator first, iterator last)
{
if (first == last) {
return last;
}
iterator result = first;
while (++first != last) {
if (cmp_(*result, *first)) {
*(++result) = *first;
}
}
return ++result;
}
bool consistent (void) const
{
typename vector<T>::size_type i;
for (i = 0; i < vec_.size() - 1; i++) {
if ( ! cmp_(vec_[i], vec_[i + 1])) {
return false;
}
}
return true;
}
vector<T> vec_;
Compare cmp_;
std::vector<T> vec_;
Compare cmp_;
};
#endif // HORUS_TINYSET_H
template <typename T, typename C> inline
TinySet<T,C>::TinySet (const std::vector<T>& elements, const C& cmp)
: vec_(elements), cmp_(cmp)
{
std::sort (begin(), end(), cmp_);
iterator it = unique_cmp (begin(), end());
vec_.resize (it - begin());
}
template <typename T, typename C> inline typename TinySet<T,C>::iterator
TinySet<T,C>::insert (const T& t)
{
iterator it = std::lower_bound (begin(), end(), t, cmp_);
if (it == end() || cmp_(t, *it)) {
vec_.insert (it, t);
}
return it;
}
template <typename T, typename C> inline void
TinySet<T,C>::insert_sorted (const T& t)
{
vec_.push_back (t);
assert (consistent());
}
template <typename T, typename C> inline void
TinySet<T,C>::remove (const T& t)
{
iterator it = std::lower_bound (begin(), end(), t, cmp_);
if (it != end()) {
vec_.erase (it);
}
}
template <typename T, typename C> inline typename TinySet<T,C>::const_iterator
TinySet<T,C>::find (const T& t) const
{
const_iterator it = std::lower_bound (begin(), end(), t, cmp_);
return it == end() || cmp_(t, *it) ? end() : it;
}
template <typename T, typename C> inline typename TinySet<T,C>::iterator
TinySet<T,C>::find (const T& t)
{
iterator it = std::lower_bound (begin(), end(), t, cmp_);
return it == end() || cmp_(t, *it) ? end() : it;
}
/* set union */
template <typename T, typename C> inline TinySet<T,C>
TinySet<T,C>::operator| (const TinySet& s) const
{
TinySet res;
std::set_union (
vec_.begin(), vec_.end(),
s.vec_.begin(), s.vec_.end(),
std::back_inserter (res.vec_),
cmp_);
return res;
}
/* set intersection */
template <typename T, typename C> inline TinySet<T,C>
TinySet<T,C>::operator& (const TinySet& s) const
{
TinySet res;
std::set_intersection (
vec_.begin(), vec_.end(),
s.vec_.begin(), s.vec_.end(),
std::back_inserter (res.vec_),
cmp_);
return res;
}
/* set difference */
template <typename T, typename C> inline TinySet<T,C>
TinySet<T,C>::operator- (const TinySet& s) const
{
TinySet res;
std::set_difference (
vec_.begin(), vec_.end(),
s.vec_.begin(), s.vec_.end(),
std::back_inserter (res.vec_),
cmp_);
return res;
}
template <typename T, typename C> inline TinySet<T,C>&
TinySet<T,C>::operator|= (const TinySet& s)
{
return *this = (*this | s);
}
template <typename T, typename C> inline TinySet<T,C>&
TinySet<T,C>::operator&= (const TinySet& s)
{
return *this = (*this & s);
}
template <typename T, typename C> inline TinySet<T,C>&
TinySet<T,C>::operator-= (const TinySet& s)
{
return *this = (*this - s);
}
template <typename T, typename C> inline bool
TinySet<T,C>::contains (const T& t) const
{
return std::binary_search (
vec_.begin(), vec_.end(), t, cmp_);
}
template <typename T, typename C> inline bool
TinySet<T,C>::contains (const TinySet& s) const
{
return std::includes (
vec_.begin(), vec_.end(),
s.vec_.begin(), s.vec_.end(),
cmp_);
}
template <typename T, typename C> inline bool
TinySet<T,C>::in (const TinySet& s) const
{
return std::includes (
s.vec_.begin(), s.vec_.end(),
vec_.begin(), vec_.end(),
cmp_);
}
template <typename T, typename C> inline bool
TinySet<T,C>::intersects (const TinySet& s) const
{
return (*this & s).size() > 0;
}
template <typename T, typename C> inline const T&
TinySet<T,C>::operator[] (typename std::vector<T>::size_type i) const
{
return vec_[i];
}
template <typename T, typename C> inline T&
TinySet<T,C>::operator[] (typename std::vector<T>::size_type i)
{
return vec_[i];
}
template <typename T, typename C> inline T
TinySet<T,C>::front() const
{
return vec_.front();
}
template <typename T, typename C> inline T&
TinySet<T,C>::front()
{
return vec_.front();
}
template <typename T, typename C> inline T
TinySet<T,C>::back() const
{
return vec_.back();
}
template <typename T, typename C> inline T&
TinySet<T,C>::back()
{
return vec_.back();
}
template <typename T, typename C> inline const std::vector<T>&
TinySet<T,C>::elements() const
{
return vec_;
}
template <typename T, typename C> inline bool
TinySet<T,C>::empty() const
{
return vec_.empty();
}
template <typename T, typename C> inline typename std::vector<T>::size_type
TinySet<T,C>::size() const
{
return vec_.size();
}
template <typename T, typename C> inline void
TinySet<T,C>::clear()
{
vec_.clear();
}
template <typename T, typename C> inline void
TinySet<T,C>::reserve (typename std::vector<T>::size_type size)
{
vec_.reserve (size);
}
template <typename T, typename C> typename TinySet<T,C>::iterator
TinySet<T,C>::unique_cmp (iterator first, iterator last)
{
if (first == last) {
return last;
}
iterator result = first;
while (++first != last) {
if (cmp_(*result, *first)) {
*(++result) = *first;
}
}
return ++result;
}
template <typename T, typename C> inline bool
TinySet<T,C>::consistent() const
{
typename std::vector<T>::size_type i;
for (i = 0; i < vec_.size() - 1; i++) {
if ( ! cmp_(vec_[i], vec_[i + 1])) {
return false;
}
}
return true;
}
} // namespace Horus
#endif // YAP_PACKAGES_CLPBN_HORUS_TINYSET_H_

View File

@ -1,27 +1,27 @@
#include <fstream>
#include "Util.h"
#include "Indexer.h"
#include "ElimGraph.h"
#include "BeliefProp.h"
namespace Horus {
namespace Globals {
bool logDomain = false;
unsigned verbosity = 0;
LiftedSolverType liftedSolver = LiftedSolverType::LVE;
LiftedSolverType liftedSolver = LiftedSolverType::lveSolver;
GroundSolverType groundSolver = GroundSolverType::VE;
GroundSolverType groundSolver = GroundSolverType::veSolver;
};
}
namespace Util {
template <> std::string
toString (const bool& b)
{
@ -33,14 +33,14 @@ toString (const bool& b)
unsigned
stringToUnsigned (string str)
stringToUnsigned (std::string str)
{
int val;
stringstream ss;
std::stringstream ss;
ss << str;
ss >> val;
if (val < 0) {
cerr << "Error: the number readed is negative." << endl;
std::cerr << "Error: the number readed is negative." << std::endl;
exit (EXIT_FAILURE);
}
return static_cast<unsigned> (val);
@ -49,10 +49,10 @@ stringToUnsigned (string str)
double
stringToDouble (string str)
stringToDouble (std::string str)
{
double val;
stringstream ss;
std::stringstream ss;
ss << str;
ss >> val;
return val;
@ -117,7 +117,7 @@ size_t
sizeExpected (const Ranges& ranges)
{
return std::accumulate (ranges.begin(),
ranges.end(), 1, multiplies<unsigned>());
ranges.end(), 1, std::multiplies<unsigned>());
}
@ -136,10 +136,10 @@ nrDigits (int num)
bool
isInteger (const string& s)
isInteger (const std::string& s)
{
stringstream ss1 (s);
stringstream ss2;
std::stringstream ss1 (s);
std::stringstream ss2;
int integer;
ss1 >> integer;
ss2 << integer;
@ -148,10 +148,10 @@ isInteger (const string& s)
string
std::string
parametersToString (const Params& v, unsigned precision)
{
stringstream ss;
std::stringstream ss;
ss.precision (precision);
ss << "[" ;
for (size_t i = 0; i < v.size(); i++) {
@ -164,7 +164,7 @@ parametersToString (const Params& v, unsigned precision)
vector<string>
std::vector<std::string>
getStateLines (const Vars& vars)
{
Ranges ranges;
@ -172,9 +172,9 @@ getStateLines (const Vars& vars)
ranges.push_back (vars[i]->range());
}
Indexer indexer (ranges);
vector<string> jointStrings;
std::vector<std::string> jointStrings;
while (indexer.valid()) {
stringstream ss;
std::stringstream ss;
for (size_t i = 0; i < vars.size(); i++) {
if (i != 0) ss << ", " ;
ss << vars[i]->label() << "=" ;
@ -188,34 +188,42 @@ getStateLines (const Vars& vars)
bool invalidValue (string option, string value)
bool invalidValue (std::string option, std::string value)
{
cerr << "Warning: invalid value `" << value << "' " ;
cerr << "for `" << option << "'." ;
cerr << endl;
std::cerr << "Warning: invalid value `" << value << "' " ;
std::cerr << "for `" << option << "'." ;
std::cerr << std::endl;
return false;
}
bool
setHorusFlag (string option, string value)
setHorusFlag (std::string option, std::string value)
{
bool returnVal = true;
if (option == "lifted_solver") {
if (value == "lve") Globals::liftedSolver = LiftedSolverType::LVE;
else if (value == "lbp") Globals::liftedSolver = LiftedSolverType::LBP;
else if (value == "lkc") Globals::liftedSolver = LiftedSolverType::LKC;
else returnVal = invalidValue (option, value);
if (value == "lve")
Globals::liftedSolver = LiftedSolverType::lveSolver;
else if (value == "lbp")
Globals::liftedSolver = LiftedSolverType::lbpSolver;
else if (value == "lkc")
Globals::liftedSolver = LiftedSolverType::lkcSolver;
else
returnVal = invalidValue (option, value);
} else if (option == "ground_solver" || option == "solver") {
if (value == "hve") Globals::groundSolver = GroundSolverType::VE;
else if (value == "bp") Globals::groundSolver = GroundSolverType::BP;
else if (value == "cbp") Globals::groundSolver = GroundSolverType::CBP;
else returnVal = invalidValue (option, value);
if (value == "hve")
Globals::groundSolver = GroundSolverType::veSolver;
else if (value == "bp")
Globals::groundSolver = GroundSolverType::bpSolver;
else if (value == "cbp")
Globals::groundSolver = GroundSolverType::CbpSolver;
else
returnVal = invalidValue (option, value);
} else if (option == "verbosity") {
stringstream ss;
std::stringstream ss;
ss << value;
ss >> Globals::verbosity;
@ -225,40 +233,42 @@ setHorusFlag (string option, string value)
else returnVal = invalidValue (option, value);
} else if (option == "hve_elim_heuristic") {
typedef ElimGraph::ElimHeuristic ElimHeuristic;
if (value == "sequential")
ElimGraph::setElimHeuristic (ElimHeuristic::SEQUENTIAL);
ElimGraph::setElimHeuristic (ElimHeuristic::sequentialEh);
else if (value == "min_neighbors")
ElimGraph::setElimHeuristic (ElimHeuristic::MIN_NEIGHBORS);
ElimGraph::setElimHeuristic (ElimHeuristic::minNeighborsEh);
else if (value == "min_weight")
ElimGraph::setElimHeuristic (ElimHeuristic::MIN_WEIGHT);
ElimGraph::setElimHeuristic (ElimHeuristic::minWeightEh);
else if (value == "min_fill")
ElimGraph::setElimHeuristic (ElimHeuristic::MIN_FILL);
ElimGraph::setElimHeuristic (ElimHeuristic::minFillEh);
else if (value == "weighted_min_fill")
ElimGraph::setElimHeuristic (ElimHeuristic::WEIGHTED_MIN_FILL);
ElimGraph::setElimHeuristic (ElimHeuristic::weightedMinFillEh);
else
returnVal = invalidValue (option, value);
} else if (option == "bp_msg_schedule") {
typedef BeliefProp::MsgSchedule MsgSchedule;
if (value == "seq_fixed")
BeliefProp::setMsgSchedule (MsgSchedule::SEQ_FIXED);
BeliefProp::setMsgSchedule (MsgSchedule::seqFixedSch);
else if (value == "seq_random")
BeliefProp::setMsgSchedule (MsgSchedule::SEQ_RANDOM);
BeliefProp::setMsgSchedule (MsgSchedule::seqRandomSch);
else if (value == "parallel")
BeliefProp::setMsgSchedule (MsgSchedule::PARALLEL);
BeliefProp::setMsgSchedule (MsgSchedule::parallelSch);
else if (value == "max_residual")
BeliefProp::setMsgSchedule (MsgSchedule::MAX_RESIDUAL);
BeliefProp::setMsgSchedule (MsgSchedule::maxResidualSch);
else
returnVal = invalidValue (option, value);
} else if (option == "bp_accuracy") {
stringstream ss;
std::stringstream ss;
double acc;
ss << value;
ss >> acc;
BeliefProp::setAccuracy (acc);
} else if (option == "bp_max_iter") {
stringstream ss;
std::stringstream ss;
unsigned mi;
ss << value;
ss >> mi;
@ -285,7 +295,7 @@ setHorusFlag (string option, string value)
else returnVal = invalidValue (option, value);
} else {
cerr << "Warning: invalid option `" << option << "'" << endl;
std::cerr << "Warning: invalid option `" << option << "'" << std::endl;
returnVal = false;
}
return returnVal;
@ -294,20 +304,20 @@ setHorusFlag (string option, string value)
void
printHeader (string header, std::ostream& os)
printHeader (std::string header, std::ostream& os)
{
printAsteriskLine (os);
os << header << endl;
os << header << std::endl;
printAsteriskLine (os);
}
void
printSubHeader (string header, std::ostream& os)
printSubHeader (std::string header, std::ostream& os)
{
printDashedLine (os);
os << header << endl;
os << header << std::endl;
printDashedLine (os);
}
@ -318,7 +328,7 @@ printAsteriskLine (std::ostream& os)
{
os << "********************************" ;
os << "********************************" ;
os << endl;
os << std::endl;
}
@ -328,11 +338,10 @@ printDashedLine (std::ostream& os)
{
os << "--------------------------------" ;
os << "--------------------------------" ;
os << endl;
os << std::endl;
}
}
} // namespace Util
@ -362,10 +371,10 @@ getL1Distance (const Params& v1, const Params& v2)
double dist = 0.0;
if (Globals::logDomain) {
dist = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0,
std::plus<double>(), FuncObject::abs_diff_exp<double>());
std::plus<double>(), FuncObj::abs_diff_exp<double>());
} else {
dist = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0,
std::plus<double>(), FuncObject::abs_diff<double>());
std::plus<double>(), FuncObj::abs_diff<double>());
}
return dist;
}
@ -379,10 +388,10 @@ getMaxNorm (const Params& v1, const Params& v2)
double max = 0.0;
if (Globals::logDomain) {
max = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0,
FuncObject::max<double>(), FuncObject::abs_diff_exp<double>());
FuncObj::max<double>(), FuncObj::abs_diff_exp<double>());
} else {
max = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0,
FuncObject::max<double>(), FuncObject::abs_diff<double>());
FuncObj::max<double>(), FuncObj::abs_diff<double>());
}
return max;
}
@ -428,5 +437,8 @@ pow (Params& v, double exp)
Globals::logDomain ? v *= exp : v ^= exp;
}
}
} // namespace LogAware
} // namespace Horus

View File

@ -1,69 +1,78 @@
#ifndef HORUS_UTIL_H
#define HORUS_UTIL_H
#ifndef YAP_PACKAGES_CLPBN_HORUS_UTIL_H_
#define YAP_PACKAGES_CLPBN_HORUS_UTIL_H_
#include <cmath>
#include <cassert>
#include <algorithm>
#include <limits>
#include <vector>
#include <queue>
#include <set>
#include <unordered_map>
#include <algorithm>
#include <limits>
#include <string>
#include <iostream>
#include <sstream>
#include "Horus.h"
using namespace std;
namespace Horus {
namespace {
const double NEG_INF = -std::numeric_limits<double>::infinity();
};
}
namespace Util {
template <typename T> void addToVector (vector<T>&, const vector<T>&);
template <typename T> void
addToVector (std::vector<T>&, const std::vector<T>&);
template <typename T> void addToSet (set<T>&, const vector<T>&);
template <typename T> void
addToSet (std::set<T>&, const std::vector<T>&);
template <typename T> void addToQueue (queue<T>&, const vector<T>&);
template <typename T> void
addToQueue (std::queue<T>&, const std::vector<T>&);
template <typename T> bool contains (const vector<T>&, const T&);
template <typename T> bool
contains (const std::vector<T>&, const T&);
template <typename T> bool contains (const set<T>&, const T&);
template <typename T> bool contains
(const std::set<T>&, const T&);
template <typename K, typename V> bool contains (
const unordered_map<K, V>&, const K&);
template <typename K, typename V> bool
contains (const std::unordered_map<K, V>&, const K&);
template <typename T> size_t indexOf (const vector<T>&, const T&);
template <typename T> size_t
indexOf (const std::vector<T>&, const T&);
template <class Operation>
void apply_n_times (Params& v1, const Params& v2,
unsigned repetitions, Operation);
template <class Operation> void
apply_n_times (Params& v1, const Params& v2, unsigned reps, Operation);
template <typename T> void log (vector<T>&);
template <typename T> void
log (std::vector<T>&);
template <typename T> void exp (vector<T>&);
template <typename T> void
exp (std::vector<T>&);
template <typename T> string elementsToString (
const vector<T>& v, string sep = " ");
template <typename T> std::string
elementsToString (const std::vector<T>& v, std::string sep = " ");
template <typename T> std::string toString (const T&);
template <typename T> std::string
toString (const T&);
template <> std::string toString (const bool&);
template <> std::string
toString (const bool&);
double logSum (double, double);
unsigned maxUnsigned (void);
unsigned maxUnsigned();
unsigned stringToUnsigned (string);
unsigned stringToUnsigned (std::string);
double stringToDouble (string);
double stringToDouble (std::string);
double factorial (unsigned);
@ -75,28 +84,29 @@ size_t sizeExpected (const Ranges&);
unsigned nrDigits (int);
bool isInteger (const string&);
bool isInteger (const std::string&);
string parametersToString (const Params&, unsigned = Constants::PRECISION);
std::string parametersToString (
const Params&, unsigned = Constants::precision);
vector<string> getStateLines (const Vars&);
std::vector<std::string> getStateLines (const Vars&);
bool setHorusFlag (string option, string value);
bool setHorusFlag (std::string option, std::string value);
void printHeader (string, std::ostream& os = std::cout);
void printHeader (std::string, std::ostream& os = std::cout);
void printSubHeader (string, std::ostream& os = std::cout);
void printSubHeader (std::string, std::ostream& os = std::cout);
void printAsteriskLine (std::ostream& os = std::cout);
void printDashedLine (std::ostream& os = std::cout);
};
} // namespace Util
template <typename T> void
Util::addToVector (vector<T>& v, const vector<T>& elements)
Util::addToVector (std::vector<T>& v, const std::vector<T>& elements)
{
v.insert (v.end(), elements.begin(), elements.end());
}
@ -104,7 +114,7 @@ Util::addToVector (vector<T>& v, const vector<T>& elements)
template <typename T> void
Util::addToSet (set<T>& s, const vector<T>& elements)
Util::addToSet (std::set<T>& s, const std::vector<T>& elements)
{
s.insert (elements.begin(), elements.end());
}
@ -112,7 +122,7 @@ Util::addToSet (set<T>& s, const vector<T>& elements)
template <typename T> void
Util::addToQueue (queue<T>& q, const vector<T>& elements)
Util::addToQueue (std::queue<T>& q, const std::vector<T>& elements)
{
for (size_t i = 0; i < elements.size(); i++) {
q.push (elements[i]);
@ -122,7 +132,7 @@ Util::addToQueue (queue<T>& q, const vector<T>& elements)
template <typename T> bool
Util::contains (const vector<T>& v, const T& e)
Util::contains (const std::vector<T>& v, const T& e)
{
return std::find (v.begin(), v.end(), e) != v.end();
}
@ -130,7 +140,7 @@ Util::contains (const vector<T>& v, const T& e)
template <typename T> bool
Util::contains (const set<T>& s, const T& e)
Util::contains (const std::set<T>& s, const T& e)
{
return s.find (e) != s.end();
}
@ -138,7 +148,7 @@ Util::contains (const set<T>& s, const T& e)
template <typename K, typename V> bool
Util::contains (const unordered_map<K, V>& m, const K& k)
Util::contains (const std::unordered_map<K, V>& m, const K& k)
{
return m.find (k) != m.end();
}
@ -146,7 +156,7 @@ Util::contains (const unordered_map<K, V>& m, const K& k)
template <typename T> size_t
Util::indexOf (const vector<T>& v, const T& e)
Util::indexOf (const std::vector<T>& v, const T& e)
{
return std::distance (v.begin(),
std::find (v.begin(), v.end(), e));
@ -155,7 +165,10 @@ Util::indexOf (const vector<T>& v, const T& e)
template <class Operation> void
Util::apply_n_times (Params& v1, const Params& v2, unsigned repetitions,
Util::apply_n_times (
Params& v1,
const Params& v2,
unsigned repetitions,
Operation unary_op)
{
Params::iterator first = v1.begin();
@ -174,7 +187,7 @@ Util::apply_n_times (Params& v1, const Params& v2, unsigned repetitions,
template <typename T> void
Util::log (vector<T>& v)
Util::log (std::vector<T>& v)
{
std::transform (v.begin(), v.end(), v.begin(), ::log);
}
@ -182,17 +195,17 @@ Util::log (vector<T>& v)
template <typename T> void
Util::exp (vector<T>& v)
Util::exp (std::vector<T>& v)
{
std::transform (v.begin(), v.end(), v.begin(), ::exp);
}
template <typename T> string
Util::elementsToString (const vector<T>& v, string sep)
template <typename T> std::string
Util::elementsToString (const std::vector<T>& v, std::string sep)
{
stringstream ss;
std::stringstream ss;
for (size_t i = 0; i < v.size(); i++) {
ss << ((i != 0) ? sep : "") << v[i];
}
@ -245,7 +258,7 @@ Util::logSum (double x, double y)
inline unsigned
Util::maxUnsigned (void)
Util::maxUnsigned()
{
return std::numeric_limits<unsigned>::max();
}
@ -277,106 +290,106 @@ void pow (Params&, unsigned);
void pow (Params&, double);
};
} // namespace LogAware
template <typename T>
void operator+=(std::vector<T>& v, double val)
template <typename T> void
operator+=(std::vector<T>& v, double val)
{
std::transform (v.begin(), v.end(), v.begin(),
std::bind2nd (plus<double>(), val));
std::bind2nd (std::plus<double>(), val));
}
template <typename T>
void operator-=(std::vector<T>& v, double val)
template <typename T> void
operator-=(std::vector<T>& v, double val)
{
std::transform (v.begin(), v.end(), v.begin(),
std::bind2nd (minus<double>(), val));
std::bind2nd (std::minus<double>(), val));
}
template <typename T>
void operator*=(std::vector<T>& v, double val)
template <typename T> void
operator*=(std::vector<T>& v, double val)
{
std::transform (v.begin(), v.end(), v.begin(),
std::bind2nd (multiplies<double>(), val));
std::bind2nd (std::multiplies<double>(), val));
}
template <typename T>
void operator/=(std::vector<T>& v, double val)
template <typename T> void
operator/=(std::vector<T>& v, double val)
{
std::transform (v.begin(), v.end(), v.begin(),
std::bind2nd (divides<double>(), val));
std::bind2nd (std::divides<double>(), val));
}
template <typename T>
void operator+=(std::vector<T>& a, const std::vector<T>& b)
template <typename T> void
operator+=(std::vector<T>& a, const std::vector<T>& b)
{
assert (a.size() == b.size());
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
plus<double>());
std::plus<double>());
}
template <typename T>
void operator-=(std::vector<T>& a, const std::vector<T>& b)
template <typename T> void
operator-=(std::vector<T>& a, const std::vector<T>& b)
{
assert (a.size() == b.size());
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
minus<double>());
std::minus<double>());
}
template <typename T>
void operator*=(std::vector<T>& a, const std::vector<T>& b)
template <typename T> void
operator*=(std::vector<T>& a, const std::vector<T>& b)
{
assert (a.size() == b.size());
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
multiplies<double>());
std::multiplies<double>());
}
template <typename T>
void operator/=(std::vector<T>& a, const std::vector<T>& b)
template <typename T> void
operator/=(std::vector<T>& a, const std::vector<T>& b)
{
assert (a.size() == b.size());
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
divides<double>());
std::divides<double>());
}
template <typename T>
void operator^=(std::vector<T>& v, double exp)
template <typename T> void
operator^=(std::vector<T>& v, double exp)
{
std::transform (v.begin(), v.end(), v.begin(),
std::bind2nd (ptr_fun<double, double, double> (std::pow), exp));
std::bind2nd (std::ptr_fun<double, double, double> (std::pow), exp));
}
template <typename T>
void operator^=(std::vector<T>& v, int iexp)
template <typename T> void
operator^=(std::vector<T>& v, int iexp)
{
std::transform (v.begin(), v.end(), v.begin(),
std::bind2nd (ptr_fun<double, int, double> (std::pow), iexp));
std::bind2nd (std::ptr_fun<double, int, double> (std::pow), iexp));
}
template <typename T>
std::ostream& operator<< (std::ostream& os, const vector<T>& v)
template <typename T> std::ostream&
operator<< (std::ostream& os, const std::vector<T>& v)
{
os << "[" ;
os << Util::elementsToString (v, ", ");
@ -385,40 +398,33 @@ std::ostream& operator<< (std::ostream& os, const vector<T>& v)
}
namespace FuncObject {
namespace FuncObj {
template<typename T>
struct max : public std::binary_function<T, T, T>
{
T operator() (const T& x, const T& y) const
{
struct max : public std::binary_function<T, T, T> {
T operator() (const T& x, const T& y) const {
return x < y ? y : x;
}
};
}};
template <typename T>
struct abs_diff : public std::binary_function<T, T, T>
{
T operator() (const T& x, const T& y) const
{
struct abs_diff : public std::binary_function<T, T, T> {
T operator() (const T& x, const T& y) const {
return std::abs (x - y);
}
};
}};
template <typename T>
struct abs_diff_exp : public std::binary_function<T, T, T>
{
T operator() (const T& x, const T& y) const
{
struct abs_diff_exp : public std::binary_function<T, T, T> {
T operator() (const T& x, const T& y) const {
return std::abs (std::exp (x) - std::exp (y));
}
};
}};
}
} // namespace FuncObj
#endif // HORUS_UTIL_H
} // namespace Horus
#endif // YAP_PACKAGES_CLPBN_HORUS_UTIL_H_

View File

@ -3,7 +3,9 @@
#include "Var.h"
unordered_map<VarId, VarInfo> Var::varsInfo_;
namespace Horus {
std::unordered_map<VarId, Var::VarInfo> Var::varsInfo_;
Var::Var (const Var* v)
@ -45,13 +47,14 @@ Var::setEvidence (int evidence)
string
Var::label (void) const
std::string
Var::label() const
{
if (Var::varsHaveInfo()) {
return Var::getVarInfo (varId_).label;
assert (Util::contains (varsInfo_, varId_));
return varsInfo_.find (varId_)->second.first;
}
stringstream ss;
std::stringstream ss;
ss << "x" << varId_;
return ss.str();
}
@ -59,17 +62,46 @@ Var::label (void) const
States
Var::states (void) const
Var::states() const
{
if (Var::varsHaveInfo()) {
return Var::getVarInfo (varId_).states;
assert (Util::contains (varsInfo_, varId_));
return varsInfo_.find (varId_)->second.second;
}
States states;
for (unsigned i = 0; i < range_; i++) {
stringstream ss;
std::stringstream ss;
ss << i ;
states.push_back (ss.str());
}
return states;
}
void
Var::addVarInfo (
VarId vid, std::string label, const States& states)
{
assert (Util::contains (varsInfo_, vid) == false);
varsInfo_.insert (std::make_pair (vid, VarInfo (label, states)));
}
bool
Var::varsHaveInfo()
{
return varsInfo_.empty() == false;
}
void
Var::clearVarsInfo()
{
varsInfo_.clear();
}
} // namespace Horus

View File

@ -1,102 +1,105 @@
#ifndef HORUS_VAR_H
#define HORUS_VAR_H
#ifndef YAP_PACKAGES_CLPBN_HORUS_VAR_H_
#define YAP_PACKAGES_CLPBN_HORUS_VAR_H_
#include <cassert>
#include <unordered_map>
#include <string>
#include "Util.h"
#include "Horus.h"
using namespace std;
namespace Horus {
struct VarInfo
{
VarInfo (string l, const States& sts)
: label(l), states(sts) { }
string label;
States states;
};
class Var
{
class Var {
public:
Var (const Var*);
Var (VarId, unsigned, int = Constants::NO_EVIDENCE);
Var (VarId, unsigned range, int evidence = Constants::unobserved);
virtual ~Var (void) { };
virtual ~Var() { };
VarId varId (void) const { return varId_; }
VarId varId() const { return varId_; }
unsigned range (void) const { return range_; }
unsigned range() const { return range_; }
int getEvidence (void) const { return evidence_; }
int getEvidence() const { return evidence_; }
size_t getIndex (void) const { return index_; }
size_t getIndex() const { return index_; }
void setIndex (size_t idx) { index_ = idx; }
bool hasEvidence (void) const
{
return evidence_ != Constants::NO_EVIDENCE;
}
bool hasEvidence() const;
operator size_t (void) const { return index_; }
operator size_t() const;
bool operator== (const Var& var) const
{
assert (!(varId_ == var.varId() && range_ != var.range()));
return varId_ == var.varId();
}
bool operator== (const Var& var) const;
bool operator!= (const Var& var) const
{
return !(*this == var);
}
bool operator!= (const Var& var) const;
bool isValidState (int);
void setEvidence (int);
string label (void) const;
std::string label() const;
States states (void) const;
States states() const;
static void addVarInfo (
VarId vid, string label, const States& states)
{
assert (Util::contains (varsInfo_, vid) == false);
varsInfo_.insert (make_pair (vid, VarInfo (label, states)));
}
VarId vid, std::string label, const States& states);
static VarInfo getVarInfo (VarId vid)
{
assert (Util::contains (varsInfo_, vid));
return varsInfo_.find (vid)->second;
}
static bool varsHaveInfo();
static bool varsHaveInfo (void)
{
return varsInfo_.empty() == false;
}
static void clearVarsInfo (void)
{
varsInfo_.clear();
}
static void clearVarsInfo();
private:
typedef std::pair<std::string, States> VarInfo;
VarId varId_;
unsigned range_;
int evidence_;
size_t index_;
static unordered_map<VarId, VarInfo> varsInfo_;
static std::unordered_map<VarId, VarInfo> varsInfo_;
DISALLOW_COPY_AND_ASSIGN(Var);
};
#endif // HORUS_VAR_H
inline bool
Var::hasEvidence() const
{
return evidence_ != Constants::unobserved;
}
inline
Var::operator size_t() const
{
return index_;
}
inline bool
Var::operator== (const Var& var) const
{
assert (!(varId_ == var.varId() && range_ != var.range()));
return varId_ == var.varId();
}
inline bool
Var::operator!= (const Var& var) const
{
return !(*this == var);
}
} // namespace Horus
#endif // YAP_PACKAGES_CLPBN_HORUS_VAR_H_

View File

@ -1,4 +1,6 @@
#include <algorithm>
#include <iostream>
#include <sstream>
#include "VarElim.h"
#include "ElimGraph.h"
@ -6,16 +8,18 @@
#include "Util.h"
namespace Horus {
Params
VarElim::solveQuery (VarIds queryVids)
{
if (Globals::verbosity > 1) {
cout << "Solving query on " ;
std::cout << "Solving query on " ;
for (size_t i = 0; i < queryVids.size(); i++) {
if (i != 0) cout << ", " ;
cout << fg.getVarNode (queryVids[i])->label();
if (i != 0) std::cout << ", " ;
std::cout << fg.getVarNode (queryVids[i])->label();
}
cout << endl;
std::cout << std::endl;
}
totalFactorSize_ = 0;
largestFactorSize_ = 0;
@ -33,27 +37,28 @@ VarElim::solveQuery (VarIds queryVids)
void
VarElim::printSolverFlags (void) const
VarElim::printSolverFlags() const
{
stringstream ss;
std::stringstream ss;
ss << "variable elimination [" ;
ss << "elim_heuristic=" ;
typedef ElimGraph::ElimHeuristic ElimHeuristic;
switch (ElimGraph::elimHeuristic()) {
case ElimHeuristic::SEQUENTIAL: ss << "sequential"; break;
case ElimHeuristic::MIN_NEIGHBORS: ss << "min_neighbors"; break;
case ElimHeuristic::MIN_WEIGHT: ss << "min_weight"; break;
case ElimHeuristic::MIN_FILL: ss << "min_fill"; break;
case ElimHeuristic::WEIGHTED_MIN_FILL: ss << "weighted_min_fill"; break;
case ElimHeuristic::sequentialEh: ss << "sequential"; break;
case ElimHeuristic::minNeighborsEh: ss << "min_neighbors"; break;
case ElimHeuristic::minWeightEh: ss << "min_weight"; break;
case ElimHeuristic::minFillEh: ss << "min_fill"; break;
case ElimHeuristic::weightedMinFillEh: ss << "weighted_min_fill"; break;
}
ss << ",log_domain=" << Util::toString (Globals::logDomain);
ss << "]" ;
cout << ss.str() << endl;
std::cout << ss.str() << std::endl;
}
void
VarElim::createFactorList (void)
VarElim::createFactorList()
{
const FacNodes& facNodes = fg.facNodes();
factorList_.reserve (facNodes.size() * 2);
@ -61,7 +66,7 @@ VarElim::createFactorList (void)
factorList_.push_back (new Factor (facNodes[i]->factor()));
const VarIds& args = facNodes[i]->factor().arguments();
for (size_t j = 0; j < args.size(); j++) {
unordered_map<VarId, vector<size_t>>::iterator it;
std::unordered_map<VarId, std::vector<size_t>>::iterator it;
it = varMap_.find (args[j]);
if (it != varMap_.end()) {
it->second.push_back (i);
@ -75,22 +80,22 @@ VarElim::createFactorList (void)
void
VarElim::absorveEvidence (void)
VarElim::absorveEvidence()
{
if (Globals::verbosity > 2) {
Util::printDashedLine();
cout << "(initial factor list)" << endl;
std::cout << "(initial factor list)" << std::endl;
printActiveFactors();
}
const VarNodes& varNodes = fg.varNodes();
for (size_t i = 0; i < varNodes.size(); i++) {
if (varNodes[i]->hasEvidence()) {
if (Globals::verbosity > 1) {
cout << "-> aborving evidence on ";
cout << varNodes[i]->label() << " = " ;
cout << varNodes[i]->getEvidence() << endl;
std::cout << "-> aborving evidence on ";
std::cout << varNodes[i]->label() << " = " ;
std::cout << varNodes[i]->getEvidence() << std::endl;
}
const vector<size_t>& indices = varMap_[varNodes[i]->varId()];
const std::vector<size_t>& indices = varMap_[varNodes[i]->varId()];
for (size_t j = 0; j < indices.size(); j++) {
size_t idx = indices[j];
if (factorList_[idx]->nrArguments() > 1) {
@ -118,8 +123,8 @@ VarElim::processFactorList (const VarIds& queryVids)
Util::printDashedLine();
printActiveFactors();
}
cout << "-> summing out " ;
cout << fg.getVarNode (elimOrder[i])->label() << endl;
std::cout << "-> summing out " ;
std::cout << fg.getVarNode (elimOrder[i])->label() << std::endl;
}
eliminate (elimOrder[i]);
}
@ -143,9 +148,9 @@ VarElim::processFactorList (const VarIds& queryVids)
result.reorderArguments (unobservedVids);
result.normalize();
if (Globals::verbosity > 0) {
cout << "total factor size: " << totalFactorSize_ << endl;
cout << "largest factor size: " << largestFactorSize_ << endl;
cout << endl;
std::cout << "total factor size: " << totalFactorSize_ << std::endl;
std::cout << "largest factor size: " << largestFactorSize_ << std::endl;
std::cout << std::endl;
}
return result.params();
}
@ -156,7 +161,7 @@ void
VarElim::eliminate (VarId vid)
{
Factor* result = new Factor();
const vector<size_t>& indices = varMap_[vid];
const std::vector<size_t>& indices = varMap_[vid];
for (size_t i = 0; i < indices.size(); i++) {
size_t idx = indices[i];
if (factorList_[idx]) {
@ -173,7 +178,7 @@ VarElim::eliminate (VarId vid)
result->sumOut (vid);
const VarIds& args = result->arguments();
for (size_t i = 0; i < args.size(); i++) {
vector<size_t>& indices2 = varMap_[args[i]];
std::vector<size_t>& indices2 = varMap_[args[i]];
indices2.push_back (factorList_.size());
}
factorList_.push_back (result);
@ -185,14 +190,16 @@ VarElim::eliminate (VarId vid)
void
VarElim::printActiveFactors (void)
VarElim::printActiveFactors()
{
for (size_t i = 0; i < factorList_.size(); i++) {
if (factorList_[i]) {
cout << factorList_[i]->getLabel() << " " ;
cout << factorList_[i]->params();
cout << endl;
std::cout << factorList_[i]->getLabel() << " " ;
std::cout << factorList_[i]->params();
std::cout << std::endl;
}
}
}
} // namespace Horus

Some files were not shown because too many files have changed in this diff Show More