Merge branch 'master' of git.dcc.fc.up.pt:yap-6.3
This commit is contained in:
commit
8e33cebd4d
105
C/absmi.c
105
C/absmi.c
@ -9376,7 +9376,7 @@ Yap_absmi(int inp)
|
||||
}
|
||||
else {
|
||||
saveregs();
|
||||
d0 = p_plus(Yap_Eval(d0), Yap_Eval(d1));
|
||||
d0 = p_plus(Yap_Eval(d0), Yap_Eval(d1) PASS_REGS);
|
||||
setregs();
|
||||
if (d0 == 0L) {
|
||||
saveregs();
|
||||
@ -9421,7 +9421,7 @@ Yap_absmi(int inp)
|
||||
}
|
||||
else {
|
||||
saveregs();
|
||||
d0 = p_plus(Yap_Eval(d0), MkIntegerTerm(d1));
|
||||
d0 = p_plus(Yap_Eval(d0), MkIntegerTerm(d1) PASS_REGS);
|
||||
setregs();
|
||||
if (d0 == 0L) {
|
||||
saveregs();
|
||||
@ -9462,7 +9462,7 @@ Yap_absmi(int inp)
|
||||
}
|
||||
else {
|
||||
saveregs();
|
||||
d0 = p_plus(Yap_Eval(d0), Yap_Eval(d1));
|
||||
d0 = p_plus(Yap_Eval(d0), Yap_Eval(d1) PASS_REGS);
|
||||
setregs();
|
||||
if (d0 == 0L) {
|
||||
saveregs();
|
||||
@ -9510,7 +9510,7 @@ Yap_absmi(int inp)
|
||||
}
|
||||
else {
|
||||
saveregs();
|
||||
d0 = p_plus(Yap_Eval(d0), MkIntegerTerm(d1));
|
||||
d0 = p_plus(Yap_Eval(d0), MkIntegerTerm(d1) PASS_REGS);
|
||||
setregs();
|
||||
if (d0 == 0L) {
|
||||
saveregs();
|
||||
@ -9554,7 +9554,7 @@ Yap_absmi(int inp)
|
||||
}
|
||||
else {
|
||||
saveregs();
|
||||
d0 = p_minus(Yap_Eval(d0), Yap_Eval(d1));
|
||||
d0 = p_minus(Yap_Eval(d0), Yap_Eval(d1) PASS_REGS);
|
||||
setregs();
|
||||
if (d0 == 0L) {
|
||||
saveregs();
|
||||
@ -9599,7 +9599,7 @@ Yap_absmi(int inp)
|
||||
}
|
||||
else {
|
||||
saveregs();
|
||||
d0 = p_minus(MkIntegerTerm(d1),Yap_Eval(d0));
|
||||
d0 = p_minus(MkIntegerTerm(d1),Yap_Eval(d0) PASS_REGS);
|
||||
setregs();
|
||||
if (d0 == 0L) {
|
||||
saveregs();
|
||||
@ -9640,7 +9640,7 @@ Yap_absmi(int inp)
|
||||
}
|
||||
else {
|
||||
saveregs();
|
||||
d0 = p_minus(Yap_Eval(d0), Yap_Eval(d1));
|
||||
d0 = p_minus(Yap_Eval(d0), Yap_Eval(d1) PASS_REGS);
|
||||
setregs();
|
||||
if (d0 == 0L) {
|
||||
saveregs();
|
||||
@ -9688,7 +9688,7 @@ Yap_absmi(int inp)
|
||||
}
|
||||
else {
|
||||
saveregs();
|
||||
d0 = p_minus(MkIntegerTerm(d1), Yap_Eval(d0));
|
||||
d0 = p_minus(MkIntegerTerm(d1), Yap_Eval(d0) PASS_REGS);
|
||||
setregs();
|
||||
if (d0 == 0L) {
|
||||
saveregs();
|
||||
@ -9728,11 +9728,11 @@ Yap_absmi(int inp)
|
||||
times_vv_nvar_nvar:
|
||||
/* d0 and d1 are where I want them */
|
||||
if (IsIntTerm(d0) && IsIntTerm(d1)) {
|
||||
d0 = times_int(IntOfTerm(d0), IntOfTerm(d1));
|
||||
d0 = times_int(IntOfTerm(d0), IntOfTerm(d1) PASS_REGS);
|
||||
}
|
||||
else {
|
||||
saveregs();
|
||||
d0 = p_times(Yap_Eval(d0), Yap_Eval(d1));
|
||||
d0 = p_times(Yap_Eval(d0), Yap_Eval(d1) PASS_REGS);
|
||||
setregs();
|
||||
if (d0 == 0L) {
|
||||
saveregs();
|
||||
@ -9773,11 +9773,11 @@ Yap_absmi(int inp)
|
||||
{
|
||||
Int d1 = PREG->u.xxn.c;
|
||||
if (IsIntTerm(d0)) {
|
||||
d0 = times_int(IntOfTerm(d0), d1);
|
||||
d0 = times_int(IntOfTerm(d0), d1 PASS_REGS);
|
||||
}
|
||||
else {
|
||||
saveregs();
|
||||
d0 = p_times(Yap_Eval(d0), MkIntegerTerm(d1));
|
||||
d0 = p_times(Yap_Eval(d0), MkIntegerTerm(d1) PASS_REGS);
|
||||
setregs();
|
||||
if (d0 == 0L) {
|
||||
saveregs();
|
||||
@ -9814,11 +9814,11 @@ Yap_absmi(int inp)
|
||||
times_y_vv_nvar_nvar:
|
||||
/* d0 and d1 are where I want them */
|
||||
if (IsIntTerm(d0) && IsIntTerm(d1)) {
|
||||
d0 = times_int(IntOfTerm(d0), IntOfTerm(d1));
|
||||
d0 = times_int(IntOfTerm(d0), IntOfTerm(d1) PASS_REGS);
|
||||
}
|
||||
else {
|
||||
saveregs();
|
||||
d0 = p_times(Yap_Eval(d0), Yap_Eval(d1));
|
||||
d0 = p_times(Yap_Eval(d0), Yap_Eval(d1) PASS_REGS);
|
||||
setregs();
|
||||
if (d0 == 0L) {
|
||||
saveregs();
|
||||
@ -9862,11 +9862,11 @@ Yap_absmi(int inp)
|
||||
{
|
||||
Int d1 = PREG->u.yxn.c;
|
||||
if (IsIntTerm(d0)) {
|
||||
d0 = times_int(IntOfTerm(d0), d1);
|
||||
d0 = times_int(IntOfTerm(d0), d1 PASS_REGS);
|
||||
}
|
||||
else {
|
||||
saveregs();
|
||||
d0 = p_times(Yap_Eval(d0), MkIntegerTerm(d1));
|
||||
d0 = p_times(Yap_Eval(d0), MkIntegerTerm(d1) PASS_REGS);
|
||||
setregs();
|
||||
if (d0 == 0L) {
|
||||
saveregs();
|
||||
@ -9917,7 +9917,7 @@ Yap_absmi(int inp)
|
||||
}
|
||||
else {
|
||||
saveregs();
|
||||
d0 = p_div(Yap_Eval(d0), Yap_Eval(d1));
|
||||
d0 = p_div(Yap_Eval(d0), Yap_Eval(d1) PASS_REGS);
|
||||
setregs();
|
||||
if (d0 == 0L) {
|
||||
saveregs();
|
||||
@ -9962,7 +9962,7 @@ Yap_absmi(int inp)
|
||||
}
|
||||
else {
|
||||
saveregs();
|
||||
d0 = p_div(Yap_Eval(d0),MkIntegerTerm(d1));
|
||||
d0 = p_div(Yap_Eval(d0),MkIntegerTerm(d1) PASS_REGS);
|
||||
setregs();
|
||||
if (d0 == 0L) {
|
||||
saveregs();
|
||||
@ -10006,7 +10006,7 @@ Yap_absmi(int inp)
|
||||
}
|
||||
else {
|
||||
saveregs();
|
||||
d0 = p_div(MkIntegerTerm(d1),Yap_Eval(d0));
|
||||
d0 = p_div(MkIntegerTerm(d1),Yap_Eval(d0) PASS_REGS);
|
||||
if (d0 == 0L) {
|
||||
saveregs();
|
||||
Yap_Error(LOCAL_Error_TYPE, LOCAL_Error_Term, LOCAL_ErrorMessage);
|
||||
@ -10053,7 +10053,7 @@ Yap_absmi(int inp)
|
||||
}
|
||||
else {
|
||||
saveregs();
|
||||
d0 = p_div(Yap_Eval(d0), Yap_Eval(d1));
|
||||
d0 = p_div(Yap_Eval(d0), Yap_Eval(d1) PASS_REGS);
|
||||
setregs();
|
||||
if (d0 == 0L) {
|
||||
saveregs();
|
||||
@ -10101,7 +10101,7 @@ Yap_absmi(int inp)
|
||||
}
|
||||
else {
|
||||
saveregs();
|
||||
d0 = p_div(Yap_Eval(d0),MkIntegerTerm(d1));
|
||||
d0 = p_div(Yap_Eval(d0),MkIntegerTerm(d1) PASS_REGS);
|
||||
setregs();
|
||||
if (d0 == 0L) {
|
||||
saveregs();
|
||||
@ -10148,7 +10148,7 @@ Yap_absmi(int inp)
|
||||
}
|
||||
else {
|
||||
saveregs();
|
||||
d0 = p_div(MkIntegerTerm(d1), Yap_Eval(d0));
|
||||
d0 = p_div(MkIntegerTerm(d1), Yap_Eval(d0) PASS_REGS);
|
||||
setregs();
|
||||
if (d0 == 0L) {
|
||||
saveregs();
|
||||
@ -10193,7 +10193,7 @@ Yap_absmi(int inp)
|
||||
}
|
||||
else {
|
||||
saveregs();
|
||||
d0 = p_and(Yap_Eval(d0), Yap_Eval(d1));
|
||||
d0 = p_and(Yap_Eval(d0), Yap_Eval(d1) PASS_REGS);
|
||||
setregs();
|
||||
if (d0 == 0L) {
|
||||
saveregs();
|
||||
@ -10238,7 +10238,7 @@ Yap_absmi(int inp)
|
||||
}
|
||||
else {
|
||||
saveregs();
|
||||
d0 = p_and(Yap_Eval(d0), MkIntegerTerm(d1));
|
||||
d0 = p_and(Yap_Eval(d0), MkIntegerTerm(d1) PASS_REGS);
|
||||
setregs();
|
||||
if (d0 == 0L) {
|
||||
saveregs();
|
||||
@ -10279,7 +10279,7 @@ Yap_absmi(int inp)
|
||||
}
|
||||
else {
|
||||
saveregs();
|
||||
d0 = p_and(Yap_Eval(d0), Yap_Eval(d1));
|
||||
d0 = p_and(Yap_Eval(d0), Yap_Eval(d1) PASS_REGS);
|
||||
setregs();
|
||||
if (d0 == 0L) {
|
||||
saveregs();
|
||||
@ -10327,7 +10327,7 @@ Yap_absmi(int inp)
|
||||
}
|
||||
else {
|
||||
saveregs();
|
||||
d0 = p_and(Yap_Eval(d0), MkIntegerTerm(d1));
|
||||
d0 = p_and(Yap_Eval(d0), MkIntegerTerm(d1) PASS_REGS);
|
||||
setregs();
|
||||
if (d0 == 0L) {
|
||||
saveregs();
|
||||
@ -10372,7 +10372,7 @@ Yap_absmi(int inp)
|
||||
}
|
||||
else {
|
||||
saveregs();
|
||||
d0 = p_or(Yap_Eval(d0), Yap_Eval(d1));
|
||||
d0 = p_or(Yap_Eval(d0), Yap_Eval(d1) PASS_REGS);
|
||||
setregs();
|
||||
if (d0 == 0L) {
|
||||
saveregs();
|
||||
@ -10417,7 +10417,7 @@ Yap_absmi(int inp)
|
||||
}
|
||||
else {
|
||||
saveregs();
|
||||
d0 = p_or(Yap_Eval(d0), MkIntegerTerm(d1));
|
||||
d0 = p_or(Yap_Eval(d0), MkIntegerTerm(d1) PASS_REGS);
|
||||
if (d0 == 0L) {
|
||||
saveregs();
|
||||
Yap_Error(LOCAL_Error_TYPE, LOCAL_Error_Term, LOCAL_ErrorMessage);
|
||||
@ -10457,7 +10457,7 @@ Yap_absmi(int inp)
|
||||
}
|
||||
else {
|
||||
saveregs();
|
||||
d0 = p_or(Yap_Eval(d0), Yap_Eval(d1));
|
||||
d0 = p_or(Yap_Eval(d0), Yap_Eval(d1) PASS_REGS);
|
||||
setregs();
|
||||
if (d0 == 0L) {
|
||||
saveregs();
|
||||
@ -10505,7 +10505,7 @@ Yap_absmi(int inp)
|
||||
}
|
||||
else {
|
||||
saveregs();
|
||||
d0 = p_or(Yap_Eval(d0), MkIntegerTerm(d1));
|
||||
d0 = p_or(Yap_Eval(d0), MkIntegerTerm(d1) PASS_REGS);
|
||||
setregs();
|
||||
if (d0 == 0L) {
|
||||
saveregs();
|
||||
@ -10549,11 +10549,11 @@ Yap_absmi(int inp)
|
||||
if (i2 < 0)
|
||||
d0 = MkIntegerTerm(SLR(IntOfTerm(d0), -i2));
|
||||
else
|
||||
d0 = do_sll(IntOfTerm(d0),i2);
|
||||
d0 = do_sll(IntOfTerm(d0),i2 PASS_REGS);
|
||||
}
|
||||
else {
|
||||
saveregs();
|
||||
d0 = p_sll(Yap_Eval(d0), Yap_Eval(d1));
|
||||
d0 = p_sll(Yap_Eval(d0), Yap_Eval(d1) PASS_REGS);
|
||||
setregs();
|
||||
}
|
||||
if (d0 == 0L) {
|
||||
@ -10594,11 +10594,11 @@ Yap_absmi(int inp)
|
||||
{
|
||||
Int d1 = PREG->u.xxn.c;
|
||||
if (IsIntTerm(d0)) {
|
||||
d0 = do_sll(IntOfTerm(d0), (Int)d1);
|
||||
d0 = do_sll(IntOfTerm(d0), (Int)d1 PASS_REGS);
|
||||
}
|
||||
else {
|
||||
saveregs();
|
||||
d0 = p_sll(Yap_Eval(d0), MkIntegerTerm(d1));
|
||||
d0 = p_sll(Yap_Eval(d0), MkIntegerTerm(d1) PASS_REGS);
|
||||
setregs();
|
||||
}
|
||||
}
|
||||
@ -10635,11 +10635,11 @@ Yap_absmi(int inp)
|
||||
if (i2 < 0)
|
||||
d0 = MkIntegerTerm(SLR(d1, -i2));
|
||||
else
|
||||
d0 = do_sll(d1,i2);
|
||||
d0 = do_sll(d1,i2 PASS_REGS);
|
||||
}
|
||||
else {
|
||||
saveregs();
|
||||
d0 = p_sll(MkIntegerTerm(d1), Yap_Eval(d0));
|
||||
d0 = p_sll(MkIntegerTerm(d1), Yap_Eval(d0) PASS_REGS);
|
||||
setregs();
|
||||
}
|
||||
}
|
||||
@ -10680,11 +10680,11 @@ Yap_absmi(int inp)
|
||||
if (i2 < 0)
|
||||
d0 = MkIntegerTerm(SLR(IntOfTerm(d0), -i2));
|
||||
else
|
||||
d0 = do_sll(IntOfTerm(d0),i2);
|
||||
d0 = do_sll(IntOfTerm(d0),i2 PASS_REGS);
|
||||
}
|
||||
else {
|
||||
saveregs();
|
||||
d0 = p_sll(Yap_Eval(d0), Yap_Eval(d1));
|
||||
d0 = p_sll(Yap_Eval(d0), Yap_Eval(d1) PASS_REGS);
|
||||
setregs();
|
||||
}
|
||||
if (d0 == 0L) {
|
||||
@ -10728,11 +10728,11 @@ Yap_absmi(int inp)
|
||||
{
|
||||
Int d1 = PREG->u.yxn.c;
|
||||
if (IsIntTerm(d0)) {
|
||||
d0 = do_sll(IntOfTerm(d0), Yap_Eval(d1));
|
||||
d0 = do_sll(IntOfTerm(d0), Yap_Eval(d1) PASS_REGS);
|
||||
}
|
||||
else {
|
||||
saveregs();
|
||||
d0 = p_sll(Yap_Eval(d0), MkIntegerTerm(d1));
|
||||
d0 = p_sll(Yap_Eval(d0), MkIntegerTerm(d1) PASS_REGS);
|
||||
setregs();
|
||||
}
|
||||
}
|
||||
@ -10773,11 +10773,11 @@ Yap_absmi(int inp)
|
||||
if (i2 < 0)
|
||||
d0 = MkIntegerTerm(SLR(d1, -i2));
|
||||
else
|
||||
d0 = do_sll(d1,i2);
|
||||
d0 = do_sll(d1,i2 PASS_REGS);
|
||||
}
|
||||
else {
|
||||
saveregs();
|
||||
d0 = p_sll(MkIntegerTerm(d1), Yap_Eval(0));
|
||||
d0 = p_sll(MkIntegerTerm(d1), Yap_Eval(0) PASS_REGS);
|
||||
setregs();
|
||||
}
|
||||
}
|
||||
@ -10819,13 +10819,13 @@ Yap_absmi(int inp)
|
||||
if (IsIntTerm(d0) && IsIntTerm(d1)) {
|
||||
Int i2 = IntOfTerm(d1);
|
||||
if (i2 < 0)
|
||||
d0 = do_sll(IntOfTerm(d0), -i2);
|
||||
d0 = do_sll(IntOfTerm(d0), -i2 PASS_REGS);
|
||||
else
|
||||
d0 = MkIntTerm(SLR(IntOfTerm(d0), i2));
|
||||
}
|
||||
else {
|
||||
saveregs();
|
||||
d0 = p_slr(Yap_Eval(d0), Yap_Eval(d1));
|
||||
d0 = p_slr(Yap_Eval(d0), Yap_Eval(d1) PASS_REGS);
|
||||
setregs();
|
||||
}
|
||||
if (d0 == 0L) {
|
||||
@ -10870,7 +10870,7 @@ Yap_absmi(int inp)
|
||||
}
|
||||
else {
|
||||
saveregs();
|
||||
d0 = p_slr(Yap_Eval(d0), MkIntegerTerm(d1));
|
||||
d0 = p_slr(Yap_Eval(d0), MkIntegerTerm(d1) PASS_REGS);
|
||||
setregs();
|
||||
if (d0 == 0L) {
|
||||
saveregs();
|
||||
@ -10905,13 +10905,13 @@ Yap_absmi(int inp)
|
||||
if (IsIntTerm(d0)) {
|
||||
Int i2 = IntOfTerm(d0);
|
||||
if (i2 < 0)
|
||||
d0 = do_sll(d1, -i2);
|
||||
d0 = do_sll(d1, -i2 PASS_REGS);
|
||||
else
|
||||
d0 = MkIntegerTerm(SLR(d1, i2));
|
||||
}
|
||||
else {
|
||||
saveregs();
|
||||
d0 = p_slr(MkIntegerTerm(d1), Yap_Eval(d0));
|
||||
d0 = p_slr(MkIntegerTerm(d1), Yap_Eval(d0) PASS_REGS);
|
||||
setregs();
|
||||
}
|
||||
}
|
||||
@ -10950,13 +10950,13 @@ Yap_absmi(int inp)
|
||||
if (IsIntTerm(d0) && IsIntTerm(d1)) {
|
||||
Int i2 = IntOfTerm(d1);
|
||||
if (i2 < 0)
|
||||
d0 = do_sll(IntOfTerm(d0), -i2);
|
||||
d0 = do_sll(IntOfTerm(d0), -i2 PASS_REGS);
|
||||
else
|
||||
d0 = MkIntTerm(SLR(IntOfTerm(d0), i2));
|
||||
}
|
||||
else {
|
||||
saveregs();
|
||||
d0 = p_slr(Yap_Eval(d0), Yap_Eval(d1));
|
||||
d0 = p_slr(Yap_Eval(d0), Yap_Eval(d1) PASS_REGS);
|
||||
setregs();
|
||||
}
|
||||
BEGP(pt0);
|
||||
@ -11004,7 +11004,7 @@ Yap_absmi(int inp)
|
||||
}
|
||||
else {
|
||||
saveregs();
|
||||
d0 = p_slr(Yap_Eval(d0), MkIntegerTerm(d1));
|
||||
d0 = p_slr(Yap_Eval(d0), MkIntegerTerm(d1) PASS_REGS);
|
||||
setregs();
|
||||
if (d0 == 0L) {
|
||||
saveregs();
|
||||
@ -11041,13 +11041,13 @@ Yap_absmi(int inp)
|
||||
if (IsIntTerm(d0)) {
|
||||
Int i2 = IntOfTerm(d0);
|
||||
if (i2 < 0)
|
||||
d0 = do_sll(d1, -i2);
|
||||
d0 = do_sll(d1, -i2 PASS_REGS);
|
||||
else
|
||||
d0 = MkIntegerTerm(SLR(d1, i2));
|
||||
}
|
||||
else {
|
||||
saveregs();
|
||||
d0 = p_slr(MkIntegerTerm(d1), Yap_Eval(d0));
|
||||
d0 = p_slr(MkIntegerTerm(d1), Yap_Eval(d0) PASS_REGS);
|
||||
setregs();
|
||||
}
|
||||
}
|
||||
@ -13432,7 +13432,6 @@ Yap_absmi(int inp)
|
||||
}
|
||||
PP = NULL;
|
||||
SREG = (CELL *) pen;
|
||||
fprintf(stderr,"Here I was\n");
|
||||
ASP = ENV_YREG;
|
||||
if (ASP > (CELL *)PROTECT_FROZEN_B(B))
|
||||
ASP = (CELL *)PROTECT_FROZEN_B(B);
|
||||
|
@ -852,7 +852,6 @@ Yap_NewPredPropByFunctor(FunctorEntry *fe, Term cur_mod)
|
||||
p->FunctorOfPred = fe;
|
||||
WRITE_UNLOCK(fe->FRWLock);
|
||||
{
|
||||
CACHE_REGS
|
||||
Yap_inform_profiler_of_clause(&(p->OpcodeOfPred), &(p->OpcodeOfPred)+1, p, GPROF_NEW_PRED_FUNC);
|
||||
if (!(p->PredFlags & (CPredFlag|AsmPredFlag))) {
|
||||
Yap_inform_profiler_of_clause(&(p->cs.p_code.ExpandCode), &(p->cs.p_code.ExpandCode)+1, p, GPROF_NEW_PRED_FUNC);
|
||||
@ -966,7 +965,6 @@ Yap_NewPredPropByAtom(AtomEntry *ae, Term cur_mod)
|
||||
p->FunctorOfPred = (Functor)AbsAtom(ae);
|
||||
WRITE_UNLOCK(ae->ARWLock);
|
||||
{
|
||||
CACHE_REGS
|
||||
Yap_inform_profiler_of_clause(&(p->OpcodeOfPred), &(p->OpcodeOfPred)+1, p, GPROF_NEW_PRED_ATOM);
|
||||
if (!(p->PredFlags & (CPredFlag|AsmPredFlag))) {
|
||||
Yap_inform_profiler_of_clause(&(p->cs.p_code.ExpandCode), &(p->cs.p_code.ExpandCode)+1, p, GPROF_NEW_PRED_ATOM);
|
||||
@ -1057,8 +1055,10 @@ Yap_GetValue(Atom a)
|
||||
if (IsApplTerm(out)) {
|
||||
Functor f = FunctorOfTerm(out);
|
||||
if (f == FunctorDouble) {
|
||||
CACHE_REGS
|
||||
out = MkFloatTerm(FloatOfTerm(out));
|
||||
} else if (f == FunctorLongInt) {
|
||||
CACHE_REGS
|
||||
out = MkLongIntTerm(LongIntOfTerm(out));
|
||||
}
|
||||
#ifdef USE_GMP
|
||||
|
@ -2056,10 +2056,7 @@ a_try(op_numbers opcode, CELL lab, CELL opr, int nofalts, int hascut, yamop *cod
|
||||
save_machine_regs();
|
||||
siglongjmp(cip->CompilerBotch,2);
|
||||
}
|
||||
{
|
||||
CACHE_REGS
|
||||
Yap_inform_profiler_of_clause(newcp, (char *)(newcp)+size, ap, GPROF_INDEX);
|
||||
}
|
||||
Yap_inform_profiler_of_clause(newcp, (char *)(newcp)+size, ap, GPROF_INDEX);
|
||||
Yap_LUIndexSpace_CP += size;
|
||||
#ifdef DEBUG
|
||||
Yap_NewCps++;
|
||||
|
25
C/arith1.c
25
C/arith1.c
@ -29,7 +29,7 @@ static char SccsId[] = "%W% %G%";
|
||||
#include "eval.h"
|
||||
|
||||
static Term
|
||||
float_to_int(Float v)
|
||||
float_to_int(Float v USES_REGS)
|
||||
{
|
||||
#if USE_GMP
|
||||
Int i = (Int)v;
|
||||
@ -44,7 +44,7 @@ float_to_int(Float v)
|
||||
#endif
|
||||
}
|
||||
|
||||
#define RBIG_FL(v) return(float_to_int(v))
|
||||
#define RBIG_FL(v) return(float_to_int(v PASS_REGS))
|
||||
|
||||
typedef struct init_un_eval {
|
||||
char *OpName;
|
||||
@ -118,7 +118,7 @@ double my_rint(double x)
|
||||
#endif
|
||||
|
||||
static Int
|
||||
msb(Int inp) /* calculate the most significant bit for an integer */
|
||||
msb(Int inp USES_REGS) /* calculate the most significant bit for an integer */
|
||||
{
|
||||
/* the obvious solution: do it by using binary search */
|
||||
Int out = 0;
|
||||
@ -141,7 +141,7 @@ msb(Int inp) /* calculate the most significant bit for an integer */
|
||||
}
|
||||
|
||||
static Int
|
||||
lsb(Int inp) /* calculate the least significant bit for an integer */
|
||||
lsb(Int inp USES_REGS) /* calculate the least significant bit for an integer */
|
||||
{
|
||||
/* the obvious solution: do it by using binary search */
|
||||
Int out = 0;
|
||||
@ -165,7 +165,7 @@ lsb(Int inp) /* calculate the least significant bit for an integer */
|
||||
}
|
||||
|
||||
static Int
|
||||
popcount(Int inp) /* calculate the least significant bit for an integer */
|
||||
popcount(Int inp USES_REGS) /* calculate the least significant bit for an integer */
|
||||
{
|
||||
/* the obvious solution: do it by using binary search */
|
||||
Int c = 0, j = 0, m = ((CELL)1);
|
||||
@ -185,7 +185,7 @@ popcount(Int inp) /* calculate the least significant bit for an integer */
|
||||
}
|
||||
|
||||
static Term
|
||||
eval1(Int fi, Term t) {
|
||||
eval1(Int fi, Term t USES_REGS) {
|
||||
arith1_op f = fi;
|
||||
switch (f) {
|
||||
case op_uplus:
|
||||
@ -586,7 +586,7 @@ eval1(Int fi, Term t) {
|
||||
case op_msb:
|
||||
switch (ETypeOfTerm(t)) {
|
||||
case long_int_e:
|
||||
RINT(msb(IntegerOfTerm(t)));
|
||||
RINT(msb(IntegerOfTerm(t) PASS_REGS));
|
||||
case double_e:
|
||||
return Yap_ArithError(TYPE_ERROR_INTEGER, t, "msb(%f)", FloatOfTerm(t));
|
||||
case big_int_e:
|
||||
@ -599,7 +599,7 @@ eval1(Int fi, Term t) {
|
||||
case op_lsb:
|
||||
switch (ETypeOfTerm(t)) {
|
||||
case long_int_e:
|
||||
RINT(lsb(IntegerOfTerm(t)));
|
||||
RINT(lsb(IntegerOfTerm(t) PASS_REGS));
|
||||
case double_e:
|
||||
return Yap_ArithError(TYPE_ERROR_INTEGER, t, "lsb(%f)", FloatOfTerm(t));
|
||||
case big_int_e:
|
||||
@ -612,7 +612,7 @@ eval1(Int fi, Term t) {
|
||||
case op_popcount:
|
||||
switch (ETypeOfTerm(t)) {
|
||||
case long_int_e:
|
||||
RINT(popcount(IntegerOfTerm(t)));
|
||||
RINT(popcount(IntegerOfTerm(t) PASS_REGS));
|
||||
case double_e:
|
||||
return Yap_ArithError(TYPE_ERROR_INTEGER, t, "popcount(%f)", FloatOfTerm(t));
|
||||
case big_int_e:
|
||||
@ -699,7 +699,8 @@ eval1(Int fi, Term t) {
|
||||
|
||||
Term Yap_eval_unary(Int f, Term t)
|
||||
{
|
||||
return eval1(f,t);
|
||||
CACHE_REGS
|
||||
return eval1(f,t PASS_REGS);
|
||||
}
|
||||
|
||||
static InitUnEntry InitUnTab[] = {
|
||||
@ -758,7 +759,7 @@ p_unary_is( USES_REGS1 )
|
||||
return FALSE;
|
||||
}
|
||||
if (IsIntTerm(t)) {
|
||||
Term tout = Yap_FoundArithError(eval1(IntegerOfTerm(t), top), Deref(ARG3));
|
||||
Term tout = Yap_FoundArithError(eval1(IntegerOfTerm(t), top PASS_REGS), Deref(ARG3));
|
||||
if (!tout)
|
||||
return FALSE;
|
||||
return Yap_unify_constant(ARG1,tout);
|
||||
@ -781,7 +782,7 @@ p_unary_is( USES_REGS1 )
|
||||
P = FAILCODE;
|
||||
return(FALSE);
|
||||
}
|
||||
if (!(out=Yap_FoundArithError(eval1(p->FOfEE, top),Deref(ARG3))))
|
||||
if (!(out=Yap_FoundArithError(eval1(p->FOfEE, top PASS_REGS),Deref(ARG3))))
|
||||
return FALSE;
|
||||
return Yap_unify_constant(ARG1,out);
|
||||
}
|
||||
|
71
C/arith2.c
71
C/arith2.c
@ -37,7 +37,7 @@ typedef struct init_un_eval {
|
||||
|
||||
|
||||
static Term
|
||||
p_mod(Term t1, Term t2) {
|
||||
p_mod(Term t1, Term t2 USES_REGS) {
|
||||
switch (ETypeOfTerm(t1)) {
|
||||
case (CELL)long_int_e:
|
||||
switch (ETypeOfTerm(t2)) {
|
||||
@ -97,7 +97,7 @@ p_mod(Term t1, Term t2) {
|
||||
}
|
||||
|
||||
static Term
|
||||
p_div2(Term t1, Term t2) {
|
||||
p_div2(Term t1, Term t2 USES_REGS) {
|
||||
switch (ETypeOfTerm(t1)) {
|
||||
case (CELL)long_int_e:
|
||||
switch (ETypeOfTerm(t2)) {
|
||||
@ -163,7 +163,7 @@ p_div2(Term t1, Term t2) {
|
||||
}
|
||||
|
||||
static Term
|
||||
p_rem(Term t1, Term t2) {
|
||||
p_rem(Term t1, Term t2 USES_REGS) {
|
||||
switch (ETypeOfTerm(t1)) {
|
||||
case (CELL)long_int_e:
|
||||
switch (ETypeOfTerm(t2)) {
|
||||
@ -215,7 +215,7 @@ p_rem(Term t1, Term t2) {
|
||||
|
||||
|
||||
static Term
|
||||
p_rdiv(Term t1, Term t2) {
|
||||
p_rdiv(Term t1, Term t2 USES_REGS) {
|
||||
#ifdef USE_GMP
|
||||
switch (ETypeOfTerm(t1)) {
|
||||
case (CELL)double_e:
|
||||
@ -266,7 +266,7 @@ p_rdiv(Term t1, Term t2) {
|
||||
Floating point division: /
|
||||
*/
|
||||
static Term
|
||||
p_fdiv(Term t1, Term t2)
|
||||
p_fdiv(Term t1, Term t2 USES_REGS)
|
||||
{
|
||||
switch (ETypeOfTerm(t1)) {
|
||||
case long_int_e:
|
||||
@ -338,7 +338,7 @@ p_fdiv(Term t1, Term t2)
|
||||
xor #
|
||||
*/
|
||||
static Term
|
||||
p_xor(Term t1, Term t2)
|
||||
p_xor(Term t1, Term t2 USES_REGS)
|
||||
{
|
||||
switch (ETypeOfTerm(t1)) {
|
||||
case long_int_e:
|
||||
@ -382,7 +382,7 @@ p_xor(Term t1, Term t2)
|
||||
atan2: arc tangent x/y
|
||||
*/
|
||||
static Term
|
||||
p_atan2(Term t1, Term t2)
|
||||
p_atan2(Term t1, Term t2 USES_REGS)
|
||||
{
|
||||
switch (ETypeOfTerm(t1)) {
|
||||
case long_int_e:
|
||||
@ -461,7 +461,7 @@ p_atan2(Term t1, Term t2)
|
||||
power: x^y
|
||||
*/
|
||||
static Term
|
||||
p_power(Term t1, Term t2)
|
||||
p_power(Term t1, Term t2 USES_REGS)
|
||||
{
|
||||
switch (ETypeOfTerm(t1)) {
|
||||
case long_int_e:
|
||||
@ -577,7 +577,7 @@ ipow(Int x, Int p)
|
||||
power: x^y
|
||||
*/
|
||||
static Term
|
||||
p_exp(Term t1, Term t2)
|
||||
p_exp(Term t1, Term t2 USES_REGS)
|
||||
{
|
||||
switch (ETypeOfTerm(t1)) {
|
||||
case long_int_e:
|
||||
@ -669,7 +669,7 @@ p_exp(Term t1, Term t2)
|
||||
}
|
||||
|
||||
static Int
|
||||
gcd(Int m11,Int m21)
|
||||
gcd(Int m11,Int m21 USES_REGS)
|
||||
{
|
||||
/* Blankinship algorithm, provided by Miguel Filgueiras */
|
||||
Int m12=1, m22=0, k;
|
||||
@ -719,7 +719,7 @@ Int gcdmult(Int m11,Int m21,Int *pm11) /* *pm11 gets multiplier of m11 */
|
||||
module gcd
|
||||
*/
|
||||
static Term
|
||||
p_gcd(Term t1, Term t2)
|
||||
p_gcd(Term t1, Term t2 USES_REGS)
|
||||
{
|
||||
switch (ETypeOfTerm(t1)) {
|
||||
case long_int_e:
|
||||
@ -731,7 +731,7 @@ p_gcd(Term t1, Term t2)
|
||||
i1 = (i1 >= 0 ? i1 : -i1);
|
||||
i2 = (i2 >= 0 ? i2 : -i2);
|
||||
|
||||
RINT(gcd(i1,i2));
|
||||
RINT(gcd(i1,i2 PASS_REGS));
|
||||
}
|
||||
case double_e:
|
||||
return Yap_ArithError(TYPE_ERROR_INTEGER, t2, "gcd/2");
|
||||
@ -957,56 +957,57 @@ p_max(Term t1, Term t2)
|
||||
}
|
||||
|
||||
static Term
|
||||
eval2(Int fi, Term t1, Term t2) {
|
||||
eval2(Int fi, Term t1, Term t2 USES_REGS) {
|
||||
arith2_op f = fi;
|
||||
switch (f) {
|
||||
case op_plus:
|
||||
return p_plus(t1, t2);
|
||||
return p_plus(t1, t2 PASS_REGS);
|
||||
case op_minus:
|
||||
return p_minus(t1, t2);
|
||||
return p_minus(t1, t2 PASS_REGS);
|
||||
case op_times:
|
||||
return p_times(t1, t2);
|
||||
return p_times(t1, t2 PASS_REGS);
|
||||
case op_div:
|
||||
return p_div(t1, t2);
|
||||
return p_div(t1, t2 PASS_REGS);
|
||||
case op_idiv:
|
||||
return p_div2(t1, t2);
|
||||
return p_div2(t1, t2 PASS_REGS);
|
||||
case op_and:
|
||||
return p_and(t1, t2);
|
||||
return p_and(t1, t2 PASS_REGS);
|
||||
case op_or:
|
||||
return p_or(t1, t2);
|
||||
return p_or(t1, t2 PASS_REGS);
|
||||
case op_sll:
|
||||
return p_sll(t1, t2);
|
||||
return p_sll(t1, t2 PASS_REGS);
|
||||
case op_slr:
|
||||
return p_slr(t1, t2);
|
||||
return p_slr(t1, t2 PASS_REGS);
|
||||
case op_mod:
|
||||
return p_mod(t1, t2);
|
||||
return p_mod(t1, t2 PASS_REGS);
|
||||
case op_rem:
|
||||
return p_rem(t1, t2);
|
||||
return p_rem(t1, t2 PASS_REGS);
|
||||
case op_fdiv:
|
||||
return p_fdiv(t1, t2);
|
||||
return p_fdiv(t1, t2 PASS_REGS);
|
||||
case op_xor:
|
||||
return p_xor(t1, t2);
|
||||
return p_xor(t1, t2 PASS_REGS);
|
||||
case op_atan2:
|
||||
return p_atan2(t1, t2);
|
||||
return p_atan2(t1, t2 PASS_REGS);
|
||||
case op_power:
|
||||
return p_exp(t1, t2);
|
||||
return p_exp(t1, t2 PASS_REGS);
|
||||
case op_power2:
|
||||
return p_power(t1, t2);
|
||||
return p_power(t1, t2 PASS_REGS);
|
||||
case op_gcd:
|
||||
return p_gcd(t1, t2);
|
||||
return p_gcd(t1, t2 PASS_REGS);
|
||||
case op_min:
|
||||
return p_min(t1, t2);
|
||||
case op_max:
|
||||
return p_max(t1, t2);
|
||||
case op_rdiv:
|
||||
return p_rdiv(t1, t2);
|
||||
return p_rdiv(t1, t2 PASS_REGS);
|
||||
}
|
||||
RERROR();
|
||||
}
|
||||
|
||||
Term Yap_eval_binary(Int f, Term t1, Term t2)
|
||||
{
|
||||
return eval2(f,t1,t2);
|
||||
CACHE_REGS
|
||||
return eval2(f,t1,t2 PASS_REGS);
|
||||
}
|
||||
|
||||
static InitBinEntry InitBinTab[] = {
|
||||
@ -1058,7 +1059,7 @@ p_binary_is( USES_REGS1 )
|
||||
return FALSE;
|
||||
}
|
||||
if (IsIntTerm(t)) {
|
||||
Term tout = Yap_FoundArithError(eval2(IntOfTerm(t), t1, t2), 0L);
|
||||
Term tout = Yap_FoundArithError(eval2(IntOfTerm(t), t1, t2 PASS_REGS), 0L);
|
||||
if (!tout)
|
||||
return FALSE;
|
||||
return Yap_unify_constant(ARG1,tout);
|
||||
@ -1081,7 +1082,7 @@ p_binary_is( USES_REGS1 )
|
||||
P = FAILCODE;
|
||||
return(FALSE);
|
||||
}
|
||||
if (!(out=Yap_FoundArithError(eval2(p->FOfEE, t1, t2), 0L)))
|
||||
if (!(out=Yap_FoundArithError(eval2(p->FOfEE, t1, t2 PASS_REGS), 0L)))
|
||||
return FALSE;
|
||||
return Yap_unify_constant(ARG1,out);
|
||||
}
|
||||
@ -1105,7 +1106,7 @@ do_arith23(arith2_op op USES_REGS)
|
||||
t2 = Yap_Eval(Deref(ARG2));
|
||||
if (t2 == 0L)
|
||||
return FALSE;
|
||||
if (!(out=Yap_FoundArithError(eval2(op, t1, t2), 0L)))
|
||||
if (!(out=Yap_FoundArithError(eval2(op, t1, t2 PASS_REGS), 0L)))
|
||||
return FALSE;
|
||||
return Yap_unify_constant(ARG3,out);
|
||||
}
|
||||
|
34
C/bignum.c
34
C/bignum.c
@ -320,6 +320,7 @@ Yap_MkULLIntTerm(YAP_ULONG_LONG n)
|
||||
/* try to scan it as a bignum */
|
||||
mpz_init_set_str (new, tmp, 10);
|
||||
if (mpz_fits_slong_p(new)) {
|
||||
CACHE_REGS
|
||||
return MkIntegerTerm(mpz_get_si(new));
|
||||
}
|
||||
t = Yap_MkBigIntTerm(new);
|
||||
@ -346,6 +347,38 @@ p_is_bignum( USES_REGS1 )
|
||||
#endif
|
||||
}
|
||||
|
||||
static Int
|
||||
p_nb_set_bit( USES_REGS1 )
|
||||
{
|
||||
#ifdef USE_GMP
|
||||
Term t = Deref(ARG1);
|
||||
Term ti = Deref(ARG2);
|
||||
Int i;
|
||||
|
||||
if (!(
|
||||
IsNonVarTerm(t) &&
|
||||
IsApplTerm(t) &&
|
||||
FunctorOfTerm(t) == FunctorBigInt &&
|
||||
RepAppl(t)[1] == BIG_INT
|
||||
))
|
||||
return FALSE;
|
||||
if (!IsIntegerTerm(ti)) {
|
||||
return FALSE;
|
||||
}
|
||||
if (!IsIntegerTerm(ti)) {
|
||||
return FALSE;
|
||||
}
|
||||
i = IntegerOfTerm(ti);
|
||||
if (i < 0) {
|
||||
return FALSE;
|
||||
}
|
||||
Yap_gmp_set_bit(i, t);
|
||||
return TRUE;
|
||||
#else
|
||||
return FALSE;
|
||||
#endif
|
||||
}
|
||||
|
||||
static Int
|
||||
p_has_bignums( USES_REGS1 )
|
||||
{
|
||||
@ -560,4 +593,5 @@ Yap_InitBigNums(void)
|
||||
Yap_InitCPred("$bignum", 1, p_is_bignum, SafePredFlag);
|
||||
Yap_InitCPred("rational", 3, p_rational, 0);
|
||||
Yap_InitCPred("rational", 1, p_is_rational, SafePredFlag);
|
||||
Yap_InitCPred("nb_set_bit", 2, p_nb_set_bit, SafePredFlag);
|
||||
}
|
||||
|
@ -738,6 +738,7 @@ YAP_IsCompoundTerm(Term t)
|
||||
X_API Term
|
||||
YAP_MkIntTerm(Int n)
|
||||
{
|
||||
CACHE_REGS
|
||||
Term I;
|
||||
BACKUP_H();
|
||||
|
||||
@ -854,6 +855,7 @@ YAP_BlobOfTerm(Term t)
|
||||
X_API Term
|
||||
YAP_MkFloatTerm(double n)
|
||||
{
|
||||
CACHE_REGS
|
||||
Term t;
|
||||
BACKUP_H();
|
||||
|
||||
@ -3734,6 +3736,7 @@ YAP_CloseList(Term t0, Term tail)
|
||||
X_API int
|
||||
YAP_IsAttVar(Term t)
|
||||
{
|
||||
CACHE_REGS
|
||||
t = Deref(t);
|
||||
if (!IsVarTerm(t))
|
||||
return FALSE;
|
||||
@ -3743,6 +3746,7 @@ YAP_IsAttVar(Term t)
|
||||
X_API Term
|
||||
YAP_AttsOfVar(Term t)
|
||||
{
|
||||
CACHE_REGS
|
||||
attvar_record *attv;
|
||||
|
||||
t = Deref(t);
|
||||
@ -4023,6 +4027,7 @@ YAP_TagOfTerm(Term t)
|
||||
if (IsVarTerm(t)) {
|
||||
CELL *pt = VarOfTerm(t);
|
||||
if (IsUnboundVar(pt)) {
|
||||
CACHE_REGS
|
||||
if (IsAttVar(pt))
|
||||
return YAP_TAG_ATT;
|
||||
return YAP_TAG_UNBOUND;
|
||||
|
@ -4910,6 +4910,7 @@ replace_integer(Term orig, UInt new)
|
||||
return MkIntTerm(new);
|
||||
/* should create an old integer */
|
||||
if (!IsApplTerm(orig)) {
|
||||
CACHE_REGS
|
||||
Yap_Error(SYSTEM_ERROR,orig,"%uld-->%uld where it should increase",(unsigned long int)IntegerOfTerm(orig),(unsigned long int)new);
|
||||
return MkIntegerTerm(new);
|
||||
}
|
||||
|
@ -471,6 +471,7 @@ ShowOp (char *f, struct PSEUDO *cpc)
|
||||
case 'b':
|
||||
/* write a variable bitmap for a call */
|
||||
{
|
||||
CACHE_REGS
|
||||
int max = arg/(8*sizeof(CELL)), i;
|
||||
CELL *ptr = cptr;
|
||||
for (i = 0; i <= max; i++) {
|
||||
@ -490,7 +491,10 @@ ShowOp (char *f, struct PSEUDO *cpc)
|
||||
}
|
||||
break;
|
||||
case 'd':
|
||||
Yap_DebugPlWrite (MkIntegerTerm (arg));
|
||||
{
|
||||
CACHE_REGS
|
||||
Yap_DebugPlWrite (MkIntegerTerm (arg));
|
||||
}
|
||||
break;
|
||||
case 'z':
|
||||
Yap_DebugPlWrite (MkIntTerm (cpc->rnd3));
|
||||
|
@ -2381,6 +2381,7 @@ GetDBLUKey(PredEntry *ap)
|
||||
{
|
||||
PELOCK(63,ap);
|
||||
if (ap->PredFlags & NumberDBPredFlag) {
|
||||
CACHE_REGS
|
||||
Int id = ap->src.IndxId;
|
||||
UNLOCK(ap->PELock);
|
||||
return MkIntegerTerm(id);
|
||||
@ -2430,6 +2431,7 @@ UnifyDBKey(DBRef DBSP, PropFlags flags, Term t)
|
||||
static int
|
||||
UnifyDBNumber(DBRef DBSP, Term t)
|
||||
{
|
||||
CACHE_REGS
|
||||
DBProp p = DBSP->Parent;
|
||||
DBRef ref;
|
||||
Int i = 1;
|
||||
|
211
C/exo.c
211
C/exo.c
@ -40,50 +40,67 @@
|
||||
|
||||
#define MAX_ARITY 256
|
||||
|
||||
#define FNV32_PRIME 16777619
|
||||
#define FNV64_PRIME ((UInt)1099511628211)
|
||||
|
||||
#define FNV32_OFFSET 2166136261
|
||||
#define FNV64_OFFSET ((UInt)14695981039346656037)
|
||||
|
||||
|
||||
/* Simple hash function:
|
||||
first component is the base key.
|
||||
hash0 spreads extensions coming from different elements.
|
||||
spread over j quadrants.
|
||||
*/
|
||||
static UInt
|
||||
HASH(UInt hash0, UInt j, CELL *cl, struct index_t *it)
|
||||
static BITS32
|
||||
HASH(UInt arity, CELL *cl, UInt bnds[], UInt sz)
|
||||
{
|
||||
Term t = cl[j];
|
||||
UInt sz = it->hsize;
|
||||
if (IsIntTerm(t))
|
||||
return (17*(IntOfTerm(t) + (hash0+1)*j ) ) % sz;
|
||||
return (17*(((UInt)AtomOfTerm(t)>>5) + (hash0+1)*j ) ) % sz;
|
||||
UInt hash;
|
||||
UInt j=0;
|
||||
|
||||
hash = FNV32_OFFSET;
|
||||
while (j < arity) {
|
||||
if (bnds[j]) {
|
||||
unsigned char *i=(unsigned char*)(cl+j);
|
||||
unsigned char *m=(unsigned char*)(cl+(j+1));
|
||||
|
||||
while (i < m) {
|
||||
hash = hash ^ i[0];
|
||||
hash = hash * FNV32_PRIME;
|
||||
i++;
|
||||
}
|
||||
}
|
||||
j++;
|
||||
}
|
||||
return hash;
|
||||
}
|
||||
|
||||
static UInt
|
||||
NEXT(UInt hash, Term t, UInt j, struct index_t *it)
|
||||
static BITS32
|
||||
NEXT(UInt hash)
|
||||
{
|
||||
return (hash+(j+1)*997) % (it->hsize);
|
||||
return (hash*997);
|
||||
}
|
||||
|
||||
/* search for matching elements */
|
||||
static int
|
||||
MATCH(CELL *clp, CELL *kvp, UInt j, struct index_t *it, UInt bnds[])
|
||||
MATCH(CELL *clp, CELL *kvp, UInt arity, UInt bnds[])
|
||||
{
|
||||
if ((kvp - it->cls)%it->arity != j)
|
||||
return FALSE;
|
||||
do {
|
||||
if ( bnds[j] && *clp != *kvp)
|
||||
UInt j = 0;
|
||||
while (j< arity) {
|
||||
if ( bnds[j] && clp[j] != kvp[j])
|
||||
return FALSE;
|
||||
clp--;
|
||||
kvp--;
|
||||
} while (j-- != 0);
|
||||
j++;
|
||||
}
|
||||
return TRUE;
|
||||
}
|
||||
|
||||
static void
|
||||
ADD_TO_TRY_CHAIN(CELL *kvp, CELL *cl, struct index_t *it)
|
||||
{
|
||||
UInt old = (kvp-it->cls)/it->arity;
|
||||
UInt new = (cl-it->cls)/it->arity;
|
||||
UInt *links = it->links;
|
||||
UInt tmp = links[old]; /* points to the end of the chain */
|
||||
BITS32 old = (kvp-it->cls)/it->arity;
|
||||
BITS32 new = (cl-it->cls)/it->arity;
|
||||
BITS32 *links = it->links;
|
||||
BITS32 tmp = links[old]; /* points to the end of the chain */
|
||||
|
||||
if (!tmp) {
|
||||
links[old] = links[new] = new;
|
||||
@ -111,50 +128,33 @@ ADD_TO_TRY_CHAIN(CELL *kvp, CELL *cl, struct index_t *it)
|
||||
* match ci..j ck..j -> find j = minarg(cij \= c2j)
|
||||
* else
|
||||
*/
|
||||
static void
|
||||
INSERT(CELL *cl, struct index_t *it, UInt arity, UInt base, UInt hash0, UInt bnds[])
|
||||
static int
|
||||
INSERT(CELL *cl, struct index_t *it, UInt arity, UInt base, UInt bnds[])
|
||||
{
|
||||
UInt j = base;
|
||||
CELL *kvp;
|
||||
UInt hash;
|
||||
BITS32 hash;
|
||||
int coll_count = 0;
|
||||
|
||||
/* skip over argument */
|
||||
while (!bnds[j]) {
|
||||
j++;
|
||||
}
|
||||
/* j is the firs bound element */
|
||||
/* check if we match */
|
||||
hash = hash0 = HASH(hash0, j, cl, it);
|
||||
//if (exo_write) printf("h=%ld j=%ld %lx\n", hash, j, cl[j]);
|
||||
|
||||
hash = HASH(arity, cl, bnds, it->hsize);
|
||||
next:
|
||||
/* loop to insert element */
|
||||
kvp = it->key[hash];
|
||||
kvp = EXO_OFFSET_TO_ADDRESS(it, it->key [hash % it->hsize]);
|
||||
if (kvp == NULL) {
|
||||
/* simple case, new entry */
|
||||
it->nentries++;
|
||||
it->key[hash] = cl+j;
|
||||
return;
|
||||
} else if (MATCH(cl+j, kvp, j, it, bnds)) {
|
||||
/* collision */
|
||||
UInt k;
|
||||
CELL *target;
|
||||
|
||||
for (k =j+1, target = kvp+1; k < arity; k++,target++ ) {
|
||||
if (bnds[k]) {
|
||||
if (*target != cl[k]) {
|
||||
/* found a new forking point */
|
||||
// printf("j=%ld hash0=%ld cl[j]=%lx\n", j, hash0, cl[j]);
|
||||
INSERT(cl, it, arity, k, hash0, bnds);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
it->key[hash % it->hsize ] = EXO_ADDRESS_TO_OFFSET(it, cl);
|
||||
return TRUE;
|
||||
} else if (MATCH(kvp, cl, arity, bnds)) {
|
||||
it->ntrys++;
|
||||
ADD_TO_TRY_CHAIN(kvp, cl, it);
|
||||
return;
|
||||
return TRUE;
|
||||
} else {
|
||||
coll_count++;
|
||||
if (coll_count == 32)
|
||||
return FALSE;
|
||||
it->ncollisions++;
|
||||
hash = NEXT(hash, cl[j], j, it);
|
||||
// printf("#");
|
||||
hash = NEXT(hash);
|
||||
//if (exo_write) printf("N=%ld\n", hash);
|
||||
goto next;
|
||||
}
|
||||
@ -165,45 +165,31 @@ LOOKUP(struct index_t *it, UInt arity, UInt j, UInt bnds[])
|
||||
{
|
||||
CACHE_REGS
|
||||
CELL *kvp;
|
||||
UInt hash, hash0 = 0;
|
||||
BITS32 hash;
|
||||
|
||||
/* j is the firs bound element */
|
||||
/* check if we match */
|
||||
hash:
|
||||
hash = hash0 = HASH(hash0, j, XREGS+1, it);
|
||||
hash = HASH(arity, XREGS+1, bnds, it->hsize);
|
||||
next:
|
||||
/* loop to insert element */
|
||||
kvp = it->key[hash];
|
||||
kvp = EXO_OFFSET_TO_ADDRESS(it, it->key[hash % it->hsize]);
|
||||
if (kvp == NULL) {
|
||||
/* simple case, no element */
|
||||
return FAILCODE;
|
||||
} else if (MATCH(XREGS+(j+1), kvp, j, it, bnds)) {
|
||||
/* found element */
|
||||
UInt k;
|
||||
CELL *target;
|
||||
|
||||
for (k =j+1, target = kvp+1; k < arity; k++ ) {
|
||||
if (bnds[k]) {
|
||||
if (*target != XREGS[k+1]) {
|
||||
j = k;
|
||||
goto hash;
|
||||
}
|
||||
}
|
||||
target++;
|
||||
}
|
||||
S = target-arity;
|
||||
} else if (MATCH(kvp, XREGS+1, arity, bnds)) {
|
||||
S = kvp;
|
||||
if (!it->is_key && it->links[(S-it->cls)/arity])
|
||||
return it->code;
|
||||
else
|
||||
return NEXTOP(NEXTOP(it->code,lp),lp);
|
||||
} else {
|
||||
/* collision */
|
||||
hash = NEXT(hash, XREGS[j+1], j, it);
|
||||
hash = NEXT(hash);
|
||||
goto next;
|
||||
}
|
||||
}
|
||||
|
||||
static void
|
||||
static int
|
||||
fill_hash(UInt bmap, struct index_t *it, UInt bnds[])
|
||||
{
|
||||
UInt i;
|
||||
@ -211,12 +197,13 @@ fill_hash(UInt bmap, struct index_t *it, UInt bnds[])
|
||||
CELL *cl = it->cls;
|
||||
|
||||
for (i=0; i < it->nels; i++) {
|
||||
INSERT(cl, it, arity, 0, 0, bnds);
|
||||
if (!INSERT(cl, it, arity, 0, bnds))
|
||||
return FALSE;
|
||||
cl += arity;
|
||||
}
|
||||
for (i=0; i < it->hsize; i++) {
|
||||
if (it->key[i]) {
|
||||
UInt offset = (it->key[i]-it->cls)/arity;
|
||||
UInt offset = it->key[i]/arity;
|
||||
UInt last = it->links[offset];
|
||||
if (last) {
|
||||
/* the chain used to point straight to the last, and the last back to the origibal first */
|
||||
@ -225,6 +212,7 @@ fill_hash(UInt bmap, struct index_t *it, UInt bnds[])
|
||||
}
|
||||
}
|
||||
}
|
||||
return TRUE;
|
||||
}
|
||||
|
||||
static struct index_t *
|
||||
@ -246,6 +234,7 @@ add_index(struct index_t **ip, UInt bmap, PredEntry *ap, UInt count, UInt bnds[]
|
||||
Yap_Error(OUT_OF_HEAP_ERROR, TermNil, LOCAL_ErrorMessage);
|
||||
return NULL;
|
||||
}
|
||||
i->is_key = FALSE;
|
||||
i->next = *ip;
|
||||
i->prev = NULL;
|
||||
i->nels = ncls;
|
||||
@ -255,7 +244,7 @@ add_index(struct index_t **ip, UInt bmap, PredEntry *ap, UInt count, UInt bnds[]
|
||||
i->is_key = FALSE;
|
||||
i->hsize = 2*ncls;
|
||||
if (count) {
|
||||
if (!(base = (CELL *)Yap_AllocCodeSpace(sizeof(CELL)*(ncls+i->hsize)))) {
|
||||
if (!(base = (CELL *)Yap_AllocCodeSpace(sizeof(BITS32)*(ncls+i->hsize)))) {
|
||||
CACHE_REGS
|
||||
save_machine_regs();
|
||||
LOCAL_Error_Size = sizeof(CELL)*(ncls+i->hsize);
|
||||
@ -267,18 +256,51 @@ add_index(struct index_t **ip, UInt bmap, PredEntry *ap, UInt count, UInt bnds[]
|
||||
bzero(base, sizeof(CELL)*(ncls+i->hsize));
|
||||
}
|
||||
i->size = sizeof(CELL)*(ncls+i->hsize)+sz+sizeof(struct index_t);
|
||||
i->key = (CELL **)base;
|
||||
i->key = (CELL *)base;
|
||||
i->links = (CELL *)(base+i->hsize);
|
||||
i->ncollisions = i->nentries = i->ntrys = 0;
|
||||
i->cls = (CELL *)((ADDR)ap->cs.p_code.FirstClause+2*sizeof(struct index_t *));
|
||||
*ip = i;
|
||||
if (count) {
|
||||
fill_hash(bmap, i, bnds);
|
||||
printf("entries=%ld collisions=%ld trys=%ld\n", i->nentries, i->ncollisions, i->ntrys);
|
||||
if (!i->ntrys) {
|
||||
i->is_key = TRUE;
|
||||
if (base != realloc(base, i->hsize*sizeof(CELL)))
|
||||
while (count) {
|
||||
if (!fill_hash(bmap, i, bnds)) {
|
||||
size_t sz;
|
||||
i->hsize += ncls;
|
||||
if (i->is_key) {
|
||||
sz = i->hsize*sizeof(BITS32);
|
||||
} else {
|
||||
sz = (ncls+i->hsize)*sizeof(BITS32);
|
||||
}
|
||||
if (base != realloc(base, sz))
|
||||
return FALSE;
|
||||
bzero(base, sz);
|
||||
i->key = (CELL *)base;
|
||||
i->links = (CELL *)(base+i->hsize);
|
||||
i->ncollisions = i->nentries = i->ntrys = 0;
|
||||
continue;
|
||||
}
|
||||
fprintf(stderr, "entries=%ld collisions=%ld trys=%ld\n", i->nentries, i->ncollisions, i->ntrys);
|
||||
if (!i->ntrys && !i->is_key) {
|
||||
i->is_key = TRUE;
|
||||
if (base != realloc(base, i->hsize*sizeof(BITS32)))
|
||||
return FALSE;
|
||||
}
|
||||
/* our hash table is just too large */
|
||||
if (( i->nentries+i->ncollisions )*10 < i->hsize) {
|
||||
size_t sz;
|
||||
i->hsize = ( i->nentries+i->ncollisions )*10;
|
||||
if (i->is_key) {
|
||||
sz = i->hsize*sizeof(BITS32);
|
||||
} else {
|
||||
sz = (ncls+i->hsize)*sizeof(BITS32);
|
||||
}
|
||||
if (base != realloc(base, sz))
|
||||
return FALSE;
|
||||
bzero(base, sz);
|
||||
i->key = (CELL *)base;
|
||||
i->links = (CELL *)(base+i->hsize);
|
||||
i->ncollisions = i->nentries = i->ntrys = 0;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
ptr = (yamop *)(i+1);
|
||||
@ -337,14 +359,11 @@ Yap_ExoLookup(PredEntry *ap USES_REGS)
|
||||
}
|
||||
|
||||
while (i) {
|
||||
if (i->is_key) {
|
||||
if ((i->bmap & bmap) == i->bmap) {
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
if (i->bmap == bmap) {
|
||||
break;
|
||||
}
|
||||
// if (i->is_key && (i->bmap & bmap) == i->bmap) {
|
||||
// break;
|
||||
// }
|
||||
if (i->bmap == bmap) {
|
||||
break;
|
||||
}
|
||||
ip = &i->next;
|
||||
i = i->next;
|
||||
@ -362,9 +381,9 @@ CELL
|
||||
Yap_NextExo(choiceptr cptr, struct index_t *it)
|
||||
{
|
||||
CACHE_REGS
|
||||
CELL offset = EXO_ADDRESS_TO_OFFSET(it,(CELL *)((CELL *)(B+1))[it->arity]);
|
||||
CELL offset = ADDRESS_TO_LINK(it,(CELL *)((CELL *)(B+1))[it->arity]);
|
||||
CELL next = it->links[offset];
|
||||
((CELL *)(B+1))[it->arity] = (CELL)EXO_OFFSET_TO_ADDRESS(it, next);
|
||||
((CELL *)(B+1))[it->arity] = (CELL)LINK_TO_ADDRESS(it, next);
|
||||
S = it->cls+it->arity*offset;
|
||||
return next;
|
||||
}
|
||||
|
@ -2121,6 +2121,7 @@ p_nb_beam_close( USES_REGS1 )
|
||||
static void
|
||||
PushBeam(CELL *pt, CELL *npt, UInt hsize, Term key, Term to)
|
||||
{
|
||||
CACHE_REGS
|
||||
UInt off = hsize, off2 = hsize;
|
||||
Term toff, toff2;
|
||||
|
||||
@ -2166,6 +2167,7 @@ PushBeam(CELL *pt, CELL *npt, UInt hsize, Term key, Term to)
|
||||
static void
|
||||
DelBeamMax(CELL *pt, CELL *pt2, UInt sz)
|
||||
{
|
||||
CACHE_REGS
|
||||
UInt off = IntegerOfTerm(pt2[1]);
|
||||
UInt indx = 0;
|
||||
Term tk, ti, tv;
|
||||
@ -2240,6 +2242,7 @@ DelBeamMax(CELL *pt, CELL *pt2, UInt sz)
|
||||
static Term
|
||||
DelBeamMin(CELL *pt, CELL *pt2, UInt sz)
|
||||
{
|
||||
CACHE_REGS
|
||||
UInt off2 = IntegerOfTerm(pt[1]);
|
||||
Term ov = pt2[3*off2+2]; /* return value */
|
||||
UInt indx = 0;
|
||||
|
@ -132,6 +132,14 @@ Yap_gmp_add_int_big(Int i, Term t)
|
||||
}
|
||||
}
|
||||
|
||||
/* add i + b using temporary bigint new */
|
||||
void
|
||||
Yap_gmp_set_bit(Int i, Term t)
|
||||
{
|
||||
MP_INT *b = Yap_BigIntOfTerm(t);
|
||||
mpz_setbit(b, i);
|
||||
}
|
||||
|
||||
/* sub i - b using temporary bigint new */
|
||||
Term
|
||||
Yap_gmp_sub_int_big(Int i, Term t)
|
||||
@ -384,6 +392,7 @@ Yap_gmp_sll_big_int(Term t, Int i)
|
||||
} else {
|
||||
mpz_init(&new);
|
||||
if (i == Int_MIN) {
|
||||
CACHE_REGS
|
||||
return Yap_ArithError(RESOURCE_ERROR_HUGE_INT, MkIntegerTerm(i), "<</2");
|
||||
}
|
||||
mpz_fdiv_q_2exp(&new, b, -i);
|
||||
@ -706,6 +715,7 @@ Yap_gmp_mod_big_int(Term t, Int i2)
|
||||
Term
|
||||
Yap_gmp_mod_int_big(Int i1, Term t)
|
||||
{
|
||||
CACHE_REGS
|
||||
CELL *pt = RepAppl(t);
|
||||
if (pt[1] != BIG_INT) {
|
||||
return Yap_ArithError(TYPE_ERROR_INTEGER, t, "mod/2");
|
||||
@ -782,6 +792,7 @@ Yap_gmp_rem_big_int(Term t, Int i2)
|
||||
Term
|
||||
Yap_gmp_rem_int_big(Int i1, Term t)
|
||||
{
|
||||
CACHE_REGS
|
||||
CELL *pt = RepAppl(t);
|
||||
if (pt[1] != BIG_INT) {
|
||||
return Yap_ArithError(TYPE_ERROR_INTEGER, t, "mod/2");
|
||||
@ -815,6 +826,7 @@ Yap_gmp_gcd_big_big(Term t1, Term t2)
|
||||
Term
|
||||
Yap_gmp_gcd_int_big(Int i, Term t)
|
||||
{
|
||||
CACHE_REGS
|
||||
CELL *pt = RepAppl(t);
|
||||
if (pt[1] != BIG_INT) {
|
||||
return Yap_ArithError(TYPE_ERROR_INTEGER, t, "mod/2");
|
||||
@ -855,6 +867,7 @@ Yap_gmp_to_float(Term t)
|
||||
Term
|
||||
Yap_gmp_add_float_big(Float d, Term t)
|
||||
{
|
||||
CACHE_REGS
|
||||
CELL *pt = RepAppl(t);
|
||||
if (pt[1] == BIG_INT) {
|
||||
MP_INT *b = Yap_BigIntOfTerm(t);
|
||||
@ -868,6 +881,7 @@ Yap_gmp_add_float_big(Float d, Term t)
|
||||
Term
|
||||
Yap_gmp_sub_float_big(Float d, Term t)
|
||||
{
|
||||
CACHE_REGS
|
||||
CELL *pt = RepAppl(t);
|
||||
if (pt[1] == BIG_INT) {
|
||||
MP_INT *b = Yap_BigIntOfTerm(t);
|
||||
@ -881,6 +895,7 @@ Yap_gmp_sub_float_big(Float d, Term t)
|
||||
Term
|
||||
Yap_gmp_sub_big_float(Term t, Float d)
|
||||
{
|
||||
CACHE_REGS
|
||||
CELL *pt = RepAppl(t);
|
||||
if (pt[1] == BIG_INT) {
|
||||
MP_INT *b = Yap_BigIntOfTerm(t);
|
||||
@ -894,6 +909,7 @@ Yap_gmp_sub_big_float(Term t, Float d)
|
||||
Term
|
||||
Yap_gmp_mul_float_big(Float d, Term t)
|
||||
{
|
||||
CACHE_REGS
|
||||
CELL *pt = RepAppl(t);
|
||||
if (pt[1] == BIG_INT) {
|
||||
MP_INT *b = Yap_BigIntOfTerm(t);
|
||||
@ -907,6 +923,7 @@ Yap_gmp_mul_float_big(Float d, Term t)
|
||||
Term
|
||||
Yap_gmp_fdiv_float_big(Float d, Term t)
|
||||
{
|
||||
CACHE_REGS
|
||||
CELL *pt = RepAppl(t);
|
||||
if (pt[1] == BIG_INT) {
|
||||
MP_INT *b = Yap_BigIntOfTerm(t);
|
||||
@ -920,6 +937,7 @@ Yap_gmp_fdiv_float_big(Float d, Term t)
|
||||
Term
|
||||
Yap_gmp_fdiv_big_float(Term t, Float d)
|
||||
{
|
||||
CACHE_REGS
|
||||
CELL *pt = RepAppl(t);
|
||||
if (pt[1] == BIG_INT) {
|
||||
MP_INT *b = Yap_BigIntOfTerm(t);
|
||||
@ -943,6 +961,7 @@ Yap_gmp_exp_int_int(Int i1, Int i2)
|
||||
Term
|
||||
Yap_gmp_exp_big_int(Term t, Int i)
|
||||
{
|
||||
CACHE_REGS
|
||||
MP_INT new;
|
||||
|
||||
CELL *pt = RepAppl(t);
|
||||
@ -969,6 +988,7 @@ Yap_gmp_exp_big_int(Term t, Int i)
|
||||
Term
|
||||
Yap_gmp_exp_int_big(Int i, Term t)
|
||||
{
|
||||
CACHE_REGS
|
||||
CELL *pt = RepAppl(t);
|
||||
if (pt[1] == BIG_INT) {
|
||||
return Yap_ArithError(RESOURCE_ERROR_HUGE_INT, t, "^/2");
|
||||
@ -982,6 +1002,7 @@ Yap_gmp_exp_int_big(Int i, Term t)
|
||||
Term
|
||||
Yap_gmp_exp_big_big(Term t1, Term t2)
|
||||
{
|
||||
CACHE_REGS
|
||||
CELL *pt1 = RepAppl(t1);
|
||||
CELL *pt2 = RepAppl(t2);
|
||||
Float dbl1, dbl2;
|
||||
@ -1116,6 +1137,7 @@ Yap_gmq_rdiv_big_big(Term t1, Term t2)
|
||||
Term
|
||||
Yap_gmp_fdiv_int_big(Int i1, Term t2)
|
||||
{
|
||||
CACHE_REGS
|
||||
MP_RAT new;
|
||||
MP_RAT *b1, *b2;
|
||||
MP_RAT bb1, bb2;
|
||||
@ -1142,6 +1164,7 @@ Yap_gmp_fdiv_int_big(Int i1, Term t2)
|
||||
Term
|
||||
Yap_gmp_fdiv_big_int(Term t2, Int i1)
|
||||
{
|
||||
CACHE_REGS
|
||||
MP_RAT new;
|
||||
MP_RAT *b1, *b2;
|
||||
MP_RAT bb1, bb2;
|
||||
@ -1168,6 +1191,7 @@ Yap_gmp_fdiv_big_int(Term t2, Int i1)
|
||||
Term
|
||||
Yap_gmp_fdiv_big_big(Term t1, Term t2)
|
||||
{
|
||||
CACHE_REGS
|
||||
CELL *pt1 = RepAppl(t1);
|
||||
CELL *pt2 = RepAppl(t2);
|
||||
MP_RAT new;
|
||||
@ -1602,6 +1626,7 @@ Yap_gmp_float_integer_part(Term t)
|
||||
Term
|
||||
Yap_gmp_sign(Term t)
|
||||
{
|
||||
CACHE_REGS
|
||||
CELL *pt = RepAppl(t);
|
||||
if (pt[1] == BIG_INT) {
|
||||
return MkIntegerTerm(mpz_sgn(Yap_BigIntOfTerm(t)));
|
||||
@ -1613,6 +1638,7 @@ Yap_gmp_sign(Term t)
|
||||
Term
|
||||
Yap_gmp_lsb(Term t)
|
||||
{
|
||||
CACHE_REGS
|
||||
CELL *pt = RepAppl(t);
|
||||
if (pt[1] == BIG_INT) {
|
||||
MP_INT *big = Yap_BigIntOfTerm(t);
|
||||
@ -1629,6 +1655,7 @@ Yap_gmp_lsb(Term t)
|
||||
Term
|
||||
Yap_gmp_msb(Term t)
|
||||
{
|
||||
CACHE_REGS
|
||||
CELL *pt = RepAppl(t);
|
||||
if (pt[1] == BIG_INT) {
|
||||
MP_INT *big = Yap_BigIntOfTerm(t);
|
||||
@ -1645,6 +1672,7 @@ Yap_gmp_msb(Term t)
|
||||
Term
|
||||
Yap_gmp_popcount(Term t)
|
||||
{
|
||||
CACHE_REGS
|
||||
CELL *pt = RepAppl(t);
|
||||
if (pt[1] == BIG_INT) {
|
||||
MP_INT *big = Yap_BigIntOfTerm(t);
|
||||
|
@ -1923,10 +1923,7 @@ suspend_indexing(ClauseDef *min, ClauseDef *max, PredEntry *ap, struct intermedi
|
||||
} else {
|
||||
Yap_IndexSpace_EXT += sz;
|
||||
}
|
||||
{
|
||||
CACHE_REGS
|
||||
Yap_inform_profiler_of_clause(ncode, (CODEADDR)ncode+sz, ap, GPROF_NEW_EXPAND_BLOCK);
|
||||
}
|
||||
Yap_inform_profiler_of_clause(ncode, (CODEADDR)ncode+sz, ap, GPROF_NEW_EXPAND_BLOCK);
|
||||
/* create an expand_block */
|
||||
ncode->opc = Yap_opcode(_expand_clauses);
|
||||
ncode->u.sssllp.p = ap;
|
||||
|
6
C/init.c
6
C/init.c
@ -882,6 +882,8 @@ Yap_InitCPredBack(char *Name, unsigned long int Arity,
|
||||
static void
|
||||
InitStdPreds(void)
|
||||
{
|
||||
void initIO(void);
|
||||
|
||||
Yap_InitCPreds();
|
||||
Yap_InitBackCPreds();
|
||||
BACKUP_MACHINE_REGS();
|
||||
@ -1288,17 +1290,21 @@ Yap_InitWorkspace(UInt Heap, UInt Stack, UInt Trail, UInt Atts, UInt max_table_s
|
||||
if (Heap < MinHeapSpace)
|
||||
Heap = MinHeapSpace;
|
||||
Heap = AdjustPageSize(Heap * K);
|
||||
Heap /= (K);
|
||||
/* sanity checking for data areas */
|
||||
if (Trail < MinTrailSpace)
|
||||
Trail = MinTrailSpace;
|
||||
Trail = AdjustPageSize(Trail * K);
|
||||
Trail /= (K);
|
||||
if (Stack < MinStackSpace)
|
||||
Stack = MinStackSpace;
|
||||
Stack = AdjustPageSize(Stack * K);
|
||||
Stack /= (K);
|
||||
if (!Atts)
|
||||
Atts = 2048*sizeof(CELL);
|
||||
else
|
||||
Atts = AdjustPageSize(Atts * K);
|
||||
Atts /= (K);
|
||||
#if defined(YAPOR) || defined(THREADS)
|
||||
worker_id = 0;
|
||||
#endif /* YAPOR || THREADS */
|
||||
|
8
C/save.c
8
C/save.c
@ -261,7 +261,7 @@ open_file(char *my_file, int flag)
|
||||
#endif /* O_BINARY */
|
||||
#endif /* M_WILLIAMS */
|
||||
{
|
||||
splfild = 0; /* We do not have an open file */
|
||||
splfild = -1; /* We do not have an open file */
|
||||
return -1;
|
||||
}
|
||||
#ifdef undf0
|
||||
@ -1466,7 +1466,7 @@ OpenRestore(char *inpf, char *YapLibDir, CELL *Astate, CELL *ATrail, CELL *AStac
|
||||
} else {
|
||||
strncat(LOCAL_FileNameBuf, inpf, YAP_FILENAME_MAX-1);
|
||||
}
|
||||
if (inpf != NULL && (splfild = open_file(inpf, O_RDONLY)) > 0) {
|
||||
if (inpf != NULL && !((splfild = open_file(inpf, O_RDONLY)) < 0)) {
|
||||
if ((mode = try_open(inpf,Astate,ATrail,AStack,AHeap,save_buffer,streamp)) != FAIL_RESTORE) {
|
||||
return mode;
|
||||
}
|
||||
@ -1499,7 +1499,7 @@ OpenRestore(char *inpf, char *YapLibDir, CELL *Astate, CELL *ATrail, CELL *AStac
|
||||
#endif
|
||||
if (YAP_LIBDIR != NULL) {
|
||||
cat_file_name(LOCAL_FileNameBuf, YAP_LIBDIR, inpf, YAP_FILENAME_MAX);
|
||||
if ((splfild = open_file(LOCAL_FileNameBuf, O_RDONLY)) > 0) {
|
||||
if (!((splfild = open_file(LOCAL_FileNameBuf, O_RDONLY)) < 0)) {
|
||||
if ((mode = try_open(LOCAL_FileNameBuf,Astate,ATrail,AStack,AHeap,save_buffer,streamp)) != FAIL_RESTORE) {
|
||||
return mode;
|
||||
}
|
||||
@ -1508,7 +1508,7 @@ OpenRestore(char *inpf, char *YapLibDir, CELL *Astate, CELL *ATrail, CELL *AStac
|
||||
}
|
||||
#if _MSC_VER || defined(__MINGW32__)
|
||||
if ((inpf = Yap_RegistryGetString("startup"))) {
|
||||
if ((splfild = open_file(inpf, O_RDONLY)) > 0) {
|
||||
if (!((splfild = open_file(inpf, O_RDONLY)) < 0)) {
|
||||
if ((mode = try_open(inpf,Astate,ATrail,AStack,AHeap,save_buffer,streamp)) != FAIL_RESTORE) {
|
||||
return mode;
|
||||
}
|
||||
|
@ -230,6 +230,7 @@ extern double atof(const char *);
|
||||
static Term
|
||||
float_send(char *s, int sign)
|
||||
{
|
||||
CACHE_REGS
|
||||
Float f = (Float)atof(s);
|
||||
#if HAVE_FINITE
|
||||
if (yap_flags[LANGUAGE_MODE_FLAG] == 1) { /* iso */
|
||||
@ -512,6 +513,7 @@ num_send_error_message(char s[])
|
||||
static Term
|
||||
get_num(int *chp, int *chbuffp, IOSTREAM *inp_stream, char *s, UInt max_size, int sign)
|
||||
{
|
||||
CACHE_REGS
|
||||
char *sp = s;
|
||||
int ch = *chp;
|
||||
Int val = 0L, base = ch - '0';
|
||||
|
@ -74,9 +74,6 @@ p_creep( USES_REGS1 )
|
||||
static Int
|
||||
p_stop_creeping( USES_REGS1 )
|
||||
{
|
||||
Atom at;
|
||||
PredEntry *pred;
|
||||
|
||||
LOCK(LOCAL_SignalLock);
|
||||
LOCAL_ActiveSignals &= ~(YAP_CREEP_SIGNAL|YAP_DELAY_CREEP_SIGNAL);
|
||||
if (!LOCAL_ActiveSignals) {
|
||||
|
16
C/stdpreds.c
16
C/stdpreds.c
@ -292,7 +292,7 @@ STD_PROTO(static Int p_values, ( USES_REGS1 ));
|
||||
STD_PROTO(static CODEADDR *FindAtom, (CODEADDR, int *));
|
||||
#endif /* undefined */
|
||||
STD_PROTO(static Int p_opdec, ( USES_REGS1 ));
|
||||
STD_PROTO(static Term get_num, (char *));
|
||||
STD_PROTO(static Term get_num, (char * USES_REGS));
|
||||
STD_PROTO(static Int p_name, ( USES_REGS1 ));
|
||||
STD_PROTO(static Int p_atom_chars, ( USES_REGS1 ));
|
||||
STD_PROTO(static Int p_atom_codes, ( USES_REGS1 ));
|
||||
@ -537,7 +537,7 @@ strtod(s, pe)
|
||||
#endif
|
||||
|
||||
static Term
|
||||
get_num(char *t)
|
||||
get_num(char *t USES_REGS)
|
||||
{
|
||||
Term out;
|
||||
IOSTREAM *smem = Sopenmem(&t, NULL, "r");
|
||||
@ -832,7 +832,7 @@ p_name( USES_REGS1 )
|
||||
return(FALSE);
|
||||
}
|
||||
if (IsAtomTerm(t) && AtomOfTerm(t) == AtomNil) {
|
||||
if ((NewT = get_num(String)) == TermNil) {
|
||||
if ((NewT = get_num(String PASS_REGS)) == TermNil) {
|
||||
Atom at;
|
||||
while ((at = Yap_LookupAtom(String)) == NIL) {
|
||||
if (!Yap_growheap(FALSE, 0, NULL)) {
|
||||
@ -1375,7 +1375,7 @@ p_atom_concat( USES_REGS1 )
|
||||
if (wide_mode) {
|
||||
wchar_t *cptr = (wchar_t *)(((AtomEntry *)Yap_PreAllocCodeSpace())->StrOfAE), *cpt0;
|
||||
wchar_t *top = (wchar_t *)AuxSp;
|
||||
unsigned char *atom_str;
|
||||
unsigned char *atom_str = NULL;
|
||||
Atom ahead;
|
||||
UInt sz;
|
||||
|
||||
@ -2227,7 +2227,7 @@ p_number_chars( USES_REGS1 )
|
||||
}
|
||||
}
|
||||
*s++ = '\0';
|
||||
if ((NewT = get_num(String)) == TermNil) {
|
||||
if ((NewT = get_num(String PASS_REGS)) == TermNil) {
|
||||
Yap_Error(SYNTAX_ERROR, gen_syntax_error(Yap_LookupAtom(String), "number_chars"), "while scanning %s", String);
|
||||
return (FALSE);
|
||||
}
|
||||
@ -2294,7 +2294,7 @@ p_number_atom( USES_REGS1 )
|
||||
return(FALSE);
|
||||
}
|
||||
s = RepAtom(AtomOfTerm(t))->StrOfAE;
|
||||
if ((NewT = get_num(s)) == TermNil) {
|
||||
if ((NewT = get_num(s PASS_REGS)) == TermNil) {
|
||||
Yap_Error(SYNTAX_ERROR, gen_syntax_error(Yap_LookupAtom(String), "number_atom"), "while scanning %s", s);
|
||||
return (FALSE);
|
||||
}
|
||||
@ -2387,7 +2387,7 @@ p_number_codes( USES_REGS1 )
|
||||
}
|
||||
}
|
||||
*s++ = '\0';
|
||||
if ((NewT = get_num(String)) == TermNil) {
|
||||
if ((NewT = get_num(String PASS_REGS)) == TermNil) {
|
||||
Yap_Error(SYNTAX_ERROR, gen_syntax_error(Yap_LookupAtom(String), "number_codes"), "while scanning %s", String);
|
||||
return (FALSE);
|
||||
}
|
||||
@ -2452,7 +2452,7 @@ p_atom_number( USES_REGS1 )
|
||||
return FALSE;
|
||||
}
|
||||
s = RepAtom(at)->StrOfAE; /* alloc temp space on Trail */
|
||||
if ((NewT = get_num(s)) == TermNil) {
|
||||
if ((NewT = get_num(s PASS_REGS)) == TermNil) {
|
||||
Yap_Error(SYNTAX_ERROR, gen_syntax_error(at, "atom_number"), "while scanning %s", s);
|
||||
return FALSE;
|
||||
}
|
||||
|
@ -1520,9 +1520,18 @@ ExportTerm(Term inp, char * buf, size_t len, UInt arity, int newattvs USES_REGS)
|
||||
Term t = Deref(inp);
|
||||
tr_fr_ptr TR0 = TR;
|
||||
size_t res = 0;
|
||||
CELL *Hi;
|
||||
CELL *Hi = H;
|
||||
|
||||
do {
|
||||
if (IsVarTerm(t) || IsIntTerm(t)) {
|
||||
return export_term_to_buffer(t, buf, buf+ 3*sizeof(CELL), &inp, &inp, len);
|
||||
}
|
||||
if (IsAtomTerm(t)) {
|
||||
Atom at = AtomOfTerm(t);
|
||||
char *b = buf+3*sizeof(CELL);
|
||||
export_atom(at, &b, b, len-3*sizeof(CELL));
|
||||
return export_term_to_buffer(t, buf, b, &inp, &inp, len);
|
||||
}
|
||||
if ((Int)res < 0) {
|
||||
H = Hi;
|
||||
TR = TR0;
|
||||
@ -1634,16 +1643,14 @@ Yap_ImportTerm(char * buf) {
|
||||
CELL *bc = (CELL *)buf;
|
||||
size_t sz = bc[1];
|
||||
Term tinp, tret;
|
||||
|
||||
tinp = bc[2];
|
||||
if (IsVarTerm(tinp))
|
||||
return MkVarTerm();
|
||||
if (IsAtomOrIntTerm(tinp)) {
|
||||
if (IsAtomTerm(tinp)) {
|
||||
char *pt = (char *)AdjustSize(bc+3, buf);
|
||||
return MkAtomTerm(Yap_LookupAtom(pt));
|
||||
} else
|
||||
return tinp;
|
||||
else if (IsIntTerm(tinp))
|
||||
return tinp;
|
||||
else if (IsAtomTerm(tinp)) {
|
||||
tret = MkAtomTerm(AddAtom(NULL,(char *)(bc+3)));
|
||||
return tret;
|
||||
}
|
||||
if (H + sz > ASP)
|
||||
return (Term)0;
|
||||
@ -1654,7 +1661,7 @@ Yap_ImportTerm(char * buf) {
|
||||
} else {
|
||||
tret = AbsPair(H);
|
||||
import_pair(H, (char *)H, buf, H);
|
||||
}
|
||||
}
|
||||
H += sz;
|
||||
return tret;
|
||||
}
|
||||
@ -4921,7 +4928,7 @@ numbervar_singleton(USES_REGS1)
|
||||
}
|
||||
|
||||
static void
|
||||
renumbervar(Term t, Int id)
|
||||
renumbervar(Term t, Int id USES_REGS)
|
||||
{
|
||||
Term *ts = RepAppl(t);
|
||||
ts[1] = MkIntegerTerm(id);
|
||||
@ -4975,7 +4982,7 @@ static Int numbervars_in_complex_term(register CELL *pt0, register CELL *pt0_end
|
||||
continue;
|
||||
}
|
||||
if (singles && ap2 >= InitialH && ap2 < H) {
|
||||
renumbervar(d0, numbv++);
|
||||
renumbervar(d0, numbv++ PASS_REGS);
|
||||
continue;
|
||||
}
|
||||
/* store the terms to visit */
|
||||
|
34
H/TermExt.h
34
H/TermExt.h
@ -60,13 +60,14 @@ blob_type;
|
||||
|
||||
#include "inline-only.h"
|
||||
|
||||
INLINE_ONLY inline EXTERN int IsAttVar (CELL *pt);
|
||||
#define IsAttVar(pt) __IsAttVar((pt) PASS_REGS)
|
||||
|
||||
INLINE_ONLY inline EXTERN int __IsAttVar (CELL *pt USES_REGS);
|
||||
|
||||
INLINE_ONLY inline EXTERN int
|
||||
IsAttVar (CELL *pt)
|
||||
__IsAttVar (CELL *pt USES_REGS)
|
||||
{
|
||||
#ifdef YAP_H
|
||||
CACHE_REGS
|
||||
return (pt)[-1] == (CELL)attvar_e
|
||||
&& pt < H;
|
||||
#else
|
||||
@ -182,13 +183,13 @@ INLINE_ONLY inline EXTERN Float CpFloatUnaligned(CELL *ptr);
|
||||
|
||||
#if SIZEOF_DOUBLE == SIZEOF_LONG_INT
|
||||
|
||||
#define MkFloatTerm(fl) __MkFloatTerm((fl) PASS_REGS)
|
||||
|
||||
INLINE_ONLY inline EXTERN Term MkFloatTerm (Float);
|
||||
INLINE_ONLY inline EXTERN Term __MkFloatTerm (Float USES_REGS);
|
||||
|
||||
INLINE_ONLY inline EXTERN Term
|
||||
MkFloatTerm (Float dbl)
|
||||
__MkFloatTerm (Float dbl USES_REGS)
|
||||
{
|
||||
CACHE_REGS
|
||||
return (Term) ((H[0] = (CELL) FunctorDouble, *(Float *) (H + 1) =
|
||||
dbl, H[2] = EndSpecials, H +=
|
||||
3, AbsAppl (H - 3)));
|
||||
@ -303,12 +304,13 @@ IsFloatTerm (Term t)
|
||||
|
||||
/* extern Functor FunctorLongInt; */
|
||||
|
||||
INLINE_ONLY inline EXTERN Term MkLongIntTerm (Int);
|
||||
#define MkLongIntTerm(i) __MkLongIntTerm((i) PASS_REGS)
|
||||
|
||||
INLINE_ONLY inline EXTERN Term __MkLongIntTerm (Int USES_REGS);
|
||||
|
||||
INLINE_ONLY inline EXTERN Term
|
||||
MkLongIntTerm (Int i)
|
||||
__MkLongIntTerm (Int i USES_REGS)
|
||||
{
|
||||
CACHE_REGS
|
||||
H[0] = (CELL) FunctorLongInt;
|
||||
H[1] = (CELL) (i);
|
||||
H[2] = EndSpecials;
|
||||
@ -546,11 +548,12 @@ IsAttachFunc (Functor f)
|
||||
|
||||
|
||||
|
||||
#define IsAttachedTerm(t) __IsAttachedTerm(t PASS_REGS)
|
||||
|
||||
INLINE_ONLY inline EXTERN Int IsAttachedTerm (Term);
|
||||
INLINE_ONLY inline EXTERN Int __IsAttachedTerm (Term USES_REGS);
|
||||
|
||||
INLINE_ONLY inline EXTERN Int
|
||||
IsAttachedTerm (Term t)
|
||||
__IsAttachedTerm (Term t USES_REGS)
|
||||
{
|
||||
return (Int) ((IsVarTerm (t) && IsAttVar(VarOfTerm(t))));
|
||||
}
|
||||
@ -563,17 +566,16 @@ GlobalIsAttachedTerm (Term t)
|
||||
return (Int) ((IsVarTerm (t) && GlobalIsAttVar(VarOfTerm(t))));
|
||||
}
|
||||
|
||||
INLINE_ONLY inline EXTERN Int SafeIsAttachedTerm (Term);
|
||||
#define SafeIsAttachedTerm(t) __SafeIsAttachedTerm((t) PASS_REGS)
|
||||
|
||||
INLINE_ONLY inline EXTERN Int __SafeIsAttachedTerm (Term USES_REGS);
|
||||
|
||||
INLINE_ONLY inline EXTERN Int
|
||||
SafeIsAttachedTerm (Term t)
|
||||
__SafeIsAttachedTerm (Term t USES_REGS)
|
||||
{
|
||||
return (Int) (IsVarTerm (t) && IsAttVar(VarOfTerm(t)));
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
INLINE_ONLY inline EXTERN exts ExtFromCell (CELL *);
|
||||
|
||||
INLINE_ONLY inline EXTERN exts
|
||||
|
@ -364,10 +364,13 @@ MkPairTerm__ (Term head, Term tail USES_REGS)
|
||||
#define IsAccessFunc(func) ((func) == FunctorAccess)
|
||||
|
||||
#ifdef YAP_H
|
||||
INLINE_ONLY inline EXTERN Term MkIntegerTerm (Int);
|
||||
|
||||
#define MkIntegerTerm(i) __MkIntegerTerm(i PASS_REGS)
|
||||
|
||||
INLINE_ONLY inline EXTERN Term __MkIntegerTerm (Int USES_REGS);
|
||||
|
||||
INLINE_ONLY inline EXTERN Term
|
||||
MkIntegerTerm (Int n)
|
||||
__MkIntegerTerm (Int n USES_REGS)
|
||||
{
|
||||
return (Term) (IntInBnd (n) ? MkIntTerm (n) : MkLongIntTerm (n));
|
||||
}
|
||||
|
@ -341,7 +341,7 @@ void STD_PROTO(Yap_InitSavePreds,(void));
|
||||
/* signals.c */
|
||||
void STD_PROTO(Yap_signal,(yap_signals));
|
||||
void STD_PROTO(Yap_undo_signal,(yap_signals));
|
||||
void STD_PROTO(Yap_InitSignalPreds,(void));
|
||||
void STD_PROTO(Yap_InitSignalCPreds,(void));
|
||||
|
||||
/* sort.c */
|
||||
void STD_PROTO(Yap_InitSortPreds,(void));
|
||||
|
34
H/arith2.h
34
H/arith2.h
@ -26,7 +26,7 @@ add_overflow(Int x, Int i, Int j)
|
||||
}
|
||||
|
||||
inline static Term
|
||||
add_int(Int i, Int j)
|
||||
add_int(Int i, Int j USES_REGS)
|
||||
{
|
||||
Int x = i+j;
|
||||
#if USE_GMP
|
||||
@ -51,7 +51,7 @@ sub_overflow(Int x, Int i, Int j)
|
||||
}
|
||||
|
||||
inline static Term
|
||||
sub_int(Int i, Int j)
|
||||
sub_int(Int i, Int j USES_REGS)
|
||||
{
|
||||
Int x = i-j;
|
||||
#if USE_GMP
|
||||
@ -105,7 +105,7 @@ mul_overflow(Int z, Int i1, Int i2)
|
||||
#endif
|
||||
|
||||
inline static Term
|
||||
times_int(Int i1, Int i2) {
|
||||
times_int(Int i1, Int i2 USES_REGS) {
|
||||
#ifdef USE_GMP
|
||||
Int z;
|
||||
DO_MULTI();
|
||||
@ -151,7 +151,7 @@ clrsb(Int i)
|
||||
#endif
|
||||
|
||||
inline static Term
|
||||
do_sll(Int i, Int j) /* j > 0 */
|
||||
do_sll(Int i, Int j USES_REGS) /* j > 0 */
|
||||
{
|
||||
#ifdef USE_GMP
|
||||
if (
|
||||
@ -174,13 +174,13 @@ do_sll(Int i, Int j) /* j > 0 */
|
||||
|
||||
|
||||
static inline Term
|
||||
p_plus(Term t1, Term t2) {
|
||||
p_plus(Term t1, Term t2 USES_REGS) {
|
||||
switch (ETypeOfTerm(t1)) {
|
||||
case long_int_e:
|
||||
switch (ETypeOfTerm(t2)) {
|
||||
case long_int_e:
|
||||
/* two integers */
|
||||
return add_int(IntegerOfTerm(t1),IntegerOfTerm(t2));
|
||||
return add_int(IntegerOfTerm(t1),IntegerOfTerm(t2) PASS_REGS);
|
||||
case double_e:
|
||||
{
|
||||
/* integer, double */
|
||||
@ -230,13 +230,13 @@ p_plus(Term t1, Term t2) {
|
||||
}
|
||||
|
||||
static Term
|
||||
p_minus(Term t1, Term t2) {
|
||||
p_minus(Term t1, Term t2 USES_REGS) {
|
||||
switch (ETypeOfTerm(t1)) {
|
||||
case long_int_e:
|
||||
switch (ETypeOfTerm(t2)) {
|
||||
case long_int_e:
|
||||
/* two integers */
|
||||
return sub_int(IntegerOfTerm(t1), IntegerOfTerm(t2));
|
||||
return sub_int(IntegerOfTerm(t1), IntegerOfTerm(t2) PASS_REGS);
|
||||
case double_e:
|
||||
{
|
||||
/* integer, double */
|
||||
@ -290,13 +290,13 @@ p_minus(Term t1, Term t2) {
|
||||
|
||||
|
||||
static Term
|
||||
p_times(Term t1, Term t2) {
|
||||
p_times(Term t1, Term t2 USES_REGS) {
|
||||
switch (ETypeOfTerm(t1)) {
|
||||
case long_int_e:
|
||||
switch (ETypeOfTerm(t2)) {
|
||||
case long_int_e:
|
||||
/* two integers */
|
||||
return(times_int(IntegerOfTerm(t1),IntegerOfTerm(t2)));
|
||||
return(times_int(IntegerOfTerm(t1),IntegerOfTerm(t2) PASS_REGS));
|
||||
case double_e:
|
||||
{
|
||||
/* integer, double */
|
||||
@ -348,7 +348,7 @@ p_times(Term t1, Term t2) {
|
||||
}
|
||||
|
||||
static Term
|
||||
p_div(Term t1, Term t2) {
|
||||
p_div(Term t1, Term t2 USES_REGS) {
|
||||
switch (ETypeOfTerm(t1)) {
|
||||
case long_int_e:
|
||||
switch (ETypeOfTerm(t2)) {
|
||||
@ -405,7 +405,7 @@ p_div(Term t1, Term t2) {
|
||||
}
|
||||
|
||||
static Term
|
||||
p_and(Term t1, Term t2) {
|
||||
p_and(Term t1, Term t2 USES_REGS) {
|
||||
switch (ETypeOfTerm(t1)) {
|
||||
case long_int_e:
|
||||
switch (ETypeOfTerm(t2)) {
|
||||
@ -446,7 +446,7 @@ p_and(Term t1, Term t2) {
|
||||
}
|
||||
|
||||
static Term
|
||||
p_or(Term t1, Term t2) {
|
||||
p_or(Term t1, Term t2 USES_REGS) {
|
||||
switch(ETypeOfTerm(t1)) {
|
||||
case long_int_e:
|
||||
switch (ETypeOfTerm(t2)) {
|
||||
@ -487,7 +487,7 @@ p_or(Term t1, Term t2) {
|
||||
}
|
||||
|
||||
static Term
|
||||
p_sll(Term t1, Term t2) {
|
||||
p_sll(Term t1, Term t2 USES_REGS) {
|
||||
switch (ETypeOfTerm(t1)) {
|
||||
case long_int_e:
|
||||
switch (ETypeOfTerm(t2)) {
|
||||
@ -501,7 +501,7 @@ p_sll(Term t1, Term t2) {
|
||||
}
|
||||
RINT(SLR(IntegerOfTerm(t1), -i2));
|
||||
}
|
||||
return do_sll(IntegerOfTerm(t1),i2);
|
||||
return do_sll(IntegerOfTerm(t1),i2 PASS_REGS);
|
||||
}
|
||||
case double_e:
|
||||
return Yap_ArithError(TYPE_ERROR_INTEGER, t2, "<</2");
|
||||
@ -535,7 +535,7 @@ p_sll(Term t1, Term t2) {
|
||||
}
|
||||
|
||||
static Term
|
||||
p_slr(Term t1, Term t2) {
|
||||
p_slr(Term t1, Term t2 USES_REGS) {
|
||||
switch (ETypeOfTerm(t1)) {
|
||||
case long_int_e:
|
||||
switch (ETypeOfTerm(t2)) {
|
||||
@ -547,7 +547,7 @@ p_slr(Term t1, Term t2) {
|
||||
if (i2 == Int_MIN) {
|
||||
return Yap_ArithError(RESOURCE_ERROR_HUGE_INT, t2, ">>/2");
|
||||
}
|
||||
return do_sll(IntegerOfTerm(t1), -i2);
|
||||
return do_sll(IntegerOfTerm(t1), -i2 PASS_REGS);
|
||||
}
|
||||
RINT(SLR(IntegerOfTerm(t1), i2));
|
||||
}
|
||||
|
38
H/clause.h
38
H/clause.h
@ -170,25 +170,43 @@ typedef struct index_t {
|
||||
UInt ntrys;
|
||||
UInt nentries;
|
||||
UInt hsize;
|
||||
CELL **key;
|
||||
BITS32 *key;
|
||||
CELL *cls;
|
||||
CELL *links;
|
||||
BITS32 *links;
|
||||
size_t size;
|
||||
yamop *code;
|
||||
} Index_t;
|
||||
|
||||
INLINE_ONLY EXTERN inline UInt EXO_ADDRESS_TO_OFFSET(struct index_t *it, CELL *ptr);
|
||||
INLINE_ONLY EXTERN inline BITS32 EXO_ADDRESS_TO_OFFSET(struct index_t *it, CELL *ptr);
|
||||
|
||||
INLINE_ONLY EXTERN inline UInt
|
||||
INLINE_ONLY EXTERN inline BITS32
|
||||
EXO_ADDRESS_TO_OFFSET(struct index_t *it, CELL* ptr)
|
||||
{
|
||||
return ptr-it->links;
|
||||
return 1+(ptr-it->cls);
|
||||
}
|
||||
|
||||
INLINE_ONLY EXTERN inline CELL *EXO_OFFSET_TO_ADDRESS(struct index_t *it, UInt off);
|
||||
|
||||
INLINE_ONLY EXTERN inline CELL *
|
||||
EXO_OFFSET_TO_ADDRESS(struct index_t *it, UInt off)
|
||||
EXO_OFFSET_TO_ADDRESS(struct index_t *it, BITS32 off)
|
||||
{
|
||||
if (off == 0L)
|
||||
return NULL;
|
||||
return (it->cls-1)+off;
|
||||
}
|
||||
|
||||
INLINE_ONLY EXTERN inline BITS32 ADDRESS_TO_LINK(struct index_t *it, CELL *ptr);
|
||||
|
||||
INLINE_ONLY EXTERN inline BITS32
|
||||
ADDRESS_TO_LINK(struct index_t *it, CELL* ptr)
|
||||
{
|
||||
return ptr-it->links;
|
||||
}
|
||||
|
||||
INLINE_ONLY EXTERN inline CELL *LINK_TO_ADDRESS(struct index_t *it, BITS32 off);
|
||||
|
||||
INLINE_ONLY EXTERN inline CELL *
|
||||
LINK_TO_ADDRESS(struct index_t *it, BITS32 off)
|
||||
{
|
||||
return it->links+off;
|
||||
}
|
||||
@ -323,8 +341,10 @@ same_lu_block(yamop **paddr, yamop *p)
|
||||
}
|
||||
#endif
|
||||
|
||||
#define Yap_MkStaticRefTerm(cp) __Yap_MkStaticRefTerm((cp) PASS_REGS)
|
||||
|
||||
static inline Term
|
||||
Yap_MkStaticRefTerm(StaticClause *cp)
|
||||
__Yap_MkStaticRefTerm(StaticClause *cp USES_REGS)
|
||||
{
|
||||
Term t[1];
|
||||
t[0] = MkIntegerTerm((Int)cp);
|
||||
@ -337,8 +357,10 @@ Yap_ClauseFromTerm(Term t)
|
||||
return (StaticClause *)IntegerOfTerm(ArgOfTerm(1,t));
|
||||
}
|
||||
|
||||
#define Yap_MkMegaRefTerm(ap, ipc) __Yap_MkMegaRefTerm((ap), (ipc) PASS_REGS)
|
||||
|
||||
static inline Term
|
||||
Yap_MkMegaRefTerm(PredEntry *ap,yamop *ipc)
|
||||
__Yap_MkMegaRefTerm(PredEntry *ap,yamop *ipc USES_REGS)
|
||||
{
|
||||
Term t[2];
|
||||
t[0] = MkIntegerTerm((Int)ap);
|
||||
|
8
H/eval.h
8
H/eval.h
@ -314,12 +314,16 @@ size_t STD_PROTO(Yap_gmp_to_size,(Term, int));
|
||||
|
||||
int STD_PROTO(Yap_term_to_existing_big,(Term, MP_INT *));
|
||||
int STD_PROTO(Yap_term_to_existing_rat,(Term, MP_RAT *));
|
||||
|
||||
void Yap_gmp_set_bit(Int i, Term t);
|
||||
#endif
|
||||
|
||||
INLINE_ONLY inline EXTERN Term Yap_Mk64IntegerTerm(YAP_LONG_LONG);
|
||||
#define Yap_Mk64IntegerTerm(i) __Yap_Mk64IntegerTerm((i) PASS_REGS)
|
||||
|
||||
INLINE_ONLY inline EXTERN Term __Yap_Mk64IntegerTerm(YAP_LONG_LONG USES_REGS);
|
||||
|
||||
INLINE_ONLY inline EXTERN Term
|
||||
Yap_Mk64IntegerTerm(YAP_LONG_LONG i)
|
||||
__Yap_Mk64IntegerTerm(YAP_LONG_LONG i USES_REGS)
|
||||
{
|
||||
if (i <= Int_MAX && i >= Int_MIN) {
|
||||
return MkIntegerTerm((Int)i);
|
||||
|
@ -10,6 +10,7 @@ INIT_SEQ_STRING(size_t n)
|
||||
|
||||
static inline Word
|
||||
EXTEND_SEQ_CODES(Word ptr, int c) {
|
||||
CACHE_REGS
|
||||
ptr[0] = MkIntegerTerm(c);
|
||||
ptr[1] = AbsPair(ptr+2);
|
||||
|
||||
|
@ -13,7 +13,7 @@ static void readswap8(double *buf);
|
||||
static byte get_hostbyteorder(void);
|
||||
static byte get_inbyteorder(void);
|
||||
static uint32 get_wkbType(void);
|
||||
static Term get_point(char *functor);
|
||||
static Term get_point(char *functor USES_REGS);
|
||||
static Term get_linestring(char *functor);
|
||||
static Term get_polygon(char *functor);
|
||||
static Term get_geometry(uint32 type);
|
||||
@ -150,7 +150,7 @@ static void readswap8(double *buf) {
|
||||
cursor += 8;
|
||||
}
|
||||
|
||||
static Term get_point(char *func){
|
||||
static Term get_point(char *func USES_REGS){
|
||||
Term args[2];
|
||||
Functor functor;
|
||||
double d;
|
||||
@ -188,7 +188,7 @@ static Term get_linestring(char *func){
|
||||
c_list = (Term *) calloc(sizeof(Term),n);
|
||||
|
||||
for ( i = 0; i < n; i++) {
|
||||
c_list[i] = get_point(NULL);
|
||||
c_list[i] = get_point(NULL PASS_REGS);
|
||||
}
|
||||
|
||||
list = MkAtomTerm(Yap_LookupAtom("[]"));
|
||||
@ -241,15 +241,14 @@ static Term get_geometry(uint32 type){
|
||||
|
||||
switch(type) {
|
||||
case WKBPOINT:
|
||||
return get_point("point");
|
||||
return get_point("point" PASS_REGS);
|
||||
case WKBLINESTRING:
|
||||
return get_linestring("linestring");
|
||||
case WKBPOLYGON:
|
||||
return get_polygon("polygon");
|
||||
case WKBMULTIPOINT:
|
||||
{
|
||||
byte b;
|
||||
uint32 n, u;
|
||||
uint32 n;
|
||||
int i;
|
||||
Functor functor;
|
||||
Term *c_list;
|
||||
@ -264,10 +263,10 @@ static Term get_geometry(uint32 type){
|
||||
|
||||
for ( i = 0; i < n; i++ ) {
|
||||
/* read (and ignore) the byteorder and type */
|
||||
b = get_inbyteorder();
|
||||
u = get_wkbType();
|
||||
get_inbyteorder();
|
||||
get_wkbType();
|
||||
|
||||
c_list[i] = get_point(NULL);
|
||||
c_list[i] = get_point(NULL PASS_REGS);
|
||||
}
|
||||
|
||||
list = MkAtomTerm(Yap_LookupAtom("[]"));
|
||||
@ -282,8 +281,7 @@ static Term get_geometry(uint32 type){
|
||||
}
|
||||
case WKBMULTILINESTRING:
|
||||
{
|
||||
byte b;
|
||||
uint32 n, u;
|
||||
uint32 n;
|
||||
int i;
|
||||
Functor functor;
|
||||
Term *c_list;
|
||||
@ -298,8 +296,8 @@ static Term get_geometry(uint32 type){
|
||||
|
||||
for ( i = 0; i < n; i++ ) {
|
||||
/* read (and ignore) the byteorder and type */
|
||||
b = get_inbyteorder();
|
||||
u = get_wkbType();
|
||||
get_inbyteorder();
|
||||
get_wkbType();
|
||||
|
||||
c_list[i] = get_linestring(NULL);
|
||||
}
|
||||
@ -316,8 +314,7 @@ static Term get_geometry(uint32 type){
|
||||
}
|
||||
case WKBMULTIPOLYGON:
|
||||
{
|
||||
byte b;
|
||||
uint32 n, u;
|
||||
uint32 n;
|
||||
int i;
|
||||
Functor functor;
|
||||
Term *c_list;
|
||||
@ -332,8 +329,8 @@ static Term get_geometry(uint32 type){
|
||||
|
||||
for ( i = 0; i < n; i++ ) {
|
||||
/* read (and ignore) the byteorder and type */
|
||||
b = get_inbyteorder();
|
||||
u = get_wkbType();
|
||||
get_inbyteorder();
|
||||
get_wkbType();
|
||||
|
||||
c_list[i] = get_polygon(NULL);
|
||||
}
|
||||
@ -350,7 +347,6 @@ static Term get_geometry(uint32 type){
|
||||
}
|
||||
case WKBGEOMETRYCOLLECTION:
|
||||
{
|
||||
byte b;
|
||||
uint32 n;
|
||||
int i;
|
||||
Functor functor;
|
||||
@ -365,7 +361,7 @@ static Term get_geometry(uint32 type){
|
||||
|
||||
|
||||
for ( i = 0; i < n; i++ ) {
|
||||
b = get_inbyteorder();
|
||||
get_inbyteorder();
|
||||
c_list[i] = get_geometry(get_wkbType());
|
||||
}
|
||||
|
||||
|
@ -22,39 +22,39 @@
|
||||
#include "opt.mavar.h"
|
||||
|
||||
#ifdef THREADS
|
||||
static inline void **get_insert_thread_bucket(void **, lockvar *);
|
||||
static inline void **get_thread_bucket(void **);
|
||||
static inline void **__get_insert_thread_bucket(void **, lockvar * USES_REGS);
|
||||
static inline void **__get_thread_bucket(void ** USES_REGS);
|
||||
static inline void abolish_thread_buckets(void **);
|
||||
#endif /* THREADS */
|
||||
static inline sg_node_ptr get_insert_subgoal_trie(tab_ent_ptr USES_REGS);
|
||||
static inline sg_node_ptr get_subgoal_trie(tab_ent_ptr);
|
||||
static inline sg_node_ptr __get_subgoal_trie(tab_ent_ptr USES_REGS);
|
||||
static inline sg_node_ptr get_subgoal_trie_for_abolish(tab_ent_ptr USES_REGS);
|
||||
static inline sg_fr_ptr *get_insert_subgoal_frame_addr(sg_node_ptr USES_REGS);
|
||||
static inline sg_fr_ptr get_subgoal_frame(sg_node_ptr);
|
||||
static inline sg_fr_ptr get_subgoal_frame_for_abolish(sg_node_ptr USES_REGS);
|
||||
#ifdef THREADS_FULL_SHARING
|
||||
static inline void SgFr_batched_cached_answers_check_insert(sg_fr_ptr, ans_node_ptr);
|
||||
static inline void __SgFr_batched_cached_answers_check_insert(sg_fr_ptr, ans_node_ptr USES_REGS);
|
||||
static inline int SgFr_batched_cached_answers_check_remove(sg_fr_ptr, ans_node_ptr);
|
||||
#endif /* THREADS_FULL_SHARING */
|
||||
#ifdef THREADS_CONSUMER_SHARING
|
||||
static inline void add_to_tdv(int, int);
|
||||
static inline void check_for_deadlock(sg_fr_ptr);
|
||||
static inline sg_fr_ptr deadlock_detection(sg_fr_ptr);
|
||||
static inline void __add_to_tdv(int, int USES_REGS);
|
||||
static inline void __check_for_deadlock(sg_fr_ptr USES_REGS);
|
||||
static inline sg_fr_ptr __deadlock_detection(sg_fr_ptr USES_REGS);
|
||||
#endif /* THREADS_CONSUMER_SHARING */
|
||||
static inline Int freeze_current_cp(void);
|
||||
static inline void wake_frozen_cp(Int);
|
||||
static inline void abolish_frozen_cps_until(Int);
|
||||
static inline void abolish_frozen_cps_all(void);
|
||||
static inline void adjust_freeze_registers(void);
|
||||
static inline void mark_as_completed(sg_fr_ptr);
|
||||
static inline void unbind_variables(tr_fr_ptr, tr_fr_ptr);
|
||||
static inline void rebind_variables(tr_fr_ptr, tr_fr_ptr);
|
||||
static inline void restore_bindings(tr_fr_ptr, tr_fr_ptr);
|
||||
static inline CELL *expand_auxiliary_stack(CELL *);
|
||||
static inline void abolish_incomplete_subgoals(choiceptr);
|
||||
static inline Int __freeze_current_cp( USES_REGS1 );
|
||||
static inline void __wake_frozen_cp(Int USES_REGS);
|
||||
static inline void __abolish_frozen_cps_until(Int USES_REGS);
|
||||
static inline void __abolish_frozen_cps_all( USES_REGS1 );
|
||||
static inline void __adjust_freeze_registers( USES_REGS1 );
|
||||
static inline void __mark_as_completed(sg_fr_ptr USES_REGS);
|
||||
static inline void __unbind_variables(tr_fr_ptr, tr_fr_ptr USES_REGS);
|
||||
static inline void __rebind_variables(tr_fr_ptr, tr_fr_ptr USES_REGS);
|
||||
static inline void __restore_bindings(tr_fr_ptr, tr_fr_ptr USES_REGS);
|
||||
static inline CELL *__expand_auxiliary_stack(CELL * USES_REGS);
|
||||
static inline void __abolish_incomplete_subgoals(choiceptr USES_REGS);
|
||||
#ifdef YAPOR
|
||||
static inline void pruning_over_tabling_data_structures(void);
|
||||
static inline void collect_suspension_frames(or_fr_ptr);
|
||||
static inline void __collect_suspension_frames(or_fr_ptr USES_REGS);
|
||||
#ifdef TIMESTAMP_CHECK
|
||||
static inline susp_fr_ptr suspension_frame_to_resume(or_fr_ptr, long);
|
||||
#else
|
||||
@ -658,8 +658,9 @@ static inline tg_sol_fr_ptr CUT_prune_tg_solution_frames(tg_sol_fr_ptr, int);
|
||||
******************************/
|
||||
|
||||
#ifdef THREADS
|
||||
static inline void **get_insert_thread_bucket(void **buckets, lockvar *buckets_lock) {
|
||||
CACHE_REGS
|
||||
#define get_insert_thread_bucket(b, bl) __get_insert_thread_bucket((b), (bl) PASS_REGS)
|
||||
|
||||
static inline void **__get_insert_thread_bucket(void **buckets, lockvar *buckets_lock USES_REGS) {
|
||||
|
||||
/* direct bucket */
|
||||
if (worker_id < THREADS_DIRECT_BUCKETS)
|
||||
@ -678,9 +679,9 @@ static inline void **get_insert_thread_bucket(void **buckets, lockvar *buckets_l
|
||||
return *buckets + (worker_id - THREADS_DIRECT_BUCKETS) % THREADS_DIRECT_BUCKETS;
|
||||
}
|
||||
|
||||
#define get_thread_bucket(b) __get_thread_bucket((b) PASS_REGS)
|
||||
|
||||
static inline void **get_thread_bucket(void **buckets) {
|
||||
CACHE_REGS
|
||||
static inline void **__get_thread_bucket(void **buckets USES_REGS) {
|
||||
|
||||
/* direct bucket */
|
||||
if (worker_id < THREADS_DIRECT_BUCKETS)
|
||||
@ -729,8 +730,9 @@ static inline sg_node_ptr get_insert_subgoal_trie(tab_ent_ptr tab_ent USES_REGS)
|
||||
#endif /* THREADS_NO_SHARING */
|
||||
}
|
||||
|
||||
#define get_subgoal_trie(te) __get_subgoal_trie((te) PASS_REGS)
|
||||
|
||||
static inline sg_node_ptr get_subgoal_trie(tab_ent_ptr tab_ent) {
|
||||
static inline sg_node_ptr __get_subgoal_trie(tab_ent_ptr tab_ent USES_REGS) {
|
||||
#ifdef THREADS_NO_SHARING
|
||||
sg_node_ptr *sg_node_addr = (sg_node_ptr *) get_thread_bucket((void **) &TabEnt_subgoal_trie(tab_ent));
|
||||
return *sg_node_addr;
|
||||
@ -825,8 +827,8 @@ static inline sg_fr_ptr get_subgoal_frame_for_abolish(sg_node_ptr sg_node USES_R
|
||||
|
||||
|
||||
#ifdef THREADS_FULL_SHARING
|
||||
static inline void SgFr_batched_cached_answers_check_insert(sg_fr_ptr sg_fr, ans_node_ptr ans_node) {
|
||||
CACHE_REGS
|
||||
#define SgFr_batched_cached_answers_check_insert(s, a) __SgFr_batched_cached_answers_check_insert((s), (a) PASS_REGS)
|
||||
static inline void SgFr_batched_cached_answers_check_insert(sg_fr_ptr sg_fr, ans_node_ptr ans_node USES_REGS) {
|
||||
|
||||
if (SgFr_batched_last_answer(sg_fr) == NULL)
|
||||
SgFr_batched_last_answer(sg_fr) = SgFr_first_answer(sg_fr);
|
||||
@ -854,8 +856,9 @@ static inline void SgFr_batched_cached_answers_check_insert(sg_fr_ptr sg_fr, ans
|
||||
return;
|
||||
}
|
||||
|
||||
static inline int SgFr_batched_cached_answers_check_remove(sg_fr_ptr sg_fr, ans_node_ptr ans_node) {
|
||||
CACHE_REGS
|
||||
#define SgFr_batched_cached_answers_check_remove(s, a) __SgFr_batched_cached_answers_check_remove((s), (a) PASS_REGS)
|
||||
|
||||
static inline int __SgFr_batched_cached_answers_check_remove(sg_fr_ptr sg_fr, ans_node_ptr ans_node USES_REgS) {
|
||||
struct answer_ref_node *local_uncons_ans;
|
||||
|
||||
local_uncons_ans = SgFr_batched_cached_answers(sg_fr) ;
|
||||
@ -884,10 +887,10 @@ static inline int SgFr_batched_cached_answers_check_remove(sg_fr_ptr sg_fr, ans_
|
||||
|
||||
|
||||
#ifdef THREADS_CONSUMER_SHARING
|
||||
static inline void add_to_tdv(int wid, int wid_dep) {
|
||||
#ifdef OUTPUT_THREADS_TABLING
|
||||
CACHE_REGS
|
||||
#endif /* OUTPUT_THREADS_TABLING */
|
||||
|
||||
#define add_to_tdv(w, wd) __add_to_tdv((w), (wd) PASS_REGS)
|
||||
|
||||
static inline void __add_to_tdv(int wid, int wid_dep USES_REGS) {
|
||||
// thread wid next of thread wid_dep
|
||||
/* check before insert */
|
||||
int c_wid = ThDepFr_next(GLOBAL_th_dep_fr(wid));
|
||||
@ -927,9 +930,9 @@ static inline void add_to_tdv(int wid, int wid_dep) {
|
||||
return;
|
||||
}
|
||||
|
||||
#define check_for_deadlock(s) __check_for_deadlock((s) PASS_REGS)
|
||||
|
||||
static inline void check_for_deadlock(sg_fr_ptr sg_fr) {
|
||||
CACHE_REGS
|
||||
static inline void __check_for_deadlock(sg_fr_ptr sg_fr USES_REGS) {
|
||||
sg_fr_ptr local_sg_fr = deadlock_detection(sg_fr);
|
||||
|
||||
if (local_sg_fr){
|
||||
@ -942,9 +945,9 @@ static inline void check_for_deadlock(sg_fr_ptr sg_fr) {
|
||||
return;
|
||||
}
|
||||
|
||||
#define deadlock_detection(s) __deadlock_detection((s) PASS_REGS)
|
||||
|
||||
static inline sg_fr_ptr deadlock_detection(sg_fr_ptr sg_fr) {
|
||||
CACHE_REGS
|
||||
static inline sg_fr_ptr __deadlock_detection(sg_fr_ptr sg_fr USES_REGS) {
|
||||
sg_fr_ptr remote_sg_fr = REMOTE_top_sg_fr(SgFr_gen_worker(sg_fr));
|
||||
|
||||
while( SgFr_sg_ent(remote_sg_fr) != SgFr_sg_ent(sg_fr)){
|
||||
@ -977,9 +980,9 @@ static inline sg_fr_ptr deadlock_detection(sg_fr_ptr sg_fr) {
|
||||
}
|
||||
#endif /* THREADS_CONSUMER_SHARING */
|
||||
|
||||
#define freeze_current_cp() __freeze_current_cp( PASS_REGS1 )
|
||||
|
||||
static inline Int freeze_current_cp(void) {
|
||||
CACHE_REGS
|
||||
static inline Int __freeze_current_cp(USES_REGS1) {
|
||||
choiceptr freeze_cp = B;
|
||||
|
||||
B_FZ = freeze_cp;
|
||||
@ -991,8 +994,11 @@ static inline Int freeze_current_cp(void) {
|
||||
}
|
||||
|
||||
|
||||
static inline void wake_frozen_cp(Int frozen_offset) {
|
||||
CACHE_REGS
|
||||
#define wake_frozen_cp(f) __wake_frozen_cp((f) PASS_REGS)
|
||||
|
||||
#define restore_bindings(u, r) __restore_bindings((u), (r) PASS_REGS)
|
||||
|
||||
static inline void __wake_frozen_cp(Int frozen_offset USES_REGS) {
|
||||
choiceptr frozen_cp = (choiceptr)(LOCAL_LocalBase - frozen_offset);
|
||||
|
||||
restore_bindings(TR, frozen_cp->cp_tr);
|
||||
@ -1003,8 +1009,9 @@ static inline void wake_frozen_cp(Int frozen_offset) {
|
||||
}
|
||||
|
||||
|
||||
static inline void abolish_frozen_cps_until(Int frozen_offset) {
|
||||
CACHE_REGS
|
||||
#define abolish_frozen_cps_until(f) __abolish_frozen_cps_until((f) PASS_REGS )
|
||||
|
||||
static inline void __abolish_frozen_cps_until(Int frozen_offset USES_REGS) {
|
||||
choiceptr frozen_cp = (choiceptr)(LOCAL_LocalBase - frozen_offset);
|
||||
|
||||
B_FZ = frozen_cp;
|
||||
@ -1013,28 +1020,28 @@ static inline void abolish_frozen_cps_until(Int frozen_offset) {
|
||||
return;
|
||||
}
|
||||
|
||||
#define abolish_frozen_cps_all() __abolish_frozen_cps_all( PASS_REGS1 )
|
||||
|
||||
static inline void abolish_frozen_cps_all(void) {
|
||||
CACHE_REGS
|
||||
static inline void __abolish_frozen_cps_all( USES_REGS1 ) {
|
||||
B_FZ = (choiceptr) LOCAL_LocalBase;
|
||||
H_FZ = (CELL *) LOCAL_GlobalBase;
|
||||
TR_FZ = (tr_fr_ptr) LOCAL_TrailBase;
|
||||
return;
|
||||
}
|
||||
|
||||
#define adjust_freeze_registers() __adjust_freeze_registers( PASS_REGS1 )
|
||||
|
||||
static inline void adjust_freeze_registers(void) {
|
||||
CACHE_REGS
|
||||
static inline void __adjust_freeze_registers( USES_REGS1 ) {
|
||||
B_FZ = DepFr_cons_cp(LOCAL_top_dep_fr);
|
||||
H_FZ = B_FZ->cp_h;
|
||||
TR_FZ = B_FZ->cp_tr;
|
||||
return;
|
||||
}
|
||||
|
||||
#define mark_as_completed(sg) __mark_as_completed((sg) PASS_REGS )
|
||||
|
||||
static inline void mark_as_completed(sg_fr_ptr sg_fr) {
|
||||
static inline void __mark_as_completed(sg_fr_ptr sg_fr USES_REGS) {
|
||||
#if defined(MODE_DIRECTED_TABLING) && !defined(THREADS_FULL_SHARING) && !defined(THREADS_CONSUMER_SHARING)
|
||||
CACHE_REGS
|
||||
#endif /* MODE_DIRECTED_TABLING && !THREADS_FULL_SHARING && !THREADS_CONSUMER_SHARING */
|
||||
|
||||
LOCK_SG_FR(sg_fr);
|
||||
@ -1079,9 +1086,9 @@ static inline void mark_as_completed(sg_fr_ptr sg_fr) {
|
||||
return;
|
||||
}
|
||||
|
||||
#define unbind_variables(u, e) __unbind_variables((u), (e) PASS_REGS)
|
||||
|
||||
static inline void unbind_variables(tr_fr_ptr unbind_tr, tr_fr_ptr end_tr) {
|
||||
CACHE_REGS
|
||||
static inline void __unbind_variables(tr_fr_ptr unbind_tr, tr_fr_ptr end_tr USES_REGS) {
|
||||
TABLING_ERROR_CHECKING(unbind_variables, unbind_tr < end_tr);
|
||||
/* unbind loop */
|
||||
while (unbind_tr != end_tr) {
|
||||
@ -1111,8 +1118,9 @@ static inline void unbind_variables(tr_fr_ptr unbind_tr, tr_fr_ptr end_tr) {
|
||||
}
|
||||
|
||||
|
||||
static inline void rebind_variables(tr_fr_ptr rebind_tr, tr_fr_ptr end_tr) {
|
||||
CACHE_REGS
|
||||
#define rebind_variables(u, e) __rebind_variables(u, e PASS_REGS)
|
||||
|
||||
static inline void __rebind_variables(tr_fr_ptr rebind_tr, tr_fr_ptr end_tr USES_REGS) {
|
||||
TABLING_ERROR_CHECKING(rebind_variables, rebind_tr < end_tr);
|
||||
/* rebind loop */
|
||||
Yap_NEW_MAHASH((ma_h_inner_struct *)H PASS_REGS);
|
||||
@ -1144,9 +1152,7 @@ static inline void rebind_variables(tr_fr_ptr rebind_tr, tr_fr_ptr end_tr) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
static inline void restore_bindings(tr_fr_ptr unbind_tr, tr_fr_ptr rebind_tr) {
|
||||
CACHE_REGS
|
||||
static inline void __restore_bindings(tr_fr_ptr unbind_tr, tr_fr_ptr rebind_tr USES_REGS) {
|
||||
CELL ref;
|
||||
tr_fr_ptr end_tr;
|
||||
|
||||
@ -1218,9 +1224,9 @@ static inline void restore_bindings(tr_fr_ptr unbind_tr, tr_fr_ptr rebind_tr) {
|
||||
return;
|
||||
}
|
||||
|
||||
#define expand_auxiliary_stack(s) __expand_auxiliary_stack((s) PASS_REGS)
|
||||
|
||||
static inline CELL *expand_auxiliary_stack(CELL *stack) {
|
||||
CACHE_REGS
|
||||
static inline CELL *__expand_auxiliary_stack(CELL *stack USES_REGS) {
|
||||
void *old_top = LOCAL_TrailTop;
|
||||
INFORMATION_MESSAGE("Expanding trail in 64 Kbytes");
|
||||
if (! Yap_growtrail(K64, TRUE)) { /* TRUE means 'contiguous_only' */
|
||||
@ -1234,9 +1240,10 @@ static inline CELL *expand_auxiliary_stack(CELL *stack) {
|
||||
}
|
||||
}
|
||||
|
||||
#define abolish_incomplete_subgoals(p) __abolish_incomplete_subgoals((p) PASS_REGS)
|
||||
|
||||
static inline void abolish_incomplete_subgoals(choiceptr prune_cp) {
|
||||
CACHE_REGS
|
||||
|
||||
static inline void __abolish_incomplete_subgoals(choiceptr prune_cp USES_REGS) {
|
||||
|
||||
#ifdef YAPOR
|
||||
if (EQUAL_OR_YOUNGER_CP(GetOrFr_node(LOCAL_top_susp_or_fr), prune_cp))
|
||||
@ -1389,8 +1396,9 @@ static inline void pruning_over_tabling_data_structures(void) {
|
||||
}
|
||||
|
||||
|
||||
static inline void collect_suspension_frames(or_fr_ptr or_fr) {
|
||||
CACHE_REGS
|
||||
#define collect_suspension_frames(o) __collect_suspension_frames((o) PASS_REGS)
|
||||
|
||||
static inline void __collect_suspension_frames(or_fr_ptr or_fr USES_REGS) {
|
||||
int depth;
|
||||
or_fr_ptr *susp_ptr;
|
||||
|
||||
|
10
configure
vendored
10
configure
vendored
@ -1539,7 +1539,7 @@ Optional Packages:
|
||||
--with-java=JAVA_HOME use Java instalation in JAVA_HOME
|
||||
--with-readline=DIR use GNU Readline Library in DIR
|
||||
--with-matlab=DIR use MATLAB package in DIR
|
||||
--with-mpi=DIR use MPI library in DIR
|
||||
--with-mpi=DIR use LAM/MPI library in DIR
|
||||
--with-mpe=DIR use MPE library in DIR
|
||||
--with-lam=DIR use LAM MPI library in DIR
|
||||
--with-heap-space=space default heap size in Kbytes
|
||||
@ -4860,16 +4860,16 @@ fi
|
||||
# Check whether --with-mpi was given.
|
||||
if test "${with_mpi+set}" = set; then :
|
||||
withval=$with_mpi; if test "$withval" = yes; then
|
||||
yap_cv_mpi=yes
|
||||
yap_cv_lam=yes
|
||||
elif test "$withval" = no; then
|
||||
yap_cv_mpi=no
|
||||
yap_cv_lam=no
|
||||
else
|
||||
yap_cv_mpi=$with_mpi
|
||||
yap_cv_lam=$with_mpi
|
||||
LDFLAGS="$LDFLAGS -L${yap_cv_mpi}/lib"
|
||||
CPPFLAGS="$CPPFLAGS -I${yap_cv_mpi}/include"
|
||||
fi
|
||||
else
|
||||
yap_cv_mpi=no
|
||||
yap_cv_lam=no
|
||||
fi
|
||||
|
||||
|
||||
|
10
configure.in
10
configure.in
@ -360,18 +360,18 @@ AC_ARG_WITH(matlab,
|
||||
[yap_cv_matlab=no])
|
||||
|
||||
AC_ARG_WITH(mpi,
|
||||
[ --with-mpi[=DIR] use MPI library in DIR],
|
||||
[ --with-mpi[=DIR] use LAM/MPI library in DIR],
|
||||
if test "$withval" = yes; then
|
||||
dnl handle UBUNTU systems
|
||||
yap_cv_mpi=yes
|
||||
yap_cv_lam=yes
|
||||
elif test "$withval" = no; then
|
||||
yap_cv_mpi=no
|
||||
yap_cv_lam=no
|
||||
else
|
||||
yap_cv_mpi=$with_mpi
|
||||
yap_cv_lam=$with_mpi
|
||||
LDFLAGS="$LDFLAGS -L${yap_cv_mpi}/lib"
|
||||
CPPFLAGS="$CPPFLAGS -I${yap_cv_mpi}/include"
|
||||
fi,
|
||||
[yap_cv_mpi=no])
|
||||
[yap_cv_lam=no])
|
||||
|
||||
|
||||
AC_ARG_WITH(mpe,
|
||||
|
@ -12840,6 +12840,13 @@ vertices in @var{Vs} map to vertices in @var{NewVs}.
|
||||
The path @var{Path} is a path starting at vertex @var{Vertex} in graph
|
||||
@var{Graph}.
|
||||
|
||||
@item dgraph_path(+@var{Vertex}, +@var{Vertex1}, +@var{Graph}, ?@var{Path})
|
||||
@findex dgraph_path/3
|
||||
@snindex dgraph_path/3
|
||||
@cnindex dgraph_path/3
|
||||
The path @var{Path} is a path starting at vertex @var{Vertex} in graph
|
||||
@var{Graph} and ending at path @var{Vertex2}.
|
||||
|
||||
@item dgraph_reachable(+@var{Vertex}, +@var{Graph}, ?@var{Edges})
|
||||
@findex dgraph_path/3
|
||||
@snindex dgraph_path/3
|
||||
|
@ -32,6 +32,7 @@
|
||||
dgraph_min_paths/3,
|
||||
dgraph_isomorphic/4,
|
||||
dgraph_path/3,
|
||||
dgraph_path/4,
|
||||
dgraph_leaves/2,
|
||||
dgraph_reachable/3
|
||||
]).
|
||||
@ -40,7 +41,8 @@
|
||||
[rb_new/1 as dgraph_new]).
|
||||
|
||||
:- use_module(library(rbtrees),
|
||||
[rb_empty/1,
|
||||
[rb_new/1,
|
||||
rb_empty/1,
|
||||
rb_lookup/3,
|
||||
rb_apply/4,
|
||||
rb_insert/4,
|
||||
@ -361,10 +363,21 @@ dgraph_min_paths(V1, Graph, Paths) :-
|
||||
dgraph_to_wdgraph(Graph, WGraph),
|
||||
wdgraph_min_paths(V1, WGraph, Paths).
|
||||
|
||||
dgraph_path(V, G, [V|P]) :-
|
||||
rb_lookup(V, Children, G),
|
||||
ord_del_element(Children, V, Ch),
|
||||
do_path(Ch, G, [V], P).
|
||||
dgraph_path(V1, V2, Graph, Path) :-
|
||||
rb_new(E0),
|
||||
rb_lookup(V1, Children, Graph),
|
||||
dgraph_path_children(Children, V2, E0, Graph, Path).
|
||||
|
||||
dgraph_path_children([V1|_], V2, _E1, _Graph, []) :- V1 == V2.
|
||||
dgraph_path_children([V1|_], V2, E1, Graph, [V1|Path]) :-
|
||||
V2 \== V1,
|
||||
\+ rb_lookup(V1, _, E0),
|
||||
rb_insert(E0, V2, [], E1),
|
||||
rb_lookup(V1, Children, Graph),
|
||||
dgraph_path_children(Children, V2, E1, Graph, Path).
|
||||
dgraph_path_children([_|Children], V2, E1, Graph, Path) :-
|
||||
dgraph_path_children(Children, V2, E1, Graph, Path).
|
||||
|
||||
|
||||
do_path([], _, _, []).
|
||||
do_path([C|Children], G, SoFar, Path) :-
|
||||
@ -378,6 +391,11 @@ do_children([V|_], G, SoFar, [V|Path]) :-
|
||||
do_children([_|Children], G, SoFar, Path) :-
|
||||
do_children(Children, G, SoFar, Path).
|
||||
|
||||
dgraph_path(V, G, [V|P]) :-
|
||||
rb_lookup(V, Children, G),
|
||||
ord_del_element(Children, V, Ch),
|
||||
do_path(Ch, G, [V], P).
|
||||
|
||||
|
||||
dgraph_isomorphic(Vs, Vs2, G1, G2) :-
|
||||
rb_new(Map0),
|
||||
|
@ -37,8 +37,8 @@ RANLIB=@RANLIB@
|
||||
srcdir=@srcdir@
|
||||
SO=@SO@
|
||||
CWD=$(PWD)
|
||||
MPILDF=`$(MPI_CC) -showme|sed "s/[^ ]*//"|sed "s/-pt/-lpt/"`
|
||||
MPICF=`$(MPI_CC) -showme| cut -d " " -f 2`
|
||||
MPILDF=`$(MPI_CC) --showme:link`
|
||||
MPICF=`$(MPI_CC) --showme:compile`
|
||||
#
|
||||
|
||||
OBJS=yap_mpi.o hash.o prologterms2c.o
|
||||
|
@ -32,23 +32,23 @@ static char *rcsid = "$Header: /Users/vitor/Yap/yap-cvsbackup/library/mpi/mpi.c,
|
||||
|
||||
void STD_PROTO(YAP_Write, (Term, void (*)(int), int));
|
||||
|
||||
STATIC_PROTO (Int p_mpi_open, (void));
|
||||
STATIC_PROTO (Int p_mpi_close, (void));
|
||||
STATIC_PROTO (Int p_mpi_send, (void));
|
||||
STATIC_PROTO (Int p_mpi_receive, (void));
|
||||
STATIC_PROTO (Int p_mpi_bcast3, (void));
|
||||
STATIC_PROTO (Int p_mpi_bcast2, (void));
|
||||
STATIC_PROTO (Int p_mpi_barrier, (void));
|
||||
STATIC_PROTO (Int p_mpi_open, ( USES_REGS1 ));
|
||||
STATIC_PROTO (Int p_mpi_close, ( USES_REGS1 ));
|
||||
STATIC_PROTO (Int p_mpi_send, ( USES_REGS1 ));
|
||||
STATIC_PROTO (Int p_mpi_receive, ( USES_REGS1 ));
|
||||
STATIC_PROTO (Int p_mpi_bcast3, ( USES_REGS1 ));
|
||||
STATIC_PROTO (Int p_mpi_bcast2, ( USES_REGS1 ));
|
||||
STATIC_PROTO (Int p_mpi_barrier, ( USES_REGS1 ));
|
||||
|
||||
|
||||
/*
|
||||
* Auxiliary Data
|
||||
*/
|
||||
|
||||
static Int rank, numprocs, namelen;
|
||||
static int rank, numprocs, namelen;
|
||||
static char processor_name[MPI_MAX_PROCESSOR_NAME];
|
||||
|
||||
static Int mpi_argc;
|
||||
static int mpi_argc;
|
||||
static char **mpi_argv;
|
||||
|
||||
/* this should eventually be moved to config.h */
|
||||
@ -111,7 +111,7 @@ mpi_putc(Int ch)
|
||||
|
||||
|
||||
static Int
|
||||
p_mpi_open(void) /* mpi_open(?rank, ?num_procs, ?proc_name) */
|
||||
p_mpi_open( USES_REGS1 ) /* mpi_open(?rank, ?num_procs, ?proc_name) */
|
||||
{
|
||||
Term t_rank = Deref(ARG1), t_numprocs = Deref(ARG2), t_procname = Deref(ARG3);
|
||||
Int retv;
|
||||
@ -156,7 +156,7 @@ Yap exit(FAILURE), whereas in Yap/LAM mpi_open/3 simply fails.
|
||||
|
||||
|
||||
static Int /* mpi_close */
|
||||
p_mpi_close()
|
||||
p_mpi_close( USES_REGS1 )
|
||||
{
|
||||
MPI_Finalize();
|
||||
return TRUE;
|
||||
@ -164,7 +164,7 @@ p_mpi_close()
|
||||
|
||||
|
||||
static Int
|
||||
p_mpi_send() /* mpi_send(+data, +destination, +tag) */
|
||||
p_mpi_send( USES_REGS1 ) /* mpi_send(+data, +destination, +tag) */
|
||||
{
|
||||
Term t_data = Deref(ARG1), t_dest = Deref(ARG2), t_tag = Deref(ARG3);
|
||||
int tag, dest, retv;
|
||||
@ -216,7 +216,7 @@ p_mpi_send() /* mpi_send(+data, +destination, +tag) */
|
||||
|
||||
|
||||
static Int
|
||||
p_mpi_receive() /* mpi_receive(-data, ?orig, ?tag) */
|
||||
p_mpi_receive( USES_REGS1 ) /* mpi_receive(-data, ?orig, ?tag) */
|
||||
{
|
||||
Term t, t_data = Deref(ARG1), t_orig = Deref(ARG2), t_tag = Deref(ARG3);
|
||||
int tag, orig, retv;
|
||||
@ -305,7 +305,7 @@ p_mpi_receive() /* mpi_receive(-data, ?orig, ?tag) */
|
||||
|
||||
|
||||
static Int
|
||||
p_mpi_bcast3() /* mpi_bcast( ?data, +root, +max_size ) */
|
||||
p_mpi_bcast3( USES_REGS1 ) /* mpi_bcast( ?data, +root, +max_size ) */
|
||||
{
|
||||
Term t_data = Deref(ARG1), t_root = Deref(ARG2), t_max_size = Deref(ARG3);
|
||||
int root, retv, max_size;
|
||||
@ -386,7 +386,7 @@ p_mpi_bcast3() /* mpi_bcast( ?data, +root, +max_size ) */
|
||||
*/
|
||||
|
||||
static Int
|
||||
p_mpi_bcast2() /* mpi_bcast( ?data, +root ) */
|
||||
p_mpi_bcast2( USES_REGS1 ) /* mpi_bcast( ?data, +root ) */
|
||||
{
|
||||
Term t_data = Deref(ARG1), t_root = Deref(ARG2);
|
||||
int root, retv;
|
||||
@ -460,7 +460,7 @@ p_mpi_bcast2() /* mpi_bcast( ?data, +root ) */
|
||||
|
||||
|
||||
static Int
|
||||
p_mpi_barrier() /* mpi_barrier/0 */
|
||||
p_mpi_barrier( USES_REGS1 ) /* mpi_barrier/0 */
|
||||
{
|
||||
int retv;
|
||||
|
||||
|
@ -473,7 +473,7 @@ factor_to_dist(Hash, f(bayes, Id, Ks)) :-
|
||||
maplist(key_to_var(Hash), Ks, [V|Parents]),
|
||||
Ks =[Key|_],
|
||||
pfl:skolem(Key, Domain),
|
||||
pfl:get_pfl_parameters(Id, CPT),
|
||||
pfl:get_pfl_parameters(Id, Ks, CPT),
|
||||
dist(p(Domain,CPT,Parents), DistInfo, Key, Parents),
|
||||
put_atts(V,[dist(DistInfo,Parents)]).
|
||||
|
||||
|
@ -122,7 +122,7 @@ evtotree(K=V,Ev0,Ev) :-
|
||||
rb_insert(Ev0, K, V, Ev).
|
||||
|
||||
ftotree(F, Fs0, Fs) :-
|
||||
F = f([K|_Parents],_,_,_),
|
||||
F = fn([K|_Parents],_,_,_,_),
|
||||
rb_insert(Fs0, K, F, Fs).
|
||||
|
||||
bdd([[]],_,_) :- !.
|
||||
@ -160,7 +160,7 @@ sort_keys(AllFs, AllVars, Leaves) :-
|
||||
dgraph_leaves(Graph, Leaves),
|
||||
dgraph_top_sort(Graph, AllVars).
|
||||
|
||||
add_node(f([K|Parents],_,_,_), Graph0, Graph) :-
|
||||
add_node(fn([K|Parents],_,_,_,_), Graph0, Graph) :-
|
||||
dgraph_add_vertex(Graph0, K, Graph1),
|
||||
foldl(add_edge(K), Parents, Graph1, Graph).
|
||||
|
||||
@ -190,7 +190,7 @@ add_parents([V0|Parents], V, Graph0, GraphF) :-
|
||||
get_keys_info([], _, _, _, Vs, Vs, Ps, Ps, _, _) --> [].
|
||||
get_keys_info([V|MoreVs], Evs, Fs, OrderVs, Vs, VsF, Ps, PsF, Lvs, Outs) -->
|
||||
{ rb_lookup(V, F, Fs) }, !,
|
||||
{ F = f([V|Parents], _, _, DistId) },
|
||||
{ F = fn([V|Parents], _, _, DistId, _) },
|
||||
%{writeln(v:DistId:Parents)},
|
||||
[DIST],
|
||||
{ get_key_info(V, F, Fs, Evs, OrderVs, DistId, Parents, Vs, Vs2, Ps, Ps1, Lvs, Outs, DIST) },
|
||||
@ -200,7 +200,7 @@ get_key_info(V, F, Fs, Evs, OrderVs, DistId, Parents0, Vs, Vs2, Ps, Ps1, Lvs, Ou
|
||||
reorder_keys(Parents0, OrderVs, Parents, Map),
|
||||
check_key_p(DistId, F, Map, Parms, _ParmVars, Ps, Ps1),
|
||||
unbound_parms(Parms, ParmVars),
|
||||
F = f(_,[Size|_],_,_),
|
||||
F = fn(_,[Size|_],_,_,_),
|
||||
check_key(V, Size, DIST, Vs, Vs1),
|
||||
DIST = info(V, Tree, Ev, Values, Formula, ParmVars, Parms),
|
||||
% get a list of form [[P00,P01], [P10,P11], [P20,P21]]
|
||||
@ -599,7 +599,7 @@ to_disj2([V,V1|Vs], V0, Out) :-
|
||||
%
|
||||
check_key_p(DistId, _, Map, Parms, ParmVars, Ps, Ps) :-
|
||||
rb_lookup(DistId-Map, theta(Parms, ParmVars), Ps), !.
|
||||
check_key_p(DistId, f(_, Sizes, Parms0, DistId), Map, Parms, ParmVars, Ps, PsF) :-
|
||||
check_key_p(DistId, fn(_, Sizes, Parms0, DistId, _), Map, Parms, ParmVars, Ps, PsF) :-
|
||||
swap_parms(Parms0, Sizes, [0|Map], Parms1),
|
||||
length(Parms1, L0),
|
||||
Sizes = [Size|_],
|
||||
@ -693,7 +693,7 @@ get_parents(V.Parents, Values.PVars, Vs0, Vs) :-
|
||||
|
||||
get_key_parent(Fs, V, Values, Vs0, Vs) :-
|
||||
INFO = info(V, _Parent, _Ev, Values, _, _, _),
|
||||
rb_lookup(V, f(_, [Size|_], _, _), Fs),
|
||||
rb_lookup(V, fn(_, [Size|_], _, _, _), Fs),
|
||||
check_key(V, Size, INFO, Vs0, Vs).
|
||||
|
||||
check_key(V, _, INFO, Vs, Vs) :-
|
||||
|
@ -42,7 +42,7 @@ init_influences(Vs, G, RG) :-
|
||||
to_dgraph(Vs, G0, G),
|
||||
dgraph_transpose(G, RG).
|
||||
|
||||
factor_to_dgraph(f([V|Parents],_,_,_), G0, G) :-
|
||||
factor_to_dgraph(fn([V|Parents],_,_,_,_), G0, G) :-
|
||||
dgraph_add_vertex(G0, V, G00),
|
||||
build_edges(Parents, V, Edges),
|
||||
dgraph_add_edges(G00, Edges, G).
|
||||
|
@ -238,7 +238,7 @@ get_dist_all_sizes(Id, DSizes) :-
|
||||
|
||||
get_dist_domain_size(DistId, DSize) :-
|
||||
use_parfactors(on), !,
|
||||
pfl:get_pfl_parameters(DistId, Dist),
|
||||
pfl:get_pfl_parameters(DistId, _, Dist),
|
||||
length(Dist, DSize).
|
||||
get_dist_domain_size(avg(D,_), DSize) :- !,
|
||||
length(D, DSize).
|
||||
@ -297,7 +297,7 @@ empty_dist(Dist, TAB) :-
|
||||
dist_new_table(DistId, NewMat) :-
|
||||
use_parfactors(on), !,
|
||||
matrix_to_list(NewMat, List),
|
||||
pfl:new_pfl_parameters(DistId, List).
|
||||
pfl:new_pfl_parameters(DistId, _, List).
|
||||
dist_new_table(Id, NewMat) :-
|
||||
matrix_to_list(NewMat, List),
|
||||
recorded(clpbn_dist_db, db(Id, Key, _, A, B, C, D), R),
|
||||
|
@ -144,7 +144,6 @@ create_new_variable(K, V, Vf0, Vff, C0, Cf) :-
|
||||
Id =.. [Na,Dom],
|
||||
Dist =.. [Na,Dom,NTVs],
|
||||
{ V = K with Dist },
|
||||
writeln(done),
|
||||
add_stored_evidence(K, V),
|
||||
add_variables(TVs, NTVs, Vf0, Vff, C0, Cf).
|
||||
|
||||
|
@ -28,6 +28,8 @@
|
||||
sum_list/2
|
||||
]).
|
||||
|
||||
:- use_module(library(maplist)).
|
||||
|
||||
:- use_module(library(ordsets),
|
||||
[ord_subtract/3]).
|
||||
|
||||
@ -87,7 +89,7 @@ run_gibbs_solver(LVs, LPs, Vs) :-
|
||||
|
||||
initialise(LVs, Graph, GVs, OutputVars, VarOrder) :-
|
||||
init_keys(Keys0),
|
||||
gen_keys(LVs, 0, VLen, Keys0, Keys),
|
||||
foldl2(gen_key, LVs, 0, VLen, Keys0, Keys),
|
||||
functor(Graph,graph,VLen),
|
||||
graph_representation(LVs, Graph, 0, Keys, TGraph),
|
||||
compile_graph(Graph),
|
||||
@ -99,21 +101,18 @@ initialise(LVs, Graph, GVs, OutputVars, VarOrder) :-
|
||||
init_keys(Keys0) :-
|
||||
rb_new(Keys0).
|
||||
|
||||
gen_keys([], I, I, Keys, Keys).
|
||||
gen_keys([V|Vs], I0, If, Keys0, Keys) :-
|
||||
clpbn:get_atts(V,[evidence(_)]), !,
|
||||
gen_keys(Vs, I0, If, Keys0, Keys).
|
||||
gen_keys([V|Vs], I0, If, Keys0, Keys) :-
|
||||
gen_key(V, I0, I0, Keys0, Keys0) :-
|
||||
clpbn:get_atts(V,[evidence(_)]), !.
|
||||
gen_key(V, I0, I, Keys0, Keys) :-
|
||||
I is I0+1,
|
||||
rb_insert(Keys0,V,I,KeysI),
|
||||
gen_keys(Vs, I, If, KeysI, Keys).
|
||||
rb_insert(Keys0,V,I,Keys).
|
||||
|
||||
graph_representation([],_,_,_,[]).
|
||||
graph_representation([V|Vs], Graph, I0, Keys, TGraph) :-
|
||||
clpbn:get_atts(V,[evidence(_)]), !,
|
||||
clpbn:get_atts(V, [dist(Id,Parents)]),
|
||||
get_possibly_deterministic_dist_matrix(Id, Parents, _, Vals, Table),
|
||||
get_sizes(Parents, Szs),
|
||||
maplist(get_size, Parents, Szs),
|
||||
length(Vals,Sz),
|
||||
project_evidence_out([V|Parents],[V|Parents],Table,[Sz|Szs],Variables,NewTable),
|
||||
% all variables are parents
|
||||
@ -123,7 +122,7 @@ graph_representation([V|Vs], Graph, I0, Keys, [I-IParents|TGraph]) :-
|
||||
I is I0+1,
|
||||
clpbn:get_atts(V, [dist(Id,Parents)]),
|
||||
get_possibly_deterministic_dist_matrix(Id, Parents, _, Vals, Table),
|
||||
get_sizes(Parents, Szs),
|
||||
maplist( get_size, Parents, Szs),
|
||||
length(Vals,Sz),
|
||||
project_evidence_out([V|Parents],[V|Parents],Table,[Sz|Szs],Variables,NewTable),
|
||||
Variables = [V|NewParents],
|
||||
@ -131,7 +130,7 @@ graph_representation([V|Vs], Graph, I0, Keys, [I-IParents|TGraph]) :-
|
||||
reorder_CPT(Variables,NewTable,[V|SortedNVs],NewTable2,_),
|
||||
add2graph(V, Vals, NewTable2, SortedIndices, Graph, Keys),
|
||||
propagate2parents(NewParents, NewTable, Variables, Graph,Keys),
|
||||
parent_indices(NewParents, Keys, IVariables0),
|
||||
maplist(parent_index(Keys), NewParents, IVariables0),
|
||||
sort(IVariables0, IParents),
|
||||
arg(I, Graph, var(_,_,_,_,_,_,_,NewTable2,SortedIndices)),
|
||||
graph_representation(Vs, Graph, I, Keys, TGraph).
|
||||
@ -141,18 +140,12 @@ write_pars([V|Parents]) :-
|
||||
clpbn:get_atts(V, [key(K),dist(I,_)]),write(K:I),nl,
|
||||
write_pars(Parents).
|
||||
|
||||
get_sizes([], []).
|
||||
get_sizes([V|Parents], [Sz|Szs]) :-
|
||||
get_size(V, Sz) :-
|
||||
clpbn:get_atts(V, [dist(Id,_)]),
|
||||
get_dist_domain_size(Id, Sz),
|
||||
get_sizes(Parents, Szs).
|
||||
|
||||
parent_indices([], _, []).
|
||||
parent_indices([V|Parents], Keys, [I|IParents]) :-
|
||||
rb_lookup(V, I, Keys),
|
||||
parent_indices(Parents, Keys, IParents).
|
||||
|
||||
get_dist_domain_size(Id, Sz).
|
||||
|
||||
parent_index(Keys, V, I) :-
|
||||
rb_lookup(V, I, Keys).
|
||||
|
||||
%
|
||||
% first, remove nodes that have evidence from tables.
|
||||
@ -180,52 +173,36 @@ add2graph(V, Vals, Table, IParents, Graph, Keys) :-
|
||||
member(tabular(Table,Index,IParents), VarSlot), !.
|
||||
|
||||
sort_according_to_indices(NVs,Keys,SortedNVs,SortedIndices) :-
|
||||
vars2indices(NVs,Keys,ToSort),
|
||||
maplist(var2index(Keys), NVs, ToSort),
|
||||
keysort(ToSort, Sorted),
|
||||
split_parents(Sorted, SortedNVs,SortedIndices).
|
||||
maplist(split_parent, Sorted, SortedNVs,SortedIndices).
|
||||
|
||||
split_parents([], [], []).
|
||||
split_parents([I-V|Sorted], [V|SortedNVs],[I|SortedIndices]) :-
|
||||
split_parents(Sorted, SortedNVs, SortedIndices).
|
||||
split_parent(I-V, V, I).
|
||||
|
||||
|
||||
vars2indices([],_,[]).
|
||||
vars2indices([V|Parents],Keys,[I-V|IParents]) :-
|
||||
rb_lookup(V, I, Keys),
|
||||
vars2indices(Parents,Keys,IParents).
|
||||
var2index(Keys, V, I-V) :-
|
||||
rb_lookup(V, I, Keys).
|
||||
|
||||
%
|
||||
% This is the really cool bit.
|
||||
%
|
||||
compile_graph(Graph) :-
|
||||
Graph =.. [_|VarsInfo],
|
||||
compile_vars(VarsInfo,Graph).
|
||||
maplist( compile_var(Graph), VarsInfo).
|
||||
|
||||
compile_vars([],_).
|
||||
compile_vars([var(_,I,_,Vals,Sz,VarSlot,Parents,_,_)|VarsInfo],Graph)
|
||||
:-
|
||||
compile_var(I,Vals,Sz,VarSlot,Parents,Graph),
|
||||
compile_vars(VarsInfo,Graph).
|
||||
|
||||
compile_var(I,Vals,Sz,VarSlot,Parents,Graph) :-
|
||||
fetch_all_parents(VarSlot,Graph,[],Parents,[],Sizes),
|
||||
mult_list(Sizes,1,TotSize),
|
||||
compile_var(Graph, var(_,I,_,Vals,Sz,VarSlot,Parents,_,_)) :-
|
||||
foldl2( fetch_parent(Graph), VarSlot, [], Parents, [], Sizes),
|
||||
foldl( mult_list, Sizes,1,TotSize),
|
||||
compile_var(TotSize,I,Vals,Sz,VarSlot,Parents,Sizes,Graph).
|
||||
|
||||
fetch_all_parents([],_,Parents,Parents,Sizes,Sizes) :- !.
|
||||
fetch_all_parents([tabular(_,_,Ps)|CPTs],Graph,Parents0,ParentsF,Sizes0,SizesF) :-
|
||||
merge_these_parents(Ps,Graph,Parents0,ParentsI,Sizes0,SizesI),
|
||||
fetch_all_parents(CPTs,Graph,ParentsI,ParentsF,SizesI,SizesF).
|
||||
fetch_parent(Graph, tabular(_,_,Ps), Parents0, ParentsF, Sizes0, SizesF) :-
|
||||
foldl2( merge_these_parents(Graph), Ps, Parents0, ParentsF, Sizes0, SizesF).
|
||||
|
||||
merge_these_parents([],_,Parents,Parents,Sizes,Sizes).
|
||||
merge_these_parents([I|Ps],Graph,Parents0,ParentsF,Sizes0,SizesF) :-
|
||||
member(I,Parents0), !,
|
||||
merge_these_parents(Ps,Graph,Parents0,ParentsF,Sizes0,SizesF).
|
||||
merge_these_parents([I|Ps],Graph,Parents0,ParentsF,Sizes0,SizesF) :-
|
||||
merge_these_parents(_Graph, I,Parents0,Parents0,Sizes0,Sizes0) :-
|
||||
member(I,Parents0), !.
|
||||
merge_these_parents(Graph, I, Parents0,ParentsF,Sizes0,SizesF) :-
|
||||
arg(I,Graph,var(_,I,_,Vals,_,_,_,_,_)),
|
||||
length(Vals, Sz),
|
||||
add_parent(Parents0,I,ParentsI,Sizes0,Sz,SizesI),
|
||||
merge_these_parents(Ps,Graph,ParentsI,ParentsF,SizesI,SizesF).
|
||||
add_parent(Parents0,I,ParentsF,Sizes0,Sz,SizesF).
|
||||
|
||||
add_parent([],I,[I],[],Sz,[Sz]).
|
||||
add_parent([P|Parents0],I,[I,P|Parents0],Sizes0,Sz,[Sz|Sizes0]) :-
|
||||
@ -234,10 +211,8 @@ add_parent([P|Parents0],I,[P|ParentsI],[S|Sizes0],Sz,[S|SizesI]) :-
|
||||
add_parent(Parents0,I,ParentsI,Sizes0,Sz,SizesI).
|
||||
|
||||
|
||||
mult_list([],Mult,Mult).
|
||||
mult_list([Sz|Sizes],Mult0,Mult) :-
|
||||
MultI is Sz*Mult0,
|
||||
mult_list(Sizes,MultI,Mult).
|
||||
mult_list(Sz,Mult0,Mult) :-
|
||||
Mult is Sz*Mult0.
|
||||
|
||||
% compile node as set of facts, faster execution
|
||||
compile_var(TotSize,I,_Vals,Sz,CPTs,Parents,_Sizes,Graph) :-
|
||||
@ -247,7 +222,7 @@ compile_var(TotSize,I,_Vals,Sz,CPTs,Parents,_Sizes,Graph) :-
|
||||
compile_var(_,_,_,_,_,_,_,_).
|
||||
|
||||
multiply_all(I,Parents,CPTs,Sz,Graph) :-
|
||||
markov_blanket_instance(Parents,Graph,Values),
|
||||
maplist( markov_blanket_instance(Graph), Parents, Values),
|
||||
(
|
||||
multiply_all(CPTs,Graph,Probs)
|
||||
->
|
||||
@ -261,11 +236,9 @@ multiply_all(I,_,_,_,_) :-
|
||||
|
||||
% note: what matters is how this predicate instantiates the temp
|
||||
% slot in the graph!
|
||||
markov_blanket_instance([],_,[]).
|
||||
markov_blanket_instance([I|Parents],Graph,[Pos|Values]) :-
|
||||
arg(I,Graph,var(_,I,Pos,Vals,_,_,_,_,_)),
|
||||
fetch_val(Vals,0,Pos),
|
||||
markov_blanket_instance(Parents,Graph,Values).
|
||||
markov_blanket_instance(Graph, I, Pos) :-
|
||||
arg(I, Graph, var(_,I,Pos,Vals,_,_,_,_,_)),
|
||||
fetch_val(Vals, 0, Pos).
|
||||
|
||||
% backtrack through every value in domain
|
||||
%
|
||||
@ -275,21 +248,19 @@ fetch_val([_|Vals],I0,Pos) :-
|
||||
fetch_val(Vals,I,Pos).
|
||||
|
||||
multiply_all([tabular(Table,_,Parents)|CPTs],Graph,Probs) :-
|
||||
fetch_parents(Parents, Graph, Vals),
|
||||
maplist( fetch_parent(Graph), Parents, Vals),
|
||||
column_from_possibly_deterministic_CPT(Table,Vals,Probs0),
|
||||
multiply_more(CPTs,Graph,Probs0,Probs).
|
||||
|
||||
fetch_parents([], _, []).
|
||||
fetch_parents([P|Parents], Graph, [Val|Vals]) :-
|
||||
arg(P,Graph,var(_,_,Val,_,_,_,_,_,_)),
|
||||
fetch_parents(Parents, Graph, Vals).
|
||||
fetch_parent(Graph, P, Val) :-
|
||||
arg(P,Graph,var(_,_,Val,_,_,_,_,_,_)).
|
||||
|
||||
multiply_more([],_,Probs0,LProbs) :-
|
||||
normalise_possibly_deterministic_CPT(Probs0, Probs),
|
||||
list_from_CPT(Probs, LProbs0),
|
||||
accumulate_up_list(LProbs0, 0.0, LProbs).
|
||||
multiply_more([tabular(Table,_,Parents)|CPTs],Graph,Probs0,Probs) :-
|
||||
fetch_parents(Parents, Graph, Vals),
|
||||
maplist( fetch_parent(Graph), Parents, Vals),
|
||||
column_from_possibly_deterministic_CPT(Table, Vals, P0),
|
||||
multiply_possibly_deterministic_factors(Probs0, P0, ProbsI),
|
||||
multiply_more(CPTs,Graph,ProbsI,Probs).
|
||||
@ -378,7 +349,7 @@ process_chains(0,_,F,F,_,_,Est,Est) :- !.
|
||||
process_chains(ToDo,VarOrder,End,Start,Graph,Len,Est0,Estf) :-
|
||||
%format('ToDo = ~d~n',[ToDo]),
|
||||
process_chains(Start,VarOrder,Int,Graph,Len,Est0,Esti),
|
||||
% (ToDo mod 100 =:= 1 -> statistics,cvt2problist(Esti, Probs), Int =[S|_], format('did ~d: ~w~n ~w~n',[ToDo,Probs,S]) ; true),
|
||||
% (ToDo mod 100 =:= 1 -> statistics,maplist(cvt2prob, Esti, Probs), Int =[S|_], format('did ~d: ~w~n ~w~n',[ToDo,Probs,S]) ; true),
|
||||
ToDo1 is ToDo-1,
|
||||
process_chains(ToDo1,VarOrder,End,Int,Graph,Len,Esti,Estf).
|
||||
|
||||
@ -388,7 +359,7 @@ process_chains([Sample0|Samples0], VarOrder, [Sample|Samples], Graph, SampLen,[E
|
||||
functor(Sample,sample,SampLen),
|
||||
do_sample(VarOrder,Sample,Sample0,Graph),
|
||||
% format('Sample = ~w~n',[Sample]),
|
||||
update_estimates(E0,Sample,Ef),
|
||||
maplist(update_estimate(Sample), E0, Ef),
|
||||
process_chains(Samples0, VarOrder, Samples, Graph, SampLen,E0s,Efs).
|
||||
|
||||
do_sample([],_,_,_).
|
||||
@ -439,15 +410,10 @@ pick_new_value([V|Vals],X,I0,Val) :-
|
||||
pick_new_value(Vals,X,I,Val)
|
||||
).
|
||||
|
||||
update_estimates([],_,[]).
|
||||
update_estimates([Est|E0],Sample,[NEst|Ef]) :-
|
||||
update_estimate(Est,Sample,NEst),
|
||||
update_estimates(E0,Sample,Ef).
|
||||
|
||||
update_estimate([I|E],Sample,[I|NE]) :-
|
||||
update_estimate(Sample, [I|E],[I|NE]) :-
|
||||
arg(I,Sample,V),
|
||||
update_estimate_for_var(V,E,NE).
|
||||
update_estimate(me(Is,Mult,E),Sample,me(Is,Mult,NE)) :-
|
||||
update_estimate(Sample,me(Is,Mult,E),me(Is,Mult,NE)) :-
|
||||
get_estimate_pos(Is, Sample, Mult, 0, V),
|
||||
update_estimate_for_var(V,E,NE).
|
||||
|
||||
@ -481,21 +447,15 @@ clean_up.
|
||||
|
||||
gibbs_params(5,100,1000).
|
||||
|
||||
cvt2problist([], []).
|
||||
cvt2problist([[[_|E]]|Est0], [Ps|Probs]) :-
|
||||
sum_all(E,0,Sum),
|
||||
do_probs(E,Sum,Ps),
|
||||
cvt2problist(Est0, Probs) .
|
||||
cvt2prob([[_|E]], Ps) :-
|
||||
foldl(sum_all, E, 0, Sum),
|
||||
maplist( do_prob(Sum), E, Ps).
|
||||
|
||||
sum_all([],Sum,Sum).
|
||||
sum_all([E|Es],S0,Sum) :-
|
||||
SI is S0+E,
|
||||
sum_all(Es,SI,Sum).
|
||||
sum_all(E, S0, Sum) :-
|
||||
Sum is S0+E.
|
||||
|
||||
do_probs([],_,[]).
|
||||
do_probs([E|Es],Sum,[P|Ps]) :-
|
||||
P is E/Sum,
|
||||
do_probs(Es,Sum,Ps).
|
||||
do_prob(Sum, E, P) :-
|
||||
P is E/Sum.
|
||||
|
||||
show_sorted([], _) :- nl.
|
||||
show_sorted([I|VarOrder], Graph) :-
|
||||
@ -506,13 +466,11 @@ show_sorted([I|VarOrder], Graph) :-
|
||||
|
||||
sum_up_all([[]|_], []).
|
||||
sum_up_all([[C|MoreC]|Chains], [Dist|Dists]) :-
|
||||
extract_sums(Chains, CurrentChains, LeftChains),
|
||||
maplist( extract_sum, Chains, CurrentChains, LeftChains),
|
||||
sum_up([C|CurrentChains], Dist),
|
||||
sum_up_all([MoreC|LeftChains], Dists).
|
||||
|
||||
extract_sums([], [], []).
|
||||
extract_sums([[C|Chains]|MoreChains], [C|CurrentChains], [Chains|LeftChains]) :-
|
||||
extract_sums(MoreChains, CurrentChains, LeftChains).
|
||||
extract_sum([C|Chains], C, Chains).
|
||||
|
||||
sum_up([[_|Counts]|Chains], Dist) :-
|
||||
add_up(Counts,Chains, Add),
|
||||
@ -523,25 +481,21 @@ sum_up([me(_,_,Counts)|Chains], Dist) :-
|
||||
|
||||
add_up(Counts,[],Counts).
|
||||
add_up(Counts,[[_|Cs]|Chains], Add) :-
|
||||
sum_lists(Counts, Cs, NCounts),
|
||||
maplist(sum, Counts, Cs, NCounts),
|
||||
add_up(NCounts, Chains, Add).
|
||||
|
||||
add_up_mes(Counts,[],Counts).
|
||||
add_up_mes(Counts,[me(_,_,Cs)|Chains], Add) :-
|
||||
sum_lists(Counts, Cs, NCounts),
|
||||
maplist( sum_list, Counts, Cs, NCounts),
|
||||
add_up_mes(NCounts, Chains, Add).
|
||||
|
||||
sum_lists([],[],[]).
|
||||
sum_lists([Count|Counts], [C|Cs], [NC|NCounts]) :-
|
||||
NC is Count+C,
|
||||
sum_lists(Counts, Cs, NCounts).
|
||||
sum(Count, C, NC) :-
|
||||
NC is Count+C.
|
||||
|
||||
normalise(Add, Dist) :-
|
||||
sum_list(Add, Sum),
|
||||
divide_list(Add, Sum, Dist).
|
||||
maplist(divide(Sum), Add, Dist).
|
||||
|
||||
divide_list([], _, []).
|
||||
divide_list([C|Add], Sum, [P|Dist]) :-
|
||||
P is C/Sum,
|
||||
divide_list(Add, Sum, Dist).
|
||||
divide(Sum, C, P) :-
|
||||
P is C/Sum.
|
||||
|
||||
|
@ -32,7 +32,7 @@
|
||||
[clpbn_bind_vals/3]).
|
||||
|
||||
:- use_module(library(pfl),
|
||||
[get_pfl_parameters/2,
|
||||
[get_pfl_parameters/3,
|
||||
skolem/2
|
||||
]).
|
||||
|
||||
@ -50,7 +50,7 @@ call_horus_ground_solver(QueryVars, QueryKeys, AllKeys, Factors, Evidence,
|
||||
end_horus_ground_solver(State).
|
||||
|
||||
|
||||
init_horus_ground_solver(QueryKeys, AllKeys, Factors, Evidence,
|
||||
init_horus_ground_solver(_QueryKeys, AllKeys, Factors, Evidence,
|
||||
state(Network,Hash,Id,DistIds)) :-
|
||||
factors_type(Factors, Type),
|
||||
keys_to_numbers(AllKeys, Factors, Evidence, Hash, Id, FacIds, EvIds),
|
||||
@ -64,10 +64,10 @@ init_horus_ground_solver(QueryKeys, AllKeys, Factors, Evidence,
|
||||
|
||||
|
||||
run_horus_ground_solver(QueryKeys, Solutions,
|
||||
state(Network,Hash,Id, DistIds)) :-
|
||||
state(Network,Hash,Id, _DistIds)) :-
|
||||
lists_of_keys_to_ids(QueryKeys, QueryIds, Hash, _, Id, _),
|
||||
%maplist(get_pfl_parameters, DistIds, DistParams),
|
||||
%cpp_set_factors_params(Network, DistIds, DistParams),
|
||||
%maplist(get_pfl_parameters, _DistIds, _, DistParams),
|
||||
%cpp_set_factors_params(Network, _DistIds, DistParams),
|
||||
cpp_run_ground_solver(Network, QueryIds, Solutions).
|
||||
|
||||
|
||||
@ -79,7 +79,7 @@ factors_type([f(bayes, _, _)|_], bayes) :- ! .
|
||||
factors_type([f(markov, _, _)|_], markov) :- ! .
|
||||
|
||||
|
||||
get_dist_id(f(_, _, _, DistId), DistId).
|
||||
get_dist_id(fn(_, _, _, DistId, _), DistId).
|
||||
|
||||
|
||||
get_domain(_:Key, Domain) :- !,
|
||||
|
@ -28,7 +28,7 @@
|
||||
:- use_module(library(pfl),
|
||||
[factor/6,
|
||||
skolem/2,
|
||||
get_pfl_parameters/2
|
||||
get_pfl_parameters/3
|
||||
]).
|
||||
|
||||
:- use_module(library(maplist)).
|
||||
@ -50,9 +50,9 @@ init_horus_lifted_solver(_, AllVars, _, state(Network, DistIds)) :-
|
||||
sort(DistIds0, DistIds).
|
||||
|
||||
|
||||
run_horus_lifted_solver(QueryVars, Solutions, state(Network, DistIds)) :-
|
||||
run_horus_lifted_solver(QueryVars, Solutions, state(Network, _DistIds)) :-
|
||||
maplist(get_query_keys, QueryVars, QueryKeys),
|
||||
%maplist(get_pfl_parameters, DistIds,DistsParams),
|
||||
%maplist(get_pfl_parameters, DistIds, _, DistsParams),
|
||||
%cpp_set_parfactors_params(Network, DistIds, DistsParams),
|
||||
cpp_run_lifted_solver(Network, QueryKeys, Solutions).
|
||||
|
||||
|
@ -41,7 +41,7 @@ key_to_id(Key, I0, Hash0, Hash, I0, I) :-
|
||||
b_hash_insert(Hash0, Key, I0, Hash),
|
||||
I is I0+1.
|
||||
|
||||
factor_to_id(Ev, f(_, DistId, Keys), f(Ids, Ranges, CPT, DistId), Hash0, Hash, I0, I) :-
|
||||
factor_to_id(Ev, f(_, DistId, Keys), fn(Ids, Ranges, CPT, DistId, Keys), Hash0, Hash, I0, I) :-
|
||||
get_pfl_cpt(DistId, Keys, Ev, NKeys, CPT),
|
||||
foldl2(key_to_id, NKeys, Ids, Hash0, Hash, I0, I),
|
||||
maplist(get_range, Keys, Ranges).
|
||||
|
@ -131,9 +131,9 @@ init_ve(FactorIds, EvidenceIds, Hash, Id, ve(FactorIds, Hash, Id, Ev)) :-
|
||||
evtotree(K=V,Ev0,Ev) :-
|
||||
rb_insert(Ev0, K, V, Ev).
|
||||
|
||||
factor_to_graph( f(Nodes, Sizes, _Pars0, Id), Factors0, Factors, Edges0, Edges, I0, I) :-
|
||||
factor_to_graph( fn(Nodes, Sizes, _Pars0, Id, Keys), Factors0, Factors, Edges0, Edges, I0, I) :-
|
||||
I is I0+1,
|
||||
pfl:get_pfl_parameters(Id, Pars0),
|
||||
pfl:get_pfl_parameters(Id, Keys, Pars0),
|
||||
init_CPT(Pars0, Sizes, CPT0),
|
||||
reorder_CPT(Nodes, CPT0, FIPs, CPT, _),
|
||||
F = f(I0, FIPs, CPT),
|
||||
|
137
packages/CLPBN/examples/complex.fg
Normal file
137
packages/CLPBN/examples/complex.fg
Normal file
@ -0,0 +1,137 @@
|
||||
10
|
||||
|
||||
2
|
||||
0 1
|
||||
2 2
|
||||
4
|
||||
0 1.02
|
||||
1 0.87
|
||||
2 0.88
|
||||
3 0.45
|
||||
|
||||
4
|
||||
1 2 3 4
|
||||
2 2 3 3
|
||||
36
|
||||
0 0.11
|
||||
1 1.11
|
||||
2 0.41
|
||||
3 0.12
|
||||
4 0.1
|
||||
5 0.17
|
||||
6 1.21
|
||||
7 1.1
|
||||
8 0.11
|
||||
9 0.41
|
||||
10 0.8
|
||||
11 0.71
|
||||
12 0.14
|
||||
13 0.24
|
||||
14 0.54
|
||||
15 1.4
|
||||
16 0.23
|
||||
17 0.24
|
||||
18 0.65
|
||||
19 0.05
|
||||
20 0.32
|
||||
21 0.12
|
||||
22 0.99
|
||||
23 0.69
|
||||
24 0.29
|
||||
25 1.29
|
||||
26 0.15
|
||||
27 1.24
|
||||
28 0.42
|
||||
29 0.124
|
||||
30 0.67
|
||||
31 0.078
|
||||
32 0.14
|
||||
33 0.55
|
||||
34 0.45
|
||||
35 0.1
|
||||
|
||||
3
|
||||
2 5 6
|
||||
2 2 3
|
||||
12
|
||||
0 0.15
|
||||
1 0.55
|
||||
2 2.21
|
||||
3 5.71
|
||||
4 0.44
|
||||
5 0.14
|
||||
6 0.5
|
||||
7 1.75
|
||||
8 1.29
|
||||
9 3.29
|
||||
10 0.36
|
||||
11 1.56
|
||||
|
||||
2
|
||||
7 2
|
||||
4 2
|
||||
8
|
||||
0 0.11
|
||||
1 0.59
|
||||
2 0.15
|
||||
3 0.124
|
||||
4 0.41
|
||||
5 2.11
|
||||
6 1.06
|
||||
7 0.929
|
||||
|
||||
1
|
||||
3
|
||||
3
|
||||
3
|
||||
0 0.1
|
||||
1 0.58
|
||||
2 0.74
|
||||
|
||||
1
|
||||
4
|
||||
3
|
||||
3
|
||||
0 3.2
|
||||
1 0.28
|
||||
2 1.24
|
||||
|
||||
2
|
||||
8 4
|
||||
2 3
|
||||
6
|
||||
0 0.19
|
||||
1 3.1
|
||||
2 0.49
|
||||
3 1.5
|
||||
4 2.1
|
||||
5 2.8
|
||||
|
||||
1
|
||||
5
|
||||
2
|
||||
2
|
||||
0 0.74
|
||||
1 0.14
|
||||
|
||||
1
|
||||
6
|
||||
3
|
||||
3
|
||||
0 0.032
|
||||
1 0.028
|
||||
2 0.24
|
||||
|
||||
2
|
||||
9 7
|
||||
2 4
|
||||
8
|
||||
0 0.61
|
||||
1 0.61
|
||||
2 1.4
|
||||
3 0.24
|
||||
4 0.09
|
||||
5 0.19
|
||||
6 1.4
|
||||
7 0.6
|
||||
|
79
packages/CLPBN/examples/complex.pfl
Normal file
79
packages/CLPBN/examples/complex.pfl
Normal file
@ -0,0 +1,79 @@
|
||||
:- use_module(library(pfl)).
|
||||
|
||||
%:- set_solver(ve).
|
||||
%:- set_solver(hve).
|
||||
%:- set_solver(jt).
|
||||
%:- set_solver(bdd).
|
||||
%:- set_solver(bp).
|
||||
%:- set_solver(cbp).
|
||||
%:- set_solver(gibbs).
|
||||
%:- set_solver(lve).
|
||||
%:- set_solver(lkc).
|
||||
%:- set_solver(lbp).
|
||||
|
||||
/*
|
||||
|
||||
v01 v02
|
||||
\ /
|
||||
\ /
|
||||
\ /
|
||||
v03 v04 v05
|
||||
/ \ | / \
|
||||
/ \ | / \
|
||||
/ \ | / \
|
||||
v06 v07 v08
|
||||
| |
|
||||
| |
|
||||
| |
|
||||
v09 v10
|
||||
|
||||
*/
|
||||
|
||||
markov v01::[a,b] ; table1 ; [].
|
||||
|
||||
markov v02::[a,b,c] ; table2 ; [].
|
||||
|
||||
markov v03::[a,b], v01, v02 ; table3 ; [].
|
||||
|
||||
markov v04::[a,b,c] ; table4 ; [].
|
||||
|
||||
markov v05::[a,b,c] ; table5 ; [].
|
||||
|
||||
markov v06::[a,b,c,d], v03 ; table6 ; [].
|
||||
|
||||
markov v07::[a,b], v03, v04, v05 ; table7 ; [].
|
||||
|
||||
markov v08::[a,b], v05 ; table8 ; [].
|
||||
|
||||
markov v09::[a,b], v06 ; table9 ; [].
|
||||
|
||||
markov v10::[a,b], v07 ; table10 ; [].
|
||||
|
||||
table1([ 0.74, 0.14 ]).
|
||||
|
||||
table2([ 0.032, 0.028, 0.24 ]).
|
||||
|
||||
table3([
|
||||
0.15, 0.44, 1.29, 2.21, 0.5, 0.36,
|
||||
0.55, 0.14, 3.29, 5.71, 1.75, 1.56
|
||||
]).
|
||||
|
||||
table4([ 0.1, 0.58, 0.74 ]).
|
||||
|
||||
table5([ 3.2, 0.28, 1.24 ]).
|
||||
|
||||
table6([ 0.11, 0.41, 0.59, 2.11, 0.15, 1.06, 0.124, 0.929 ]).
|
||||
|
||||
table7([
|
||||
0.11, 0.14, 0.29, 0.1, 0.23, 0.42, 0.11, 0.32, 0.14,
|
||||
0.41, 0.54, 0.15, 1.21, 0.65, 0.67, 0.8, 0.99, 0.45,
|
||||
1.11, 0.24, 1.29, 0.17, 0.24, 0.124, 0.41, 0.12, 0.55,
|
||||
0.12, 1.4, 1.24, 1.1, 0.05, 0.078, 0.71, 0.69, 0.1
|
||||
]).
|
||||
|
||||
table8([ 0.19, 0.49, 2.1, 3.1, 1.5, 2.8 ]).
|
||||
|
||||
table9([ 0.61, 1.4, 0.09, 1.4, 0.61, 0.24, 0.19, 0.6 ]).
|
||||
|
||||
table10([ 1.02, 0.88, 0.87, 0.45 ]).
|
||||
|
@ -3,6 +3,25 @@
|
||||
#include "BayesBall.h"
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
BayesBall::BayesBall (FactorGraph& fg)
|
||||
: fg_(fg) , dag_(fg.getStructure())
|
||||
{
|
||||
dag_.clear();
|
||||
}
|
||||
|
||||
|
||||
|
||||
FactorGraph*
|
||||
BayesBall::getMinimalFactorGraph (FactorGraph& fg, VarIds vids)
|
||||
{
|
||||
BayesBall bb (fg);
|
||||
return bb.getMinimalFactorGraph (vids);
|
||||
}
|
||||
|
||||
|
||||
|
||||
FactorGraph*
|
||||
BayesBall::getMinimalFactorGraph (const VarIds& queryIds)
|
||||
{
|
||||
@ -19,22 +38,22 @@ BayesBall::getMinimalFactorGraph (const VarIds& queryIds)
|
||||
BBNode* n = sch.node;
|
||||
n->setAsVisited();
|
||||
if (n->hasEvidence() == false && sch.visitedFromChild) {
|
||||
if (n->isMarkedOnTop() == false) {
|
||||
n->markOnTop();
|
||||
if (n->isMarkedAbove() == false) {
|
||||
n->markAbove();
|
||||
scheduleParents (n, scheduling);
|
||||
}
|
||||
if (n->isMarkedOnBottom() == false) {
|
||||
n->markOnBottom();
|
||||
if (n->isMarkedBelow() == false) {
|
||||
n->markBelow();
|
||||
scheduleChilds (n, scheduling);
|
||||
}
|
||||
}
|
||||
if (sch.visitedFromParent) {
|
||||
if (n->hasEvidence() && n->isMarkedOnTop() == false) {
|
||||
n->markOnTop();
|
||||
if (n->hasEvidence() && n->isMarkedAbove() == false) {
|
||||
n->markAbove();
|
||||
scheduleParents (n, scheduling);
|
||||
}
|
||||
if (n->hasEvidence() == false && n->isMarkedOnBottom() == false) {
|
||||
n->markOnBottom();
|
||||
if (n->hasEvidence() == false && n->isMarkedBelow() == false) {
|
||||
n->markBelow();
|
||||
scheduleChilds (n, scheduling);
|
||||
}
|
||||
}
|
||||
@ -55,7 +74,7 @@ BayesBall::constructGraph (FactorGraph* fg) const
|
||||
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||
const BBNode* n = dag_.getNode (
|
||||
facNodes[i]->factor().argument (0));
|
||||
if (n->isMarkedOnTop()) {
|
||||
if (n->isMarkedAbove()) {
|
||||
fg->addFactor (facNodes[i]->factor());
|
||||
} else if (n->hasEvidence() && n->isVisited()) {
|
||||
VarIds varIds = { facNodes[i]->factor().argument (0) };
|
||||
@ -76,3 +95,5 @@ BayesBall::constructGraph (FactorGraph* fg) const
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
#ifndef HORUS_BAYESBALL_H
|
||||
#define HORUS_BAYESBALL_H
|
||||
#ifndef YAP_PACKAGES_CLPBN_HORUS_BAYESBALL_H_
|
||||
#define YAP_PACKAGES_CLPBN_HORUS_BAYESBALL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <queue>
|
||||
@ -9,41 +9,28 @@
|
||||
#include "BayesBallGraph.h"
|
||||
#include "Horus.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace Horus {
|
||||
|
||||
struct ScheduleInfo
|
||||
{
|
||||
ScheduleInfo (BBNode* n, bool vfp, bool vfc)
|
||||
: node(n), visitedFromParent(vfp), visitedFromChild(vfc) { }
|
||||
|
||||
BBNode* node;
|
||||
bool visitedFromParent;
|
||||
bool visitedFromChild;
|
||||
};
|
||||
|
||||
|
||||
typedef queue<ScheduleInfo, list<ScheduleInfo>> Scheduling;
|
||||
|
||||
|
||||
class BayesBall
|
||||
{
|
||||
class BayesBall {
|
||||
public:
|
||||
BayesBall (FactorGraph& fg)
|
||||
: fg_(fg) , dag_(fg.getStructure())
|
||||
{
|
||||
dag_.clear();
|
||||
}
|
||||
BayesBall (FactorGraph& fg);
|
||||
|
||||
FactorGraph* getMinimalFactorGraph (const VarIds&);
|
||||
|
||||
static FactorGraph* getMinimalFactorGraph (FactorGraph& fg, VarIds vids)
|
||||
{
|
||||
BayesBall bb (fg);
|
||||
return bb.getMinimalFactorGraph (vids);
|
||||
}
|
||||
static FactorGraph* getMinimalFactorGraph (FactorGraph& fg, VarIds vids);
|
||||
|
||||
private:
|
||||
struct ScheduleInfo {
|
||||
ScheduleInfo (BBNode* n, bool vfp, bool vfc)
|
||||
: node(n), visitedFromParent(vfp), visitedFromChild(vfc) { }
|
||||
|
||||
BBNode* node;
|
||||
bool visitedFromParent;
|
||||
bool visitedFromChild;
|
||||
};
|
||||
|
||||
typedef std::queue<ScheduleInfo, std::list<ScheduleInfo>> Scheduling;
|
||||
|
||||
void constructGraph (FactorGraph* fg) const;
|
||||
|
||||
@ -51,9 +38,8 @@ class BayesBall
|
||||
|
||||
void scheduleChilds (const BBNode* n, Scheduling& sch) const;
|
||||
|
||||
FactorGraph& fg_;
|
||||
|
||||
BayesBallGraph& dag_;
|
||||
FactorGraph& fg_;
|
||||
BayesBallGraph& dag_;
|
||||
};
|
||||
|
||||
|
||||
@ -61,8 +47,8 @@ class BayesBall
|
||||
inline void
|
||||
BayesBall::scheduleParents (const BBNode* n, Scheduling& sch) const
|
||||
{
|
||||
const vector<BBNode*>& ps = n->parents();
|
||||
for (vector<BBNode*>::const_iterator it = ps.begin();
|
||||
const std::vector<BBNode*>& ps = n->parents();
|
||||
for (std::vector<BBNode*>::const_iterator it = ps.begin();
|
||||
it != ps.end(); ++it) {
|
||||
sch.push (ScheduleInfo (*it, false, true));
|
||||
}
|
||||
@ -73,12 +59,14 @@ BayesBall::scheduleParents (const BBNode* n, Scheduling& sch) const
|
||||
inline void
|
||||
BayesBall::scheduleChilds (const BBNode* n, Scheduling& sch) const
|
||||
{
|
||||
const vector<BBNode*>& cs = n->childs();
|
||||
for (vector<BBNode*>::const_iterator it = cs.begin();
|
||||
const std::vector<BBNode*>& cs = n->childs();
|
||||
for (std::vector<BBNode*>::const_iterator it = cs.begin();
|
||||
it != cs.end(); ++it) {
|
||||
sch.push (ScheduleInfo (*it, true, false));
|
||||
}
|
||||
}
|
||||
|
||||
#endif // HORUS_BAYESBALL_H
|
||||
} // namespace Horus
|
||||
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_BAYESBALL_H_
|
||||
|
||||
|
@ -1,14 +1,15 @@
|
||||
#include <cstdlib>
|
||||
#include <cassert>
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
|
||||
#include "BayesBallGraph.h"
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
void
|
||||
BayesBallGraph::addNode (BBNode* n)
|
||||
{
|
||||
@ -22,8 +23,8 @@ BayesBallGraph::addNode (BBNode* n)
|
||||
void
|
||||
BayesBallGraph::addEdge (VarId vid1, VarId vid2)
|
||||
{
|
||||
unordered_map<VarId, BBNode*>::iterator it1;
|
||||
unordered_map<VarId, BBNode*>::iterator it2;
|
||||
std::unordered_map<VarId, BBNode*>::iterator it1;
|
||||
std::unordered_map<VarId, BBNode*>::iterator it2;
|
||||
it1 = varMap_.find (vid1);
|
||||
it2 = varMap_.find (vid2);
|
||||
assert (it1 != varMap_.end());
|
||||
@ -37,7 +38,7 @@ BayesBallGraph::addEdge (VarId vid1, VarId vid2)
|
||||
const BBNode*
|
||||
BayesBallGraph::getNode (VarId vid) const
|
||||
{
|
||||
unordered_map<VarId, BBNode*>::const_iterator it;
|
||||
std::unordered_map<VarId, BBNode*>::const_iterator it;
|
||||
it = varMap_.find (vid);
|
||||
return it != varMap_.end() ? it->second : 0;
|
||||
}
|
||||
@ -47,7 +48,7 @@ BayesBallGraph::getNode (VarId vid) const
|
||||
BBNode*
|
||||
BayesBallGraph::getNode (VarId vid)
|
||||
{
|
||||
unordered_map<VarId, BBNode*>::const_iterator it;
|
||||
std::unordered_map<VarId, BBNode*>::const_iterator it;
|
||||
it = varMap_.find (vid);
|
||||
return it != varMap_.end() ? it->second : 0;
|
||||
}
|
||||
@ -55,7 +56,7 @@ BayesBallGraph::getNode (VarId vid)
|
||||
|
||||
|
||||
void
|
||||
BayesBallGraph::setIndexes (void)
|
||||
BayesBallGraph::setIndexes()
|
||||
{
|
||||
for (size_t i = 0; i < nodes_.size(); i++) {
|
||||
nodes_[i]->setIndex (i);
|
||||
@ -65,7 +66,7 @@ BayesBallGraph::setIndexes (void)
|
||||
|
||||
|
||||
void
|
||||
BayesBallGraph::clear (void)
|
||||
BayesBallGraph::clear()
|
||||
{
|
||||
for (size_t i = 0; i < nodes_.size(); i++) {
|
||||
nodes_[i]->clear();
|
||||
@ -77,13 +78,14 @@ BayesBallGraph::clear (void)
|
||||
void
|
||||
BayesBallGraph::exportToGraphViz (const char* fileName)
|
||||
{
|
||||
ofstream out (fileName);
|
||||
std::ofstream out (fileName);
|
||||
if (!out.is_open()) {
|
||||
cerr << "Error: couldn't open file '" << fileName << "'." ;
|
||||
std::cerr << "Error: couldn't open file '" << fileName << "'." ;
|
||||
std::cerr << std::endl;
|
||||
return;
|
||||
}
|
||||
out << "digraph {" << endl;
|
||||
out << "ranksep=1" << endl;
|
||||
out << "digraph {" << std::endl;
|
||||
out << "ranksep=1" << std::endl;
|
||||
for (size_t i = 0; i < nodes_.size(); i++) {
|
||||
out << nodes_[i]->varId() ;
|
||||
out << " [" ;
|
||||
@ -91,16 +93,18 @@ BayesBallGraph::exportToGraphViz (const char* fileName)
|
||||
if (nodes_[i]->hasEvidence()) {
|
||||
out << ",style=filled, fillcolor=yellow" ;
|
||||
}
|
||||
out << "]" << endl;
|
||||
out << "]" << std::endl;
|
||||
}
|
||||
for (size_t i = 0; i < nodes_.size(); i++) {
|
||||
const vector<BBNode*>& childs = nodes_[i]->childs();
|
||||
const std::vector<BBNode*>& childs = nodes_[i]->childs();
|
||||
for (size_t j = 0; j < childs.size(); j++) {
|
||||
out << nodes_[i]->varId() << " -> " << childs[j]->varId();
|
||||
out << " [style=bold]" << endl ;
|
||||
out << " [style=bold]" << std::endl;
|
||||
}
|
||||
}
|
||||
out << "}" << endl;
|
||||
out << "}" << std::endl;
|
||||
out.close();
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
#ifndef HORUS_BAYESBALLGRAPH_H
|
||||
#define HORUS_BAYESBALLGRAPH_H
|
||||
#ifndef YAP_PACKAGES_CLPBN_HORUS_BAYESBALLGRAPH_H_
|
||||
#define YAP_PACKAGES_CLPBN_HORUS_BAYESBALLGRAPH_H_
|
||||
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
@ -7,54 +7,55 @@
|
||||
#include "Var.h"
|
||||
#include "Horus.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
class BBNode : public Var
|
||||
{
|
||||
namespace Horus {
|
||||
|
||||
class BBNode : public Var {
|
||||
public:
|
||||
BBNode (Var* v) : Var (v), visited_(false),
|
||||
markedOnTop_(false), markedOnBottom_(false) { }
|
||||
markedAbove_(false), markedBelow_(false) { }
|
||||
|
||||
const vector<BBNode*>& childs (void) const { return childs_; }
|
||||
const std::vector<BBNode*>& childs() const { return childs_; }
|
||||
|
||||
vector<BBNode*>& childs (void) { return childs_; }
|
||||
std::vector<BBNode*>& childs() { return childs_; }
|
||||
|
||||
const vector<BBNode*>& parents (void) const { return parents_; }
|
||||
const std::vector<BBNode*>& parents() const { return parents_; }
|
||||
|
||||
vector<BBNode*>& parents (void) { return parents_; }
|
||||
std::vector<BBNode*>& parents() { return parents_; }
|
||||
|
||||
void addParent (BBNode* p) { parents_.push_back (p); }
|
||||
|
||||
void addChild (BBNode* c) { childs_.push_back (c); }
|
||||
|
||||
bool isVisited (void) const { return visited_; }
|
||||
bool isVisited() const { return visited_; }
|
||||
|
||||
void setAsVisited (void) { visited_ = true; }
|
||||
void setAsVisited() { visited_ = true; }
|
||||
|
||||
bool isMarkedOnTop (void) const { return markedOnTop_; }
|
||||
bool isMarkedAbove() const { return markedAbove_; }
|
||||
|
||||
void markOnTop (void) { markedOnTop_ = true; }
|
||||
void markAbove() { markedAbove_ = true; }
|
||||
|
||||
bool isMarkedOnBottom (void) const { return markedOnBottom_; }
|
||||
bool isMarkedBelow() const { return markedBelow_; }
|
||||
|
||||
void markOnBottom (void) { markedOnBottom_ = true; }
|
||||
void markBelow() { markedBelow_ = true; }
|
||||
|
||||
void clear (void) { visited_ = markedOnTop_ = markedOnBottom_ = false; }
|
||||
void clear() { visited_ = markedAbove_ = markedBelow_ = false; }
|
||||
|
||||
private:
|
||||
bool visited_;
|
||||
bool markedOnTop_;
|
||||
bool markedOnBottom_;
|
||||
bool markedAbove_;
|
||||
bool markedBelow_;
|
||||
|
||||
vector<BBNode*> childs_;
|
||||
vector<BBNode*> parents_;
|
||||
std::vector<BBNode*> childs_;
|
||||
std::vector<BBNode*> parents_;
|
||||
};
|
||||
|
||||
|
||||
class BayesBallGraph
|
||||
{
|
||||
class BayesBallGraph {
|
||||
public:
|
||||
BayesBallGraph (void) { }
|
||||
BayesBallGraph() { }
|
||||
|
||||
bool empty() const { return nodes_.empty(); }
|
||||
|
||||
void addNode (BBNode* n);
|
||||
|
||||
@ -64,19 +65,18 @@ class BayesBallGraph
|
||||
|
||||
BBNode* getNode (VarId vid);
|
||||
|
||||
bool empty (void) const { return nodes_.empty(); }
|
||||
void setIndexes();
|
||||
|
||||
void setIndexes (void);
|
||||
|
||||
void clear (void);
|
||||
void clear();
|
||||
|
||||
void exportToGraphViz (const char*);
|
||||
|
||||
private:
|
||||
vector<BBNode*> nodes_;
|
||||
|
||||
unordered_map<VarId, BBNode*> varMap_;
|
||||
std::vector<BBNode*> nodes_;
|
||||
std::unordered_map<VarId, BBNode*> varMap_;
|
||||
};
|
||||
|
||||
#endif // HORUS_BAYESBALLGRAPH_H
|
||||
} // namespace Horus
|
||||
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_BAYESBALLGRAPH_H_
|
||||
|
||||
|
@ -1,37 +1,39 @@
|
||||
#include <cassert>
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include <iostream>
|
||||
#include <iomanip>
|
||||
#include <sstream>
|
||||
|
||||
#include "BeliefProp.h"
|
||||
#include "Indexer.h"
|
||||
#include "Horus.h"
|
||||
|
||||
|
||||
double BeliefProp::accuracy_ = 0.0001;
|
||||
unsigned BeliefProp::maxIter_ = 1000;
|
||||
MsgSchedule BeliefProp::schedule_ = MsgSchedule::SEQ_FIXED;
|
||||
namespace Horus {
|
||||
|
||||
double BeliefProp::accuracy_ = 0.0001;
|
||||
unsigned BeliefProp::maxIter_ = 1000;
|
||||
|
||||
BeliefProp::MsgSchedule BeliefProp::schedule_ =
|
||||
MsgSchedule::seqFixedSch;
|
||||
|
||||
|
||||
BeliefProp::BeliefProp (const FactorGraph& fg) : GroundSolver (fg)
|
||||
|
||||
BeliefProp::BeliefProp (const FactorGraph& fg)
|
||||
: GroundSolver (fg), nIters_(0), runned_(false)
|
||||
{
|
||||
runned_ = false;
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
BeliefProp::~BeliefProp (void)
|
||||
BeliefProp::~BeliefProp()
|
||||
{
|
||||
for (size_t i = 0; i < varsI_.size(); i++) {
|
||||
delete varsI_[i];
|
||||
}
|
||||
for (size_t i = 0; i < facsI_.size(); i++) {
|
||||
delete facsI_[i];
|
||||
}
|
||||
for (size_t i = 0; i < links_.size(); i++) {
|
||||
delete links_[i];
|
||||
}
|
||||
links_.clear();
|
||||
}
|
||||
|
||||
|
||||
@ -48,22 +50,22 @@ BeliefProp::solveQuery (VarIds queryVids)
|
||||
|
||||
|
||||
void
|
||||
BeliefProp::printSolverFlags (void) const
|
||||
BeliefProp::printSolverFlags() const
|
||||
{
|
||||
stringstream ss;
|
||||
std::stringstream ss;
|
||||
ss << "belief propagation [" ;
|
||||
ss << "bp_msg_schedule=" ;
|
||||
switch (schedule_) {
|
||||
case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break;
|
||||
case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break;
|
||||
case MsgSchedule::PARALLEL: ss << "parallel"; break;
|
||||
case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break;
|
||||
case MsgSchedule::seqFixedSch: ss << "seq_fixed"; break;
|
||||
case MsgSchedule::seqRandomSch: ss << "seq_random"; break;
|
||||
case MsgSchedule::parallelSch: ss << "parallel"; break;
|
||||
case MsgSchedule::maxResidualSch: ss << "max_residual"; break;
|
||||
}
|
||||
ss << ",bp_max_iter=" << Util::toString (maxIter_);
|
||||
ss << ",bp_accuracy=" << Util::toString (accuracy_);
|
||||
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
||||
ss << ",bp_max_iter=" << Util::toString (maxIter_);
|
||||
ss << ",bp_accuracy=" << Util::toString (accuracy_);
|
||||
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
||||
ss << "]" ;
|
||||
cout << ss.str() << endl;
|
||||
std::cout << ss.str() << std::endl;
|
||||
}
|
||||
|
||||
|
||||
@ -82,7 +84,7 @@ BeliefProp::getPosterioriOf (VarId vid)
|
||||
probs[var->getEvidence()] = LogAware::withEvidence();
|
||||
} else {
|
||||
probs.resize (var->range(), LogAware::multIdenty());
|
||||
const BpLinks& links = ninf(var)->getLinks();
|
||||
const BpLinks& links = getLinks (var);
|
||||
if (Globals::logDomain) {
|
||||
for (size_t i = 0; i < links.size(); i++) {
|
||||
probs += links[i]->message();
|
||||
@ -133,7 +135,7 @@ BeliefProp::getFactorJoint (
|
||||
runSolver();
|
||||
}
|
||||
Factor res (fn->factor());
|
||||
const BpLinks& links = ninf(fn)->getLinks();
|
||||
const BpLinks& links = getLinks( fn);
|
||||
for (size_t i = 0; i < links.size(); i++) {
|
||||
Factor msg ({links[i]->varNode()->varId()},
|
||||
{links[i]->varNode()->range()},
|
||||
@ -152,26 +154,119 @@ BeliefProp::getFactorJoint (
|
||||
|
||||
|
||||
|
||||
BeliefProp::BpLink::BpLink (FacNode* fn, VarNode* vn)
|
||||
{
|
||||
fac_ = fn;
|
||||
var_ = vn;
|
||||
v1_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
|
||||
v2_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
|
||||
currMsg_ = &v1_;
|
||||
nextMsg_ = &v2_;
|
||||
residual_ = 0.0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BeliefProp::runSolver (void)
|
||||
BeliefProp::BpLink::clearResidual()
|
||||
{
|
||||
residual_ = 0.0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BeliefProp::BpLink::updateResidual()
|
||||
{
|
||||
residual_ = LogAware::getMaxNorm (v1_, v2_);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BeliefProp::BpLink::updateMessage()
|
||||
{
|
||||
swap (currMsg_, nextMsg_);
|
||||
}
|
||||
|
||||
|
||||
|
||||
std::string
|
||||
BeliefProp::BpLink::toString() const
|
||||
{
|
||||
std::stringstream ss;
|
||||
ss << fac_->getLabel();
|
||||
ss << " -- " ;
|
||||
ss << var_->label();
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BeliefProp::calculateAndUpdateMessage (BpLink* link, bool calcResidual)
|
||||
{
|
||||
if (Globals::verbosity > 2) {
|
||||
std::cout << "calculating & updating " << link->toString();
|
||||
std::cout << std::endl;
|
||||
}
|
||||
calcFactorToVarMsg (link);
|
||||
if (calcResidual) {
|
||||
link->updateResidual();
|
||||
}
|
||||
link->updateMessage();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BeliefProp::calculateMessage (BpLink* link, bool calcResidual)
|
||||
{
|
||||
if (Globals::verbosity > 2) {
|
||||
std::cout << "calculating " << link->toString();
|
||||
std::cout << std::endl;
|
||||
}
|
||||
calcFactorToVarMsg (link);
|
||||
if (calcResidual) {
|
||||
link->updateResidual();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BeliefProp::updateMessage (BpLink* link)
|
||||
{
|
||||
link->updateMessage();
|
||||
if (Globals::verbosity > 2) {
|
||||
std::cout << "updating " << link->toString();
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BeliefProp::runSolver()
|
||||
{
|
||||
initializeSolver();
|
||||
nIters_ = 0;
|
||||
while (!converged() && nIters_ < maxIter_) {
|
||||
nIters_ ++;
|
||||
if (Globals::verbosity > 1) {
|
||||
Util::printHeader (string ("Iteration ") + Util::toString (nIters_));
|
||||
Util::printHeader (std::string ("Iteration ")
|
||||
+ Util::toString (nIters_));
|
||||
}
|
||||
switch (schedule_) {
|
||||
case MsgSchedule::SEQ_RANDOM:
|
||||
case MsgSchedule::seqRandomSch:
|
||||
std::random_shuffle (links_.begin(), links_.end());
|
||||
// no break
|
||||
case MsgSchedule::SEQ_FIXED:
|
||||
case MsgSchedule::seqFixedSch:
|
||||
for (size_t i = 0; i < links_.size(); i++) {
|
||||
calculateAndUpdateMessage (links_[i]);
|
||||
}
|
||||
break;
|
||||
case MsgSchedule::PARALLEL:
|
||||
case MsgSchedule::parallelSch:
|
||||
for (size_t i = 0; i < links_.size(); i++) {
|
||||
calculateMessage (links_[i]);
|
||||
}
|
||||
@ -179,20 +274,21 @@ BeliefProp::runSolver (void)
|
||||
updateMessage(links_[i]);
|
||||
}
|
||||
break;
|
||||
case MsgSchedule::MAX_RESIDUAL:
|
||||
case MsgSchedule::maxResidualSch:
|
||||
maxResidualSchedule();
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (Globals::verbosity > 0) {
|
||||
if (nIters_ < maxIter_) {
|
||||
cout << "Belief propagation converged in " ;
|
||||
cout << nIters_ << " iterations" << endl;
|
||||
std::cout << "Belief propagation converged in " ;
|
||||
std::cout << nIters_ << " iterations" << std::endl;
|
||||
} else {
|
||||
cout << "The maximum number of iterations was hit, terminating..." ;
|
||||
cout << endl;
|
||||
std::cout << "The maximum number of iterations was hit," ;
|
||||
std::cout << " terminating..." ;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
cout << endl;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
runned_ = true;
|
||||
}
|
||||
@ -200,7 +296,7 @@ BeliefProp::runSolver (void)
|
||||
|
||||
|
||||
void
|
||||
BeliefProp::createLinks (void)
|
||||
BeliefProp::createLinks()
|
||||
{
|
||||
const FacNodes& facNodes = fg.facNodes();
|
||||
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||
@ -214,7 +310,7 @@ BeliefProp::createLinks (void)
|
||||
|
||||
|
||||
void
|
||||
BeliefProp::maxResidualSchedule (void)
|
||||
BeliefProp::maxResidualSchedule()
|
||||
{
|
||||
if (nIters_ == 1) {
|
||||
for (size_t i = 0; i < links_.size(); i++) {
|
||||
@ -227,11 +323,13 @@ BeliefProp::maxResidualSchedule (void)
|
||||
|
||||
for (size_t c = 0; c < links_.size(); c++) {
|
||||
if (Globals::verbosity > 1) {
|
||||
cout << "current residuals:" << endl;
|
||||
std::cout << "current residuals:" << std::endl;
|
||||
for (SortedOrder::iterator it = sortedOrder_.begin();
|
||||
it != sortedOrder_.end(); ++it) {
|
||||
cout << " " << setw (30) << left << (*it)->toString();
|
||||
cout << "residual = " << (*it)->residual() << endl;
|
||||
std::cout << " " << std::setw (30) << std::left;
|
||||
std::cout << (*it)->toString();
|
||||
std::cout << "residual = " << (*it)->residual();
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
@ -249,7 +347,7 @@ BeliefProp::maxResidualSchedule (void)
|
||||
const FacNodes& factorNeighbors = link->varNode()->neighbors();
|
||||
for (size_t i = 0; i < factorNeighbors.size(); i++) {
|
||||
if (factorNeighbors[i] != link->facNode()) {
|
||||
const BpLinks& links = ninf(factorNeighbors[i])->getLinks();
|
||||
const BpLinks& links = getLinks (factorNeighbors[i]);
|
||||
for (size_t j = 0; j < links.size(); j++) {
|
||||
if (links[j]->varNode() != link->varNode()) {
|
||||
calculateMessage (links[j]);
|
||||
@ -273,7 +371,7 @@ BeliefProp::calcFactorToVarMsg (BpLink* link)
|
||||
{
|
||||
FacNode* src = link->facNode();
|
||||
const VarNode* dst = link->varNode();
|
||||
const BpLinks& links = ninf(src)->getLinks();
|
||||
const BpLinks& links = getLinks (src);
|
||||
// calculate the product of messages that were sent
|
||||
// to factor `src', except from var `dst'
|
||||
unsigned reps = 1;
|
||||
@ -282,14 +380,14 @@ BeliefProp::calcFactorToVarMsg (BpLink* link)
|
||||
if (Globals::logDomain) {
|
||||
for (size_t i = links.size(); i-- > 0; ) {
|
||||
if (links[i]->varNode() != dst) {
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << " message from " << links[i]->varNode()->label();
|
||||
cout << ": " ;
|
||||
if (Constants::showBpCalcs) {
|
||||
std::cout << " message from " << links[i]->varNode()->label();
|
||||
std::cout << ": " ;
|
||||
}
|
||||
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
|
||||
reps, std::plus<double>());
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << endl;
|
||||
if (Constants::showBpCalcs) {
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
reps *= links[i]->varNode()->range();
|
||||
@ -297,14 +395,14 @@ BeliefProp::calcFactorToVarMsg (BpLink* link)
|
||||
} else {
|
||||
for (size_t i = links.size(); i-- > 0; ) {
|
||||
if (links[i]->varNode() != dst) {
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << " message from " << links[i]->varNode()->label();
|
||||
cout << ": " ;
|
||||
if (Constants::showBpCalcs) {
|
||||
std::cout << " message from " << links[i]->varNode()->label();
|
||||
std::cout << ": " ;
|
||||
}
|
||||
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
|
||||
reps, std::multiplies<double>());
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << endl;
|
||||
if (Constants::showBpCalcs) {
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
reps *= links[i]->varNode()->range();
|
||||
@ -313,27 +411,28 @@ BeliefProp::calcFactorToVarMsg (BpLink* link)
|
||||
Factor result (src->factor().arguments(),
|
||||
src->factor().ranges(), msgProduct);
|
||||
result.multiply (src->factor());
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << " message product: " << msgProduct << endl;
|
||||
cout << " original factor: " << src->factor().params() << endl;
|
||||
cout << " factor product: " << result.params() << endl;
|
||||
if (Constants::showBpCalcs) {
|
||||
std::cout << " message product: " << msgProduct << std::endl;
|
||||
std::cout << " original factor: " << src->factor().params();
|
||||
std::cout << std::endl;
|
||||
std::cout << " factor product: " << result.params() << std::endl;
|
||||
}
|
||||
result.sumOutAllExcept (dst->varId());
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << " marginalized: " << result.params() << endl;
|
||||
if (Constants::showBpCalcs) {
|
||||
std::cout << " marginalized: " << result.params() << std::endl;
|
||||
}
|
||||
link->nextMessage() = result.params();
|
||||
LogAware::normalize (link->nextMessage());
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << " curr msg: " << link->message() << endl;
|
||||
cout << " next msg: " << link->nextMessage() << endl;
|
||||
if (Constants::showBpCalcs) {
|
||||
std::cout << " curr msg: " << link->message() << std::endl;
|
||||
std::cout << " next msg: " << link->nextMessage() << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
Params
|
||||
BeliefProp::getVarToFactorMsg (const BpLink* link) const
|
||||
BeliefProp::getVarToFactorMsg (const BpLink* link)
|
||||
{
|
||||
const VarNode* src = link->varNode();
|
||||
Params msg;
|
||||
@ -343,18 +442,18 @@ BeliefProp::getVarToFactorMsg (const BpLink* link) const
|
||||
} else {
|
||||
msg.resize (src->range(), LogAware::one());
|
||||
}
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << msg;
|
||||
if (Constants::showBpCalcs) {
|
||||
std::cout << msg;
|
||||
}
|
||||
BpLinks::const_iterator it;
|
||||
const BpLinks& links = ninf (src)->getLinks();
|
||||
const BpLinks& links = getLinks (src);
|
||||
if (Globals::logDomain) {
|
||||
for (it = links.begin(); it != links.end(); ++it) {
|
||||
if (*it != link) {
|
||||
msg += (*it)->message();
|
||||
}
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << " x " << (*it)->message();
|
||||
if (Constants::showBpCalcs) {
|
||||
std::cout << " x " << (*it)->message();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@ -362,13 +461,13 @@ BeliefProp::getVarToFactorMsg (const BpLink* link) const
|
||||
if (*it != link) {
|
||||
msg *= (*it)->message();
|
||||
}
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << " x " << (*it)->message();
|
||||
if (Constants::showBpCalcs) {
|
||||
std::cout << " x " << (*it)->message();
|
||||
}
|
||||
}
|
||||
}
|
||||
if (Constants::SHOW_BP_CALCS) {
|
||||
cout << " = " << msg;
|
||||
if (Constants::showBpCalcs) {
|
||||
std::cout << " = " << msg;
|
||||
}
|
||||
return msg;
|
||||
}
|
||||
@ -379,37 +478,37 @@ Params
|
||||
BeliefProp::getJointByConditioning (const VarIds& jointVarIds) const
|
||||
{
|
||||
return GroundSolver::getJointByConditioning (
|
||||
GroundSolverType::BP, fg, jointVarIds);
|
||||
GroundSolverType::bpSolver, fg, jointVarIds);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BeliefProp::initializeSolver (void)
|
||||
BeliefProp::initializeSolver()
|
||||
{
|
||||
const VarNodes& varNodes = fg.varNodes();
|
||||
varsI_.reserve (varNodes.size());
|
||||
varsLinks_.reserve (varNodes.size());
|
||||
for (size_t i = 0; i < varNodes.size(); i++) {
|
||||
varsI_.push_back (new SPNodeInfo());
|
||||
varsLinks_.push_back (BpLinks());
|
||||
}
|
||||
const FacNodes& facNodes = fg.facNodes();
|
||||
facsI_.reserve (facNodes.size());
|
||||
facsLinks_.reserve (facNodes.size());
|
||||
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||
facsI_.push_back (new SPNodeInfo());
|
||||
facsLinks_.push_back (BpLinks());
|
||||
}
|
||||
createLinks();
|
||||
for (size_t i = 0; i < links_.size(); i++) {
|
||||
FacNode* src = links_[i]->facNode();
|
||||
VarNode* dst = links_[i]->varNode();
|
||||
ninf (dst)->addBpLink (links_[i]);
|
||||
ninf (src)->addBpLink (links_[i]);
|
||||
getLinks (dst).push_back (links_[i]);
|
||||
getLinks (src).push_back (links_[i]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
BeliefProp::converged (void)
|
||||
BeliefProp::converged()
|
||||
{
|
||||
if (links_.empty()) {
|
||||
return true;
|
||||
@ -418,16 +517,16 @@ BeliefProp::converged (void)
|
||||
return false;
|
||||
}
|
||||
if (Globals::verbosity > 2) {
|
||||
cout << endl;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
if (nIters_ == 1) {
|
||||
if (Globals::verbosity > 1) {
|
||||
cout << "no residuals" << endl << endl;
|
||||
std::cout << "no residuals" << std::endl << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
bool converged = true;
|
||||
if (schedule_ == MsgSchedule::MAX_RESIDUAL) {
|
||||
if (schedule_ == MsgSchedule::maxResidualSch) {
|
||||
double maxResidual = (*(sortedOrder_.begin()))->residual();
|
||||
if (maxResidual > accuracy_) {
|
||||
converged = false;
|
||||
@ -438,7 +537,8 @@ BeliefProp::converged (void)
|
||||
for (size_t i = 0; i < links_.size(); i++) {
|
||||
double residual = links_[i]->residual();
|
||||
if (Globals::verbosity > 1) {
|
||||
cout << links_[i]->toString() + " residual = " << residual << endl;
|
||||
std::cout << links_[i]->toString() + " residual = " << residual;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
if (residual > accuracy_) {
|
||||
converged = false;
|
||||
@ -448,7 +548,7 @@ BeliefProp::converged (void)
|
||||
}
|
||||
}
|
||||
if (Globals::verbosity > 1) {
|
||||
cout << endl;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
return converged;
|
||||
@ -457,8 +557,10 @@ BeliefProp::converged (void)
|
||||
|
||||
|
||||
void
|
||||
BeliefProp::printLinkInformation (void) const
|
||||
BeliefProp::printLinkInformation() const
|
||||
{
|
||||
using std::cout;
|
||||
using std::endl;
|
||||
for (size_t i = 0; i < links_.size(); i++) {
|
||||
BpLink* l = links_[i];
|
||||
cout << l->toString() << ":" << endl;
|
||||
@ -470,3 +572,5 @@ BeliefProp::printLinkInformation (void) const
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
|
@ -1,111 +1,35 @@
|
||||
#ifndef HORUS_BELIEFPROP_H
|
||||
#define HORUS_BELIEFPROP_H
|
||||
#ifndef YAP_PACKAGES_CLPBN_HORUS_BELIEFPROP_H_
|
||||
#define YAP_PACKAGES_CLPBN_HORUS_BELIEFPROP_H_
|
||||
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
#include <sstream>
|
||||
#include <set>
|
||||
#include <string>
|
||||
|
||||
#include "GroundSolver.h"
|
||||
#include "FactorGraph.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
||||
enum MsgSchedule {
|
||||
SEQ_FIXED,
|
||||
SEQ_RANDOM,
|
||||
PARALLEL,
|
||||
MAX_RESIDUAL
|
||||
};
|
||||
|
||||
|
||||
class BpLink
|
||||
{
|
||||
public:
|
||||
BpLink (FacNode* fn, VarNode* vn)
|
||||
{
|
||||
fac_ = fn;
|
||||
var_ = vn;
|
||||
v1_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
|
||||
v2_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
|
||||
currMsg_ = &v1_;
|
||||
nextMsg_ = &v2_;
|
||||
residual_ = 0.0;
|
||||
}
|
||||
|
||||
virtual ~BpLink (void) { };
|
||||
|
||||
FacNode* facNode (void) const { return fac_; }
|
||||
|
||||
VarNode* varNode (void) const { return var_; }
|
||||
|
||||
const Params& message (void) const { return *currMsg_; }
|
||||
|
||||
Params& nextMessage (void) { return *nextMsg_; }
|
||||
|
||||
double residual (void) const { return residual_; }
|
||||
|
||||
void clearResidual (void) { residual_ = 0.0; }
|
||||
|
||||
void updateResidual (void)
|
||||
{
|
||||
residual_ = LogAware::getMaxNorm (v1_, v2_);
|
||||
}
|
||||
|
||||
virtual void updateMessage (void)
|
||||
{
|
||||
swap (currMsg_, nextMsg_);
|
||||
}
|
||||
|
||||
string toString (void) const
|
||||
{
|
||||
stringstream ss;
|
||||
ss << fac_->getLabel();
|
||||
ss << " -- " ;
|
||||
ss << var_->label();
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
protected:
|
||||
FacNode* fac_;
|
||||
VarNode* var_;
|
||||
Params v1_;
|
||||
Params v2_;
|
||||
Params* currMsg_;
|
||||
Params* nextMsg_;
|
||||
double residual_;
|
||||
namespace Horus {
|
||||
|
||||
class BeliefProp : public GroundSolver {
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN (BpLink);
|
||||
};
|
||||
class SPNodeInfo;
|
||||
|
||||
typedef vector<BpLink*> BpLinks;
|
||||
|
||||
|
||||
class SPNodeInfo
|
||||
{
|
||||
public:
|
||||
SPNodeInfo (void) { }
|
||||
void addBpLink (BpLink* link) { links_.push_back (link); }
|
||||
const BpLinks& getLinks (void) { return links_; }
|
||||
private:
|
||||
BpLinks links_;
|
||||
DISALLOW_COPY_AND_ASSIGN (SPNodeInfo);
|
||||
};
|
||||
enum class MsgSchedule {
|
||||
seqFixedSch,
|
||||
seqRandomSch,
|
||||
parallelSch,
|
||||
maxResidualSch
|
||||
};
|
||||
|
||||
|
||||
class BeliefProp : public GroundSolver
|
||||
{
|
||||
public:
|
||||
BeliefProp (const FactorGraph&);
|
||||
|
||||
virtual ~BeliefProp (void);
|
||||
virtual ~BeliefProp();
|
||||
|
||||
Params solveQuery (VarIds);
|
||||
|
||||
virtual void printSolverFlags (void) const;
|
||||
virtual void printSolverFlags() const;
|
||||
|
||||
virtual Params getPosterioriOf (VarId);
|
||||
|
||||
@ -113,105 +37,128 @@ class BeliefProp : public GroundSolver
|
||||
|
||||
Params getFactorJoint (FacNode* fn, const VarIds&);
|
||||
|
||||
static double accuracy (void) { return accuracy_; }
|
||||
static double accuracy() { return accuracy_; }
|
||||
|
||||
static void setAccuracy (double acc) { accuracy_ = acc; }
|
||||
|
||||
static unsigned maxIterations (void) { return maxIter_; }
|
||||
static unsigned maxIterations() { return maxIter_; }
|
||||
|
||||
static void setMaxIterations (unsigned mi) { maxIter_ = mi; }
|
||||
|
||||
static MsgSchedule msgSchedule (void) { return schedule_; }
|
||||
static MsgSchedule msgSchedule() { return schedule_; }
|
||||
|
||||
static void setMsgSchedule (MsgSchedule sch) { schedule_ = sch; }
|
||||
|
||||
protected:
|
||||
SPNodeInfo* ninf (const VarNode* var) const
|
||||
{
|
||||
return varsI_[var->getIndex()];
|
||||
}
|
||||
class BpLink {
|
||||
public:
|
||||
BpLink (FacNode* fn, VarNode* vn);
|
||||
|
||||
SPNodeInfo* ninf (const FacNode* fac) const
|
||||
{
|
||||
return facsI_[fac->getIndex()];
|
||||
}
|
||||
virtual ~BpLink() { };
|
||||
|
||||
void calculateAndUpdateMessage (BpLink* link, bool calcResidual = true)
|
||||
{
|
||||
if (Globals::verbosity > 2) {
|
||||
cout << "calculating & updating " << link->toString() << endl;
|
||||
}
|
||||
calcFactorToVarMsg (link);
|
||||
if (calcResidual) {
|
||||
link->updateResidual();
|
||||
}
|
||||
link->updateMessage();
|
||||
}
|
||||
FacNode* facNode() const { return fac_; }
|
||||
|
||||
void calculateMessage (BpLink* link, bool calcResidual = true)
|
||||
{
|
||||
if (Globals::verbosity > 2) {
|
||||
cout << "calculating " << link->toString() << endl;
|
||||
}
|
||||
calcFactorToVarMsg (link);
|
||||
if (calcResidual) {
|
||||
link->updateResidual();
|
||||
}
|
||||
}
|
||||
VarNode* varNode() const { return var_; }
|
||||
|
||||
void updateMessage (BpLink* link)
|
||||
{
|
||||
link->updateMessage();
|
||||
if (Globals::verbosity > 2) {
|
||||
cout << "updating " << link->toString() << endl;
|
||||
}
|
||||
}
|
||||
const Params& message() const { return *currMsg_; }
|
||||
|
||||
struct CompareResidual
|
||||
{
|
||||
inline bool operator() (const BpLink* link1, const BpLink* link2)
|
||||
{
|
||||
return link1->residual() > link2->residual();
|
||||
}
|
||||
Params& nextMessage() { return *nextMsg_; }
|
||||
|
||||
double residual() const { return residual_; }
|
||||
|
||||
void clearResidual();
|
||||
|
||||
void updateResidual();
|
||||
|
||||
virtual void updateMessage();
|
||||
|
||||
std::string toString() const;
|
||||
|
||||
protected:
|
||||
FacNode* fac_;
|
||||
VarNode* var_;
|
||||
Params v1_;
|
||||
Params v2_;
|
||||
Params* currMsg_;
|
||||
Params* nextMsg_;
|
||||
double residual_;
|
||||
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN (BpLink);
|
||||
};
|
||||
|
||||
void runSolver (void);
|
||||
struct CmpResidual {
|
||||
bool operator() (const BpLink* l1, const BpLink* l2) {
|
||||
return l1->residual() > l2->residual();
|
||||
}};
|
||||
|
||||
virtual void createLinks (void);
|
||||
typedef std::vector<BeliefProp::BpLink*> BpLinks;
|
||||
typedef std::multiset<BpLink*, CmpResidual> SortedOrder;
|
||||
typedef std::unordered_map<BpLink*, SortedOrder::iterator> BpLinkMap;
|
||||
|
||||
virtual void maxResidualSchedule (void);
|
||||
BpLinks& getLinks (const VarNode* var);
|
||||
|
||||
BpLinks& getLinks (const FacNode* fac);
|
||||
|
||||
void calculateAndUpdateMessage (BpLink* link, bool calcResidual = true);
|
||||
|
||||
void calculateMessage (BpLink* link, bool calcResidual = true);
|
||||
|
||||
void updateMessage (BpLink* link);
|
||||
|
||||
void runSolver();
|
||||
|
||||
virtual void createLinks();
|
||||
|
||||
virtual void maxResidualSchedule();
|
||||
|
||||
virtual void calcFactorToVarMsg (BpLink*);
|
||||
|
||||
virtual Params getVarToFactorMsg (const BpLink*) const;
|
||||
virtual Params getVarToFactorMsg (const BpLink*);
|
||||
|
||||
virtual Params getJointByConditioning (const VarIds&) const;
|
||||
|
||||
BpLinks links_;
|
||||
unsigned nIters_;
|
||||
vector<SPNodeInfo*> varsI_;
|
||||
vector<SPNodeInfo*> facsI_;
|
||||
bool runned_;
|
||||
BpLinks links_;
|
||||
unsigned nIters_;
|
||||
bool runned_;
|
||||
SortedOrder sortedOrder_;
|
||||
BpLinkMap linkMap_;
|
||||
|
||||
typedef multiset<BpLink*, CompareResidual> SortedOrder;
|
||||
SortedOrder sortedOrder_;
|
||||
|
||||
typedef unordered_map<BpLink*, SortedOrder::iterator> BpLinkMap;
|
||||
BpLinkMap linkMap_;
|
||||
|
||||
static double accuracy_;
|
||||
static unsigned maxIter_;
|
||||
static MsgSchedule schedule_;
|
||||
static double accuracy_;
|
||||
|
||||
private:
|
||||
void initializeSolver (void);
|
||||
void initializeSolver();
|
||||
|
||||
bool converged (void);
|
||||
bool converged();
|
||||
|
||||
virtual void printLinkInformation (void) const;
|
||||
virtual void printLinkInformation() const;
|
||||
|
||||
std::vector<BpLinks> varsLinks_;
|
||||
std::vector<BpLinks> facsLinks_;
|
||||
|
||||
static unsigned maxIter_;
|
||||
static MsgSchedule schedule_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN (BeliefProp);
|
||||
};
|
||||
|
||||
#endif // HORUS_BELIEFPROP_H
|
||||
|
||||
|
||||
inline BeliefProp::BpLinks&
|
||||
BeliefProp::getLinks (const VarNode* var)
|
||||
{
|
||||
return varsLinks_[var->getIndex()];
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline BeliefProp::BpLinks&
|
||||
BeliefProp::getLinks (const FacNode* fac)
|
||||
{
|
||||
return facsLinks_[fac->getIndex()];
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_BELIEFPROP_H_
|
||||
|
||||
|
@ -1,11 +1,88 @@
|
||||
#include <queue>
|
||||
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <fstream>
|
||||
|
||||
#include "ConstraintTree.h"
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
class CTNode {
|
||||
public:
|
||||
CTNode (const CTNode& n, const CTChilds& chs = CTChilds())
|
||||
: symbol_(n.symbol()), childs_(chs), level_(n.level()) { }
|
||||
|
||||
CTNode (Symbol s, unsigned l, const CTChilds& chs = CTChilds())
|
||||
: symbol_(s), childs_(chs), level_(l) { }
|
||||
|
||||
unsigned level() const { return level_; }
|
||||
|
||||
void setLevel (unsigned level) { level_ = level; }
|
||||
|
||||
Symbol symbol() const { return symbol_; }
|
||||
|
||||
void setSymbol (Symbol s) { symbol_ = s; }
|
||||
|
||||
CTChilds& childs() { return childs_; }
|
||||
|
||||
const CTChilds& childs() const { return childs_; }
|
||||
|
||||
size_t nrChilds() const { return childs_.size(); }
|
||||
|
||||
bool isRoot() const { return level_ == 0; }
|
||||
|
||||
bool isLeaf() const { return childs_.empty(); }
|
||||
|
||||
CTChilds::iterator findSymbol (Symbol symb);
|
||||
|
||||
void mergeSubtree (CTNode*, bool = true);
|
||||
|
||||
void removeChild (CTNode*);
|
||||
|
||||
void removeChilds();
|
||||
|
||||
void removeAndDeleteChild (CTNode*);
|
||||
|
||||
void removeAndDeleteAllChilds();
|
||||
|
||||
SymbolSet childSymbols() const;
|
||||
|
||||
static CTNode* copySubtree (const CTNode*);
|
||||
|
||||
static void deleteSubtree (CTNode*);
|
||||
|
||||
private:
|
||||
void updateChildLevels (CTNode*, unsigned);
|
||||
|
||||
Symbol symbol_;
|
||||
CTChilds childs_;
|
||||
unsigned level_;
|
||||
|
||||
DISALLOW_ASSIGN (CTNode);
|
||||
};
|
||||
|
||||
|
||||
|
||||
inline CTChilds::iterator
|
||||
CTNode::findSymbol (Symbol symb)
|
||||
{
|
||||
CTNode tmp (symb, 0);
|
||||
return childs_.find (&tmp);
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline bool
|
||||
CmpSymbol::operator() (const CTNode* n1, const CTNode* n2) const
|
||||
{
|
||||
return n1->symbol() < n2->symbol();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
CTNode::mergeSubtree (CTNode* n, bool updateLevels)
|
||||
{
|
||||
@ -38,7 +115,7 @@ CTNode::removeChild (CTNode* child)
|
||||
|
||||
|
||||
void
|
||||
CTNode::removeChilds (void)
|
||||
CTNode::removeChilds()
|
||||
{
|
||||
childs_.clear();
|
||||
}
|
||||
@ -55,7 +132,7 @@ CTNode::removeAndDeleteChild (CTNode* child)
|
||||
|
||||
|
||||
void
|
||||
CTNode::removeAndDeleteAllChilds (void)
|
||||
CTNode::removeAndDeleteAllChilds()
|
||||
{
|
||||
for (CTChilds::const_iterator chIt = childs_.begin();
|
||||
chIt != childs_.end(); ++ chIt) {
|
||||
@ -67,7 +144,7 @@ CTNode::removeAndDeleteAllChilds (void)
|
||||
|
||||
|
||||
SymbolSet
|
||||
CTNode::childSymbols (void) const
|
||||
CTNode::childSymbols() const
|
||||
{
|
||||
SymbolSet symbols;
|
||||
for (CTChilds::const_iterator chIt = childs_.begin();
|
||||
@ -106,14 +183,14 @@ CTNode::copySubtree (const CTNode* root1)
|
||||
return new CTNode (*root1);
|
||||
}
|
||||
CTNode* root2 = new CTNode (*root1);
|
||||
typedef pair<const CTNode*, CTNode*> StackPair;
|
||||
vector<StackPair> stack = { StackPair (root1, root2) };
|
||||
typedef std::pair<const CTNode*, CTNode*> StackPair;
|
||||
std::vector<StackPair> stack = { StackPair (root1, root2) };
|
||||
while (stack.empty() == false) {
|
||||
const CTNode* n1 = stack.back().first;
|
||||
CTNode* n2 = stack.back().second;
|
||||
stack.pop_back();
|
||||
// cout << "n2 childs: " << n2->childs();
|
||||
// cout << "n1 childs: " << n1->childs();
|
||||
// std::cout << "n2 childs: " << n2->childs();
|
||||
// std::cout << "n1 childs: " << n1->childs();
|
||||
n2->childs().reserve (n1->nrChilds());
|
||||
stack.reserve (n1->nrChilds());
|
||||
for (CTChilds::const_iterator chIt = n1->childs().begin();
|
||||
@ -144,7 +221,8 @@ CTNode::deleteSubtree (CTNode* n)
|
||||
|
||||
|
||||
|
||||
ostream& operator<< (ostream &out, const CTNode& n)
|
||||
std::ostream&
|
||||
operator<< (std::ostream& out, const CTNode& n)
|
||||
{
|
||||
out << "(" << n.level() << ") " ;
|
||||
out << n.symbol();
|
||||
@ -187,7 +265,8 @@ ConstraintTree::ConstraintTree (
|
||||
|
||||
|
||||
|
||||
ConstraintTree::ConstraintTree (vector<vector<string>> names)
|
||||
ConstraintTree::ConstraintTree (
|
||||
std::vector<std::vector<std::string>> names)
|
||||
{
|
||||
assert (names.empty() == false);
|
||||
assert (names.front().empty() == false);
|
||||
@ -216,13 +295,33 @@ ConstraintTree::ConstraintTree (const ConstraintTree& ct)
|
||||
|
||||
|
||||
|
||||
ConstraintTree::~ConstraintTree (void)
|
||||
ConstraintTree::ConstraintTree (
|
||||
const CTChilds& rootChilds,
|
||||
const LogVars& logVars)
|
||||
: root_(new CTNode (Symbol (0), unsigned (0), rootChilds)),
|
||||
logVars_(logVars),
|
||||
logVarSet_(logVars)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
ConstraintTree::~ConstraintTree()
|
||||
{
|
||||
CTNode::deleteSubtree (root_);
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
ConstraintTree::empty() const
|
||||
{
|
||||
return root_->childs().empty();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
ConstraintTree::addTuple (const Tuple& tuple)
|
||||
{
|
||||
@ -448,7 +547,7 @@ ConstraintTree::ConstraintTree::isSingleton (LogVar X)
|
||||
|
||||
|
||||
LogVarSet
|
||||
ConstraintTree::singletons (void)
|
||||
ConstraintTree::singletons()
|
||||
{
|
||||
LogVarSet singletons;
|
||||
for (size_t i = 0; i < logVars_.size(); i++) {
|
||||
@ -491,7 +590,7 @@ ConstraintTree::tupleSet (const LogVars& originalLvs)
|
||||
getTuples (root_, Tuples(), stopLevel, tuples, CTNodes() = {});
|
||||
|
||||
if (originalLvs.size() != uniqueLvs.size()) {
|
||||
vector<size_t> indexes;
|
||||
std::vector<size_t> indexes;
|
||||
indexes.reserve (originalLvs.size());
|
||||
for (size_t i = 0; i < originalLvs.size(); i++) {
|
||||
indexes.push_back (Util::indexOf (uniqueLvs, originalLvs[i]));
|
||||
@ -519,21 +618,22 @@ ConstraintTree::exportToGraphViz (
|
||||
const char* fileName,
|
||||
bool showLogVars) const
|
||||
{
|
||||
ofstream out (fileName);
|
||||
std::ofstream out (fileName);
|
||||
if (!out.is_open()) {
|
||||
cerr << "Error: couldn't open file '" << fileName << "'." ;
|
||||
std::cerr << "Error: couldn't open file '" << fileName << "'." ;
|
||||
std::cerr << std::endl;
|
||||
return;
|
||||
}
|
||||
out << "digraph {" << endl;
|
||||
out << "digraph {" << std::endl;
|
||||
ConstraintTree copy (*this);
|
||||
copy.moveToTop (copy.logVarSet_.elements());
|
||||
CTNodes nodes = getNodesBelow (copy.root_);
|
||||
out << "\"" << copy.root_ << "\"" << " [label=\"R\"]" << endl;
|
||||
out << "\"" << copy.root_ << "\"" << " [label=\"R\"]" << std::endl;
|
||||
for (CTNodes::const_iterator it = ++ nodes.begin();
|
||||
it != nodes.end(); ++ it) {
|
||||
out << "\"" << *it << "\"";
|
||||
out << " [label=\"" << **it << "\"]" ;
|
||||
out << endl;
|
||||
out << std::endl;
|
||||
}
|
||||
for (CTNodes::const_iterator it = nodes.begin();
|
||||
it != nodes.end(); ++ it) {
|
||||
@ -542,24 +642,24 @@ ConstraintTree::exportToGraphViz (
|
||||
chIt != childs.end(); ++ chIt) {
|
||||
out << "\"" << *it << "\"" ;
|
||||
out << " -> " ;
|
||||
out << "\"" << *chIt << "\"" << endl ;
|
||||
out << "\"" << *chIt << "\"" << std::endl ;
|
||||
}
|
||||
}
|
||||
if (showLogVars) {
|
||||
out << "Root [label=\"\", shape=plaintext]" << endl;
|
||||
out << "Root [label=\"\", shape=plaintext]" << std::endl;
|
||||
for (size_t i = 0; i < copy.logVars_.size(); i++) {
|
||||
out << copy.logVars_[i] << " [label=" ;
|
||||
out << copy.logVars_[i] << ", " ;
|
||||
out << "shape=plaintext, fontsize=14]" << endl;
|
||||
out << "shape=plaintext, fontsize=14]" << std::endl;
|
||||
}
|
||||
out << "Root -> " << copy.logVars_[0];
|
||||
out << " [style=invis]" << endl;
|
||||
out << " [style=invis]" << std::endl;
|
||||
for (size_t i = 0; i < copy.logVars_.size() - 1; i++) {
|
||||
out << copy.logVars_[i] << " -> " << copy.logVars_[i + 1];
|
||||
out << " [style=invis]" << endl;
|
||||
out << " [style=invis]" << std::endl;
|
||||
}
|
||||
}
|
||||
out << "}" << endl;
|
||||
out << "}" <<std::endl;
|
||||
out.close();
|
||||
}
|
||||
|
||||
@ -690,9 +790,9 @@ ConstraintTree::split (
|
||||
split (root_, ct->root(), commChilds, exclChilds, stopLevel);
|
||||
ConstraintTree* commCt = new ConstraintTree (commChilds, logVars_);
|
||||
ConstraintTree* exclCt = new ConstraintTree (exclChilds, logVars_);
|
||||
// cout << commCt->tupleSet() << " + " ;
|
||||
// cout << exclCt->tupleSet() << " = " ;
|
||||
// cout << tupleSet() << endl;
|
||||
// std::cout << commCt->tupleSet() << " + " ;
|
||||
// std::cout << exclCt->tupleSet() << " = " ;
|
||||
// std::cout << tupleSet() << std::endl;
|
||||
assert ((commCt->tupleSet() | exclCt->tupleSet()) == tupleSet());
|
||||
assert ((exclCt->tupleSet (stopLevel) & ct->tupleSet (stopLevel)).empty());
|
||||
return {commCt, exclCt};
|
||||
@ -710,20 +810,20 @@ ConstraintTree::countNormalize (const LogVarSet& Ys)
|
||||
}
|
||||
moveToTop (Zs.elements());
|
||||
ConstraintTrees cts;
|
||||
unordered_map<unsigned, ConstraintTree*> countMap;
|
||||
std::unordered_map<unsigned, ConstraintTree*> countMap;
|
||||
unsigned stopLevel = getLevel (Zs.back());
|
||||
const CTChilds& childs = root_->childs();
|
||||
|
||||
for (CTChilds::const_iterator chIt = childs.begin();
|
||||
chIt != childs.end(); ++ chIt) {
|
||||
const vector<pair<CTNode*, unsigned>>& res =
|
||||
const std::vector<std::pair<CTNode*, unsigned>>& res =
|
||||
countNormalize (*chIt, stopLevel);
|
||||
for (size_t j = 0; j < res.size(); j++) {
|
||||
unordered_map<unsigned, ConstraintTree*>::iterator it
|
||||
std::unordered_map<unsigned, ConstraintTree*>::iterator it
|
||||
= countMap.find (res[j].second);
|
||||
if (it == countMap.end()) {
|
||||
ConstraintTree* newCt = new ConstraintTree (logVars_);
|
||||
it = countMap.insert (make_pair (res[j].second, newCt)).first;
|
||||
it = countMap.insert (std::make_pair (res[j].second, newCt)).first;
|
||||
cts.push_back (newCt);
|
||||
}
|
||||
it->second->root_->mergeSubtree (res[j].first);
|
||||
@ -743,31 +843,31 @@ ConstraintTree::jointCountNormalize (
|
||||
LogVar X_new2)
|
||||
{
|
||||
unsigned N = getConditionalCount (X);
|
||||
// cout << "My tuples: " << tupleSet() << endl;
|
||||
// cout << "CommCt tuples: " << commCt->tupleSet() << endl;
|
||||
// cout << "ExclCt tuples: " << exclCt->tupleSet() << endl;
|
||||
// cout << "Counted Lv: " << X << endl;
|
||||
// cout << "X_new1: " << X_new1 << endl;
|
||||
// cout << "X_new2: " << X_new2 << endl;
|
||||
// cout << "Original N: " << N << endl;
|
||||
// cout << endl;
|
||||
// std::cout << "My tuples: " << tupleSet() << std::endl;
|
||||
// std::cout << "CommCt tuples: " << commCt->tupleSet() << std::endl;
|
||||
// std::cout << "ExclCt tuples: " << exclCt->tupleSet() << std::endl;
|
||||
// std::cout << "Counted Lv: " << X << std::endl;
|
||||
// std::cout << "X_new1: " << X_new1 << std::endl;
|
||||
// std::cout << "X_new2: " << X_new2 << std::endl;
|
||||
// std::cout << "Original N: " << N << std::endl;
|
||||
// std::cout << endl;
|
||||
|
||||
ConstraintTrees normCts1 = commCt->countNormalize (X);
|
||||
vector<unsigned> counts1 (normCts1.size());
|
||||
std::vector<unsigned> counts1 (normCts1.size());
|
||||
for (size_t i = 0; i < normCts1.size(); i++) {
|
||||
counts1[i] = normCts1[i]->getConditionalCount (X);
|
||||
// cout << "normCts1[" << i << "] #" << counts1[i] ;
|
||||
// cout << " " << normCts1[i]->tupleSet() << endl;
|
||||
// std::cout << "normCts1[" << i << "] #" << counts1[i] ;
|
||||
// std::cout << " " << normCts1[i]->tupleSet() << std::endl;
|
||||
}
|
||||
|
||||
ConstraintTrees normCts2 = exclCt->countNormalize (X);
|
||||
vector<unsigned> counts2 (normCts2.size());
|
||||
std::vector<unsigned> counts2 (normCts2.size());
|
||||
for (size_t i = 0; i < normCts2.size(); i++) {
|
||||
counts2[i] = normCts2[i]->getConditionalCount (X);
|
||||
// cout << "normCts2[" << i << "] #" << counts2[i] ;
|
||||
// cout << " " << normCts2[i]->tupleSet() << endl;
|
||||
// std::cout << "normCts2[" << i << "] #" << counts2[i] ;
|
||||
// std::cout << " " << normCts2[i]->tupleSet() << std::endl;
|
||||
}
|
||||
// cout << endl;
|
||||
// std::cout << std::endl;
|
||||
|
||||
ConstraintTree* excl1 = 0;
|
||||
for (size_t i = 0; i < normCts1.size(); i++) {
|
||||
@ -775,7 +875,7 @@ ConstraintTree::jointCountNormalize (
|
||||
excl1 = normCts1[i];
|
||||
normCts1.erase (normCts1.begin() + i);
|
||||
counts1.erase (counts1.begin() + i);
|
||||
// cout << "joint-count(" << N << ",0)" << endl;
|
||||
// std::cout << "joint-count(" << N << ",0)" << std::endl;
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -786,7 +886,7 @@ ConstraintTree::jointCountNormalize (
|
||||
excl2 = normCts2[i];
|
||||
normCts2.erase (normCts2.begin() + i);
|
||||
counts2.erase (counts2.begin() + i);
|
||||
// cout << "joint-count(0," << N << ")" << endl;
|
||||
// std::cout << "joint-count(0," << N << ")" << std::endl;
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -794,8 +894,8 @@ ConstraintTree::jointCountNormalize (
|
||||
for (size_t i = 0; i < normCts1.size(); i++) {
|
||||
unsigned j;
|
||||
for (j = 0; counts1[i] + counts2[j] != N; j++) ;
|
||||
// cout << "joint-count(" << counts1[i] ;
|
||||
// cout << "," << counts2[j] << ")" << endl;
|
||||
// std::cout << "joint-count(" << counts1[i] ;
|
||||
// std::cout << "," << counts2[j] << ")" << std::endl;
|
||||
const CTChilds& childs = normCts2[j]->root_->childs();
|
||||
for (CTChilds::const_iterator chIt = childs.begin();
|
||||
chIt != childs.end(); ++ chIt) {
|
||||
@ -930,7 +1030,7 @@ CTNodes
|
||||
ConstraintTree::getNodesBelow (CTNode* fromHere) const
|
||||
{
|
||||
CTNodes nodes;
|
||||
queue<CTNode*> queue;
|
||||
std::queue<CTNode*> queue;
|
||||
queue.push (fromHere);
|
||||
while (queue.empty() == false) {
|
||||
CTNode* node = queue.front();
|
||||
@ -1016,7 +1116,7 @@ ConstraintTree::swapLogVar (LogVar X)
|
||||
{
|
||||
size_t pos = Util::indexOf (logVars_, X);
|
||||
assert (pos != logVars_.size());
|
||||
const CTNodes& nodes = getNodesAtLevel (pos);
|
||||
CTNodes nodes = getNodesAtLevel (pos);
|
||||
for (CTNodes::const_iterator nodeIt = nodes.begin();
|
||||
nodeIt != nodes.end(); ++ nodeIt) {
|
||||
CTChilds childsCopy = (*nodeIt)->childs();
|
||||
@ -1098,7 +1198,7 @@ ConstraintTree::getTuples (
|
||||
|
||||
|
||||
unsigned
|
||||
ConstraintTree::size (void) const
|
||||
ConstraintTree::size() const
|
||||
{
|
||||
return countTuples (root_);
|
||||
}
|
||||
@ -1114,26 +1214,26 @@ ConstraintTree::nrSymbols (LogVar X)
|
||||
|
||||
|
||||
|
||||
vector<pair<CTNode*, unsigned>>
|
||||
std::vector<std::pair<CTNode*, unsigned>>
|
||||
ConstraintTree::countNormalize (
|
||||
const CTNode* n,
|
||||
unsigned stopLevel)
|
||||
{
|
||||
if (n->level() == stopLevel) {
|
||||
return vector<pair<CTNode*, unsigned>>() = {
|
||||
make_pair (CTNode::copySubtree (n), countTuples (n))
|
||||
return std::vector<std::pair<CTNode*, unsigned>>() = {
|
||||
std::make_pair (CTNode::copySubtree (n), countTuples (n))
|
||||
};
|
||||
}
|
||||
vector<pair<CTNode*, unsigned>> res;
|
||||
std::vector<std::pair<CTNode*, unsigned>> res;
|
||||
const CTChilds& childs = n->childs();
|
||||
for (CTChilds::const_iterator chIt = childs.begin();
|
||||
chIt != childs.end(); ++ chIt) {
|
||||
const vector<pair<CTNode*, unsigned>>& lowerRes =
|
||||
const std::vector<std::pair<CTNode*, unsigned>>& lowerRes =
|
||||
countNormalize (*chIt, stopLevel);
|
||||
for (size_t j = 0; j < lowerRes.size(); j++) {
|
||||
CTNode* newNode = new CTNode (*n);
|
||||
newNode->mergeSubtree (lowerRes[j].first);
|
||||
res.push_back (make_pair (newNode, lowerRes[j].second));
|
||||
res.push_back (std::make_pair (newNode, lowerRes[j].second));
|
||||
}
|
||||
}
|
||||
return res;
|
||||
@ -1172,3 +1272,5 @@ ConstraintTree::split (
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
|
@ -1,104 +1,35 @@
|
||||
#ifndef HORUS_CONSTRAINTTREE_H
|
||||
#define HORUS_CONSTRAINTTREE_H
|
||||
#ifndef YAP_PACKAGES_CLPBN_HORUS_CONSTRAINTTREE_H_
|
||||
#define YAP_PACKAGES_CLPBN_HORUS_CONSTRAINTTREE_H_
|
||||
|
||||
#include <cassert>
|
||||
#include <algorithm>
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
|
||||
#include "TinySet.h"
|
||||
#include "LiftedUtils.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace Horus {
|
||||
|
||||
class CTNode;
|
||||
typedef vector<CTNode*> CTNodes;
|
||||
|
||||
class ConstraintTree;
|
||||
typedef vector<ConstraintTree*> ConstraintTrees;
|
||||
|
||||
|
||||
class CTNode
|
||||
{
|
||||
public:
|
||||
struct CompareSymbol
|
||||
{
|
||||
bool operator() (const CTNode* n1, const CTNode* n2) const
|
||||
{
|
||||
return n1->symbol() < n2->symbol();
|
||||
}
|
||||
};
|
||||
typedef std::vector<CTNode*> CTNodes;
|
||||
typedef std::vector<ConstraintTree*> ConstraintTrees;
|
||||
|
||||
private:
|
||||
typedef TinySet<CTNode*, CompareSymbol> CTChilds_;
|
||||
|
||||
public:
|
||||
CTNode (const CTNode& n, const CTChilds_& chs = CTChilds_())
|
||||
: symbol_(n.symbol()), childs_(chs), level_(n.level()) { }
|
||||
|
||||
CTNode (Symbol s, unsigned l, const CTChilds_& chs = CTChilds_())
|
||||
: symbol_(s), childs_(chs), level_(l) { }
|
||||
|
||||
unsigned level (void) const { return level_; }
|
||||
|
||||
void setLevel (unsigned level) { level_ = level; }
|
||||
|
||||
Symbol symbol (void) const { return symbol_; }
|
||||
|
||||
void setSymbol (const Symbol s) { symbol_ = s; }
|
||||
|
||||
CTChilds_& childs (void) { return childs_; }
|
||||
|
||||
const CTChilds_& childs (void) const { return childs_; }
|
||||
|
||||
size_t nrChilds (void) const { return childs_.size(); }
|
||||
|
||||
bool isRoot (void) const { return level_ == 0; }
|
||||
|
||||
bool isLeaf (void) const { return childs_.empty(); }
|
||||
|
||||
CTChilds_::iterator findSymbol (Symbol symb)
|
||||
{
|
||||
CTNode tmp (symb, 0);
|
||||
return childs_.find (&tmp);
|
||||
}
|
||||
|
||||
void mergeSubtree (CTNode*, bool = true);
|
||||
|
||||
void removeChild (CTNode*);
|
||||
|
||||
void removeChilds (void);
|
||||
|
||||
void removeAndDeleteChild (CTNode*);
|
||||
|
||||
void removeAndDeleteAllChilds (void);
|
||||
|
||||
SymbolSet childSymbols (void) const;
|
||||
|
||||
static CTNode* copySubtree (const CTNode*);
|
||||
|
||||
static void deleteSubtree (CTNode*);
|
||||
|
||||
private:
|
||||
void updateChildLevels (CTNode*, unsigned);
|
||||
|
||||
Symbol symbol_;
|
||||
CTChilds_ childs_;
|
||||
unsigned level_;
|
||||
|
||||
DISALLOW_ASSIGN (CTNode);
|
||||
struct CmpSymbol {
|
||||
bool operator() (const CTNode* n1, const CTNode* n2) const;
|
||||
};
|
||||
|
||||
ostream& operator<< (ostream &out, const CTNode&);
|
||||
|
||||
typedef TinySet<CTNode*, CmpSymbol> CTChilds;
|
||||
|
||||
|
||||
typedef TinySet<CTNode*, CTNode::CompareSymbol> CTChilds;
|
||||
|
||||
|
||||
class ConstraintTree
|
||||
{
|
||||
class ConstraintTree {
|
||||
public:
|
||||
ConstraintTree (unsigned);
|
||||
|
||||
@ -106,38 +37,23 @@ class ConstraintTree
|
||||
|
||||
ConstraintTree (const LogVars&, const Tuples&);
|
||||
|
||||
ConstraintTree (vector<vector<string>> names);
|
||||
ConstraintTree (std::vector<std::vector<std::string>> names);
|
||||
|
||||
ConstraintTree (const ConstraintTree&);
|
||||
|
||||
ConstraintTree (const CTChilds& rootChilds, const LogVars& logVars)
|
||||
: root_(new CTNode (0, 0, rootChilds)),
|
||||
logVars_(logVars),
|
||||
logVarSet_(logVars) { }
|
||||
ConstraintTree (const CTChilds& rootChilds, const LogVars& logVars);
|
||||
|
||||
~ConstraintTree (void);
|
||||
~ConstraintTree();
|
||||
|
||||
CTNode* root (void) const { return root_; }
|
||||
CTNode* root() const { return root_; }
|
||||
|
||||
bool empty (void) const { return root_->childs().empty(); }
|
||||
bool empty() const;
|
||||
|
||||
const LogVars& logVars (void) const
|
||||
{
|
||||
assert (LogVarSet (logVars_) == logVarSet_);
|
||||
return logVars_;
|
||||
}
|
||||
const LogVars& logVars() const;
|
||||
|
||||
const LogVarSet& logVarSet (void) const
|
||||
{
|
||||
assert (LogVarSet (logVars_) == logVarSet_);
|
||||
return logVarSet_;
|
||||
}
|
||||
const LogVarSet& logVarSet() const;
|
||||
|
||||
size_t nrLogVars (void) const
|
||||
{
|
||||
return logVars_.size();
|
||||
assert (LogVarSet (logVars_) == logVarSet_);
|
||||
}
|
||||
size_t nrLogVars() const;
|
||||
|
||||
void addTuple (const Tuple&);
|
||||
|
||||
@ -163,13 +79,13 @@ class ConstraintTree
|
||||
|
||||
bool isSingleton (LogVar);
|
||||
|
||||
LogVarSet singletons (void);
|
||||
LogVarSet singletons();
|
||||
|
||||
TupleSet tupleSet (unsigned = 0) const;
|
||||
|
||||
TupleSet tupleSet (const LogVars&);
|
||||
|
||||
unsigned size (void) const;
|
||||
unsigned size() const;
|
||||
|
||||
unsigned nrSymbols (LogVar);
|
||||
|
||||
@ -218,11 +134,10 @@ class ConstraintTree
|
||||
|
||||
void getTuples (CTNode*, Tuples, unsigned, Tuples&, CTNodes&) const;
|
||||
|
||||
vector<std::pair<CTNode*, unsigned>> countNormalize (
|
||||
std::vector<std::pair<CTNode*, unsigned>> countNormalize (
|
||||
const CTNode*, unsigned);
|
||||
|
||||
static void split (
|
||||
CTNode*, CTNode*, CTChilds&, CTChilds&, unsigned);
|
||||
static void split (CTNode*, CTNode*, CTChilds&, CTChilds&, unsigned);
|
||||
|
||||
CTNode* root_;
|
||||
LogVars logVars_;
|
||||
@ -230,5 +145,33 @@ class ConstraintTree
|
||||
};
|
||||
|
||||
|
||||
#endif // HORUS_CONSTRAINTTREE_H
|
||||
|
||||
inline const LogVars&
|
||||
ConstraintTree::logVars() const
|
||||
{
|
||||
assert (LogVarSet (logVars_) == logVarSet_);
|
||||
return logVars_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline const LogVarSet&
|
||||
ConstraintTree::logVarSet() const
|
||||
{
|
||||
assert (LogVarSet (logVars_) == logVarSet_);
|
||||
return logVarSet_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline size_t
|
||||
ConstraintTree::nrLogVars() const
|
||||
{
|
||||
assert (LogVarSet (logVars_) == logVarSet_);
|
||||
return logVars_.size();
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_CONSTRAINTTREE_H_
|
||||
|
||||
|
@ -1,7 +1,62 @@
|
||||
#include <cassert>
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "CountingBp.h"
|
||||
#include "WeightedBp.h"
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
class VarCluster {
|
||||
public:
|
||||
VarCluster (const VarNodes& vs) : members_(vs) { }
|
||||
|
||||
const VarNode* first() const { return members_.front(); }
|
||||
|
||||
const VarNodes& members() const { return members_; }
|
||||
|
||||
VarNode* representative() const { return repr_; }
|
||||
|
||||
void setRepresentative (VarNode* vn) { repr_ = vn; }
|
||||
|
||||
private:
|
||||
VarNodes members_;
|
||||
VarNode* repr_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN (VarCluster);
|
||||
};
|
||||
|
||||
|
||||
|
||||
class FacCluster {
|
||||
private:
|
||||
typedef std::vector<VarCluster*> VarClusters;
|
||||
|
||||
public:
|
||||
FacCluster (const FacNodes& fcs, const VarClusters& vcs)
|
||||
: members_(fcs), varClusters_(vcs) { }
|
||||
|
||||
const FacNode* first() const { return members_.front(); }
|
||||
|
||||
const FacNodes& members() const { return members_; }
|
||||
|
||||
FacNode* representative() const { return repr_; }
|
||||
|
||||
void setRepresentative (FacNode* fn) { repr_ = fn; }
|
||||
|
||||
VarClusters& varClusters() { return varClusters_; }
|
||||
|
||||
FacNodes members_;
|
||||
FacNode* repr_;
|
||||
VarClusters varClusters_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN (FacCluster);
|
||||
};
|
||||
|
||||
|
||||
|
||||
bool CountingBp::fif_ = true;
|
||||
|
||||
|
||||
@ -17,7 +72,7 @@ CountingBp::CountingBp (const FactorGraph& fg)
|
||||
|
||||
|
||||
|
||||
CountingBp::~CountingBp (void)
|
||||
CountingBp::~CountingBp()
|
||||
{
|
||||
delete solver_;
|
||||
delete compressedFg_;
|
||||
@ -32,23 +87,24 @@ CountingBp::~CountingBp (void)
|
||||
|
||||
|
||||
void
|
||||
CountingBp::printSolverFlags (void) const
|
||||
CountingBp::printSolverFlags() const
|
||||
{
|
||||
stringstream ss;
|
||||
std::stringstream ss;
|
||||
ss << "counting bp [" ;
|
||||
ss << "bp_msg_schedule=" ;
|
||||
typedef WeightedBp::MsgSchedule MsgSchedule;
|
||||
switch (WeightedBp::msgSchedule()) {
|
||||
case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break;
|
||||
case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break;
|
||||
case MsgSchedule::PARALLEL: ss << "parallel"; break;
|
||||
case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break;
|
||||
case MsgSchedule::seqFixedSch: ss << "seq_fixed"; break;
|
||||
case MsgSchedule::seqRandomSch: ss << "seq_random"; break;
|
||||
case MsgSchedule::parallelSch: ss << "parallel"; break;
|
||||
case MsgSchedule::maxResidualSch: ss << "max_residual"; break;
|
||||
}
|
||||
ss << ",bp_max_iter=" << WeightedBp::maxIterations();
|
||||
ss << ",bp_accuracy=" << WeightedBp::accuracy();
|
||||
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
||||
ss << ",fif=" << Util::toString (CountingBp::fif_);
|
||||
ss << "]" ;
|
||||
cout << ss.str() << endl;
|
||||
std::cout << ss.str() << std::endl;
|
||||
}
|
||||
|
||||
|
||||
@ -69,11 +125,10 @@ CountingBp::solveQuery (VarIds queryVids)
|
||||
idx = i;
|
||||
break;
|
||||
}
|
||||
cout << endl;
|
||||
}
|
||||
if (idx == facNodes.size()) {
|
||||
res = GroundSolver::getJointByConditioning (
|
||||
GroundSolverType::CBP, fg, queryVids);
|
||||
GroundSolverType::CbpSolver, fg, queryVids);
|
||||
} else {
|
||||
VarIds reprArgs;
|
||||
for (size_t i = 0; i < queryVids.size(); i++) {
|
||||
@ -124,7 +179,7 @@ CountingBp::findIdenticalFactors()
|
||||
|
||||
|
||||
void
|
||||
CountingBp::setInitialColors (void)
|
||||
CountingBp::setInitialColors()
|
||||
{
|
||||
varColors_.resize (fg.nrVarNodes());
|
||||
facColors_.resize (fg.nrFacNodes());
|
||||
@ -135,7 +190,7 @@ CountingBp::setInitialColors (void)
|
||||
unsigned range = varNodes[i]->range();
|
||||
VarColorMap::iterator it = colorMap.find (range);
|
||||
if (it == colorMap.end()) {
|
||||
it = colorMap.insert (make_pair (
|
||||
it = colorMap.insert (std::make_pair (
|
||||
range, Colors (range + 1, -1))).first;
|
||||
}
|
||||
unsigned idx = varNodes[i]->hasEvidence()
|
||||
@ -154,7 +209,8 @@ CountingBp::setInitialColors (void)
|
||||
unsigned distId = facNodes[i]->factor().distId();
|
||||
DistColorMap::iterator it = distColors.find (distId);
|
||||
if (it == distColors.end()) {
|
||||
it = distColors.insert (make_pair (distId, getNewColor())).first;
|
||||
it = distColors.insert (std::make_pair (
|
||||
distId, getNewColor())).first;
|
||||
}
|
||||
setColor (facNodes[i], it->second);
|
||||
}
|
||||
@ -163,7 +219,7 @@ CountingBp::setInitialColors (void)
|
||||
|
||||
|
||||
void
|
||||
CountingBp::createGroups (void)
|
||||
CountingBp::createGroups()
|
||||
{
|
||||
VarSignMap varGroups;
|
||||
FacSignMap facGroups;
|
||||
@ -179,10 +235,11 @@ CountingBp::createGroups (void)
|
||||
size_t prevVarGroupsSize = varGroups.size();
|
||||
varGroups.clear();
|
||||
for (size_t i = 0; i < varNodes.size(); i++) {
|
||||
const VarSignature& signature = getSignature (varNodes[i]);
|
||||
VarSignature signature = getSignature (varNodes[i]);
|
||||
VarSignMap::iterator it = varGroups.find (signature);
|
||||
if (it == varGroups.end()) {
|
||||
it = varGroups.insert (make_pair (signature, VarNodes())).first;
|
||||
it = varGroups.insert (std::make_pair (
|
||||
signature, VarNodes())).first;
|
||||
}
|
||||
it->second.push_back (varNodes[i]);
|
||||
}
|
||||
@ -199,10 +256,11 @@ CountingBp::createGroups (void)
|
||||
facGroups.clear();
|
||||
// set a new color to the factors with the same signature
|
||||
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||
const FacSignature& signature = getSignature (facNodes[i]);
|
||||
FacSignature signature = getSignature (facNodes[i]);
|
||||
FacSignMap::iterator it = facGroups.find (signature);
|
||||
if (it == facGroups.end()) {
|
||||
it = facGroups.insert (make_pair (signature, FacNodes())).first;
|
||||
it = facGroups.insert (std::make_pair (
|
||||
signature, FacNodes())).first;
|
||||
}
|
||||
it->second.push_back (facNodes[i]);
|
||||
}
|
||||
@ -235,7 +293,8 @@ CountingBp::createClusters (
|
||||
const VarNodes& groupVars = it->second;
|
||||
VarCluster* vc = new VarCluster (groupVars);
|
||||
for (size_t i = 0; i < groupVars.size(); i++) {
|
||||
varClusterMap_.insert (make_pair (groupVars[i]->varId(), vc));
|
||||
varClusterMap_.insert (std::make_pair (
|
||||
groupVars[i]->varId(), vc));
|
||||
}
|
||||
varClusters_.push_back (vc);
|
||||
}
|
||||
@ -257,29 +316,29 @@ CountingBp::createClusters (
|
||||
|
||||
|
||||
|
||||
VarSignature
|
||||
CountingBp::VarSignature
|
||||
CountingBp::getSignature (const VarNode* varNode)
|
||||
{
|
||||
const FacNodes& neighs = varNode->neighbors();
|
||||
VarSignature sign;
|
||||
const FacNodes& neighs = varNode->neighbors();
|
||||
sign.reserve (neighs.size() + 1);
|
||||
for (size_t i = 0; i < neighs.size(); i++) {
|
||||
sign.push_back (make_pair (
|
||||
sign.push_back (std::make_pair (
|
||||
getColor (neighs[i]),
|
||||
neighs[i]->factor().indexOf (varNode->varId())));
|
||||
}
|
||||
std::sort (sign.begin(), sign.end());
|
||||
sign.push_back (make_pair (getColor (varNode), 0));
|
||||
sign.push_back (std::make_pair (getColor (varNode), 0));
|
||||
return sign;
|
||||
}
|
||||
|
||||
|
||||
|
||||
FacSignature
|
||||
CountingBp::FacSignature
|
||||
CountingBp::getSignature (const FacNode* facNode)
|
||||
{
|
||||
const VarNodes& neighs = facNode->neighbors();
|
||||
FacSignature sign;
|
||||
const VarNodes& neighs = facNode->neighbors();
|
||||
sign.reserve (neighs.size() + 1);
|
||||
for (size_t i = 0; i < neighs.size(); i++) {
|
||||
sign.push_back (getColor (neighs[i]));
|
||||
@ -314,7 +373,7 @@ CountingBp::getRepresentative (FacNode* fn)
|
||||
|
||||
|
||||
FactorGraph*
|
||||
CountingBp::getCompressedFactorGraph (void)
|
||||
CountingBp::getCompressedFactorGraph()
|
||||
{
|
||||
FactorGraph* fg = new FactorGraph();
|
||||
for (size_t i = 0; i < varClusters_.size(); i++) {
|
||||
@ -342,10 +401,10 @@ CountingBp::getCompressedFactorGraph (void)
|
||||
|
||||
|
||||
|
||||
vector<vector<unsigned>>
|
||||
CountingBp::getWeights (void) const
|
||||
std::vector<std::vector<unsigned>>
|
||||
CountingBp::getWeights() const
|
||||
{
|
||||
vector<vector<unsigned>> weights;
|
||||
std::vector<std::vector<unsigned>> weights;
|
||||
weights.reserve (facClusters_.size());
|
||||
for (size_t i = 0; i < facClusters_.size(); i++) {
|
||||
const VarClusters& neighs = facClusters_[i]->varClusters();
|
||||
@ -390,32 +449,34 @@ CountingBp::printGroups (
|
||||
const FacSignMap& facGroups) const
|
||||
{
|
||||
unsigned count = 1;
|
||||
cout << "variable groups:" << endl;
|
||||
std::cout << "variable groups:" << std::endl;
|
||||
for (VarSignMap::const_iterator it = varGroups.begin();
|
||||
it != varGroups.end(); ++it) {
|
||||
const VarNodes& groupMembers = it->second;
|
||||
if (groupMembers.size() > 0) {
|
||||
cout << count << ": " ;
|
||||
std::cout << count << ": " ;
|
||||
for (size_t i = 0; i < groupMembers.size(); i++) {
|
||||
cout << groupMembers[i]->label() << " " ;
|
||||
std::cout << groupMembers[i]->label() << " " ;
|
||||
}
|
||||
count ++;
|
||||
cout << endl;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
count = 1;
|
||||
cout << endl << "factor groups:" << endl;
|
||||
std::cout << std::endl << "factor groups:" << std::endl;
|
||||
for (FacSignMap::const_iterator it = facGroups.begin();
|
||||
it != facGroups.end(); ++it) {
|
||||
const FacNodes& groupMembers = it->second;
|
||||
if (groupMembers.size() > 0) {
|
||||
cout << ++count << ": " ;
|
||||
std::cout << ++count << ": " ;
|
||||
for (size_t i = 0; i < groupMembers.size(); i++) {
|
||||
cout << groupMembers[i]->getLabel() << " " ;
|
||||
std::cout << groupMembers[i]->getLabel() << " " ;
|
||||
}
|
||||
count ++;
|
||||
cout << endl;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
|
@ -1,155 +1,99 @@
|
||||
#ifndef HORUS_COUNTINGBP_H
|
||||
#define HORUS_COUNTINGBP_H
|
||||
#ifndef YAP_PACKAGES_CLPBN_HORUS_COUNTINGBP_H_
|
||||
#define YAP_PACKAGES_CLPBN_HORUS_COUNTINGBP_H_
|
||||
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "GroundSolver.h"
|
||||
#include "FactorGraph.h"
|
||||
#include "Horus.h"
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
class VarCluster;
|
||||
class FacCluster;
|
||||
class WeightedBp;
|
||||
|
||||
typedef long Color;
|
||||
typedef vector<Color> Colors;
|
||||
typedef vector<std::pair<Color,unsigned>> VarSignature;
|
||||
typedef vector<Color> FacSignature;
|
||||
|
||||
typedef unordered_map<unsigned, Color> DistColorMap;
|
||||
typedef unordered_map<unsigned, Colors> VarColorMap;
|
||||
|
||||
typedef unordered_map<VarSignature, VarNodes> VarSignMap;
|
||||
typedef unordered_map<FacSignature, FacNodes> FacSignMap;
|
||||
|
||||
typedef unordered_map<VarId, VarCluster*> VarClusterMap;
|
||||
|
||||
typedef vector<VarCluster*> VarClusters;
|
||||
typedef vector<FacCluster*> FacClusters;
|
||||
|
||||
template <class T>
|
||||
inline size_t hash_combine (size_t seed, const T& v)
|
||||
template <class T> inline size_t
|
||||
hash_combine (size_t seed, const T& v)
|
||||
{
|
||||
return seed ^ (hash<T>()(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2));
|
||||
return seed ^ (std::hash<T>()(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2));
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
|
||||
namespace std {
|
||||
template <typename T1, typename T2> struct hash<std::pair<T1,T2>>
|
||||
{
|
||||
size_t operator() (const std::pair<T1,T2>& p) const
|
||||
{
|
||||
return hash_combine (std::hash<T1>()(p.first), p.second);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T> struct hash<std::vector<T>>
|
||||
{
|
||||
size_t operator() (const std::vector<T>& vec) const
|
||||
{
|
||||
size_t h = 0;
|
||||
typename vector<T>::const_iterator first = vec.begin();
|
||||
typename vector<T>::const_iterator last = vec.end();
|
||||
for (; first != last; ++first) {
|
||||
h = hash_combine (h, *first);
|
||||
}
|
||||
return h;
|
||||
}
|
||||
};
|
||||
}
|
||||
template <typename T1, typename T2> struct hash<std::pair<T1,T2>> {
|
||||
size_t operator() (const std::pair<T1,T2>& p) const {
|
||||
return Horus::hash_combine (std::hash<T1>()(p.first), p.second);
|
||||
}};
|
||||
|
||||
|
||||
class VarCluster
|
||||
template <typename T> struct hash<std::vector<T>>
|
||||
{
|
||||
public:
|
||||
VarCluster (const VarNodes& vs) : members_(vs) { }
|
||||
|
||||
const VarNode* first (void) const { return members_.front(); }
|
||||
|
||||
const VarNodes& members (void) const { return members_; }
|
||||
|
||||
VarNode* representative (void) const { return repr_; }
|
||||
|
||||
void setRepresentative (VarNode* vn) { repr_ = vn; }
|
||||
|
||||
private:
|
||||
VarNodes members_;
|
||||
VarNode* repr_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN (VarCluster);
|
||||
size_t operator() (const std::vector<T>& vec) const
|
||||
{
|
||||
size_t h = 0;
|
||||
typename std::vector<T>::const_iterator first = vec.begin();
|
||||
typename std::vector<T>::const_iterator last = vec.end();
|
||||
for (; first != last; ++first) {
|
||||
h = Horus::hash_combine (h, *first);
|
||||
}
|
||||
return h;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
class FacCluster
|
||||
{
|
||||
public:
|
||||
FacCluster (const FacNodes& fcs, const VarClusters& vcs)
|
||||
: members_(fcs), varClusters_(vcs) { }
|
||||
|
||||
const FacNode* first (void) const { return members_.front(); }
|
||||
|
||||
const FacNodes& members (void) const { return members_; }
|
||||
|
||||
FacNode* representative (void) const { return repr_; }
|
||||
|
||||
void setRepresentative (FacNode* fn) { repr_ = fn; }
|
||||
|
||||
VarClusters& varClusters (void) { return varClusters_; }
|
||||
|
||||
private:
|
||||
FacNodes members_;
|
||||
FacNode* repr_;
|
||||
VarClusters varClusters_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN (FacCluster);
|
||||
};
|
||||
} // namespace std
|
||||
|
||||
|
||||
class CountingBp : public GroundSolver
|
||||
{
|
||||
namespace Horus {
|
||||
|
||||
class CountingBp : public GroundSolver {
|
||||
public:
|
||||
CountingBp (const FactorGraph& fg);
|
||||
|
||||
~CountingBp (void);
|
||||
~CountingBp();
|
||||
|
||||
void printSolverFlags (void) const;
|
||||
void printSolverFlags() const;
|
||||
|
||||
Params solveQuery (VarIds);
|
||||
|
||||
static void setFindIdenticalFactorsFlag (bool fif) { fif_ = fif; }
|
||||
|
||||
private:
|
||||
Color getNewColor (void)
|
||||
{
|
||||
++ freeColor_;
|
||||
return freeColor_ - 1;
|
||||
}
|
||||
typedef long Color;
|
||||
typedef std::vector<Color> Colors;
|
||||
|
||||
Color getColor (const VarNode* vn) const
|
||||
{
|
||||
return varColors_[vn->getIndex()];
|
||||
}
|
||||
typedef std::vector<std::pair<Color,unsigned>> VarSignature;
|
||||
typedef std::vector<Color> FacSignature;
|
||||
|
||||
Color getColor (const FacNode* fn) const
|
||||
{
|
||||
return facColors_[fn->getIndex()];
|
||||
}
|
||||
typedef std::vector<VarCluster*> VarClusters;
|
||||
typedef std::vector<FacCluster*> FacClusters;
|
||||
|
||||
void setColor (const VarNode* vn, Color c)
|
||||
{
|
||||
varColors_[vn->getIndex()] = c;
|
||||
}
|
||||
typedef std::unordered_map<unsigned, Color> DistColorMap;
|
||||
typedef std::unordered_map<unsigned, Colors> VarColorMap;
|
||||
typedef std::unordered_map<VarSignature, VarNodes> VarSignMap;
|
||||
typedef std::unordered_map<FacSignature, FacNodes> FacSignMap;
|
||||
typedef std::unordered_map<VarId, VarCluster*> VarClusterMap;
|
||||
|
||||
void setColor (const FacNode* fn, Color c)
|
||||
{
|
||||
facColors_[fn->getIndex()] = c;
|
||||
}
|
||||
Color getNewColor();
|
||||
|
||||
void findIdenticalFactors (void);
|
||||
Color getColor (const VarNode* vn) const;
|
||||
|
||||
void setInitialColors (void);
|
||||
Color getColor (const FacNode* fn) const;
|
||||
|
||||
void createGroups (void);
|
||||
void setColor (const VarNode* vn, Color c);
|
||||
|
||||
void setColor (const FacNode* fn, Color c);
|
||||
|
||||
void findIdenticalFactors();
|
||||
|
||||
void setInitialColors();
|
||||
|
||||
void createGroups();
|
||||
|
||||
void createClusters (const VarSignMap&, const FacSignMap&);
|
||||
|
||||
@ -163,12 +107,12 @@ class CountingBp : public GroundSolver
|
||||
|
||||
FacNode* getRepresentative (FacNode*);
|
||||
|
||||
FactorGraph* getCompressedFactorGraph (void);
|
||||
FactorGraph* getCompressedFactorGraph();
|
||||
|
||||
vector<vector<unsigned>> getWeights (void) const;
|
||||
std::vector<std::vector<unsigned>> getWeights() const;
|
||||
|
||||
unsigned getWeight (const FacCluster*,
|
||||
const VarCluster*, size_t index) const;
|
||||
unsigned getWeight (const FacCluster*, const VarCluster*,
|
||||
size_t index) const;
|
||||
|
||||
Color freeColor_;
|
||||
Colors varColors_;
|
||||
@ -184,5 +128,48 @@ class CountingBp : public GroundSolver
|
||||
DISALLOW_COPY_AND_ASSIGN (CountingBp);
|
||||
};
|
||||
|
||||
#endif // HORUS_COUNTINGBP_H
|
||||
|
||||
|
||||
inline CountingBp::Color
|
||||
CountingBp::getNewColor()
|
||||
{
|
||||
++ freeColor_;
|
||||
return freeColor_ - 1;
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline CountingBp::Color
|
||||
CountingBp::getColor (const VarNode* vn) const
|
||||
{
|
||||
return varColors_[vn->getIndex()];
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline CountingBp::Color
|
||||
CountingBp::getColor (const FacNode* fn) const
|
||||
{
|
||||
return facColors_[fn->getIndex()];
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
CountingBp::setColor (const VarNode* vn, CountingBp::Color c)
|
||||
{
|
||||
varColors_[vn->getIndex()] = c;
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
CountingBp::setColor (const FacNode* fn, Color c)
|
||||
{
|
||||
facColors_[fn->getIndex()] = c;
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_COUNTINGBP_H_
|
||||
|
||||
|
@ -1,25 +1,30 @@
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
|
||||
#include "ElimGraph.h"
|
||||
|
||||
ElimHeuristic ElimGraph::elimHeuristic_ = MIN_NEIGHBORS;
|
||||
|
||||
namespace Horus {
|
||||
|
||||
ElimGraph::ElimHeuristic ElimGraph::elimHeuristic_ =
|
||||
ElimHeuristic::minNeighborsEh;
|
||||
|
||||
|
||||
ElimGraph::ElimGraph (const vector<Factor*>& factors)
|
||||
ElimGraph::ElimGraph (const std::vector<Factor*>& factors)
|
||||
{
|
||||
for (size_t i = 0; i < factors.size(); i++) {
|
||||
if (factors[i]) {
|
||||
const VarIds& args = factors[i]->arguments();
|
||||
for (size_t j = 0; j < args.size() - 1; j++) {
|
||||
EgNode* n1 = getEgNode (args[j]);
|
||||
EGNode* n1 = getEGNode (args[j]);
|
||||
if (!n1) {
|
||||
n1 = new EgNode (args[j], factors[i]->range (j));
|
||||
n1 = new EGNode (args[j], factors[i]->range (j));
|
||||
addNode (n1);
|
||||
}
|
||||
for (size_t k = j + 1; k < args.size(); k++) {
|
||||
EgNode* n2 = getEgNode (args[k]);
|
||||
EGNode* n2 = getEGNode (args[k]);
|
||||
if (!n2) {
|
||||
n2 = new EgNode (args[k], factors[i]->range (k));
|
||||
n2 = new EGNode (args[k], factors[i]->range (k));
|
||||
addNode (n2);
|
||||
}
|
||||
if (!neighbors (n1, n2)) {
|
||||
@ -27,8 +32,8 @@ ElimGraph::ElimGraph (const vector<Factor*>& factors)
|
||||
}
|
||||
}
|
||||
}
|
||||
if (args.size() == 1 && !getEgNode (args[0])) {
|
||||
addNode (new EgNode (args[0], factors[i]->range (0)));
|
||||
if (args.size() == 1 && !getEGNode (args[0])) {
|
||||
addNode (new EGNode (args[0], factors[i]->range (0)));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -36,7 +41,7 @@ ElimGraph::ElimGraph (const vector<Factor*>& factors)
|
||||
|
||||
|
||||
|
||||
ElimGraph::~ElimGraph (void)
|
||||
ElimGraph::~ElimGraph()
|
||||
{
|
||||
for (size_t i = 0; i < nodes_.size(); i++) {
|
||||
delete nodes_[i];
|
||||
@ -57,7 +62,7 @@ ElimGraph::getEliminatingOrder (const VarIds& excludedVids)
|
||||
}
|
||||
size_t nrVarsToEliminate = nodes_.size() - excludedVids.size();
|
||||
for (size_t i = 0; i < nrVarsToEliminate; i++) {
|
||||
EgNode* node = getLowestCostNode();
|
||||
EGNode* node = getLowestCostNode();
|
||||
unmarked_.remove (node);
|
||||
const EGNeighs& neighs = node->neighbors();
|
||||
for (size_t j = 0; j < neighs.size(); j++) {
|
||||
@ -72,15 +77,15 @@ ElimGraph::getEliminatingOrder (const VarIds& excludedVids)
|
||||
|
||||
|
||||
void
|
||||
ElimGraph::print (void) const
|
||||
ElimGraph::print() const
|
||||
{
|
||||
for (size_t i = 0; i < nodes_.size(); i++) {
|
||||
cout << "node " << nodes_[i]->label() << " neighs:" ;
|
||||
std::cout << "node " << nodes_[i]->label() << " neighs:" ;
|
||||
EGNeighs neighs = nodes_[i]->neighbors();
|
||||
for (size_t j = 0; j < neighs.size(); j++) {
|
||||
cout << " " << neighs[j]->label();
|
||||
std::cout << " " << neighs[j]->label();
|
||||
}
|
||||
cout << endl;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
@ -92,25 +97,27 @@ ElimGraph::exportToGraphViz (
|
||||
bool showNeighborless,
|
||||
const VarIds& highlightVarIds) const
|
||||
{
|
||||
ofstream out (fileName);
|
||||
std::ofstream out (fileName);
|
||||
if (!out.is_open()) {
|
||||
cerr << "Error: couldn't open file '" << fileName << "'." ;
|
||||
std::cerr << "Error: couldn't open file '" << fileName << "'." ;
|
||||
std::cerr << std::endl;
|
||||
return;
|
||||
}
|
||||
out << "strict graph {" << endl;
|
||||
out << "strict graph {" << std::endl;
|
||||
for (size_t i = 0; i < nodes_.size(); i++) {
|
||||
if (showNeighborless || nodes_[i]->neighbors().empty() == false) {
|
||||
out << '"' << nodes_[i]->label() << '"' << endl;
|
||||
out << '"' << nodes_[i]->label() << '"' << std::endl;
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < highlightVarIds.size(); i++) {
|
||||
EgNode* node =getEgNode (highlightVarIds[i]);
|
||||
EGNode* node =getEGNode (highlightVarIds[i]);
|
||||
if (node) {
|
||||
out << '"' << node->label() << '"' ;
|
||||
out << " [shape=box3d]" << endl;
|
||||
out << " [shape=box3d]" << std::endl;
|
||||
} else {
|
||||
cerr << "Error: invalid variable id: " << highlightVarIds[i] << "." ;
|
||||
cerr << endl;
|
||||
std::cerr << "Error: invalid variable id: " ;
|
||||
std::cerr << highlightVarIds[i] << "." ;
|
||||
std::cerr << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
}
|
||||
@ -118,10 +125,10 @@ ElimGraph::exportToGraphViz (
|
||||
EGNeighs neighs = nodes_[i]->neighbors();
|
||||
for (size_t j = 0; j < neighs.size(); j++) {
|
||||
out << '"' << nodes_[i]->label() << '"' << " -- " ;
|
||||
out << '"' << neighs[j]->label() << '"' << endl;
|
||||
out << '"' << neighs[j]->label() << '"' << std::endl;
|
||||
}
|
||||
}
|
||||
out << "}" << endl;
|
||||
out << "}" << std::endl;
|
||||
out.close();
|
||||
}
|
||||
|
||||
@ -132,7 +139,7 @@ ElimGraph::getEliminationOrder (
|
||||
const Factors& factors,
|
||||
VarIds excludedVids)
|
||||
{
|
||||
if (elimHeuristic_ == ElimHeuristic::SEQUENTIAL) {
|
||||
if (elimHeuristic_ == ElimHeuristic::sequentialEh) {
|
||||
VarIds allVids;
|
||||
Factors::const_iterator first = factors.begin();
|
||||
Factors::const_iterator end = factors.end();
|
||||
@ -150,33 +157,33 @@ ElimGraph::getEliminationOrder (
|
||||
|
||||
|
||||
void
|
||||
ElimGraph::addNode (EgNode* n)
|
||||
ElimGraph::addNode (EGNode* n)
|
||||
{
|
||||
nodes_.push_back (n);
|
||||
n->setIndex (nodes_.size() - 1);
|
||||
varMap_.insert (make_pair (n->varId(), n));
|
||||
varMap_.insert (std::make_pair (n->varId(), n));
|
||||
}
|
||||
|
||||
|
||||
|
||||
EgNode*
|
||||
ElimGraph::getEgNode (VarId vid) const
|
||||
ElimGraph::EGNode*
|
||||
ElimGraph::getEGNode (VarId vid) const
|
||||
{
|
||||
unordered_map<VarId, EgNode*>::const_iterator it;
|
||||
std::unordered_map<VarId, EGNode*>::const_iterator it;
|
||||
it = varMap_.find (vid);
|
||||
return (it != varMap_.end()) ? it->second : 0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
EgNode*
|
||||
ElimGraph::getLowestCostNode (void) const
|
||||
ElimGraph::EGNode*
|
||||
ElimGraph::getLowestCostNode() const
|
||||
{
|
||||
EgNode* bestNode = 0;
|
||||
EGNode* bestNode = 0;
|
||||
unsigned minCost = Util::maxUnsigned();
|
||||
EGNeighs::const_iterator it;
|
||||
switch (elimHeuristic_) {
|
||||
case MIN_NEIGHBORS: {
|
||||
case ElimHeuristic::minNeighborsEh: {
|
||||
for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) {
|
||||
unsigned cost = getNeighborsCost (*it);
|
||||
if (cost < minCost) {
|
||||
@ -185,7 +192,7 @@ ElimGraph::getLowestCostNode (void) const
|
||||
}
|
||||
}}
|
||||
break;
|
||||
case MIN_WEIGHT: {
|
||||
case ElimHeuristic::minWeightEh: {
|
||||
for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) {
|
||||
unsigned cost = getWeightCost (*it);
|
||||
if (cost < minCost) {
|
||||
@ -194,7 +201,7 @@ ElimGraph::getLowestCostNode (void) const
|
||||
}
|
||||
}}
|
||||
break;
|
||||
case MIN_FILL: {
|
||||
case ElimHeuristic::minFillEh: {
|
||||
for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) {
|
||||
unsigned cost = getFillCost (*it);
|
||||
if (cost < minCost) {
|
||||
@ -203,7 +210,7 @@ ElimGraph::getLowestCostNode (void) const
|
||||
}
|
||||
}}
|
||||
break;
|
||||
case WEIGHTED_MIN_FILL: {
|
||||
case ElimHeuristic::weightedMinFillEh: {
|
||||
for (it = unmarked_.begin(); it != unmarked_.end(); ++ it) {
|
||||
unsigned cost = getWeightedFillCost (*it);
|
||||
if (cost < minCost) {
|
||||
@ -222,7 +229,7 @@ ElimGraph::getLowestCostNode (void) const
|
||||
|
||||
|
||||
void
|
||||
ElimGraph::connectAllNeighbors (const EgNode* n)
|
||||
ElimGraph::connectAllNeighbors (const EGNode* n)
|
||||
{
|
||||
const EGNeighs& neighs = n->neighbors();
|
||||
if (neighs.size() > 0) {
|
||||
@ -236,3 +243,5 @@ ElimGraph::connectAllNeighbors (const EgNode* n)
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
|
@ -1,143 +1,177 @@
|
||||
#ifndef HORUS_ELIMGRAPH_H
|
||||
#define HORUS_ELIMGRAPH_H
|
||||
#ifndef YAP_PACKAGES_CLPBN_HORUS_ELIMGRAPH_H_
|
||||
#define YAP_PACKAGES_CLPBN_HORUS_ELIMGRAPH_H_
|
||||
|
||||
#include "unordered_map"
|
||||
#include <cassert>
|
||||
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "FactorGraph.h"
|
||||
#include "TinySet.h"
|
||||
#include "Horus.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
enum ElimHeuristic
|
||||
{
|
||||
SEQUENTIAL,
|
||||
MIN_NEIGHBORS,
|
||||
MIN_WEIGHT,
|
||||
MIN_FILL,
|
||||
WEIGHTED_MIN_FILL
|
||||
};
|
||||
namespace Horus {
|
||||
|
||||
|
||||
class EgNode;
|
||||
|
||||
typedef TinySet<EgNode*> EGNeighs;
|
||||
|
||||
|
||||
class EgNode : public Var
|
||||
{
|
||||
class ElimGraph {
|
||||
public:
|
||||
EgNode (VarId vid, unsigned range) : Var (vid, range) { }
|
||||
enum class ElimHeuristic {
|
||||
sequentialEh,
|
||||
minNeighborsEh,
|
||||
minWeightEh,
|
||||
minFillEh,
|
||||
weightedMinFillEh
|
||||
};
|
||||
|
||||
void addNeighbor (EgNode* n) { neighs_.insert (n); }
|
||||
|
||||
void removeNeighbor (EgNode* n) { neighs_.remove (n); }
|
||||
|
||||
bool isNeighbor (EgNode* n) const { return neighs_.contains (n); }
|
||||
|
||||
const EGNeighs& neighbors (void) const { return neighs_; }
|
||||
|
||||
private:
|
||||
EGNeighs neighs_;
|
||||
};
|
||||
|
||||
|
||||
class ElimGraph
|
||||
{
|
||||
public:
|
||||
ElimGraph (const Factors&);
|
||||
|
||||
~ElimGraph (void);
|
||||
~ElimGraph();
|
||||
|
||||
VarIds getEliminatingOrder (const VarIds&);
|
||||
|
||||
void print (void) const;
|
||||
void print() const;
|
||||
|
||||
void exportToGraphViz (const char*, bool = true,
|
||||
const VarIds& = VarIds()) const;
|
||||
|
||||
static VarIds getEliminationOrder (const Factors&, VarIds);
|
||||
|
||||
static ElimHeuristic elimHeuristic (void) { return elimHeuristic_; }
|
||||
static ElimHeuristic elimHeuristic() { return elimHeuristic_; }
|
||||
|
||||
static void setElimHeuristic (ElimHeuristic eh) { elimHeuristic_ = eh; }
|
||||
|
||||
private:
|
||||
void addEdge (EgNode* n1, EgNode* n2)
|
||||
{
|
||||
assert (n1 != n2);
|
||||
n1->addNeighbor (n2);
|
||||
n2->addNeighbor (n1);
|
||||
}
|
||||
class EGNode;
|
||||
|
||||
unsigned getNeighborsCost (const EgNode* n) const
|
||||
{
|
||||
return n->neighbors().size();
|
||||
}
|
||||
typedef TinySet<EGNode*> EGNeighs;
|
||||
|
||||
unsigned getWeightCost (const EgNode* n) const
|
||||
{
|
||||
unsigned cost = 1;
|
||||
const EGNeighs& neighs = n->neighbors();
|
||||
for (size_t i = 0; i < neighs.size(); i++) {
|
||||
cost *= neighs[i]->range();
|
||||
}
|
||||
return cost;
|
||||
}
|
||||
class EGNode : public Var {
|
||||
public:
|
||||
EGNode (VarId vid, unsigned range) : Var (vid, range) { }
|
||||
|
||||
unsigned getFillCost (const EgNode* n) const
|
||||
{
|
||||
unsigned cost = 0;
|
||||
const EGNeighs& neighs = n->neighbors();
|
||||
if (neighs.size() > 0) {
|
||||
for (size_t i = 0; i < neighs.size() - 1; i++) {
|
||||
for (size_t j = i + 1; j < neighs.size(); j++) {
|
||||
if ( ! neighbors (neighs[i], neighs[j])) {
|
||||
cost ++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return cost;
|
||||
}
|
||||
void addNeighbor (EGNode* n) { neighs_.insert (n); }
|
||||
|
||||
unsigned getWeightedFillCost (const EgNode* n) const
|
||||
{
|
||||
unsigned cost = 0;
|
||||
const EGNeighs& neighs = n->neighbors();
|
||||
if (neighs.size() > 0) {
|
||||
for (size_t i = 0; i < neighs.size() - 1; i++) {
|
||||
for (size_t j = i + 1; j < neighs.size(); j++) {
|
||||
if ( ! neighbors (neighs[i], neighs[j])) {
|
||||
cost += neighs[i]->range() * neighs[j]->range();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return cost;
|
||||
}
|
||||
void removeNeighbor (EGNode* n) { neighs_.remove (n); }
|
||||
|
||||
bool neighbors (EgNode* n1, EgNode* n2) const
|
||||
{
|
||||
return n1->isNeighbor (n2);
|
||||
}
|
||||
bool isNeighbor (EGNode* n) const { return neighs_.contains (n); }
|
||||
|
||||
void addNode (EgNode*);
|
||||
const EGNeighs& neighbors() const { return neighs_; }
|
||||
|
||||
EgNode* getEgNode (VarId) const;
|
||||
private:
|
||||
EGNeighs neighs_;
|
||||
};
|
||||
|
||||
EgNode* getLowestCostNode (void) const;
|
||||
void addEdge (EGNode* n1, EGNode* n2);
|
||||
|
||||
void connectAllNeighbors (const EgNode*);
|
||||
unsigned getNeighborsCost (const EGNode* n) const;
|
||||
|
||||
vector<EgNode*> nodes_;
|
||||
TinySet<EgNode*> unmarked_;
|
||||
unordered_map<VarId, EgNode*> varMap_;
|
||||
unsigned getWeightCost (const EGNode* n) const;
|
||||
|
||||
unsigned getFillCost (const EGNode* n) const;
|
||||
|
||||
unsigned getWeightedFillCost (const EGNode* n) const;
|
||||
|
||||
bool neighbors (EGNode* n1, EGNode* n2) const;
|
||||
|
||||
void addNode (EGNode*);
|
||||
|
||||
EGNode* getEGNode (VarId) const;
|
||||
|
||||
EGNode* getLowestCostNode() const;
|
||||
|
||||
void connectAllNeighbors (const EGNode*);
|
||||
|
||||
std::vector<EGNode*> nodes_;
|
||||
EGNeighs unmarked_;
|
||||
std::unordered_map<VarId, EGNode*> varMap_;
|
||||
|
||||
static ElimHeuristic elimHeuristic_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN (ElimGraph);
|
||||
};
|
||||
|
||||
#endif // HORUS_ELIMGRAPH_H
|
||||
|
||||
|
||||
/* Profiling shows that we should inline the following functions */
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
ElimGraph::addEdge (EGNode* n1, EGNode* n2)
|
||||
{
|
||||
assert (n1 != n2);
|
||||
n1->addNeighbor (n2);
|
||||
n2->addNeighbor (n1);
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline unsigned
|
||||
ElimGraph::getNeighborsCost (const EGNode* n) const
|
||||
{
|
||||
return n->neighbors().size();
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline unsigned
|
||||
ElimGraph::getWeightCost (const EGNode* n) const
|
||||
{
|
||||
unsigned cost = 1;
|
||||
const EGNeighs& neighs = n->neighbors();
|
||||
for (size_t i = 0; i < neighs.size(); i++) {
|
||||
cost *= neighs[i]->range();
|
||||
}
|
||||
return cost;
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline unsigned
|
||||
ElimGraph::getFillCost (const EGNode* n) const
|
||||
{
|
||||
unsigned cost = 0;
|
||||
const EGNeighs& neighs = n->neighbors();
|
||||
if (neighs.size() > 0) {
|
||||
for (size_t i = 0; i < neighs.size() - 1; i++) {
|
||||
for (size_t j = i + 1; j < neighs.size(); j++) {
|
||||
if ( ! neighbors (neighs[i], neighs[j])) {
|
||||
cost ++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return cost;
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline unsigned
|
||||
ElimGraph::getWeightedFillCost (const EGNode* n) const
|
||||
{
|
||||
unsigned cost = 0;
|
||||
const EGNeighs& neighs = n->neighbors();
|
||||
if (neighs.size() > 0) {
|
||||
for (size_t i = 0; i < neighs.size() - 1; i++) {
|
||||
for (size_t j = i + 1; j < neighs.size(); j++) {
|
||||
if ( ! neighbors (neighs[i], neighs[j])) {
|
||||
cost += neighs[i]->range() * neighs[j]->range();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return cost;
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline bool
|
||||
ElimGraph::neighbors (EGNode* n1, EGNode* n2) const
|
||||
{
|
||||
return n1->isNeighbor (n2);
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_ELIMGRAPH_H_
|
||||
|
||||
|
@ -1,21 +1,15 @@
|
||||
#include <cstdlib>
|
||||
#include <cassert>
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "Factor.h"
|
||||
#include "Indexer.h"
|
||||
#include "Var.h"
|
||||
|
||||
|
||||
Factor::Factor (const Factor& g)
|
||||
{
|
||||
clone (g);
|
||||
}
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
Factor::Factor (
|
||||
const VarIds& vids,
|
||||
@ -77,7 +71,7 @@ Factor::sumOutAllExcept (VarId vid)
|
||||
void
|
||||
Factor::sumOutAllExcept (const VarIds& vids)
|
||||
{
|
||||
vector<bool> mask (args_.size(), false);
|
||||
std::vector<bool> mask (args_.size(), false);
|
||||
for (unsigned i = 0; i < vids.size(); i++) {
|
||||
assert (indexOf (vids[i]) != args_.size());
|
||||
mask[indexOf (vids[i])] = true;
|
||||
@ -91,28 +85,30 @@ void
|
||||
Factor::sumOutAllExceptIndex (size_t idx)
|
||||
{
|
||||
assert (idx < args_.size());
|
||||
vector<bool> mask (args_.size(), false);
|
||||
std::vector<bool> mask (args_.size(), false);
|
||||
mask[idx] = true;
|
||||
sumOutArgs (mask);
|
||||
}
|
||||
|
||||
|
||||
void
|
||||
Factor::multiply (Factor& g)
|
||||
|
||||
Factor&
|
||||
Factor::multiply (const Factor& g)
|
||||
{
|
||||
if (args_.empty()) {
|
||||
clone (g);
|
||||
operator= (g);
|
||||
} else {
|
||||
TFactor<VarId>::multiply (g);
|
||||
GenericFactor<VarId>::multiply (g);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
||||
|
||||
string
|
||||
Factor::getLabel (void) const
|
||||
std::string
|
||||
Factor::getLabel() const
|
||||
{
|
||||
stringstream ss;
|
||||
std::stringstream ss;
|
||||
ss << "f(" ;
|
||||
for (size_t i = 0; i < args_.size(); i++) {
|
||||
if (i != 0) ss << "," ;
|
||||
@ -125,19 +121,19 @@ Factor::getLabel (void) const
|
||||
|
||||
|
||||
void
|
||||
Factor::print (void) const
|
||||
Factor::print() const
|
||||
{
|
||||
Vars vars;
|
||||
for (size_t i = 0; i < args_.size(); i++) {
|
||||
vars.push_back (new Var (args_[i], ranges_[i]));
|
||||
}
|
||||
vector<string> jointStrings = Util::getStateLines (vars);
|
||||
std::vector<std::string> jointStrings = Util::getStateLines (vars);
|
||||
for (size_t i = 0; i < params_.size(); i++) {
|
||||
// cout << "[" << distId_ << "] " ;
|
||||
cout << "f(" << jointStrings[i] << ")" ;
|
||||
cout << " = " << params_[i] << endl;
|
||||
std::cout << "f(" << jointStrings[i] << ")" ;
|
||||
std::cout << " = " << params_[i] << std::endl;
|
||||
}
|
||||
cout << endl;
|
||||
std::cout << std::endl;
|
||||
for (size_t i = 0; i < vars.size(); i++) {
|
||||
delete vars[i];
|
||||
}
|
||||
@ -146,8 +142,9 @@ Factor::print (void) const
|
||||
|
||||
|
||||
void
|
||||
Factor::sumOutFirstVariable (void)
|
||||
Factor::sumOutFirstVariable()
|
||||
{
|
||||
assert (ranges_.front() == 2);
|
||||
size_t sep = params_.size() / 2;
|
||||
if (Globals::logDomain) {
|
||||
std::transform (
|
||||
@ -169,19 +166,21 @@ Factor::sumOutFirstVariable (void)
|
||||
|
||||
|
||||
void
|
||||
Factor::sumOutLastVariable (void)
|
||||
Factor::sumOutLastVariable()
|
||||
{
|
||||
assert (ranges_.back() == 2);
|
||||
Params::iterator first1 = params_.begin();
|
||||
Params::iterator first2 = params_.begin();
|
||||
Params::iterator last = params_.end();
|
||||
if (Globals::logDomain) {
|
||||
while (first2 != last) {
|
||||
// the arguments can be swaped, but that is ok
|
||||
*first1++ = Util::logSum (*first2++, *first2++);
|
||||
double tmp = *first2++;
|
||||
*first1++ = Util::logSum (tmp, *first2++);
|
||||
}
|
||||
} else {
|
||||
while (first2 != last) {
|
||||
*first1++ = (*first2++) + (*first2++);
|
||||
*first1 = *first2++;
|
||||
*first1++ += *first2++;
|
||||
}
|
||||
}
|
||||
params_.resize (params_.size() / 2);
|
||||
@ -192,7 +191,7 @@ Factor::sumOutLastVariable (void)
|
||||
|
||||
|
||||
void
|
||||
Factor::sumOutArgs (const vector<bool>& mask)
|
||||
Factor::sumOutArgs (const std::vector<bool>& mask)
|
||||
{
|
||||
assert (mask.size() == args_.size());
|
||||
size_t new_size = 1;
|
||||
@ -224,14 +223,5 @@ Factor::sumOutArgs (const vector<bool>& mask)
|
||||
params_ = newps;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::clone (const Factor& g)
|
||||
{
|
||||
args_ = g.arguments();
|
||||
ranges_ = g.ranges();
|
||||
params_ = g.params();
|
||||
distId_ = g.distId();
|
||||
}
|
||||
} // namespace Horus
|
||||
|
||||
|
@ -1,262 +1,20 @@
|
||||
#ifndef HORUS_FACTOR_H
|
||||
#define HORUS_FACTOR_H
|
||||
#ifndef YAP_PACKAGES_CLPBN_HORUS_FACTOR_H_
|
||||
#define YAP_PACKAGES_CLPBN_HORUS_FACTOR_H_
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "Indexer.h"
|
||||
#include "GenericFactor.h"
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
namespace Horus {
|
||||
|
||||
|
||||
template <typename T>
|
||||
class TFactor
|
||||
{
|
||||
class Factor : public GenericFactor<VarId> {
|
||||
public:
|
||||
const vector<T>& arguments (void) const { return args_; }
|
||||
|
||||
vector<T>& arguments (void) { return args_; }
|
||||
|
||||
const Ranges& ranges (void) const { return ranges_; }
|
||||
|
||||
const Params& params (void) const { return params_; }
|
||||
|
||||
Params& params (void) { return params_; }
|
||||
|
||||
size_t nrArguments (void) const { return args_.size(); }
|
||||
|
||||
size_t size (void) const { return params_.size(); }
|
||||
|
||||
unsigned distId (void) const { return distId_; }
|
||||
|
||||
void setDistId (unsigned id) { distId_ = id; }
|
||||
|
||||
void normalize (void) { LogAware::normalize (params_); }
|
||||
|
||||
void randomize (void)
|
||||
{
|
||||
for (size_t i = 0; i < params_.size(); ++i) {
|
||||
params_[i] = (double) std::rand() / RAND_MAX;
|
||||
}
|
||||
}
|
||||
|
||||
void setParams (const Params& newParams)
|
||||
{
|
||||
params_ = newParams;
|
||||
assert (params_.size() == Util::sizeExpected (ranges_));
|
||||
}
|
||||
|
||||
size_t indexOf (const T& t) const
|
||||
{
|
||||
return Util::indexOf (args_, t);
|
||||
}
|
||||
|
||||
const T& argument (size_t idx) const
|
||||
{
|
||||
assert (idx < args_.size());
|
||||
return args_[idx];
|
||||
}
|
||||
|
||||
T& argument (size_t idx)
|
||||
{
|
||||
assert (idx < args_.size());
|
||||
return args_[idx];
|
||||
}
|
||||
|
||||
unsigned range (size_t idx) const
|
||||
{
|
||||
assert (idx < ranges_.size());
|
||||
return ranges_[idx];
|
||||
}
|
||||
|
||||
void multiply (TFactor<T>& g)
|
||||
{
|
||||
if (args_ == g.arguments()) {
|
||||
// optimization
|
||||
Globals::logDomain
|
||||
? params_ += g.params()
|
||||
: params_ *= g.params();
|
||||
return;
|
||||
}
|
||||
unsigned range_prod = 1;
|
||||
bool share_arguments = false;
|
||||
const vector<T>& g_args = g.arguments();
|
||||
const Ranges& g_ranges = g.ranges();
|
||||
const Params& g_params = g.params();
|
||||
for (size_t i = 0; i < g_args.size(); i++) {
|
||||
size_t idx = indexOf (g_args[i]);
|
||||
if (idx == args_.size()) {
|
||||
range_prod *= g_ranges[i];
|
||||
args_.push_back (g_args[i]);
|
||||
ranges_.push_back (g_ranges[i]);
|
||||
} else {
|
||||
share_arguments = true;
|
||||
}
|
||||
}
|
||||
if (share_arguments == false) {
|
||||
// optimization
|
||||
cartesianProduct (g_params.begin(), g_params.end());
|
||||
} else {
|
||||
extend (range_prod);
|
||||
Params::iterator it = params_.begin();
|
||||
MapIndexer indexer (args_, ranges_, g_args, g_ranges);
|
||||
if (Globals::logDomain) {
|
||||
for (; indexer.valid(); ++it, ++indexer) {
|
||||
*it += g_params[indexer];
|
||||
}
|
||||
} else {
|
||||
for (; indexer.valid(); ++it, ++indexer) {
|
||||
*it *= g_params[indexer];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void sumOutIndex (size_t idx)
|
||||
{
|
||||
assert (idx < args_.size());
|
||||
assert (args_.size() > 1);
|
||||
size_t new_size = params_.size() / ranges_[idx];
|
||||
Params newps (new_size, LogAware::addIdenty());
|
||||
Params::const_iterator first = params_.begin();
|
||||
Params::const_iterator last = params_.end();
|
||||
MapIndexer indexer (ranges_, idx);
|
||||
if (Globals::logDomain) {
|
||||
for (; first != last; ++indexer) {
|
||||
newps[indexer] = Util::logSum (newps[indexer], *first++);
|
||||
}
|
||||
} else {
|
||||
for (; first != last; ++indexer) {
|
||||
newps[indexer] += *first++;
|
||||
}
|
||||
}
|
||||
params_ = newps;
|
||||
args_.erase (args_.begin() + idx);
|
||||
ranges_.erase (ranges_.begin() + idx);
|
||||
}
|
||||
|
||||
void absorveEvidence (const T& arg, unsigned obsIdx)
|
||||
{
|
||||
size_t idx = indexOf (arg);
|
||||
assert (idx != args_.size());
|
||||
assert (obsIdx < ranges_[idx]);
|
||||
Params newps;
|
||||
newps.reserve (params_.size() / ranges_[idx]);
|
||||
Indexer indexer (ranges_);
|
||||
for (unsigned i = 0; i < obsIdx; ++i) {
|
||||
indexer.incrementDimension (idx);
|
||||
}
|
||||
while (indexer.valid()) {
|
||||
newps.push_back (params_[indexer]);
|
||||
indexer.incrementExceptDimension (idx);
|
||||
}
|
||||
params_ = newps;
|
||||
args_.erase (args_.begin() + idx);
|
||||
ranges_.erase (ranges_.begin() + idx);
|
||||
}
|
||||
|
||||
void reorderArguments (const vector<T> new_args)
|
||||
{
|
||||
assert (new_args.size() == args_.size());
|
||||
if (new_args == args_) {
|
||||
return; // already on the desired order
|
||||
}
|
||||
Ranges new_ranges;
|
||||
for (size_t i = 0; i < new_args.size(); i++) {
|
||||
size_t idx = indexOf (new_args[i]);
|
||||
assert (idx != args_.size());
|
||||
new_ranges.push_back (ranges_[idx]);
|
||||
}
|
||||
Params newps;
|
||||
newps.reserve (params_.size());
|
||||
MapIndexer indexer (new_args, new_ranges, args_, ranges_);
|
||||
for (; indexer.valid(); ++indexer) {
|
||||
newps.push_back (params_[indexer]);
|
||||
}
|
||||
params_ = newps;
|
||||
args_ = new_args;
|
||||
ranges_ = new_ranges;
|
||||
}
|
||||
|
||||
bool contains (const T& arg) const
|
||||
{
|
||||
return Util::contains (args_, arg);
|
||||
}
|
||||
|
||||
bool contains (const vector<T>& args) const
|
||||
{
|
||||
for (size_t i = 0; i < args.size(); i++) {
|
||||
if (contains (args[i]) == false) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
double& operator[] (size_t idx)
|
||||
{
|
||||
assert (idx < params_.size());
|
||||
return params_[idx];
|
||||
}
|
||||
|
||||
|
||||
protected:
|
||||
vector<T> args_;
|
||||
Ranges ranges_;
|
||||
Params params_;
|
||||
unsigned distId_;
|
||||
|
||||
private:
|
||||
void extend (unsigned range_prod)
|
||||
{
|
||||
Params backup = params_;
|
||||
params_.clear();
|
||||
params_.reserve (backup.size() * range_prod);
|
||||
Params::const_iterator first = backup.begin();
|
||||
Params::const_iterator last = backup.end();
|
||||
for (; first != last; ++first) {
|
||||
for (unsigned reps = 0; reps < range_prod; ++reps) {
|
||||
params_.push_back (*first);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void cartesianProduct (
|
||||
Params::const_iterator first2,
|
||||
Params::const_iterator last2)
|
||||
{
|
||||
Params backup = params_;
|
||||
params_.clear();
|
||||
params_.reserve (params_.size() * (last2 - first2));
|
||||
Params::const_iterator first1 = backup.begin();
|
||||
Params::const_iterator last1 = backup.end();
|
||||
Params::const_iterator tmp;
|
||||
if (Globals::logDomain) {
|
||||
for (; first1 != last1; ++first1) {
|
||||
for (tmp = first2; tmp != last2; ++tmp) {
|
||||
params_.push_back ((*first1) + (*tmp));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (; first1 != last1; ++first1) {
|
||||
for (tmp = first2; tmp != last2; ++tmp) {
|
||||
params_.push_back ((*first1) * (*tmp));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
|
||||
|
||||
class Factor : public TFactor<VarId>
|
||||
{
|
||||
public:
|
||||
Factor (void) { }
|
||||
|
||||
Factor (const Factor&);
|
||||
Factor() { }
|
||||
|
||||
Factor (const VarIds&, const Ranges&, const Params&,
|
||||
unsigned = Util::maxUnsigned());
|
||||
@ -272,23 +30,21 @@ class Factor : public TFactor<VarId>
|
||||
|
||||
void sumOutAllExceptIndex (size_t idx);
|
||||
|
||||
void multiply (Factor&);
|
||||
Factor& multiply (const Factor&);
|
||||
|
||||
string getLabel (void) const;
|
||||
std::string getLabel() const;
|
||||
|
||||
void print (void) const;
|
||||
void print() const;
|
||||
|
||||
private:
|
||||
void sumOutFirstVariable (void);
|
||||
void sumOutFirstVariable();
|
||||
|
||||
void sumOutLastVariable (void);
|
||||
void sumOutLastVariable();
|
||||
|
||||
void sumOutArgs (const vector<bool>& mask);
|
||||
|
||||
void clone (const Factor& f);
|
||||
|
||||
DISALLOW_ASSIGN (Factor);
|
||||
void sumOutArgs (const std::vector<bool>& mask);
|
||||
};
|
||||
|
||||
#endif // HORUS_FACTOR_H
|
||||
} // namespace Horus
|
||||
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_FACTOR_H_
|
||||
|
||||
|
@ -1,17 +1,15 @@
|
||||
#include <cassert>
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <fstream>
|
||||
|
||||
#include "FactorGraph.h"
|
||||
#include "BayesBall.h"
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
bool FactorGraph::exportLd_ = false;
|
||||
bool FactorGraph::exportUai_ = false;
|
||||
bool FactorGraph::exportGv_ = false;
|
||||
@ -20,25 +18,12 @@ bool FactorGraph::printFg_ = false;
|
||||
|
||||
FactorGraph::FactorGraph (const FactorGraph& fg)
|
||||
{
|
||||
const VarNodes& varNodes = fg.varNodes();
|
||||
for (size_t i = 0; i < varNodes.size(); i++) {
|
||||
addVarNode (new VarNode (varNodes[i]));
|
||||
}
|
||||
const FacNodes& facNodes = fg.facNodes();
|
||||
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||
FacNode* facNode = new FacNode (facNodes[i]->factor());
|
||||
addFacNode (facNode);
|
||||
const VarNodes& neighs = facNodes[i]->neighbors();
|
||||
for (size_t j = 0; j < neighs.size(); j++) {
|
||||
addEdge (varNodes_[neighs[j]->getIndex()], facNode);
|
||||
}
|
||||
}
|
||||
bayesFactors_ = fg.bayesianFactors();
|
||||
clone (fg);
|
||||
}
|
||||
|
||||
|
||||
|
||||
FactorGraph::~FactorGraph (void)
|
||||
FactorGraph::~FactorGraph()
|
||||
{
|
||||
for (size_t i = 0; i < varNodes_.size(); i++) {
|
||||
delete varNodes_[i];
|
||||
@ -50,152 +35,6 @@ FactorGraph::~FactorGraph (void)
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::readFromUaiFormat (const char* fileName)
|
||||
{
|
||||
std::ifstream is (fileName);
|
||||
if (!is.is_open()) {
|
||||
cerr << "Error: couldn't open file '" << fileName << "'." ;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
ignoreLines (is);
|
||||
string line;
|
||||
getline (is, line);
|
||||
if (line == "BAYES") {
|
||||
bayesFactors_ = true;
|
||||
} else if (line == "MARKOV") {
|
||||
bayesFactors_ = false;
|
||||
} else {
|
||||
cerr << "Error: the type of network is missing." << endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
// read the number of vars
|
||||
ignoreLines (is);
|
||||
unsigned nrVars;
|
||||
is >> nrVars;
|
||||
// read the range of each var
|
||||
ignoreLines (is);
|
||||
Ranges ranges (nrVars);
|
||||
for (unsigned i = 0; i < nrVars; i++) {
|
||||
is >> ranges[i];
|
||||
}
|
||||
unsigned nrFactors;
|
||||
unsigned nrArgs;
|
||||
unsigned vid;
|
||||
is >> nrFactors;
|
||||
vector<VarIds> allVarIds;
|
||||
vector<Ranges> allRanges;
|
||||
for (unsigned i = 0; i < nrFactors; i++) {
|
||||
ignoreLines (is);
|
||||
is >> nrArgs;
|
||||
allVarIds.push_back ({ });
|
||||
allRanges.push_back ({ });
|
||||
for (unsigned j = 0; j < nrArgs; j++) {
|
||||
is >> vid;
|
||||
if (vid >= ranges.size()) {
|
||||
cerr << "Error: invalid variable identifier `" << vid << "'. " ;
|
||||
cerr << "Identifiers must be between 0 and " << ranges.size() - 1 ;
|
||||
cerr << "." << endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
allVarIds.back().push_back (vid);
|
||||
allRanges.back().push_back (ranges[vid]);
|
||||
}
|
||||
}
|
||||
// read the parameters
|
||||
unsigned nrParams;
|
||||
for (unsigned i = 0; i < nrFactors; i++) {
|
||||
ignoreLines (is);
|
||||
is >> nrParams;
|
||||
if (nrParams != Util::sizeExpected (allRanges[i])) {
|
||||
cerr << "Error: invalid number of parameters for factor nº " << i ;
|
||||
cerr << ", " << Util::sizeExpected (allRanges[i]);
|
||||
cerr << " expected, " << nrParams << " given." << endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
Params params (nrParams);
|
||||
for (unsigned j = 0; j < nrParams; j++) {
|
||||
is >> params[j];
|
||||
}
|
||||
if (Globals::logDomain) {
|
||||
Util::log (params);
|
||||
}
|
||||
Factor f (allVarIds[i], allRanges[i], params);
|
||||
if (bayesFactors_ && allVarIds[i].size() > 1) {
|
||||
// In this format the child is the last variable,
|
||||
// move it to be the first
|
||||
std::swap (allVarIds[i].front(), allVarIds[i].back());
|
||||
f.reorderArguments (allVarIds[i]);
|
||||
}
|
||||
addFactor (f);
|
||||
}
|
||||
is.close();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::readFromLibDaiFormat (const char* fileName)
|
||||
{
|
||||
std::ifstream is (fileName);
|
||||
if (!is.is_open()) {
|
||||
cerr << "Error: couldn't open file '" << fileName << "'." ;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
ignoreLines (is);
|
||||
unsigned nrFactors;
|
||||
unsigned nrArgs;
|
||||
VarId vid;
|
||||
is >> nrFactors;
|
||||
for (unsigned i = 0; i < nrFactors; i++) {
|
||||
ignoreLines (is);
|
||||
// read the factor arguments
|
||||
is >> nrArgs;
|
||||
VarIds vids;
|
||||
for (unsigned j = 0; j < nrArgs; j++) {
|
||||
ignoreLines (is);
|
||||
is >> vid;
|
||||
vids.push_back (vid);
|
||||
}
|
||||
// read ranges
|
||||
Ranges ranges (nrArgs);
|
||||
for (unsigned j = 0; j < nrArgs; j++) {
|
||||
ignoreLines (is);
|
||||
is >> ranges[j];
|
||||
VarNode* var = getVarNode (vids[j]);
|
||||
if (var && ranges[j] != var->range()) {
|
||||
cerr << "Error: variable `" << vids[j] << "' appears in two or " ;
|
||||
cerr << "more factors with a different range." << endl;
|
||||
}
|
||||
}
|
||||
// read parameters
|
||||
ignoreLines (is);
|
||||
unsigned nNonzeros;
|
||||
is >> nNonzeros;
|
||||
Params params (Util::sizeExpected (ranges), 0);
|
||||
for (unsigned j = 0; j < nNonzeros; j++) {
|
||||
ignoreLines (is);
|
||||
unsigned index;
|
||||
is >> index;
|
||||
ignoreLines (is);
|
||||
double val;
|
||||
is >> val;
|
||||
params[index] = val;
|
||||
}
|
||||
if (Globals::logDomain) {
|
||||
Util::log (params);
|
||||
}
|
||||
std::reverse (vids.begin(), vids.end());
|
||||
Factor f (vids, ranges, params);
|
||||
std::reverse (vids.begin(), vids.end());
|
||||
f.reorderArguments (vids);
|
||||
addFactor (f);
|
||||
}
|
||||
is.close();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::addFactor (const Factor& factor)
|
||||
{
|
||||
@ -221,7 +60,7 @@ FactorGraph::addVarNode (VarNode* vn)
|
||||
{
|
||||
varNodes_.push_back (vn);
|
||||
vn->setIndex (varNodes_.size() - 1);
|
||||
varMap_.insert (make_pair (vn->varId(), vn));
|
||||
varMap_.insert (std::make_pair (vn->varId(), vn));
|
||||
}
|
||||
|
||||
|
||||
@ -245,7 +84,7 @@ FactorGraph::addEdge (VarNode* vn, FacNode* fn)
|
||||
|
||||
|
||||
bool
|
||||
FactorGraph::isTree (void) const
|
||||
FactorGraph::isTree() const
|
||||
{
|
||||
return !containsCycle();
|
||||
}
|
||||
@ -253,7 +92,7 @@ FactorGraph::isTree (void) const
|
||||
|
||||
|
||||
BayesBallGraph&
|
||||
FactorGraph::getStructure (void)
|
||||
FactorGraph::getStructure()
|
||||
{
|
||||
assert (bayesFactors_);
|
||||
if (structure_.empty()) {
|
||||
@ -273,8 +112,10 @@ FactorGraph::getStructure (void)
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::print (void) const
|
||||
FactorGraph::print() const
|
||||
{
|
||||
using std::cout;
|
||||
using std::endl;
|
||||
for (size_t i = 0; i < varNodes_.size(); i++) {
|
||||
cout << "var id = " << varNodes_[i]->varId() << endl;
|
||||
cout << "label = " << varNodes_[i]->label() << endl;
|
||||
@ -296,28 +137,29 @@ FactorGraph::print (void) const
|
||||
void
|
||||
FactorGraph::exportToLibDai (const char* fileName) const
|
||||
{
|
||||
ofstream out (fileName);
|
||||
std::ofstream out (fileName);
|
||||
if (!out.is_open()) {
|
||||
cerr << "Error: couldn't open file '" << fileName << "'." ;
|
||||
std::cerr << "Error: couldn't open file '" << fileName << "'." ;
|
||||
std::cerr << std::endl;
|
||||
return;
|
||||
}
|
||||
out << facNodes_.size() << endl << endl;
|
||||
out << facNodes_.size() << std::endl << std::endl;
|
||||
for (size_t i = 0; i < facNodes_.size(); i++) {
|
||||
Factor f (facNodes_[i]->factor());
|
||||
out << f.nrArguments() << endl;
|
||||
out << Util::elementsToString (f.arguments()) << endl;
|
||||
out << Util::elementsToString (f.ranges()) << endl;
|
||||
out << f.nrArguments() << std::endl;
|
||||
out << Util::elementsToString (f.arguments()) << std::endl;
|
||||
out << Util::elementsToString (f.ranges()) << std::endl;
|
||||
VarIds args = f.arguments();
|
||||
std::reverse (args.begin(), args.end());
|
||||
f.reorderArguments (args);
|
||||
if (Globals::logDomain) {
|
||||
Util::exp (f.params());
|
||||
}
|
||||
out << f.size() << endl;
|
||||
out << f.size() << std::endl;
|
||||
for (size_t j = 0; j < f.size(); j++) {
|
||||
out << j << " " << f[j] << endl;
|
||||
out << j << " " << f[j] << std::endl;
|
||||
}
|
||||
out << endl;
|
||||
out << std::endl;
|
||||
}
|
||||
out.close();
|
||||
}
|
||||
@ -327,28 +169,30 @@ FactorGraph::exportToLibDai (const char* fileName) const
|
||||
void
|
||||
FactorGraph::exportToUai (const char* fileName) const
|
||||
{
|
||||
ofstream out (fileName);
|
||||
std::ofstream out (fileName);
|
||||
if (!out.is_open()) {
|
||||
cerr << "Error: couldn't open file '" << fileName << "'." ;
|
||||
std::cerr << "Error: couldn't open file '" << fileName << "'." ;
|
||||
std::cerr << std::endl;
|
||||
return;
|
||||
}
|
||||
out << (bayesFactors_ ? "BAYES" : "MARKOV") ;
|
||||
out << endl << endl;
|
||||
out << varNodes_.size() << endl;
|
||||
out << std::endl << std::endl;
|
||||
out << varNodes_.size() << std::endl;
|
||||
VarNodes sortedVns = varNodes_;
|
||||
std::sort (sortedVns.begin(), sortedVns.end(), sortByVarId());
|
||||
for (size_t i = 0; i < sortedVns.size(); i++) {
|
||||
out << ((i != 0) ? " " : "") << sortedVns[i]->range();
|
||||
}
|
||||
out << endl << facNodes_.size() << endl;
|
||||
out << std::endl << facNodes_.size() << std::endl;
|
||||
for (size_t i = 0; i < facNodes_.size(); i++) {
|
||||
VarIds args = facNodes_[i]->factor().arguments();
|
||||
if (bayesFactors_) {
|
||||
std::swap (args.front(), args.back());
|
||||
}
|
||||
out << args.size() << " " << Util::elementsToString (args) << endl;
|
||||
out << args.size() << " " << Util::elementsToString (args);
|
||||
out << std::endl;
|
||||
}
|
||||
out << endl;
|
||||
out << std::endl;
|
||||
for (size_t i = 0; i < facNodes_.size(); i++) {
|
||||
Factor f = facNodes_[i]->factor();
|
||||
if (bayesFactors_) {
|
||||
@ -360,8 +204,9 @@ FactorGraph::exportToUai (const char* fileName) const
|
||||
if (Globals::logDomain) {
|
||||
Util::exp (params);
|
||||
}
|
||||
out << params.size() << endl << " " ;
|
||||
out << Util::elementsToString (params) << endl << endl;
|
||||
out << params.size() << std::endl << " " ;
|
||||
out << Util::elementsToString (params);
|
||||
out << std::endl << std::endl;
|
||||
}
|
||||
out.close();
|
||||
}
|
||||
@ -371,53 +216,239 @@ FactorGraph::exportToUai (const char* fileName) const
|
||||
void
|
||||
FactorGraph::exportToGraphViz (const char* fileName) const
|
||||
{
|
||||
ofstream out (fileName);
|
||||
std::ofstream out (fileName);
|
||||
if (!out.is_open()) {
|
||||
cerr << "Error: couldn't open file '" << fileName << "'." ;
|
||||
std::cerr << "Error: couldn't open file '" << fileName << "'." ;
|
||||
std::cerr << std::endl;
|
||||
return;
|
||||
}
|
||||
out << "graph \"" << fileName << "\" {" << endl;
|
||||
out << "graph \"" << fileName << "\" {" << std::endl;
|
||||
for (size_t i = 0; i < varNodes_.size(); i++) {
|
||||
if (varNodes_[i]->hasEvidence()) {
|
||||
out << '"' << varNodes_[i]->label() << '"' ;
|
||||
out << " [style=filled, fillcolor=yellow]" << endl;
|
||||
out << " [style=filled, fillcolor=yellow]" << std::endl;
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < facNodes_.size(); i++) {
|
||||
out << '"' << facNodes_[i]->getLabel() << '"' ;
|
||||
out << " [label=\"" << facNodes_[i]->getLabel();
|
||||
out << "\"" << ", shape=box]" << endl;
|
||||
out << "\"" << ", shape=box]" << std::endl;
|
||||
}
|
||||
for (size_t i = 0; i < facNodes_.size(); i++) {
|
||||
const VarNodes& myVars = facNodes_[i]->neighbors();
|
||||
for (size_t j = 0; j < myVars.size(); j++) {
|
||||
out << '"' << facNodes_[i]->getLabel() << '"' ;
|
||||
out << " -- " ;
|
||||
out << '"' << myVars[j]->label() << '"' << endl;
|
||||
out << '"' << myVars[j]->label() << '"' << std::endl;
|
||||
}
|
||||
}
|
||||
out << "}" << endl;
|
||||
out << "}" << std::endl;
|
||||
out.close();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::ignoreLines (std::ifstream& is) const
|
||||
FactorGraph&
|
||||
FactorGraph::operator= (const FactorGraph& fg)
|
||||
{
|
||||
string ignoreStr;
|
||||
while (is.peek() == '#' || is.peek() == '\n') {
|
||||
getline (is, ignoreStr);
|
||||
if (this != &fg) {
|
||||
for (size_t i = 0; i < varNodes_.size(); i++) {
|
||||
delete varNodes_[i];
|
||||
}
|
||||
varNodes_.clear();
|
||||
for (size_t i = 0; i < facNodes_.size(); i++) {
|
||||
delete facNodes_[i];
|
||||
}
|
||||
facNodes_.clear();
|
||||
varMap_.clear();
|
||||
clone (fg);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
||||
|
||||
FactorGraph
|
||||
FactorGraph::readFromUaiFormat (const char* fileName)
|
||||
{
|
||||
std::ifstream is (fileName);
|
||||
if (!is.is_open()) {
|
||||
std::cerr << "Error: couldn't open file '" << fileName << "'." ;
|
||||
std::cerr << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
FactorGraph fg;
|
||||
ignoreLines (is);
|
||||
std::string line;
|
||||
getline (is, line);
|
||||
if (line == "BAYES") {
|
||||
fg.bayesFactors_ = true;
|
||||
} else if (line == "MARKOV") {
|
||||
fg.bayesFactors_ = false;
|
||||
} else {
|
||||
std::cerr << "Error: the type of network is missing." << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
// read the number of vars
|
||||
ignoreLines (is);
|
||||
unsigned nrVars;
|
||||
is >> nrVars;
|
||||
// read the range of each var
|
||||
ignoreLines (is);
|
||||
Ranges ranges (nrVars);
|
||||
for (unsigned i = 0; i < nrVars; i++) {
|
||||
is >> ranges[i];
|
||||
}
|
||||
unsigned nrFactors;
|
||||
unsigned nrArgs;
|
||||
unsigned vid;
|
||||
is >> nrFactors;
|
||||
std::vector<VarIds> allVarIds;
|
||||
std::vector<Ranges> allRanges;
|
||||
for (unsigned i = 0; i < nrFactors; i++) {
|
||||
ignoreLines (is);
|
||||
is >> nrArgs;
|
||||
allVarIds.push_back ({ });
|
||||
allRanges.push_back ({ });
|
||||
for (unsigned j = 0; j < nrArgs; j++) {
|
||||
is >> vid;
|
||||
if (vid >= ranges.size()) {
|
||||
std::cerr << "Error: invalid variable identifier `" << vid << "'" ;
|
||||
std::cerr << ". Identifiers must be between 0 and " ;
|
||||
std::cerr << ranges.size() - 1 << "." << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
allVarIds.back().push_back (vid);
|
||||
allRanges.back().push_back (ranges[vid]);
|
||||
}
|
||||
}
|
||||
// read the parameters
|
||||
unsigned nrParams;
|
||||
for (unsigned i = 0; i < nrFactors; i++) {
|
||||
ignoreLines (is);
|
||||
is >> nrParams;
|
||||
if (nrParams != Util::sizeExpected (allRanges[i])) {
|
||||
std::cerr << "Error: invalid number of parameters for factor nº " ;
|
||||
std::cerr << i << ", " << Util::sizeExpected (allRanges[i]);
|
||||
std::cerr << " expected, " << nrParams << " given." << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
Params params (nrParams);
|
||||
for (unsigned j = 0; j < nrParams; j++) {
|
||||
is >> params[j];
|
||||
}
|
||||
if (Globals::logDomain) {
|
||||
Util::log (params);
|
||||
}
|
||||
Factor f (allVarIds[i], allRanges[i], params);
|
||||
if (fg.bayesFactors_ && allVarIds[i].size() > 1) {
|
||||
// In this format the child is the last variable,
|
||||
// move it to be the first
|
||||
std::swap (allVarIds[i].front(), allVarIds[i].back());
|
||||
f.reorderArguments (allVarIds[i]);
|
||||
}
|
||||
fg.addFactor (f);
|
||||
}
|
||||
is.close();
|
||||
return fg;
|
||||
}
|
||||
|
||||
|
||||
|
||||
FactorGraph
|
||||
FactorGraph::readFromLibDaiFormat (const char* fileName)
|
||||
{
|
||||
std::ifstream is (fileName);
|
||||
if (!is.is_open()) {
|
||||
std::cerr << "Error: couldn't open file '" << fileName << "'." ;
|
||||
std::cerr << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
FactorGraph fg;
|
||||
ignoreLines (is);
|
||||
unsigned nrFactors;
|
||||
unsigned nrArgs;
|
||||
VarId vid;
|
||||
is >> nrFactors;
|
||||
for (unsigned i = 0; i < nrFactors; i++) {
|
||||
ignoreLines (is);
|
||||
// read the factor arguments
|
||||
is >> nrArgs;
|
||||
VarIds vids;
|
||||
for (unsigned j = 0; j < nrArgs; j++) {
|
||||
ignoreLines (is);
|
||||
is >> vid;
|
||||
vids.push_back (vid);
|
||||
}
|
||||
// read ranges
|
||||
Ranges ranges (nrArgs);
|
||||
for (unsigned j = 0; j < nrArgs; j++) {
|
||||
ignoreLines (is);
|
||||
is >> ranges[j];
|
||||
VarNode* var = fg.getVarNode (vids[j]);
|
||||
if (var && ranges[j] != var->range()) {
|
||||
std::cerr << "Error: variable `" << vids[j] << "' appears" ;
|
||||
std::cerr << " in two or more factors with a different range." ;
|
||||
std::cerr << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
}
|
||||
// read parameters
|
||||
ignoreLines (is);
|
||||
unsigned nNonzeros;
|
||||
is >> nNonzeros;
|
||||
Params params (Util::sizeExpected (ranges), 0);
|
||||
for (unsigned j = 0; j < nNonzeros; j++) {
|
||||
ignoreLines (is);
|
||||
unsigned index;
|
||||
is >> index;
|
||||
ignoreLines (is);
|
||||
double val;
|
||||
is >> val;
|
||||
params[index] = val;
|
||||
}
|
||||
if (Globals::logDomain) {
|
||||
Util::log (params);
|
||||
}
|
||||
std::reverse (vids.begin(), vids.end());
|
||||
std::reverse (ranges.begin(), ranges.end());
|
||||
Factor f (vids, ranges, params);
|
||||
std::reverse (vids.begin(), vids.end());
|
||||
f.reorderArguments (vids);
|
||||
fg.addFactor (f);
|
||||
}
|
||||
is.close();
|
||||
return fg;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::clone (const FactorGraph& fg)
|
||||
{
|
||||
const VarNodes& varNodes = fg.varNodes();
|
||||
for (size_t i = 0; i < varNodes.size(); i++) {
|
||||
addVarNode (new VarNode (varNodes[i]));
|
||||
}
|
||||
const FacNodes& facNodes = fg.facNodes();
|
||||
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||
FacNode* facNode = new FacNode (facNodes[i]->factor());
|
||||
addFacNode (facNode);
|
||||
const VarNodes& neighs = facNodes[i]->neighbors();
|
||||
for (size_t j = 0; j < neighs.size(); j++) {
|
||||
addEdge (varNodes_[neighs[j]->getIndex()], facNode);
|
||||
}
|
||||
}
|
||||
bayesFactors_ = fg.bayesianFactors();
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
FactorGraph::containsCycle (void) const
|
||||
FactorGraph::containsCycle() const
|
||||
{
|
||||
vector<bool> visitedVars (varNodes_.size(), false);
|
||||
vector<bool> visitedFactors (facNodes_.size(), false);
|
||||
std::vector<bool> visitedVars (varNodes_.size(), false);
|
||||
std::vector<bool> visitedFactors (facNodes_.size(), false);
|
||||
for (size_t i = 0; i < varNodes_.size(); i++) {
|
||||
int v = varNodes_[i]->getIndex();
|
||||
if (!visitedVars[v]) {
|
||||
@ -435,8 +466,8 @@ bool
|
||||
FactorGraph::containsCycle (
|
||||
const VarNode* v,
|
||||
const FacNode* p,
|
||||
vector<bool>& visitedVars,
|
||||
vector<bool>& visitedFactors) const
|
||||
std::vector<bool>& visitedVars,
|
||||
std::vector<bool>& visitedFactors) const
|
||||
{
|
||||
visitedVars[v->getIndex()] = true;
|
||||
const FacNodes& adjacencies = v->neighbors();
|
||||
@ -460,8 +491,8 @@ bool
|
||||
FactorGraph::containsCycle (
|
||||
const FacNode* v,
|
||||
const VarNode* p,
|
||||
vector<bool>& visitedVars,
|
||||
vector<bool>& visitedFactors) const
|
||||
std::vector<bool>& visitedVars,
|
||||
std::vector<bool>& visitedFactors) const
|
||||
{
|
||||
visitedFactors[v->getIndex()] = true;
|
||||
const VarNodes& adjacencies = v->neighbors();
|
||||
@ -479,3 +510,16 @@ FactorGraph::containsCycle (
|
||||
return false; // no cycle detected in this component
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::ignoreLines (std::ifstream& is)
|
||||
{
|
||||
std::string ignoreStr;
|
||||
while (is.peek() == '#' || is.peek() == '\n') {
|
||||
getline (is, ignoreStr);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
|
@ -1,28 +1,32 @@
|
||||
#ifndef HORUS_FACTORGRAPH_H
|
||||
#define HORUS_FACTORGRAPH_H
|
||||
#ifndef YAP_PACKAGES_CLPBN_HORUS_FACTORGRAPH_H_
|
||||
#define YAP_PACKAGES_CLPBN_HORUS_FACTORGRAPH_H_
|
||||
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include <string>
|
||||
#include <fstream>
|
||||
|
||||
#include "Factor.h"
|
||||
#include "BayesBallGraph.h"
|
||||
#include "Horus.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace Horus {
|
||||
|
||||
class FacNode;
|
||||
|
||||
class VarNode : public Var
|
||||
{
|
||||
|
||||
class VarNode : public Var {
|
||||
public:
|
||||
VarNode (VarId varId, unsigned nrStates,
|
||||
int evidence = Constants::NO_EVIDENCE)
|
||||
int evidence = Constants::unobserved)
|
||||
: Var (varId, nrStates, evidence) { }
|
||||
|
||||
VarNode (const Var* v) : Var (v) { }
|
||||
|
||||
void addNeighbor (FacNode* fn) { neighs_.push_back (fn); }
|
||||
|
||||
const FacNodes& neighbors (void) const { return neighs_; }
|
||||
const FacNodes& neighbors() const { return neighs_; }
|
||||
|
||||
private:
|
||||
FacNodes neighs_;
|
||||
@ -32,24 +36,23 @@ class VarNode : public Var
|
||||
|
||||
|
||||
|
||||
class FacNode
|
||||
{
|
||||
class FacNode {
|
||||
public:
|
||||
FacNode (const Factor& f) : factor_(f), index_(-1) { }
|
||||
|
||||
const Factor& factor (void) const { return factor_; }
|
||||
const Factor& factor() const { return factor_; }
|
||||
|
||||
Factor& factor (void) { return factor_; }
|
||||
Factor& factor() { return factor_; }
|
||||
|
||||
void addNeighbor (VarNode* vn) { neighs_.push_back (vn); }
|
||||
|
||||
const VarNodes& neighbors (void) const { return neighs_; }
|
||||
const VarNodes& neighbors() const { return neighs_; }
|
||||
|
||||
size_t getIndex (void) const { return index_; }
|
||||
size_t getIndex() const { return index_; }
|
||||
|
||||
void setIndex (size_t index) { index_ = index; }
|
||||
|
||||
string getLabel (void) { return factor_.getLabel(); }
|
||||
std::string getLabel() { return factor_.getLabel(); }
|
||||
|
||||
private:
|
||||
VarNodes neighs_;
|
||||
@ -61,36 +64,27 @@ class FacNode
|
||||
|
||||
|
||||
|
||||
class FactorGraph
|
||||
{
|
||||
class FactorGraph {
|
||||
public:
|
||||
FactorGraph (void) : bayesFactors_(false) { }
|
||||
FactorGraph() : bayesFactors_(false) { }
|
||||
|
||||
FactorGraph (const FactorGraph&);
|
||||
|
||||
~FactorGraph (void);
|
||||
~FactorGraph();
|
||||
|
||||
const VarNodes& varNodes (void) const { return varNodes_; }
|
||||
const VarNodes& varNodes() const { return varNodes_; }
|
||||
|
||||
const FacNodes& facNodes (void) const { return facNodes_; }
|
||||
const FacNodes& facNodes() const { return facNodes_; }
|
||||
|
||||
void setFactorsAsBayesian (void) { bayesFactors_ = true; }
|
||||
void setFactorsAsBayesian() { bayesFactors_ = true; }
|
||||
|
||||
bool bayesianFactors (void) const { return bayesFactors_; }
|
||||
bool bayesianFactors() const { return bayesFactors_; }
|
||||
|
||||
size_t nrVarNodes (void) const { return varNodes_.size(); }
|
||||
size_t nrVarNodes() const { return varNodes_.size(); }
|
||||
|
||||
size_t nrFacNodes (void) const { return facNodes_.size(); }
|
||||
size_t nrFacNodes() const { return facNodes_.size(); }
|
||||
|
||||
VarNode* getVarNode (VarId vid) const
|
||||
{
|
||||
VarMap::const_iterator it = varMap_.find (vid);
|
||||
return it != varMap_.end() ? it->second : 0;
|
||||
}
|
||||
|
||||
void readFromUaiFormat (const char*);
|
||||
|
||||
void readFromLibDaiFormat (const char*);
|
||||
VarNode* getVarNode (VarId vid) const;
|
||||
|
||||
void addFactor (const Factor& factor);
|
||||
|
||||
@ -100,11 +94,11 @@ class FactorGraph
|
||||
|
||||
void addEdge (VarNode*, FacNode*);
|
||||
|
||||
bool isTree (void) const;
|
||||
bool isTree() const;
|
||||
|
||||
BayesBallGraph& getStructure (void);
|
||||
BayesBallGraph& getStructure();
|
||||
|
||||
void print (void) const;
|
||||
void print() const;
|
||||
|
||||
void exportToLibDai (const char*) const;
|
||||
|
||||
@ -112,67 +106,80 @@ class FactorGraph
|
||||
|
||||
void exportToGraphViz (const char*) const;
|
||||
|
||||
static bool exportToLibDai (void) { return exportLd_; }
|
||||
FactorGraph& operator= (const FactorGraph&);
|
||||
|
||||
static bool exportToUai (void) { return exportUai_; }
|
||||
static FactorGraph readFromUaiFormat (const char*);
|
||||
|
||||
static bool exportGraphViz (void) { return exportGv_; }
|
||||
static FactorGraph readFromLibDaiFormat (const char*);
|
||||
|
||||
static bool printFactorGraph (void) { return printFg_; }
|
||||
static bool exportToLibDai() { return exportLd_; }
|
||||
|
||||
static void enableExportToLibDai (void) { exportLd_ = true; }
|
||||
static bool exportToUai() { return exportUai_; }
|
||||
|
||||
static void disableExportToLibDai (void) { exportLd_ = false; }
|
||||
static bool exportGraphViz() { return exportGv_; }
|
||||
|
||||
static void enableExportToUai (void) { exportUai_ = true; }
|
||||
static bool printFactorGraph() { return printFg_; }
|
||||
|
||||
static void disableExportToUai (void) { exportUai_ = false; }
|
||||
static void enableExportToLibDai() { exportLd_ = true; }
|
||||
|
||||
static void enableExportToGraphViz (void) { exportGv_ = true; }
|
||||
static void disableExportToLibDai() { exportLd_ = false; }
|
||||
|
||||
static void disableExportToGraphViz (void) { exportGv_ = false; }
|
||||
static void enableExportToUai() { exportUai_ = true; }
|
||||
|
||||
static void enablePrintFactorGraph (void) { printFg_ = true; }
|
||||
static void disableExportToUai() { exportUai_ = false; }
|
||||
|
||||
static void disablePrintFactorGraph (void) { printFg_ = false; }
|
||||
static void enableExportToGraphViz() { exportGv_ = true; }
|
||||
|
||||
static void disableExportToGraphViz() { exportGv_ = false; }
|
||||
|
||||
static void enablePrintFactorGraph() { printFg_ = true; }
|
||||
|
||||
static void disablePrintFactorGraph() { printFg_ = false; }
|
||||
|
||||
private:
|
||||
void ignoreLines (std::ifstream&) const;
|
||||
typedef std::unordered_map<unsigned, VarNode*> VarMap;
|
||||
|
||||
bool containsCycle (void) const;
|
||||
void clone (const FactorGraph& fg);
|
||||
|
||||
bool containsCycle() const;
|
||||
|
||||
bool containsCycle (const VarNode*, const FacNode*,
|
||||
vector<bool>&, vector<bool>&) const;
|
||||
std::vector<bool>&, std::vector<bool>&) const;
|
||||
|
||||
bool containsCycle (const FacNode*, const VarNode*,
|
||||
vector<bool>&, vector<bool>&) const;
|
||||
std::vector<bool>&, std::vector<bool>&) const;
|
||||
|
||||
VarNodes varNodes_;
|
||||
FacNodes facNodes_;
|
||||
static void ignoreLines (std::ifstream&);
|
||||
|
||||
VarNodes varNodes_;
|
||||
FacNodes facNodes_;
|
||||
VarMap varMap_;
|
||||
BayesBallGraph structure_;
|
||||
bool bayesFactors_;
|
||||
|
||||
typedef unordered_map<unsigned, VarNode*> VarMap;
|
||||
VarMap varMap_;
|
||||
|
||||
static bool exportLd_;
|
||||
static bool exportUai_;
|
||||
static bool exportGv_;
|
||||
static bool printFg_;
|
||||
|
||||
DISALLOW_ASSIGN (FactorGraph);
|
||||
static bool exportLd_;
|
||||
static bool exportUai_;
|
||||
static bool exportGv_;
|
||||
static bool printFg_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
struct sortByVarId
|
||||
inline VarNode*
|
||||
FactorGraph::getVarNode (VarId vid) const
|
||||
{
|
||||
VarMap::const_iterator it = varMap_.find (vid);
|
||||
return it != varMap_.end() ? it->second : 0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
struct sortByVarId {
|
||||
bool operator()(VarNode* vn1, VarNode* vn2) {
|
||||
return vn1->varId() < vn2->varId();
|
||||
}
|
||||
};
|
||||
}};
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
#endif // HORUS_FACTORGRAPH_H
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_FACTORGRAPH_H_
|
||||
|
||||
|
256
packages/CLPBN/horus/GenericFactor.cpp
Normal file
256
packages/CLPBN/horus/GenericFactor.cpp
Normal file
@ -0,0 +1,256 @@
|
||||
#include <cassert>
|
||||
|
||||
#include "GenericFactor.h"
|
||||
#include "ProbFormula.h"
|
||||
#include "Indexer.h"
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
template <typename T> const T&
|
||||
GenericFactor<T>::argument (size_t idx) const
|
||||
{
|
||||
assert (idx < args_.size());
|
||||
return args_[idx];
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T> T&
|
||||
GenericFactor<T>::argument (size_t idx)
|
||||
{
|
||||
assert (idx < args_.size());
|
||||
return args_[idx];
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T> unsigned
|
||||
GenericFactor<T>::range (size_t idx) const
|
||||
{
|
||||
assert (idx < ranges_.size());
|
||||
return ranges_[idx];
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T> bool
|
||||
GenericFactor<T>::contains (const T& arg) const
|
||||
{
|
||||
return Util::contains (args_, arg);
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T> bool
|
||||
GenericFactor<T>::contains (const std::vector<T>& args) const
|
||||
{
|
||||
for (size_t i = 0; i < args.size(); i++) {
|
||||
if (contains (args[i]) == false) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T> void
|
||||
GenericFactor<T>::setParams (const Params& newParams)
|
||||
{
|
||||
params_ = newParams;
|
||||
assert (params_.size() == Util::sizeExpected (ranges_));
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T> double
|
||||
GenericFactor<T>::operator[] (size_t idx) const
|
||||
{
|
||||
assert (idx < params_.size());
|
||||
return params_[idx];
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T> double&
|
||||
GenericFactor<T>::operator[] (size_t idx)
|
||||
{
|
||||
assert (idx < params_.size());
|
||||
return params_[idx];
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T> GenericFactor<T>&
|
||||
GenericFactor<T>::multiply (const GenericFactor<T>& g)
|
||||
{
|
||||
if (args_ == g.arguments()) {
|
||||
// optimization
|
||||
Globals::logDomain
|
||||
? params_ += g.params()
|
||||
: params_ *= g.params();
|
||||
return *this;
|
||||
}
|
||||
unsigned range_prod = 1;
|
||||
bool share_arguments = false;
|
||||
const std::vector<T>& g_args = g.arguments();
|
||||
const Ranges& g_ranges = g.ranges();
|
||||
const Params& g_params = g.params();
|
||||
for (size_t i = 0; i < g_args.size(); i++) {
|
||||
size_t idx = indexOf (g_args[i]);
|
||||
if (idx == args_.size()) {
|
||||
range_prod *= g_ranges[i];
|
||||
args_.push_back (g_args[i]);
|
||||
ranges_.push_back (g_ranges[i]);
|
||||
} else {
|
||||
share_arguments = true;
|
||||
}
|
||||
}
|
||||
if (share_arguments == false) {
|
||||
// optimization
|
||||
cartesianProduct (g_params.begin(), g_params.end());
|
||||
} else {
|
||||
extend (range_prod);
|
||||
Params::iterator it = params_.begin();
|
||||
MapIndexer indexer (args_, ranges_, g_args, g_ranges);
|
||||
if (Globals::logDomain) {
|
||||
for (; indexer.valid(); ++it, ++indexer) {
|
||||
*it += g_params[indexer];
|
||||
}
|
||||
} else {
|
||||
for (; indexer.valid(); ++it, ++indexer) {
|
||||
*it *= g_params[indexer];
|
||||
}
|
||||
}
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T> void
|
||||
GenericFactor<T>::sumOutIndex (size_t idx)
|
||||
{
|
||||
assert (idx < args_.size());
|
||||
assert (args_.size() > 1);
|
||||
size_t new_size = params_.size() / ranges_[idx];
|
||||
Params newps (new_size, LogAware::addIdenty());
|
||||
Params::const_iterator first = params_.begin();
|
||||
Params::const_iterator last = params_.end();
|
||||
MapIndexer indexer (ranges_, idx);
|
||||
if (Globals::logDomain) {
|
||||
for (; first != last; ++indexer) {
|
||||
newps[indexer] = Util::logSum (newps[indexer], *first++);
|
||||
}
|
||||
} else {
|
||||
for (; first != last; ++indexer) {
|
||||
newps[indexer] += *first++;
|
||||
}
|
||||
}
|
||||
params_ = newps;
|
||||
args_.erase (args_.begin() + idx);
|
||||
ranges_.erase (ranges_.begin() + idx);
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T> void
|
||||
GenericFactor<T>::absorveEvidence (const T& arg, unsigned obsIdx)
|
||||
{
|
||||
size_t idx = indexOf (arg);
|
||||
assert (idx != args_.size());
|
||||
assert (obsIdx < ranges_[idx]);
|
||||
Params newps;
|
||||
newps.reserve (params_.size() / ranges_[idx]);
|
||||
Indexer indexer (ranges_);
|
||||
for (unsigned i = 0; i < obsIdx; ++i) {
|
||||
indexer.incrementDimension (idx);
|
||||
}
|
||||
while (indexer.valid()) {
|
||||
newps.push_back (params_[indexer]);
|
||||
indexer.incrementExceptDimension (idx);
|
||||
}
|
||||
params_ = newps;
|
||||
args_.erase (args_.begin() + idx);
|
||||
ranges_.erase (ranges_.begin() + idx);
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T> void
|
||||
GenericFactor<T>::reorderArguments (const std::vector<T>& new_args)
|
||||
{
|
||||
assert (new_args.size() == args_.size());
|
||||
if (new_args == args_) {
|
||||
return; // already on the desired order
|
||||
}
|
||||
Ranges new_ranges;
|
||||
for (size_t i = 0; i < new_args.size(); i++) {
|
||||
size_t idx = indexOf (new_args[i]);
|
||||
assert (idx != args_.size());
|
||||
new_ranges.push_back (ranges_[idx]);
|
||||
}
|
||||
Params newps;
|
||||
newps.reserve (params_.size());
|
||||
MapIndexer indexer (new_args, new_ranges, args_, ranges_);
|
||||
for (; indexer.valid(); ++indexer) {
|
||||
newps.push_back (params_[indexer]);
|
||||
}
|
||||
params_ = newps;
|
||||
args_ = new_args;
|
||||
ranges_ = new_ranges;
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T> void
|
||||
GenericFactor<T>::extend (unsigned range_prod)
|
||||
{
|
||||
Params backup = params_;
|
||||
params_.clear();
|
||||
params_.reserve (backup.size() * range_prod);
|
||||
Params::const_iterator first = backup.begin();
|
||||
Params::const_iterator last = backup.end();
|
||||
for (; first != last; ++first) {
|
||||
for (unsigned reps = 0; reps < range_prod; ++reps) {
|
||||
params_.push_back (*first);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T> void
|
||||
GenericFactor<T>::cartesianProduct (
|
||||
Params::const_iterator first2,
|
||||
Params::const_iterator last2)
|
||||
{
|
||||
Params backup = params_;
|
||||
params_.clear();
|
||||
params_.reserve (params_.size() * (last2 - first2));
|
||||
Params::const_iterator first1 = backup.begin();
|
||||
Params::const_iterator last1 = backup.end();
|
||||
Params::const_iterator tmp;
|
||||
if (Globals::logDomain) {
|
||||
for (; first1 != last1; ++first1) {
|
||||
for (tmp = first2; tmp != last2; ++tmp) {
|
||||
params_.push_back ((*first1) + (*tmp));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (; first1 != last1; ++first1) {
|
||||
for (tmp = first2; tmp != last2; ++tmp) {
|
||||
params_.push_back ((*first1) * (*tmp));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
template class GenericFactor<VarId>;
|
||||
template class GenericFactor<ProbFormula>;
|
||||
|
||||
} // namespace Horus
|
||||
|
76
packages/CLPBN/horus/GenericFactor.h
Normal file
76
packages/CLPBN/horus/GenericFactor.h
Normal file
@ -0,0 +1,76 @@
|
||||
#ifndef YAP_PACKAGES_CLPBN_HORUS_GENERICFACTOR_H_
|
||||
#define YAP_PACKAGES_CLPBN_HORUS_GENERICFACTOR_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
template <typename T>
|
||||
class GenericFactor {
|
||||
public:
|
||||
const std::vector<T>& arguments() const { return args_; }
|
||||
|
||||
std::vector<T>& arguments() { return args_; }
|
||||
|
||||
const Ranges& ranges() const { return ranges_; }
|
||||
|
||||
const Params& params() const { return params_; }
|
||||
|
||||
Params& params() { return params_; }
|
||||
|
||||
size_t nrArguments() const { return args_.size(); }
|
||||
|
||||
size_t size() const { return params_.size(); }
|
||||
|
||||
unsigned distId() const { return distId_; }
|
||||
|
||||
void setDistId (unsigned id) { distId_ = id; }
|
||||
|
||||
void normalize() { LogAware::normalize (params_); }
|
||||
|
||||
size_t indexOf (const T& t) const { return Util::indexOf (args_, t); }
|
||||
|
||||
const T& argument (size_t idx) const;
|
||||
|
||||
T& argument (size_t idx);
|
||||
|
||||
unsigned range (size_t idx) const;
|
||||
|
||||
bool contains (const T& arg) const;
|
||||
|
||||
bool contains (const std::vector<T>& args) const;
|
||||
|
||||
void setParams (const Params& newParams);
|
||||
|
||||
double operator[] (size_t idx) const;
|
||||
|
||||
double& operator[] (size_t idx);
|
||||
|
||||
GenericFactor<T>& multiply (const GenericFactor<T>& g);
|
||||
|
||||
void sumOutIndex (size_t idx);
|
||||
|
||||
void absorveEvidence (const T& arg, unsigned obsIdx);
|
||||
|
||||
void reorderArguments (const std::vector<T>& new_args);
|
||||
|
||||
protected:
|
||||
std::vector<T> args_;
|
||||
Ranges ranges_;
|
||||
Params params_;
|
||||
unsigned distId_;
|
||||
|
||||
private:
|
||||
void extend (unsigned range_prod);
|
||||
|
||||
void cartesianProduct (
|
||||
Params::const_iterator first2, Params::const_iterator last2);
|
||||
};
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_GENERICFACTOR_H_
|
||||
|
@ -1,10 +1,20 @@
|
||||
#include <cassert>
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include <iomanip>
|
||||
|
||||
#include "GroundSolver.h"
|
||||
#include "VarElim.h"
|
||||
#include "BeliefProp.h"
|
||||
#include "CountingBp.h"
|
||||
#include "Indexer.h"
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
void
|
||||
GroundSolver::printAnswer (const VarIds& vids)
|
||||
{
|
||||
@ -19,20 +29,21 @@ GroundSolver::printAnswer (const VarIds& vids)
|
||||
}
|
||||
if (unobservedVids.empty() == false) {
|
||||
Params res = solveQuery (unobservedVids);
|
||||
vector<string> stateLines = Util::getStateLines (unobservedVars);
|
||||
std::vector<std::string> stateLines =
|
||||
Util::getStateLines (unobservedVars);
|
||||
for (size_t i = 0; i < res.size(); i++) {
|
||||
cout << "P(" << stateLines[i] << ") = " ;
|
||||
cout << std::setprecision (Constants::PRECISION) << res[i];
|
||||
cout << endl;
|
||||
std::cout << "P(" << stateLines[i] << ") = " ;
|
||||
std::cout << std::setprecision (Constants::precision) << res[i];
|
||||
std::cout << std::endl;
|
||||
}
|
||||
cout << endl;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
GroundSolver::printAllPosterioris (void)
|
||||
GroundSolver::printAllPosterioris()
|
||||
{
|
||||
VarNodes vars = fg.varNodes();
|
||||
std::sort (vars.begin(), vars.end(), sortByVarId());
|
||||
@ -57,9 +68,9 @@ GroundSolver::getJointByConditioning (
|
||||
|
||||
GroundSolver* solver = 0;
|
||||
switch (solverType) {
|
||||
case GroundSolverType::BP: solver = new BeliefProp (fg); break;
|
||||
case GroundSolverType::CBP: solver = new CountingBp (fg); break;
|
||||
case GroundSolverType::VE: solver = new VarElim (fg); break;
|
||||
case GroundSolverType::bpSolver: solver = new BeliefProp (fg); break;
|
||||
case GroundSolverType::CbpSolver: solver = new CountingBp (fg); break;
|
||||
case GroundSolverType::veSolver: solver = new VarElim (fg); break;
|
||||
}
|
||||
Params prevBeliefs = solver->solveQuery ({jointVarIds[0]});
|
||||
VarIds observedVids = {jointVars[0]->varId()};
|
||||
@ -80,9 +91,9 @@ GroundSolver::getJointByConditioning (
|
||||
}
|
||||
delete solver;
|
||||
switch (solverType) {
|
||||
case GroundSolverType::BP: solver = new BeliefProp (fg); break;
|
||||
case GroundSolverType::CBP: solver = new CountingBp (fg); break;
|
||||
case GroundSolverType::VE: solver = new VarElim (fg); break;
|
||||
case GroundSolverType::bpSolver: solver = new BeliefProp (fg); break;
|
||||
case GroundSolverType::CbpSolver: solver = new CountingBp (fg); break;
|
||||
case GroundSolverType::veSolver: solver = new VarElim (fg); break;
|
||||
}
|
||||
Params beliefs = solver->solveQuery ({jointVarIds[i]});
|
||||
for (size_t k = 0; k < beliefs.size(); k++) {
|
||||
@ -105,3 +116,5 @@ GroundSolver::getJointByConditioning (
|
||||
return prevBeliefs;
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
|
@ -1,16 +1,13 @@
|
||||
#ifndef HORUS_GROUNDSOLVER_H
|
||||
#define HORUS_GROUNDSOLVER_H
|
||||
|
||||
#include <iomanip>
|
||||
#ifndef YAP_PACKAGES_CLPBN_HORUS_GROUNDSOLVER_H_
|
||||
#define YAP_PACKAGES_CLPBN_HORUS_GROUNDSOLVER_H_
|
||||
|
||||
#include "FactorGraph.h"
|
||||
#include "Horus.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
namespace Horus {
|
||||
|
||||
class GroundSolver
|
||||
{
|
||||
class GroundSolver {
|
||||
public:
|
||||
GroundSolver (const FactorGraph& factorGraph) : fg(factorGraph) { }
|
||||
|
||||
@ -18,11 +15,11 @@ class GroundSolver
|
||||
|
||||
virtual Params solveQuery (VarIds queryVids) = 0;
|
||||
|
||||
virtual void printSolverFlags (void) const = 0;
|
||||
virtual void printSolverFlags() const = 0;
|
||||
|
||||
void printAnswer (const VarIds& vids);
|
||||
|
||||
void printAllPosterioris (void);
|
||||
void printAllPosterioris();
|
||||
|
||||
static Params getJointByConditioning (GroundSolverType,
|
||||
FactorGraph, const VarIds& jointVarIds);
|
||||
@ -30,8 +27,11 @@ class GroundSolver
|
||||
protected:
|
||||
const FactorGraph& fg;
|
||||
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN (GroundSolver);
|
||||
};
|
||||
|
||||
#endif // HORUS_GROUNDSOLVER_H
|
||||
} // namespace Horus
|
||||
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_GROUNDSOLVER_H_
|
||||
|
||||
|
@ -7,6 +7,8 @@
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
HistogramSet::HistogramSet (unsigned size, unsigned range)
|
||||
{
|
||||
size_ = size;
|
||||
@ -17,7 +19,7 @@ HistogramSet::HistogramSet (unsigned size, unsigned range)
|
||||
|
||||
|
||||
void
|
||||
HistogramSet::nextHistogram (void)
|
||||
HistogramSet::nextHistogram()
|
||||
{
|
||||
for (size_t i = hist_.size() - 1; i-- > 0; ) {
|
||||
if (hist_[i] > 0) {
|
||||
@ -43,7 +45,7 @@ HistogramSet::operator[] (size_t idx) const
|
||||
|
||||
|
||||
unsigned
|
||||
HistogramSet::nrHistograms (void) const
|
||||
HistogramSet::nrHistograms() const
|
||||
{
|
||||
return HistogramSet::nrHistograms (size_, hist_.size());
|
||||
}
|
||||
@ -51,7 +53,7 @@ HistogramSet::nrHistograms (void) const
|
||||
|
||||
|
||||
void
|
||||
HistogramSet::reset (void)
|
||||
HistogramSet::reset()
|
||||
{
|
||||
std::fill (hist_.begin() + 1, hist_.end(), 0);
|
||||
hist_[0] = size_;
|
||||
@ -59,12 +61,12 @@ HistogramSet::reset (void)
|
||||
|
||||
|
||||
|
||||
vector<Histogram>
|
||||
std::vector<Histogram>
|
||||
HistogramSet::getHistograms (unsigned N, unsigned R)
|
||||
{
|
||||
HistogramSet hs (N, R);
|
||||
unsigned H = hs.nrHistograms();
|
||||
vector<Histogram> histograms;
|
||||
std::vector<Histogram> histograms;
|
||||
histograms.reserve (H);
|
||||
for (unsigned i = 0; i < H; i++) {
|
||||
histograms.push_back (hs.hist_);
|
||||
@ -86,9 +88,9 @@ HistogramSet::nrHistograms (unsigned N, unsigned R)
|
||||
size_t
|
||||
HistogramSet::findIndex (
|
||||
const Histogram& h,
|
||||
const vector<Histogram>& hists)
|
||||
const std::vector<Histogram>& hists)
|
||||
{
|
||||
vector<Histogram>::const_iterator it = std::lower_bound (
|
||||
std::vector<Histogram>::const_iterator it = std::lower_bound (
|
||||
hists.begin(), hists.end(), h, std::greater<Histogram>());
|
||||
assert (it != hists.end() && *it == h);
|
||||
return std::distance (hists.begin(), it);
|
||||
@ -96,13 +98,13 @@ HistogramSet::findIndex (
|
||||
|
||||
|
||||
|
||||
vector<double>
|
||||
std::vector<double>
|
||||
HistogramSet::getNumAssigns (unsigned N, unsigned R)
|
||||
{
|
||||
HistogramSet hs (N, R);
|
||||
double N_fac = Util::logFactorial (N);
|
||||
unsigned H = hs.nrHistograms();
|
||||
vector<double> numAssigns;
|
||||
std::vector<double> numAssigns;
|
||||
numAssigns.reserve (H);
|
||||
for (unsigned h = 0; h < H; h++) {
|
||||
double prod = 0.0;
|
||||
@ -118,14 +120,6 @@ HistogramSet::getNumAssigns (unsigned N, unsigned R)
|
||||
|
||||
|
||||
|
||||
ostream& operator<< (ostream &os, const HistogramSet& hs)
|
||||
{
|
||||
os << "#" << hs.hist_;
|
||||
return os;
|
||||
}
|
||||
|
||||
|
||||
|
||||
unsigned
|
||||
HistogramSet::maxCount (size_t idx) const
|
||||
{
|
||||
@ -144,3 +138,14 @@ HistogramSet::clearAfter (size_t idx)
|
||||
std::fill (hist_.begin() + idx + 1, hist_.end(), 0);
|
||||
}
|
||||
|
||||
|
||||
|
||||
std::ostream&
|
||||
operator<< (std::ostream& os, const HistogramSet& hs)
|
||||
{
|
||||
os << "#" << hs.hist_;
|
||||
return os;
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
|
@ -1,50 +1,51 @@
|
||||
#ifndef HORUS_HISTOGRAM_H
|
||||
#define HORUS_HISTOGRAM_H
|
||||
#ifndef YAP_PACKAGES_CLPBN_HORUS_HISTOGRAM_H_
|
||||
#define YAP_PACKAGES_CLPBN_HORUS_HISTOGRAM_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include <ostream>
|
||||
|
||||
#include "Horus.h"
|
||||
|
||||
using namespace std;
|
||||
typedef std::vector<unsigned> Histogram;
|
||||
|
||||
typedef vector<unsigned> Histogram;
|
||||
|
||||
class HistogramSet
|
||||
{
|
||||
namespace Horus {
|
||||
|
||||
class HistogramSet {
|
||||
public:
|
||||
HistogramSet (unsigned, unsigned);
|
||||
|
||||
void nextHistogram (void);
|
||||
void nextHistogram();
|
||||
|
||||
unsigned operator[] (size_t idx) const;
|
||||
|
||||
unsigned nrHistograms (void) const;
|
||||
unsigned nrHistograms() const;
|
||||
|
||||
void reset (void);
|
||||
void reset();
|
||||
|
||||
static vector<Histogram> getHistograms (unsigned, unsigned);
|
||||
static std::vector<Histogram> getHistograms (unsigned, unsigned);
|
||||
|
||||
static unsigned nrHistograms (unsigned, unsigned);
|
||||
|
||||
static size_t findIndex (
|
||||
const Histogram&, const vector<Histogram>&);
|
||||
const Histogram&, const std::vector<Histogram>&);
|
||||
|
||||
static vector<double> getNumAssigns (unsigned, unsigned);
|
||||
|
||||
friend std::ostream& operator<< (ostream &os, const HistogramSet& hs);
|
||||
static std::vector<double> getNumAssigns (unsigned, unsigned);
|
||||
|
||||
private:
|
||||
unsigned maxCount (size_t) const;
|
||||
|
||||
void clearAfter (size_t);
|
||||
|
||||
friend std::ostream& operator<< (std::ostream&, const HistogramSet&);
|
||||
|
||||
unsigned size_;
|
||||
Histogram hist_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN (HistogramSet);
|
||||
};
|
||||
|
||||
#endif // HORUS_HISTOGRAM_H
|
||||
} // namespace Horus
|
||||
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_HISTOGRAM_H_
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
#ifndef HORUS_HORUS_H
|
||||
#define HORUS_HORUS_H
|
||||
#ifndef YAP_PACKAGES_CLPBN_HORUS_HORUS_H_
|
||||
#define YAP_PACKAGES_CLPBN_HORUS_HORUS_H_
|
||||
|
||||
#define DISALLOW_COPY_AND_ASSIGN(TypeName) \
|
||||
TypeName(const TypeName&); \
|
||||
@ -14,6 +14,9 @@
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
class Var;
|
||||
class Factor;
|
||||
class VarNode;
|
||||
@ -31,19 +34,17 @@ typedef std::vector<unsigned> Ranges;
|
||||
typedef unsigned long long ullong;
|
||||
|
||||
|
||||
enum LiftedSolverType
|
||||
{
|
||||
LVE, // generalized counting first-order variable elimination (GC-FOVE)
|
||||
LBP, // lifted first-order belief propagation
|
||||
LKC // lifted first-order knowledge compilation
|
||||
enum class LiftedSolverType {
|
||||
lveSolver, // generalized counting first-order variable elimination
|
||||
lbpSolver, // lifted first-order belief propagation
|
||||
lkcSolver // lifted first-order knowledge compilation
|
||||
};
|
||||
|
||||
|
||||
enum GroundSolverType
|
||||
{
|
||||
VE, // variable elimination
|
||||
BP, // belief propagation
|
||||
CBP // counting belief propagation
|
||||
enum class GroundSolverType {
|
||||
veSolver, // variable elimination
|
||||
bpSolver, // belief propagation
|
||||
CbpSolver // counting belief propagation
|
||||
};
|
||||
|
||||
|
||||
@ -57,20 +58,22 @@ extern unsigned verbosity;
|
||||
extern LiftedSolverType liftedSolver;
|
||||
extern GroundSolverType groundSolver;
|
||||
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
namespace Constants {
|
||||
|
||||
// show message calculation for belief propagation
|
||||
const bool SHOW_BP_CALCS = false;
|
||||
const bool showBpCalcs = false;
|
||||
|
||||
const int NO_EVIDENCE = -1;
|
||||
const int unobserved = -1;
|
||||
|
||||
// number of digits to show when printing a parameter
|
||||
const unsigned PRECISION = 6;
|
||||
const unsigned precision = 8;
|
||||
|
||||
};
|
||||
}
|
||||
|
||||
#endif // HORUS_HORUS_H
|
||||
} // namespace Horus
|
||||
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_HORUS_H_
|
||||
|
||||
|
@ -1,53 +1,61 @@
|
||||
#include <cstdlib>
|
||||
#include <cassert>
|
||||
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "FactorGraph.h"
|
||||
#include "VarElim.h"
|
||||
#include "BeliefProp.h"
|
||||
#include "CountingBp.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace {
|
||||
|
||||
int readHorusFlags (int, const char* []);
|
||||
void readFactorGraph (FactorGraph&, const char*);
|
||||
VarIds readQueryAndEvidence (FactorGraph&, int, const char* [], int);
|
||||
|
||||
void runSolver (const FactorGraph&, const VarIds&);
|
||||
void readFactorGraph (Horus::FactorGraph&, const char*);
|
||||
|
||||
const string USAGE = "usage: ./hcli [solver=hve|bp|cbp] \
|
||||
Horus::VarIds readQueryAndEvidence (
|
||||
Horus::FactorGraph&, int, const char* [], int);
|
||||
|
||||
void runSolver (const Horus::FactorGraph&, const Horus::VarIds&);
|
||||
|
||||
const std::string usage = "usage: ./hcli [solver=hve|bp|cbp] \
|
||||
[<OPTION>=<VALUE>]... <FILE> [<VAR>|<VAR>=<EVIDENCE>]... " ;
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
int
|
||||
main (int argc, const char* argv[])
|
||||
{
|
||||
if (argc <= 1) {
|
||||
cerr << "Error: no probabilistic graphical model was given." << endl;
|
||||
cerr << USAGE << endl;
|
||||
std::cerr << "Error: no probabilistic graphical model was given." ;
|
||||
std::cerr << std::endl << usage << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
int idx = readHorusFlags (argc, argv);
|
||||
FactorGraph fg;
|
||||
Horus::FactorGraph fg;
|
||||
readFactorGraph (fg, argv[idx]);
|
||||
VarIds queryIds = readQueryAndEvidence (fg, argc, argv, idx + 1);
|
||||
if (FactorGraph::exportToLibDai()) {
|
||||
Horus::VarIds queryIds
|
||||
= readQueryAndEvidence (fg, argc, argv, idx + 1);
|
||||
if (Horus::FactorGraph::exportToLibDai()) {
|
||||
fg.exportToLibDai ("model.fg");
|
||||
}
|
||||
if (FactorGraph::exportToUai()) {
|
||||
if (Horus::FactorGraph::exportToUai()) {
|
||||
fg.exportToUai ("model.uai");
|
||||
}
|
||||
if (FactorGraph::exportGraphViz()) {
|
||||
if (Horus::FactorGraph::exportGraphViz()) {
|
||||
fg.exportToGraphViz ("model.dot");
|
||||
}
|
||||
if (FactorGraph::printFactorGraph()) {
|
||||
if (Horus::FactorGraph::printFactorGraph()) {
|
||||
fg.print();
|
||||
}
|
||||
if (Globals::verbosity > 0) {
|
||||
cout << "factor graph contains " ;
|
||||
cout << fg.nrVarNodes() << " variables and " ;
|
||||
cout << fg.nrFacNodes() << " factors " << endl;
|
||||
if (Horus::Globals::verbosity > 0) {
|
||||
std::cout << "factor graph contains " ;
|
||||
std::cout << fg.nrVarNodes() << " variables and " ;
|
||||
std::cout << fg.nrFacNodes() << " factors " << std::endl;
|
||||
}
|
||||
runSolver (fg, queryIds);
|
||||
return 0;
|
||||
@ -55,29 +63,31 @@ main (int argc, const char* argv[])
|
||||
|
||||
|
||||
|
||||
namespace {
|
||||
|
||||
int
|
||||
readHorusFlags (int argc, const char* argv[])
|
||||
{
|
||||
int i = 1;
|
||||
for (; i < argc; i++) {
|
||||
const string& arg = argv[i];
|
||||
const std::string& arg = argv[i];
|
||||
size_t pos = arg.find ('=');
|
||||
if (pos == std::string::npos) {
|
||||
return i;
|
||||
}
|
||||
string leftArg = arg.substr (0, pos);
|
||||
string rightArg = arg.substr (pos + 1);
|
||||
std::string leftArg = arg.substr (0, pos);
|
||||
std::string rightArg = arg.substr (pos + 1);
|
||||
if (leftArg.empty()) {
|
||||
cerr << "Error: missing left argument." << endl;
|
||||
cerr << USAGE << endl;
|
||||
std::cerr << "Error: missing left argument." << std::endl;
|
||||
std::cerr << usage << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
if (rightArg.empty()) {
|
||||
cerr << "Error: missing right argument." << endl;
|
||||
cerr << USAGE << endl;
|
||||
std::cerr << "Error: missing right argument." << std::endl;
|
||||
std::cerr << usage << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
Util::setHorusFlag (leftArg, rightArg);
|
||||
Horus::Util::setHorusFlag (leftArg, rightArg);
|
||||
}
|
||||
return i + 1;
|
||||
}
|
||||
@ -85,84 +95,84 @@ readHorusFlags (int argc, const char* argv[])
|
||||
|
||||
|
||||
void
|
||||
readFactorGraph (FactorGraph& fg, const char* s)
|
||||
readFactorGraph (Horus::FactorGraph& fg, const char* s)
|
||||
{
|
||||
string fileName (s);
|
||||
string extension = fileName.substr (fileName.find_last_of ('.') + 1);
|
||||
std::string fileName (s);
|
||||
std::string extension = fileName.substr (fileName.find_last_of ('.') + 1);
|
||||
if (extension == "uai") {
|
||||
fg.readFromUaiFormat (fileName.c_str());
|
||||
fg = Horus::FactorGraph::readFromUaiFormat (fileName.c_str());
|
||||
} else if (extension == "fg") {
|
||||
fg.readFromLibDaiFormat (fileName.c_str());
|
||||
fg = Horus::FactorGraph::readFromLibDaiFormat (fileName.c_str());
|
||||
} else {
|
||||
cerr << "Error: the probabilistic graphical model must be " ;
|
||||
cerr << "defined either in a UAI or libDAI file." << endl;
|
||||
std::cerr << "Error: the probabilistic graphical model must be " ;
|
||||
std::cerr << "defined either in a UAI or libDAI file." << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
VarIds
|
||||
Horus::VarIds
|
||||
readQueryAndEvidence (
|
||||
FactorGraph& fg,
|
||||
Horus::FactorGraph& fg,
|
||||
int argc,
|
||||
const char* argv[],
|
||||
int start)
|
||||
{
|
||||
VarIds queryIds;
|
||||
Horus::VarIds queryIds;
|
||||
for (int i = start; i < argc; i++) {
|
||||
const string& arg = argv[i];
|
||||
const std::string& arg = argv[i];
|
||||
if (arg.find ('=') == std::string::npos) {
|
||||
if (Util::isInteger (arg) == false) {
|
||||
cerr << "Error: `" << arg << "' " ;
|
||||
cerr << "is not a variable id." ;
|
||||
cerr << endl;
|
||||
if (Horus::Util::isInteger (arg) == false) {
|
||||
std::cerr << "Error: `" << arg << "' " ;
|
||||
std::cerr << "is not a variable id." ;
|
||||
std::cerr << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
VarId vid = Util::stringToUnsigned (arg);
|
||||
VarNode* queryVar = fg.getVarNode (vid);
|
||||
Horus::VarId vid = Horus::Util::stringToUnsigned (arg);
|
||||
Horus::VarNode* queryVar = fg.getVarNode (vid);
|
||||
if (queryVar == false) {
|
||||
cerr << "Error: unknow variable with id " ;
|
||||
cerr << "`" << vid << "'." << endl;
|
||||
std::cerr << "Error: unknow variable with id " ;
|
||||
std::cerr << "`" << vid << "'." << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
queryIds.push_back (vid);
|
||||
} else {
|
||||
size_t pos = arg.find ('=');
|
||||
string leftArg = arg.substr (0, pos);
|
||||
string rightArg = arg.substr (pos + 1);
|
||||
std::string leftArg = arg.substr (0, pos);
|
||||
std::string rightArg = arg.substr (pos + 1);
|
||||
if (leftArg.empty()) {
|
||||
cerr << "Error: missing left argument." << endl;
|
||||
cerr << USAGE << endl;
|
||||
std::cerr << "Error: missing left argument." << std::endl;
|
||||
std::cerr << usage << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
if (Util::isInteger (leftArg) == false) {
|
||||
cerr << "Error: `" << leftArg << "' " ;
|
||||
cerr << "is not a variable id." << endl ;
|
||||
if (Horus::Util::isInteger (leftArg) == false) {
|
||||
std::cerr << "Error: `" << leftArg << "' " ;
|
||||
std::cerr << "is not a variable id." << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
VarId vid = Util::stringToUnsigned (leftArg);
|
||||
VarNode* observedVar = fg.getVarNode (vid);
|
||||
Horus::VarId vid = Horus::Util::stringToUnsigned (leftArg);
|
||||
Horus::VarNode* observedVar = fg.getVarNode (vid);
|
||||
if (observedVar == false) {
|
||||
cerr << "Error: unknow variable with id " ;
|
||||
cerr << "`" << vid << "'." << endl;
|
||||
std::cerr << "Error: unknow variable with id " ;
|
||||
std::cerr << "`" << vid << "'." << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
if (rightArg.empty()) {
|
||||
cerr << "Error: missing right argument." << endl;
|
||||
cerr << USAGE << endl;
|
||||
std::cerr << "Error: missing right argument." << std::endl;
|
||||
std::cerr << usage << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
if (Util::isInteger (rightArg) == false) {
|
||||
cerr << "Error: `" << rightArg << "' " ;
|
||||
cerr << "is not a state index." << endl ;
|
||||
if (Horus::Util::isInteger (rightArg) == false) {
|
||||
std::cerr << "Error: `" << rightArg << "' " ;
|
||||
std::cerr << "is not a state index." << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
unsigned stateIdx = Util::stringToUnsigned (rightArg);
|
||||
unsigned stateIdx = Horus::Util::stringToUnsigned (rightArg);
|
||||
if (observedVar->isValidState (stateIdx) == false) {
|
||||
cerr << "Error: `" << stateIdx << "' " ;
|
||||
cerr << "is not a valid state index for variable with id " ;
|
||||
cerr << "`" << vid << "'." << endl;
|
||||
std::cerr << "Error: `" << stateIdx << "' " ;
|
||||
std::cerr << "is not a valid state index for variable with id " ;
|
||||
std::cerr << "`" << vid << "'." << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
observedVar->setEvidence (stateIdx);
|
||||
@ -174,25 +184,27 @@ readQueryAndEvidence (
|
||||
|
||||
|
||||
void
|
||||
runSolver (const FactorGraph& fg, const VarIds& queryIds)
|
||||
runSolver (
|
||||
const Horus::FactorGraph& fg,
|
||||
const Horus::VarIds& queryIds)
|
||||
{
|
||||
GroundSolver* solver = 0;
|
||||
switch (Globals::groundSolver) {
|
||||
case GroundSolverType::VE:
|
||||
solver = new VarElim (fg);
|
||||
Horus::GroundSolver* solver = 0;
|
||||
switch (Horus::Globals::groundSolver) {
|
||||
case Horus::GroundSolverType::veSolver:
|
||||
solver = new Horus::VarElim (fg);
|
||||
break;
|
||||
case GroundSolverType::BP:
|
||||
solver = new BeliefProp (fg);
|
||||
case Horus::GroundSolverType::bpSolver:
|
||||
solver = new Horus::BeliefProp (fg);
|
||||
break;
|
||||
case GroundSolverType::CBP:
|
||||
solver = new CountingBp (fg);
|
||||
case Horus::GroundSolverType::CbpSolver:
|
||||
solver = new Horus::CountingBp (fg);
|
||||
break;
|
||||
default:
|
||||
assert (false);
|
||||
}
|
||||
if (Globals::verbosity > 0) {
|
||||
if (Horus::Globals::verbosity > 0) {
|
||||
solver->printSolverFlags();
|
||||
cout << endl;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
if (queryIds.empty()) {
|
||||
solver->printAllPosterioris();
|
||||
@ -202,3 +214,5 @@ runSolver (const FactorGraph& fg, const VarIds& queryIds)
|
||||
delete solver;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
@ -1,7 +1,8 @@
|
||||
#include <cstdlib>
|
||||
#include <cassert>
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include <unordered_map>
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
@ -20,31 +21,35 @@
|
||||
#include "BayesBall.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
namespace Horus {
|
||||
|
||||
typedef std::pair<ParfactorList*, ObservedFormulas*> LiftedNetwork;
|
||||
namespace {
|
||||
|
||||
Parfactor* readParfactor (YAP_Term);
|
||||
|
||||
void readLiftedEvidence (YAP_Term, ObservedFormulas&);
|
||||
ObservedFormulas* readLiftedEvidence (YAP_Term);
|
||||
|
||||
vector<unsigned> readUnsignedList (YAP_Term list);
|
||||
std::vector<unsigned> readUnsignedList (YAP_Term);
|
||||
|
||||
Params readParameters (YAP_Term);
|
||||
|
||||
YAP_Term fillAnswersPrologList (vector<Params>& results);
|
||||
YAP_Term fillSolutionList (const std::vector<Params>&);
|
||||
|
||||
}
|
||||
|
||||
typedef std::pair<ParfactorList*, ObservedFormulas*> LiftedNetwork;
|
||||
|
||||
|
||||
|
||||
int
|
||||
createLiftedNetwork (void)
|
||||
createLiftedNetwork()
|
||||
{
|
||||
Parfactors parfactors;
|
||||
YAP_Term parfactorList = YAP_ARG1;
|
||||
while (parfactorList != YAP_TermNil()) {
|
||||
YAP_Term pfTerm = YAP_HeadOfTerm (parfactorList);
|
||||
parfactors.push_back (readParfactor (pfTerm));
|
||||
parfactorList = YAP_TailOfTerm (parfactorList);
|
||||
parfactorList = YAP_TailOfTerm (parfactorList);
|
||||
}
|
||||
|
||||
// LiftedUtils::printSymbolDictionary();
|
||||
@ -52,7 +57,7 @@ createLiftedNetwork (void)
|
||||
Util::printHeader ("INITIAL PARFACTORS");
|
||||
for (size_t i = 0; i < parfactors.size(); i++) {
|
||||
parfactors[i]->print();
|
||||
cout << endl;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
@ -64,21 +69,20 @@ createLiftedNetwork (void)
|
||||
}
|
||||
|
||||
// read evidence
|
||||
ObservedFormulas* obsFormulas = new ObservedFormulas();
|
||||
readLiftedEvidence (YAP_ARG2, *(obsFormulas));
|
||||
ObservedFormulas* obsFormulas = readLiftedEvidence (YAP_ARG2);
|
||||
|
||||
LiftedNetwork* net = new LiftedNetwork (pfList, obsFormulas);
|
||||
LiftedNetwork* network = new LiftedNetwork (pfList, obsFormulas);
|
||||
|
||||
YAP_Int p = (YAP_Int) (net);
|
||||
YAP_Int p = (YAP_Int) (network);
|
||||
return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG3);
|
||||
}
|
||||
|
||||
|
||||
|
||||
int
|
||||
createGroundNetwork (void)
|
||||
createGroundNetwork()
|
||||
{
|
||||
string factorsType ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG1)));
|
||||
std::string factorsType ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG1)));
|
||||
FactorGraph* fg = new FactorGraph();
|
||||
if (factorsType == "bayes") {
|
||||
fg->setFactorsAsBayesian();
|
||||
@ -121,9 +125,9 @@ createGroundNetwork (void)
|
||||
fg->print();
|
||||
}
|
||||
if (Globals::verbosity > 0) {
|
||||
cout << "factor graph contains " ;
|
||||
cout << fg->nrVarNodes() << " variables and " ;
|
||||
cout << fg->nrFacNodes() << " factors " << endl;
|
||||
std::cout << "factor graph contains " ;
|
||||
std::cout << fg->nrVarNodes() << " variables and " ;
|
||||
std::cout << fg->nrFacNodes() << " factors " << std::endl;
|
||||
}
|
||||
YAP_Int p = (YAP_Int) (fg);
|
||||
return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG4);
|
||||
@ -132,45 +136,46 @@ createGroundNetwork (void)
|
||||
|
||||
|
||||
int
|
||||
runLiftedSolver (void)
|
||||
runLiftedSolver()
|
||||
{
|
||||
LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1);
|
||||
ParfactorList pfListCopy (*network->first);
|
||||
LiftedOperations::absorveEvidence (pfListCopy, *network->second);
|
||||
ParfactorList copy (*network->first);
|
||||
LiftedOperations::absorveEvidence (copy, *network->second);
|
||||
|
||||
LiftedSolver* solver = 0;
|
||||
switch (Globals::liftedSolver) {
|
||||
case LiftedSolverType::LVE: solver = new LiftedVe (pfListCopy); break;
|
||||
case LiftedSolverType::LBP: solver = new LiftedBp (pfListCopy); break;
|
||||
case LiftedSolverType::LKC: solver = new LiftedKc (pfListCopy); break;
|
||||
case LiftedSolverType::lveSolver: solver = new LiftedVe (copy); break;
|
||||
case LiftedSolverType::lbpSolver: solver = new LiftedBp (copy); break;
|
||||
case LiftedSolverType::lkcSolver: solver = new LiftedKc (copy); break;
|
||||
}
|
||||
|
||||
if (Globals::verbosity > 0) {
|
||||
solver->printSolverFlags();
|
||||
cout << endl;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
YAP_Term taskList = YAP_ARG2;
|
||||
vector<Params> results;
|
||||
std::vector<Params> results;
|
||||
while (taskList != YAP_TermNil()) {
|
||||
Grounds queryVars;
|
||||
YAP_Term jointList = YAP_HeadOfTerm (taskList);
|
||||
while (jointList != YAP_TermNil()) {
|
||||
YAP_Term ground = YAP_HeadOfTerm (jointList);
|
||||
if (YAP_IsAtomTerm (ground)) {
|
||||
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ground)));
|
||||
std::string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ground)));
|
||||
queryVars.push_back (Ground (LiftedUtils::getSymbol (name)));
|
||||
} else {
|
||||
assert (YAP_IsApplTerm (ground));
|
||||
YAP_Functor yapFunctor = YAP_FunctorOfTerm (ground);
|
||||
string name ((char*) (YAP_AtomName (YAP_NameOfFunctor (yapFunctor))));
|
||||
std::string name ((char*) (YAP_AtomName (
|
||||
YAP_NameOfFunctor (yapFunctor))));
|
||||
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
|
||||
Symbol functor = LiftedUtils::getSymbol (name);
|
||||
Symbols args;
|
||||
for (unsigned i = 1; i <= arity; i++) {
|
||||
YAP_Term ti = YAP_ArgOfTerm (i, ground);
|
||||
assert (YAP_IsAtomTerm (ti));
|
||||
string arg ((char *) YAP_AtomName (YAP_AtomOfTerm (ti)));
|
||||
std::string arg ((char *) YAP_AtomName (YAP_AtomOfTerm (ti)));
|
||||
args.push_back (LiftedUtils::getSymbol (arg));
|
||||
}
|
||||
queryVars.push_back (Ground (functor, args));
|
||||
@ -183,17 +188,17 @@ runLiftedSolver (void)
|
||||
|
||||
delete solver;
|
||||
|
||||
return YAP_Unify (fillAnswersPrologList (results), YAP_ARG3);
|
||||
return YAP_Unify (fillSolutionList (results), YAP_ARG3);
|
||||
}
|
||||
|
||||
|
||||
|
||||
int
|
||||
runGroundSolver (void)
|
||||
runGroundSolver()
|
||||
{
|
||||
FactorGraph* fg = (FactorGraph*) YAP_IntOfTerm (YAP_ARG1);
|
||||
|
||||
vector<VarIds> tasks;
|
||||
std::vector<VarIds> tasks;
|
||||
YAP_Term taskList = YAP_ARG2;
|
||||
while (taskList != YAP_TermNil()) {
|
||||
tasks.push_back (readUnsignedList (YAP_HeadOfTerm (taskList)));
|
||||
@ -213,17 +218,17 @@ runGroundSolver (void)
|
||||
GroundSolver* solver = 0;
|
||||
CountingBp::setFindIdenticalFactorsFlag (false);
|
||||
switch (Globals::groundSolver) {
|
||||
case GroundSolverType::VE: solver = new VarElim (*mfg); break;
|
||||
case GroundSolverType::BP: solver = new BeliefProp (*mfg); break;
|
||||
case GroundSolverType::CBP: solver = new CountingBp (*mfg); break;
|
||||
case GroundSolverType::veSolver: solver = new VarElim (*mfg); break;
|
||||
case GroundSolverType::bpSolver: solver = new BeliefProp (*mfg); break;
|
||||
case GroundSolverType::CbpSolver: solver = new CountingBp (*mfg); break;
|
||||
}
|
||||
|
||||
if (Globals::verbosity > 0) {
|
||||
solver->printSolverFlags();
|
||||
cout << endl;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
vector<Params> results;
|
||||
std::vector<Params> results;
|
||||
results.reserve (tasks.size());
|
||||
for (size_t i = 0; i < tasks.size(); i++) {
|
||||
results.push_back (solver->solveQuery (tasks[i]));
|
||||
@ -234,19 +239,19 @@ runGroundSolver (void)
|
||||
delete mfg;
|
||||
}
|
||||
|
||||
return YAP_Unify (fillAnswersPrologList (results), YAP_ARG3);
|
||||
return YAP_Unify (fillSolutionList (results), YAP_ARG3);
|
||||
}
|
||||
|
||||
|
||||
|
||||
int
|
||||
setParfactorsParams (void)
|
||||
setParfactorsParams()
|
||||
{
|
||||
LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1);
|
||||
ParfactorList* pfList = network->first;
|
||||
YAP_Term distIdsList = YAP_ARG2;
|
||||
YAP_Term paramsList = YAP_ARG3;
|
||||
unordered_map<unsigned, Params> paramsMap;
|
||||
std::unordered_map<unsigned, Params> paramsMap;
|
||||
while (distIdsList != YAP_TermNil()) {
|
||||
unsigned distId = (unsigned) YAP_IntOfTerm (
|
||||
YAP_HeadOfTerm (distIdsList));
|
||||
@ -267,12 +272,12 @@ setParfactorsParams (void)
|
||||
|
||||
|
||||
int
|
||||
setFactorsParams (void)
|
||||
setFactorsParams()
|
||||
{
|
||||
FactorGraph* fg = (FactorGraph*) YAP_IntOfTerm (YAP_ARG1);
|
||||
YAP_Term distIdsList = YAP_ARG2;
|
||||
YAP_Term paramsList = YAP_ARG3;
|
||||
unordered_map<unsigned, Params> paramsMap;
|
||||
std::unordered_map<unsigned, Params> paramsMap;
|
||||
while (distIdsList != YAP_TermNil()) {
|
||||
unsigned distId = (unsigned) YAP_IntOfTerm (
|
||||
YAP_HeadOfTerm (distIdsList));
|
||||
@ -293,10 +298,10 @@ setFactorsParams (void)
|
||||
|
||||
|
||||
int
|
||||
setVarsInformation (void)
|
||||
setVarsInformation()
|
||||
{
|
||||
Var::clearVarsInfo();
|
||||
vector<string> labels;
|
||||
std::vector<std::string> labels;
|
||||
YAP_Term labelsL = YAP_ARG1;
|
||||
while (labelsL != YAP_TermNil()) {
|
||||
YAP_Atom atom = YAP_AtomOfTerm (YAP_HeadOfTerm (labelsL));
|
||||
@ -323,20 +328,20 @@ setVarsInformation (void)
|
||||
|
||||
|
||||
int
|
||||
setHorusFlag (void)
|
||||
setHorusFlag()
|
||||
{
|
||||
string option ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG1)));
|
||||
string value;
|
||||
std::string option ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG1)));
|
||||
std::string value;
|
||||
if (option == "verbosity") {
|
||||
stringstream ss;
|
||||
std::stringstream ss;
|
||||
ss << (int) YAP_IntOfTerm (YAP_ARG2);
|
||||
ss >> value;
|
||||
} else if (option == "bp_accuracy") {
|
||||
stringstream ss;
|
||||
std::stringstream ss;
|
||||
ss << (float) YAP_FloatOfTerm (YAP_ARG2);
|
||||
ss >> value;
|
||||
} else if (option == "bp_max_iter") {
|
||||
stringstream ss;
|
||||
std::stringstream ss;
|
||||
ss << (int) YAP_IntOfTerm (YAP_ARG2);
|
||||
ss >> value;
|
||||
} else {
|
||||
@ -348,7 +353,7 @@ setHorusFlag (void)
|
||||
|
||||
|
||||
int
|
||||
freeGroundNetwork (void)
|
||||
freeGroundNetwork()
|
||||
{
|
||||
delete (FactorGraph*) YAP_IntOfTerm (YAP_ARG1);
|
||||
return TRUE;
|
||||
@ -357,7 +362,7 @@ freeGroundNetwork (void)
|
||||
|
||||
|
||||
int
|
||||
freeLiftedNetwork (void)
|
||||
freeLiftedNetwork()
|
||||
{
|
||||
LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1);
|
||||
delete network->first;
|
||||
@ -368,6 +373,8 @@ freeLiftedNetwork (void)
|
||||
|
||||
|
||||
|
||||
namespace {
|
||||
|
||||
Parfactor*
|
||||
readParfactor (YAP_Term pfTerm)
|
||||
{
|
||||
@ -386,23 +393,24 @@ readParfactor (YAP_Term pfTerm)
|
||||
// read parametric random vars
|
||||
ProbFormulas formulas;
|
||||
unsigned count = 0;
|
||||
unordered_map<YAP_Term, LogVar> lvMap;
|
||||
std::unordered_map<YAP_Term, LogVar> lvMap;
|
||||
YAP_Term pvList = YAP_ArgOfTerm (2, pfTerm);
|
||||
while (pvList != YAP_TermNil()) {
|
||||
YAP_Term formulaTerm = YAP_HeadOfTerm (pvList);
|
||||
if (YAP_IsAtomTerm (formulaTerm)) {
|
||||
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (formulaTerm)));
|
||||
std::string name ((char*) YAP_AtomName (YAP_AtomOfTerm (formulaTerm)));
|
||||
Symbol functor = LiftedUtils::getSymbol (name);
|
||||
formulas.push_back (ProbFormula (functor, ranges[count]));
|
||||
} else {
|
||||
LogVars logVars;
|
||||
YAP_Functor yapFunctor = YAP_FunctorOfTerm (formulaTerm);
|
||||
string name ((char*) YAP_AtomName (YAP_NameOfFunctor (yapFunctor)));
|
||||
std::string name ((char*) YAP_AtomName (
|
||||
YAP_NameOfFunctor (yapFunctor)));
|
||||
Symbol functor = LiftedUtils::getSymbol (name);
|
||||
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
|
||||
for (unsigned i = 1; i <= arity; i++) {
|
||||
YAP_Term ti = YAP_ArgOfTerm (i, formulaTerm);
|
||||
unordered_map<YAP_Term, LogVar>::iterator it = lvMap.find (ti);
|
||||
std::unordered_map<YAP_Term, LogVar>::iterator it = lvMap.find (ti);
|
||||
if (it != lvMap.end()) {
|
||||
logVars.push_back (it->second);
|
||||
} else {
|
||||
@ -418,7 +426,7 @@ readParfactor (YAP_Term pfTerm)
|
||||
}
|
||||
|
||||
// read the parameters
|
||||
const Params& params = readParameters (YAP_ArgOfTerm (4, pfTerm));
|
||||
Params params = readParameters (YAP_ArgOfTerm (4, pfTerm));
|
||||
|
||||
// read the constraint
|
||||
Tuples tuples;
|
||||
@ -434,10 +442,11 @@ readParfactor (YAP_Term pfTerm)
|
||||
for (unsigned i = 1; i <= arity; i++) {
|
||||
YAP_Term ti = YAP_ArgOfTerm (i, term);
|
||||
if (YAP_IsAtomTerm (ti) == false) {
|
||||
cerr << "Error: the constraint contains free variables." << endl;
|
||||
std::cerr << "Error: the constraint contains free variables." ;
|
||||
std::cerr << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ti)));
|
||||
std::string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ti)));
|
||||
tuple[i - 1] = LiftedUtils::getSymbol (name);
|
||||
}
|
||||
tuples.push_back (tuple);
|
||||
@ -449,55 +458,56 @@ readParfactor (YAP_Term pfTerm)
|
||||
|
||||
|
||||
|
||||
void
|
||||
readLiftedEvidence (
|
||||
YAP_Term observedList,
|
||||
ObservedFormulas& obsFormulas)
|
||||
ObservedFormulas*
|
||||
readLiftedEvidence (YAP_Term observedList)
|
||||
{
|
||||
ObservedFormulas* obsFormulas = new ObservedFormulas();
|
||||
while (observedList != YAP_TermNil()) {
|
||||
YAP_Term pair = YAP_HeadOfTerm (observedList);
|
||||
YAP_Term ground = YAP_ArgOfTerm (1, pair);
|
||||
Symbol functor;
|
||||
Symbols args;
|
||||
if (YAP_IsAtomTerm (ground)) {
|
||||
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ground)));
|
||||
std::string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ground)));
|
||||
functor = LiftedUtils::getSymbol (name);
|
||||
} else {
|
||||
assert (YAP_IsApplTerm (ground));
|
||||
YAP_Functor yapFunctor = YAP_FunctorOfTerm (ground);
|
||||
string name ((char*) (YAP_AtomName (YAP_NameOfFunctor (yapFunctor))));
|
||||
std::string name ((char*) (YAP_AtomName (
|
||||
YAP_NameOfFunctor (yapFunctor))));
|
||||
functor = LiftedUtils::getSymbol (name);
|
||||
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
|
||||
for (unsigned i = 1; i <= arity; i++) {
|
||||
YAP_Term ti = YAP_ArgOfTerm (i, ground);
|
||||
assert (YAP_IsAtomTerm (ti));
|
||||
string arg ((char *) YAP_AtomName (YAP_AtomOfTerm (ti)));
|
||||
std::string arg ((char *) YAP_AtomName (YAP_AtomOfTerm (ti)));
|
||||
args.push_back (LiftedUtils::getSymbol (arg));
|
||||
}
|
||||
}
|
||||
unsigned evidence = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (2, pair));
|
||||
bool found = false;
|
||||
for (size_t i = 0; i < obsFormulas.size(); i++) {
|
||||
if (obsFormulas[i].functor() == functor &&
|
||||
obsFormulas[i].arity() == args.size() &&
|
||||
obsFormulas[i].evidence() == evidence) {
|
||||
obsFormulas[i].addTuple (args);
|
||||
for (size_t i = 0; i < obsFormulas->size(); i++) {
|
||||
if ((*obsFormulas)[i].functor() == functor &&
|
||||
(*obsFormulas)[i].arity() == args.size() &&
|
||||
(*obsFormulas)[i].evidence() == evidence) {
|
||||
(*obsFormulas)[i].addTuple (args);
|
||||
found = true;
|
||||
}
|
||||
}
|
||||
if (found == false) {
|
||||
obsFormulas.push_back (ObservedFormula (functor, evidence, args));
|
||||
obsFormulas->push_back (ObservedFormula (functor, evidence, args));
|
||||
}
|
||||
observedList = YAP_TailOfTerm (observedList);
|
||||
}
|
||||
return obsFormulas;
|
||||
}
|
||||
|
||||
|
||||
|
||||
vector<unsigned>
|
||||
std::vector<unsigned>
|
||||
readUnsignedList (YAP_Term list)
|
||||
{
|
||||
vector<unsigned> vec;
|
||||
std::vector<unsigned> vec;
|
||||
while (list != YAP_TermNil()) {
|
||||
vec.push_back ((unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (list)));
|
||||
list = YAP_TailOfTerm (list);
|
||||
@ -513,7 +523,12 @@ readParameters (YAP_Term paramL)
|
||||
Params params;
|
||||
assert (YAP_IsPairTerm (paramL));
|
||||
while (paramL != YAP_TermNil()) {
|
||||
params.push_back ((double) YAP_FloatOfTerm (YAP_HeadOfTerm (paramL)));
|
||||
YAP_Term hd = YAP_HeadOfTerm (paramL);
|
||||
if (YAP_IsFloatTerm (hd)) {
|
||||
params.push_back ((double) YAP_FloatOfTerm (hd));
|
||||
} else {
|
||||
params.push_back ((double) YAP_IntOfTerm (hd));
|
||||
}
|
||||
paramL = YAP_TailOfTerm (paramL);
|
||||
}
|
||||
if (Globals::logDomain) {
|
||||
@ -525,17 +540,17 @@ readParameters (YAP_Term paramL)
|
||||
|
||||
|
||||
YAP_Term
|
||||
fillAnswersPrologList (vector<Params>& results)
|
||||
fillSolutionList (const std::vector<Params>& results)
|
||||
{
|
||||
YAP_Term list = YAP_TermNil();
|
||||
for (size_t i = results.size(); i-- > 0; ) {
|
||||
const Params& beliefs = results[i];
|
||||
YAP_Term queryBeliefsL = YAP_TermNil();
|
||||
for (size_t j = beliefs.size(); j-- > 0; ) {
|
||||
YAP_Int sl1 = YAP_InitSlot (list);
|
||||
YAP_Int sl = YAP_InitSlot (list);
|
||||
YAP_Term belief = YAP_MkFloatTerm (beliefs[j]);
|
||||
queryBeliefsL = YAP_MkPairTerm (belief, queryBeliefsL);
|
||||
list = YAP_GetFromSlot (sl1);
|
||||
list = YAP_GetFromSlot (sl);
|
||||
YAP_RecoverSlots (1);
|
||||
}
|
||||
list = YAP_MkPairTerm (queryBeliefsL, list);
|
||||
@ -543,10 +558,12 @@ fillAnswersPrologList (vector<Params>& results)
|
||||
return list;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
extern "C" void
|
||||
init_predicates (void)
|
||||
init_predicates()
|
||||
{
|
||||
YAP_UserCPredicate ("cpp_create_lifted_network",
|
||||
createLiftedNetwork, 3);
|
||||
@ -579,3 +596,5 @@ init_predicates (void)
|
||||
freeGroundNetwork, 1);
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
|
32
packages/CLPBN/horus/Indexer.cpp
Normal file
32
packages/CLPBN/horus/Indexer.cpp
Normal file
@ -0,0 +1,32 @@
|
||||
#include <sstream>
|
||||
#include <iomanip>
|
||||
|
||||
#include "Indexer.h"
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
std::ostream&
|
||||
operator<< (std::ostream& os, const Indexer& indexer)
|
||||
{
|
||||
os << "(" ;
|
||||
os << std::setw (2) << std::setfill('0') << indexer.index_;
|
||||
os << ") " ;
|
||||
os << indexer.indices_;
|
||||
return os;
|
||||
}
|
||||
|
||||
|
||||
|
||||
std::ostream&
|
||||
operator<< (std::ostream &os, const MapIndexer& indexer)
|
||||
{
|
||||
os << "(" ;
|
||||
os << std::setw (2) << std::setfill('0') << indexer.index_;
|
||||
os << ") " ;
|
||||
os << indexer.indices_;
|
||||
return os;
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
@ -1,262 +1,343 @@
|
||||
#ifndef HORUS_INDEXER_H
|
||||
#define HORUS_INDEXER_H
|
||||
#ifndef YAP_PACKAGES_CLPBN_HORUS_INDEXER_H_
|
||||
#define YAP_PACKAGES_CLPBN_HORUS_INDEXER_H_
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
|
||||
#include <sstream>
|
||||
#include <iomanip>
|
||||
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
class Indexer
|
||||
{
|
||||
namespace Horus {
|
||||
|
||||
class Indexer {
|
||||
public:
|
||||
Indexer (const Ranges& ranges, bool calcOffsets = true)
|
||||
: index_(0), indices_(ranges.size(), 0), ranges_(ranges),
|
||||
size_(Util::sizeExpected (ranges))
|
||||
{
|
||||
if (calcOffsets) {
|
||||
calculateOffsets();
|
||||
}
|
||||
}
|
||||
Indexer (const Ranges& ranges, bool calcOffsets = true);
|
||||
|
||||
void increment (void)
|
||||
{
|
||||
for (size_t i = ranges_.size(); i-- > 0; ) {
|
||||
indices_[i] ++;
|
||||
if (indices_[i] != ranges_[i]) {
|
||||
break;
|
||||
} else {
|
||||
indices_[i] = 0;
|
||||
}
|
||||
}
|
||||
index_ ++;
|
||||
}
|
||||
void increment();
|
||||
|
||||
void incrementDimension (size_t dim)
|
||||
{
|
||||
assert (dim < ranges_.size());
|
||||
assert (ranges_.size() == offsets_.size());
|
||||
assert (indices_[dim] < ranges_[dim]);
|
||||
indices_[dim] ++;
|
||||
index_ += offsets_[dim];
|
||||
}
|
||||
void incrementDimension (size_t dim);
|
||||
|
||||
void incrementExceptDimension (size_t dim)
|
||||
{
|
||||
assert (ranges_.size() == offsets_.size());
|
||||
for (size_t i = ranges_.size(); i-- > 0; ) {
|
||||
if (i != dim) {
|
||||
indices_[i] ++;
|
||||
index_ += offsets_[i];
|
||||
if (indices_[i] != ranges_[i]) {
|
||||
return;
|
||||
} else {
|
||||
indices_[i] = 0;
|
||||
index_ -= offsets_[i] * ranges_[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
index_ = size_;
|
||||
}
|
||||
void incrementExceptDimension (size_t dim);
|
||||
|
||||
Indexer& operator++ (void)
|
||||
{
|
||||
increment();
|
||||
return *this;
|
||||
}
|
||||
Indexer& operator++();
|
||||
|
||||
operator size_t (void) const
|
||||
{
|
||||
return index_;
|
||||
}
|
||||
operator size_t() const;
|
||||
|
||||
unsigned operator[] (size_t dim) const
|
||||
{
|
||||
assert (valid());
|
||||
assert (dim < ranges_.size());
|
||||
return indices_[dim];
|
||||
}
|
||||
unsigned operator[] (size_t dim) const;
|
||||
|
||||
bool valid (void) const
|
||||
{
|
||||
return index_ < size_;
|
||||
}
|
||||
bool valid() const;
|
||||
|
||||
void reset (void)
|
||||
{
|
||||
std::fill (indices_.begin(), indices_.end(), 0);
|
||||
index_ = 0;
|
||||
}
|
||||
void reset();
|
||||
|
||||
void resetDimension (size_t dim)
|
||||
{
|
||||
indices_[dim] = 0;
|
||||
index_ -= offsets_[dim] * ranges_[dim];
|
||||
}
|
||||
void resetDimension (size_t dim);
|
||||
|
||||
size_t size (void) const
|
||||
{
|
||||
return size_ ;
|
||||
}
|
||||
size_t size() const;
|
||||
|
||||
private:
|
||||
void calculateOffsets();
|
||||
|
||||
friend std::ostream& operator<< (std::ostream&, const Indexer&);
|
||||
|
||||
private:
|
||||
void calculateOffsets (void)
|
||||
{
|
||||
size_t prod = 1;
|
||||
offsets_.resize (ranges_.size());
|
||||
for (size_t i = ranges_.size(); i-- > 0; ) {
|
||||
offsets_[i] = prod;
|
||||
prod *= ranges_[i];
|
||||
}
|
||||
}
|
||||
|
||||
size_t index_;
|
||||
Ranges indices_;
|
||||
const Ranges& ranges_;
|
||||
size_t size_;
|
||||
vector<size_t> offsets_;
|
||||
size_t index_;
|
||||
Ranges indices_;
|
||||
const Ranges& ranges_;
|
||||
size_t size_;
|
||||
std::vector<size_t> offsets_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN (Indexer);
|
||||
};
|
||||
|
||||
|
||||
|
||||
inline std::ostream&
|
||||
operator<< (std::ostream& os, const Indexer& indexer)
|
||||
inline
|
||||
Indexer::Indexer (const Ranges& ranges, bool calcOffsets)
|
||||
: index_(0), indices_(ranges.size(), 0), ranges_(ranges),
|
||||
size_(Util::sizeExpected (ranges))
|
||||
{
|
||||
os << "(" ;
|
||||
os << std::setw (2) << std::setfill('0') << indexer.index_;
|
||||
os << ") " ;
|
||||
os << indexer.indices_;
|
||||
return os;
|
||||
if (calcOffsets) {
|
||||
calculateOffsets();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
class MapIndexer
|
||||
inline void
|
||||
Indexer::increment()
|
||||
{
|
||||
for (size_t i = ranges_.size(); i-- > 0; ) {
|
||||
indices_[i] ++;
|
||||
if (indices_[i] != ranges_[i]) {
|
||||
break;
|
||||
} else {
|
||||
indices_[i] = 0;
|
||||
}
|
||||
}
|
||||
index_ ++;
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
Indexer::incrementDimension (size_t dim)
|
||||
{
|
||||
assert (dim < ranges_.size());
|
||||
assert (ranges_.size() == offsets_.size());
|
||||
assert (indices_[dim] < ranges_[dim]);
|
||||
indices_[dim] ++;
|
||||
index_ += offsets_[dim];
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
Indexer::incrementExceptDimension (size_t dim)
|
||||
{
|
||||
assert (ranges_.size() == offsets_.size());
|
||||
for (size_t i = ranges_.size(); i-- > 0; ) {
|
||||
if (i != dim) {
|
||||
indices_[i] ++;
|
||||
index_ += offsets_[i];
|
||||
if (indices_[i] != ranges_[i]) {
|
||||
return;
|
||||
} else {
|
||||
indices_[i] = 0;
|
||||
index_ -= offsets_[i] * ranges_[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
index_ = size_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline Indexer&
|
||||
Indexer::operator++()
|
||||
{
|
||||
increment();
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline
|
||||
Indexer::operator size_t() const
|
||||
{
|
||||
return index_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline unsigned
|
||||
Indexer::operator[] (size_t dim) const
|
||||
{
|
||||
assert (valid());
|
||||
assert (dim < ranges_.size());
|
||||
return indices_[dim];
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline bool
|
||||
Indexer::valid() const
|
||||
{
|
||||
return index_ < size_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
Indexer::reset()
|
||||
{
|
||||
index_ = 0;
|
||||
std::fill (indices_.begin(), indices_.end(), 0);
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
Indexer::resetDimension (size_t dim)
|
||||
{
|
||||
indices_[dim] = 0;
|
||||
index_ -= offsets_[dim] * ranges_[dim];
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline size_t
|
||||
Indexer::size() const
|
||||
{
|
||||
return size_ ;
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
Indexer::calculateOffsets()
|
||||
{
|
||||
size_t prod = 1;
|
||||
offsets_.resize (ranges_.size());
|
||||
for (size_t i = ranges_.size(); i-- > 0; ) {
|
||||
offsets_[i] = prod;
|
||||
prod *= ranges_[i];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
class MapIndexer {
|
||||
public:
|
||||
MapIndexer (const Ranges& ranges, const vector<bool>& mask)
|
||||
: index_(0), indices_(ranges.size(), 0), ranges_(ranges),
|
||||
valid_(true)
|
||||
{
|
||||
size_t prod = 1;
|
||||
offsets_.resize (ranges.size(), 0);
|
||||
for (size_t i = ranges.size(); i-- > 0; ) {
|
||||
if (mask[i]) {
|
||||
offsets_[i] = prod;
|
||||
prod *= ranges[i];
|
||||
}
|
||||
}
|
||||
assert (ranges.size() == mask.size());
|
||||
}
|
||||
MapIndexer (const Ranges& ranges, const std::vector<bool>& mask);
|
||||
|
||||
MapIndexer (const Ranges& ranges, size_t dim)
|
||||
: index_(0), indices_(ranges.size(), 0), ranges_(ranges),
|
||||
valid_(true)
|
||||
{
|
||||
size_t prod = 1;
|
||||
offsets_.resize (ranges.size(), 0);
|
||||
for (size_t i = ranges.size(); i-- > 0; ) {
|
||||
if (i != dim) {
|
||||
offsets_[i] = prod;
|
||||
prod *= ranges[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
MapIndexer (const Ranges& ranges, size_t dim);
|
||||
|
||||
template <typename T>
|
||||
MapIndexer (
|
||||
const vector<T>& allArgs,
|
||||
const Ranges& allRanges,
|
||||
const vector<T>& wantedArgs,
|
||||
const Ranges& wantedRanges)
|
||||
: index_(0), indices_(allArgs.size(), 0), ranges_(allRanges),
|
||||
valid_(true)
|
||||
{
|
||||
size_t prod = 1;
|
||||
vector<size_t> offsets (wantedRanges.size());
|
||||
for (size_t i = wantedRanges.size(); i-- > 0; ) {
|
||||
offsets[i] = prod;
|
||||
prod *= wantedRanges[i];
|
||||
}
|
||||
offsets_.reserve (allArgs.size());
|
||||
for (size_t i = 0; i < allArgs.size(); i++) {
|
||||
size_t idx = Util::indexOf (wantedArgs, allArgs[i]);
|
||||
offsets_.push_back (idx != wantedArgs.size() ? offsets[idx] : 0);
|
||||
}
|
||||
}
|
||||
template <typename T>
|
||||
MapIndexer (
|
||||
const std::vector<T>& allArgs,
|
||||
const Ranges& allRanges,
|
||||
const std::vector<T>& wantedArgs,
|
||||
const Ranges& wantedRanges);
|
||||
|
||||
MapIndexer& operator++ (void)
|
||||
{
|
||||
assert (valid_);
|
||||
for (size_t i = ranges_.size(); i-- > 0; ) {
|
||||
indices_[i] ++;
|
||||
index_ += offsets_[i];
|
||||
if (indices_[i] != ranges_[i]) {
|
||||
return *this;
|
||||
} else {
|
||||
indices_[i] = 0;
|
||||
index_ -= offsets_[i] * ranges_[i];
|
||||
}
|
||||
}
|
||||
valid_ = false;
|
||||
return *this;
|
||||
}
|
||||
MapIndexer& operator++();
|
||||
|
||||
operator size_t (void) const
|
||||
{
|
||||
assert (valid());
|
||||
return index_;
|
||||
}
|
||||
operator size_t() const;
|
||||
|
||||
unsigned operator[] (size_t dim) const
|
||||
{
|
||||
assert (valid());
|
||||
assert (dim < ranges_.size());
|
||||
return indices_[dim];
|
||||
}
|
||||
unsigned operator[] (size_t dim) const;
|
||||
|
||||
bool valid (void) const
|
||||
{
|
||||
return valid_;
|
||||
}
|
||||
bool valid() const;
|
||||
|
||||
void reset (void)
|
||||
{
|
||||
std::fill (indices_.begin(), indices_.end(), 0);
|
||||
index_ = 0;
|
||||
}
|
||||
|
||||
friend std::ostream& operator<< (std::ostream&, const MapIndexer&);
|
||||
void reset();
|
||||
|
||||
private:
|
||||
size_t index_;
|
||||
Ranges indices_;
|
||||
const Ranges& ranges_;
|
||||
bool valid_;
|
||||
vector<size_t> offsets_;
|
||||
friend std::ostream& operator<< (std::ostream&, const MapIndexer&);
|
||||
|
||||
size_t index_;
|
||||
Ranges indices_;
|
||||
const Ranges& ranges_;
|
||||
bool valid_;
|
||||
std::vector<size_t> offsets_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN (MapIndexer);
|
||||
};
|
||||
|
||||
|
||||
|
||||
inline std::ostream&
|
||||
operator<< (std::ostream &os, const MapIndexer& indexer)
|
||||
inline
|
||||
MapIndexer::MapIndexer (
|
||||
const Ranges& ranges,
|
||||
const std::vector<bool>& mask)
|
||||
: index_(0), indices_(ranges.size(), 0), ranges_(ranges),
|
||||
valid_(true)
|
||||
{
|
||||
os << "(" ;
|
||||
os << std::setw (2) << std::setfill('0') << indexer.index_;
|
||||
os << ") " ;
|
||||
os << indexer.indices_;
|
||||
return os;
|
||||
size_t prod = 1;
|
||||
offsets_.resize (ranges.size(), 0);
|
||||
for (size_t i = ranges.size(); i-- > 0; ) {
|
||||
if (mask[i]) {
|
||||
offsets_[i] = prod;
|
||||
prod *= ranges[i];
|
||||
}
|
||||
}
|
||||
assert (ranges.size() == mask.size());
|
||||
}
|
||||
|
||||
|
||||
#endif // HORUS_INDEXER_H
|
||||
|
||||
inline
|
||||
MapIndexer::MapIndexer (const Ranges& ranges, size_t dim)
|
||||
: index_(0), indices_(ranges.size(), 0), ranges_(ranges),
|
||||
valid_(true)
|
||||
{
|
||||
size_t prod = 1;
|
||||
offsets_.resize (ranges.size(), 0);
|
||||
for (size_t i = ranges.size(); i-- > 0; ) {
|
||||
if (i != dim) {
|
||||
offsets_[i] = prod;
|
||||
prod *= ranges[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T> inline
|
||||
MapIndexer::MapIndexer (
|
||||
const std::vector<T>& allArgs,
|
||||
const Ranges& allRanges,
|
||||
const std::vector<T>& wantedArgs,
|
||||
const Ranges& wantedRanges)
|
||||
: index_(0), indices_(allArgs.size(), 0), ranges_(allRanges),
|
||||
valid_(true)
|
||||
{
|
||||
size_t prod = 1;
|
||||
std::vector<size_t> offsets (wantedRanges.size());
|
||||
for (size_t i = wantedRanges.size(); i-- > 0; ) {
|
||||
offsets[i] = prod;
|
||||
prod *= wantedRanges[i];
|
||||
}
|
||||
offsets_.reserve (allArgs.size());
|
||||
for (size_t i = 0; i < allArgs.size(); i++) {
|
||||
size_t idx = Util::indexOf (wantedArgs, allArgs[i]);
|
||||
offsets_.push_back (idx != wantedArgs.size() ? offsets[idx] : 0);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline MapIndexer&
|
||||
MapIndexer::operator++()
|
||||
{
|
||||
assert (valid_);
|
||||
for (size_t i = ranges_.size(); i-- > 0; ) {
|
||||
indices_[i] ++;
|
||||
index_ += offsets_[i];
|
||||
if (indices_[i] != ranges_[i]) {
|
||||
return *this;
|
||||
} else {
|
||||
indices_[i] = 0;
|
||||
index_ -= offsets_[i] * ranges_[i];
|
||||
}
|
||||
}
|
||||
valid_ = false;
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline
|
||||
MapIndexer::operator size_t() const
|
||||
{
|
||||
assert (valid());
|
||||
return index_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline unsigned
|
||||
MapIndexer::operator[] (size_t dim) const
|
||||
{
|
||||
assert (valid());
|
||||
assert (dim < ranges_.size());
|
||||
return indices_[dim];
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline bool
|
||||
MapIndexer::valid() const
|
||||
{
|
||||
return valid_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
MapIndexer::reset()
|
||||
{
|
||||
index_ = 0;
|
||||
std::fill (indices_.begin(), indices_.end(), 0);
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_INDEXER_H_
|
||||
|
||||
|
@ -1,9 +1,15 @@
|
||||
#include <cassert>
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include "LiftedBp.h"
|
||||
#include "LiftedOperations.h"
|
||||
#include "WeightedBp.h"
|
||||
#include "FactorGraph.h"
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
LiftedBp::LiftedBp (const ParfactorList& parfactorList)
|
||||
: LiftedSolver (parfactorList)
|
||||
{
|
||||
@ -14,7 +20,7 @@ LiftedBp::LiftedBp (const ParfactorList& parfactorList)
|
||||
|
||||
|
||||
|
||||
LiftedBp::~LiftedBp (void)
|
||||
LiftedBp::~LiftedBp()
|
||||
{
|
||||
delete solver_;
|
||||
delete fg_;
|
||||
@ -27,7 +33,7 @@ LiftedBp::solveQuery (const Grounds& query)
|
||||
{
|
||||
assert (query.empty() == false);
|
||||
Params res;
|
||||
vector<PrvGroup> groups = getQueryGroups (query);
|
||||
std::vector<PrvGroup> groups = getQueryGroups (query);
|
||||
if (query.size() == 1) {
|
||||
res = solver_->getPosterioriOf (groups[0]);
|
||||
} else {
|
||||
@ -58,28 +64,29 @@ LiftedBp::solveQuery (const Grounds& query)
|
||||
|
||||
|
||||
void
|
||||
LiftedBp::printSolverFlags (void) const
|
||||
LiftedBp::printSolverFlags() const
|
||||
{
|
||||
stringstream ss;
|
||||
std::stringstream ss;
|
||||
ss << "lifted bp [" ;
|
||||
ss << "bp_msg_schedule=" ;
|
||||
typedef WeightedBp::MsgSchedule MsgSchedule;
|
||||
switch (WeightedBp::msgSchedule()) {
|
||||
case MsgSchedule::SEQ_FIXED: ss << "seq_fixed"; break;
|
||||
case MsgSchedule::SEQ_RANDOM: ss << "seq_random"; break;
|
||||
case MsgSchedule::PARALLEL: ss << "parallel"; break;
|
||||
case MsgSchedule::MAX_RESIDUAL: ss << "max_residual"; break;
|
||||
case MsgSchedule::seqFixedSch: ss << "seq_fixed"; break;
|
||||
case MsgSchedule::seqRandomSch: ss << "seq_random"; break;
|
||||
case MsgSchedule::parallelSch: ss << "parallel"; break;
|
||||
case MsgSchedule::maxResidualSch: ss << "max_residual"; break;
|
||||
}
|
||||
ss << ",bp_max_iter=" << WeightedBp::maxIterations();
|
||||
ss << ",bp_accuracy=" << WeightedBp::accuracy();
|
||||
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
||||
ss << "]" ;
|
||||
cout << ss.str() << endl;
|
||||
std::cout << ss.str() << std::endl;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
LiftedBp::refineParfactors (void)
|
||||
LiftedBp::refineParfactors()
|
||||
{
|
||||
pfList_ = parfactorList;
|
||||
while (iterate() == false);
|
||||
@ -93,7 +100,7 @@ LiftedBp::refineParfactors (void)
|
||||
|
||||
|
||||
bool
|
||||
LiftedBp::iterate (void)
|
||||
LiftedBp::iterate()
|
||||
{
|
||||
ParfactorList::iterator it = pfList_.begin();
|
||||
while (it != pfList_.end()) {
|
||||
@ -114,10 +121,10 @@ LiftedBp::iterate (void)
|
||||
|
||||
|
||||
|
||||
vector<PrvGroup>
|
||||
std::vector<PrvGroup>
|
||||
LiftedBp::getQueryGroups (const Grounds& query)
|
||||
{
|
||||
vector<PrvGroup> queryGroups;
|
||||
std::vector<PrvGroup> queryGroups;
|
||||
for (unsigned i = 0; i < query.size(); i++) {
|
||||
ParfactorList::const_iterator it = pfList_.begin();
|
||||
for (; it != pfList_.end(); ++it) {
|
||||
@ -134,12 +141,12 @@ LiftedBp::getQueryGroups (const Grounds& query)
|
||||
|
||||
|
||||
void
|
||||
LiftedBp::createFactorGraph (void)
|
||||
LiftedBp::createFactorGraph()
|
||||
{
|
||||
fg_ = new FactorGraph();
|
||||
ParfactorList::const_iterator it = pfList_.begin();
|
||||
for (; it != pfList_.end(); ++it) {
|
||||
vector<PrvGroup> groups = (*it)->getAllGroups();
|
||||
std::vector<PrvGroup> groups = (*it)->getAllGroups();
|
||||
VarIds varIds;
|
||||
for (size_t i = 0; i < groups.size(); i++) {
|
||||
varIds.push_back (groups[i]);
|
||||
@ -150,10 +157,10 @@ LiftedBp::createFactorGraph (void)
|
||||
|
||||
|
||||
|
||||
vector<vector<unsigned>>
|
||||
LiftedBp::getWeights (void) const
|
||||
std::vector<std::vector<unsigned>>
|
||||
LiftedBp::getWeights() const
|
||||
{
|
||||
vector<vector<unsigned>> weights;
|
||||
std::vector<std::vector<unsigned>> weights;
|
||||
weights.reserve (pfList_.size());
|
||||
ParfactorList::const_iterator it = pfList_.begin();
|
||||
for (; it != pfList_.end(); ++it) {
|
||||
@ -196,7 +203,7 @@ LiftedBp::getJointByConditioning (
|
||||
Grounds obsGrounds = {query[0]};
|
||||
for (size_t i = 1; i < query.size(); i++) {
|
||||
Params newBeliefs;
|
||||
vector<ObservedFormula> obsFs;
|
||||
std::vector<ObservedFormula> obsFs;
|
||||
Ranges obsRanges;
|
||||
for (size_t j = 0; j < obsGrounds.size(); j++) {
|
||||
obsFs.push_back (ObservedFormula (
|
||||
@ -231,3 +238,5 @@ LiftedBp::getJointByConditioning (
|
||||
return prevBeliefs;
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
|
@ -1,33 +1,38 @@
|
||||
#ifndef HORUS_LIFTEDBP_H
|
||||
#define HORUS_LIFTEDBP_H
|
||||
#ifndef YAP_PACKAGES_CLPBN_HORUS_LIFTEDBP_H_
|
||||
#define YAP_PACKAGES_CLPBN_HORUS_LIFTEDBP_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "LiftedSolver.h"
|
||||
#include "ParfactorList.h"
|
||||
#include "Indexer.h"
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
class FactorGraph;
|
||||
class WeightedBp;
|
||||
|
||||
class LiftedBp : public LiftedSolver
|
||||
{
|
||||
class LiftedBp : public LiftedSolver{
|
||||
public:
|
||||
LiftedBp (const ParfactorList& pfList);
|
||||
LiftedBp (const ParfactorList& pfList);
|
||||
|
||||
~LiftedBp (void);
|
||||
~LiftedBp();
|
||||
|
||||
Params solveQuery (const Grounds&);
|
||||
Params solveQuery (const Grounds&);
|
||||
|
||||
void printSolverFlags (void) const;
|
||||
void printSolverFlags() const;
|
||||
|
||||
private:
|
||||
void refineParfactors (void);
|
||||
void refineParfactors();
|
||||
|
||||
bool iterate (void);
|
||||
bool iterate();
|
||||
|
||||
vector<PrvGroup> getQueryGroups (const Grounds&);
|
||||
std::vector<PrvGroup> getQueryGroups (const Grounds&);
|
||||
|
||||
void createFactorGraph (void);
|
||||
void createFactorGraph();
|
||||
|
||||
vector<vector<unsigned>> getWeights (void) const;
|
||||
std::vector<std::vector<unsigned>> getWeights() const;
|
||||
|
||||
unsigned rangeOfGround (const Ground&);
|
||||
|
||||
@ -38,8 +43,9 @@ class LiftedBp : public LiftedSolver
|
||||
FactorGraph* fg_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN (LiftedBp);
|
||||
|
||||
};
|
||||
|
||||
#endif // HORUS_LIFTEDBP_H
|
||||
} // namespace Horus
|
||||
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_LIFTEDBP_H_
|
||||
|
||||
|
@ -1,11 +1,283 @@
|
||||
#include <cassert>
|
||||
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include <string>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
|
||||
#include "LiftedKc.h"
|
||||
#include "LiftedWCNF.h"
|
||||
#include "LiftedOperations.h"
|
||||
#include "Indexer.h"
|
||||
|
||||
|
||||
OrNode::~OrNode (void)
|
||||
namespace Horus {
|
||||
|
||||
enum class CircuitNodeType {
|
||||
orCnt,
|
||||
andCnt,
|
||||
setOrCnt,
|
||||
setAndCnt,
|
||||
incExcCnt,
|
||||
leafCnt,
|
||||
smoothCnt,
|
||||
trueCnt,
|
||||
compilationFailedCnt
|
||||
};
|
||||
|
||||
|
||||
|
||||
class CircuitNode {
|
||||
public:
|
||||
CircuitNode() { }
|
||||
|
||||
virtual ~CircuitNode() { }
|
||||
|
||||
virtual double weight() const = 0;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class OrNode : public CircuitNode {
|
||||
public:
|
||||
OrNode() : CircuitNode(), leftBranch_(0), rightBranch_(0) { }
|
||||
|
||||
~OrNode();
|
||||
|
||||
CircuitNode** leftBranch () { return &leftBranch_; }
|
||||
CircuitNode** rightBranch() { return &rightBranch_; }
|
||||
|
||||
double weight() const;
|
||||
|
||||
private:
|
||||
CircuitNode* leftBranch_;
|
||||
CircuitNode* rightBranch_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class AndNode : public CircuitNode {
|
||||
public:
|
||||
AndNode() : CircuitNode(), leftBranch_(0), rightBranch_(0) { }
|
||||
|
||||
AndNode (CircuitNode* leftBranch, CircuitNode* rightBranch)
|
||||
: CircuitNode(), leftBranch_(leftBranch),
|
||||
rightBranch_(rightBranch) { }
|
||||
|
||||
~AndNode();
|
||||
|
||||
CircuitNode** leftBranch () { return &leftBranch_; }
|
||||
CircuitNode** rightBranch() { return &rightBranch_; }
|
||||
|
||||
double weight() const;
|
||||
|
||||
private:
|
||||
CircuitNode* leftBranch_;
|
||||
CircuitNode* rightBranch_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class SetOrNode : public CircuitNode {
|
||||
public:
|
||||
SetOrNode (unsigned nrGroundings)
|
||||
: CircuitNode(), follow_(0), nrGroundings_(nrGroundings) { }
|
||||
|
||||
~SetOrNode();
|
||||
|
||||
CircuitNode** follow() { return &follow_; }
|
||||
|
||||
static unsigned nrPositives() { return nrPos_; }
|
||||
|
||||
static unsigned nrNegatives() { return nrNeg_; }
|
||||
|
||||
static bool isSet() { return nrPos_ >= 0; }
|
||||
|
||||
double weight() const;
|
||||
|
||||
private:
|
||||
CircuitNode* follow_;
|
||||
unsigned nrGroundings_;
|
||||
static int nrPos_;
|
||||
static int nrNeg_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class SetAndNode : public CircuitNode {
|
||||
public:
|
||||
SetAndNode (unsigned nrGroundings)
|
||||
: CircuitNode(), follow_(0), nrGroundings_(nrGroundings) { }
|
||||
|
||||
~SetAndNode();
|
||||
|
||||
CircuitNode** follow() { return &follow_; }
|
||||
|
||||
double weight() const;
|
||||
|
||||
private:
|
||||
CircuitNode* follow_;
|
||||
unsigned nrGroundings_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class IncExcNode : public CircuitNode {
|
||||
public:
|
||||
IncExcNode()
|
||||
: CircuitNode(), plus1Branch_(0), plus2Branch_(0), minusBranch_(0) { }
|
||||
|
||||
~IncExcNode();
|
||||
|
||||
CircuitNode** plus1Branch() { return &plus1Branch_; }
|
||||
CircuitNode** plus2Branch() { return &plus2Branch_; }
|
||||
CircuitNode** minusBranch() { return &minusBranch_; }
|
||||
|
||||
double weight() const;
|
||||
|
||||
private:
|
||||
CircuitNode* plus1Branch_;
|
||||
CircuitNode* plus2Branch_;
|
||||
CircuitNode* minusBranch_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class LeafNode : public CircuitNode {
|
||||
public:
|
||||
LeafNode (Clause* clause, const LiftedWCNF& lwcnf)
|
||||
: CircuitNode(), clause_(clause), lwcnf_(lwcnf) { }
|
||||
|
||||
~LeafNode();
|
||||
|
||||
const Clause* clause() const { return clause_; }
|
||||
|
||||
Clause* clause() { return clause_; }
|
||||
|
||||
double weight() const;
|
||||
|
||||
private:
|
||||
Clause* clause_;
|
||||
const LiftedWCNF& lwcnf_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class SmoothNode : public CircuitNode {
|
||||
public:
|
||||
SmoothNode (const Clauses& clauses, const LiftedWCNF& lwcnf)
|
||||
: CircuitNode(), clauses_(clauses), lwcnf_(lwcnf) { }
|
||||
|
||||
~SmoothNode();
|
||||
|
||||
const Clauses& clauses() const { return clauses_; }
|
||||
|
||||
Clauses clauses() { return clauses_; }
|
||||
|
||||
double weight() const;
|
||||
|
||||
private:
|
||||
Clauses clauses_;
|
||||
const LiftedWCNF& lwcnf_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class TrueNode : public CircuitNode {
|
||||
public:
|
||||
TrueNode() : CircuitNode() { }
|
||||
|
||||
double weight() const;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class CompilationFailedNode : public CircuitNode {
|
||||
public:
|
||||
CompilationFailedNode() : CircuitNode() { }
|
||||
|
||||
double weight() const;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class LiftedCircuit {
|
||||
public:
|
||||
LiftedCircuit (const LiftedWCNF* lwcnf);
|
||||
|
||||
~LiftedCircuit();
|
||||
|
||||
bool isCompilationSucceeded() const;
|
||||
|
||||
double getWeightedModelCount() const;
|
||||
|
||||
void exportToGraphViz (const char*);
|
||||
|
||||
private:
|
||||
void compile (CircuitNode** follow, Clauses& clauses);
|
||||
|
||||
bool tryUnitPropagation (CircuitNode** follow, Clauses& clauses);
|
||||
|
||||
bool tryIndependence (CircuitNode** follow, Clauses& clauses);
|
||||
|
||||
bool tryShannonDecomp (CircuitNode** follow, Clauses& clauses);
|
||||
|
||||
bool tryInclusionExclusion (CircuitNode** follow, Clauses& clauses);
|
||||
|
||||
bool tryIndepPartialGrounding (CircuitNode** follow, Clauses& clauses);
|
||||
|
||||
bool tryIndepPartialGroundingAux (Clauses& clauses, ConstraintTree& ct,
|
||||
LogVars& rootLogVars);
|
||||
|
||||
bool tryAtomCounting (CircuitNode** follow, Clauses& clauses);
|
||||
|
||||
void shatterCountedLogVars (Clauses& clauses);
|
||||
|
||||
bool shatterCountedLogVarsAux (Clauses& clauses);
|
||||
|
||||
bool shatterCountedLogVarsAux (Clauses& clauses,
|
||||
size_t idx1, size_t idx2);
|
||||
|
||||
bool independentClause (Clause& clause, Clauses& otherClauses) const;
|
||||
|
||||
bool independentLiteral (const Literal& lit,
|
||||
const Literals& otherLits) const;
|
||||
|
||||
LitLvTypesSet smoothCircuit (CircuitNode* node);
|
||||
|
||||
void createSmoothNode (const LitLvTypesSet& lids,
|
||||
CircuitNode** prev);
|
||||
|
||||
std::vector<LogVarTypes> getAllPossibleTypes (unsigned nrLogVars) const;
|
||||
|
||||
bool containsTypes (const LogVarTypes& typesA,
|
||||
const LogVarTypes& typesB) const;
|
||||
|
||||
CircuitNodeType getCircuitNodeType (const CircuitNode* node) const;
|
||||
|
||||
void exportToGraphViz (CircuitNode* node, std::ofstream&);
|
||||
|
||||
void printClauses (CircuitNode* node, std::ofstream&,
|
||||
std::string extraOptions = "");
|
||||
|
||||
std::string escapeNode (const CircuitNode* node) const;
|
||||
|
||||
std::string getExplanationString (CircuitNode* node);
|
||||
|
||||
CircuitNode* root_;
|
||||
const LiftedWCNF* lwcnf_;
|
||||
bool compilationSucceeded_;
|
||||
Clauses backupClauses_;
|
||||
std::unordered_map<CircuitNode*, Clauses> originClausesMap_;
|
||||
std::unordered_map<CircuitNode*, std::string> explanationMap_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN (LiftedCircuit);
|
||||
};
|
||||
|
||||
|
||||
|
||||
OrNode::~OrNode()
|
||||
{
|
||||
delete leftBranch_;
|
||||
delete rightBranch_;
|
||||
@ -14,7 +286,7 @@ OrNode::~OrNode (void)
|
||||
|
||||
|
||||
double
|
||||
OrNode::weight (void) const
|
||||
OrNode::weight() const
|
||||
{
|
||||
double lw = leftBranch_->weight();
|
||||
double rw = rightBranch_->weight();
|
||||
@ -23,7 +295,7 @@ OrNode::weight (void) const
|
||||
|
||||
|
||||
|
||||
AndNode::~AndNode (void)
|
||||
AndNode::~AndNode()
|
||||
{
|
||||
delete leftBranch_;
|
||||
delete rightBranch_;
|
||||
@ -32,7 +304,7 @@ AndNode::~AndNode (void)
|
||||
|
||||
|
||||
double
|
||||
AndNode::weight (void) const
|
||||
AndNode::weight() const
|
||||
{
|
||||
double lw = leftBranch_->weight();
|
||||
double rw = rightBranch_->weight();
|
||||
@ -46,7 +318,7 @@ int SetOrNode::nrNeg_ = -1;
|
||||
|
||||
|
||||
|
||||
SetOrNode::~SetOrNode (void)
|
||||
SetOrNode::~SetOrNode()
|
||||
{
|
||||
delete follow_;
|
||||
}
|
||||
@ -54,7 +326,7 @@ SetOrNode::~SetOrNode (void)
|
||||
|
||||
|
||||
double
|
||||
SetOrNode::weight (void) const
|
||||
SetOrNode::weight() const
|
||||
{
|
||||
double weightSum = LogAware::addIdenty();
|
||||
for (unsigned i = 0; i < nrGroundings_ + 1; i++) {
|
||||
@ -76,7 +348,7 @@ SetOrNode::weight (void) const
|
||||
|
||||
|
||||
|
||||
SetAndNode::~SetAndNode (void)
|
||||
SetAndNode::~SetAndNode()
|
||||
{
|
||||
delete follow_;
|
||||
}
|
||||
@ -84,14 +356,14 @@ SetAndNode::~SetAndNode (void)
|
||||
|
||||
|
||||
double
|
||||
SetAndNode::weight (void) const
|
||||
SetAndNode::weight() const
|
||||
{
|
||||
return LogAware::pow (follow_->weight(), nrGroundings_);
|
||||
}
|
||||
|
||||
|
||||
|
||||
IncExcNode::~IncExcNode (void)
|
||||
IncExcNode::~IncExcNode()
|
||||
{
|
||||
delete plus1Branch_;
|
||||
delete plus2Branch_;
|
||||
@ -101,7 +373,7 @@ IncExcNode::~IncExcNode (void)
|
||||
|
||||
|
||||
double
|
||||
IncExcNode::weight (void) const
|
||||
IncExcNode::weight() const
|
||||
{
|
||||
double w = 0.0;
|
||||
if (Globals::logDomain) {
|
||||
@ -116,7 +388,7 @@ IncExcNode::weight (void) const
|
||||
|
||||
|
||||
|
||||
LeafNode::~LeafNode (void)
|
||||
LeafNode::~LeafNode()
|
||||
{
|
||||
delete clause_;
|
||||
}
|
||||
@ -124,7 +396,7 @@ LeafNode::~LeafNode (void)
|
||||
|
||||
|
||||
double
|
||||
LeafNode::weight (void) const
|
||||
LeafNode::weight() const
|
||||
{
|
||||
assert (clause_->isUnit());
|
||||
if (clause_->posCountedLogVars().empty() == false
|
||||
@ -161,7 +433,7 @@ LeafNode::weight (void) const
|
||||
|
||||
|
||||
|
||||
SmoothNode::~SmoothNode (void)
|
||||
SmoothNode::~SmoothNode()
|
||||
{
|
||||
Clause::deleteClauses (clauses_);
|
||||
}
|
||||
@ -169,7 +441,7 @@ SmoothNode::~SmoothNode (void)
|
||||
|
||||
|
||||
double
|
||||
SmoothNode::weight (void) const
|
||||
SmoothNode::weight() const
|
||||
{
|
||||
Clauses cs = clauses();
|
||||
double totalWeight = LogAware::multIdenty();
|
||||
@ -204,7 +476,7 @@ SmoothNode::weight (void) const
|
||||
|
||||
|
||||
double
|
||||
TrueNode::weight (void) const
|
||||
TrueNode::weight() const
|
||||
{
|
||||
return LogAware::multIdenty();
|
||||
}
|
||||
@ -212,7 +484,7 @@ TrueNode::weight (void) const
|
||||
|
||||
|
||||
double
|
||||
CompilationFailedNode::weight (void) const
|
||||
CompilationFailedNode::weight() const
|
||||
{
|
||||
// weighted model counting in compilation
|
||||
// failed nodes should give NaN
|
||||
@ -234,21 +506,22 @@ LiftedCircuit::LiftedCircuit (const LiftedWCNF* lwcnf)
|
||||
if (Globals::verbosity > 1) {
|
||||
if (compilationSucceeded_) {
|
||||
double wmc = LogAware::exp (getWeightedModelCount());
|
||||
cout << "Weighted model count = " << wmc << endl << endl;
|
||||
std::cout << "Weighted model count = " << wmc;
|
||||
std::cout << std::endl << std::endl;
|
||||
}
|
||||
cout << "Exporting circuit to graphviz (circuit.dot)..." ;
|
||||
cout << endl << endl;
|
||||
std::cout << "Exporting circuit to graphviz (circuit.dot)..." ;
|
||||
std::cout << std::endl << std::endl;
|
||||
exportToGraphViz ("circuit.dot");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
LiftedCircuit::~LiftedCircuit (void)
|
||||
LiftedCircuit::~LiftedCircuit()
|
||||
{
|
||||
delete root_;
|
||||
unordered_map<CircuitNode*, Clauses>::iterator it;
|
||||
it = originClausesMap_.begin();
|
||||
std::unordered_map<CircuitNode*, Clauses>::iterator it
|
||||
= originClausesMap_.begin();
|
||||
while (it != originClausesMap_.end()) {
|
||||
Clause::deleteClauses (it->second);
|
||||
++ it;
|
||||
@ -258,7 +531,7 @@ LiftedCircuit::~LiftedCircuit (void)
|
||||
|
||||
|
||||
bool
|
||||
LiftedCircuit::isCompilationSucceeded (void) const
|
||||
LiftedCircuit::isCompilationSucceeded() const
|
||||
{
|
||||
return compilationSucceeded_;
|
||||
}
|
||||
@ -266,7 +539,7 @@ LiftedCircuit::isCompilationSucceeded (void) const
|
||||
|
||||
|
||||
double
|
||||
LiftedCircuit::getWeightedModelCount (void) const
|
||||
LiftedCircuit::getWeightedModelCount() const
|
||||
{
|
||||
assert (compilationSucceeded_);
|
||||
return root_->weight();
|
||||
@ -277,15 +550,16 @@ LiftedCircuit::getWeightedModelCount (void) const
|
||||
void
|
||||
LiftedCircuit::exportToGraphViz (const char* fileName)
|
||||
{
|
||||
ofstream out (fileName);
|
||||
std::ofstream out (fileName);
|
||||
if (!out.is_open()) {
|
||||
cerr << "Error: couldn't open file '" << fileName << "'." ;
|
||||
std::cerr << "Error: couldn't open file '" << fileName << "'." ;
|
||||
std::cerr << std::endl;
|
||||
return;
|
||||
}
|
||||
out << "digraph {" << endl;
|
||||
out << "ranksep=1" << endl;
|
||||
out << "digraph {" << std::endl;
|
||||
out << "ranksep=1" << std::endl;
|
||||
exportToGraphViz (root_, out);
|
||||
out << "}" << endl;
|
||||
out << "}" << std::endl;
|
||||
out.close();
|
||||
}
|
||||
|
||||
@ -389,7 +663,7 @@ LiftedCircuit::tryUnitPropagation (
|
||||
AndNode* andNode = new AndNode();
|
||||
if (Globals::verbosity > 1) {
|
||||
originClausesMap_[andNode] = backupClauses_;
|
||||
stringstream explanation;
|
||||
std::stringstream explanation;
|
||||
explanation << " UP on " << clauses[i]->literals()[0];
|
||||
explanationMap_[andNode] = explanation.str();
|
||||
}
|
||||
@ -478,7 +752,7 @@ LiftedCircuit::tryShannonDecomp (
|
||||
OrNode* orNode = new OrNode();
|
||||
if (Globals::verbosity > 1) {
|
||||
originClausesMap_[orNode] = backupClauses_;
|
||||
stringstream explanation;
|
||||
std::stringstream explanation;
|
||||
explanation << " SD on " << literals[j];
|
||||
explanationMap_[orNode] = explanation.str();
|
||||
}
|
||||
@ -558,7 +832,7 @@ LiftedCircuit::tryInclusionExclusion (
|
||||
IncExcNode* ieNode = new IncExcNode();
|
||||
if (Globals::verbosity > 1) {
|
||||
originClausesMap_[ieNode] = backupClauses_;
|
||||
stringstream explanation;
|
||||
std::stringstream explanation;
|
||||
explanation << " IncExc on clause nº " << i + 1;
|
||||
explanationMap_[ieNode] = explanation.str();
|
||||
}
|
||||
@ -635,13 +909,13 @@ LiftedCircuit::tryIndepPartialGroundingAux (
|
||||
}
|
||||
}
|
||||
// verifies if the IPG logical vars appear in the same positions
|
||||
unordered_map<LiteralId, size_t> positions;
|
||||
std::unordered_map<LiteralId, size_t> positions;
|
||||
for (size_t i = 0; i < clauses.size(); i++) {
|
||||
const Literals& literals = clauses[i]->literals();
|
||||
for (size_t j = 0; j < literals.size(); j++) {
|
||||
size_t idx = literals[j].indexOfLogVar (rootLogVars[i]);
|
||||
assert (idx != literals[j].nrLogVars());
|
||||
unordered_map<LiteralId, size_t>::iterator it;
|
||||
std::unordered_map<LiteralId, size_t>::iterator it;
|
||||
it = positions.find (literals[j].lid());
|
||||
if (it != positions.end()) {
|
||||
if (it->second != idx) {
|
||||
@ -810,7 +1084,7 @@ LiftedCircuit::smoothCircuit (CircuitNode* node)
|
||||
|
||||
switch (getCircuitNodeType (node)) {
|
||||
|
||||
case CircuitNodeType::OR_NODE: {
|
||||
case CircuitNodeType::orCnt: {
|
||||
OrNode* casted = dynamic_cast<OrNode*>(node);
|
||||
LitLvTypesSet lids1 = smoothCircuit (*casted->leftBranch());
|
||||
LitLvTypesSet lids2 = smoothCircuit (*casted->rightBranch());
|
||||
@ -823,7 +1097,7 @@ LiftedCircuit::smoothCircuit (CircuitNode* node)
|
||||
break;
|
||||
}
|
||||
|
||||
case CircuitNodeType::AND_NODE: {
|
||||
case CircuitNodeType::andCnt: {
|
||||
AndNode* casted = dynamic_cast<AndNode*>(node);
|
||||
LitLvTypesSet lids1 = smoothCircuit (*casted->leftBranch());
|
||||
LitLvTypesSet lids2 = smoothCircuit (*casted->rightBranch());
|
||||
@ -832,17 +1106,18 @@ LiftedCircuit::smoothCircuit (CircuitNode* node)
|
||||
break;
|
||||
}
|
||||
|
||||
case CircuitNodeType::SET_OR_NODE: {
|
||||
case CircuitNodeType::setOrCnt: {
|
||||
SetOrNode* casted = dynamic_cast<SetOrNode*>(node);
|
||||
propagLits = smoothCircuit (*casted->follow());
|
||||
TinySet<pair<LiteralId,unsigned>> litSet;
|
||||
TinySet<std::pair<LiteralId,unsigned>> litSet;
|
||||
for (size_t i = 0; i < propagLits.size(); i++) {
|
||||
litSet.insert (make_pair (propagLits[i].lid(),
|
||||
litSet.insert (std::make_pair (propagLits[i].lid(),
|
||||
propagLits[i].logVarTypes().size()));
|
||||
}
|
||||
LitLvTypesSet missingLids;
|
||||
for (size_t i = 0; i < litSet.size(); i++) {
|
||||
vector<LogVarTypes> allTypes = getAllPossibleTypes (litSet[i].second);
|
||||
std::vector<LogVarTypes> allTypes
|
||||
= getAllPossibleTypes (litSet[i].second);
|
||||
for (size_t j = 0; j < allTypes.size(); j++) {
|
||||
bool typeFound = false;
|
||||
for (size_t k = 0; k < propagLits.size(); k++) {
|
||||
@ -869,13 +1144,13 @@ LiftedCircuit::smoothCircuit (CircuitNode* node)
|
||||
break;
|
||||
}
|
||||
|
||||
case CircuitNodeType::SET_AND_NODE: {
|
||||
case CircuitNodeType::setAndCnt: {
|
||||
SetAndNode* casted = dynamic_cast<SetAndNode*>(node);
|
||||
propagLits = smoothCircuit (*casted->follow());
|
||||
break;
|
||||
}
|
||||
|
||||
case CircuitNodeType::INC_EXC_NODE: {
|
||||
case CircuitNodeType::incExcCnt: {
|
||||
IncExcNode* casted = dynamic_cast<IncExcNode*>(node);
|
||||
LitLvTypesSet lids1 = smoothCircuit (*casted->plus1Branch());
|
||||
LitLvTypesSet lids2 = smoothCircuit (*casted->plus2Branch());
|
||||
@ -888,7 +1163,7 @@ LiftedCircuit::smoothCircuit (CircuitNode* node)
|
||||
break;
|
||||
}
|
||||
|
||||
case CircuitNodeType::LEAF_NODE: {
|
||||
case CircuitNodeType::leafCnt: {
|
||||
LeafNode* casted = dynamic_cast<LeafNode*>(node);
|
||||
propagLits.insert (LitLvTypes (
|
||||
casted->clause()->literals()[0].lid(),
|
||||
@ -911,8 +1186,8 @@ LiftedCircuit::createSmoothNode (
|
||||
{
|
||||
if (missingLits.empty() == false) {
|
||||
if (Globals::verbosity > 1) {
|
||||
unordered_map<CircuitNode*, Clauses>::iterator it;
|
||||
it = originClausesMap_.find (*prev);
|
||||
std::unordered_map<CircuitNode*, Clauses>::iterator it
|
||||
= originClausesMap_.find (*prev);
|
||||
if (it != originClausesMap_.end()) {
|
||||
backupClauses_ = it->second;
|
||||
} else {
|
||||
@ -927,9 +1202,9 @@ LiftedCircuit::createSmoothNode (
|
||||
Clause* c = lwcnf_->createClause (lid);
|
||||
for (size_t j = 0; j < types.size(); j++) {
|
||||
LogVar X = c->literals().front().logVars()[j];
|
||||
if (types[j] == LogVarType::POS_LV) {
|
||||
if (types[j] == LogVarType::posLvt) {
|
||||
c->addPosCountedLogVar (X);
|
||||
} else if (types[j] == LogVarType::NEG_LV) {
|
||||
} else if (types[j] == LogVarType::negLvt) {
|
||||
c->addNegCountedLogVar (X);
|
||||
}
|
||||
}
|
||||
@ -947,15 +1222,15 @@ LiftedCircuit::createSmoothNode (
|
||||
|
||||
|
||||
|
||||
vector<LogVarTypes>
|
||||
std::vector<LogVarTypes>
|
||||
LiftedCircuit::getAllPossibleTypes (unsigned nrLogVars) const
|
||||
{
|
||||
vector<LogVarTypes> res;
|
||||
std::vector<LogVarTypes> res;
|
||||
if (nrLogVars == 0) {
|
||||
// do nothing
|
||||
} else if (nrLogVars == 1) {
|
||||
res.push_back ({ LogVarType::POS_LV });
|
||||
res.push_back ({ LogVarType::NEG_LV });
|
||||
res.push_back ({ LogVarType::posLvt });
|
||||
res.push_back ({ LogVarType::negLvt });
|
||||
} else {
|
||||
Ranges ranges (nrLogVars, 2);
|
||||
Indexer indexer (ranges);
|
||||
@ -963,9 +1238,9 @@ LiftedCircuit::getAllPossibleTypes (unsigned nrLogVars) const
|
||||
LogVarTypes types;
|
||||
for (size_t i = 0; i < nrLogVars; i++) {
|
||||
if (indexer[i] == 0) {
|
||||
types.push_back (LogVarType::POS_LV);
|
||||
types.push_back (LogVarType::posLvt);
|
||||
} else {
|
||||
types.push_back (LogVarType::NEG_LV);
|
||||
types.push_back (LogVarType::negLvt);
|
||||
}
|
||||
}
|
||||
res.push_back (types);
|
||||
@ -983,13 +1258,13 @@ LiftedCircuit::containsTypes (
|
||||
const LogVarTypes& typesB) const
|
||||
{
|
||||
for (size_t i = 0; i < typesA.size(); i++) {
|
||||
if (typesA[i] == LogVarType::FULL_LV) {
|
||||
if (typesA[i] == LogVarType::fullLvt) {
|
||||
|
||||
} else if (typesA[i] == LogVarType::POS_LV
|
||||
&& typesB[i] == LogVarType::POS_LV) {
|
||||
} else if (typesA[i] == LogVarType::posLvt
|
||||
&& typesB[i] == LogVarType::posLvt) {
|
||||
|
||||
} else if (typesA[i] == LogVarType::NEG_LV
|
||||
&& typesB[i] == LogVarType::NEG_LV) {
|
||||
} else if (typesA[i] == LogVarType::negLvt
|
||||
&& typesB[i] == LogVarType::negLvt) {
|
||||
|
||||
} else {
|
||||
return false;
|
||||
@ -1003,25 +1278,25 @@ LiftedCircuit::containsTypes (
|
||||
CircuitNodeType
|
||||
LiftedCircuit::getCircuitNodeType (const CircuitNode* node) const
|
||||
{
|
||||
CircuitNodeType type = CircuitNodeType::OR_NODE;
|
||||
CircuitNodeType type = CircuitNodeType::orCnt;
|
||||
if (dynamic_cast<const OrNode*>(node)) {
|
||||
type = CircuitNodeType::OR_NODE;
|
||||
type = CircuitNodeType::orCnt;
|
||||
} else if (dynamic_cast<const AndNode*>(node)) {
|
||||
type = CircuitNodeType::AND_NODE;
|
||||
type = CircuitNodeType::andCnt;
|
||||
} else if (dynamic_cast<const SetOrNode*>(node)) {
|
||||
type = CircuitNodeType::SET_OR_NODE;
|
||||
type = CircuitNodeType::setOrCnt;
|
||||
} else if (dynamic_cast<const SetAndNode*>(node)) {
|
||||
type = CircuitNodeType::SET_AND_NODE;
|
||||
type = CircuitNodeType::setAndCnt;
|
||||
} else if (dynamic_cast<const IncExcNode*>(node)) {
|
||||
type = CircuitNodeType::INC_EXC_NODE;
|
||||
type = CircuitNodeType::incExcCnt;
|
||||
} else if (dynamic_cast<const LeafNode*>(node)) {
|
||||
type = CircuitNodeType::LEAF_NODE;
|
||||
type = CircuitNodeType::leafCnt;
|
||||
} else if (dynamic_cast<const SmoothNode*>(node)) {
|
||||
type = CircuitNodeType::SMOOTH_NODE;
|
||||
type = CircuitNodeType::smoothCnt;
|
||||
} else if (dynamic_cast<const TrueNode*>(node)) {
|
||||
type = CircuitNodeType::TRUE_NODE;
|
||||
type = CircuitNodeType::trueCnt;
|
||||
} else if (dynamic_cast<const CompilationFailedNode*>(node)) {
|
||||
type = CircuitNodeType::COMPILATION_FAILED_NODE;
|
||||
type = CircuitNodeType::compilationFailedCnt;
|
||||
} else {
|
||||
assert (false);
|
||||
}
|
||||
@ -1031,127 +1306,131 @@ LiftedCircuit::getCircuitNodeType (const CircuitNode* node) const
|
||||
|
||||
|
||||
void
|
||||
LiftedCircuit::exportToGraphViz (CircuitNode* node, ofstream& os)
|
||||
LiftedCircuit::exportToGraphViz (CircuitNode* node, std::ofstream& os)
|
||||
{
|
||||
assert (node);
|
||||
|
||||
static unsigned nrAuxNodes = 0;
|
||||
stringstream ss;
|
||||
std::stringstream ss;
|
||||
ss << "n" << nrAuxNodes;
|
||||
string auxNode = ss.str();
|
||||
std::string auxNode = ss.str();
|
||||
nrAuxNodes ++;
|
||||
string opStyle = "shape=circle,width=0.7,margin=\"0.0,0.0\"," ;
|
||||
std::string opStyle = "shape=circle,width=0.7,margin=\"0.0,0.0\"," ;
|
||||
|
||||
switch (getCircuitNodeType (node)) {
|
||||
|
||||
case OR_NODE: {
|
||||
case CircuitNodeType::orCnt: {
|
||||
OrNode* casted = dynamic_cast<OrNode*>(node);
|
||||
printClauses (casted, os);
|
||||
|
||||
os << auxNode << " [" << opStyle << "label=\"∨\"]" << endl;
|
||||
os << auxNode << " [" << opStyle << "label=\"∨\"]" ;
|
||||
os << std::endl;
|
||||
os << escapeNode (node) << " -> " << auxNode;
|
||||
os << " [label=\"" << getExplanationString (node) << "\"]" ;
|
||||
os << endl;
|
||||
os << std::endl;
|
||||
|
||||
os << auxNode << " -> " ;
|
||||
os << escapeNode (*casted->leftBranch());
|
||||
os << " [label=\" " << (*casted->leftBranch())->weight() << "\"]" ;
|
||||
os << endl;
|
||||
os << std::endl;
|
||||
|
||||
os << auxNode << " -> " ;
|
||||
os << escapeNode (*casted->rightBranch());
|
||||
os << " [label=\" " << (*casted->rightBranch())->weight() << "\"]" ;
|
||||
os << endl;
|
||||
os << std::endl;
|
||||
|
||||
exportToGraphViz (*casted->leftBranch(), os);
|
||||
exportToGraphViz (*casted->rightBranch(), os);
|
||||
break;
|
||||
}
|
||||
|
||||
case AND_NODE: {
|
||||
case CircuitNodeType::andCnt: {
|
||||
AndNode* casted = dynamic_cast<AndNode*>(node);
|
||||
printClauses (casted, os);
|
||||
|
||||
os << auxNode << " [" << opStyle << "label=\"∧\"]" << endl;
|
||||
os << auxNode << " [" << opStyle << "label=\"∧\"]" ;
|
||||
os << std::endl;
|
||||
os << escapeNode (node) << " -> " << auxNode;
|
||||
os << " [label=\"" << getExplanationString (node) << "\"]" ;
|
||||
os << endl;
|
||||
os << std::endl;
|
||||
|
||||
os << auxNode << " -> " ;
|
||||
os << escapeNode (*casted->leftBranch());
|
||||
os << " [label=\" " << (*casted->leftBranch())->weight() << "\"]" ;
|
||||
os << endl;
|
||||
os << std::endl;
|
||||
|
||||
os << auxNode << " -> " ;
|
||||
os << escapeNode (*casted->rightBranch()) << endl;
|
||||
os << escapeNode (*casted->rightBranch());
|
||||
os << " [label=\" " << (*casted->rightBranch())->weight() << "\"]" ;
|
||||
os << endl;
|
||||
os << std::endl;
|
||||
|
||||
exportToGraphViz (*casted->leftBranch(), os);
|
||||
exportToGraphViz (*casted->rightBranch(), os);
|
||||
break;
|
||||
}
|
||||
|
||||
case SET_OR_NODE: {
|
||||
case CircuitNodeType::setOrCnt: {
|
||||
SetOrNode* casted = dynamic_cast<SetOrNode*>(node);
|
||||
printClauses (casted, os);
|
||||
|
||||
os << auxNode << " [" << opStyle << "label=\"∨(X)\"]" << endl;
|
||||
os << auxNode << " [" << opStyle << "label=\"∨(X)\"]" ;
|
||||
os << std::endl;
|
||||
os << escapeNode (node) << " -> " << auxNode;
|
||||
os << " [label=\"" << getExplanationString (node) << "\"]" ;
|
||||
os << endl;
|
||||
os << std::endl;
|
||||
|
||||
os << auxNode << " -> " ;
|
||||
os << escapeNode (*casted->follow());
|
||||
os << " [label=\" " << (*casted->follow())->weight() << "\"]" ;
|
||||
os << endl;
|
||||
os << std::endl;
|
||||
|
||||
exportToGraphViz (*casted->follow(), os);
|
||||
break;
|
||||
}
|
||||
|
||||
case SET_AND_NODE: {
|
||||
case CircuitNodeType::setAndCnt: {
|
||||
SetAndNode* casted = dynamic_cast<SetAndNode*>(node);
|
||||
printClauses (casted, os);
|
||||
|
||||
os << auxNode << " [" << opStyle << "label=\"∧(X)\"]" << endl;
|
||||
os << auxNode << " [" << opStyle << "label=\"∧(X)\"]" ;
|
||||
os << std::endl;
|
||||
os << escapeNode (node) << " -> " << auxNode;
|
||||
os << " [label=\"" << getExplanationString (node) << "\"]" ;
|
||||
os << endl;
|
||||
os << std::endl;
|
||||
|
||||
os << auxNode << " -> " ;
|
||||
os << escapeNode (*casted->follow());
|
||||
os << " [label=\" " << (*casted->follow())->weight() << "\"]" ;
|
||||
os << endl;
|
||||
os << std::endl;
|
||||
|
||||
exportToGraphViz (*casted->follow(), os);
|
||||
break;
|
||||
}
|
||||
|
||||
case INC_EXC_NODE: {
|
||||
case CircuitNodeType::incExcCnt: {
|
||||
IncExcNode* casted = dynamic_cast<IncExcNode*>(node);
|
||||
printClauses (casted, os);
|
||||
|
||||
os << auxNode << " [" << opStyle << "label=\"+ - +\"]" ;
|
||||
os << endl;
|
||||
os << std::endl;
|
||||
os << escapeNode (node) << " -> " << auxNode;
|
||||
os << " [label=\"" << getExplanationString (node) << "\"]" ;
|
||||
os << endl;
|
||||
os << std::endl;
|
||||
|
||||
os << auxNode << " -> " ;
|
||||
os << escapeNode (*casted->plus1Branch());
|
||||
os << " [label=\" " << (*casted->plus1Branch())->weight() << "\"]" ;
|
||||
os << endl;
|
||||
os << std::endl;
|
||||
|
||||
os << auxNode << " -> " ;
|
||||
os << escapeNode (*casted->minusBranch()) << endl;
|
||||
os << escapeNode (*casted->minusBranch()) << std::endl;
|
||||
os << " [label=\" " << (*casted->minusBranch())->weight() << "\"]" ;
|
||||
os << endl;
|
||||
os << std::endl;
|
||||
|
||||
os << auxNode << " -> " ;
|
||||
os << escapeNode (*casted->plus2Branch());
|
||||
os << " [label=\" " << (*casted->plus2Branch())->weight() << "\"]" ;
|
||||
os << endl;
|
||||
os << std::endl;
|
||||
|
||||
exportToGraphViz (*casted->plus1Branch(), os);
|
||||
exportToGraphViz (*casted->plus2Branch(), os);
|
||||
@ -1159,24 +1438,24 @@ LiftedCircuit::exportToGraphViz (CircuitNode* node, ofstream& os)
|
||||
break;
|
||||
}
|
||||
|
||||
case LEAF_NODE: {
|
||||
case CircuitNodeType::leafCnt: {
|
||||
printClauses (node, os, "style=filled,fillcolor=palegreen,");
|
||||
break;
|
||||
}
|
||||
|
||||
case SMOOTH_NODE: {
|
||||
case CircuitNodeType::smoothCnt: {
|
||||
printClauses (node, os, "style=filled,fillcolor=lightblue,");
|
||||
break;
|
||||
}
|
||||
|
||||
case TRUE_NODE: {
|
||||
case CircuitNodeType::trueCnt: {
|
||||
os << escapeNode (node);
|
||||
os << " [shape=box,label=\"⊤\"]" ;
|
||||
os << endl;
|
||||
os << std::endl;
|
||||
break;
|
||||
}
|
||||
|
||||
case COMPILATION_FAILED_NODE: {
|
||||
case CircuitNodeType::compilationFailedCnt: {
|
||||
printClauses (node, os, "style=filled,fillcolor=salmon,");
|
||||
break;
|
||||
}
|
||||
@ -1188,17 +1467,17 @@ LiftedCircuit::exportToGraphViz (CircuitNode* node, ofstream& os)
|
||||
|
||||
|
||||
|
||||
string
|
||||
std::string
|
||||
LiftedCircuit::escapeNode (const CircuitNode* node) const
|
||||
{
|
||||
stringstream ss;
|
||||
std::stringstream ss;
|
||||
ss << "\"" << node << "\"" ;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
|
||||
|
||||
string
|
||||
std::string
|
||||
LiftedCircuit::getExplanationString (CircuitNode* node)
|
||||
{
|
||||
return Util::contains (explanationMap_, node)
|
||||
@ -1211,15 +1490,15 @@ LiftedCircuit::getExplanationString (CircuitNode* node)
|
||||
void
|
||||
LiftedCircuit::printClauses (
|
||||
CircuitNode* node,
|
||||
ofstream& os,
|
||||
string extraOptions)
|
||||
std::ofstream& os,
|
||||
std::string extraOptions)
|
||||
{
|
||||
Clauses clauses;
|
||||
if (Util::contains (originClausesMap_, node)) {
|
||||
clauses = originClausesMap_[node];
|
||||
} else if (getCircuitNodeType (node) == CircuitNodeType::LEAF_NODE) {
|
||||
} else if (getCircuitNodeType (node) == CircuitNodeType::leafCnt) {
|
||||
clauses = { (dynamic_cast<LeafNode*>(node))->clause() } ;
|
||||
} else if (getCircuitNodeType (node) == CircuitNodeType::SMOOTH_NODE) {
|
||||
} else if (getCircuitNodeType (node) == CircuitNodeType::smoothCnt) {
|
||||
clauses = (dynamic_cast<SmoothNode*>(node))->clauses();
|
||||
}
|
||||
assert (clauses.empty() == false);
|
||||
@ -1230,15 +1509,7 @@ LiftedCircuit::printClauses (
|
||||
os << *clauses[i];
|
||||
}
|
||||
os << "\"]" ;
|
||||
os << endl;
|
||||
}
|
||||
|
||||
|
||||
|
||||
LiftedKc::~LiftedKc (void)
|
||||
{
|
||||
delete lwcnf_;
|
||||
delete circuit_;
|
||||
os << std::endl;
|
||||
}
|
||||
|
||||
|
||||
@ -1246,20 +1517,21 @@ LiftedKc::~LiftedKc (void)
|
||||
Params
|
||||
LiftedKc::solveQuery (const Grounds& query)
|
||||
{
|
||||
pfList_ = parfactorList;
|
||||
LiftedOperations::shatterAgainstQuery (pfList_, query);
|
||||
LiftedOperations::runWeakBayesBall (pfList_, query);
|
||||
lwcnf_ = new LiftedWCNF (pfList_);
|
||||
circuit_ = new LiftedCircuit (lwcnf_);
|
||||
if (circuit_->isCompilationSucceeded() == false) {
|
||||
cerr << "Error: the circuit compilation has failed." << endl;
|
||||
ParfactorList pfList (parfactorList);
|
||||
LiftedOperations::shatterAgainstQuery (pfList, query);
|
||||
LiftedOperations::runWeakBayesBall (pfList, query);
|
||||
LiftedWCNF lwcnf (pfList);
|
||||
LiftedCircuit circuit (&lwcnf);
|
||||
if (circuit.isCompilationSucceeded() == false) {
|
||||
std::cerr << "Error: the circuit compilation has failed." ;
|
||||
std::cerr << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
vector<PrvGroup> groups;
|
||||
std::vector<PrvGroup> groups;
|
||||
Ranges ranges;
|
||||
for (size_t i = 0; i < query.size(); i++) {
|
||||
ParfactorList::const_iterator it = pfList_.begin();
|
||||
while (it != pfList_.end()) {
|
||||
ParfactorList::const_iterator it = pfList.begin();
|
||||
while (it != pfList.end()) {
|
||||
size_t idx = (*it)->indexOfGround (query[i]);
|
||||
if (idx != (*it)->nrArguments()) {
|
||||
groups.push_back ((*it)->argument (idx).group());
|
||||
@ -1274,18 +1546,18 @@ LiftedKc::solveQuery (const Grounds& query)
|
||||
Indexer indexer (ranges);
|
||||
while (indexer.valid()) {
|
||||
for (size_t i = 0; i < groups.size(); i++) {
|
||||
vector<LiteralId> litIds = lwcnf_->prvGroupLiterals (groups[i]);
|
||||
std::vector<LiteralId> litIds = lwcnf.prvGroupLiterals (groups[i]);
|
||||
for (size_t j = 0; j < litIds.size(); j++) {
|
||||
if (indexer[i] == j) {
|
||||
lwcnf_->addWeight (litIds[j], LogAware::one(),
|
||||
lwcnf.addWeight (litIds[j], LogAware::one(),
|
||||
LogAware::one());
|
||||
} else {
|
||||
lwcnf_->addWeight (litIds[j], LogAware::zero(),
|
||||
lwcnf.addWeight (litIds[j], LogAware::zero(),
|
||||
LogAware::one());
|
||||
}
|
||||
}
|
||||
}
|
||||
params.push_back (circuit_->getWeightedModelCount());
|
||||
params.push_back (circuit.getWeightedModelCount());
|
||||
++ indexer;
|
||||
}
|
||||
LogAware::normalize (params);
|
||||
@ -1298,12 +1570,14 @@ LiftedKc::solveQuery (const Grounds& query)
|
||||
|
||||
|
||||
void
|
||||
LiftedKc::printSolverFlags (void) const
|
||||
LiftedKc::printSolverFlags() const
|
||||
{
|
||||
stringstream ss;
|
||||
std::stringstream ss;
|
||||
ss << "lifted kc [" ;
|
||||
ss << "log_domain=" << Util::toString (Globals::logDomain);
|
||||
ss << "]" ;
|
||||
cout << ss.str() << endl;
|
||||
std::cout << ss.str() << std::endl;
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
|
@ -1,302 +1,26 @@
|
||||
#ifndef HORUS_LIFTEDKC_H
|
||||
#define HORUS_LIFTEDKC_H
|
||||
#ifndef YAP_PACKAGES_CLPBN_HORUS_LIFTEDKC_H_
|
||||
#define YAP_PACKAGES_CLPBN_HORUS_LIFTEDKC_H_
|
||||
|
||||
#include "LiftedSolver.h"
|
||||
#include "LiftedWCNF.h"
|
||||
#include "ParfactorList.h"
|
||||
|
||||
|
||||
enum CircuitNodeType {
|
||||
OR_NODE,
|
||||
AND_NODE,
|
||||
SET_OR_NODE,
|
||||
SET_AND_NODE,
|
||||
INC_EXC_NODE,
|
||||
LEAF_NODE,
|
||||
SMOOTH_NODE,
|
||||
TRUE_NODE,
|
||||
COMPILATION_FAILED_NODE
|
||||
};
|
||||
namespace Horus {
|
||||
|
||||
|
||||
|
||||
class CircuitNode
|
||||
{
|
||||
public:
|
||||
CircuitNode (void) { }
|
||||
|
||||
virtual ~CircuitNode (void) { }
|
||||
|
||||
virtual double weight (void) const = 0;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class OrNode : public CircuitNode
|
||||
{
|
||||
public:
|
||||
OrNode (void) : CircuitNode(), leftBranch_(0), rightBranch_(0) { }
|
||||
|
||||
~OrNode (void);
|
||||
|
||||
CircuitNode** leftBranch (void) { return &leftBranch_; }
|
||||
CircuitNode** rightBranch (void) { return &rightBranch_; }
|
||||
|
||||
double weight (void) const;
|
||||
|
||||
private:
|
||||
CircuitNode* leftBranch_;
|
||||
CircuitNode* rightBranch_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class AndNode : public CircuitNode
|
||||
{
|
||||
public:
|
||||
AndNode (void) : CircuitNode(), leftBranch_(0), rightBranch_(0) { }
|
||||
|
||||
AndNode (CircuitNode* leftBranch, CircuitNode* rightBranch)
|
||||
: CircuitNode(), leftBranch_(leftBranch), rightBranch_(rightBranch) { }
|
||||
|
||||
~AndNode (void);
|
||||
|
||||
CircuitNode** leftBranch (void) { return &leftBranch_; }
|
||||
CircuitNode** rightBranch (void) { return &rightBranch_; }
|
||||
|
||||
double weight (void) const;
|
||||
|
||||
private:
|
||||
CircuitNode* leftBranch_;
|
||||
CircuitNode* rightBranch_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class SetOrNode : public CircuitNode
|
||||
{
|
||||
public:
|
||||
SetOrNode (unsigned nrGroundings)
|
||||
: CircuitNode(), follow_(0), nrGroundings_(nrGroundings) { }
|
||||
|
||||
~SetOrNode (void);
|
||||
|
||||
CircuitNode** follow (void) { return &follow_; }
|
||||
|
||||
static unsigned nrPositives (void) { return nrPos_; }
|
||||
|
||||
static unsigned nrNegatives (void) { return nrNeg_; }
|
||||
|
||||
static bool isSet (void) { return nrPos_ >= 0; }
|
||||
|
||||
double weight (void) const;
|
||||
|
||||
private:
|
||||
CircuitNode* follow_;
|
||||
unsigned nrGroundings_;
|
||||
static int nrPos_;
|
||||
static int nrNeg_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class SetAndNode : public CircuitNode
|
||||
{
|
||||
public:
|
||||
SetAndNode (unsigned nrGroundings)
|
||||
: CircuitNode(), follow_(0), nrGroundings_(nrGroundings) { }
|
||||
|
||||
~SetAndNode (void);
|
||||
|
||||
CircuitNode** follow (void) { return &follow_; }
|
||||
|
||||
double weight (void) const;
|
||||
|
||||
private:
|
||||
CircuitNode* follow_;
|
||||
unsigned nrGroundings_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class IncExcNode : public CircuitNode
|
||||
{
|
||||
public:
|
||||
IncExcNode (void)
|
||||
: CircuitNode(), plus1Branch_(0), plus2Branch_(0), minusBranch_(0) { }
|
||||
|
||||
~IncExcNode (void);
|
||||
|
||||
CircuitNode** plus1Branch (void) { return &plus1Branch_; }
|
||||
CircuitNode** plus2Branch (void) { return &plus2Branch_; }
|
||||
CircuitNode** minusBranch (void) { return &minusBranch_; }
|
||||
|
||||
double weight (void) const;
|
||||
|
||||
private:
|
||||
CircuitNode* plus1Branch_;
|
||||
CircuitNode* plus2Branch_;
|
||||
CircuitNode* minusBranch_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class LeafNode : public CircuitNode
|
||||
{
|
||||
public:
|
||||
LeafNode (Clause* clause, const LiftedWCNF& lwcnf)
|
||||
: CircuitNode(), clause_(clause), lwcnf_(lwcnf) { }
|
||||
|
||||
~LeafNode (void);
|
||||
|
||||
const Clause* clause (void) const { return clause_; }
|
||||
|
||||
Clause* clause (void) { return clause_; }
|
||||
|
||||
double weight (void) const;
|
||||
|
||||
private:
|
||||
Clause* clause_;
|
||||
const LiftedWCNF& lwcnf_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class SmoothNode : public CircuitNode
|
||||
{
|
||||
public:
|
||||
SmoothNode (const Clauses& clauses, const LiftedWCNF& lwcnf)
|
||||
: CircuitNode(), clauses_(clauses), lwcnf_(lwcnf) { }
|
||||
|
||||
~SmoothNode (void);
|
||||
|
||||
const Clauses& clauses (void) const { return clauses_; }
|
||||
|
||||
Clauses clauses (void) { return clauses_; }
|
||||
|
||||
double weight (void) const;
|
||||
|
||||
private:
|
||||
Clauses clauses_;
|
||||
const LiftedWCNF& lwcnf_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class TrueNode : public CircuitNode
|
||||
{
|
||||
public:
|
||||
TrueNode (void) : CircuitNode() { }
|
||||
|
||||
double weight (void) const;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class CompilationFailedNode : public CircuitNode
|
||||
{
|
||||
public:
|
||||
CompilationFailedNode (void) : CircuitNode() { }
|
||||
|
||||
double weight (void) const;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class LiftedCircuit
|
||||
{
|
||||
public:
|
||||
LiftedCircuit (const LiftedWCNF* lwcnf);
|
||||
|
||||
~LiftedCircuit (void);
|
||||
|
||||
bool isCompilationSucceeded (void) const;
|
||||
|
||||
double getWeightedModelCount (void) const;
|
||||
|
||||
void exportToGraphViz (const char*);
|
||||
|
||||
private:
|
||||
void compile (CircuitNode** follow, Clauses& clauses);
|
||||
|
||||
bool tryUnitPropagation (CircuitNode** follow, Clauses& clauses);
|
||||
|
||||
bool tryIndependence (CircuitNode** follow, Clauses& clauses);
|
||||
|
||||
bool tryShannonDecomp (CircuitNode** follow, Clauses& clauses);
|
||||
|
||||
bool tryInclusionExclusion (CircuitNode** follow, Clauses& clauses);
|
||||
|
||||
bool tryIndepPartialGrounding (CircuitNode** follow, Clauses& clauses);
|
||||
|
||||
bool tryIndepPartialGroundingAux (Clauses& clauses, ConstraintTree& ct,
|
||||
LogVars& rootLogVars);
|
||||
|
||||
bool tryAtomCounting (CircuitNode** follow, Clauses& clauses);
|
||||
|
||||
void shatterCountedLogVars (Clauses& clauses);
|
||||
|
||||
bool shatterCountedLogVarsAux (Clauses& clauses);
|
||||
|
||||
bool shatterCountedLogVarsAux (Clauses& clauses, size_t idx1, size_t idx2);
|
||||
|
||||
bool independentClause (Clause& clause, Clauses& otherClauses) const;
|
||||
|
||||
bool independentLiteral (const Literal& lit,
|
||||
const Literals& otherLits) const;
|
||||
|
||||
LitLvTypesSet smoothCircuit (CircuitNode* node);
|
||||
|
||||
void createSmoothNode (const LitLvTypesSet& lids,
|
||||
CircuitNode** prev);
|
||||
|
||||
vector<LogVarTypes> getAllPossibleTypes (unsigned nrLogVars) const;
|
||||
|
||||
bool containsTypes (const LogVarTypes& typesA,
|
||||
const LogVarTypes& typesB) const;
|
||||
|
||||
CircuitNodeType getCircuitNodeType (const CircuitNode* node) const;
|
||||
|
||||
void exportToGraphViz (CircuitNode* node, ofstream&);
|
||||
|
||||
void printClauses (CircuitNode* node, ofstream&,
|
||||
string extraOptions = "");
|
||||
|
||||
string escapeNode (const CircuitNode* node) const;
|
||||
|
||||
string getExplanationString (CircuitNode* node);
|
||||
|
||||
CircuitNode* root_;
|
||||
const LiftedWCNF* lwcnf_;
|
||||
bool compilationSucceeded_;
|
||||
Clauses backupClauses_;
|
||||
unordered_map<CircuitNode*, Clauses> originClausesMap_;
|
||||
unordered_map<CircuitNode*, string> explanationMap_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN (LiftedCircuit);
|
||||
};
|
||||
|
||||
|
||||
|
||||
class LiftedKc : public LiftedSolver
|
||||
{
|
||||
class LiftedKc : public LiftedSolver {
|
||||
public:
|
||||
LiftedKc (const ParfactorList& pfList)
|
||||
: LiftedSolver(pfList) { }
|
||||
|
||||
~LiftedKc (void);
|
||||
|
||||
Params solveQuery (const Grounds&);
|
||||
|
||||
void printSolverFlags (void) const;
|
||||
void printSolverFlags() const;
|
||||
|
||||
private:
|
||||
LiftedWCNF* lwcnf_;
|
||||
LiftedCircuit* circuit_;
|
||||
ParfactorList pfList_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN (LiftedKc);
|
||||
};
|
||||
|
||||
#endif // HORUS_LIFTEDKC_H
|
||||
} // namespace Horus
|
||||
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_LIFTEDKC_H_
|
||||
|
||||
|
@ -1,10 +1,22 @@
|
||||
#include <vector>
|
||||
#include <queue>
|
||||
#include <iostream>
|
||||
|
||||
#include "LiftedOperations.h"
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
namespace LiftedOperations {
|
||||
|
||||
namespace {
|
||||
|
||||
Parfactors absorve (ObservedFormula& obsFormula, Parfactor* g);
|
||||
|
||||
}
|
||||
|
||||
void
|
||||
LiftedOperations::shatterAgainstQuery (
|
||||
ParfactorList& pfList,
|
||||
const Grounds& query)
|
||||
shatterAgainstQuery (ParfactorList& pfList, const Grounds& query)
|
||||
{
|
||||
for (size_t i = 0; i < query.size(); i++) {
|
||||
if (query[i].isAtom()) {
|
||||
@ -35,17 +47,17 @@ LiftedOperations::shatterAgainstQuery (
|
||||
}
|
||||
}
|
||||
if (found == false) {
|
||||
cerr << "Error: could not find a parfactor with ground " ;
|
||||
cerr << "`" << query[i] << "'." << endl;
|
||||
std::cerr << "Error: could not find a parfactor with ground " ;
|
||||
std::cerr << "`" << query[i] << "'." << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
pfList.add (newPfs);
|
||||
}
|
||||
if (Globals::verbosity > 2) {
|
||||
Util::printAsteriskLine();
|
||||
cout << "SHATTERED AGAINST THE QUERY" << endl;
|
||||
std::cout << "SHATTERED AGAINST THE QUERY" << std::endl;
|
||||
for (size_t i = 0; i < query.size(); i++) {
|
||||
cout << " -> " << query[i] << endl;
|
||||
std::cout << " -> " << query[i] << std::endl;
|
||||
}
|
||||
Util::printAsteriskLine();
|
||||
pfList.print();
|
||||
@ -55,12 +67,10 @@ LiftedOperations::shatterAgainstQuery (
|
||||
|
||||
|
||||
void
|
||||
LiftedOperations::runWeakBayesBall (
|
||||
ParfactorList& pfList,
|
||||
const Grounds& query)
|
||||
runWeakBayesBall (ParfactorList& pfList, const Grounds& query)
|
||||
{
|
||||
queue<PrvGroup> todo; // groups to process
|
||||
set<PrvGroup> done; // processed or in queue
|
||||
std::queue<PrvGroup> todo; // groups to process
|
||||
std::set<PrvGroup> done; // processed or in queue
|
||||
for (size_t i = 0; i < query.size(); i++) {
|
||||
ParfactorList::iterator it = pfList.begin();
|
||||
while (it != pfList.end()) {
|
||||
@ -74,14 +84,14 @@ LiftedOperations::runWeakBayesBall (
|
||||
}
|
||||
}
|
||||
|
||||
set<Parfactor*> requiredPfs;
|
||||
std::set<Parfactor*> requiredPfs;
|
||||
while (todo.empty() == false) {
|
||||
PrvGroup group = todo.front();
|
||||
ParfactorList::iterator it = pfList.begin();
|
||||
while (it != pfList.end()) {
|
||||
if (Util::contains (requiredPfs, *it) == false &&
|
||||
(*it)->containsGroup (group)) {
|
||||
vector<PrvGroup> groups = (*it)->getAllGroups();
|
||||
std::vector<PrvGroup> groups = (*it)->getAllGroups();
|
||||
for (size_t i = 0; i < groups.size(); i++) {
|
||||
if (Util::contains (done, groups[i]) == false) {
|
||||
todo.push (groups[i]);
|
||||
@ -116,9 +126,7 @@ LiftedOperations::runWeakBayesBall (
|
||||
|
||||
|
||||
void
|
||||
LiftedOperations::absorveEvidence (
|
||||
ParfactorList& pfList,
|
||||
ObservedFormulas& obsFormulas)
|
||||
absorveEvidence (ParfactorList& pfList, ObservedFormulas& obsFormulas)
|
||||
{
|
||||
for (size_t i = 0; i < obsFormulas.size(); i++) {
|
||||
Parfactors newPfs;
|
||||
@ -143,9 +151,9 @@ LiftedOperations::absorveEvidence (
|
||||
}
|
||||
if (Globals::verbosity > 2 && obsFormulas.empty() == false) {
|
||||
Util::printAsteriskLine();
|
||||
cout << "AFTER EVIDENCE ABSORVED" << endl;
|
||||
std::cout << "AFTER EVIDENCE ABSORVED" << std::endl;
|
||||
for (size_t i = 0; i < obsFormulas.size(); i++) {
|
||||
cout << " -> " << obsFormulas[i] << endl;
|
||||
std::cout << " -> " << obsFormulas[i] << std::endl;
|
||||
}
|
||||
Util::printAsteriskLine();
|
||||
pfList.print();
|
||||
@ -155,9 +163,7 @@ LiftedOperations::absorveEvidence (
|
||||
|
||||
|
||||
Parfactors
|
||||
LiftedOperations::countNormalize (
|
||||
Parfactor* g,
|
||||
const LogVarSet& set)
|
||||
countNormalize (Parfactor* g, const LogVarSet& set)
|
||||
{
|
||||
Parfactors normPfs;
|
||||
if (set.empty()) {
|
||||
@ -174,7 +180,7 @@ LiftedOperations::countNormalize (
|
||||
|
||||
|
||||
Parfactor
|
||||
LiftedOperations::calcGroundMultiplication (Parfactor pf)
|
||||
calcGroundMultiplication (Parfactor pf)
|
||||
{
|
||||
LogVarSet lvs = pf.constr()->logVarSet();
|
||||
lvs -= pf.constr()->singletons();
|
||||
@ -206,10 +212,10 @@ LiftedOperations::calcGroundMultiplication (Parfactor pf)
|
||||
|
||||
|
||||
|
||||
namespace {
|
||||
|
||||
Parfactors
|
||||
LiftedOperations::absorve (
|
||||
ObservedFormula& obsFormula,
|
||||
Parfactor* g)
|
||||
absorve (ObservedFormula& obsFormula, Parfactor* g)
|
||||
{
|
||||
Parfactors absorvedPfs;
|
||||
const ProbFormulas& formulas = g->arguments();
|
||||
@ -269,3 +275,9 @@ LiftedOperations::absorve (
|
||||
return absorvedPfs;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
} // namespace LiftedOperations
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
|
@ -1,29 +1,26 @@
|
||||
#ifndef HORUS_LIFTEDOPERATIONS_H
|
||||
#define HORUS_LIFTEDOPERATIONS_H
|
||||
#ifndef YAP_PACKAGES_CLPBN_HORUS_LIFTEDOPERATIONS_H_
|
||||
#define YAP_PACKAGES_CLPBN_HORUS_LIFTEDOPERATIONS_H_
|
||||
|
||||
#include "ParfactorList.h"
|
||||
|
||||
class LiftedOperations
|
||||
{
|
||||
public:
|
||||
static void shatterAgainstQuery (
|
||||
ParfactorList& pfList, const Grounds& query);
|
||||
|
||||
static void runWeakBayesBall (
|
||||
ParfactorList& pfList, const Grounds&);
|
||||
namespace Horus {
|
||||
|
||||
static void absorveEvidence (
|
||||
ParfactorList& pfList, ObservedFormulas& obsFormulas);
|
||||
namespace LiftedOperations {
|
||||
|
||||
static Parfactors countNormalize (Parfactor*, const LogVarSet&);
|
||||
void shatterAgainstQuery (ParfactorList& pfList, const Grounds& query);
|
||||
|
||||
static Parfactor calcGroundMultiplication (Parfactor pf);
|
||||
void runWeakBayesBall (ParfactorList& pfList, const Grounds& query);
|
||||
|
||||
private:
|
||||
static Parfactors absorve (ObservedFormula&, Parfactor*);
|
||||
void absorveEvidence (ParfactorList& pfList, ObservedFormulas&);
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN (LiftedOperations);
|
||||
};
|
||||
Parfactors countNormalize (Parfactor*, const LogVarSet&);
|
||||
|
||||
#endif // HORUS_LIFTEDOPERATIONS_H
|
||||
Parfactor calcGroundMultiplication (Parfactor pf);
|
||||
|
||||
} // namespace LiftedOperations
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_LIFTEDOPERATIONS_H_
|
||||
|
||||
|
@ -1,14 +1,12 @@
|
||||
#ifndef HORUS_LIFTEDSOLVER_H
|
||||
#define HORUS_LIFTEDSOLVER_H
|
||||
#ifndef YAP_PACKAGES_CLPBN_HORUS_LIFTEDSOLVER_H_
|
||||
#define YAP_PACKAGES_CLPBN_HORUS_LIFTEDSOLVER_H_
|
||||
|
||||
#include "ParfactorList.h"
|
||||
#include "Horus.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
namespace Horus {
|
||||
|
||||
class LiftedSolver
|
||||
{
|
||||
class LiftedSolver {
|
||||
public:
|
||||
LiftedSolver (const ParfactorList& pfList)
|
||||
: parfactorList(pfList) { }
|
||||
@ -17,7 +15,7 @@ class LiftedSolver
|
||||
|
||||
virtual Params solveQuery (const Grounds& query) = 0;
|
||||
|
||||
virtual void printSolverFlags (void) const = 0;
|
||||
virtual void printSolverFlags() const = 0;
|
||||
|
||||
protected:
|
||||
const ParfactorList& parfactorList;
|
||||
@ -26,5 +24,7 @@ class LiftedSolver
|
||||
DISALLOW_COPY_AND_ASSIGN (LiftedSolver);
|
||||
};
|
||||
|
||||
#endif // HORUS_LIFTEDSOLVER_H
|
||||
} // namespace Horus
|
||||
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_LIFTEDSOLVER_H_
|
||||
|
||||
|
@ -1,22 +1,21 @@
|
||||
#include <cassert>
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "LiftedUtils.h"
|
||||
#include "ConstraintTree.h"
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
namespace LiftedUtils {
|
||||
|
||||
|
||||
unordered_map<string, unsigned> symbolDict;
|
||||
std::unordered_map<std::string, unsigned> symbolDict;
|
||||
|
||||
|
||||
Symbol
|
||||
getSymbol (const string& symbolName)
|
||||
getSymbol (const std::string& symbolName)
|
||||
{
|
||||
unordered_map<string, unsigned>::iterator it
|
||||
std::unordered_map<std::string, unsigned>::iterator it
|
||||
= symbolDict.find (symbolName);
|
||||
if (it != symbolDict.end()) {
|
||||
return it->second;
|
||||
@ -29,12 +28,12 @@ getSymbol (const string& symbolName)
|
||||
|
||||
|
||||
void
|
||||
printSymbolDictionary (void)
|
||||
printSymbolDictionary()
|
||||
{
|
||||
unordered_map<string, unsigned>::const_iterator it
|
||||
std::unordered_map<std::string, unsigned>::const_iterator it
|
||||
= symbolDict.begin();
|
||||
while (it != symbolDict.end()) {
|
||||
cout << it->first << " -> " << it->second << endl;
|
||||
std::cout << it->first << " -> " << it->second << std::endl;
|
||||
++ it;
|
||||
}
|
||||
}
|
||||
@ -43,9 +42,10 @@ printSymbolDictionary (void)
|
||||
|
||||
|
||||
|
||||
ostream& operator<< (ostream &os, const Symbol& s)
|
||||
std::ostream&
|
||||
operator<< (std::ostream& os, const Symbol& s)
|
||||
{
|
||||
unordered_map<string, unsigned>::const_iterator it
|
||||
std::unordered_map<std::string, unsigned>::const_iterator it
|
||||
= LiftedUtils::symbolDict.begin();
|
||||
while (it != LiftedUtils::symbolDict.end() && it->second != s) {
|
||||
++ it;
|
||||
@ -57,9 +57,10 @@ ostream& operator<< (ostream &os, const Symbol& s)
|
||||
|
||||
|
||||
|
||||
ostream& operator<< (ostream &os, const LogVar& X)
|
||||
std::ostream&
|
||||
operator<< (std::ostream& os, const LogVar& X)
|
||||
{
|
||||
const string labels[] = {
|
||||
const std::string labels[] = {
|
||||
"A", "B", "C", "D", "E", "F",
|
||||
"G", "H", "I", "J", "K", "M" };
|
||||
(X >= 12) ? os << "X_" << X.id_ : os << labels[X];
|
||||
@ -68,7 +69,8 @@ ostream& operator<< (ostream &os, const LogVar& X)
|
||||
|
||||
|
||||
|
||||
ostream& operator<< (ostream &os, const Tuple& t)
|
||||
std::ostream&
|
||||
operator<< (std::ostream& os, const Tuple& t)
|
||||
{
|
||||
os << "(" ;
|
||||
for (size_t i = 0; i < t.size(); i++) {
|
||||
@ -80,7 +82,8 @@ ostream& operator<< (ostream &os, const Tuple& t)
|
||||
|
||||
|
||||
|
||||
ostream& operator<< (ostream &os, const Ground& gr)
|
||||
std::ostream&
|
||||
operator<< (std::ostream& os, const Ground& gr)
|
||||
{
|
||||
os << gr.functor();
|
||||
os << "(" ;
|
||||
@ -95,12 +98,12 @@ ostream& operator<< (ostream &os, const Ground& gr)
|
||||
|
||||
|
||||
LogVars
|
||||
Substitution::getDiscardedLogVars (void) const
|
||||
Substitution::getDiscardedLogVars() const
|
||||
{
|
||||
LogVars discardedLvs;
|
||||
set<LogVar> doneLvs;
|
||||
unordered_map<LogVar, LogVar>::const_iterator it;
|
||||
it = subs_.begin();
|
||||
std::set<LogVar> doneLvs;
|
||||
std::unordered_map<LogVar, LogVar>::const_iterator it
|
||||
= subs_.begin();
|
||||
while (it != subs_.end()) {
|
||||
if (Util::contains (doneLvs, it->second)) {
|
||||
discardedLvs.push_back (it->first);
|
||||
@ -114,9 +117,10 @@ Substitution::getDiscardedLogVars (void) const
|
||||
|
||||
|
||||
|
||||
ostream& operator<< (ostream &os, const Substitution& theta)
|
||||
std::ostream&
|
||||
operator<< (std::ostream& os, const Substitution& theta)
|
||||
{
|
||||
unordered_map<LogVar, LogVar>::const_iterator it;
|
||||
std::unordered_map<LogVar, LogVar>::const_iterator it;
|
||||
os << "[" ;
|
||||
it = theta.subs_.begin();
|
||||
while (it != theta.subs_.end()) {
|
||||
@ -128,3 +132,5 @@ ostream& operator<< (ostream &os, const Substitution& theta)
|
||||
return os;
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
|
@ -1,164 +1,210 @@
|
||||
#ifndef HORUS_LIFTEDUTILS_H
|
||||
#define HORUS_LIFTEDUTILS_H
|
||||
|
||||
#include <string>
|
||||
#ifndef YAP_PACKAGES_CLPBN_HORUS_LIFTEDUTILS_H_
|
||||
#define YAP_PACKAGES_CLPBN_HORUS_LIFTEDUTILS_H_
|
||||
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include <string>
|
||||
#include <ostream>
|
||||
|
||||
#include "TinySet.h"
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
namespace Horus {
|
||||
|
||||
|
||||
class Symbol
|
||||
{
|
||||
class Symbol {
|
||||
public:
|
||||
Symbol (void) : id_(Util::maxUnsigned()) { }
|
||||
Symbol() : id_(Util::maxUnsigned()) { }
|
||||
|
||||
Symbol (unsigned id) : id_(id) { }
|
||||
|
||||
operator unsigned (void) const { return id_; }
|
||||
operator unsigned() const { return id_; }
|
||||
|
||||
bool valid (void) const { return id_ != Util::maxUnsigned(); }
|
||||
bool valid() const { return id_ != Util::maxUnsigned(); }
|
||||
|
||||
static Symbol invalid (void) { return Symbol(); }
|
||||
|
||||
friend ostream& operator<< (ostream &os, const Symbol& s);
|
||||
static Symbol invalid() { return Symbol(); }
|
||||
|
||||
private:
|
||||
friend std::ostream& operator<< (std::ostream&, const Symbol&);
|
||||
|
||||
unsigned id_;
|
||||
};
|
||||
|
||||
|
||||
class LogVar
|
||||
{
|
||||
class LogVar {
|
||||
public:
|
||||
LogVar (void) : id_(Util::maxUnsigned()) { }
|
||||
LogVar() : id_(Util::maxUnsigned()) { }
|
||||
|
||||
LogVar (unsigned id) : id_(id) { }
|
||||
|
||||
operator unsigned (void) const { return id_; }
|
||||
operator unsigned() const { return id_; }
|
||||
|
||||
LogVar& operator++ (void)
|
||||
{
|
||||
assert (valid());
|
||||
id_ ++;
|
||||
return *this;
|
||||
}
|
||||
LogVar& operator++();
|
||||
|
||||
bool valid (void) const
|
||||
{
|
||||
return id_ != Util::maxUnsigned();
|
||||
}
|
||||
|
||||
friend ostream& operator<< (ostream &os, const LogVar& X);
|
||||
bool valid() const;
|
||||
|
||||
private:
|
||||
friend std::ostream& operator<< (std::ostream&, const LogVar&);
|
||||
|
||||
unsigned id_;
|
||||
};
|
||||
|
||||
|
||||
namespace std {
|
||||
template <> struct hash<Symbol> {
|
||||
size_t operator() (const Symbol& s) const {
|
||||
return std::hash<unsigned>() (s);
|
||||
}};
|
||||
|
||||
template <> struct hash<LogVar> {
|
||||
size_t operator() (const LogVar& X) const {
|
||||
return std::hash<unsigned>() (X);
|
||||
}};
|
||||
};
|
||||
|
||||
|
||||
typedef vector<Symbol> Symbols;
|
||||
typedef vector<Symbol> Tuple;
|
||||
typedef vector<Tuple> Tuples;
|
||||
typedef vector<LogVar> LogVars;
|
||||
typedef TinySet<Symbol> SymbolSet;
|
||||
typedef TinySet<LogVar> LogVarSet;
|
||||
typedef TinySet<Tuple> TupleSet;
|
||||
|
||||
|
||||
ostream& operator<< (ostream &os, const Tuple& t);
|
||||
|
||||
|
||||
namespace LiftedUtils {
|
||||
Symbol getSymbol (const string&);
|
||||
void printSymbolDictionary (void);
|
||||
inline LogVar&
|
||||
LogVar::operator++()
|
||||
{
|
||||
assert (valid());
|
||||
id_ ++;
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
||||
|
||||
class Ground
|
||||
inline bool
|
||||
LogVar::valid() const
|
||||
{
|
||||
return id_ != Util::maxUnsigned();
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
|
||||
namespace std {
|
||||
|
||||
template <> struct hash<Horus::Symbol> {
|
||||
size_t operator() (const Horus::Symbol& s) const {
|
||||
return std::hash<unsigned>() (s);
|
||||
}};
|
||||
|
||||
template <> struct hash<Horus::LogVar> {
|
||||
size_t operator() (const Horus::LogVar& X) const {
|
||||
return std::hash<unsigned>() (X);
|
||||
}};
|
||||
|
||||
} // namespace std
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
typedef std::vector<Symbol> Symbols;
|
||||
typedef std::vector<Symbol> Tuple;
|
||||
typedef std::vector<Tuple> Tuples;
|
||||
typedef std::vector<LogVar> LogVars;
|
||||
typedef TinySet<Symbol> SymbolSet;
|
||||
typedef TinySet<LogVar> LogVarSet;
|
||||
typedef TinySet<Tuple> TupleSet;
|
||||
|
||||
|
||||
std::ostream& operator<< (std::ostream&, const Tuple&);
|
||||
|
||||
|
||||
namespace LiftedUtils {
|
||||
|
||||
Symbol getSymbol (const std::string&);
|
||||
|
||||
void printSymbolDictionary();
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
class Ground {
|
||||
public:
|
||||
Ground (Symbol f) : functor_(f) { }
|
||||
|
||||
Ground (Symbol f, const Symbols& args) : functor_(f), args_(args) { }
|
||||
Ground (Symbol f, const Symbols& args)
|
||||
: functor_(f), args_(args) { }
|
||||
|
||||
Symbol functor (void) const { return functor_; }
|
||||
Symbol functor() const { return functor_; }
|
||||
|
||||
Symbols args (void) const { return args_; }
|
||||
Symbols args() const { return args_; }
|
||||
|
||||
size_t arity (void) const { return args_.size(); }
|
||||
size_t arity() const { return args_.size(); }
|
||||
|
||||
bool isAtom (void) const { return args_.empty(); }
|
||||
|
||||
friend ostream& operator<< (ostream &os, const Ground& gr);
|
||||
bool isAtom() const { return args_.empty(); }
|
||||
|
||||
private:
|
||||
friend std::ostream& operator<< (std::ostream&, const Ground&);
|
||||
|
||||
Symbol functor_;
|
||||
Symbols args_;
|
||||
};
|
||||
|
||||
typedef vector<Ground> Grounds;
|
||||
typedef std::vector<Ground> Grounds;
|
||||
|
||||
|
||||
|
||||
class Substitution
|
||||
{
|
||||
class Substitution {
|
||||
public:
|
||||
void add (LogVar X_old, LogVar X_new)
|
||||
{
|
||||
assert (Util::contains (subs_, X_old) == false);
|
||||
subs_.insert (make_pair (X_old, X_new));
|
||||
}
|
||||
void add (LogVar X_old, LogVar X_new);
|
||||
|
||||
void rename (LogVar X_old, LogVar X_new)
|
||||
{
|
||||
assert (Util::contains (subs_, X_old));
|
||||
subs_.find (X_old)->second = X_new;
|
||||
}
|
||||
void rename (LogVar X_old, LogVar X_new);
|
||||
|
||||
LogVar newNameFor (LogVar X) const
|
||||
{
|
||||
unordered_map<LogVar, LogVar>::const_iterator it;
|
||||
it = subs_.find (X);
|
||||
if (it != subs_.end()) {
|
||||
return subs_.find (X)->second;
|
||||
}
|
||||
return X;
|
||||
}
|
||||
LogVar newNameFor (LogVar X) const;
|
||||
|
||||
bool containsReplacementFor (LogVar X) const
|
||||
{
|
||||
return Util::contains (subs_, X);
|
||||
}
|
||||
bool containsReplacementFor (LogVar X) const;
|
||||
|
||||
size_t nrReplacements (void) const { return subs_.size(); }
|
||||
size_t nrReplacements() const;
|
||||
|
||||
LogVars getDiscardedLogVars (void) const;
|
||||
|
||||
friend ostream& operator<< (ostream &os, const Substitution& theta);
|
||||
LogVars getDiscardedLogVars() const;
|
||||
|
||||
private:
|
||||
unordered_map<LogVar, LogVar> subs_;
|
||||
friend std::ostream& operator<< (
|
||||
std::ostream&, const Substitution&);
|
||||
|
||||
std::unordered_map<LogVar, LogVar> subs_;
|
||||
};
|
||||
|
||||
#endif // HORUS_LIFTEDUTILS_H
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
Substitution::add (LogVar X_old, LogVar X_new)
|
||||
{
|
||||
assert (Util::contains (subs_, X_old) == false);
|
||||
subs_.insert (std::make_pair (X_old, X_new));
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
Substitution::rename (LogVar X_old, LogVar X_new)
|
||||
{
|
||||
assert (Util::contains (subs_, X_old));
|
||||
subs_.find (X_old)->second = X_new;
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline LogVar
|
||||
Substitution::newNameFor (LogVar X) const
|
||||
{
|
||||
std::unordered_map<LogVar, LogVar>::const_iterator it;
|
||||
it = subs_.find (X);
|
||||
if (it != subs_.end()) {
|
||||
return subs_.find (X)->second;
|
||||
}
|
||||
return X;
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline bool
|
||||
Substitution::containsReplacementFor (LogVar X) const
|
||||
{
|
||||
return Util::contains (subs_, X);
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline size_t
|
||||
Substitution::nrReplacements() const
|
||||
{
|
||||
return subs_.size();
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_LIFTEDUTILS_H_
|
||||
|
||||
|
@ -1,6 +1,12 @@
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <queue>
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "LiftedVe.h"
|
||||
#include "LiftedOperations.h"
|
||||
@ -8,21 +14,158 @@
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
vector<LiftedOperator*>
|
||||
namespace Horus {
|
||||
|
||||
class LiftedOperator {
|
||||
public:
|
||||
virtual ~LiftedOperator() { }
|
||||
|
||||
virtual double getLogCost() = 0;
|
||||
|
||||
virtual void apply() = 0;
|
||||
|
||||
virtual std::string toString() = 0;
|
||||
|
||||
static std::vector<LiftedOperator*> getValidOps (
|
||||
ParfactorList&, const Grounds&);
|
||||
|
||||
static void printValidOps (ParfactorList&, const Grounds&);
|
||||
|
||||
static std::vector<ParfactorList::iterator> getParfactorsWithGroup (
|
||||
ParfactorList&, PrvGroup group);
|
||||
|
||||
private:
|
||||
DISALLOW_ASSIGN (LiftedOperator);
|
||||
};
|
||||
|
||||
|
||||
|
||||
class ProductOperator : public LiftedOperator {
|
||||
public:
|
||||
ProductOperator (
|
||||
ParfactorList::iterator g1,
|
||||
ParfactorList::iterator g2,
|
||||
ParfactorList& pfList)
|
||||
: g1_(g1), g2_(g2), pfList_(pfList) { }
|
||||
|
||||
double getLogCost();
|
||||
|
||||
void apply();
|
||||
|
||||
static std::vector<ProductOperator*> getValidOps (ParfactorList&);
|
||||
|
||||
std::string toString();
|
||||
|
||||
private:
|
||||
static bool validOp (Parfactor*, Parfactor*);
|
||||
|
||||
ParfactorList::iterator g1_;
|
||||
ParfactorList::iterator g2_;
|
||||
ParfactorList& pfList_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN (ProductOperator);
|
||||
};
|
||||
|
||||
|
||||
|
||||
class SumOutOperator : public LiftedOperator {
|
||||
public:
|
||||
SumOutOperator (PrvGroup group, ParfactorList& pfList)
|
||||
: group_(group), pfList_(pfList) { }
|
||||
|
||||
double getLogCost();
|
||||
|
||||
void apply();
|
||||
|
||||
static std::vector<SumOutOperator*> getValidOps (
|
||||
ParfactorList&, const Grounds&);
|
||||
|
||||
std::string toString();
|
||||
|
||||
private:
|
||||
static bool validOp (PrvGroup, ParfactorList&, const Grounds&);
|
||||
|
||||
static bool isToEliminate (Parfactor*, PrvGroup, const Grounds&);
|
||||
|
||||
PrvGroup group_;
|
||||
ParfactorList& pfList_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN (SumOutOperator);
|
||||
};
|
||||
|
||||
|
||||
|
||||
class CountingOperator : public LiftedOperator {
|
||||
public:
|
||||
CountingOperator (
|
||||
ParfactorList::iterator pfIter,
|
||||
LogVar X,
|
||||
ParfactorList& pfList)
|
||||
: pfIter_(pfIter), X_(X), pfList_(pfList) { }
|
||||
|
||||
double getLogCost();
|
||||
|
||||
void apply();
|
||||
|
||||
static std::vector<CountingOperator*> getValidOps (ParfactorList&);
|
||||
|
||||
std::string toString();
|
||||
|
||||
private:
|
||||
static bool validOp (Parfactor*, LogVar);
|
||||
|
||||
ParfactorList::iterator pfIter_;
|
||||
LogVar X_;
|
||||
ParfactorList& pfList_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN (CountingOperator);
|
||||
};
|
||||
|
||||
|
||||
|
||||
class GroundOperator : public LiftedOperator {
|
||||
public:
|
||||
GroundOperator (
|
||||
PrvGroup group,
|
||||
unsigned lvIndex,
|
||||
ParfactorList& pfList)
|
||||
: group_(group), lvIndex_(lvIndex), pfList_(pfList) { }
|
||||
|
||||
double getLogCost();
|
||||
|
||||
void apply();
|
||||
|
||||
static std::vector<GroundOperator*> getValidOps (ParfactorList&);
|
||||
|
||||
std::string toString();
|
||||
|
||||
private:
|
||||
std::vector<std::pair<PrvGroup, unsigned>> getAffectedFormulas();
|
||||
|
||||
PrvGroup group_;
|
||||
unsigned lvIndex_;
|
||||
ParfactorList& pfList_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN (GroundOperator);
|
||||
};
|
||||
|
||||
|
||||
|
||||
std::vector<LiftedOperator*>
|
||||
LiftedOperator::getValidOps (
|
||||
ParfactorList& pfList,
|
||||
const Grounds& query)
|
||||
{
|
||||
vector<LiftedOperator*> validOps;
|
||||
vector<ProductOperator*> multOps;
|
||||
std::vector<LiftedOperator*> validOps;
|
||||
std::vector<ProductOperator*> multOps;
|
||||
|
||||
multOps = ProductOperator::getValidOps (pfList);
|
||||
validOps.insert (validOps.end(), multOps.begin(), multOps.end());
|
||||
|
||||
if (Globals::verbosity > 1 || multOps.empty()) {
|
||||
vector<SumOutOperator*> sumOutOps;
|
||||
vector<CountingOperator*> countOps;
|
||||
vector<GroundOperator*> groundOps;
|
||||
std::vector<SumOutOperator*> sumOutOps;
|
||||
std::vector<CountingOperator*> countOps;
|
||||
std::vector<GroundOperator*> groundOps;
|
||||
sumOutOps = SumOutOperator::getValidOps (pfList, query);
|
||||
countOps = CountingOperator::getValidOps (pfList);
|
||||
groundOps = GroundOperator::getValidOps (pfList);
|
||||
@ -41,21 +184,21 @@ LiftedOperator::printValidOps (
|
||||
ParfactorList& pfList,
|
||||
const Grounds& query)
|
||||
{
|
||||
vector<LiftedOperator*> validOps;
|
||||
std::vector<LiftedOperator*> validOps;
|
||||
validOps = LiftedOperator::getValidOps (pfList, query);
|
||||
for (size_t i = 0; i < validOps.size(); i++) {
|
||||
cout << "-> " << validOps[i]->toString();
|
||||
std::cout << "-> " << validOps[i]->toString();
|
||||
delete validOps[i];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
vector<ParfactorList::iterator>
|
||||
std::vector<ParfactorList::iterator>
|
||||
LiftedOperator::getParfactorsWithGroup (
|
||||
ParfactorList& pfList, PrvGroup group)
|
||||
{
|
||||
vector<ParfactorList::iterator> iters;
|
||||
std::vector<ParfactorList::iterator> iters;
|
||||
ParfactorList::iterator pflIt = pfList.begin();
|
||||
while (pflIt != pfList.end()) {
|
||||
if ((*pflIt)->containsGroup (group)) {
|
||||
@ -69,7 +212,7 @@ LiftedOperator::getParfactorsWithGroup (
|
||||
|
||||
|
||||
double
|
||||
ProductOperator::getLogCost (void)
|
||||
ProductOperator::getLogCost()
|
||||
{
|
||||
return std::log (0.0);
|
||||
}
|
||||
@ -77,7 +220,7 @@ ProductOperator::getLogCost (void)
|
||||
|
||||
|
||||
void
|
||||
ProductOperator::apply (void)
|
||||
ProductOperator::apply()
|
||||
{
|
||||
Parfactor* g1 = *g1_;
|
||||
Parfactor* g2 = *g2_;
|
||||
@ -89,13 +232,13 @@ ProductOperator::apply (void)
|
||||
|
||||
|
||||
|
||||
vector<ProductOperator*>
|
||||
std::vector<ProductOperator*>
|
||||
ProductOperator::getValidOps (ParfactorList& pfList)
|
||||
{
|
||||
vector<ProductOperator*> validOps;
|
||||
std::vector<ProductOperator*> validOps;
|
||||
ParfactorList::iterator it1 = pfList.begin();
|
||||
ParfactorList::iterator penultimate = -- pfList.end();
|
||||
set<Parfactor*> pfs;
|
||||
std::set<Parfactor*> pfs;
|
||||
while (it1 != penultimate) {
|
||||
if (Util::contains (pfs, *it1)) {
|
||||
++ it1;
|
||||
@ -128,15 +271,15 @@ ProductOperator::getValidOps (ParfactorList& pfList)
|
||||
|
||||
|
||||
|
||||
string
|
||||
ProductOperator::toString (void)
|
||||
std::string
|
||||
ProductOperator::toString()
|
||||
{
|
||||
stringstream ss;
|
||||
std::stringstream ss;
|
||||
ss << "just multiplicate " ;
|
||||
ss << (*g1_)->getAllGroups();
|
||||
ss << " x " ;
|
||||
ss << (*g2_)->getAllGroups();
|
||||
ss << " [cost=" << std::exp (getLogCost()) << "]" << endl;
|
||||
ss << " [cost=" << std::exp (getLogCost()) << "]" << std::endl;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
@ -168,14 +311,14 @@ ProductOperator::validOp (Parfactor* g1, Parfactor* g2)
|
||||
|
||||
|
||||
double
|
||||
SumOutOperator::getLogCost (void)
|
||||
SumOutOperator::getLogCost()
|
||||
{
|
||||
TinySet<PrvGroup> groupSet;
|
||||
ParfactorList::const_iterator pfIter = pfList_.begin();
|
||||
unsigned nrProdFactors = 0;
|
||||
while (pfIter != pfList_.end()) {
|
||||
if ((*pfIter)->containsGroup (group_)) {
|
||||
vector<PrvGroup> groups = (*pfIter)->getAllGroups();
|
||||
std::vector<PrvGroup> groups = (*pfIter)->getAllGroups();
|
||||
groupSet |= TinySet<PrvGroup> (groups);
|
||||
++ nrProdFactors;
|
||||
}
|
||||
@ -203,9 +346,9 @@ SumOutOperator::getLogCost (void)
|
||||
|
||||
|
||||
void
|
||||
SumOutOperator::apply (void)
|
||||
SumOutOperator::apply()
|
||||
{
|
||||
vector<ParfactorList::iterator> iters;
|
||||
std::vector<ParfactorList::iterator> iters;
|
||||
iters = getParfactorsWithGroup (pfList_, group_);
|
||||
Parfactor* product = *(iters[0]);
|
||||
pfList_.remove (iters[0]);
|
||||
@ -234,13 +377,13 @@ SumOutOperator::apply (void)
|
||||
|
||||
|
||||
|
||||
vector<SumOutOperator*>
|
||||
std::vector<SumOutOperator*>
|
||||
SumOutOperator::getValidOps (
|
||||
ParfactorList& pfList,
|
||||
const Grounds& query)
|
||||
{
|
||||
vector<SumOutOperator*> validOps;
|
||||
set<PrvGroup> allGroups;
|
||||
std::vector<SumOutOperator*> validOps;
|
||||
std::set<PrvGroup> allGroups;
|
||||
ParfactorList::const_iterator it = pfList.begin();
|
||||
while (it != pfList.end()) {
|
||||
const ProbFormulas& formulas = (*it)->arguments();
|
||||
@ -249,7 +392,7 @@ SumOutOperator::getValidOps (
|
||||
}
|
||||
++ it;
|
||||
}
|
||||
set<PrvGroup>::const_iterator groupIt = allGroups.begin();
|
||||
std::set<PrvGroup>::const_iterator groupIt = allGroups.begin();
|
||||
while (groupIt != allGroups.end()) {
|
||||
if (validOp (*groupIt, pfList, query)) {
|
||||
validOps.push_back (new SumOutOperator (*groupIt, pfList));
|
||||
@ -261,18 +404,18 @@ SumOutOperator::getValidOps (
|
||||
|
||||
|
||||
|
||||
string
|
||||
SumOutOperator::toString (void)
|
||||
std::string
|
||||
SumOutOperator::toString()
|
||||
{
|
||||
stringstream ss;
|
||||
vector<ParfactorList::iterator> pfIters;
|
||||
std::stringstream ss;
|
||||
std::vector<ParfactorList::iterator> pfIters;
|
||||
pfIters = getParfactorsWithGroup (pfList_, group_);
|
||||
size_t idx = (*pfIters[0])->indexOfGroup (group_);
|
||||
ProbFormula f = (*pfIters[0])->argument (idx);
|
||||
TupleSet tupleSet = (*pfIters[0])->constr()->tupleSet (f.logVars());
|
||||
ss << "sum out " << f.functor() << "/" << f.arity();
|
||||
ss << "|" << tupleSet << " (group " << group_ << ")";
|
||||
ss << " [cost=" << std::exp (getLogCost()) << "]" << endl;
|
||||
ss << " [cost=" << std::exp (getLogCost()) << "]" << std::endl;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
@ -284,7 +427,7 @@ SumOutOperator::validOp (
|
||||
ParfactorList& pfList,
|
||||
const Grounds& query)
|
||||
{
|
||||
vector<ParfactorList::iterator> pfIters;
|
||||
std::vector<ParfactorList::iterator> pfIters;
|
||||
pfIters = getParfactorsWithGroup (pfList, group);
|
||||
if (isToEliminate (*pfIters[0], group, query) == false) {
|
||||
return false;
|
||||
@ -335,7 +478,7 @@ SumOutOperator::isToEliminate (
|
||||
|
||||
|
||||
double
|
||||
CountingOperator::getLogCost (void)
|
||||
CountingOperator::getLogCost()
|
||||
{
|
||||
double cost = 0.0;
|
||||
size_t fIdx = (*pfIter_)->indexOfLogVar (X_);
|
||||
@ -370,7 +513,7 @@ CountingOperator::getLogCost (void)
|
||||
|
||||
|
||||
void
|
||||
CountingOperator::apply (void)
|
||||
CountingOperator::apply()
|
||||
{
|
||||
if ((*pfIter_)->constr()->isCountNormalized (X_)) {
|
||||
(*pfIter_)->countConvert (X_);
|
||||
@ -393,10 +536,10 @@ CountingOperator::apply (void)
|
||||
|
||||
|
||||
|
||||
vector<CountingOperator*>
|
||||
std::vector<CountingOperator*>
|
||||
CountingOperator::getValidOps (ParfactorList& pfList)
|
||||
{
|
||||
vector<CountingOperator*> validOps;
|
||||
std::vector<CountingOperator*> validOps;
|
||||
ParfactorList::iterator it = pfList.begin();
|
||||
while (it != pfList.end()) {
|
||||
LogVarSet candidates = (*it)->uncountedLogVars();
|
||||
@ -414,17 +557,17 @@ CountingOperator::getValidOps (ParfactorList& pfList)
|
||||
|
||||
|
||||
|
||||
string
|
||||
CountingOperator::toString (void)
|
||||
std::string
|
||||
CountingOperator::toString()
|
||||
{
|
||||
stringstream ss;
|
||||
std::stringstream ss;
|
||||
ss << "count convert " << X_ << " in " ;
|
||||
ss << (*pfIter_)->getLabel();
|
||||
ss << " [cost=" << std::exp (getLogCost()) << "]" << endl;
|
||||
ss << " [cost=" << std::exp (getLogCost()) << "]" << std::endl;
|
||||
Parfactors pfs = LiftedOperations::countNormalize (*pfIter_, X_);
|
||||
if ((*pfIter_)->constr()->isCountNormalized (X_) == false) {
|
||||
for (size_t i = 0; i < pfs.size(); i++) {
|
||||
ss << " º " << pfs[i]->getLabel() << endl;
|
||||
ss << " º " << pfs[i]->getLabel() << std::endl;
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < pfs.size(); i++) {
|
||||
@ -455,16 +598,16 @@ CountingOperator::validOp (Parfactor* g, LogVar X)
|
||||
|
||||
|
||||
double
|
||||
GroundOperator::getLogCost (void)
|
||||
GroundOperator::getLogCost()
|
||||
{
|
||||
vector<pair<PrvGroup, unsigned>> affectedFormulas;
|
||||
std::vector<std::pair<PrvGroup, unsigned>> affectedFormulas;
|
||||
affectedFormulas = getAffectedFormulas();
|
||||
// cout << "affected formulas: " ;
|
||||
// std::cout << "affected formulas: " ;
|
||||
// for (size_t i = 0; i < affectedFormulas.size(); i++) {
|
||||
// cout << affectedFormulas[i].first << ":" ;
|
||||
// cout << affectedFormulas[i].second << " " ;
|
||||
// std::cout << affectedFormulas[i].first << ":" ;
|
||||
// std::cout << affectedFormulas[i].second << " " ;
|
||||
// }
|
||||
// cout << "cost =" ;
|
||||
// std::cout << "cost =" ;
|
||||
double totalCost = std::log (0.0);
|
||||
ParfactorList::iterator pflIt = pfList_.begin();
|
||||
while (pflIt != pfList_.end()) {
|
||||
@ -495,20 +638,20 @@ GroundOperator::getLogCost (void)
|
||||
}
|
||||
}
|
||||
if (willBeAffected) {
|
||||
// cout << " + " << std::exp (reps) << "x" << std::exp (pfSize);
|
||||
// std::cout << " + " << std::exp (reps) << "x" << std::exp (pfSize);
|
||||
double pfCost = reps + pfSize;
|
||||
totalCost = Util::logSum (totalCost, pfCost);
|
||||
}
|
||||
++ pflIt;
|
||||
}
|
||||
// cout << endl;
|
||||
// std::cout << std::endl;
|
||||
return totalCost + 3;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
GroundOperator::apply (void)
|
||||
GroundOperator::apply()
|
||||
{
|
||||
ParfactorList::iterator pfIter;
|
||||
pfIter = getParfactorsWithGroup (pfList_, group_).front();
|
||||
@ -537,11 +680,11 @@ GroundOperator::apply (void)
|
||||
|
||||
|
||||
|
||||
vector<GroundOperator*>
|
||||
std::vector<GroundOperator*>
|
||||
GroundOperator::getValidOps (ParfactorList& pfList)
|
||||
{
|
||||
vector<GroundOperator*> validOps;
|
||||
set<PrvGroup> allGroups;
|
||||
std::vector<GroundOperator*> validOps;
|
||||
std::set<PrvGroup> allGroups;
|
||||
ParfactorList::const_iterator it = pfList.begin();
|
||||
while (it != pfList.end()) {
|
||||
const ProbFormulas& formulas = (*it)->arguments();
|
||||
@ -564,18 +707,18 @@ GroundOperator::getValidOps (ParfactorList& pfList)
|
||||
|
||||
|
||||
|
||||
string
|
||||
GroundOperator::toString (void)
|
||||
std::string
|
||||
GroundOperator::toString()
|
||||
{
|
||||
stringstream ss;
|
||||
vector<ParfactorList::iterator> pfIters;
|
||||
std::stringstream ss;
|
||||
std::vector<ParfactorList::iterator> pfIters;
|
||||
pfIters = getParfactorsWithGroup (pfList_, group_);
|
||||
Parfactor* pf = *(getParfactorsWithGroup (pfList_, group_).front());
|
||||
size_t idx = pf->indexOfGroup (group_);
|
||||
ProbFormula f = pf->argument (idx);
|
||||
LogVar lv = f.logVars()[lvIndex_];
|
||||
TupleSet tupleSet = pf->constr()->tupleSet ({lv});
|
||||
string pos = "th";
|
||||
std::string pos = "th";
|
||||
if (lvIndex_ == 0) {
|
||||
pos = "st" ;
|
||||
} else if (lvIndex_ == 1) {
|
||||
@ -586,21 +729,21 @@ GroundOperator::toString (void)
|
||||
ss << "grounding " << lvIndex_ + 1 << pos << " log var in " ;
|
||||
ss << f.functor() << "/" << f.arity();
|
||||
ss << "|" << tupleSet << " (group " << group_ << ")";
|
||||
ss << " [cost=" << std::exp (getLogCost()) << "]" << endl;
|
||||
ss << " [cost=" << std::exp (getLogCost()) << "]" << std::endl;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
|
||||
|
||||
vector<pair<PrvGroup, unsigned>>
|
||||
GroundOperator::getAffectedFormulas (void)
|
||||
std::vector<std::pair<PrvGroup, unsigned>>
|
||||
GroundOperator::getAffectedFormulas()
|
||||
{
|
||||
vector<pair<PrvGroup, unsigned>> affectedFormulas;
|
||||
affectedFormulas.push_back (make_pair (group_, lvIndex_));
|
||||
queue<pair<PrvGroup, unsigned>> q;
|
||||
q.push (make_pair (group_, lvIndex_));
|
||||
std::vector<std::pair<PrvGroup, unsigned>> affectedFormulas;
|
||||
affectedFormulas.push_back (std::make_pair (group_, lvIndex_));
|
||||
std::queue<std::pair<PrvGroup, unsigned>> q;
|
||||
q.push (std::make_pair (group_, lvIndex_));
|
||||
while (q.empty() == false) {
|
||||
pair<PrvGroup, unsigned> front = q.front();
|
||||
std::pair<PrvGroup, unsigned> front = q.front();
|
||||
ParfactorList::iterator pflIt = pfList_.begin();
|
||||
while (pflIt != pfList_.end()) {
|
||||
size_t idx = (*pflIt)->indexOfGroup (front.first);
|
||||
@ -610,7 +753,7 @@ GroundOperator::getAffectedFormulas (void)
|
||||
const ProbFormulas& fs = (*pflIt)->arguments();
|
||||
for (size_t i = 0; i < fs.size(); i++) {
|
||||
if (i != idx && fs[i].contains (X)) {
|
||||
pair<PrvGroup, unsigned> pair = make_pair (
|
||||
std::pair<PrvGroup, unsigned> pair = std::make_pair (
|
||||
fs[i].group(), fs[i].indexOf (X));
|
||||
if (Util::contains (affectedFormulas, pair) == false) {
|
||||
q.push (pair);
|
||||
@ -645,13 +788,13 @@ LiftedVe::solveQuery (const Grounds& query)
|
||||
|
||||
|
||||
void
|
||||
LiftedVe::printSolverFlags (void) const
|
||||
LiftedVe::printSolverFlags() const
|
||||
{
|
||||
stringstream ss;
|
||||
std::stringstream ss;
|
||||
ss << "lve [" ;
|
||||
ss << "log_domain=" << Util::toString (Globals::logDomain);
|
||||
ss << "]" ;
|
||||
cout << ss.str() << endl;
|
||||
std::cout << ss.str() << std::endl;
|
||||
}
|
||||
|
||||
|
||||
@ -675,9 +818,9 @@ LiftedVe::runSolver (const Grounds& query)
|
||||
break;
|
||||
}
|
||||
if (Globals::verbosity > 1) {
|
||||
cout << "best operation: " << op->toString();
|
||||
std::cout << "best operation: " << op->toString();
|
||||
if (Globals::verbosity > 2) {
|
||||
cout << endl;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
op->apply();
|
||||
@ -693,8 +836,9 @@ LiftedVe::runSolver (const Grounds& query)
|
||||
}
|
||||
}
|
||||
if (Globals::verbosity > 0) {
|
||||
cout << "largest cost = " << std::exp (largestCost_) << endl;
|
||||
cout << endl;
|
||||
std::cout << "largest cost = " << std::exp (largestCost_);
|
||||
std::cout << std::endl;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
(*pfList_.begin())->simplifyGrounds();
|
||||
(*pfList_.begin())->reorderAccordingGrounds (query);
|
||||
@ -707,7 +851,7 @@ LiftedVe::getBestOperation (const Grounds& query)
|
||||
{
|
||||
double bestCost = 0.0;
|
||||
LiftedOperator* bestOp = 0;
|
||||
vector<LiftedOperator*> validOps;
|
||||
std::vector<LiftedOperator*> validOps;
|
||||
validOps = LiftedOperator::getValidOps (pfList_, query);
|
||||
for (size_t i = 0; i < validOps.size(); i++) {
|
||||
double cost = validOps[i]->getLogCost();
|
||||
@ -727,3 +871,5 @@ LiftedVe::getBestOperation (const Grounds& query)
|
||||
return bestOp;
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
|
@ -1,157 +1,23 @@
|
||||
#ifndef HORUS_LIFTEDVE_H
|
||||
#define HORUS_LIFTEDVE_H
|
||||
#ifndef YAP_PACKAGES_CLPBN_HORUS_LIFTEDVE_H_
|
||||
#define YAP_PACKAGES_CLPBN_HORUS_LIFTEDVE_H_
|
||||
|
||||
#include "LiftedSolver.h"
|
||||
#include "ParfactorList.h"
|
||||
|
||||
|
||||
class LiftedOperator
|
||||
{
|
||||
public:
|
||||
virtual ~LiftedOperator (void) { }
|
||||
namespace Horus {
|
||||
|
||||
virtual double getLogCost (void) = 0;
|
||||
|
||||
virtual void apply (void) = 0;
|
||||
|
||||
virtual string toString (void) = 0;
|
||||
|
||||
static vector<LiftedOperator*> getValidOps (
|
||||
ParfactorList&, const Grounds&);
|
||||
|
||||
static void printValidOps (ParfactorList&, const Grounds&);
|
||||
|
||||
static vector<ParfactorList::iterator> getParfactorsWithGroup (
|
||||
ParfactorList&, PrvGroup group);
|
||||
|
||||
private:
|
||||
DISALLOW_ASSIGN (LiftedOperator);
|
||||
};
|
||||
class LiftedOperator;
|
||||
|
||||
|
||||
|
||||
class ProductOperator : public LiftedOperator
|
||||
{
|
||||
public:
|
||||
ProductOperator (
|
||||
ParfactorList::iterator g1, ParfactorList::iterator g2,
|
||||
ParfactorList& pfList) : g1_(g1), g2_(g2), pfList_(pfList) { }
|
||||
|
||||
double getLogCost (void);
|
||||
|
||||
void apply (void);
|
||||
|
||||
static vector<ProductOperator*> getValidOps (ParfactorList&);
|
||||
|
||||
string toString (void);
|
||||
|
||||
private:
|
||||
static bool validOp (Parfactor*, Parfactor*);
|
||||
|
||||
ParfactorList::iterator g1_;
|
||||
ParfactorList::iterator g2_;
|
||||
ParfactorList& pfList_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN (ProductOperator);
|
||||
};
|
||||
|
||||
|
||||
|
||||
class SumOutOperator : public LiftedOperator
|
||||
{
|
||||
public:
|
||||
SumOutOperator (PrvGroup group, ParfactorList& pfList)
|
||||
: group_(group), pfList_(pfList) { }
|
||||
|
||||
double getLogCost (void);
|
||||
|
||||
void apply (void);
|
||||
|
||||
static vector<SumOutOperator*> getValidOps (
|
||||
ParfactorList&, const Grounds&);
|
||||
|
||||
string toString (void);
|
||||
|
||||
private:
|
||||
static bool validOp (PrvGroup, ParfactorList&, const Grounds&);
|
||||
|
||||
static bool isToEliminate (Parfactor*, PrvGroup, const Grounds&);
|
||||
|
||||
PrvGroup group_;
|
||||
ParfactorList& pfList_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN (SumOutOperator);
|
||||
};
|
||||
|
||||
|
||||
|
||||
class CountingOperator : public LiftedOperator
|
||||
{
|
||||
public:
|
||||
CountingOperator (
|
||||
ParfactorList::iterator pfIter,
|
||||
LogVar X,
|
||||
ParfactorList& pfList)
|
||||
: pfIter_(pfIter), X_(X), pfList_(pfList) { }
|
||||
|
||||
double getLogCost (void);
|
||||
|
||||
void apply (void);
|
||||
|
||||
static vector<CountingOperator*> getValidOps (ParfactorList&);
|
||||
|
||||
string toString (void);
|
||||
|
||||
private:
|
||||
static bool validOp (Parfactor*, LogVar);
|
||||
|
||||
ParfactorList::iterator pfIter_;
|
||||
LogVar X_;
|
||||
ParfactorList& pfList_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN (CountingOperator);
|
||||
};
|
||||
|
||||
|
||||
|
||||
class GroundOperator : public LiftedOperator
|
||||
{
|
||||
public:
|
||||
GroundOperator (
|
||||
PrvGroup group,
|
||||
unsigned lvIndex,
|
||||
ParfactorList& pfList)
|
||||
: group_(group), lvIndex_(lvIndex), pfList_(pfList) { }
|
||||
|
||||
double getLogCost (void);
|
||||
|
||||
void apply (void);
|
||||
|
||||
static vector<GroundOperator*> getValidOps (ParfactorList&);
|
||||
|
||||
string toString (void);
|
||||
|
||||
private:
|
||||
vector<pair<PrvGroup, unsigned>> getAffectedFormulas (void);
|
||||
|
||||
PrvGroup group_;
|
||||
unsigned lvIndex_;
|
||||
ParfactorList& pfList_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN (GroundOperator);
|
||||
};
|
||||
|
||||
|
||||
|
||||
class LiftedVe : public LiftedSolver
|
||||
{
|
||||
class LiftedVe : public LiftedSolver {
|
||||
public:
|
||||
LiftedVe (const ParfactorList& pfList)
|
||||
: LiftedSolver(pfList) { }
|
||||
|
||||
Params solveQuery (const Grounds&);
|
||||
|
||||
void printSolverFlags (void) const;
|
||||
void printSolverFlags() const;
|
||||
|
||||
private:
|
||||
void runSolver (const Grounds&);
|
||||
@ -164,5 +30,7 @@ class LiftedVe : public LiftedSolver
|
||||
DISALLOW_COPY_AND_ASSIGN (LiftedVe);
|
||||
};
|
||||
|
||||
#endif // HORUS_LIFTEDVE_H
|
||||
} // namespace Horus
|
||||
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_LIFTEDVE_H_
|
||||
|
||||
|
@ -1,10 +1,20 @@
|
||||
#include <cassert>
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "LiftedWCNF.h"
|
||||
#include "ParfactorList.h"
|
||||
#include "ConstraintTree.h"
|
||||
#include "Indexer.h"
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
bool
|
||||
Literal::isGround (ConstraintTree constr, LogVarSet ipgLogVars) const
|
||||
Literal::isGround (
|
||||
ConstraintTree constr,
|
||||
const LogVarSet& ipgLogVars) const
|
||||
{
|
||||
if (logVars_.empty()) {
|
||||
return true;
|
||||
@ -24,13 +34,13 @@ Literal::indexOfLogVar (LogVar X) const
|
||||
|
||||
|
||||
|
||||
string
|
||||
std::string
|
||||
Literal::toString (
|
||||
LogVarSet ipgLogVars,
|
||||
LogVarSet posCountedLvs,
|
||||
LogVarSet negCountedLvs) const
|
||||
{
|
||||
stringstream ss;
|
||||
std::stringstream ss;
|
||||
negated_ ? ss << "¬" : ss << "" ;
|
||||
ss << "λ" ;
|
||||
ss << lid_ ;
|
||||
@ -44,7 +54,7 @@ Literal::toString (
|
||||
ss << "-" << logVars_[i];
|
||||
} else if (ipgLogVars.contains (logVars_[i])) {
|
||||
LogVar X = logVars_[i];
|
||||
const string labels[] = {
|
||||
const std::string labels[] = {
|
||||
"a", "b", "c", "d", "e", "f",
|
||||
"g", "h", "i", "j", "k", "m" };
|
||||
(X >= 12) ? ss << "x_" << X : ss << labels[X];
|
||||
@ -60,7 +70,7 @@ Literal::toString (
|
||||
|
||||
|
||||
std::ostream&
|
||||
operator<< (ostream &os, const Literal& lit)
|
||||
operator<< (std::ostream& os, const Literal& lit)
|
||||
{
|
||||
os << lit.toString();
|
||||
return os;
|
||||
@ -216,7 +226,7 @@ Clause::isIpgLogVar (LogVar X) const
|
||||
|
||||
|
||||
TinySet<LiteralId>
|
||||
Clause::lidSet (void) const
|
||||
Clause::lidSet() const
|
||||
{
|
||||
TinySet<LiteralId> lidSet;
|
||||
for (size_t i = 0; i < literals_.size(); i++) {
|
||||
@ -228,7 +238,7 @@ Clause::lidSet (void) const
|
||||
|
||||
|
||||
LogVarSet
|
||||
Clause::ipgCandidates (void) const
|
||||
Clause::ipgCandidates() const
|
||||
{
|
||||
LogVarSet candidates;
|
||||
LogVarSet allLvs = constr_.logVarSet();
|
||||
@ -259,11 +269,11 @@ Clause::logVarTypes (size_t litIdx) const
|
||||
const LogVars& lvs = literals_[litIdx].logVars();
|
||||
for (size_t i = 0; i < lvs.size(); i++) {
|
||||
if (posCountedLvs_.contains (lvs[i])) {
|
||||
types.push_back (LogVarType::POS_LV);
|
||||
types.push_back (LogVarType::posLvt);
|
||||
} else if (negCountedLvs_.contains (lvs[i])) {
|
||||
types.push_back (LogVarType::NEG_LV);
|
||||
types.push_back (LogVarType::negLvt);
|
||||
} else {
|
||||
types.push_back (LogVarType::FULL_LV);
|
||||
types.push_back (LogVarType::fullLvt);
|
||||
}
|
||||
}
|
||||
return types;
|
||||
@ -320,7 +330,7 @@ void
|
||||
Clause::printClauses (const Clauses& clauses)
|
||||
{
|
||||
for (size_t i = 0; i < clauses.size(); i++) {
|
||||
cout << *clauses[i] << endl;
|
||||
std::cout << *clauses[i] << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
@ -337,7 +347,7 @@ Clause::deleteClauses (Clauses& clauses)
|
||||
|
||||
|
||||
std::ostream&
|
||||
operator<< (ostream &os, const Clause& clause)
|
||||
operator<< (std::ostream& os, const Clause& clause)
|
||||
{
|
||||
for (unsigned i = 0; i < clause.literals_.size(); i++) {
|
||||
if (i != 0) os << " v " ;
|
||||
@ -369,14 +379,14 @@ Clause::getLogVarSetExcluding (size_t idx) const
|
||||
|
||||
|
||||
std::ostream&
|
||||
operator<< (std::ostream &os, const LitLvTypes& lit)
|
||||
operator<< (std::ostream& os, const LitLvTypes& lit)
|
||||
{
|
||||
os << lit.lid_ << "<" ;
|
||||
for (size_t i = 0; i < lit.lvTypes_.size(); i++) {
|
||||
switch (lit.lvTypes_[i]) {
|
||||
case LogVarType::FULL_LV: os << "F" ; break;
|
||||
case LogVarType::POS_LV: os << "P" ; break;
|
||||
case LogVarType::NEG_LV: os << "N" ; break;
|
||||
case LogVarType::fullLvt: os << "F" ; break;
|
||||
case LogVarType::posLvt: os << "P" ; break;
|
||||
case LogVarType::negLvt: os << "N" ; break;
|
||||
}
|
||||
}
|
||||
os << ">" ;
|
||||
@ -385,6 +395,14 @@ operator<< (std::ostream &os, const LitLvTypes& lit)
|
||||
|
||||
|
||||
|
||||
void
|
||||
LitLvTypes::setAllFullLogVars()
|
||||
{
|
||||
std::fill (lvTypes_.begin(), lvTypes_.end(), LogVarType::fullLvt);
|
||||
}
|
||||
|
||||
|
||||
|
||||
LiftedWCNF::LiftedWCNF (const ParfactorList& pfList)
|
||||
: freeLiteralId_(0), pfList_(pfList)
|
||||
{
|
||||
@ -394,7 +412,7 @@ LiftedWCNF::LiftedWCNF (const ParfactorList& pfList)
|
||||
/*
|
||||
// INCLUSION-EXCLUSION TEST
|
||||
clauses_.clear();
|
||||
vector<vector<string>> names = {
|
||||
std::vector<std::vector<string>> names = {
|
||||
{"a1","b1"},{"a2","b2"}
|
||||
};
|
||||
Clause* c1 = new Clause (names);
|
||||
@ -406,7 +424,7 @@ LiftedWCNF::LiftedWCNF (const ParfactorList& pfList)
|
||||
/*
|
||||
// INDEPENDENT PARTIAL GROUND TEST
|
||||
clauses_.clear();
|
||||
vector<vector<string>> names = {
|
||||
std::vector<std::vector<string>> names = {
|
||||
{"a1","b1"},{"a2","b2"}
|
||||
};
|
||||
Clause* c1 = new Clause (names);
|
||||
@ -422,7 +440,7 @@ LiftedWCNF::LiftedWCNF (const ParfactorList& pfList)
|
||||
/*
|
||||
// ATOM-COUNTING TEST
|
||||
clauses_.clear();
|
||||
vector<vector<string>> names = {
|
||||
std::vector<std::vector<string>> names = {
|
||||
{"p1","p1"},{"p1","p2"},{"p1","p3"},
|
||||
{"p2","p1"},{"p2","p2"},{"p2","p3"},
|
||||
{"p3","p1"},{"p3","p2"},{"p3","p3"}
|
||||
@ -438,21 +456,21 @@ LiftedWCNF::LiftedWCNF (const ParfactorList& pfList)
|
||||
*/
|
||||
|
||||
if (Globals::verbosity > 1) {
|
||||
cout << "FORMULA INDICATORS:" << endl;
|
||||
std::cout << "FORMULA INDICATORS:" << std::endl;
|
||||
printFormulaIndicators();
|
||||
cout << endl;
|
||||
cout << "WEIGHTED INDICATORS:" << endl;
|
||||
std::cout << std::endl;
|
||||
std::cout << "WEIGHTED INDICATORS:" << std::endl;
|
||||
printWeights();
|
||||
cout << endl;
|
||||
cout << "CLAUSES:" << endl;
|
||||
std::cout << std::endl;
|
||||
std::cout << "CLAUSES:" << std::endl;
|
||||
printClauses();
|
||||
cout << endl;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
LiftedWCNF::~LiftedWCNF (void)
|
||||
LiftedWCNF::~LiftedWCNF()
|
||||
{
|
||||
Clause::deleteClauses (clauses_);
|
||||
}
|
||||
@ -462,7 +480,7 @@ LiftedWCNF::~LiftedWCNF (void)
|
||||
void
|
||||
LiftedWCNF::addWeight (LiteralId lid, double posW, double negW)
|
||||
{
|
||||
weights_[lid] = make_pair (posW, negW);
|
||||
weights_[lid] = std::make_pair (posW, negW);
|
||||
}
|
||||
|
||||
|
||||
@ -470,8 +488,8 @@ LiftedWCNF::addWeight (LiteralId lid, double posW, double negW)
|
||||
double
|
||||
LiftedWCNF::posWeight (LiteralId lid) const
|
||||
{
|
||||
unordered_map<LiteralId, std::pair<double,double>>::const_iterator it;
|
||||
it = weights_.find (lid);
|
||||
std::unordered_map<LiteralId, std::pair<double,double>>::const_iterator it
|
||||
= weights_.find (lid);
|
||||
return it != weights_.end() ? it->second.first : LogAware::one();
|
||||
}
|
||||
|
||||
@ -480,14 +498,14 @@ LiftedWCNF::posWeight (LiteralId lid) const
|
||||
double
|
||||
LiftedWCNF::negWeight (LiteralId lid) const
|
||||
{
|
||||
unordered_map<LiteralId, std::pair<double,double>>::const_iterator it;
|
||||
it = weights_.find (lid);
|
||||
std::unordered_map<LiteralId, std::pair<double,double>>::const_iterator it
|
||||
= weights_.find (lid);
|
||||
return it != weights_.end() ? it->second.second : LogAware::one();
|
||||
}
|
||||
|
||||
|
||||
|
||||
vector<LiteralId>
|
||||
std::vector<LiteralId>
|
||||
LiftedWCNF::prvGroupLiterals (PrvGroup prvGroup)
|
||||
{
|
||||
assert (Util::contains (map_, prvGroup));
|
||||
@ -536,9 +554,10 @@ LiftedWCNF::addIndicatorClauses (const ParfactorList& pfList)
|
||||
ConstraintTree tempConstr = (*it)->constr()->projectedCopy(
|
||||
formulas[i].logVars());
|
||||
Clause* clause = new Clause (tempConstr);
|
||||
vector<LiteralId> lids;
|
||||
std::vector<LiteralId> lids;
|
||||
for (size_t j = 0; j < formulas[i].range(); j++) {
|
||||
clause->addLiteral (Literal (freeLiteralId_, formulas[i].logVars()));
|
||||
clause->addLiteral (Literal (
|
||||
freeLiteralId_, formulas[i].logVars()));
|
||||
lids.push_back (freeLiteralId_);
|
||||
freeLiteralId_ ++;
|
||||
}
|
||||
@ -568,7 +587,7 @@ LiftedWCNF::addParameterClauses (const ParfactorList& pfList)
|
||||
ParfactorList::const_iterator it = pfList.begin();
|
||||
while (it != pfList.end()) {
|
||||
Indexer indexer ((*it)->ranges());
|
||||
vector<PrvGroup> groups = (*it)->getAllGroups();
|
||||
std::vector<PrvGroup> groups = (*it)->getAllGroups();
|
||||
while (indexer.valid()) {
|
||||
LiteralId paramVarLid = freeLiteralId_;
|
||||
// λu1 ∧ ... ∧ λun ∧ λxi <=> θxi|u1,...,un
|
||||
@ -606,26 +625,26 @@ LiftedWCNF::addParameterClauses (const ParfactorList& pfList)
|
||||
|
||||
|
||||
void
|
||||
LiftedWCNF::printFormulaIndicators (void) const
|
||||
LiftedWCNF::printFormulaIndicators() const
|
||||
{
|
||||
if (map_.empty()) {
|
||||
return;
|
||||
}
|
||||
set<PrvGroup> allGroups;
|
||||
std::set<PrvGroup> allGroups;
|
||||
ParfactorList::const_iterator it = pfList_.begin();
|
||||
while (it != pfList_.end()) {
|
||||
const ProbFormulas& formulas = (*it)->arguments();
|
||||
for (size_t i = 0; i < formulas.size(); i++) {
|
||||
if (Util::contains (allGroups, formulas[i].group()) == false) {
|
||||
allGroups.insert (formulas[i].group());
|
||||
cout << formulas[i] << " | " ;
|
||||
std::cout << formulas[i] << " | " ;
|
||||
ConstraintTree tempCt = (*it)->constr()->projectedCopy (
|
||||
formulas[i].logVars());
|
||||
cout << tempCt.tupleSet();
|
||||
cout << " indicators => " ;
|
||||
vector<LiteralId> indicators =
|
||||
std::cout << tempCt.tupleSet();
|
||||
std::cout << " indicators => " ;
|
||||
std::vector<LiteralId> indicators =
|
||||
(map_.find (formulas[i].group()))->second;
|
||||
cout << indicators << endl;
|
||||
std::cout << indicators << std::endl;
|
||||
}
|
||||
}
|
||||
++ it;
|
||||
@ -635,14 +654,14 @@ LiftedWCNF::printFormulaIndicators (void) const
|
||||
|
||||
|
||||
void
|
||||
LiftedWCNF::printWeights (void) const
|
||||
LiftedWCNF::printWeights() const
|
||||
{
|
||||
unordered_map<LiteralId, std::pair<double,double>>::const_iterator it;
|
||||
it = weights_.begin();
|
||||
std::unordered_map<LiteralId, std::pair<double,double>>::const_iterator it
|
||||
= weights_.begin();
|
||||
while (it != weights_.end()) {
|
||||
cout << "λ" << it->first << " weights: " ;
|
||||
cout << it->second.first << " " << it->second.second;
|
||||
cout << endl;
|
||||
std::cout << "λ" << it->first << " weights: " ;
|
||||
std::cout << it->second.first << " " << it->second.second;
|
||||
std::cout << std::endl;
|
||||
++ it;
|
||||
}
|
||||
}
|
||||
@ -650,8 +669,10 @@ LiftedWCNF::printWeights (void) const
|
||||
|
||||
|
||||
void
|
||||
LiftedWCNF::printClauses (void) const
|
||||
LiftedWCNF::printClauses() const
|
||||
{
|
||||
Clause::printClauses (clauses_);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
@ -1,90 +1,95 @@
|
||||
#ifndef HORUS_LIFTEDWCNF_H
|
||||
#define HORUS_LIFTEDWCNF_H
|
||||
#ifndef YAP_PACKAGES_CLPBN_HORUS_LIFTEDWCNF_H_
|
||||
#define YAP_PACKAGES_CLPBN_HORUS_LIFTEDWCNF_H_
|
||||
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include <string>
|
||||
#include <ostream>
|
||||
|
||||
#include "ParfactorList.h"
|
||||
#include "ConstraintTree.h"
|
||||
#include "ProbFormula.h"
|
||||
#include "LiftedUtils.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
class ConstraintTree;
|
||||
namespace Horus {
|
||||
|
||||
enum LogVarType
|
||||
{
|
||||
FULL_LV,
|
||||
POS_LV,
|
||||
NEG_LV
|
||||
class ParfactorList;
|
||||
|
||||
enum class LogVarType {
|
||||
fullLvt,
|
||||
posLvt,
|
||||
negLvt
|
||||
};
|
||||
|
||||
typedef long LiteralId;
|
||||
typedef vector<LogVarType> LogVarTypes;
|
||||
typedef long LiteralId;
|
||||
typedef std::vector<LogVarType> LogVarTypes;
|
||||
|
||||
|
||||
class Literal
|
||||
{
|
||||
class Literal {
|
||||
public:
|
||||
Literal (LiteralId lid, const LogVars& lvs) :
|
||||
lid_(lid), logVars_(lvs), negated_(false) { }
|
||||
Literal (LiteralId lid, const LogVars& lvs)
|
||||
: lid_(lid), logVars_(lvs), negated_(false) { }
|
||||
|
||||
Literal (const Literal& lit, bool negated) :
|
||||
lid_(lit.lid_), logVars_(lit.logVars_), negated_(negated) { }
|
||||
Literal (const Literal& lit, bool negated)
|
||||
: lid_(lit.lid_), logVars_(lit.logVars_), negated_(negated) { }
|
||||
|
||||
LiteralId lid (void) const { return lid_; }
|
||||
LiteralId lid() const { return lid_; }
|
||||
|
||||
LogVars logVars (void) const { return logVars_; }
|
||||
LogVars logVars() const { return logVars_; }
|
||||
|
||||
size_t nrLogVars (void) const { return logVars_.size(); }
|
||||
size_t nrLogVars() const { return logVars_.size(); }
|
||||
|
||||
LogVarSet logVarSet (void) const { return LogVarSet (logVars_); }
|
||||
LogVarSet logVarSet() const { return LogVarSet (logVars_); }
|
||||
|
||||
void complement (void) { negated_ = !negated_; }
|
||||
void complement() { negated_ = !negated_; }
|
||||
|
||||
bool isPositive (void) const { return negated_ == false; }
|
||||
bool isPositive() const { return negated_ == false; }
|
||||
|
||||
bool isNegative (void) const { return negated_; }
|
||||
bool isNegative() const { return negated_; }
|
||||
|
||||
bool isGround (ConstraintTree constr, LogVarSet ipgLogVars) const;
|
||||
bool isGround (ConstraintTree constr, const LogVarSet& ipgLogVars) const;
|
||||
|
||||
size_t indexOfLogVar (LogVar X) const;
|
||||
|
||||
string toString (LogVarSet ipgLogVars = LogVarSet(),
|
||||
LogVarSet posCountedLvs = LogVarSet(),
|
||||
LogVarSet negCountedLvs = LogVarSet()) const;
|
||||
|
||||
friend std::ostream& operator<< (std::ostream &os, const Literal& lit);
|
||||
std::string toString (
|
||||
LogVarSet ipgLogVars = LogVarSet(),
|
||||
LogVarSet posCountedLvs = LogVarSet(),
|
||||
LogVarSet negCountedLvs = LogVarSet()) const;
|
||||
|
||||
private:
|
||||
friend std::ostream& operator<< (std::ostream&, const Literal&);
|
||||
|
||||
LiteralId lid_;
|
||||
LogVars logVars_;
|
||||
bool negated_;
|
||||
};
|
||||
|
||||
typedef vector<Literal> Literals;
|
||||
typedef std::vector<Literal> Literals;
|
||||
|
||||
|
||||
|
||||
class Clause
|
||||
{
|
||||
class Clause {
|
||||
public:
|
||||
Clause (const ConstraintTree& ct = ConstraintTree({})) : constr_(ct) { }
|
||||
|
||||
Clause (vector<vector<string>> names) : constr_(ConstraintTree (names)) { }
|
||||
Clause (std::vector<std::vector<std::string>> names) :
|
||||
constr_(ConstraintTree (names)) { }
|
||||
|
||||
void addLiteral (const Literal& l) { literals_.push_back (l); }
|
||||
|
||||
const Literals& literals (void) const { return literals_; }
|
||||
const Literals& literals() const { return literals_; }
|
||||
|
||||
Literals& literals (void) { return literals_; }
|
||||
Literals& literals() { return literals_; }
|
||||
|
||||
size_t nrLiterals (void) const { return literals_.size(); }
|
||||
size_t nrLiterals() const { return literals_.size(); }
|
||||
|
||||
const ConstraintTree& constr (void) const { return constr_; }
|
||||
const ConstraintTree& constr() const { return constr_; }
|
||||
|
||||
ConstraintTree constr (void) { return constr_; }
|
||||
ConstraintTree constr() { return constr_; }
|
||||
|
||||
bool isUnit (void) const { return literals_.size() == 1; }
|
||||
bool isUnit() const { return literals_.size() == 1; }
|
||||
|
||||
LogVarSet ipgLogVars (void) const { return ipgLvs_; }
|
||||
LogVarSet ipgLogVars() const { return ipgLvs_; }
|
||||
|
||||
void addIpgLogVar (LogVar X) { ipgLvs_.insert (X); }
|
||||
|
||||
@ -92,13 +97,13 @@ class Clause
|
||||
|
||||
void addNegCountedLogVar (LogVar X) { negCountedLvs_.insert (X); }
|
||||
|
||||
LogVarSet posCountedLogVars (void) const { return posCountedLvs_; }
|
||||
LogVarSet posCountedLogVars() const { return posCountedLvs_; }
|
||||
|
||||
LogVarSet negCountedLogVars (void) const { return negCountedLvs_; }
|
||||
LogVarSet negCountedLogVars() const { return negCountedLvs_; }
|
||||
|
||||
unsigned nrPosCountedLogVars (void) const { return posCountedLvs_.size(); }
|
||||
unsigned nrPosCountedLogVars() const { return posCountedLvs_.size(); }
|
||||
|
||||
unsigned nrNegCountedLogVars (void) const { return negCountedLvs_.size(); }
|
||||
unsigned nrNegCountedLogVars() const { return negCountedLvs_.size(); }
|
||||
|
||||
void addLiteralComplemented (const Literal& lit);
|
||||
|
||||
@ -122,9 +127,9 @@ class Clause
|
||||
|
||||
bool isIpgLogVar (LogVar X) const;
|
||||
|
||||
TinySet<LiteralId> lidSet (void) const;
|
||||
TinySet<LiteralId> lidSet() const;
|
||||
|
||||
LogVarSet ipgCandidates (void) const;
|
||||
LogVarSet ipgCandidates() const;
|
||||
|
||||
LogVarTypes logVarTypes (size_t litIdx) const;
|
||||
|
||||
@ -132,78 +137,78 @@ class Clause
|
||||
|
||||
static bool independentClauses (Clause& c1, Clause& c2);
|
||||
|
||||
static vector<Clause*> copyClauses (const vector<Clause*>& clauses);
|
||||
static std::vector<Clause*> copyClauses (
|
||||
const std::vector<Clause*>& clauses);
|
||||
|
||||
static void printClauses (const vector<Clause*>& clauses);
|
||||
static void printClauses (const std::vector<Clause*>& clauses);
|
||||
|
||||
static void deleteClauses (vector<Clause*>& clauses);
|
||||
|
||||
friend std::ostream& operator<< (ostream &os, const Clause& clause);
|
||||
static void deleteClauses (std::vector<Clause*>& clauses);
|
||||
|
||||
private:
|
||||
LogVarSet getLogVarSetExcluding (size_t idx) const;
|
||||
|
||||
Literals literals_;
|
||||
LogVarSet ipgLvs_;
|
||||
LogVarSet posCountedLvs_;
|
||||
LogVarSet negCountedLvs_;
|
||||
ConstraintTree constr_;
|
||||
friend std::ostream& operator<< (std::ostream&, const Clause&);
|
||||
|
||||
Literals literals_;
|
||||
LogVarSet ipgLvs_;
|
||||
LogVarSet posCountedLvs_;
|
||||
LogVarSet negCountedLvs_;
|
||||
ConstraintTree constr_;
|
||||
|
||||
DISALLOW_ASSIGN (Clause);
|
||||
};
|
||||
|
||||
typedef vector<Clause*> Clauses;
|
||||
typedef std::vector<Clause*> Clauses;
|
||||
|
||||
|
||||
|
||||
class LitLvTypes
|
||||
{
|
||||
class LitLvTypes {
|
||||
public:
|
||||
struct CompareLitLvTypes
|
||||
{
|
||||
bool operator() (
|
||||
const LitLvTypes& types1,
|
||||
const LitLvTypes& types2) const
|
||||
{
|
||||
if (types1.lid_ < types2.lid_) {
|
||||
return true;
|
||||
}
|
||||
if (types1.lid_ == types2.lid_) {
|
||||
return types1.lvTypes_ < types2.lvTypes_;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
LitLvTypes (LiteralId lid, const LogVarTypes& lvTypes) :
|
||||
lid_(lid), lvTypes_(lvTypes) { }
|
||||
|
||||
LiteralId lid (void) const { return lid_; }
|
||||
LiteralId lid() const { return lid_; }
|
||||
|
||||
const LogVarTypes& logVarTypes (void) const { return lvTypes_; }
|
||||
const LogVarTypes& logVarTypes() const { return lvTypes_; }
|
||||
|
||||
void setAllFullLogVars (void) {
|
||||
std::fill (lvTypes_.begin(), lvTypes_.end(), LogVarType::FULL_LV); }
|
||||
|
||||
friend std::ostream& operator<< (std::ostream &os, const LitLvTypes& lit);
|
||||
void setAllFullLogVars();
|
||||
|
||||
private:
|
||||
friend std::ostream& operator<< (std::ostream&, const LitLvTypes&);
|
||||
|
||||
LiteralId lid_;
|
||||
LogVarTypes lvTypes_;
|
||||
};
|
||||
|
||||
typedef TinySet<LitLvTypes,LitLvTypes::CompareLitLvTypes> LitLvTypesSet;
|
||||
|
||||
|
||||
|
||||
class LiftedWCNF
|
||||
struct CmpLitLvTypes
|
||||
{
|
||||
bool operator() (
|
||||
const LitLvTypes& types1,
|
||||
const LitLvTypes& types2) const
|
||||
{
|
||||
if (types1.lid() < types2.lid()) {
|
||||
return true;
|
||||
}
|
||||
// vsc if (types1.lid() == types2.lid()){
|
||||
// return types1.logVarTypes() < types2.logVarTypes();
|
||||
//}
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
typedef TinySet<LitLvTypes, CmpLitLvTypes> LitLvTypesSet;
|
||||
|
||||
|
||||
|
||||
class LiftedWCNF {
|
||||
public:
|
||||
LiftedWCNF (const ParfactorList& pfList);
|
||||
|
||||
~LiftedWCNF (void);
|
||||
~LiftedWCNF();
|
||||
|
||||
const Clauses& clauses (void) const { return clauses_; }
|
||||
const Clauses& clauses() const { return clauses_; }
|
||||
|
||||
void addWeight (LiteralId lid, double posW, double negW);
|
||||
|
||||
@ -211,15 +216,15 @@ class LiftedWCNF
|
||||
|
||||
double negWeight (LiteralId lid) const;
|
||||
|
||||
vector<LiteralId> prvGroupLiterals (PrvGroup prvGroup);
|
||||
std::vector<LiteralId> prvGroupLiterals (PrvGroup prvGroup);
|
||||
|
||||
Clause* createClause (LiteralId lid) const;
|
||||
|
||||
void printFormulaIndicators (void) const;
|
||||
void printFormulaIndicators() const;
|
||||
|
||||
void printWeights (void) const;
|
||||
void printWeights() const;
|
||||
|
||||
void printClauses (void) const;
|
||||
void printClauses() const;
|
||||
|
||||
private:
|
||||
LiteralId getLiteralId (PrvGroup prvGroup, unsigned range);
|
||||
@ -228,14 +233,16 @@ class LiftedWCNF
|
||||
|
||||
void addParameterClauses (const ParfactorList& pfList);
|
||||
|
||||
Clauses clauses_;
|
||||
LiteralId freeLiteralId_;
|
||||
const ParfactorList& pfList_;
|
||||
unordered_map<PrvGroup, vector<LiteralId>> map_;
|
||||
unordered_map<LiteralId, std::pair<double,double>> weights_;
|
||||
Clauses clauses_;
|
||||
LiteralId freeLiteralId_;
|
||||
const ParfactorList& pfList_;
|
||||
std::unordered_map<PrvGroup, std::vector<LiteralId>> map_;
|
||||
std::unordered_map<LiteralId, std::pair<double,double>> weights_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN (LiftedWCNF);
|
||||
};
|
||||
|
||||
#endif // HORUS_LIFTEDWCNF_H
|
||||
} // namespace Horus
|
||||
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_LIFTEDWCNF_H_
|
||||
|
||||
|
@ -43,9 +43,9 @@ SO=@SO@
|
||||
#4.1VPATH=@srcdir@:@srcdir@/OPTYap
|
||||
CWD=$(PWD)
|
||||
|
||||
HCLI = $(srcdir)/hcli
|
||||
utestsdir=@srcdir@/unit_tests
|
||||
|
||||
HEADERS = \
|
||||
MAIN_HEADERS = \
|
||||
$(srcdir)/BayesBall.h \
|
||||
$(srcdir)/BayesBallGraph.h \
|
||||
$(srcdir)/BeliefProp.h \
|
||||
@ -54,6 +54,8 @@ HEADERS = \
|
||||
$(srcdir)/ElimGraph.h \
|
||||
$(srcdir)/Factor.h \
|
||||
$(srcdir)/FactorGraph.h \
|
||||
$(srcdir)/GenericFactor.h \
|
||||
$(srcdir)/GroundSolver.h \
|
||||
$(srcdir)/Histogram.h \
|
||||
$(srcdir)/Horus.h \
|
||||
$(srcdir)/Indexer.h \
|
||||
@ -67,14 +69,20 @@ HEADERS = \
|
||||
$(srcdir)/Parfactor.h \
|
||||
$(srcdir)/ParfactorList.h \
|
||||
$(srcdir)/ProbFormula.h \
|
||||
$(srcdir)/GroundSolver.h \
|
||||
$(srcdir)/TinySet.h \
|
||||
$(srcdir)/Util.h \
|
||||
$(srcdir)/Var.h \
|
||||
$(srcdir)/VarElim.h \
|
||||
$(srcdir)/WeightedBp.h
|
||||
|
||||
CPP_SOURCES = \
|
||||
UTESTS_HEADERS = \
|
||||
$(utestsdir)/Common.h
|
||||
|
||||
HEADERS = \
|
||||
$(MAIN_HEADERS) \
|
||||
$(UTESTS_HEADERS)
|
||||
|
||||
MAIN_SOURCES = \
|
||||
$(srcdir)/BayesBall.cpp \
|
||||
$(srcdir)/BayesBallGraph.cpp \
|
||||
$(srcdir)/BeliefProp.cpp \
|
||||
@ -83,9 +91,12 @@ CPP_SOURCES = \
|
||||
$(srcdir)/ElimGraph.cpp \
|
||||
$(srcdir)/Factor.cpp \
|
||||
$(srcdir)/FactorGraph.cpp \
|
||||
$(srcdir)/GenericFactor.cpp \
|
||||
$(srcdir)/GroundSolver.cpp \
|
||||
$(srcdir)/Histogram.cpp \
|
||||
$(srcdir)/HorusCli.cpp \
|
||||
$(srcdir)/HorusYap.cpp \
|
||||
$(srcdir)/Indexer.cpp \
|
||||
$(srcdir)/LiftedBp.cpp \
|
||||
$(srcdir)/LiftedKc.cpp \
|
||||
$(srcdir)/LiftedOperations.cpp \
|
||||
@ -95,12 +106,23 @@ CPP_SOURCES = \
|
||||
$(srcdir)/Parfactor.cpp \
|
||||
$(srcdir)/ParfactorList.cpp \
|
||||
$(srcdir)/ProbFormula.cpp \
|
||||
$(srcdir)/GroundSolver.cpp \
|
||||
$(srcdir)/Util.cpp \
|
||||
$(srcdir)/Var.cpp \
|
||||
$(srcdir)/VarElim.cpp \
|
||||
$(srcdir)/WeightedBp.cpp
|
||||
|
||||
UTESTS_SOURCES = \
|
||||
$(utestsdir)/BeliefPropTest.cpp \
|
||||
$(utestsdir)/Common.cpp \
|
||||
$(utestsdir)/CountingBpTest.cpp \
|
||||
$(utestsdir)/FactorTest.cpp \
|
||||
$(utestsdir)/VarElimTest.cpp \
|
||||
$(utestsdir)/UnitTesting.cpp
|
||||
|
||||
SOURCES = \
|
||||
$(MAIN_SOURCES) \
|
||||
$(UTESTS_SOURCES)
|
||||
|
||||
OBJS = \
|
||||
BayesBall.o \
|
||||
BayesBallGraph.o \
|
||||
@ -110,8 +132,10 @@ OBJS = \
|
||||
ElimGraph.o \
|
||||
Factor.o \
|
||||
FactorGraph.o \
|
||||
GenericFactor.o \
|
||||
GroundSolver.o \
|
||||
Histogram.o \
|
||||
HorusYap.o \
|
||||
Indexer.o \
|
||||
LiftedBp.o \
|
||||
LiftedKc.o \
|
||||
LiftedOperations.o \
|
||||
@ -121,12 +145,15 @@ OBJS = \
|
||||
ProbFormula.o \
|
||||
Parfactor.o \
|
||||
ParfactorList.o \
|
||||
GroundSolver.o \
|
||||
Util.o \
|
||||
Var.o \
|
||||
VarElim.o \
|
||||
WeightedBp.o
|
||||
|
||||
LIB_OBJS = \
|
||||
$(OBJS) \
|
||||
HorusYap.o
|
||||
|
||||
HCLI_OBJS = \
|
||||
BayesBall.o \
|
||||
BayesBallGraph.o \
|
||||
@ -135,51 +162,82 @@ HCLI_OBJS = \
|
||||
ElimGraph.o \
|
||||
Factor.o \
|
||||
FactorGraph.o \
|
||||
HorusCli.o \
|
||||
GenericFactor.o \
|
||||
GroundSolver.o \
|
||||
HorusCli.o \
|
||||
Indexer.o \
|
||||
Util.o \
|
||||
Var.o \
|
||||
VarElim.o \
|
||||
WeightedBp.o
|
||||
|
||||
SOBJS=horus.@SO@
|
||||
UTESTS_OBJS = \
|
||||
$(OBJS) \
|
||||
$(utestsdir)/BeliefPropTest.o \
|
||||
$(utestsdir)/Common.o \
|
||||
$(utestsdir)/CountingBpTest.o \
|
||||
$(utestsdir)/FactorTest.o \
|
||||
$(utestsdir)/VarElimTest.o \
|
||||
$(utestsdir)/UnitTesting.o
|
||||
|
||||
|
||||
all: $(SOBJS) hcli
|
||||
LIB = $(srcdir)/horus.@SO@
|
||||
HCLI = $(srcdir)/hcli
|
||||
UTESTING = $(srcdir)/run_tests
|
||||
|
||||
|
||||
all: $(LIB) $(HCLI)
|
||||
|
||||
|
||||
# Don't require $(UTESTING) by default as we
|
||||
# don't want a hard dependency on CppUnit
|
||||
with_tests: $(LIB) $(HCLI) $(UTESTING)
|
||||
|
||||
|
||||
@DO_SECOND_LD@$(LIB): $(LIB_OBJS)
|
||||
@DO_SECOND_LD@ @SHLIB_CXX_LD@ -o $@ $(LIB_OBJS) @EXTRA_LIBS_FOR_SWIDLLS@
|
||||
|
||||
|
||||
$(HCLI): $(HCLI_OBJS)
|
||||
$(CXX) -o $@ $(HCLI_OBJS)
|
||||
|
||||
|
||||
$(UTESTING): $(UTESTS_OBJS)
|
||||
$(CXX) -o $@ $(UTESTS_OBJS) -lcppunit
|
||||
|
||||
|
||||
# default rule
|
||||
%.o : $(srcdir)/%.cpp
|
||||
$(CXX) -c $(CXXFLAGS) $< -o $@
|
||||
|
||||
|
||||
@DO_SECOND_LD@horus.@SO@: $(OBJS)
|
||||
@DO_SECOND_LD@ @SHLIB_CXX_LD@ -o horus.@SO@ $(OBJS) @EXTRA_LIBS_FOR_SWIDLLS@
|
||||
|
||||
|
||||
hcli: $(HCLI_OBJS)
|
||||
$(CXX) -o $(HCLI) $(HCLI_OBJS)
|
||||
$(CXX) -o $@ -c $(CXXFLAGS) $<
|
||||
|
||||
|
||||
install: all
|
||||
$(INSTALL_PROGRAM) $(SOBJS) $(DESTDIR)$(YAPLIBDIR)
|
||||
$(INSTALL_PROGRAM) $(LIB) $(DESTDIR)$(YAPLIBDIR)
|
||||
$(INSTALL_PROGRAM) $(HCLI) $(DESTDIR)$(BINDIR)
|
||||
|
||||
|
||||
clean:
|
||||
rm -f *.o *~ $(OBJS) $(SOBJS) $(HCLI) *.BAK
|
||||
rm -f $(LIB) $(HCLI) $(UTESTING) *.o *~ $(utestsdir)/*.o $(utestsdir)/*~
|
||||
|
||||
|
||||
erase_dots:
|
||||
rm -f *.dot *.png
|
||||
remove_dots:
|
||||
rm -f *.dot *.png *.svg
|
||||
|
||||
|
||||
depend: $(HEADERS) $(CPP_SOURCES)
|
||||
depend: $(SOURCES) $(HEADERS)
|
||||
-@if test "$(GCC)" = yes; then\
|
||||
$(CC) -std=c++0x -MM -MG $(CFLAGS) -I$(srcdir) -I$(srcdir)/../../../../include -I$(srcdir)/../../../../H $(CPP_SOURCES) >> Makefile;\
|
||||
for F in $(SOURCES); do \
|
||||
D=`dirname $$F`; \
|
||||
B=`basename $$F .cpp`; \
|
||||
$(CXX) $(CXXFLAGS) -MM -MG -MT "$$D/$$B.o" -I$(srcdir)/../../../../H -I$(srcdir)/../../../../include $$F >> Makefile; \
|
||||
done; \
|
||||
else\
|
||||
makedepend -f - -- $(CFLAGS) -I$(srcdir)/../../../../H -I$(srcdir)/../../../../include -- $(CPP_SOURCES) |\
|
||||
sed 's|.*/\([^:]*\):|\1:|' >> Makefile ;\
|
||||
makedepend -- $(CXXFLAGS) -- -I$(srcdir)/../../../../H -I$(srcdir)/../../../../include $(SOURCES); \
|
||||
fi
|
||||
|
||||
|
||||
.PHONY: default all install clean remove_dots depend
|
||||
|
||||
|
||||
# DO NOT DELETE THIS LINE -- make depend depends on it.
|
||||
|
||||
|
@ -1,3 +1,8 @@
|
||||
#include <cassert>
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "Parfactor.h"
|
||||
#include "Histogram.h"
|
||||
#include "Indexer.h"
|
||||
@ -5,6 +10,8 @@
|
||||
#include "Horus.h"
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
Parfactor::Parfactor (
|
||||
const ProbFormulas& formulas,
|
||||
const Params& params,
|
||||
@ -84,7 +91,7 @@ Parfactor::Parfactor (const Parfactor& g)
|
||||
|
||||
|
||||
|
||||
Parfactor::~Parfactor (void)
|
||||
Parfactor::~Parfactor()
|
||||
{
|
||||
delete constr_;
|
||||
}
|
||||
@ -92,7 +99,7 @@ Parfactor::~Parfactor (void)
|
||||
|
||||
|
||||
LogVarSet
|
||||
Parfactor::countedLogVars (void) const
|
||||
Parfactor::countedLogVars() const
|
||||
{
|
||||
LogVarSet set;
|
||||
for (size_t i = 0; i < args_.size(); i++) {
|
||||
@ -106,7 +113,7 @@ Parfactor::countedLogVars (void) const
|
||||
|
||||
|
||||
LogVarSet
|
||||
Parfactor::uncountedLogVars (void) const
|
||||
Parfactor::uncountedLogVars() const
|
||||
{
|
||||
return constr_->logVarSet() - countedLogVars();
|
||||
}
|
||||
@ -114,7 +121,7 @@ Parfactor::uncountedLogVars (void) const
|
||||
|
||||
|
||||
LogVarSet
|
||||
Parfactor::elimLogVars (void) const
|
||||
Parfactor::elimLogVars() const
|
||||
{
|
||||
LogVarSet requiredToElim = constr_->logVarSet();
|
||||
requiredToElim -= constr_->singletons();
|
||||
@ -149,7 +156,7 @@ Parfactor::sumOutIndex (size_t fIdx)
|
||||
unsigned N = constr_->getConditionalCount (
|
||||
args_[fIdx].countedLogVar());
|
||||
unsigned R = args_[fIdx].range();
|
||||
vector<double> numAssigns = HistogramSet::getNumAssigns (N, R);
|
||||
std::vector<double> numAssigns = HistogramSet::getNumAssigns (N, R);
|
||||
Indexer indexer (ranges_, fIdx);
|
||||
while (indexer.valid()) {
|
||||
if (Globals::logDomain) {
|
||||
@ -171,7 +178,7 @@ Parfactor::sumOutIndex (size_t fIdx)
|
||||
}
|
||||
constr_->remove (excl);
|
||||
|
||||
TFactor<ProbFormula>::sumOutIndex (fIdx);
|
||||
GenericFactor<ProbFormula>::sumOutIndex (fIdx);
|
||||
LogAware::pow (params_, exp);
|
||||
}
|
||||
|
||||
@ -181,7 +188,7 @@ void
|
||||
Parfactor::multiply (Parfactor& g)
|
||||
{
|
||||
alignAndExponentiate (this, &g);
|
||||
TFactor<ProbFormula>::multiply (g);
|
||||
GenericFactor<ProbFormula>::multiply (g);
|
||||
constr_->join (g.constr(), true);
|
||||
simplifyGrounds();
|
||||
assert (constr_->isCartesianProduct (countedLogVars()));
|
||||
@ -224,10 +231,10 @@ Parfactor::countConvert (LogVar X)
|
||||
unsigned N = constr_->getConditionalCount (X);
|
||||
unsigned R = ranges_[fIdx];
|
||||
unsigned H = HistogramSet::nrHistograms (N, R);
|
||||
vector<Histogram> histograms = HistogramSet::getHistograms (N, R);
|
||||
std::vector<Histogram> histograms = HistogramSet::getHistograms (N, R);
|
||||
|
||||
Indexer indexer (ranges_);
|
||||
vector<Params> sumout (params_.size() / R);
|
||||
std::vector<Params> sumout (params_.size() / R);
|
||||
unsigned count = 0;
|
||||
while (indexer.valid()) {
|
||||
sumout[count].reserve (R);
|
||||
@ -279,11 +286,11 @@ Parfactor::expand (LogVar X, LogVar X_new1, LogVar X_new2)
|
||||
unsigned H1 = HistogramSet::nrHistograms (N1, R);
|
||||
unsigned H2 = HistogramSet::nrHistograms (N2, R);
|
||||
|
||||
vector<Histogram> histograms = HistogramSet::getHistograms (N, R);
|
||||
vector<Histogram> histograms1 = HistogramSet::getHistograms (N1, R);
|
||||
vector<Histogram> histograms2 = HistogramSet::getHistograms (N2, R);
|
||||
std::vector<Histogram> histograms = HistogramSet::getHistograms (N, R);
|
||||
std::vector<Histogram> histograms1 = HistogramSet::getHistograms (N1, R);
|
||||
std::vector<Histogram> histograms2 = HistogramSet::getHistograms (N2, R);
|
||||
|
||||
vector<unsigned> sumIndexes;
|
||||
std::vector<unsigned> sumIndexes;
|
||||
sumIndexes.reserve (H1 * H2);
|
||||
for (unsigned i = 0; i < H1; i++) {
|
||||
for (unsigned j = 0; j < H2; j++) {
|
||||
@ -319,16 +326,16 @@ Parfactor::fullExpand (LogVar X)
|
||||
|
||||
unsigned N = constr_->getConditionalCount (X);
|
||||
unsigned R = args_[fIdx].range();
|
||||
vector<Histogram> originHists = HistogramSet::getHistograms (N, R);
|
||||
vector<Histogram> expandHists = HistogramSet::getHistograms (1, R);
|
||||
std::vector<Histogram> originHists = HistogramSet::getHistograms (N, R);
|
||||
std::vector<Histogram> expandHists = HistogramSet::getHistograms (1, R);
|
||||
assert (ranges_[fIdx] == originHists.size());
|
||||
vector<unsigned> sumIndexes;
|
||||
std::vector<unsigned> sumIndexes;
|
||||
sumIndexes.reserve (N * R);
|
||||
|
||||
Ranges expandRanges (N, R);
|
||||
Indexer indexer (expandRanges);
|
||||
while (indexer.valid()) {
|
||||
vector<unsigned> hist (R, 0);
|
||||
std::vector<unsigned> hist (R, 0);
|
||||
for (unsigned n = 0; n < N; n++) {
|
||||
hist += expandHists[indexer[n]];
|
||||
}
|
||||
@ -384,14 +391,14 @@ Parfactor::absorveEvidence (const ProbFormula& formula, unsigned evidence)
|
||||
assert (args_[fIdx].isCounting() == false);
|
||||
assert (constr_->isCountNormalized (excl));
|
||||
LogAware::pow (params_, constr_->getConditionalCount (excl));
|
||||
TFactor<ProbFormula>::absorveEvidence (formula, evidence);
|
||||
GenericFactor<ProbFormula>::absorveEvidence (formula, evidence);
|
||||
constr_->remove (excl);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Parfactor::setNewGroups (void)
|
||||
Parfactor::setNewGroups()
|
||||
{
|
||||
for (size_t i = 0; i < args_.size(); i++) {
|
||||
args_[i].setGroup (ProbFormula::getNewGroup());
|
||||
@ -494,7 +501,7 @@ Parfactor::containsGroup (PrvGroup group) const
|
||||
|
||||
|
||||
bool
|
||||
Parfactor::containsGroups (vector<PrvGroup> groups) const
|
||||
Parfactor::containsGroups (std::vector<PrvGroup> groups) const
|
||||
{
|
||||
for (size_t i = 0; i < groups.size(); i++) {
|
||||
if (containsGroup (groups[i]) == false) {
|
||||
@ -565,10 +572,10 @@ Parfactor::nrFormulasWithGroup (PrvGroup group) const
|
||||
|
||||
|
||||
|
||||
vector<PrvGroup>
|
||||
Parfactor::getAllGroups (void) const
|
||||
std::vector<PrvGroup>
|
||||
Parfactor::getAllGroups() const
|
||||
{
|
||||
vector<PrvGroup> groups (args_.size());
|
||||
std::vector<PrvGroup> groups (args_.size());
|
||||
for (size_t i = 0; i < args_.size(); i++) {
|
||||
groups[i] = args_[i].group();
|
||||
}
|
||||
@ -577,10 +584,10 @@ Parfactor::getAllGroups (void) const
|
||||
|
||||
|
||||
|
||||
string
|
||||
Parfactor::getLabel (void) const
|
||||
std::string
|
||||
Parfactor::getLabel() const
|
||||
{
|
||||
stringstream ss;
|
||||
std::stringstream ss;
|
||||
ss << "phi(" ;
|
||||
for (size_t i = 0; i < args_.size(); i++) {
|
||||
if (i != 0) ss << "," ;
|
||||
@ -598,6 +605,8 @@ Parfactor::getLabel (void) const
|
||||
void
|
||||
Parfactor::print (bool printParams) const
|
||||
{
|
||||
using std::cout;
|
||||
using std::endl;
|
||||
cout << "Formulas: " ;
|
||||
for (size_t i = 0; i < args_.size(); i++) {
|
||||
if (i != 0) cout << ", " ;
|
||||
@ -605,9 +614,10 @@ Parfactor::print (bool printParams) const
|
||||
}
|
||||
cout << endl;
|
||||
if (args_[0].group() != Util::maxUnsigned()) {
|
||||
vector<string> groups;
|
||||
std::vector<std::string> groups;
|
||||
for (size_t i = 0; i < args_.size(); i++) {
|
||||
groups.push_back (string ("g") + Util::toString (args_[i].group()));
|
||||
groups.push_back (std::string ("g")
|
||||
+ Util::toString (args_[i].group()));
|
||||
}
|
||||
cout << "Groups: " << groups << endl;
|
||||
}
|
||||
@ -633,12 +643,12 @@ Parfactor::print (bool printParams) const
|
||||
|
||||
|
||||
void
|
||||
Parfactor::printParameters (void) const
|
||||
Parfactor::printParameters() const
|
||||
{
|
||||
vector<string> jointStrings;
|
||||
std::vector<std::string> jointStrings;
|
||||
Indexer indexer (ranges_);
|
||||
while (indexer.valid()) {
|
||||
stringstream ss;
|
||||
std::stringstream ss;
|
||||
for (size_t i = 0; i < args_.size(); i++) {
|
||||
if (i != 0) ss << ", " ;
|
||||
if (args_[i].isCounting()) {
|
||||
@ -659,22 +669,22 @@ Parfactor::printParameters (void) const
|
||||
++ indexer;
|
||||
}
|
||||
for (size_t i = 0; i < params_.size(); i++) {
|
||||
cout << "f(" << jointStrings[i] << ")" ;
|
||||
cout << " = " << params_[i] << endl;
|
||||
std::cout << "f(" << jointStrings[i] << ")" ;
|
||||
std::cout << " = " << params_[i] << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Parfactor::printProjections (void) const
|
||||
Parfactor::printProjections() const
|
||||
{
|
||||
ConstraintTree copy (*constr_);
|
||||
|
||||
LogVarSet Xs = copy.logVarSet();
|
||||
for (size_t i = 0; i < Xs.size(); i++) {
|
||||
cout << "-> projection of " << Xs[i] << ": " ;
|
||||
cout << copy.tupleSet ({Xs[i]}) << endl;
|
||||
std::cout << "-> projection of " << Xs[i] << ": " ;
|
||||
std::cout << copy.tupleSet ({Xs[i]}) << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
@ -684,12 +694,12 @@ void
|
||||
Parfactor::expandPotential (
|
||||
size_t fIdx,
|
||||
unsigned newRange,
|
||||
const vector<unsigned>& sumIndexes)
|
||||
const std::vector<unsigned>& sumIndexes)
|
||||
{
|
||||
ullong newSize = (params_.size() / ranges_[fIdx]) * newRange;
|
||||
if (newSize > params_.max_size()) {
|
||||
cerr << "Error: an overflow occurred when performing expansion." ;
|
||||
cerr << endl;
|
||||
std::cerr << "Error: an overflow occurred when performing expansion." ;
|
||||
std::cerr << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
|
||||
@ -698,7 +708,7 @@ Parfactor::expandPotential (
|
||||
params_.reserve (newSize);
|
||||
|
||||
size_t prod = 1;
|
||||
vector<size_t> offsets (ranges_.size());
|
||||
std::vector<size_t> offsets (ranges_.size());
|
||||
for (size_t i = ranges_.size(); i-- > 0; ) {
|
||||
offsets[i] = prod;
|
||||
prod *= ranges_[i];
|
||||
@ -706,7 +716,7 @@ Parfactor::expandPotential (
|
||||
|
||||
size_t index = 0;
|
||||
ranges_[fIdx] = newRange;
|
||||
vector<unsigned> indices (ranges_.size(), 0);
|
||||
std::vector<unsigned> indices (ranges_.size(), 0);
|
||||
for (size_t k = 0; k < newSize; k++) {
|
||||
assert (index < backup.size());
|
||||
params_.push_back (backup[index]);
|
||||
@ -759,7 +769,7 @@ Parfactor::simplifyCountingFormulas (size_t fIdx)
|
||||
|
||||
|
||||
void
|
||||
Parfactor::simplifyGrounds (void)
|
||||
Parfactor::simplifyGrounds()
|
||||
{
|
||||
if (args_.size() == 1) {
|
||||
return;
|
||||
@ -872,12 +882,12 @@ Parfactor::alignLogicalVars (Parfactor* g1, Parfactor* g2)
|
||||
std::pair<LogVars, LogVars> res = getAlignLogVars (g1, g2);
|
||||
const LogVars& alignLvs1 = res.first;
|
||||
const LogVars& alignLvs2 = res.second;
|
||||
// cout << "ALIGNING :::::::::::::::::" << endl;
|
||||
// std::cout << "ALIGNING :::::::::::::::::" << std::endl;
|
||||
// g1->print();
|
||||
// cout << "AND" << endl;
|
||||
// g2->print();
|
||||
// cout << "-> align lvs1 = " << alignLvs1 << endl;
|
||||
// cout << "-> align lvs2 = " << alignLvs2 << endl;
|
||||
// std::cout << "-> align lvs1 = " << alignLvs1 << std::endl;
|
||||
// std::cout << "-> align lvs2 = " << alignLvs2 << std::endl;
|
||||
LogVar freeLogVar (0);
|
||||
Substitution theta1, theta2;
|
||||
for (size_t i = 0; i < alignLvs1.size(); i++) {
|
||||
@ -933,9 +943,11 @@ Parfactor::alignLogicalVars (Parfactor* g1, Parfactor* g2)
|
||||
}
|
||||
}
|
||||
|
||||
// cout << "theta1: " << theta1 << endl;
|
||||
// cout << "theta2: " << theta2 << endl;
|
||||
// std::cout << "theta1: " << theta1 << std::endl;
|
||||
// std::cout << "theta2: " << theta2 << std::endl;
|
||||
g1->applySubstitution (theta1);
|
||||
g2->applySubstitution (theta2);
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
|
@ -1,20 +1,21 @@
|
||||
#ifndef HORUS_PARFACTOR_H
|
||||
#define HORUS_PARFACTOR_H
|
||||
#ifndef YAP_PACKAGES_CLPBN_HORUS_PARFACTOR_H_
|
||||
#define YAP_PACKAGES_CLPBN_HORUS_PARFACTOR_H_
|
||||
|
||||
#include "Factor.h"
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "GenericFactor.h"
|
||||
#include "ProbFormula.h"
|
||||
#include "ConstraintTree.h"
|
||||
#include "LiftedUtils.h"
|
||||
#include "Horus.h"
|
||||
|
||||
|
||||
class Parfactor : public TFactor<ProbFormula>
|
||||
{
|
||||
namespace Horus {
|
||||
|
||||
class Parfactor : public GenericFactor<ProbFormula> {
|
||||
public:
|
||||
Parfactor (
|
||||
const ProbFormulas&,
|
||||
const Params&,
|
||||
const Tuples&,
|
||||
Parfactor (const ProbFormulas&, const Params&, const Tuples&,
|
||||
unsigned distId);
|
||||
|
||||
Parfactor (const Parfactor*, const Tuple&);
|
||||
@ -23,21 +24,21 @@ class Parfactor : public TFactor<ProbFormula>
|
||||
|
||||
Parfactor (const Parfactor&);
|
||||
|
||||
~Parfactor (void);
|
||||
~Parfactor();
|
||||
|
||||
ConstraintTree* constr (void) { return constr_; }
|
||||
ConstraintTree* constr() { return constr_; }
|
||||
|
||||
const ConstraintTree* constr (void) const { return constr_; }
|
||||
const ConstraintTree* constr() const { return constr_; }
|
||||
|
||||
const LogVars& logVars (void) const { return constr_->logVars(); }
|
||||
const LogVars& logVars() const { return constr_->logVars(); }
|
||||
|
||||
const LogVarSet& logVarSet (void) const { return constr_->logVarSet(); }
|
||||
const LogVarSet& logVarSet() const { return constr_->logVarSet(); }
|
||||
|
||||
LogVarSet countedLogVars (void) const;
|
||||
LogVarSet countedLogVars() const;
|
||||
|
||||
LogVarSet uncountedLogVars (void) const;
|
||||
LogVarSet uncountedLogVars() const;
|
||||
|
||||
LogVarSet elimLogVars (void) const;
|
||||
LogVarSet elimLogVars() const;
|
||||
|
||||
LogVarSet exclusiveLogVars (size_t fIdx) const;
|
||||
|
||||
@ -57,7 +58,7 @@ class Parfactor : public TFactor<ProbFormula>
|
||||
|
||||
void absorveEvidence (const ProbFormula&, unsigned);
|
||||
|
||||
void setNewGroups (void);
|
||||
void setNewGroups();
|
||||
|
||||
void applySubstitution (const Substitution&);
|
||||
|
||||
@ -71,7 +72,7 @@ class Parfactor : public TFactor<ProbFormula>
|
||||
|
||||
bool containsGroup (PrvGroup) const;
|
||||
|
||||
bool containsGroups (vector<PrvGroup>) const;
|
||||
bool containsGroups (std::vector<PrvGroup>) const;
|
||||
|
||||
unsigned nrFormulas (LogVar) const;
|
||||
|
||||
@ -81,17 +82,17 @@ class Parfactor : public TFactor<ProbFormula>
|
||||
|
||||
unsigned nrFormulasWithGroup (PrvGroup) const;
|
||||
|
||||
vector<PrvGroup> getAllGroups (void) const;
|
||||
std::vector<PrvGroup> getAllGroups() const;
|
||||
|
||||
void print (bool = false) const;
|
||||
|
||||
void printParameters (void) const;
|
||||
void printParameters() const;
|
||||
|
||||
void printProjections (void) const;
|
||||
void printProjections() const;
|
||||
|
||||
string getLabel (void) const;
|
||||
std::string getLabel() const;
|
||||
|
||||
void simplifyGrounds (void);
|
||||
void simplifyGrounds();
|
||||
|
||||
static bool canMultiply (Parfactor*, Parfactor*);
|
||||
|
||||
@ -104,18 +105,20 @@ class Parfactor : public TFactor<ProbFormula>
|
||||
Parfactor* g1, Parfactor* g2);
|
||||
|
||||
void expandPotential (size_t fIdx, unsigned newRange,
|
||||
const vector<unsigned>& sumIndexes);
|
||||
const std::vector<unsigned>& sumIndexes);
|
||||
|
||||
static void alignAndExponentiate (Parfactor*, Parfactor*);
|
||||
|
||||
static void alignLogicalVars (Parfactor*, Parfactor*);
|
||||
|
||||
ConstraintTree* constr_;
|
||||
ConstraintTree* constr_;
|
||||
|
||||
DISALLOW_ASSIGN (Parfactor);
|
||||
};
|
||||
|
||||
typedef vector<Parfactor*> Parfactors;
|
||||
typedef std::vector<Parfactor*> Parfactors;
|
||||
|
||||
#endif // HORUS_PARFACTOR_H
|
||||
} // namespace Horus
|
||||
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_PARFACTOR_H_
|
||||
|
||||
|
@ -1,10 +1,14 @@
|
||||
#include <cassert>
|
||||
|
||||
#include <queue>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ParfactorList.h"
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
ParfactorList::ParfactorList (const ParfactorList& pfList)
|
||||
{
|
||||
ParfactorList::const_iterator it = pfList.begin();
|
||||
@ -23,7 +27,7 @@ ParfactorList::ParfactorList (const Parfactors& pfs)
|
||||
|
||||
|
||||
|
||||
ParfactorList::~ParfactorList (void)
|
||||
ParfactorList::~ParfactorList()
|
||||
{
|
||||
ParfactorList::const_iterator it = pfList_.begin();
|
||||
while (it != pfList_.end()) {
|
||||
@ -64,27 +68,27 @@ ParfactorList::addShattered (Parfactor* pf)
|
||||
|
||||
|
||||
|
||||
list<Parfactor*>::iterator
|
||||
std::list<Parfactor*>::iterator
|
||||
ParfactorList::insertShattered (
|
||||
list<Parfactor*>::iterator it,
|
||||
std::list<Parfactor*>::iterator it,
|
||||
Parfactor* pf)
|
||||
{
|
||||
return pfList_.insert (it, pf);
|
||||
assert (isAllShattered());
|
||||
return pfList_.insert (it, pf);
|
||||
}
|
||||
|
||||
|
||||
|
||||
list<Parfactor*>::iterator
|
||||
ParfactorList::remove (list<Parfactor*>::iterator it)
|
||||
std::list<Parfactor*>::iterator
|
||||
ParfactorList::remove (std::list<Parfactor*>::iterator it)
|
||||
{
|
||||
return pfList_.erase (it);
|
||||
}
|
||||
|
||||
|
||||
|
||||
list<Parfactor*>::iterator
|
||||
ParfactorList::removeAndDelete (list<Parfactor*>::iterator it)
|
||||
std::list<Parfactor*>::iterator
|
||||
ParfactorList::removeAndDelete (std::list<Parfactor*>::iterator it)
|
||||
{
|
||||
delete *it;
|
||||
return pfList_.erase (it);
|
||||
@ -93,12 +97,12 @@ ParfactorList::removeAndDelete (list<Parfactor*>::iterator it)
|
||||
|
||||
|
||||
bool
|
||||
ParfactorList::isAllShattered (void) const
|
||||
ParfactorList::isAllShattered() const
|
||||
{
|
||||
if (pfList_.size() <= 1) {
|
||||
return true;
|
||||
}
|
||||
vector<Parfactor*> pfs (pfList_.begin(), pfList_.end());
|
||||
Parfactors pfs (pfList_.begin(), pfList_.end());
|
||||
for (size_t i = 0; i < pfs.size(); i++) {
|
||||
assert (isShattered (pfs[i]));
|
||||
}
|
||||
@ -115,13 +119,25 @@ ParfactorList::isAllShattered (void) const
|
||||
|
||||
|
||||
void
|
||||
ParfactorList::print (void) const
|
||||
ParfactorList::print() const
|
||||
{
|
||||
struct sortByParams {
|
||||
bool operator() (const Parfactor* pf1, const Parfactor* pf2)
|
||||
{
|
||||
if (pf1->params().size() < pf2->params().size()) {
|
||||
return true;
|
||||
} else if (pf1->params().size() == pf2->params().size() &&
|
||||
pf1->params() < pf2->params()) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
};
|
||||
Parfactors pfVec (pfList_.begin(), pfList_.end());
|
||||
std::sort (pfVec.begin(), pfVec.end(), sortByParams());
|
||||
// vsc std::sort (pfVec.begin(), pfVec.end(), sortByParams());
|
||||
for (size_t i = 0; i < pfVec.size(); i++) {
|
||||
pfVec[i]->print();
|
||||
cout << endl;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
@ -163,8 +179,8 @@ ParfactorList::isShattered (const Parfactor* g) const
|
||||
formulas[i], *(g->constr()),
|
||||
formulas[j], *(g->constr())) == false) {
|
||||
g->print();
|
||||
cout << "-> not identical on positions " ;
|
||||
cout << i << " and " << j << endl;
|
||||
std::cout << "-> not identical on positions " ;
|
||||
std::cout << i << " and " << j << std::endl;
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
@ -172,8 +188,8 @@ ParfactorList::isShattered (const Parfactor* g) const
|
||||
formulas[i], *(g->constr()),
|
||||
formulas[j], *(g->constr())) == false) {
|
||||
g->print();
|
||||
cout << "-> not disjoint on positions " ;
|
||||
cout << i << " and " << j << endl;
|
||||
std::cout << "-> not disjoint on positions " ;
|
||||
std::cout << i << " and " << j << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@ -200,9 +216,10 @@ ParfactorList::isShattered (
|
||||
fms1[i], *(g1->constr()),
|
||||
fms2[j], *(g2->constr())) == false) {
|
||||
g1->print();
|
||||
cout << "^" << endl;
|
||||
std::cout << "^" << std::endl;
|
||||
g2->print();
|
||||
cout << "-> not identical on group " << fms1[i].group() << endl;
|
||||
std::cout << "-> not identical on group " ;
|
||||
std::cout << fms1[i].group() << std::endl;
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
@ -210,10 +227,10 @@ ParfactorList::isShattered (
|
||||
fms1[i], *(g1->constr()),
|
||||
fms2[j], *(g2->constr())) == false) {
|
||||
g1->print();
|
||||
cout << "^" << endl;
|
||||
std::cout << "^" << std::endl;
|
||||
g2->print();
|
||||
cout << "-> not disjoint on groups " << fms1[i].group();
|
||||
cout << " and " << fms2[j].group() << endl;
|
||||
std::cout << "-> not disjoint on groups " << fms1[i].group();
|
||||
std::cout << " and " << fms2[j].group() << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@ -227,12 +244,12 @@ ParfactorList::isShattered (
|
||||
void
|
||||
ParfactorList::addToShatteredList (Parfactor* g)
|
||||
{
|
||||
queue<Parfactor*> residuals;
|
||||
std::queue<Parfactor*> residuals;
|
||||
residuals.push (g);
|
||||
while (residuals.empty() == false) {
|
||||
Parfactor* pf = residuals.front();
|
||||
bool pfSplitted = false;
|
||||
list<Parfactor*>::iterator pfIter;
|
||||
std::list<Parfactor*>::iterator pfIter;
|
||||
pfIter = pfList_.begin();
|
||||
while (pfIter != pfList_.end()) {
|
||||
std::pair<Parfactors, Parfactors> shattRes;
|
||||
@ -269,7 +286,7 @@ Parfactors
|
||||
ParfactorList::shatterAgainstMySelf (Parfactor* g)
|
||||
{
|
||||
Parfactors pfs;
|
||||
queue<Parfactor*> residuals;
|
||||
std::queue<Parfactor*> residuals;
|
||||
residuals.push (g);
|
||||
bool shattered = true;
|
||||
while (residuals.empty() == false) {
|
||||
@ -325,19 +342,22 @@ ParfactorList::shatterAgainstMySelf (
|
||||
{
|
||||
/*
|
||||
Util::printDashedLine();
|
||||
cout << "-> SHATTERING" << endl;
|
||||
std::cout << "-> SHATTERING" << std::endl;
|
||||
g->print();
|
||||
cout << "-> ON: " << g->argument (fIdx1) << "|" ;
|
||||
cout << g->constr()->tupleSet (g->argument (fIdx1).logVars()) << endl;
|
||||
cout << "-> ON: " << g->argument (fIdx2) << "|" ;
|
||||
cout << g->constr()->tupleSet (g->argument (fIdx2).logVars()) << endl;
|
||||
std::cout << "-> ON: " << g->argument (fIdx1) << "|" ;
|
||||
std::cout << g->constr()->tupleSet (g->argument (fIdx1).logVars());
|
||||
std::cout << std::endl;
|
||||
std::cout << "-> ON: " << g->argument (fIdx2) << "|" ;
|
||||
std::cout << g->constr()->tupleSet (g->argument (fIdx2).logVars())
|
||||
std::cout << std::endl;
|
||||
Util::printDashedLine();
|
||||
*/
|
||||
ProbFormula& f1 = g->argument (fIdx1);
|
||||
ProbFormula& f2 = g->argument (fIdx2);
|
||||
if (f1.isAtom()) {
|
||||
cerr << "Error: a ground occurs twice in the same parfactor." << endl;
|
||||
cerr << endl;
|
||||
std::cerr << "Error: a ground occurs twice in the same parfactor." ;
|
||||
std::cerr << std::endl;
|
||||
std::cerr << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
assert (g->constr()->empty() == false);
|
||||
@ -441,14 +461,14 @@ ParfactorList::shatter (
|
||||
ProbFormula& f2 = g2->argument (fIdx2);
|
||||
/*
|
||||
Util::printDashedLine();
|
||||
cout << "-> SHATTERING" << endl;
|
||||
std::cout << "-> SHATTERING" << std::endl;
|
||||
g1->print();
|
||||
cout << "-> WITH" << endl;
|
||||
std::cout << "-> WITH" << std::endl;
|
||||
g2->print();
|
||||
cout << "-> ON: " << f1 << "|" ;
|
||||
cout << g1->constr()->tupleSet (f1.logVars()) << endl;
|
||||
cout << "-> ON: " << f2 << "|" ;
|
||||
cout << g2->constr()->tupleSet (f2.logVars()) << endl;
|
||||
std::cout << "-> ON: " << f1 << "|" ;
|
||||
std::cout << g1->constr()->tupleSet (f1.logVars()) << std::endl;
|
||||
std::cout << "-> ON: " << f2 << "|" ;
|
||||
std::cout << g2->constr()->tupleSet (f2.logVars()) << std::endl;
|
||||
Util::printDashedLine();
|
||||
*/
|
||||
if (f1.isAtom()) {
|
||||
@ -486,12 +506,12 @@ ParfactorList::shatter (
|
||||
assert (commCt1->tupleSet (f1.logVars()) ==
|
||||
commCt2->tupleSet (f2.logVars()));
|
||||
|
||||
// stringstream ss1; ss1 << "" << count << "_A.dot" ;
|
||||
// stringstream ss2; ss2 << "" << count << "_B.dot" ;
|
||||
// stringstream ss3; ss3 << "" << count << "_A_comm.dot" ;
|
||||
// stringstream ss4; ss4 << "" << count << "_A_excl.dot" ;
|
||||
// stringstream ss5; ss5 << "" << count << "_B_comm.dot" ;
|
||||
// stringstream ss6; ss6 << "" << count << "_B_excl.dot" ;
|
||||
// std::stringstream ss1; ss1 << "" << count << "_A.dot" ;
|
||||
// std::stringstream ss2; ss2 << "" << count << "_B.dot" ;
|
||||
// std::stringstream ss3; ss3 << "" << count << "_A_comm.dot" ;
|
||||
// std::stringstream ss4; ss4 << "" << count << "_A_excl.dot" ;
|
||||
// std::stringstream ss5; ss5 << "" << count << "_B_comm.dot" ;
|
||||
// std::stringstream ss6; ss6 << "" << count << "_B_excl.dot" ;
|
||||
// g1->constr()->exportToGraphViz (ss1.str().c_str(), true);
|
||||
// g2->constr()->exportToGraphViz (ss2.str().c_str(), true);
|
||||
// commCt1->exportToGraphViz (ss3.str().c_str(), true);
|
||||
@ -638,3 +658,5 @@ ParfactorList::disjoint (
|
||||
return (ts1 & ts2).empty();
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
#ifndef HORUS_PARFACTORLIST_H
|
||||
#define HORUS_PARFACTORLIST_H
|
||||
#ifndef YAP_PACKAGES_CLPBN_HORUS_PARFACTORLIST_H_
|
||||
#define YAP_PACKAGES_CLPBN_HORUS_PARFACTORLIST_H_
|
||||
|
||||
#include <list>
|
||||
|
||||
@ -7,39 +7,38 @@
|
||||
#include "ProbFormula.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace Horus {
|
||||
|
||||
class Parfactor;
|
||||
|
||||
class ParfactorList
|
||||
{
|
||||
|
||||
class ParfactorList {
|
||||
public:
|
||||
ParfactorList (void) { }
|
||||
ParfactorList() { }
|
||||
|
||||
ParfactorList (const ParfactorList&);
|
||||
|
||||
ParfactorList (const Parfactors&);
|
||||
|
||||
~ParfactorList (void);
|
||||
~ParfactorList();
|
||||
|
||||
const list<Parfactor*>& parfactors (void) const { return pfList_; }
|
||||
const std::list<Parfactor*>& parfactors() const { return pfList_; }
|
||||
|
||||
void clear (void) { pfList_.clear(); }
|
||||
void clear() { pfList_.clear(); }
|
||||
|
||||
size_t size (void) const { return pfList_.size(); }
|
||||
size_t size() const { return pfList_.size(); }
|
||||
|
||||
typedef std::list<Parfactor*>::iterator iterator;
|
||||
|
||||
iterator begin (void) { return pfList_.begin(); }
|
||||
iterator begin() { return pfList_.begin(); }
|
||||
|
||||
iterator end (void) { return pfList_.end(); }
|
||||
iterator end() { return pfList_.end(); }
|
||||
|
||||
typedef std::list<Parfactor*>::const_iterator const_iterator;
|
||||
|
||||
const_iterator begin (void) const { return pfList_.begin(); }
|
||||
const_iterator begin() const { return pfList_.begin(); }
|
||||
|
||||
const_iterator end (void) const { return pfList_.end(); }
|
||||
const_iterator end() const { return pfList_.end(); }
|
||||
|
||||
void add (Parfactor* pf);
|
||||
|
||||
@ -47,16 +46,18 @@ class ParfactorList
|
||||
|
||||
void addShattered (Parfactor* pf);
|
||||
|
||||
list<Parfactor*>::iterator insertShattered (
|
||||
list<Parfactor*>::iterator, Parfactor*);
|
||||
std::list<Parfactor*>::iterator insertShattered (
|
||||
std::list<Parfactor*>::iterator, Parfactor*);
|
||||
|
||||
list<Parfactor*>::iterator remove (list<Parfactor*>::iterator);
|
||||
std::list<Parfactor*>::iterator remove (
|
||||
std::list<Parfactor*>::iterator);
|
||||
|
||||
list<Parfactor*>::iterator removeAndDelete (list<Parfactor*>::iterator);
|
||||
std::list<Parfactor*>::iterator removeAndDelete (
|
||||
std::list<Parfactor*>::iterator);
|
||||
|
||||
bool isAllShattered (void) const;
|
||||
bool isAllShattered() const;
|
||||
|
||||
void print (void) const;
|
||||
void print() const;
|
||||
|
||||
ParfactorList& operator= (const ParfactorList& pfList);
|
||||
|
||||
@ -101,22 +102,10 @@ class ParfactorList
|
||||
const ProbFormula&, ConstraintTree,
|
||||
const ProbFormula&, ConstraintTree) const;
|
||||
|
||||
struct sortByParams
|
||||
{
|
||||
inline bool operator() (const Parfactor* pf1, const Parfactor* pf2)
|
||||
{
|
||||
if (pf1->params().size() < pf2->params().size()) {
|
||||
return true;
|
||||
} else if (pf1->params().size() == pf2->params().size() &&
|
||||
pf1->params() < pf2->params()) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
list<Parfactor*> pfList_;
|
||||
std::list<Parfactor*> pfList_;
|
||||
};
|
||||
|
||||
#endif // HORUS_PARFACTORLIST_H
|
||||
} // namespace Horus
|
||||
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_PARFACTORLIST_H_
|
||||
|
||||
|
@ -1,6 +1,13 @@
|
||||
#include <cassert>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "ProbFormula.h"
|
||||
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
PrvGroup ProbFormula::freeGroup_ = 0;
|
||||
|
||||
|
||||
@ -38,7 +45,7 @@ ProbFormula::indexOf (LogVar X) const
|
||||
|
||||
|
||||
bool
|
||||
ProbFormula::isAtom (void) const
|
||||
ProbFormula::isAtom() const
|
||||
{
|
||||
return logVars_.empty();
|
||||
}
|
||||
@ -46,7 +53,7 @@ ProbFormula::isAtom (void) const
|
||||
|
||||
|
||||
bool
|
||||
ProbFormula::isCounting (void) const
|
||||
ProbFormula::isCounting() const
|
||||
{
|
||||
return countedLogVar_.valid();
|
||||
}
|
||||
@ -54,7 +61,7 @@ ProbFormula::isCounting (void) const
|
||||
|
||||
|
||||
LogVar
|
||||
ProbFormula::countedLogVar (void) const
|
||||
ProbFormula::countedLogVar() const
|
||||
{
|
||||
assert (isCounting());
|
||||
return countedLogVar_;
|
||||
@ -71,7 +78,7 @@ ProbFormula::setCountedLogVar (LogVar lv)
|
||||
|
||||
|
||||
void
|
||||
ProbFormula::clearCountedLogVar (void)
|
||||
ProbFormula::clearCountedLogVar()
|
||||
{
|
||||
countedLogVar_ = LogVar();
|
||||
}
|
||||
@ -93,15 +100,8 @@ ProbFormula::rename (LogVar oldName, LogVar newName)
|
||||
|
||||
|
||||
|
||||
bool operator== (const ProbFormula& f1, const ProbFormula& f2)
|
||||
{
|
||||
return f1.group_ == f2.group_ &&
|
||||
f1.logVars_ == f2.logVars_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
std::ostream& operator<< (ostream &os, const ProbFormula& f)
|
||||
std::ostream&
|
||||
operator<< (std::ostream& os, const ProbFormula& f)
|
||||
{
|
||||
os << f.functor_;
|
||||
if (f.isAtom() == false) {
|
||||
@ -122,7 +122,7 @@ std::ostream& operator<< (ostream &os, const ProbFormula& f)
|
||||
|
||||
|
||||
PrvGroup
|
||||
ProbFormula::getNewGroup (void)
|
||||
ProbFormula::getNewGroup()
|
||||
{
|
||||
freeGroup_ ++;
|
||||
assert (freeGroup_ != std::numeric_limits<PrvGroup>::max());
|
||||
@ -131,7 +131,24 @@ ProbFormula::getNewGroup (void)
|
||||
|
||||
|
||||
|
||||
ostream& operator<< (ostream &os, const ObservedFormula& of)
|
||||
ObservedFormula::ObservedFormula (Symbol f, unsigned a, unsigned ev)
|
||||
: functor_(f), arity_(a), evidence_(ev), constr_(a)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
ObservedFormula::ObservedFormula (Symbol f, unsigned ev, const Tuple& tuple)
|
||||
: functor_(f), arity_(tuple.size()), evidence_(ev), constr_(arity_)
|
||||
{
|
||||
constr_.addTuple (tuple);
|
||||
}
|
||||
|
||||
|
||||
|
||||
std::ostream&
|
||||
operator<< (std::ostream& os, const ObservedFormula& of)
|
||||
{
|
||||
os << of.functor_ << "/" << of.arity_;
|
||||
os << "|" << of.constr_.tupleSet();
|
||||
@ -139,3 +156,5 @@ ostream& operator<< (ostream &os, const ObservedFormula& of)
|
||||
return os;
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
|
@ -1,16 +1,20 @@
|
||||
#ifndef HORUS_PROBFORMULA_H
|
||||
#define HORUS_PROBFORMULA_H
|
||||
#ifndef YAP_PACKAGES_CLPBN_HORUS_PROBFORMULA_H_
|
||||
#define YAP_PACKAGES_CLPBN_HORUS_PROBFORMULA_H_
|
||||
|
||||
#include <vector>
|
||||
#include <ostream>
|
||||
#include <limits>
|
||||
|
||||
#include "ConstraintTree.h"
|
||||
#include "LiftedUtils.h"
|
||||
#include "Horus.h"
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
typedef unsigned long PrvGroup;
|
||||
|
||||
class ProbFormula
|
||||
{
|
||||
class ProbFormula {
|
||||
public:
|
||||
ProbFormula (Symbol f, const LogVars& lvs, unsigned range)
|
||||
: functor_(f), logVars_(lvs), range_(range),
|
||||
@ -20,19 +24,19 @@ class ProbFormula
|
||||
: functor_(f), range_(r),
|
||||
group_(std::numeric_limits<PrvGroup>::max()) { }
|
||||
|
||||
Symbol functor (void) const { return functor_; }
|
||||
Symbol functor() const { return functor_; }
|
||||
|
||||
unsigned arity (void) const { return logVars_.size(); }
|
||||
unsigned arity() const { return logVars_.size(); }
|
||||
|
||||
unsigned range (void) const { return range_; }
|
||||
unsigned range() const { return range_; }
|
||||
|
||||
LogVars& logVars (void) { return logVars_; }
|
||||
LogVars& logVars() { return logVars_; }
|
||||
|
||||
const LogVars& logVars (void) const { return logVars_; }
|
||||
const LogVars& logVars() const { return logVars_; }
|
||||
|
||||
LogVarSet logVarSet (void) const { return LogVarSet (logVars_); }
|
||||
LogVarSet logVarSet() const { return LogVarSet (logVars_); }
|
||||
|
||||
PrvGroup group (void) const { return group_; }
|
||||
PrvGroup group() const { return group_; }
|
||||
|
||||
void setGroup (PrvGroup g) { group_ = g; }
|
||||
|
||||
@ -44,25 +48,28 @@ class ProbFormula
|
||||
|
||||
size_t indexOf (LogVar) const;
|
||||
|
||||
bool isAtom (void) const;
|
||||
bool isAtom() const;
|
||||
|
||||
bool isCounting (void) const;
|
||||
bool isCounting() const;
|
||||
|
||||
LogVar countedLogVar (void) const;
|
||||
LogVar countedLogVar() const;
|
||||
|
||||
void setCountedLogVar (LogVar);
|
||||
|
||||
void clearCountedLogVar (void);
|
||||
void clearCountedLogVar();
|
||||
|
||||
void rename (LogVar, LogVar);
|
||||
|
||||
static PrvGroup getNewGroup (void);
|
||||
|
||||
friend std::ostream& operator<< (ostream &os, const ProbFormula& f);
|
||||
|
||||
friend bool operator== (const ProbFormula& f1, const ProbFormula& f2);
|
||||
static PrvGroup getNewGroup();
|
||||
|
||||
private:
|
||||
|
||||
friend bool operator== (
|
||||
const ProbFormula& f1, const ProbFormula& f2);
|
||||
|
||||
friend std::ostream& operator<< (
|
||||
std::ostream&, const ProbFormula&);
|
||||
|
||||
Symbol functor_;
|
||||
LogVars logVars_;
|
||||
unsigned range_;
|
||||
@ -71,45 +78,50 @@ class ProbFormula
|
||||
static PrvGroup freeGroup_;
|
||||
};
|
||||
|
||||
typedef vector<ProbFormula> ProbFormulas;
|
||||
typedef std::vector<ProbFormula> ProbFormulas;
|
||||
|
||||
|
||||
class ObservedFormula
|
||||
inline bool
|
||||
operator== (const ProbFormula& f1, const ProbFormula& f2)
|
||||
{
|
||||
return f1.group_ == f2.group_ && f1.logVars_ == f2.logVars_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
class ObservedFormula {
|
||||
public:
|
||||
ObservedFormula (Symbol f, unsigned a, unsigned ev)
|
||||
: functor_(f), arity_(a), evidence_(ev), constr_(a) { }
|
||||
ObservedFormula (Symbol f, unsigned a, unsigned ev);
|
||||
|
||||
ObservedFormula (Symbol f, unsigned ev, const Tuple& tuple)
|
||||
: functor_(f), arity_(tuple.size()), evidence_(ev), constr_(arity_)
|
||||
{
|
||||
constr_.addTuple (tuple);
|
||||
}
|
||||
ObservedFormula (Symbol f, unsigned ev, const Tuple& tuple);
|
||||
|
||||
Symbol functor (void) const { return functor_; }
|
||||
Symbol functor() const { return functor_; }
|
||||
|
||||
unsigned arity (void) const { return arity_; }
|
||||
unsigned arity() const { return arity_; }
|
||||
|
||||
unsigned evidence (void) const { return evidence_; }
|
||||
unsigned evidence() const { return evidence_; }
|
||||
|
||||
void setEvidence (unsigned ev) { evidence_ = ev; }
|
||||
|
||||
ConstraintTree& constr (void) { return constr_; }
|
||||
ConstraintTree& constr() { return constr_; }
|
||||
|
||||
bool isAtom (void) const { return arity_ == 0; }
|
||||
bool isAtom() const { return arity_ == 0; }
|
||||
|
||||
void addTuple (const Tuple& tuple) { constr_.addTuple (tuple); }
|
||||
|
||||
friend ostream& operator<< (ostream &os, const ObservedFormula& of);
|
||||
|
||||
private:
|
||||
friend std::ostream& operator<< (
|
||||
std::ostream&, const ObservedFormula&);
|
||||
|
||||
Symbol functor_;
|
||||
unsigned arity_;
|
||||
unsigned evidence_;
|
||||
ConstraintTree constr_;
|
||||
};
|
||||
|
||||
typedef vector<ObservedFormula> ObservedFormulas;
|
||||
typedef std::vector<ObservedFormula> ObservedFormulas;
|
||||
|
||||
#endif // HORUS_PROBFORMULA_H
|
||||
} // namespace Horus
|
||||
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_PROBFORMULA_H_
|
||||
|
||||
|
@ -1,20 +1,18 @@
|
||||
#ifndef HORUS_TINYSET_H
|
||||
#define HORUS_TINYSET_H
|
||||
|
||||
#include <algorithm>
|
||||
#ifndef YAP_PACKAGES_CLPBN_HORUS_TINYSET_H_
|
||||
#define YAP_PACKAGES_CLPBN_HORUS_TINYSET_H_
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <ostream>
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace Horus {
|
||||
|
||||
template <typename T, typename Compare = std::less<T>>
|
||||
class TinySet
|
||||
{
|
||||
class TinySet {
|
||||
public:
|
||||
|
||||
typedef typename vector<T>::iterator iterator;
|
||||
typedef typename vector<T>::const_iterator const_iterator;
|
||||
typedef typename std::vector<T>::iterator iterator;
|
||||
typedef typename std::vector<T>::const_iterator const_iterator;
|
||||
|
||||
TinySet (const TinySet& s)
|
||||
: vec_(s.vec_), cmp_(s.cmp_) { }
|
||||
@ -25,190 +23,72 @@ class TinySet
|
||||
TinySet (const T& t, const Compare& cmp = Compare())
|
||||
: vec_(1, t), cmp_(cmp) { }
|
||||
|
||||
TinySet (const vector<T>& elements, const Compare& cmp = Compare())
|
||||
: vec_(elements), cmp_(cmp)
|
||||
{
|
||||
std::sort (begin(), end(), cmp_);
|
||||
iterator it = unique_cmp (begin(), end());
|
||||
vec_.resize (it - begin());
|
||||
}
|
||||
TinySet (const std::vector<T>& elements, const Compare& cmp = Compare());
|
||||
|
||||
iterator insert (const T& t)
|
||||
{
|
||||
iterator it = std::lower_bound (begin(), end(), t, cmp_);
|
||||
if (it == end() || cmp_(t, *it)) {
|
||||
vec_.insert (it, t);
|
||||
}
|
||||
return it;
|
||||
}
|
||||
iterator insert (const T& t);
|
||||
|
||||
void insert_sorted (const T& t)
|
||||
{
|
||||
vec_.push_back (t);
|
||||
assert (consistent());
|
||||
}
|
||||
void insert_sorted (const T& t);
|
||||
|
||||
void remove (const T& t)
|
||||
{
|
||||
iterator it = std::lower_bound (begin(), end(), t, cmp_);
|
||||
if (it != end()) {
|
||||
vec_.erase (it);
|
||||
}
|
||||
}
|
||||
void remove (const T& t);
|
||||
|
||||
const_iterator find (const T& t) const
|
||||
{
|
||||
const_iterator it = std::lower_bound (begin(), end(), t, cmp_);
|
||||
return it == end() || cmp_(t, *it) ? end() : it;
|
||||
}
|
||||
|
||||
iterator find (const T& t)
|
||||
{
|
||||
iterator it = std::lower_bound (begin(), end(), t, cmp_);
|
||||
return it == end() || cmp_(t, *it) ? end() : it;
|
||||
}
|
||||
const_iterator find (const T& t) const;
|
||||
|
||||
iterator find (const T& t);
|
||||
|
||||
/* set union */
|
||||
TinySet operator| (const TinySet& s) const
|
||||
{
|
||||
TinySet res;
|
||||
std::set_union (
|
||||
vec_.begin(), vec_.end(),
|
||||
s.vec_.begin(), s.vec_.end(),
|
||||
std::back_inserter (res.vec_),
|
||||
cmp_);
|
||||
return res;
|
||||
}
|
||||
TinySet operator| (const TinySet& s) const;
|
||||
|
||||
/* set intersection */
|
||||
TinySet operator& (const TinySet& s) const
|
||||
{
|
||||
TinySet res;
|
||||
std::set_intersection (
|
||||
vec_.begin(), vec_.end(),
|
||||
s.vec_.begin(), s.vec_.end(),
|
||||
std::back_inserter (res.vec_),
|
||||
cmp_);
|
||||
return res;
|
||||
}
|
||||
TinySet operator& (const TinySet& s) const;
|
||||
|
||||
/* set difference */
|
||||
TinySet operator- (const TinySet& s) const
|
||||
{
|
||||
TinySet res;
|
||||
std::set_difference (
|
||||
vec_.begin(), vec_.end(),
|
||||
s.vec_.begin(), s.vec_.end(),
|
||||
std::back_inserter (res.vec_),
|
||||
cmp_);
|
||||
return res;
|
||||
}
|
||||
TinySet operator- (const TinySet& s) const;
|
||||
|
||||
TinySet& operator|= (const TinySet& s)
|
||||
{
|
||||
return *this = (*this | s);
|
||||
}
|
||||
TinySet& operator|= (const TinySet& s);
|
||||
|
||||
TinySet& operator&= (const TinySet& s)
|
||||
{
|
||||
return *this = (*this & s);
|
||||
}
|
||||
TinySet& operator&= (const TinySet& s);
|
||||
|
||||
TinySet& operator-= (const TinySet& s)
|
||||
{
|
||||
return *this = (*this - s);
|
||||
}
|
||||
TinySet& operator-= (const TinySet& s);
|
||||
|
||||
bool contains (const T& t) const
|
||||
{
|
||||
return std::binary_search (
|
||||
vec_.begin(), vec_.end(), t, cmp_);
|
||||
}
|
||||
bool contains (const T& t) const;
|
||||
|
||||
bool contains (const TinySet& s) const
|
||||
{
|
||||
return std::includes (
|
||||
vec_.begin(),
|
||||
vec_.end(),
|
||||
s.vec_.begin(),
|
||||
s.vec_.end(),
|
||||
cmp_);
|
||||
}
|
||||
bool contains (const TinySet& s) const;
|
||||
|
||||
bool in (const TinySet& s) const
|
||||
{
|
||||
return std::includes (
|
||||
s.vec_.begin(),
|
||||
s.vec_.end(),
|
||||
vec_.begin(),
|
||||
vec_.end(),
|
||||
cmp_);
|
||||
}
|
||||
bool in (const TinySet& s) const;
|
||||
|
||||
bool intersects (const TinySet& s) const
|
||||
{
|
||||
return (*this & s).size() > 0;
|
||||
}
|
||||
bool intersects (const TinySet& s) const;
|
||||
|
||||
const T& operator[] (typename vector<T>::size_type i) const
|
||||
{
|
||||
return vec_[i];
|
||||
}
|
||||
const T& operator[] (typename std::vector<T>::size_type i) const;
|
||||
|
||||
T& operator[] (typename vector<T>::size_type i)
|
||||
{
|
||||
return vec_[i];
|
||||
}
|
||||
T& operator[] (typename std::vector<T>::size_type i);
|
||||
|
||||
T front (void) const
|
||||
{
|
||||
return vec_.front();
|
||||
}
|
||||
T front() const;
|
||||
|
||||
T& front (void)
|
||||
{
|
||||
return vec_.front();
|
||||
}
|
||||
T& front();
|
||||
|
||||
T back (void) const
|
||||
{
|
||||
return vec_.back();
|
||||
}
|
||||
T back() const;
|
||||
|
||||
T& back (void)
|
||||
{
|
||||
return vec_.back();
|
||||
}
|
||||
T& back();
|
||||
|
||||
const vector<T>& elements (void) const
|
||||
{
|
||||
return vec_;
|
||||
}
|
||||
const std::vector<T>& elements() const;
|
||||
|
||||
bool empty (void) const
|
||||
{
|
||||
return vec_.empty();
|
||||
}
|
||||
bool empty() const;
|
||||
|
||||
typename vector<T>::size_type size (void) const
|
||||
{
|
||||
return vec_.size();
|
||||
}
|
||||
typename std::vector<T>::size_type size() const;
|
||||
|
||||
void clear (void)
|
||||
{
|
||||
vec_.clear();
|
||||
}
|
||||
void clear();
|
||||
|
||||
void reserve (typename vector<T>::size_type size)
|
||||
{
|
||||
vec_.reserve (size);
|
||||
}
|
||||
void reserve (typename std::vector<T>::size_type size);
|
||||
|
||||
iterator begin (void) { return vec_.begin(); }
|
||||
iterator end (void) { return vec_.end(); }
|
||||
const_iterator begin (void) const { return vec_.begin(); }
|
||||
const_iterator end (void) const { return vec_.end(); }
|
||||
iterator begin() { return vec_.begin(); }
|
||||
iterator end () { return vec_.end(); }
|
||||
const_iterator begin() const { return vec_.begin(); }
|
||||
const_iterator end () const { return vec_.end(); }
|
||||
|
||||
private:
|
||||
iterator unique_cmp (iterator first, iterator last);
|
||||
|
||||
bool consistent() const;
|
||||
|
||||
friend bool operator== (const TinySet& s1, const TinySet& s2)
|
||||
{
|
||||
@ -223,7 +103,7 @@ class TinySet
|
||||
friend std::ostream& operator<< (std::ostream& out, const TinySet& s)
|
||||
{
|
||||
out << "{" ;
|
||||
typename vector<T>::size_type i;
|
||||
typename std::vector<T>::size_type i;
|
||||
for (i = 0; i < s.size(); i++) {
|
||||
out << ((i != 0) ? "," : "") << s.vec_[i];
|
||||
}
|
||||
@ -231,35 +111,299 @@ class TinySet
|
||||
return out;
|
||||
}
|
||||
|
||||
private:
|
||||
iterator unique_cmp (iterator first, iterator last)
|
||||
{
|
||||
if (first == last) {
|
||||
return last;
|
||||
}
|
||||
iterator result = first;
|
||||
while (++first != last) {
|
||||
if (cmp_(*result, *first)) {
|
||||
*(++result) = *first;
|
||||
}
|
||||
}
|
||||
return ++result;
|
||||
}
|
||||
|
||||
bool consistent (void) const
|
||||
{
|
||||
typename vector<T>::size_type i;
|
||||
for (i = 0; i < vec_.size() - 1; i++) {
|
||||
if ( ! cmp_(vec_[i], vec_[i + 1])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
vector<T> vec_;
|
||||
Compare cmp_;
|
||||
std::vector<T> vec_;
|
||||
Compare cmp_;
|
||||
};
|
||||
|
||||
#endif // HORUS_TINYSET_H
|
||||
|
||||
|
||||
template <typename T, typename C> inline
|
||||
TinySet<T,C>::TinySet (const std::vector<T>& elements, const C& cmp)
|
||||
: vec_(elements), cmp_(cmp)
|
||||
{
|
||||
std::sort (begin(), end(), cmp_);
|
||||
iterator it = unique_cmp (begin(), end());
|
||||
vec_.resize (it - begin());
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T, typename C> inline typename TinySet<T,C>::iterator
|
||||
TinySet<T,C>::insert (const T& t)
|
||||
{
|
||||
iterator it = std::lower_bound (begin(), end(), t, cmp_);
|
||||
if (it == end() || cmp_(t, *it)) {
|
||||
vec_.insert (it, t);
|
||||
}
|
||||
return it;
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T, typename C> inline void
|
||||
TinySet<T,C>::insert_sorted (const T& t)
|
||||
{
|
||||
vec_.push_back (t);
|
||||
assert (consistent());
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T, typename C> inline void
|
||||
TinySet<T,C>::remove (const T& t)
|
||||
{
|
||||
iterator it = std::lower_bound (begin(), end(), t, cmp_);
|
||||
if (it != end()) {
|
||||
vec_.erase (it);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T, typename C> inline typename TinySet<T,C>::const_iterator
|
||||
TinySet<T,C>::find (const T& t) const
|
||||
{
|
||||
const_iterator it = std::lower_bound (begin(), end(), t, cmp_);
|
||||
return it == end() || cmp_(t, *it) ? end() : it;
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T, typename C> inline typename TinySet<T,C>::iterator
|
||||
TinySet<T,C>::find (const T& t)
|
||||
{
|
||||
iterator it = std::lower_bound (begin(), end(), t, cmp_);
|
||||
return it == end() || cmp_(t, *it) ? end() : it;
|
||||
}
|
||||
|
||||
|
||||
|
||||
/* set union */
|
||||
template <typename T, typename C> inline TinySet<T,C>
|
||||
TinySet<T,C>::operator| (const TinySet& s) const
|
||||
{
|
||||
TinySet res;
|
||||
std::set_union (
|
||||
vec_.begin(), vec_.end(),
|
||||
s.vec_.begin(), s.vec_.end(),
|
||||
std::back_inserter (res.vec_),
|
||||
cmp_);
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
|
||||
/* set intersection */
|
||||
template <typename T, typename C> inline TinySet<T,C>
|
||||
TinySet<T,C>::operator& (const TinySet& s) const
|
||||
{
|
||||
TinySet res;
|
||||
std::set_intersection (
|
||||
vec_.begin(), vec_.end(),
|
||||
s.vec_.begin(), s.vec_.end(),
|
||||
std::back_inserter (res.vec_),
|
||||
cmp_);
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
|
||||
/* set difference */
|
||||
template <typename T, typename C> inline TinySet<T,C>
|
||||
TinySet<T,C>::operator- (const TinySet& s) const
|
||||
{
|
||||
TinySet res;
|
||||
std::set_difference (
|
||||
vec_.begin(), vec_.end(),
|
||||
s.vec_.begin(), s.vec_.end(),
|
||||
std::back_inserter (res.vec_),
|
||||
cmp_);
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T, typename C> inline TinySet<T,C>&
|
||||
TinySet<T,C>::operator|= (const TinySet& s)
|
||||
{
|
||||
return *this = (*this | s);
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T, typename C> inline TinySet<T,C>&
|
||||
TinySet<T,C>::operator&= (const TinySet& s)
|
||||
{
|
||||
return *this = (*this & s);
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T, typename C> inline TinySet<T,C>&
|
||||
TinySet<T,C>::operator-= (const TinySet& s)
|
||||
{
|
||||
return *this = (*this - s);
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T, typename C> inline bool
|
||||
TinySet<T,C>::contains (const T& t) const
|
||||
{
|
||||
return std::binary_search (
|
||||
vec_.begin(), vec_.end(), t, cmp_);
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T, typename C> inline bool
|
||||
TinySet<T,C>::contains (const TinySet& s) const
|
||||
{
|
||||
return std::includes (
|
||||
vec_.begin(), vec_.end(),
|
||||
s.vec_.begin(), s.vec_.end(),
|
||||
cmp_);
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T, typename C> inline bool
|
||||
TinySet<T,C>::in (const TinySet& s) const
|
||||
{
|
||||
return std::includes (
|
||||
s.vec_.begin(), s.vec_.end(),
|
||||
vec_.begin(), vec_.end(),
|
||||
cmp_);
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T, typename C> inline bool
|
||||
TinySet<T,C>::intersects (const TinySet& s) const
|
||||
{
|
||||
return (*this & s).size() > 0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T, typename C> inline const T&
|
||||
TinySet<T,C>::operator[] (typename std::vector<T>::size_type i) const
|
||||
{
|
||||
return vec_[i];
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T, typename C> inline T&
|
||||
TinySet<T,C>::operator[] (typename std::vector<T>::size_type i)
|
||||
{
|
||||
return vec_[i];
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T, typename C> inline T
|
||||
TinySet<T,C>::front() const
|
||||
{
|
||||
return vec_.front();
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T, typename C> inline T&
|
||||
TinySet<T,C>::front()
|
||||
{
|
||||
return vec_.front();
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T, typename C> inline T
|
||||
TinySet<T,C>::back() const
|
||||
{
|
||||
return vec_.back();
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T, typename C> inline T&
|
||||
TinySet<T,C>::back()
|
||||
{
|
||||
return vec_.back();
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T, typename C> inline const std::vector<T>&
|
||||
TinySet<T,C>::elements() const
|
||||
{
|
||||
return vec_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T, typename C> inline bool
|
||||
TinySet<T,C>::empty() const
|
||||
{
|
||||
return vec_.empty();
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T, typename C> inline typename std::vector<T>::size_type
|
||||
TinySet<T,C>::size() const
|
||||
{
|
||||
return vec_.size();
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T, typename C> inline void
|
||||
TinySet<T,C>::clear()
|
||||
{
|
||||
vec_.clear();
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T, typename C> inline void
|
||||
TinySet<T,C>::reserve (typename std::vector<T>::size_type size)
|
||||
{
|
||||
vec_.reserve (size);
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T, typename C> typename TinySet<T,C>::iterator
|
||||
TinySet<T,C>::unique_cmp (iterator first, iterator last)
|
||||
{
|
||||
if (first == last) {
|
||||
return last;
|
||||
}
|
||||
iterator result = first;
|
||||
while (++first != last) {
|
||||
if (cmp_(*result, *first)) {
|
||||
*(++result) = *first;
|
||||
}
|
||||
}
|
||||
return ++result;
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T, typename C> inline bool
|
||||
TinySet<T,C>::consistent() const
|
||||
{
|
||||
typename std::vector<T>::size_type i;
|
||||
for (i = 0; i < vec_.size() - 1; i++) {
|
||||
if ( ! cmp_(vec_[i], vec_[i + 1])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_TINYSET_H_
|
||||
|
||||
|
@ -1,27 +1,27 @@
|
||||
#include <fstream>
|
||||
|
||||
#include "Util.h"
|
||||
#include "Indexer.h"
|
||||
#include "ElimGraph.h"
|
||||
#include "BeliefProp.h"
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
namespace Globals {
|
||||
|
||||
bool logDomain = false;
|
||||
|
||||
unsigned verbosity = 0;
|
||||
|
||||
LiftedSolverType liftedSolver = LiftedSolverType::LVE;
|
||||
LiftedSolverType liftedSolver = LiftedSolverType::lveSolver;
|
||||
|
||||
GroundSolverType groundSolver = GroundSolverType::VE;
|
||||
GroundSolverType groundSolver = GroundSolverType::veSolver;
|
||||
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
|
||||
namespace Util {
|
||||
|
||||
|
||||
template <> std::string
|
||||
toString (const bool& b)
|
||||
{
|
||||
@ -33,14 +33,14 @@ toString (const bool& b)
|
||||
|
||||
|
||||
unsigned
|
||||
stringToUnsigned (string str)
|
||||
stringToUnsigned (std::string str)
|
||||
{
|
||||
int val;
|
||||
stringstream ss;
|
||||
std::stringstream ss;
|
||||
ss << str;
|
||||
ss >> val;
|
||||
if (val < 0) {
|
||||
cerr << "Error: the number readed is negative." << endl;
|
||||
std::cerr << "Error: the number readed is negative." << std::endl;
|
||||
exit (EXIT_FAILURE);
|
||||
}
|
||||
return static_cast<unsigned> (val);
|
||||
@ -49,10 +49,10 @@ stringToUnsigned (string str)
|
||||
|
||||
|
||||
double
|
||||
stringToDouble (string str)
|
||||
stringToDouble (std::string str)
|
||||
{
|
||||
double val;
|
||||
stringstream ss;
|
||||
std::stringstream ss;
|
||||
ss << str;
|
||||
ss >> val;
|
||||
return val;
|
||||
@ -117,7 +117,7 @@ size_t
|
||||
sizeExpected (const Ranges& ranges)
|
||||
{
|
||||
return std::accumulate (ranges.begin(),
|
||||
ranges.end(), 1, multiplies<unsigned>());
|
||||
ranges.end(), 1, std::multiplies<unsigned>());
|
||||
}
|
||||
|
||||
|
||||
@ -136,10 +136,10 @@ nrDigits (int num)
|
||||
|
||||
|
||||
bool
|
||||
isInteger (const string& s)
|
||||
isInteger (const std::string& s)
|
||||
{
|
||||
stringstream ss1 (s);
|
||||
stringstream ss2;
|
||||
std::stringstream ss1 (s);
|
||||
std::stringstream ss2;
|
||||
int integer;
|
||||
ss1 >> integer;
|
||||
ss2 << integer;
|
||||
@ -148,10 +148,10 @@ isInteger (const string& s)
|
||||
|
||||
|
||||
|
||||
string
|
||||
std::string
|
||||
parametersToString (const Params& v, unsigned precision)
|
||||
{
|
||||
stringstream ss;
|
||||
std::stringstream ss;
|
||||
ss.precision (precision);
|
||||
ss << "[" ;
|
||||
for (size_t i = 0; i < v.size(); i++) {
|
||||
@ -164,7 +164,7 @@ parametersToString (const Params& v, unsigned precision)
|
||||
|
||||
|
||||
|
||||
vector<string>
|
||||
std::vector<std::string>
|
||||
getStateLines (const Vars& vars)
|
||||
{
|
||||
Ranges ranges;
|
||||
@ -172,9 +172,9 @@ getStateLines (const Vars& vars)
|
||||
ranges.push_back (vars[i]->range());
|
||||
}
|
||||
Indexer indexer (ranges);
|
||||
vector<string> jointStrings;
|
||||
std::vector<std::string> jointStrings;
|
||||
while (indexer.valid()) {
|
||||
stringstream ss;
|
||||
std::stringstream ss;
|
||||
for (size_t i = 0; i < vars.size(); i++) {
|
||||
if (i != 0) ss << ", " ;
|
||||
ss << vars[i]->label() << "=" ;
|
||||
@ -188,34 +188,42 @@ getStateLines (const Vars& vars)
|
||||
|
||||
|
||||
|
||||
bool invalidValue (string option, string value)
|
||||
bool invalidValue (std::string option, std::string value)
|
||||
{
|
||||
cerr << "Warning: invalid value `" << value << "' " ;
|
||||
cerr << "for `" << option << "'." ;
|
||||
cerr << endl;
|
||||
std::cerr << "Warning: invalid value `" << value << "' " ;
|
||||
std::cerr << "for `" << option << "'." ;
|
||||
std::cerr << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
setHorusFlag (string option, string value)
|
||||
setHorusFlag (std::string option, std::string value)
|
||||
{
|
||||
bool returnVal = true;
|
||||
if (option == "lifted_solver") {
|
||||
if (value == "lve") Globals::liftedSolver = LiftedSolverType::LVE;
|
||||
else if (value == "lbp") Globals::liftedSolver = LiftedSolverType::LBP;
|
||||
else if (value == "lkc") Globals::liftedSolver = LiftedSolverType::LKC;
|
||||
else returnVal = invalidValue (option, value);
|
||||
if (value == "lve")
|
||||
Globals::liftedSolver = LiftedSolverType::lveSolver;
|
||||
else if (value == "lbp")
|
||||
Globals::liftedSolver = LiftedSolverType::lbpSolver;
|
||||
else if (value == "lkc")
|
||||
Globals::liftedSolver = LiftedSolverType::lkcSolver;
|
||||
else
|
||||
returnVal = invalidValue (option, value);
|
||||
|
||||
} else if (option == "ground_solver" || option == "solver") {
|
||||
if (value == "hve") Globals::groundSolver = GroundSolverType::VE;
|
||||
else if (value == "bp") Globals::groundSolver = GroundSolverType::BP;
|
||||
else if (value == "cbp") Globals::groundSolver = GroundSolverType::CBP;
|
||||
else returnVal = invalidValue (option, value);
|
||||
if (value == "hve")
|
||||
Globals::groundSolver = GroundSolverType::veSolver;
|
||||
else if (value == "bp")
|
||||
Globals::groundSolver = GroundSolverType::bpSolver;
|
||||
else if (value == "cbp")
|
||||
Globals::groundSolver = GroundSolverType::CbpSolver;
|
||||
else
|
||||
returnVal = invalidValue (option, value);
|
||||
|
||||
} else if (option == "verbosity") {
|
||||
stringstream ss;
|
||||
std::stringstream ss;
|
||||
ss << value;
|
||||
ss >> Globals::verbosity;
|
||||
|
||||
@ -225,40 +233,42 @@ setHorusFlag (string option, string value)
|
||||
else returnVal = invalidValue (option, value);
|
||||
|
||||
} else if (option == "hve_elim_heuristic") {
|
||||
typedef ElimGraph::ElimHeuristic ElimHeuristic;
|
||||
if (value == "sequential")
|
||||
ElimGraph::setElimHeuristic (ElimHeuristic::SEQUENTIAL);
|
||||
ElimGraph::setElimHeuristic (ElimHeuristic::sequentialEh);
|
||||
else if (value == "min_neighbors")
|
||||
ElimGraph::setElimHeuristic (ElimHeuristic::MIN_NEIGHBORS);
|
||||
ElimGraph::setElimHeuristic (ElimHeuristic::minNeighborsEh);
|
||||
else if (value == "min_weight")
|
||||
ElimGraph::setElimHeuristic (ElimHeuristic::MIN_WEIGHT);
|
||||
ElimGraph::setElimHeuristic (ElimHeuristic::minWeightEh);
|
||||
else if (value == "min_fill")
|
||||
ElimGraph::setElimHeuristic (ElimHeuristic::MIN_FILL);
|
||||
ElimGraph::setElimHeuristic (ElimHeuristic::minFillEh);
|
||||
else if (value == "weighted_min_fill")
|
||||
ElimGraph::setElimHeuristic (ElimHeuristic::WEIGHTED_MIN_FILL);
|
||||
ElimGraph::setElimHeuristic (ElimHeuristic::weightedMinFillEh);
|
||||
else
|
||||
returnVal = invalidValue (option, value);
|
||||
|
||||
} else if (option == "bp_msg_schedule") {
|
||||
typedef BeliefProp::MsgSchedule MsgSchedule;
|
||||
if (value == "seq_fixed")
|
||||
BeliefProp::setMsgSchedule (MsgSchedule::SEQ_FIXED);
|
||||
BeliefProp::setMsgSchedule (MsgSchedule::seqFixedSch);
|
||||
else if (value == "seq_random")
|
||||
BeliefProp::setMsgSchedule (MsgSchedule::SEQ_RANDOM);
|
||||
BeliefProp::setMsgSchedule (MsgSchedule::seqRandomSch);
|
||||
else if (value == "parallel")
|
||||
BeliefProp::setMsgSchedule (MsgSchedule::PARALLEL);
|
||||
BeliefProp::setMsgSchedule (MsgSchedule::parallelSch);
|
||||
else if (value == "max_residual")
|
||||
BeliefProp::setMsgSchedule (MsgSchedule::MAX_RESIDUAL);
|
||||
BeliefProp::setMsgSchedule (MsgSchedule::maxResidualSch);
|
||||
else
|
||||
returnVal = invalidValue (option, value);
|
||||
|
||||
} else if (option == "bp_accuracy") {
|
||||
stringstream ss;
|
||||
std::stringstream ss;
|
||||
double acc;
|
||||
ss << value;
|
||||
ss >> acc;
|
||||
BeliefProp::setAccuracy (acc);
|
||||
|
||||
} else if (option == "bp_max_iter") {
|
||||
stringstream ss;
|
||||
std::stringstream ss;
|
||||
unsigned mi;
|
||||
ss << value;
|
||||
ss >> mi;
|
||||
@ -285,7 +295,7 @@ setHorusFlag (string option, string value)
|
||||
else returnVal = invalidValue (option, value);
|
||||
|
||||
} else {
|
||||
cerr << "Warning: invalid option `" << option << "'" << endl;
|
||||
std::cerr << "Warning: invalid option `" << option << "'" << std::endl;
|
||||
returnVal = false;
|
||||
}
|
||||
return returnVal;
|
||||
@ -294,20 +304,20 @@ setHorusFlag (string option, string value)
|
||||
|
||||
|
||||
void
|
||||
printHeader (string header, std::ostream& os)
|
||||
printHeader (std::string header, std::ostream& os)
|
||||
{
|
||||
printAsteriskLine (os);
|
||||
os << header << endl;
|
||||
os << header << std::endl;
|
||||
printAsteriskLine (os);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
printSubHeader (string header, std::ostream& os)
|
||||
printSubHeader (std::string header, std::ostream& os)
|
||||
{
|
||||
printDashedLine (os);
|
||||
os << header << endl;
|
||||
os << header << std::endl;
|
||||
printDashedLine (os);
|
||||
}
|
||||
|
||||
@ -318,7 +328,7 @@ printAsteriskLine (std::ostream& os)
|
||||
{
|
||||
os << "********************************" ;
|
||||
os << "********************************" ;
|
||||
os << endl;
|
||||
os << std::endl;
|
||||
}
|
||||
|
||||
|
||||
@ -328,11 +338,10 @@ printDashedLine (std::ostream& os)
|
||||
{
|
||||
os << "--------------------------------" ;
|
||||
os << "--------------------------------" ;
|
||||
os << endl;
|
||||
os << std::endl;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
} // namespace Util
|
||||
|
||||
|
||||
|
||||
@ -362,10 +371,10 @@ getL1Distance (const Params& v1, const Params& v2)
|
||||
double dist = 0.0;
|
||||
if (Globals::logDomain) {
|
||||
dist = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0,
|
||||
std::plus<double>(), FuncObject::abs_diff_exp<double>());
|
||||
std::plus<double>(), FuncObj::abs_diff_exp<double>());
|
||||
} else {
|
||||
dist = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0,
|
||||
std::plus<double>(), FuncObject::abs_diff<double>());
|
||||
std::plus<double>(), FuncObj::abs_diff<double>());
|
||||
}
|
||||
return dist;
|
||||
}
|
||||
@ -379,10 +388,10 @@ getMaxNorm (const Params& v1, const Params& v2)
|
||||
double max = 0.0;
|
||||
if (Globals::logDomain) {
|
||||
max = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0,
|
||||
FuncObject::max<double>(), FuncObject::abs_diff_exp<double>());
|
||||
FuncObj::max<double>(), FuncObj::abs_diff_exp<double>());
|
||||
} else {
|
||||
max = std::inner_product (v1.begin(), v1.end(), v2.begin(), 0.0,
|
||||
FuncObject::max<double>(), FuncObject::abs_diff<double>());
|
||||
FuncObj::max<double>(), FuncObj::abs_diff<double>());
|
||||
}
|
||||
return max;
|
||||
}
|
||||
@ -428,5 +437,8 @@ pow (Params& v, double exp)
|
||||
Globals::logDomain ? v *= exp : v ^= exp;
|
||||
}
|
||||
|
||||
}
|
||||
} // namespace LogAware
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
|
||||
|
@ -1,69 +1,78 @@
|
||||
#ifndef HORUS_UTIL_H
|
||||
#define HORUS_UTIL_H
|
||||
#ifndef YAP_PACKAGES_CLPBN_HORUS_UTIL_H_
|
||||
#define YAP_PACKAGES_CLPBN_HORUS_UTIL_H_
|
||||
|
||||
#include <cmath>
|
||||
#include <cassert>
|
||||
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
|
||||
#include <vector>
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include <unordered_map>
|
||||
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "Horus.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace Horus {
|
||||
|
||||
namespace {
|
||||
|
||||
const double NEG_INF = -std::numeric_limits<double>::infinity();
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
|
||||
namespace Util {
|
||||
|
||||
template <typename T> void addToVector (vector<T>&, const vector<T>&);
|
||||
template <typename T> void
|
||||
addToVector (std::vector<T>&, const std::vector<T>&);
|
||||
|
||||
template <typename T> void addToSet (set<T>&, const vector<T>&);
|
||||
template <typename T> void
|
||||
addToSet (std::set<T>&, const std::vector<T>&);
|
||||
|
||||
template <typename T> void addToQueue (queue<T>&, const vector<T>&);
|
||||
template <typename T> void
|
||||
addToQueue (std::queue<T>&, const std::vector<T>&);
|
||||
|
||||
template <typename T> bool contains (const vector<T>&, const T&);
|
||||
template <typename T> bool
|
||||
contains (const std::vector<T>&, const T&);
|
||||
|
||||
template <typename T> bool contains (const set<T>&, const T&);
|
||||
template <typename T> bool contains
|
||||
(const std::set<T>&, const T&);
|
||||
|
||||
template <typename K, typename V> bool contains (
|
||||
const unordered_map<K, V>&, const K&);
|
||||
template <typename K, typename V> bool
|
||||
contains (const std::unordered_map<K, V>&, const K&);
|
||||
|
||||
template <typename T> size_t indexOf (const vector<T>&, const T&);
|
||||
template <typename T> size_t
|
||||
indexOf (const std::vector<T>&, const T&);
|
||||
|
||||
template <class Operation>
|
||||
void apply_n_times (Params& v1, const Params& v2,
|
||||
unsigned repetitions, Operation);
|
||||
template <class Operation> void
|
||||
apply_n_times (Params& v1, const Params& v2, unsigned reps, Operation);
|
||||
|
||||
template <typename T> void log (vector<T>&);
|
||||
template <typename T> void
|
||||
log (std::vector<T>&);
|
||||
|
||||
template <typename T> void exp (vector<T>&);
|
||||
template <typename T> void
|
||||
exp (std::vector<T>&);
|
||||
|
||||
template <typename T> string elementsToString (
|
||||
const vector<T>& v, string sep = " ");
|
||||
template <typename T> std::string
|
||||
elementsToString (const std::vector<T>& v, std::string sep = " ");
|
||||
|
||||
template <typename T> std::string toString (const T&);
|
||||
template <typename T> std::string
|
||||
toString (const T&);
|
||||
|
||||
template <> std::string toString (const bool&);
|
||||
template <> std::string
|
||||
toString (const bool&);
|
||||
|
||||
double logSum (double, double);
|
||||
|
||||
unsigned maxUnsigned (void);
|
||||
unsigned maxUnsigned();
|
||||
|
||||
unsigned stringToUnsigned (string);
|
||||
unsigned stringToUnsigned (std::string);
|
||||
|
||||
double stringToDouble (string);
|
||||
double stringToDouble (std::string);
|
||||
|
||||
double factorial (unsigned);
|
||||
|
||||
@ -75,28 +84,29 @@ size_t sizeExpected (const Ranges&);
|
||||
|
||||
unsigned nrDigits (int);
|
||||
|
||||
bool isInteger (const string&);
|
||||
bool isInteger (const std::string&);
|
||||
|
||||
string parametersToString (const Params&, unsigned = Constants::PRECISION);
|
||||
std::string parametersToString (
|
||||
const Params&, unsigned = Constants::precision);
|
||||
|
||||
vector<string> getStateLines (const Vars&);
|
||||
std::vector<std::string> getStateLines (const Vars&);
|
||||
|
||||
bool setHorusFlag (string option, string value);
|
||||
bool setHorusFlag (std::string option, std::string value);
|
||||
|
||||
void printHeader (string, std::ostream& os = std::cout);
|
||||
void printHeader (std::string, std::ostream& os = std::cout);
|
||||
|
||||
void printSubHeader (string, std::ostream& os = std::cout);
|
||||
void printSubHeader (std::string, std::ostream& os = std::cout);
|
||||
|
||||
void printAsteriskLine (std::ostream& os = std::cout);
|
||||
|
||||
void printDashedLine (std::ostream& os = std::cout);
|
||||
|
||||
};
|
||||
} // namespace Util
|
||||
|
||||
|
||||
|
||||
template <typename T> void
|
||||
Util::addToVector (vector<T>& v, const vector<T>& elements)
|
||||
Util::addToVector (std::vector<T>& v, const std::vector<T>& elements)
|
||||
{
|
||||
v.insert (v.end(), elements.begin(), elements.end());
|
||||
}
|
||||
@ -104,7 +114,7 @@ Util::addToVector (vector<T>& v, const vector<T>& elements)
|
||||
|
||||
|
||||
template <typename T> void
|
||||
Util::addToSet (set<T>& s, const vector<T>& elements)
|
||||
Util::addToSet (std::set<T>& s, const std::vector<T>& elements)
|
||||
{
|
||||
s.insert (elements.begin(), elements.end());
|
||||
}
|
||||
@ -112,7 +122,7 @@ Util::addToSet (set<T>& s, const vector<T>& elements)
|
||||
|
||||
|
||||
template <typename T> void
|
||||
Util::addToQueue (queue<T>& q, const vector<T>& elements)
|
||||
Util::addToQueue (std::queue<T>& q, const std::vector<T>& elements)
|
||||
{
|
||||
for (size_t i = 0; i < elements.size(); i++) {
|
||||
q.push (elements[i]);
|
||||
@ -122,7 +132,7 @@ Util::addToQueue (queue<T>& q, const vector<T>& elements)
|
||||
|
||||
|
||||
template <typename T> bool
|
||||
Util::contains (const vector<T>& v, const T& e)
|
||||
Util::contains (const std::vector<T>& v, const T& e)
|
||||
{
|
||||
return std::find (v.begin(), v.end(), e) != v.end();
|
||||
}
|
||||
@ -130,7 +140,7 @@ Util::contains (const vector<T>& v, const T& e)
|
||||
|
||||
|
||||
template <typename T> bool
|
||||
Util::contains (const set<T>& s, const T& e)
|
||||
Util::contains (const std::set<T>& s, const T& e)
|
||||
{
|
||||
return s.find (e) != s.end();
|
||||
}
|
||||
@ -138,7 +148,7 @@ Util::contains (const set<T>& s, const T& e)
|
||||
|
||||
|
||||
template <typename K, typename V> bool
|
||||
Util::contains (const unordered_map<K, V>& m, const K& k)
|
||||
Util::contains (const std::unordered_map<K, V>& m, const K& k)
|
||||
{
|
||||
return m.find (k) != m.end();
|
||||
}
|
||||
@ -146,7 +156,7 @@ Util::contains (const unordered_map<K, V>& m, const K& k)
|
||||
|
||||
|
||||
template <typename T> size_t
|
||||
Util::indexOf (const vector<T>& v, const T& e)
|
||||
Util::indexOf (const std::vector<T>& v, const T& e)
|
||||
{
|
||||
return std::distance (v.begin(),
|
||||
std::find (v.begin(), v.end(), e));
|
||||
@ -155,7 +165,10 @@ Util::indexOf (const vector<T>& v, const T& e)
|
||||
|
||||
|
||||
template <class Operation> void
|
||||
Util::apply_n_times (Params& v1, const Params& v2, unsigned repetitions,
|
||||
Util::apply_n_times (
|
||||
Params& v1,
|
||||
const Params& v2,
|
||||
unsigned repetitions,
|
||||
Operation unary_op)
|
||||
{
|
||||
Params::iterator first = v1.begin();
|
||||
@ -174,7 +187,7 @@ Util::apply_n_times (Params& v1, const Params& v2, unsigned repetitions,
|
||||
|
||||
|
||||
template <typename T> void
|
||||
Util::log (vector<T>& v)
|
||||
Util::log (std::vector<T>& v)
|
||||
{
|
||||
std::transform (v.begin(), v.end(), v.begin(), ::log);
|
||||
}
|
||||
@ -182,17 +195,17 @@ Util::log (vector<T>& v)
|
||||
|
||||
|
||||
template <typename T> void
|
||||
Util::exp (vector<T>& v)
|
||||
Util::exp (std::vector<T>& v)
|
||||
{
|
||||
std::transform (v.begin(), v.end(), v.begin(), ::exp);
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T> string
|
||||
Util::elementsToString (const vector<T>& v, string sep)
|
||||
template <typename T> std::string
|
||||
Util::elementsToString (const std::vector<T>& v, std::string sep)
|
||||
{
|
||||
stringstream ss;
|
||||
std::stringstream ss;
|
||||
for (size_t i = 0; i < v.size(); i++) {
|
||||
ss << ((i != 0) ? sep : "") << v[i];
|
||||
}
|
||||
@ -245,7 +258,7 @@ Util::logSum (double x, double y)
|
||||
|
||||
|
||||
inline unsigned
|
||||
Util::maxUnsigned (void)
|
||||
Util::maxUnsigned()
|
||||
{
|
||||
return std::numeric_limits<unsigned>::max();
|
||||
}
|
||||
@ -277,106 +290,106 @@ void pow (Params&, unsigned);
|
||||
|
||||
void pow (Params&, double);
|
||||
|
||||
};
|
||||
} // namespace LogAware
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
void operator+=(std::vector<T>& v, double val)
|
||||
template <typename T> void
|
||||
operator+=(std::vector<T>& v, double val)
|
||||
{
|
||||
std::transform (v.begin(), v.end(), v.begin(),
|
||||
std::bind2nd (plus<double>(), val));
|
||||
std::bind2nd (std::plus<double>(), val));
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
void operator-=(std::vector<T>& v, double val)
|
||||
template <typename T> void
|
||||
operator-=(std::vector<T>& v, double val)
|
||||
{
|
||||
std::transform (v.begin(), v.end(), v.begin(),
|
||||
std::bind2nd (minus<double>(), val));
|
||||
std::bind2nd (std::minus<double>(), val));
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
void operator*=(std::vector<T>& v, double val)
|
||||
template <typename T> void
|
||||
operator*=(std::vector<T>& v, double val)
|
||||
{
|
||||
std::transform (v.begin(), v.end(), v.begin(),
|
||||
std::bind2nd (multiplies<double>(), val));
|
||||
std::bind2nd (std::multiplies<double>(), val));
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
void operator/=(std::vector<T>& v, double val)
|
||||
template <typename T> void
|
||||
operator/=(std::vector<T>& v, double val)
|
||||
{
|
||||
std::transform (v.begin(), v.end(), v.begin(),
|
||||
std::bind2nd (divides<double>(), val));
|
||||
std::bind2nd (std::divides<double>(), val));
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
void operator+=(std::vector<T>& a, const std::vector<T>& b)
|
||||
template <typename T> void
|
||||
operator+=(std::vector<T>& a, const std::vector<T>& b)
|
||||
{
|
||||
assert (a.size() == b.size());
|
||||
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
|
||||
plus<double>());
|
||||
std::plus<double>());
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
void operator-=(std::vector<T>& a, const std::vector<T>& b)
|
||||
template <typename T> void
|
||||
operator-=(std::vector<T>& a, const std::vector<T>& b)
|
||||
{
|
||||
assert (a.size() == b.size());
|
||||
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
|
||||
minus<double>());
|
||||
std::minus<double>());
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
void operator*=(std::vector<T>& a, const std::vector<T>& b)
|
||||
template <typename T> void
|
||||
operator*=(std::vector<T>& a, const std::vector<T>& b)
|
||||
{
|
||||
assert (a.size() == b.size());
|
||||
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
|
||||
multiplies<double>());
|
||||
std::multiplies<double>());
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
void operator/=(std::vector<T>& a, const std::vector<T>& b)
|
||||
template <typename T> void
|
||||
operator/=(std::vector<T>& a, const std::vector<T>& b)
|
||||
{
|
||||
assert (a.size() == b.size());
|
||||
std::transform (a.begin(), a.end(), b.begin(), a.begin(),
|
||||
divides<double>());
|
||||
std::divides<double>());
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
void operator^=(std::vector<T>& v, double exp)
|
||||
template <typename T> void
|
||||
operator^=(std::vector<T>& v, double exp)
|
||||
{
|
||||
std::transform (v.begin(), v.end(), v.begin(),
|
||||
std::bind2nd (ptr_fun<double, double, double> (std::pow), exp));
|
||||
std::bind2nd (std::ptr_fun<double, double, double> (std::pow), exp));
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
void operator^=(std::vector<T>& v, int iexp)
|
||||
template <typename T> void
|
||||
operator^=(std::vector<T>& v, int iexp)
|
||||
{
|
||||
std::transform (v.begin(), v.end(), v.begin(),
|
||||
std::bind2nd (ptr_fun<double, int, double> (std::pow), iexp));
|
||||
std::bind2nd (std::ptr_fun<double, int, double> (std::pow), iexp));
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
std::ostream& operator<< (std::ostream& os, const vector<T>& v)
|
||||
template <typename T> std::ostream&
|
||||
operator<< (std::ostream& os, const std::vector<T>& v)
|
||||
{
|
||||
os << "[" ;
|
||||
os << Util::elementsToString (v, ", ");
|
||||
@ -385,40 +398,33 @@ std::ostream& operator<< (std::ostream& os, const vector<T>& v)
|
||||
}
|
||||
|
||||
|
||||
namespace FuncObject {
|
||||
namespace FuncObj {
|
||||
|
||||
template<typename T>
|
||||
struct max : public std::binary_function<T, T, T>
|
||||
{
|
||||
T operator() (const T& x, const T& y) const
|
||||
{
|
||||
struct max : public std::binary_function<T, T, T> {
|
||||
T operator() (const T& x, const T& y) const {
|
||||
return x < y ? y : x;
|
||||
}
|
||||
};
|
||||
}};
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
struct abs_diff : public std::binary_function<T, T, T>
|
||||
{
|
||||
T operator() (const T& x, const T& y) const
|
||||
{
|
||||
struct abs_diff : public std::binary_function<T, T, T> {
|
||||
T operator() (const T& x, const T& y) const {
|
||||
return std::abs (x - y);
|
||||
}
|
||||
};
|
||||
}};
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
struct abs_diff_exp : public std::binary_function<T, T, T>
|
||||
{
|
||||
T operator() (const T& x, const T& y) const
|
||||
{
|
||||
struct abs_diff_exp : public std::binary_function<T, T, T> {
|
||||
T operator() (const T& x, const T& y) const {
|
||||
return std::abs (std::exp (x) - std::exp (y));
|
||||
}
|
||||
};
|
||||
}};
|
||||
|
||||
}
|
||||
} // namespace FuncObj
|
||||
|
||||
#endif // HORUS_UTIL_H
|
||||
} // namespace Horus
|
||||
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_UTIL_H_
|
||||
|
||||
|
@ -3,7 +3,9 @@
|
||||
#include "Var.h"
|
||||
|
||||
|
||||
unordered_map<VarId, VarInfo> Var::varsInfo_;
|
||||
namespace Horus {
|
||||
|
||||
std::unordered_map<VarId, Var::VarInfo> Var::varsInfo_;
|
||||
|
||||
|
||||
Var::Var (const Var* v)
|
||||
@ -45,13 +47,14 @@ Var::setEvidence (int evidence)
|
||||
|
||||
|
||||
|
||||
string
|
||||
Var::label (void) const
|
||||
std::string
|
||||
Var::label() const
|
||||
{
|
||||
if (Var::varsHaveInfo()) {
|
||||
return Var::getVarInfo (varId_).label;
|
||||
assert (Util::contains (varsInfo_, varId_));
|
||||
return varsInfo_.find (varId_)->second.first;
|
||||
}
|
||||
stringstream ss;
|
||||
std::stringstream ss;
|
||||
ss << "x" << varId_;
|
||||
return ss.str();
|
||||
}
|
||||
@ -59,17 +62,46 @@ Var::label (void) const
|
||||
|
||||
|
||||
States
|
||||
Var::states (void) const
|
||||
Var::states() const
|
||||
{
|
||||
if (Var::varsHaveInfo()) {
|
||||
return Var::getVarInfo (varId_).states;
|
||||
assert (Util::contains (varsInfo_, varId_));
|
||||
return varsInfo_.find (varId_)->second.second;
|
||||
}
|
||||
States states;
|
||||
for (unsigned i = 0; i < range_; i++) {
|
||||
stringstream ss;
|
||||
std::stringstream ss;
|
||||
ss << i ;
|
||||
states.push_back (ss.str());
|
||||
}
|
||||
return states;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Var::addVarInfo (
|
||||
VarId vid, std::string label, const States& states)
|
||||
{
|
||||
assert (Util::contains (varsInfo_, vid) == false);
|
||||
varsInfo_.insert (std::make_pair (vid, VarInfo (label, states)));
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
Var::varsHaveInfo()
|
||||
{
|
||||
return varsInfo_.empty() == false;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Var::clearVarsInfo()
|
||||
{
|
||||
varsInfo_.clear();
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
|
@ -1,102 +1,105 @@
|
||||
#ifndef HORUS_VAR_H
|
||||
#define HORUS_VAR_H
|
||||
#ifndef YAP_PACKAGES_CLPBN_HORUS_VAR_H_
|
||||
#define YAP_PACKAGES_CLPBN_HORUS_VAR_H_
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include <unordered_map>
|
||||
#include <string>
|
||||
|
||||
#include "Util.h"
|
||||
#include "Horus.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
namespace Horus {
|
||||
|
||||
|
||||
struct VarInfo
|
||||
{
|
||||
VarInfo (string l, const States& sts)
|
||||
: label(l), states(sts) { }
|
||||
string label;
|
||||
States states;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class Var
|
||||
{
|
||||
class Var {
|
||||
public:
|
||||
Var (const Var*);
|
||||
|
||||
Var (VarId, unsigned, int = Constants::NO_EVIDENCE);
|
||||
Var (VarId, unsigned range, int evidence = Constants::unobserved);
|
||||
|
||||
virtual ~Var (void) { };
|
||||
virtual ~Var() { };
|
||||
|
||||
VarId varId (void) const { return varId_; }
|
||||
VarId varId() const { return varId_; }
|
||||
|
||||
unsigned range (void) const { return range_; }
|
||||
unsigned range() const { return range_; }
|
||||
|
||||
int getEvidence (void) const { return evidence_; }
|
||||
int getEvidence() const { return evidence_; }
|
||||
|
||||
size_t getIndex (void) const { return index_; }
|
||||
size_t getIndex() const { return index_; }
|
||||
|
||||
void setIndex (size_t idx) { index_ = idx; }
|
||||
|
||||
bool hasEvidence (void) const
|
||||
{
|
||||
return evidence_ != Constants::NO_EVIDENCE;
|
||||
}
|
||||
bool hasEvidence() const;
|
||||
|
||||
operator size_t (void) const { return index_; }
|
||||
operator size_t() const;
|
||||
|
||||
bool operator== (const Var& var) const
|
||||
{
|
||||
assert (!(varId_ == var.varId() && range_ != var.range()));
|
||||
return varId_ == var.varId();
|
||||
}
|
||||
bool operator== (const Var& var) const;
|
||||
|
||||
bool operator!= (const Var& var) const
|
||||
{
|
||||
return !(*this == var);
|
||||
}
|
||||
bool operator!= (const Var& var) const;
|
||||
|
||||
bool isValidState (int);
|
||||
|
||||
void setEvidence (int);
|
||||
|
||||
string label (void) const;
|
||||
std::string label() const;
|
||||
|
||||
States states (void) const;
|
||||
States states() const;
|
||||
|
||||
static void addVarInfo (
|
||||
VarId vid, string label, const States& states)
|
||||
{
|
||||
assert (Util::contains (varsInfo_, vid) == false);
|
||||
varsInfo_.insert (make_pair (vid, VarInfo (label, states)));
|
||||
}
|
||||
VarId vid, std::string label, const States& states);
|
||||
|
||||
static VarInfo getVarInfo (VarId vid)
|
||||
{
|
||||
assert (Util::contains (varsInfo_, vid));
|
||||
return varsInfo_.find (vid)->second;
|
||||
}
|
||||
static bool varsHaveInfo();
|
||||
|
||||
static bool varsHaveInfo (void)
|
||||
{
|
||||
return varsInfo_.empty() == false;
|
||||
}
|
||||
|
||||
static void clearVarsInfo (void)
|
||||
{
|
||||
varsInfo_.clear();
|
||||
}
|
||||
static void clearVarsInfo();
|
||||
|
||||
private:
|
||||
typedef std::pair<std::string, States> VarInfo;
|
||||
|
||||
VarId varId_;
|
||||
unsigned range_;
|
||||
int evidence_;
|
||||
size_t index_;
|
||||
|
||||
static unordered_map<VarId, VarInfo> varsInfo_;
|
||||
static std::unordered_map<VarId, VarInfo> varsInfo_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN(Var);
|
||||
};
|
||||
|
||||
#endif // HORUS_VAR_H
|
||||
|
||||
|
||||
inline bool
|
||||
Var::hasEvidence() const
|
||||
{
|
||||
return evidence_ != Constants::unobserved;
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline
|
||||
Var::operator size_t() const
|
||||
{
|
||||
return index_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline bool
|
||||
Var::operator== (const Var& var) const
|
||||
{
|
||||
assert (!(varId_ == var.varId() && range_ != var.range()));
|
||||
return varId_ == var.varId();
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline bool
|
||||
Var::operator!= (const Var& var) const
|
||||
{
|
||||
return !(*this == var);
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
#endif // YAP_PACKAGES_CLPBN_HORUS_VAR_H_
|
||||
|
||||
|
@ -1,4 +1,6 @@
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "VarElim.h"
|
||||
#include "ElimGraph.h"
|
||||
@ -6,16 +8,18 @@
|
||||
#include "Util.h"
|
||||
|
||||
|
||||
namespace Horus {
|
||||
|
||||
Params
|
||||
VarElim::solveQuery (VarIds queryVids)
|
||||
{
|
||||
if (Globals::verbosity > 1) {
|
||||
cout << "Solving query on " ;
|
||||
std::cout << "Solving query on " ;
|
||||
for (size_t i = 0; i < queryVids.size(); i++) {
|
||||
if (i != 0) cout << ", " ;
|
||||
cout << fg.getVarNode (queryVids[i])->label();
|
||||
if (i != 0) std::cout << ", " ;
|
||||
std::cout << fg.getVarNode (queryVids[i])->label();
|
||||
}
|
||||
cout << endl;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
totalFactorSize_ = 0;
|
||||
largestFactorSize_ = 0;
|
||||
@ -33,27 +37,28 @@ VarElim::solveQuery (VarIds queryVids)
|
||||
|
||||
|
||||
void
|
||||
VarElim::printSolverFlags (void) const
|
||||
VarElim::printSolverFlags() const
|
||||
{
|
||||
stringstream ss;
|
||||
std::stringstream ss;
|
||||
ss << "variable elimination [" ;
|
||||
ss << "elim_heuristic=" ;
|
||||
typedef ElimGraph::ElimHeuristic ElimHeuristic;
|
||||
switch (ElimGraph::elimHeuristic()) {
|
||||
case ElimHeuristic::SEQUENTIAL: ss << "sequential"; break;
|
||||
case ElimHeuristic::MIN_NEIGHBORS: ss << "min_neighbors"; break;
|
||||
case ElimHeuristic::MIN_WEIGHT: ss << "min_weight"; break;
|
||||
case ElimHeuristic::MIN_FILL: ss << "min_fill"; break;
|
||||
case ElimHeuristic::WEIGHTED_MIN_FILL: ss << "weighted_min_fill"; break;
|
||||
case ElimHeuristic::sequentialEh: ss << "sequential"; break;
|
||||
case ElimHeuristic::minNeighborsEh: ss << "min_neighbors"; break;
|
||||
case ElimHeuristic::minWeightEh: ss << "min_weight"; break;
|
||||
case ElimHeuristic::minFillEh: ss << "min_fill"; break;
|
||||
case ElimHeuristic::weightedMinFillEh: ss << "weighted_min_fill"; break;
|
||||
}
|
||||
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
||||
ss << "]" ;
|
||||
cout << ss.str() << endl;
|
||||
std::cout << ss.str() << std::endl;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
VarElim::createFactorList (void)
|
||||
VarElim::createFactorList()
|
||||
{
|
||||
const FacNodes& facNodes = fg.facNodes();
|
||||
factorList_.reserve (facNodes.size() * 2);
|
||||
@ -61,7 +66,7 @@ VarElim::createFactorList (void)
|
||||
factorList_.push_back (new Factor (facNodes[i]->factor()));
|
||||
const VarIds& args = facNodes[i]->factor().arguments();
|
||||
for (size_t j = 0; j < args.size(); j++) {
|
||||
unordered_map<VarId, vector<size_t>>::iterator it;
|
||||
std::unordered_map<VarId, std::vector<size_t>>::iterator it;
|
||||
it = varMap_.find (args[j]);
|
||||
if (it != varMap_.end()) {
|
||||
it->second.push_back (i);
|
||||
@ -75,22 +80,22 @@ VarElim::createFactorList (void)
|
||||
|
||||
|
||||
void
|
||||
VarElim::absorveEvidence (void)
|
||||
VarElim::absorveEvidence()
|
||||
{
|
||||
if (Globals::verbosity > 2) {
|
||||
Util::printDashedLine();
|
||||
cout << "(initial factor list)" << endl;
|
||||
std::cout << "(initial factor list)" << std::endl;
|
||||
printActiveFactors();
|
||||
}
|
||||
const VarNodes& varNodes = fg.varNodes();
|
||||
for (size_t i = 0; i < varNodes.size(); i++) {
|
||||
if (varNodes[i]->hasEvidence()) {
|
||||
if (Globals::verbosity > 1) {
|
||||
cout << "-> aborving evidence on ";
|
||||
cout << varNodes[i]->label() << " = " ;
|
||||
cout << varNodes[i]->getEvidence() << endl;
|
||||
std::cout << "-> aborving evidence on ";
|
||||
std::cout << varNodes[i]->label() << " = " ;
|
||||
std::cout << varNodes[i]->getEvidence() << std::endl;
|
||||
}
|
||||
const vector<size_t>& indices = varMap_[varNodes[i]->varId()];
|
||||
const std::vector<size_t>& indices = varMap_[varNodes[i]->varId()];
|
||||
for (size_t j = 0; j < indices.size(); j++) {
|
||||
size_t idx = indices[j];
|
||||
if (factorList_[idx]->nrArguments() > 1) {
|
||||
@ -118,8 +123,8 @@ VarElim::processFactorList (const VarIds& queryVids)
|
||||
Util::printDashedLine();
|
||||
printActiveFactors();
|
||||
}
|
||||
cout << "-> summing out " ;
|
||||
cout << fg.getVarNode (elimOrder[i])->label() << endl;
|
||||
std::cout << "-> summing out " ;
|
||||
std::cout << fg.getVarNode (elimOrder[i])->label() << std::endl;
|
||||
}
|
||||
eliminate (elimOrder[i]);
|
||||
}
|
||||
@ -143,9 +148,9 @@ VarElim::processFactorList (const VarIds& queryVids)
|
||||
result.reorderArguments (unobservedVids);
|
||||
result.normalize();
|
||||
if (Globals::verbosity > 0) {
|
||||
cout << "total factor size: " << totalFactorSize_ << endl;
|
||||
cout << "largest factor size: " << largestFactorSize_ << endl;
|
||||
cout << endl;
|
||||
std::cout << "total factor size: " << totalFactorSize_ << std::endl;
|
||||
std::cout << "largest factor size: " << largestFactorSize_ << std::endl;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
return result.params();
|
||||
}
|
||||
@ -156,7 +161,7 @@ void
|
||||
VarElim::eliminate (VarId vid)
|
||||
{
|
||||
Factor* result = new Factor();
|
||||
const vector<size_t>& indices = varMap_[vid];
|
||||
const std::vector<size_t>& indices = varMap_[vid];
|
||||
for (size_t i = 0; i < indices.size(); i++) {
|
||||
size_t idx = indices[i];
|
||||
if (factorList_[idx]) {
|
||||
@ -173,7 +178,7 @@ VarElim::eliminate (VarId vid)
|
||||
result->sumOut (vid);
|
||||
const VarIds& args = result->arguments();
|
||||
for (size_t i = 0; i < args.size(); i++) {
|
||||
vector<size_t>& indices2 = varMap_[args[i]];
|
||||
std::vector<size_t>& indices2 = varMap_[args[i]];
|
||||
indices2.push_back (factorList_.size());
|
||||
}
|
||||
factorList_.push_back (result);
|
||||
@ -185,14 +190,16 @@ VarElim::eliminate (VarId vid)
|
||||
|
||||
|
||||
void
|
||||
VarElim::printActiveFactors (void)
|
||||
VarElim::printActiveFactors()
|
||||
{
|
||||
for (size_t i = 0; i < factorList_.size(); i++) {
|
||||
if (factorList_[i]) {
|
||||
cout << factorList_[i]->getLabel() << " " ;
|
||||
cout << factorList_[i]->params();
|
||||
cout << endl;
|
||||
std::cout << factorList_[i]->getLabel() << " " ;
|
||||
std::cout << factorList_[i]->params();
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace Horus
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user