Merge branch 'master' of https://github.com/tacgomes/yap6.3
Conflicts: packages/CLPBN/clpbn/horus.yap
This commit is contained in:
commit
3669cb894f
28
C/index.c
28
C/index.c
@ -3298,6 +3298,28 @@ code_to_indexcl(yamop *ipc, int is_lu)
|
|||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void
|
||||||
|
increase_expand_depth(yamop *ipc, struct intermediates *cint)
|
||||||
|
{
|
||||||
|
yamop *ncode;
|
||||||
|
|
||||||
|
cint->term_depth++;
|
||||||
|
if (/* ipc->opc == Yap_opcode(_switch_on_sub_arg_type) && */
|
||||||
|
(ncode = ipc->u.sllll.l4)->opc == Yap_opcode(_expand_clauses)) {
|
||||||
|
if (ncode->u.sssllp.s2 != cint->last_depth_size) {
|
||||||
|
cint->last_index_new_depth = cint->term_depth;
|
||||||
|
cint->last_depth_size = ncode->u.sssllp.s2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void
|
||||||
|
zero_expand_depth(PredEntry *ap, struct intermediates *cint)
|
||||||
|
{
|
||||||
|
cint->term_depth = cint->last_index_new_depth;
|
||||||
|
cint->last_depth_size = ap->cs.p_code.NOfClauses;
|
||||||
|
}
|
||||||
|
|
||||||
static yamop **
|
static yamop **
|
||||||
expand_index(struct intermediates *cint) {
|
expand_index(struct intermediates *cint) {
|
||||||
CACHE_REGS
|
CACHE_REGS
|
||||||
@ -3499,6 +3521,7 @@ expand_index(struct intermediates *cint) {
|
|||||||
break;
|
break;
|
||||||
/* instructions type e */
|
/* instructions type e */
|
||||||
case _switch_on_type:
|
case _switch_on_type:
|
||||||
|
zero_expand_depth(ap, cint);
|
||||||
t = Deref(ARG1);
|
t = Deref(ARG1);
|
||||||
argno = 1;
|
argno = 1;
|
||||||
i = 0;
|
i = 0;
|
||||||
@ -3520,6 +3543,7 @@ expand_index(struct intermediates *cint) {
|
|||||||
parentcl = index_jmp(parentcl, parentcl, ipc, is_lu, e_code);
|
parentcl = index_jmp(parentcl, parentcl, ipc, is_lu, e_code);
|
||||||
break;
|
break;
|
||||||
case _switch_list_nl:
|
case _switch_list_nl:
|
||||||
|
zero_expand_depth(ap, cint);
|
||||||
t = Deref(ARG1);
|
t = Deref(ARG1);
|
||||||
argno = 1;
|
argno = 1;
|
||||||
i = 0;
|
i = 0;
|
||||||
@ -3548,6 +3572,7 @@ expand_index(struct intermediates *cint) {
|
|||||||
parentcl = index_jmp(parentcl, parentcl, ipc, is_lu, e_code);
|
parentcl = index_jmp(parentcl, parentcl, ipc, is_lu, e_code);
|
||||||
break;
|
break;
|
||||||
case _switch_on_arg_type:
|
case _switch_on_arg_type:
|
||||||
|
zero_expand_depth(ap, cint);
|
||||||
argno = arg_from_x(ipc->u.xllll.x);
|
argno = arg_from_x(ipc->u.xllll.x);
|
||||||
i = 0;
|
i = 0;
|
||||||
t = Deref(XREGS[argno]);
|
t = Deref(XREGS[argno]);
|
||||||
@ -3578,12 +3603,14 @@ expand_index(struct intermediates *cint) {
|
|||||||
ipc = ipc->u.sllll.l4;
|
ipc = ipc->u.sllll.l4;
|
||||||
i++;
|
i++;
|
||||||
} else if (IsPairTerm(t)) {
|
} else if (IsPairTerm(t)) {
|
||||||
|
increase_expand_depth(ipc, cint);
|
||||||
s_reg = RepPair(t);
|
s_reg = RepPair(t);
|
||||||
sp = push_stack(sp, -i-1, AbsPair(NULL), TermNil, cint);
|
sp = push_stack(sp, -i-1, AbsPair(NULL), TermNil, cint);
|
||||||
labp = &(ipc->u.sllll.l1);
|
labp = &(ipc->u.sllll.l1);
|
||||||
ipc = ipc->u.sllll.l1;
|
ipc = ipc->u.sllll.l1;
|
||||||
i = 0;
|
i = 0;
|
||||||
} else if (IsApplTerm(t)) {
|
} else if (IsApplTerm(t)) {
|
||||||
|
increase_expand_depth(ipc, cint);
|
||||||
sp = push_stack(sp, -i-1, AbsAppl((CELL *)FunctorOfTerm(t)), TermNil, cint);
|
sp = push_stack(sp, -i-1, AbsAppl((CELL *)FunctorOfTerm(t)), TermNil, cint);
|
||||||
ipc = ipc->u.sllll.l3;
|
ipc = ipc->u.sllll.l3;
|
||||||
i = 0;
|
i = 0;
|
||||||
@ -3591,6 +3618,7 @@ expand_index(struct intermediates *cint) {
|
|||||||
/* We don't push stack here, instead we go over to next argument
|
/* We don't push stack here, instead we go over to next argument
|
||||||
sp = push_stack(sp, -i-1, t, cint);
|
sp = push_stack(sp, -i-1, t, cint);
|
||||||
*/
|
*/
|
||||||
|
increase_expand_depth(ipc, cint);
|
||||||
sp = push_stack(sp, -i-1, t, TermNil, cint);
|
sp = push_stack(sp, -i-1, t, TermNil, cint);
|
||||||
ipc = ipc->u.sllll.l2;
|
ipc = ipc->u.sllll.l2;
|
||||||
i++;
|
i++;
|
||||||
|
@ -706,7 +706,7 @@ all: startup.yss
|
|||||||
@ENABLE_SEMWEB@ @INSTALL_DLLS@ (cd packages/semweb; $(MAKE))
|
@ENABLE_SEMWEB@ @INSTALL_DLLS@ (cd packages/semweb; $(MAKE))
|
||||||
@ENABLE_SGML@ @INSTALL_DLLS@ (cd packages/sgml; $(MAKE))
|
@ENABLE_SGML@ @INSTALL_DLLS@ (cd packages/sgml; $(MAKE))
|
||||||
@ENABLE_REAL@ (cd packages/real; $(MAKE))
|
@ENABLE_REAL@ (cd packages/real; $(MAKE))
|
||||||
@ENABLE_CLPBN_BP@ (cd packages/CLPBN/clpbn/bp ; $(MAKE))
|
@ENABLE_CLPBN_BP@ (cd packages/CLPBN/horus; $(MAKE))
|
||||||
@ENABLE_MINISAT@ (cd packages/swi-minisat2/C; $(MAKE))
|
@ENABLE_MINISAT@ (cd packages/swi-minisat2/C; $(MAKE))
|
||||||
@ENABLE_ZLIB@ @INSTALL_DLLS@ (cd packages/zlib; $(MAKE))
|
@ENABLE_ZLIB@ @INSTALL_DLLS@ (cd packages/zlib; $(MAKE))
|
||||||
@ENABLE_CPLINT@ (cd packages/cplint/approx/simplecuddLPADs; $(MAKE))
|
@ENABLE_CPLINT@ (cd packages/cplint/approx/simplecuddLPADs; $(MAKE))
|
||||||
@ -775,7 +775,7 @@ install_unix: startup.yss libYap.a
|
|||||||
@ENABLE_SEMWEB@ @INSTALL_DLLS@ (cd packages/semweb; $(MAKE) install)
|
@ENABLE_SEMWEB@ @INSTALL_DLLS@ (cd packages/semweb; $(MAKE) install)
|
||||||
@ENABLE_SGML@ @INSTALL_DLLS@ (cd packages/sgml; $(MAKE) install)
|
@ENABLE_SGML@ @INSTALL_DLLS@ (cd packages/sgml; $(MAKE) install)
|
||||||
@ENABLE_ZLIB@ @INSTALL_DLLS@ (cd packages/zlib; $(MAKE) @ZLIB_INSTALL@)
|
@ENABLE_ZLIB@ @INSTALL_DLLS@ (cd packages/zlib; $(MAKE) @ZLIB_INSTALL@)
|
||||||
@ENABLE_CLPBN_BP@ @INSTALL_DLLS@ (cd packages/CLPBN/clpbn/bp ; $(MAKE) install)
|
@ENABLE_CLPBN_BP@ @INSTALL_DLLS@ (cd packages/CLPBN/horus; $(MAKE) install)
|
||||||
@ENABLE_MINISAT@ (cd packages/swi-minisat2/C; $(MAKE) install)
|
@ENABLE_MINISAT@ (cd packages/swi-minisat2/C; $(MAKE) install)
|
||||||
@INSTALL_MATLAB@ (cd library/matlab; $(MAKE) install)
|
@INSTALL_MATLAB@ (cd library/matlab; $(MAKE) install)
|
||||||
@ENABLE_REAL@ (cd packages/real; $(MAKE) install)
|
@ENABLE_REAL@ (cd packages/real; $(MAKE) install)
|
||||||
@ -839,7 +839,7 @@ install_win32: startup.yss @ENABLE_WINCONSOLE@ pl-yap@EXEC_SUFFIX@
|
|||||||
@ENABLE_SGML@ (cd packages/sgml; $(MAKE) install)
|
@ENABLE_SGML@ (cd packages/sgml; $(MAKE) install)
|
||||||
@ENABLE_ZLIB@ (cd packages/zlib; $(MAKE) @ZLIB_INSTALL@)
|
@ENABLE_ZLIB@ (cd packages/zlib; $(MAKE) @ZLIB_INSTALL@)
|
||||||
(cd packages/CLPBN ; $(MAKE) install)
|
(cd packages/CLPBN ; $(MAKE) install)
|
||||||
@ENABLE_CLPBN_BP@ (cd packages/CLPBN/clpbn/bp ; $(MAKE) install)
|
@ENABLE_CLPBN_BP@ (cd packages/CLPBN/horus; $(MAKE) install)
|
||||||
@ENABLE_JPL@ (cd packages/jpl ; $(MAKE) install)
|
@ENABLE_JPL@ (cd packages/jpl ; $(MAKE) install)
|
||||||
@ENABLE_MINISAT@ (cd packages/swi-minisat2/C; $(MAKE) install)
|
@ENABLE_MINISAT@ (cd packages/swi-minisat2/C; $(MAKE) install)
|
||||||
@ENABLE_CPLINT@ (cd packages/cplint; $(MAKE) install)
|
@ENABLE_CPLINT@ (cd packages/cplint; $(MAKE) install)
|
||||||
@ -904,7 +904,7 @@ clean: clean_docs
|
|||||||
@ENABLE_SGML@ @INSTALL_DLLS@ (cd packages/sgml; $(MAKE) clean)
|
@ENABLE_SGML@ @INSTALL_DLLS@ (cd packages/sgml; $(MAKE) clean)
|
||||||
@ENABLE_REAL@ (cd packages/real; $(MAKE) clean)
|
@ENABLE_REAL@ (cd packages/real; $(MAKE) clean)
|
||||||
@ENABLE_MINISAT@ (cd packages/swi-minisat2; $(MAKE) clean)
|
@ENABLE_MINISAT@ (cd packages/swi-minisat2; $(MAKE) clean)
|
||||||
@ENABLE_CLPBN_BP@ (cd packages/CLPBN/clpbn/bp; $(MAKE) clean)
|
@ENABLE_CLPBN_BP@ (cd packages/CLPBN/horus; $(MAKE) clean)
|
||||||
@ENABLE_ZLIB@ @INSTALL_DLLS@ (cd packages/zlib; $(MAKE) clean)
|
@ENABLE_ZLIB@ @INSTALL_DLLS@ (cd packages/zlib; $(MAKE) clean)
|
||||||
@ENABLE_PRISM@ (cd packages/prism/src/c; $(MAKE) clean)
|
@ENABLE_PRISM@ (cd packages/prism/src/c; $(MAKE) clean)
|
||||||
@ENABLE_PRISM@ (cd packages/prism/src/prolog; $(MAKE) clean)
|
@ENABLE_PRISM@ (cd packages/prism/src/prolog; $(MAKE) clean)
|
||||||
|
8
configure
vendored
8
configure
vendored
@ -10519,8 +10519,7 @@ mkdir -p packages/clib/maildrop/rfc822
|
|||||||
mkdir -p packages/clib/maildrop/rfc2045
|
mkdir -p packages/clib/maildrop/rfc2045
|
||||||
mkdir -p packages/CLPBN
|
mkdir -p packages/CLPBN
|
||||||
mkdir -p packages/CLPBN/clpbn
|
mkdir -p packages/CLPBN/clpbn
|
||||||
mkdir -p packages/CLPBN/clpbn/bp
|
mkdir -p packages/CLPBN/horus
|
||||||
mkdir -p packages/CLPBN/clpbn/bp/xmlParser
|
|
||||||
mkdir -p packages/clpqr
|
mkdir -p packages/clpqr
|
||||||
mkdir -p packages/cplint
|
mkdir -p packages/cplint
|
||||||
mkdir -p packages/cplint/approx
|
mkdir -p packages/cplint/approx
|
||||||
@ -10673,6 +10672,7 @@ ac_config_files="$ac_config_files packages/zlib/Makefile"
|
|||||||
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
||||||
if test "$ENABLE_CUDD" = ""; then
|
if test "$ENABLE_CUDD" = ""; then
|
||||||
ac_config_files="$ac_config_files packages/bdd/Makefile"
|
ac_config_files="$ac_config_files packages/bdd/Makefile"
|
||||||
|
|
||||||
@ -10695,7 +10695,7 @@ ac_config_files="$ac_config_files packages/real/Makefile"
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
if test "$ENABLE_CLPBN_BP" = ""; then
|
if test "$ENABLE_CLPBN_BP" = ""; then
|
||||||
ac_config_files="$ac_config_files packages/CLPBN/clpbn/bp/Makefile"
|
ac_config_files="$ac_config_files packages/CLPBN/horus/Makefile"
|
||||||
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
@ -11464,7 +11464,7 @@ do
|
|||||||
"packages/swi-minisat2/Makefile") CONFIG_FILES="$CONFIG_FILES packages/swi-minisat2/Makefile" ;;
|
"packages/swi-minisat2/Makefile") CONFIG_FILES="$CONFIG_FILES packages/swi-minisat2/Makefile" ;;
|
||||||
"packages/swi-minisat2/C/Makefile") CONFIG_FILES="$CONFIG_FILES packages/swi-minisat2/C/Makefile" ;;
|
"packages/swi-minisat2/C/Makefile") CONFIG_FILES="$CONFIG_FILES packages/swi-minisat2/C/Makefile" ;;
|
||||||
"packages/real/Makefile") CONFIG_FILES="$CONFIG_FILES packages/real/Makefile" ;;
|
"packages/real/Makefile") CONFIG_FILES="$CONFIG_FILES packages/real/Makefile" ;;
|
||||||
"packages/CLPBN/clpbn/bp/Makefile") CONFIG_FILES="$CONFIG_FILES packages/CLPBN/clpbn/bp/Makefile" ;;
|
"packages/CLPBN/horus/Makefile") CONFIG_FILES="$CONFIG_FILES packages/CLPBN/horus/Makefile" ;;
|
||||||
"library/gecode/Makefile") CONFIG_FILES="$CONFIG_FILES library/gecode/Makefile" ;;
|
"library/gecode/Makefile") CONFIG_FILES="$CONFIG_FILES library/gecode/Makefile" ;;
|
||||||
"packages/prism/src/c/Makefile") CONFIG_FILES="$CONFIG_FILES packages/prism/src/c/Makefile" ;;
|
"packages/prism/src/c/Makefile") CONFIG_FILES="$CONFIG_FILES packages/prism/src/c/Makefile" ;;
|
||||||
"packages/prism/src/prolog/Makefile") CONFIG_FILES="$CONFIG_FILES packages/prism/src/prolog/Makefile" ;;
|
"packages/prism/src/prolog/Makefile") CONFIG_FILES="$CONFIG_FILES packages/prism/src/prolog/Makefile" ;;
|
||||||
|
@ -2300,8 +2300,7 @@ mkdir -p packages/clib/maildrop/rfc822
|
|||||||
mkdir -p packages/clib/maildrop/rfc2045
|
mkdir -p packages/clib/maildrop/rfc2045
|
||||||
mkdir -p packages/CLPBN
|
mkdir -p packages/CLPBN
|
||||||
mkdir -p packages/CLPBN/clpbn
|
mkdir -p packages/CLPBN/clpbn
|
||||||
mkdir -p packages/CLPBN/clpbn/bp
|
mkdir -p packages/CLPBN/horus
|
||||||
mkdir -p packages/CLPBN/clpbn/bp/xmlParser
|
|
||||||
mkdir -p packages/clpqr
|
mkdir -p packages/clpqr
|
||||||
mkdir -p packages/cplint
|
mkdir -p packages/cplint
|
||||||
mkdir -p packages/cplint/approx
|
mkdir -p packages/cplint/approx
|
||||||
@ -2415,6 +2414,7 @@ if test "$ENABLE_ZLIB" = ""; then
|
|||||||
AC_CONFIG_FILES([packages/zlib/Makefile])
|
AC_CONFIG_FILES([packages/zlib/Makefile])
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
||||||
if test "$ENABLE_CUDD" = ""; then
|
if test "$ENABLE_CUDD" = ""; then
|
||||||
AC_CONFIG_FILES([packages/bdd/Makefile])
|
AC_CONFIG_FILES([packages/bdd/Makefile])
|
||||||
AC_CONFIG_FILES([packages/ProbLog/simplecudd/Makefile])
|
AC_CONFIG_FILES([packages/ProbLog/simplecudd/Makefile])
|
||||||
@ -2431,7 +2431,7 @@ AC_CONFIG_FILES([packages/real/Makefile])
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
if test "$ENABLE_CLPBN_BP" = ""; then
|
if test "$ENABLE_CLPBN_BP" = ""; then
|
||||||
AC_CONFIG_FILES([packages/CLPBN/clpbn/bp/Makefile])
|
AC_CONFIG_FILES([packages/CLPBN/horus/Makefile])
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if test "$ENABLE_GECODE" = ""; then
|
if test "$ENABLE_GECODE" = ""; then
|
||||||
|
@ -42,19 +42,19 @@ CLPBN_PROGRAMS= \
|
|||||||
$(CLPBN_SRCDIR)/aggregates.yap \
|
$(CLPBN_SRCDIR)/aggregates.yap \
|
||||||
$(CLPBN_SRCDIR)/bdd.yap \
|
$(CLPBN_SRCDIR)/bdd.yap \
|
||||||
$(CLPBN_SRCDIR)/bnt.yap \
|
$(CLPBN_SRCDIR)/bnt.yap \
|
||||||
$(CLPBN_SRCDIR)/bp.yap \
|
|
||||||
$(CLPBN_SRCDIR)/connected.yap \
|
$(CLPBN_SRCDIR)/connected.yap \
|
||||||
$(CLPBN_SRCDIR)/discrete_utils.yap \
|
$(CLPBN_SRCDIR)/discrete_utils.yap \
|
||||||
$(CLPBN_SRCDIR)/display.yap \
|
$(CLPBN_SRCDIR)/display.yap \
|
||||||
$(CLPBN_SRCDIR)/dists.yap \
|
$(CLPBN_SRCDIR)/dists.yap \
|
||||||
$(CLPBN_SRCDIR)/evidence.yap \
|
$(CLPBN_SRCDIR)/evidence.yap \
|
||||||
$(CLPBN_SRCDIR)/fove.yap \
|
|
||||||
$(CLPBN_SRCDIR)/gibbs.yap \
|
$(CLPBN_SRCDIR)/gibbs.yap \
|
||||||
$(CLPBN_SRCDIR)/graphs.yap \
|
$(CLPBN_SRCDIR)/graphs.yap \
|
||||||
$(CLPBN_SRCDIR)/graphviz.yap \
|
$(CLPBN_SRCDIR)/graphviz.yap \
|
||||||
$(CLPBN_SRCDIR)/ground_factors.yap \
|
$(CLPBN_SRCDIR)/ground_factors.yap \
|
||||||
$(CLPBN_SRCDIR)/hmm.yap \
|
$(CLPBN_SRCDIR)/hmm.yap \
|
||||||
$(CLPBN_SRCDIR)/horus.yap \
|
$(CLPBN_SRCDIR)/horus.yap \
|
||||||
|
$(CLPBN_SRCDIR)/horus_ground.yap \
|
||||||
|
$(CLPBN_SRCDIR)/horus_lifted.yap \
|
||||||
$(CLPBN_SRCDIR)/jt.yap \
|
$(CLPBN_SRCDIR)/jt.yap \
|
||||||
$(CLPBN_SRCDIR)/matrix_cpt_utils.yap \
|
$(CLPBN_SRCDIR)/matrix_cpt_utils.yap \
|
||||||
$(CLPBN_SRCDIR)/pgrammar.yap \
|
$(CLPBN_SRCDIR)/pgrammar.yap \
|
||||||
@ -94,8 +94,16 @@ CLPBN_HMMER_EXAMPLES= \
|
|||||||
$(CLPBN_EXDIR)/HMMer/score.yap
|
$(CLPBN_EXDIR)/HMMer/score.yap
|
||||||
|
|
||||||
CLPBN_EXAMPLES= \
|
CLPBN_EXAMPLES= \
|
||||||
|
$(CLPBN_EXDIR)/burglary-alarm.fg \
|
||||||
|
$(CLPBN_EXDIR)/burglary-alarm.yap \
|
||||||
|
$(CLPBN_EXDIR)/burglary-alarm.uai \
|
||||||
$(CLPBN_EXDIR)/cg.yap \
|
$(CLPBN_EXDIR)/cg.yap \
|
||||||
$(CLPBN_EXDIR)/sprinkler.yap
|
$(CLPBN_EXDIR)/city.yap \
|
||||||
|
$(CLPBN_EXDIR)/comp_workshops.yap \
|
||||||
|
$(CLPBN_EXDIR)/social_domain1.yap \
|
||||||
|
$(CLPBN_EXDIR)/social_domain2.yap \
|
||||||
|
$(CLPBN_EXDIR)/sprinkler.yap \
|
||||||
|
$(CLPBN_EXDIR)/workshop_attrs.yap
|
||||||
|
|
||||||
|
|
||||||
install: $(CLBN_TOP) $(CLBN_PROGRAMS) $(CLPBN_PROGRAMS)
|
install: $(CLBN_TOP) $(CLBN_PROGRAMS) $(CLPBN_PROGRAMS)
|
||||||
|
78
packages/CLPBN/benchmarks/benchs.sh
Executable file
78
packages/CLPBN/benchmarks/benchs.sh
Executable file
@ -0,0 +1,78 @@
|
|||||||
|
|
||||||
|
|
||||||
|
function prepare_new_run
|
||||||
|
{
|
||||||
|
YAP=~/bin/$SHORTNAME-$SOLVER
|
||||||
|
|
||||||
|
LOG_FILE=$SOLVER.log
|
||||||
|
#LOG_FILE=results`date "+ %H:%M:%S %d-%m-%Y"`.
|
||||||
|
|
||||||
|
rm -f $LOG_FILE
|
||||||
|
rm -f ignore.$LOG_FILE
|
||||||
|
|
||||||
|
cp ~/bin/yap $YAP
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
function run_solver
|
||||||
|
{
|
||||||
|
constraint=$1
|
||||||
|
solver_flag=true
|
||||||
|
if [ -n "$2" ]; then
|
||||||
|
if [ $SOLVER = hve ]; then
|
||||||
|
solver_flag=clpbn_horus:set_horus_flag\(elim_heuristic,$2\)
|
||||||
|
elif [ $SOLVER = bp ]; then
|
||||||
|
solver_flag=clpbn_horus:set_horus_flag\(schedule,$2\)
|
||||||
|
elif [ $SOLVER = cbp ]; then
|
||||||
|
solver_flag=clpbn_horus:set_horus_flag\(schedule,$2\)
|
||||||
|
else
|
||||||
|
echo "unknow flag $2"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
/usr/bin/time -o $LOG_FILE -a -f "%U\t%S\t%e\t%M" \
|
||||||
|
$YAP << EOF >> $LOG_FILE &>> ignore.$LOG_FILE
|
||||||
|
nogc.
|
||||||
|
[$NETWORK].
|
||||||
|
[$constraint].
|
||||||
|
clpbn_horus:set_solver($SOLVER).
|
||||||
|
clpbn_horus:set_horus_flag(use_logarithms, true).
|
||||||
|
clpbn_horus:set_horus_flag(verbosity, 1).
|
||||||
|
$solver_flag.
|
||||||
|
$QUERY.
|
||||||
|
open("$LOG_FILE", 'append', S), format(S, '$constraint ~15+ ', []), close(S).
|
||||||
|
EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
function clear_log_files
|
||||||
|
{
|
||||||
|
rm -f *~
|
||||||
|
rm -f ../*~
|
||||||
|
rm -f school/*.log school/*~
|
||||||
|
rm -f ../school/*.log ../school/*~
|
||||||
|
rm -f city/*.log city/*~
|
||||||
|
rm -f ../city/*.log ../city/*~
|
||||||
|
rm -f workshop_attrs/*.log workshop_attrs/*~
|
||||||
|
rm -f ../workshop_attrs/*.log ../workshop_attrs/*~
|
||||||
|
echo all done!
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
function write_header
|
||||||
|
{
|
||||||
|
echo -n "****************************************" >> $LOG_FILE
|
||||||
|
echo "****************************************" >> $LOG_FILE
|
||||||
|
echo "results for solver $1 user(s) sys(s) real(s), mem(kB)" >> $LOG_FILE
|
||||||
|
echo -n "****************************************" >> $LOG_FILE
|
||||||
|
echo "****************************************" >> $LOG_FILE
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if [ $1 ] && [ $1 == "clean" ]; then
|
||||||
|
clear_log_files
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
37
packages/CLPBN/benchmarks/city/bp_tests.sh
Executable file
37
packages/CLPBN/benchmarks/city/bp_tests.sh
Executable file
@ -0,0 +1,37 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
source city.sh
|
||||||
|
source ../benchs.sh
|
||||||
|
|
||||||
|
SOLVER="bp"
|
||||||
|
|
||||||
|
function run_all_graphs
|
||||||
|
{
|
||||||
|
write_header $1
|
||||||
|
run_solver city1000 $2
|
||||||
|
run_solver city5000 $2
|
||||||
|
run_solver city10000 $2
|
||||||
|
run_solver city15000 $2
|
||||||
|
run_solver city20000 $2
|
||||||
|
run_solver city25000 $2
|
||||||
|
run_solver city30000 $2
|
||||||
|
run_solver city35000 $2
|
||||||
|
run_solver city40000 $2
|
||||||
|
run_solver city45000 $2
|
||||||
|
run_solver city50000 $2
|
||||||
|
run_solver city55000 $2
|
||||||
|
run_solver city60000 $2
|
||||||
|
run_solver city65000 $2
|
||||||
|
return
|
||||||
|
run_solver city70000 $2
|
||||||
|
run_solver city75000 $2
|
||||||
|
run_solver city80000 $2
|
||||||
|
run_solver city85000 $2
|
||||||
|
run_solver city90000 $2
|
||||||
|
run_solver city95000 $2
|
||||||
|
run_solver city100000 $2
|
||||||
|
}
|
||||||
|
|
||||||
|
prepare_new_run
|
||||||
|
run_all_graphs "bp(shedule=seq_fixed) " seq_fixed
|
||||||
|
|
36
packages/CLPBN/benchmarks/city/cbp_tests.sh
Executable file
36
packages/CLPBN/benchmarks/city/cbp_tests.sh
Executable file
@ -0,0 +1,36 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
source city.sh
|
||||||
|
source ../benchs.sh
|
||||||
|
|
||||||
|
SOLVER="cbp"
|
||||||
|
|
||||||
|
function run_all_graphs
|
||||||
|
{
|
||||||
|
write_header $1
|
||||||
|
run_solver city1000 $2
|
||||||
|
run_solver city5000 $2
|
||||||
|
run_solver city10000 $2
|
||||||
|
run_solver city15000 $2
|
||||||
|
run_solver city20000 $2
|
||||||
|
run_solver city25000 $2
|
||||||
|
run_solver city30000 $2
|
||||||
|
run_solver city35000 $2
|
||||||
|
run_solver city40000 $2
|
||||||
|
run_solver city45000 $2
|
||||||
|
run_solver city50000 $2
|
||||||
|
run_solver city55000 $2
|
||||||
|
run_solver city60000 $2
|
||||||
|
run_solver city65000 $2
|
||||||
|
run_solver city70000 $2
|
||||||
|
run_solver city75000 $2
|
||||||
|
run_solver city80000 $2
|
||||||
|
run_solver city85000 $2
|
||||||
|
run_solver city90000 $2
|
||||||
|
run_solver city95000 $2
|
||||||
|
run_solver city100000 $2
|
||||||
|
}
|
||||||
|
|
||||||
|
prepare_new_run
|
||||||
|
run_all_graphs "cbp(shedule=seq_fixed) " seq_fixed
|
||||||
|
|
6
packages/CLPBN/benchmarks/city/city.sh
Executable file
6
packages/CLPBN/benchmarks/city/city.sh
Executable file
@ -0,0 +1,6 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
NETWORK="'../../examples/city'"
|
||||||
|
SHORTNAME="city"
|
||||||
|
QUERY="is_joe_guilty(X)"
|
||||||
|
|
36
packages/CLPBN/benchmarks/city/fove_tests.sh
Executable file
36
packages/CLPBN/benchmarks/city/fove_tests.sh
Executable file
@ -0,0 +1,36 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
source city.sh
|
||||||
|
source ../benchs.sh
|
||||||
|
|
||||||
|
SOLVER="fove"
|
||||||
|
|
||||||
|
function run_all_graphs
|
||||||
|
{
|
||||||
|
write_header $1
|
||||||
|
run_solver city1000 $2
|
||||||
|
run_solver city5000 $2
|
||||||
|
run_solver city10000 $2
|
||||||
|
run_solver city15000 $2
|
||||||
|
run_solver city20000 $2
|
||||||
|
run_solver city25000 $2
|
||||||
|
run_solver city30000 $2
|
||||||
|
run_solver city35000 $2
|
||||||
|
run_solver city40000 $2
|
||||||
|
run_solver city45000 $2
|
||||||
|
run_solver city50000 $2
|
||||||
|
run_solver city55000 $2
|
||||||
|
run_solver city60000 $2
|
||||||
|
run_solver city65000 $2
|
||||||
|
run_solver city70000 $2
|
||||||
|
run_solver city75000 $2
|
||||||
|
run_solver city80000 $2
|
||||||
|
run_solver city85000 $2
|
||||||
|
run_solver city90000 $2
|
||||||
|
run_solver city95000 $2
|
||||||
|
run_solver city100000 $2
|
||||||
|
}
|
||||||
|
|
||||||
|
prepare_new_run
|
||||||
|
run_all_graphs "fove "
|
||||||
|
|
@ -1,21 +1,17 @@
|
|||||||
#!/home/tiago/bin/yap -L --
|
#! /home/tgomes/bin/yap -L --
|
||||||
|
|
||||||
|
|
||||||
:- initialization(main).
|
:- initialization(main).
|
||||||
|
|
||||||
|
|
||||||
main :-
|
main :-
|
||||||
unix(argv([H])),
|
unix(argv([N])),
|
||||||
generate_town(H).
|
atomic_concat(['city', N, '.yap'], FileName),
|
||||||
|
|
||||||
|
|
||||||
generate_town(N) :-
|
|
||||||
atomic_concat(['city_', N, '.yap'], FileName),
|
|
||||||
open(FileName, 'write', S),
|
open(FileName, 'write', S),
|
||||||
atom_number(N, N2),
|
atom_number(N, N2),
|
||||||
generate_people(S, N2, 4),
|
generate_people(S, N2, 1),
|
||||||
write(S, '\n'),
|
write(S, '\n'),
|
||||||
generate_query(S, N2, 4),
|
generate_evidence(S, N2, 1),
|
||||||
write(S, '\n'),
|
write(S, '\n'),
|
||||||
close(S).
|
close(S).
|
||||||
|
|
||||||
@ -28,10 +24,10 @@ generate_people(S, N, Counting) :-
|
|||||||
generate_people(S, N, Counting1).
|
generate_people(S, N, Counting1).
|
||||||
|
|
||||||
|
|
||||||
generate_query(S, N, Counting) :-
|
generate_evidence(S, N, Counting) :-
|
||||||
Counting > N, !.
|
Counting > N, !.
|
||||||
generate_query(S, N, Counting) :- !,
|
generate_evidence(S, N, Counting) :- !,
|
||||||
format(S, 'ev(descn(p~w, t)).~n', [Counting]),
|
format(S, 'ev(descn(p~w, t)).~n', [Counting]),
|
||||||
Counting1 is Counting + 1,
|
Counting1 is Counting + 1,
|
||||||
generate_query(S, N, Counting1).
|
generate_evidence(S, N, Counting1).
|
||||||
|
|
37
packages/CLPBN/benchmarks/city/hve_tests.sh
Executable file
37
packages/CLPBN/benchmarks/city/hve_tests.sh
Executable file
@ -0,0 +1,37 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
source city.sh
|
||||||
|
source ../benchs.sh
|
||||||
|
|
||||||
|
SOLVER="hve"
|
||||||
|
|
||||||
|
function run_all_graphs
|
||||||
|
{
|
||||||
|
write_header $1
|
||||||
|
run_solver city1000 $2
|
||||||
|
run_solver city5000 $2
|
||||||
|
run_solver city10000 $2
|
||||||
|
run_solver city15000 $2
|
||||||
|
run_solver city20000 $2
|
||||||
|
run_solver city25000 $2
|
||||||
|
run_solver city30000 $2
|
||||||
|
run_solver city35000 $2
|
||||||
|
run_solver city40000 $2
|
||||||
|
run_solver city45000 $2
|
||||||
|
run_solver city50000 $2
|
||||||
|
run_solver city55000 $2
|
||||||
|
run_solver city60000 $2
|
||||||
|
run_solver city65000 $2
|
||||||
|
run_solver city70000 $2
|
||||||
|
|
||||||
|
run_solver city75000 $2
|
||||||
|
run_solver city80000 $2
|
||||||
|
run_solver city85000 $2
|
||||||
|
run_solver city90000 $2
|
||||||
|
run_solver city95000 $2
|
||||||
|
run_solver city100000 $2
|
||||||
|
}
|
||||||
|
|
||||||
|
prepare_new_run
|
||||||
|
run_all_graphs "hve(elim_heuristic=min_neighbors) " min_neighbors
|
||||||
|
|
31
packages/CLPBN/benchmarks/comp_workshops/bp_tests.sh
Executable file
31
packages/CLPBN/benchmarks/comp_workshops/bp_tests.sh
Executable file
@ -0,0 +1,31 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
source cw.sh
|
||||||
|
source ../benchs.sh
|
||||||
|
|
||||||
|
SOLVER="bp"
|
||||||
|
|
||||||
|
function run_all_graphs
|
||||||
|
{
|
||||||
|
write_header $1
|
||||||
|
run_solver p1000w$N_WORKSHOPS $2
|
||||||
|
run_solver p5000w$N_WORKSHOPS $2
|
||||||
|
run_solver p10000w$N_WORKSHOPS $2
|
||||||
|
run_solver p15000w$N_WORKSHOPS $2
|
||||||
|
run_solver p20000w$N_WORKSHOPS $2
|
||||||
|
run_solver p25000w$N_WORKSHOPS $2
|
||||||
|
return
|
||||||
|
run_solver p30000w$N_WORKSHOPS $2
|
||||||
|
run_solver p35000w$N_WORKSHOPS $2
|
||||||
|
run_solver p40000w$N_WORKSHOPS $2
|
||||||
|
run_solver p45000w$N_WORKSHOPS $2
|
||||||
|
run_solver p50000w$N_WORKSHOPS $2
|
||||||
|
run_solver p55000w$N_WORKSHOPS $2
|
||||||
|
run_solver p60000w$N_WORKSHOPS $2
|
||||||
|
run_solver p65000w$N_WORKSHOPS $2
|
||||||
|
run_solver p70000w$N_WORKSHOPS $2
|
||||||
|
}
|
||||||
|
|
||||||
|
prepare_new_run
|
||||||
|
run_all_graphs "bp(shedule=seq_fixed) " seq_fixed
|
||||||
|
|
30
packages/CLPBN/benchmarks/comp_workshops/cbp_tests.sh
Executable file
30
packages/CLPBN/benchmarks/comp_workshops/cbp_tests.sh
Executable file
@ -0,0 +1,30 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
source cw.sh
|
||||||
|
source ../benchs.sh
|
||||||
|
|
||||||
|
SOLVER="cbp"
|
||||||
|
|
||||||
|
function run_all_graphs
|
||||||
|
{
|
||||||
|
write_header $1
|
||||||
|
run_solver p1000w$N_WORKSHOPS $2
|
||||||
|
run_solver p5000w$N_WORKSHOPS $2
|
||||||
|
run_solver p10000w$N_WORKSHOPS $2
|
||||||
|
run_solver p15000w$N_WORKSHOPS $2
|
||||||
|
run_solver p20000w$N_WORKSHOPS $2
|
||||||
|
run_solver p25000w$N_WORKSHOPS $2
|
||||||
|
run_solver p30000w$N_WORKSHOPS $2
|
||||||
|
run_solver p35000w$N_WORKSHOPS $2
|
||||||
|
run_solver p40000w$N_WORKSHOPS $2
|
||||||
|
run_solver p45000w$N_WORKSHOPS $2
|
||||||
|
run_solver p50000w$N_WORKSHOPS $2
|
||||||
|
run_solver p55000w$N_WORKSHOPS $2
|
||||||
|
run_solver p60000w$N_WORKSHOPS $2
|
||||||
|
run_solver p65000w$N_WORKSHOPS $2
|
||||||
|
run_solver p70000w$N_WORKSHOPS $2
|
||||||
|
}
|
||||||
|
|
||||||
|
prepare_new_run
|
||||||
|
run_all_graphs "cbp(shedule=seq_fixed) " seq_fixed
|
||||||
|
|
8
packages/CLPBN/benchmarks/comp_workshops/cw.sh
Executable file
8
packages/CLPBN/benchmarks/comp_workshops/cw.sh
Executable file
@ -0,0 +1,8 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
NETWORK="'../../examples/comp_workshops'"
|
||||||
|
SHORTNAME="cw"
|
||||||
|
QUERY="series(X)"
|
||||||
|
|
||||||
|
N_WORKSHOPS=10
|
||||||
|
|
31
packages/CLPBN/benchmarks/comp_workshops/fove_tests.sh
Executable file
31
packages/CLPBN/benchmarks/comp_workshops/fove_tests.sh
Executable file
@ -0,0 +1,31 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
source cw.sh
|
||||||
|
source ../benchs.sh
|
||||||
|
|
||||||
|
SOLVER="fove"
|
||||||
|
|
||||||
|
function run_all_graphs
|
||||||
|
{
|
||||||
|
write_header $1
|
||||||
|
run_solver p1000w$N_WORKSHOPS $2
|
||||||
|
run_solver p5000w$N_WORKSHOPS $2
|
||||||
|
run_solver p10000w$N_WORKSHOPS $2
|
||||||
|
run_solver p15000w$N_WORKSHOPS $2
|
||||||
|
run_solver p20000w$N_WORKSHOPS $2
|
||||||
|
run_solver p25000w$N_WORKSHOPS $2
|
||||||
|
run_solver p30000w$N_WORKSHOPS $2
|
||||||
|
run_solver p35000w$N_WORKSHOPS $2
|
||||||
|
run_solver p40000w$N_WORKSHOPS $2
|
||||||
|
run_solver p45000w$N_WORKSHOPS $2
|
||||||
|
run_solver p50000w$N_WORKSHOPS $2
|
||||||
|
run_solver p55000w$N_WORKSHOPS $2
|
||||||
|
run_solver p60000w$N_WORKSHOPS $2
|
||||||
|
run_solver p65000w$N_WORKSHOPS $2
|
||||||
|
run_solver p70000w$N_WORKSHOPS $2
|
||||||
|
}
|
||||||
|
|
||||||
|
prepare_new_run
|
||||||
|
run_all_graphs "fove "
|
||||||
|
|
||||||
|
|
35
packages/CLPBN/benchmarks/comp_workshops/gen_workshops.sh
Executable file
35
packages/CLPBN/benchmarks/comp_workshops/gen_workshops.sh
Executable file
@ -0,0 +1,35 @@
|
|||||||
|
#!/home/tgomes/bin/yap -L --
|
||||||
|
|
||||||
|
:- use_module(library(lists)).
|
||||||
|
|
||||||
|
:- initialization(main).
|
||||||
|
|
||||||
|
|
||||||
|
main :-
|
||||||
|
unix(argv(Args)),
|
||||||
|
nth(1, Args, NP), % number of invitees
|
||||||
|
nth(2, Args, NW), % number of workshops
|
||||||
|
atomic_concat(['p', NP , 'w', NW, '.yap'], FileName),
|
||||||
|
open(FileName, 'write', S),
|
||||||
|
atom_number(NP, NP2),
|
||||||
|
atom_number(NW, NW2),
|
||||||
|
gen(S, NP2, NW2, 1),
|
||||||
|
write(S, '\n'),
|
||||||
|
close(S).
|
||||||
|
|
||||||
|
|
||||||
|
gen(_, NP, _, Count) :-
|
||||||
|
Count > NP, !.
|
||||||
|
gen(S, NP, NW, Count) :-
|
||||||
|
gen_workshops(S, Count, NW, 1),
|
||||||
|
Count1 is Count + 1,
|
||||||
|
gen(S, NP, NW, Count1).
|
||||||
|
|
||||||
|
|
||||||
|
gen_workshops(_, _, NW, Count) :-
|
||||||
|
Count > NW, !.
|
||||||
|
gen_workshops(S, P, NW, Count) :-
|
||||||
|
format(S, 'c(p~w,w~w).~n', [P,Count]),
|
||||||
|
Count1 is Count + 1,
|
||||||
|
gen_workshops(S, P, NW, Count1).
|
||||||
|
|
30
packages/CLPBN/benchmarks/comp_workshops/hve_tests.sh
Executable file
30
packages/CLPBN/benchmarks/comp_workshops/hve_tests.sh
Executable file
@ -0,0 +1,30 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
source cw.sh
|
||||||
|
source ../benchs.sh
|
||||||
|
|
||||||
|
SOLVER="hve"
|
||||||
|
|
||||||
|
function run_all_graphs
|
||||||
|
{
|
||||||
|
write_header $1
|
||||||
|
run_solver p1000w$N_WORKSHOPS $2
|
||||||
|
run_solver p5000w$N_WORKSHOPS $2
|
||||||
|
run_solver p10000w$N_WORKSHOPS $2
|
||||||
|
run_solver p15000w$N_WORKSHOPS $2
|
||||||
|
run_solver p20000w$N_WORKSHOPS $2
|
||||||
|
run_solver p25000w$N_WORKSHOPS $2
|
||||||
|
run_solver p30000w$N_WORKSHOPS $2
|
||||||
|
run_solver p35000w$N_WORKSHOPS $2
|
||||||
|
run_solver p40000w$N_WORKSHOPS $2
|
||||||
|
run_solver p45000w$N_WORKSHOPS $2
|
||||||
|
run_solver p50000w$N_WORKSHOPS $2
|
||||||
|
run_solver p55000w$N_WORKSHOPS $2
|
||||||
|
run_solver p60000w$N_WORKSHOPS $2
|
||||||
|
run_solver p65000w$N_WORKSHOPS $2
|
||||||
|
run_solver p70000w$N_WORKSHOPS $2
|
||||||
|
}
|
||||||
|
|
||||||
|
prepare_new_run
|
||||||
|
run_all_graphs "hve(elim_heuristic=min_neighbors) " min_neighbors
|
||||||
|
|
30
packages/CLPBN/benchmarks/smokers/bp_tests.sh
Executable file
30
packages/CLPBN/benchmarks/smokers/bp_tests.sh
Executable file
@ -0,0 +1,30 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
source sm.sh
|
||||||
|
source ../benchs.sh
|
||||||
|
|
||||||
|
SOLVER="bp"
|
||||||
|
|
||||||
|
function run_all_graphs
|
||||||
|
{
|
||||||
|
write_header $1
|
||||||
|
run_solver pop100 $2
|
||||||
|
run_solver pop200 $2
|
||||||
|
run_solver pop300 $2
|
||||||
|
run_solver pop400 $2
|
||||||
|
run_solver pop500 $2
|
||||||
|
run_solver pop600 $2
|
||||||
|
run_solver pop700 $2
|
||||||
|
run_solver pop800 $2
|
||||||
|
run_solver pop900 $2
|
||||||
|
run_solver pop1000 $2
|
||||||
|
run_solver pop1100 $2
|
||||||
|
run_solver pop1200 $2
|
||||||
|
run_solver pop1300 $2
|
||||||
|
run_solver pop1400 $2
|
||||||
|
run_solver pop1500 $2
|
||||||
|
}
|
||||||
|
|
||||||
|
prepare_new_run
|
||||||
|
run_all_graphs "bp(shedule=seq_fixed) " seq_fixed
|
||||||
|
|
30
packages/CLPBN/benchmarks/smokers/cbp_tests.sh
Executable file
30
packages/CLPBN/benchmarks/smokers/cbp_tests.sh
Executable file
@ -0,0 +1,30 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
source sm.sh
|
||||||
|
source ../benchs.sh
|
||||||
|
|
||||||
|
SOLVER="cbp"
|
||||||
|
|
||||||
|
function run_all_graphs
|
||||||
|
{
|
||||||
|
write_header $1
|
||||||
|
run_solver pop100 $2
|
||||||
|
run_solver pop200 $2
|
||||||
|
run_solver pop300 $2
|
||||||
|
run_solver pop400 $2
|
||||||
|
run_solver pop500 $2
|
||||||
|
run_solver pop600 $2
|
||||||
|
run_solver pop700 $2
|
||||||
|
run_solver pop800 $2
|
||||||
|
run_solver pop900 $2
|
||||||
|
run_solver pop1000 $2
|
||||||
|
run_solver pop1100 $2
|
||||||
|
run_solver pop1200 $2
|
||||||
|
run_solver pop1300 $2
|
||||||
|
run_solver pop1400 $2
|
||||||
|
run_solver pop1500 $2
|
||||||
|
}
|
||||||
|
|
||||||
|
prepare_new_run
|
||||||
|
run_all_graphs "cbp(shedule=seq_fixed) " seq_fixed
|
||||||
|
|
31
packages/CLPBN/benchmarks/smokers/fove_tests.sh
Executable file
31
packages/CLPBN/benchmarks/smokers/fove_tests.sh
Executable file
@ -0,0 +1,31 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
source sm.sh
|
||||||
|
source ../benchs.sh
|
||||||
|
|
||||||
|
SOLVER="fove"
|
||||||
|
|
||||||
|
function run_all_graphs
|
||||||
|
{
|
||||||
|
write_header $1
|
||||||
|
run_solver pop100 $2
|
||||||
|
run_solver pop200 $2
|
||||||
|
run_solver pop300 $2
|
||||||
|
run_solver pop400 $2
|
||||||
|
run_solver pop500 $2
|
||||||
|
run_solver pop600 $2
|
||||||
|
run_solver pop700 $2
|
||||||
|
run_solver pop800 $2
|
||||||
|
run_solver pop900 $2
|
||||||
|
run_solver pop1000 $2
|
||||||
|
run_solver pop1100 $2
|
||||||
|
run_solver pop1200 $2
|
||||||
|
run_solver pop1300 $2
|
||||||
|
run_solver pop1400 $2
|
||||||
|
run_solver pop1500 $2
|
||||||
|
}
|
||||||
|
|
||||||
|
prepare_new_run
|
||||||
|
run_all_graphs "fove "
|
||||||
|
|
||||||
|
|
@ -1,16 +1,12 @@
|
|||||||
#!/home/tiago/bin/yap -L --
|
#!/home/tgomes/bin/yap -L --
|
||||||
|
|
||||||
|
|
||||||
:- initialization(main).
|
:- initialization(main).
|
||||||
|
|
||||||
|
|
||||||
main :-
|
main :-
|
||||||
unix(argv([H])),
|
unix(argv([N])),
|
||||||
generate_town(H).
|
atomic_concat(['pop', N, '.yap'], FileName),
|
||||||
|
|
||||||
|
|
||||||
generate_town(N) :-
|
|
||||||
atomic_concat(['pop_', N, '.yap'], FileName),
|
|
||||||
open(FileName, 'write', S),
|
open(FileName, 'write', S),
|
||||||
atom_number(N, N2),
|
atom_number(N, N2),
|
||||||
generate_people(S, N2, 4),
|
generate_people(S, N2, 4),
|
33
packages/CLPBN/benchmarks/smokers/hve_tests.sh
Executable file
33
packages/CLPBN/benchmarks/smokers/hve_tests.sh
Executable file
@ -0,0 +1,33 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
source sm.sh
|
||||||
|
source ../benchs.sh
|
||||||
|
|
||||||
|
SOLVER="hve"
|
||||||
|
|
||||||
|
function run_all_graphs
|
||||||
|
{
|
||||||
|
write_header $1
|
||||||
|
run_solver pop100 $2
|
||||||
|
#run_solver pop200 $2
|
||||||
|
#run_solver pop300 $2
|
||||||
|
#run_solver pop400 $2
|
||||||
|
#run_solver pop500 $2
|
||||||
|
#run_solver pop600 $2
|
||||||
|
#run_solver pop700 $2
|
||||||
|
#run_solver pop800 $2
|
||||||
|
#run_solver pop900 $2
|
||||||
|
#run_solver pop1000 $2
|
||||||
|
#run_solver pop1100 $2
|
||||||
|
#run_solver pop1200 $2
|
||||||
|
#run_solver pop1300 $2
|
||||||
|
#run_solver pop1400 $2
|
||||||
|
#run_solver pop1500 $2
|
||||||
|
}
|
||||||
|
|
||||||
|
prepare_new_run
|
||||||
|
run_all_graphs "hve(elim_heuristic=min_neighbors) " min_neighbors
|
||||||
|
#run_all_graphs "hve(elim_heuristic=min_weight) " min_weight
|
||||||
|
#run_all_graphs "hve(elim_heuristic=min_fill) " min_fill
|
||||||
|
#run_all_graphs "hve(elim_heuristic=weighted_min_fill) " weighted_min_fill
|
||||||
|
|
6
packages/CLPBN/benchmarks/smokers/sm.sh
Executable file
6
packages/CLPBN/benchmarks/smokers/sm.sh
Executable file
@ -0,0 +1,6 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
NETWORK="'../../examples/social_domain2'"
|
||||||
|
SHORTNAME="sm"
|
||||||
|
QUERY="smokes(p1,t), smokes(p2,t), friends(p1,p2,X)"
|
||||||
|
|
37
packages/CLPBN/benchmarks/workshop_attrs/bp_tests.sh
Executable file
37
packages/CLPBN/benchmarks/workshop_attrs/bp_tests.sh
Executable file
@ -0,0 +1,37 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
source wa.sh
|
||||||
|
source ../benchs.sh
|
||||||
|
|
||||||
|
SOLVER="bp"
|
||||||
|
|
||||||
|
function run_all_graphs
|
||||||
|
{
|
||||||
|
write_header $1
|
||||||
|
run_solver p1000attrs$N_ATTRS $2
|
||||||
|
run_solver p5000attrs$N_ATTRS $2
|
||||||
|
run_solver p10000attrs$N_ATTRS $2
|
||||||
|
run_solver p15000attrs$N_ATTRS $2
|
||||||
|
run_solver p20000attrs$N_ATTRS $2
|
||||||
|
run_solver p25000attrs$N_ATTRS $2
|
||||||
|
run_solver p30000attrs$N_ATTRS $2
|
||||||
|
run_solver p35000attrs$N_ATTRS $2
|
||||||
|
return
|
||||||
|
run_solver p40000attrs$N_ATTRS $2
|
||||||
|
run_solver p45000attrs$N_ATTRS $2
|
||||||
|
run_solver p50000attrs$N_ATTRS $2
|
||||||
|
run_solver p55000attrs$N_ATTRS $2
|
||||||
|
run_solver p60000attrs$N_ATTRS $2
|
||||||
|
run_solver p65000attrs$N_ATTRS $2
|
||||||
|
run_solver p70000attrs$N_ATTRS $2
|
||||||
|
run_solver p75000attrs$N_ATTRS $2
|
||||||
|
run_solver p80000attrs$N_ATTRS $2
|
||||||
|
run_solver p85000attrs$N_ATTRS $2
|
||||||
|
run_solver p90000attrs$N_ATTRS $2
|
||||||
|
run_solver p95000attrs$N_ATTRS $2
|
||||||
|
run_solver p100000attrs$N_ATTRS $2
|
||||||
|
}
|
||||||
|
|
||||||
|
prepare_new_run
|
||||||
|
run_all_graphs "bp(shedule=seq_fixed) " seq_fixed
|
||||||
|
|
36
packages/CLPBN/benchmarks/workshop_attrs/cbp_tests.sh
Executable file
36
packages/CLPBN/benchmarks/workshop_attrs/cbp_tests.sh
Executable file
@ -0,0 +1,36 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
source wa.sh
|
||||||
|
source ../benchs.sh
|
||||||
|
|
||||||
|
SOLVER="cbp"
|
||||||
|
|
||||||
|
function run_all_graphs
|
||||||
|
{
|
||||||
|
write_header $1
|
||||||
|
run_solver p1000attrs$N_ATTRS $2
|
||||||
|
run_solver p5000attrs$N_ATTRS $2
|
||||||
|
run_solver p10000attrs$N_ATTRS $2
|
||||||
|
run_solver p15000attrs$N_ATTRS $2
|
||||||
|
run_solver p20000attrs$N_ATTRS $2
|
||||||
|
run_solver p25000attrs$N_ATTRS $2
|
||||||
|
run_solver p30000attrs$N_ATTRS $2
|
||||||
|
run_solver p35000attrs$N_ATTRS $2
|
||||||
|
run_solver p40000attrs$N_ATTRS $2
|
||||||
|
run_solver p45000attrs$N_ATTRS $2
|
||||||
|
run_solver p50000attrs$N_ATTRS $2
|
||||||
|
run_solver p55000attrs$N_ATTRS $2
|
||||||
|
run_solver p60000attrs$N_ATTRS $2
|
||||||
|
run_solver p65000attrs$N_ATTRS $2
|
||||||
|
run_solver p70000attrs$N_ATTRS $2
|
||||||
|
run_solver p75000attrs$N_ATTRS $2
|
||||||
|
run_solver p80000attrs$N_ATTRS $2
|
||||||
|
run_solver p85000attrs$N_ATTRS $2
|
||||||
|
run_solver p90000attrs$N_ATTRS $2
|
||||||
|
run_solver p95000attrs$N_ATTRS $2
|
||||||
|
run_solver p100000attrs$N_ATTRS $2
|
||||||
|
}
|
||||||
|
|
||||||
|
prepare_new_run
|
||||||
|
run_all_graphs "cbp(shedule=seq_fixed) " seq_fixed
|
||||||
|
|
37
packages/CLPBN/benchmarks/workshop_attrs/fove_tests.sh
Executable file
37
packages/CLPBN/benchmarks/workshop_attrs/fove_tests.sh
Executable file
@ -0,0 +1,37 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
source wa.sh
|
||||||
|
source ../benchs.sh
|
||||||
|
|
||||||
|
SOLVER="fove"
|
||||||
|
|
||||||
|
function run_all_graphs
|
||||||
|
{
|
||||||
|
write_header $1
|
||||||
|
run_solver p1000attrs$N_ATTRS $2
|
||||||
|
run_solver p5000attrs$N_ATTRS $2
|
||||||
|
run_solver p10000attrs$N_ATTRS $2
|
||||||
|
run_solver p15000attrs$N_ATTRS $2
|
||||||
|
run_solver p20000attrs$N_ATTRS $2
|
||||||
|
run_solver p25000attrs$N_ATTRS $2
|
||||||
|
run_solver p30000attrs$N_ATTRS $2
|
||||||
|
run_solver p35000attrs$N_ATTRS $2
|
||||||
|
run_solver p40000attrs$N_ATTRS $2
|
||||||
|
run_solver p45000attrs$N_ATTRS $2
|
||||||
|
run_solver p50000attrs$N_ATTRS $2
|
||||||
|
run_solver p55000attrs$N_ATTRS $2
|
||||||
|
run_solver p60000attrs$N_ATTRS $2
|
||||||
|
run_solver p65000attrs$N_ATTRS $2
|
||||||
|
run_solver p70000attrs$N_ATTRS $2
|
||||||
|
run_solver p75000attrs$N_ATTRS $2
|
||||||
|
run_solver p80000attrs$N_ATTRS $2
|
||||||
|
run_solver p85000attrs$N_ATTRS $2
|
||||||
|
run_solver p90000attrs$N_ATTRS $2
|
||||||
|
run_solver p95000attrs$N_ATTRS $2
|
||||||
|
run_solver p100000attrs$N_ATTRS $2
|
||||||
|
}
|
||||||
|
|
||||||
|
prepare_new_run
|
||||||
|
run_all_graphs "fove "
|
||||||
|
|
||||||
|
|
39
packages/CLPBN/benchmarks/workshop_attrs/gen_attrs.sh
Executable file
39
packages/CLPBN/benchmarks/workshop_attrs/gen_attrs.sh
Executable file
@ -0,0 +1,39 @@
|
|||||||
|
#!/home/tgomes/bin/yap -L --
|
||||||
|
|
||||||
|
:- use_module(library(lists)).
|
||||||
|
|
||||||
|
:- initialization(main).
|
||||||
|
|
||||||
|
|
||||||
|
main :-
|
||||||
|
unix(argv(Args)),
|
||||||
|
nth(1, Args, NP), % number of invitees
|
||||||
|
nth(2, Args, NA), % number of attributes
|
||||||
|
atomic_concat(['p', NP , 'attrs', NA, '.yap'], FileName),
|
||||||
|
open(FileName, 'write', S),
|
||||||
|
atom_number(NP, NP2),
|
||||||
|
atom_number(NA, NA2),
|
||||||
|
generate_people(S, NP2, 1),
|
||||||
|
write(S, '\n'),
|
||||||
|
generate_attrs(S, NA2, 7),
|
||||||
|
write(S, '\n'),
|
||||||
|
close(S).
|
||||||
|
|
||||||
|
|
||||||
|
generate_people(S, N, Counting) :-
|
||||||
|
Counting > N, !.
|
||||||
|
generate_people(S, N, Counting) :-
|
||||||
|
format(S, 'people(p~w).~n', [Counting]),
|
||||||
|
Counting1 is Counting + 1,
|
||||||
|
generate_people(S, N, Counting1).
|
||||||
|
|
||||||
|
|
||||||
|
generate_attrs(S, N, Counting) :-
|
||||||
|
Counting > N, !.
|
||||||
|
generate_attrs(S, N, Counting) :-
|
||||||
|
%format(S, 'people(p~w).~n', [Counting]),
|
||||||
|
format(S, 'markov attends(P)::[t,f], attr~w::[t,f]', [Counting]),
|
||||||
|
format(S, '; [0.7, 0.3, 0.3, 0.3] ; [people(P)].~n',[]),
|
||||||
|
Counting1 is Counting + 1,
|
||||||
|
generate_attrs(S, N, Counting1).
|
||||||
|
|
36
packages/CLPBN/benchmarks/workshop_attrs/hve_tests.sh
Executable file
36
packages/CLPBN/benchmarks/workshop_attrs/hve_tests.sh
Executable file
@ -0,0 +1,36 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
source wa.sh
|
||||||
|
source ../benchs.sh
|
||||||
|
|
||||||
|
SOLVER="hve"
|
||||||
|
|
||||||
|
function run_all_graphs
|
||||||
|
{
|
||||||
|
write_header $1
|
||||||
|
run_solver p1000attrs$N_ATTRS $2
|
||||||
|
run_solver p5000attrs$N_ATTRS $2
|
||||||
|
run_solver p10000attrs$N_ATTRS $2
|
||||||
|
run_solver p15000attrs$N_ATTRS $2
|
||||||
|
run_solver p20000attrs$N_ATTRS $2
|
||||||
|
run_solver p25000attrs$N_ATTRS $2
|
||||||
|
run_solver p30000attrs$N_ATTRS $2
|
||||||
|
run_solver p35000attrs$N_ATTRS $2
|
||||||
|
run_solver p40000attrs$N_ATTRS $2
|
||||||
|
run_solver p45000attrs$N_ATTRS $2
|
||||||
|
run_solver p50000attrs$N_ATTRS $2
|
||||||
|
run_solver p55000attrs$N_ATTRS $2
|
||||||
|
run_solver p60000attrs$N_ATTRS $2
|
||||||
|
run_solver p65000attrs$N_ATTRS $2
|
||||||
|
run_solver p70000attrs$N_ATTRS $2
|
||||||
|
run_solver p75000attrs$N_ATTRS $2
|
||||||
|
run_solver p80000attrs$N_ATTRS $2
|
||||||
|
run_solver p85000attrs$N_ATTRS $2
|
||||||
|
run_solver p90000attrs$N_ATTRS $2
|
||||||
|
run_solver p95000attrs$N_ATTRS $2
|
||||||
|
run_solver p100000attrs$N_ATTRS $2
|
||||||
|
}
|
||||||
|
|
||||||
|
prepare_new_run
|
||||||
|
run_all_graphs "hve(elim_heuristic=min_neighbors) " min_neighbors
|
||||||
|
|
9
packages/CLPBN/benchmarks/workshop_attrs/wa.sh
Executable file
9
packages/CLPBN/benchmarks/workshop_attrs/wa.sh
Executable file
@ -0,0 +1,9 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
NETWORK="'../../examples/workshop_attrs'"
|
||||||
|
SHORTNAME="wa"
|
||||||
|
QUERY="series(X)"
|
||||||
|
|
||||||
|
N_ATTRS=6
|
||||||
|
|
||||||
|
|
@ -1,5 +1,4 @@
|
|||||||
|
|
||||||
|
|
||||||
:- module(clpbn, [{}/1,
|
:- module(clpbn, [{}/1,
|
||||||
clpbn_flag/2,
|
clpbn_flag/2,
|
||||||
set_clpbn_flag/2,
|
set_clpbn_flag/2,
|
||||||
@ -39,24 +38,22 @@
|
|||||||
run_ve_solver/3
|
run_ve_solver/3
|
||||||
]).
|
]).
|
||||||
|
|
||||||
:- use_module('clpbn/bp',
|
:- use_module('clpbn/horus_ground',
|
||||||
[bp/3,
|
[call_horus_ground_solver/6,
|
||||||
check_if_bp_done/1,
|
check_if_horus_ground_solver_done/1,
|
||||||
init_bp_solver/4,
|
init_horus_ground_solver/4,
|
||||||
run_bp_solver/3,
|
run_horus_ground_solver/3,
|
||||||
call_bp_ground/6,
|
finalize_horus_ground_solver/1
|
||||||
finalize_bp_solver/1
|
|
||||||
]).
|
]).
|
||||||
|
|
||||||
:- use_module('clpbn/fove',
|
:- use_module('clpbn/horus_lifted',
|
||||||
[fove/3,
|
[call_horus_lifted_solver/3,
|
||||||
check_if_fove_done/1,
|
check_if_horus_lifted_solver_done/1,
|
||||||
init_fove_solver/4,
|
init_horus_lifted_solver/4,
|
||||||
run_fove_solver/3,
|
run_horus_lifted_solver/3,
|
||||||
finalize_fove_solver/1
|
finalize_horus_lifted_solver/1
|
||||||
]).
|
]).
|
||||||
|
|
||||||
|
|
||||||
:- use_module('clpbn/jt',
|
:- use_module('clpbn/jt',
|
||||||
[jt/3,
|
[jt/3,
|
||||||
init_jt_solver/4,
|
init_jt_solver/4,
|
||||||
@ -306,18 +303,19 @@ write_out(jt, GVars, AVars, DiffVars) :-
|
|||||||
jt(GVars, AVars, DiffVars).
|
jt(GVars, AVars, DiffVars).
|
||||||
write_out(bdd, GVars, AVars, DiffVars) :-
|
write_out(bdd, GVars, AVars, DiffVars) :-
|
||||||
bdd(GVars, AVars, DiffVars).
|
bdd(GVars, AVars, DiffVars).
|
||||||
write_out(bp, GVars, AVars, DiffVars) :-
|
write_out(bp, _GVars, _AVars, _DiffVars) :-
|
||||||
bp(GVars, AVars, DiffVars).
|
writeln('interface not supported anymore').
|
||||||
|
%bp(GVars, AVars, DiffVars).
|
||||||
write_out(gibbs, GVars, AVars, DiffVars) :-
|
write_out(gibbs, GVars, AVars, DiffVars) :-
|
||||||
gibbs(GVars, AVars, DiffVars).
|
gibbs(GVars, AVars, DiffVars).
|
||||||
write_out(bnt, GVars, AVars, DiffVars) :-
|
write_out(bnt, GVars, AVars, DiffVars) :-
|
||||||
do_bnt(GVars, AVars, DiffVars).
|
do_bnt(GVars, AVars, DiffVars).
|
||||||
write_out(fove, GVars, AVars, DiffVars) :-
|
write_out(fove, GVars, AVars, DiffVars) :-
|
||||||
fove(GVars, AVars, DiffVars).
|
call_horus_lifted_solver(GVars, AVars, DiffVars).
|
||||||
|
|
||||||
% call a solver with keys, not actual variables
|
% call a solver with keys, not actual variables
|
||||||
call_ground_solver(bp, GVars, GoalKeys, Keys, Factors, Evidence, Answ) :-
|
call_ground_solver(bp, GVars, GoalKeys, Keys, Factors, Evidence, Answ) :-
|
||||||
call_bp_ground(GVars, GoalKeys, Keys, Factors, Evidence, Answ).
|
call_horus_ground_solver(GVars, GoalKeys, Keys, Factors, Evidence, Answ).
|
||||||
|
|
||||||
|
|
||||||
get_bnode(Var, Goal) :-
|
get_bnode(Var, Goal) :-
|
||||||
@ -400,7 +398,7 @@ bind_clpbn(_, Var, _, _, _, _, []) :-
|
|||||||
check_if_ve_done(Var), !.
|
check_if_ve_done(Var), !.
|
||||||
bind_clpbn(_, Var, _, _, _, _, []) :-
|
bind_clpbn(_, Var, _, _, _, _, []) :-
|
||||||
use(bp),
|
use(bp),
|
||||||
check_if_bp_done(Var), !.
|
check_if_horus_ground_solver_done(Var), !.
|
||||||
bind_clpbn(_, Var, _, _, _, _, []) :-
|
bind_clpbn(_, Var, _, _, _, _, []) :-
|
||||||
use(jt),
|
use(jt),
|
||||||
check_if_ve_done(Var), !.
|
check_if_ve_done(Var), !.
|
||||||
@ -475,7 +473,7 @@ clpbn_init_solver(gibbs, LVs, Vs0, VarsWithUnboundKeys, State) :-
|
|||||||
clpbn_init_solver(ve, LVs, Vs0, VarsWithUnboundKeys, State) :-
|
clpbn_init_solver(ve, LVs, Vs0, VarsWithUnboundKeys, State) :-
|
||||||
init_ve_solver(LVs, Vs0, VarsWithUnboundKeys, State).
|
init_ve_solver(LVs, Vs0, VarsWithUnboundKeys, State).
|
||||||
clpbn_init_solver(bp, LVs, Vs0, VarsWithUnboundKeys, State) :-
|
clpbn_init_solver(bp, LVs, Vs0, VarsWithUnboundKeys, State) :-
|
||||||
init_bp_solver(LVs, Vs0, VarsWithUnboundKeys, State).
|
init_horus_ground_solver(LVs, Vs0, VarsWithUnboundKeys, State).
|
||||||
clpbn_init_solver(jt, LVs, Vs0, VarsWithUnboundKeys, State) :-
|
clpbn_init_solver(jt, LVs, Vs0, VarsWithUnboundKeys, State) :-
|
||||||
init_jt_solver(LVs, Vs0, VarsWithUnboundKeys, State).
|
init_jt_solver(LVs, Vs0, VarsWithUnboundKeys, State).
|
||||||
clpbn_init_solver(bdd, LVs, Vs0, VarsWithUnboundKeys, State) :-
|
clpbn_init_solver(bdd, LVs, Vs0, VarsWithUnboundKeys, State) :-
|
||||||
@ -501,7 +499,7 @@ clpbn_run_solver(ve, LVs, LPs, State) :-
|
|||||||
run_ve_solver(LVs, LPs, State).
|
run_ve_solver(LVs, LPs, State).
|
||||||
|
|
||||||
clpbn_run_solver(bp, LVs, LPs, State) :-
|
clpbn_run_solver(bp, LVs, LPs, State) :-
|
||||||
run_bp_solver(LVs, LPs, State).
|
run_horus_ground_solver(LVs, LPs, State).
|
||||||
|
|
||||||
clpbn_run_solver(jt, LVs, LPs, State) :-
|
clpbn_run_solver(jt, LVs, LPs, State) :-
|
||||||
run_jt_solver(LVs, LPs, State).
|
run_jt_solver(LVs, LPs, State).
|
||||||
@ -522,7 +520,7 @@ clpbn_finalize_solver(State) :-
|
|||||||
solver(bp), !,
|
solver(bp), !,
|
||||||
functor(State, _, Last),
|
functor(State, _, Last),
|
||||||
arg(Last, State, Info),
|
arg(Last, State, Info),
|
||||||
finalize_bp_solver(Info).
|
finalize_horus_ground_solver(Info).
|
||||||
clpbn_finalize_solver(_State).
|
clpbn_finalize_solver(_State).
|
||||||
|
|
||||||
probability(Goal, Prob) :-
|
probability(Goal, Prob) :-
|
||||||
|
@ -1,493 +0,0 @@
|
|||||||
#include <cassert>
|
|
||||||
#include <limits>
|
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
|
|
||||||
#include <iostream>
|
|
||||||
|
|
||||||
#include "BpSolver.h"
|
|
||||||
#include "FactorGraph.h"
|
|
||||||
#include "Factor.h"
|
|
||||||
#include "Indexer.h"
|
|
||||||
#include "Horus.h"
|
|
||||||
|
|
||||||
|
|
||||||
BpSolver::BpSolver (const FactorGraph& fg) : Solver (fg)
|
|
||||||
{
|
|
||||||
fg_ = &fg;
|
|
||||||
runned_ = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
BpSolver::~BpSolver (void)
|
|
||||||
{
|
|
||||||
for (unsigned i = 0; i < varsI_.size(); i++) {
|
|
||||||
delete varsI_[i];
|
|
||||||
}
|
|
||||||
for (unsigned i = 0; i < facsI_.size(); i++) {
|
|
||||||
delete facsI_[i];
|
|
||||||
}
|
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
|
||||||
delete links_[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
|
||||||
BpSolver::solveQuery (VarIds queryVids)
|
|
||||||
{
|
|
||||||
assert (queryVids.empty() == false);
|
|
||||||
if (queryVids.size() == 1) {
|
|
||||||
return getPosterioriOf (queryVids[0]);
|
|
||||||
} else {
|
|
||||||
return getJointDistributionOf (queryVids);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
|
||||||
BpSolver::getPosterioriOf (VarId vid)
|
|
||||||
{
|
|
||||||
if (runned_ == false) {
|
|
||||||
runSolver();
|
|
||||||
}
|
|
||||||
assert (fg_->getVarNode (vid));
|
|
||||||
VarNode* var = fg_->getVarNode (vid);
|
|
||||||
Params probs;
|
|
||||||
if (var->hasEvidence()) {
|
|
||||||
probs.resize (var->range(), LogAware::noEvidence());
|
|
||||||
probs[var->getEvidence()] = LogAware::withEvidence();
|
|
||||||
} else {
|
|
||||||
probs.resize (var->range(), LogAware::multIdenty());
|
|
||||||
const SpLinkSet& links = ninf(var)->getLinks();
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
for (unsigned i = 0; i < links.size(); i++) {
|
|
||||||
Util::add (probs, links[i]->getMessage());
|
|
||||||
}
|
|
||||||
LogAware::normalize (probs);
|
|
||||||
Util::fromLog (probs);
|
|
||||||
} else {
|
|
||||||
for (unsigned i = 0; i < links.size(); i++) {
|
|
||||||
Util::multiply (probs, links[i]->getMessage());
|
|
||||||
}
|
|
||||||
LogAware::normalize (probs);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return probs;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
|
||||||
BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
|
||||||
{
|
|
||||||
if (runned_ == false) {
|
|
||||||
runSolver();
|
|
||||||
}
|
|
||||||
int idx = -1;
|
|
||||||
VarNode* vn = fg_->getVarNode (jointVarIds[0]);
|
|
||||||
const FacNodes& facNodes = vn->neighbors();
|
|
||||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
|
||||||
if (facNodes[i]->factor().contains (jointVarIds)) {
|
|
||||||
idx = i;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (idx == -1) {
|
|
||||||
return getJointByConditioning (jointVarIds);
|
|
||||||
} else {
|
|
||||||
Factor res (facNodes[idx]->factor());
|
|
||||||
const SpLinkSet& links = ninf(facNodes[idx])->getLinks();
|
|
||||||
for (unsigned i = 0; i < links.size(); i++) {
|
|
||||||
Factor msg ({links[i]->getVariable()->varId()},
|
|
||||||
{links[i]->getVariable()->range()},
|
|
||||||
getVar2FactorMsg (links[i]));
|
|
||||||
res.multiply (msg);
|
|
||||||
}
|
|
||||||
res.sumOutAllExcept (jointVarIds);
|
|
||||||
res.reorderArguments (jointVarIds);
|
|
||||||
res.normalize();
|
|
||||||
Params jointDist = res.params();
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
Util::fromLog (jointDist);
|
|
||||||
}
|
|
||||||
return jointDist;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BpSolver::runSolver (void)
|
|
||||||
{
|
|
||||||
clock_t start;
|
|
||||||
if (Constants::COLLECT_STATS) {
|
|
||||||
start = clock();
|
|
||||||
}
|
|
||||||
initializeSolver();
|
|
||||||
nIters_ = 0;
|
|
||||||
while (!converged() && nIters_ < BpOptions::maxIter) {
|
|
||||||
nIters_ ++;
|
|
||||||
if (Constants::DEBUG >= 2) {
|
|
||||||
Util::printHeader (string ("Iteration ") + Util::toString (nIters_));
|
|
||||||
// cout << endl;
|
|
||||||
}
|
|
||||||
switch (BpOptions::schedule) {
|
|
||||||
case BpOptions::Schedule::SEQ_RANDOM:
|
|
||||||
random_shuffle (links_.begin(), links_.end());
|
|
||||||
// no break
|
|
||||||
case BpOptions::Schedule::SEQ_FIXED:
|
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
|
||||||
calculateAndUpdateMessage (links_[i]);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case BpOptions::Schedule::PARALLEL:
|
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
|
||||||
calculateMessage (links_[i]);
|
|
||||||
}
|
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
|
||||||
updateMessage(links_[i]);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case BpOptions::Schedule::MAX_RESIDUAL:
|
|
||||||
maxResidualSchedule();
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if (Constants::DEBUG >= 2) {
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (Constants::DEBUG >= 2) {
|
|
||||||
cout << endl;
|
|
||||||
if (nIters_ < BpOptions::maxIter) {
|
|
||||||
cout << "Sum-Product converged in " ;
|
|
||||||
cout << nIters_ << " iterations" << endl;
|
|
||||||
} else {
|
|
||||||
cout << "The maximum number of iterations was hit, terminating..." ;
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
unsigned size = fg_->varNodes().size();
|
|
||||||
if (Constants::COLLECT_STATS) {
|
|
||||||
unsigned nIters = 0;
|
|
||||||
bool loopy = fg_->isTree() == false;
|
|
||||||
if (loopy) nIters = nIters_;
|
|
||||||
double time = (double (clock() - start)) / CLOCKS_PER_SEC;
|
|
||||||
Statistics::updateStatistics (size, loopy, nIters, time);
|
|
||||||
}
|
|
||||||
runned_ = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BpSolver::createLinks (void)
|
|
||||||
{
|
|
||||||
const FacNodes& facNodes = fg_->facNodes();
|
|
||||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
|
||||||
const VarNodes& neighbors = facNodes[i]->neighbors();
|
|
||||||
for (unsigned j = 0; j < neighbors.size(); j++) {
|
|
||||||
links_.push_back (new SpLink (facNodes[i], neighbors[j]));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BpSolver::maxResidualSchedule (void)
|
|
||||||
{
|
|
||||||
if (nIters_ == 1) {
|
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
|
||||||
calculateMessage (links_[i]);
|
|
||||||
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
|
|
||||||
linkMap_.insert (make_pair (links_[i], it));
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (unsigned c = 0; c < links_.size(); c++) {
|
|
||||||
if (Constants::DEBUG >= 2) {
|
|
||||||
cout << "current residuals:" << endl;
|
|
||||||
for (SortedOrder::iterator it = sortedOrder_.begin();
|
|
||||||
it != sortedOrder_.end(); it ++) {
|
|
||||||
cout << " " << setw (30) << left << (*it)->toString();
|
|
||||||
cout << "residual = " << (*it)->getResidual() << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
SortedOrder::iterator it = sortedOrder_.begin();
|
|
||||||
SpLink* link = *it;
|
|
||||||
if (link->getResidual() < BpOptions::accuracy) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
updateMessage (link);
|
|
||||||
link->clearResidual();
|
|
||||||
sortedOrder_.erase (it);
|
|
||||||
linkMap_.find (link)->second = sortedOrder_.insert (link);
|
|
||||||
|
|
||||||
// update the messages that depend on message source --> destin
|
|
||||||
const FacNodes& factorNeighbors = link->getVariable()->neighbors();
|
|
||||||
for (unsigned i = 0; i < factorNeighbors.size(); i++) {
|
|
||||||
if (factorNeighbors[i] != link->getFactor()) {
|
|
||||||
const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks();
|
|
||||||
for (unsigned j = 0; j < links.size(); j++) {
|
|
||||||
if (links[j]->getVariable() != link->getVariable()) {
|
|
||||||
calculateMessage (links[j]);
|
|
||||||
SpLinkMap::iterator iter = linkMap_.find (links[j]);
|
|
||||||
sortedOrder_.erase (iter->second);
|
|
||||||
iter->second = sortedOrder_.insert (links[j]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (Constants::DEBUG >= 2) {
|
|
||||||
Util::printDashedLine();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BpSolver::calculateFactor2VariableMsg (SpLink* link)
|
|
||||||
{
|
|
||||||
FacNode* src = link->getFactor();
|
|
||||||
const VarNode* dst = link->getVariable();
|
|
||||||
const SpLinkSet& links = ninf(src)->getLinks();
|
|
||||||
// calculate the product of messages that were sent
|
|
||||||
// to factor `src', except from var `dst'
|
|
||||||
unsigned msgSize = 1;
|
|
||||||
for (unsigned i = 0; i < links.size(); i++) {
|
|
||||||
msgSize *= links[i]->getVariable()->range();
|
|
||||||
}
|
|
||||||
unsigned repetitions = 1;
|
|
||||||
Params msgProduct (msgSize, LogAware::multIdenty());
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
for (int i = links.size() - 1; i >= 0; i--) {
|
|
||||||
if (links[i]->getVariable() != dst) {
|
|
||||||
Util::add (msgProduct, getVar2FactorMsg (links[i]), repetitions);
|
|
||||||
repetitions *= links[i]->getVariable()->range();
|
|
||||||
} else {
|
|
||||||
unsigned ds = links[i]->getVariable()->range();
|
|
||||||
Util::add (msgProduct, Params (ds, 1.0), repetitions);
|
|
||||||
repetitions *= ds;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (int i = links.size() - 1; i >= 0; i--) {
|
|
||||||
if (links[i]->getVariable() != dst) {
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
cout << " message from " << links[i]->getVariable()->label();
|
|
||||||
cout << ": " << endl;
|
|
||||||
}
|
|
||||||
Util::multiply (msgProduct, getVar2FactorMsg (links[i]), repetitions);
|
|
||||||
repetitions *= links[i]->getVariable()->range();
|
|
||||||
} else {
|
|
||||||
unsigned ds = links[i]->getVariable()->range();
|
|
||||||
Util::multiply (msgProduct, Params (ds, 1.0), repetitions);
|
|
||||||
repetitions *= ds;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Factor result (src->factor().arguments(),
|
|
||||||
src->factor().ranges(), msgProduct);
|
|
||||||
result.multiply (src->factor());
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
cout << " message product: " << msgProduct << endl;
|
|
||||||
cout << " original factor: " << src->factor().params() << endl;
|
|
||||||
cout << " factor product: " << result.params() << endl;
|
|
||||||
}
|
|
||||||
result.sumOutAllExcept (dst->varId());
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
cout << " marginalized: " ;
|
|
||||||
cout << result.params() << endl;
|
|
||||||
}
|
|
||||||
const Params& resultParams = result.params();
|
|
||||||
Params& message = link->getNextMessage();
|
|
||||||
for (unsigned i = 0; i < resultParams.size(); i++) {
|
|
||||||
message[i] = resultParams[i];
|
|
||||||
}
|
|
||||||
LogAware::normalize (message);
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
cout << " curr msg: " << link->getMessage() << endl;
|
|
||||||
cout << " next msg: " << message << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
|
||||||
BpSolver::getVar2FactorMsg (const SpLink* link) const
|
|
||||||
{
|
|
||||||
const VarNode* src = link->getVariable();
|
|
||||||
const FacNode* dst = link->getFactor();
|
|
||||||
Params msg;
|
|
||||||
if (src->hasEvidence()) {
|
|
||||||
msg.resize (src->range(), LogAware::noEvidence());
|
|
||||||
msg[src->getEvidence()] = LogAware::withEvidence();
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
cout << msg;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
msg.resize (src->range(), LogAware::one());
|
|
||||||
}
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
cout << msg;
|
|
||||||
}
|
|
||||||
const SpLinkSet& links = ninf (src)->getLinks();
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
for (unsigned i = 0; i < links.size(); i++) {
|
|
||||||
if (links[i]->getFactor() != dst) {
|
|
||||||
Util::add (msg, links[i]->getMessage());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (unsigned i = 0; i < links.size(); i++) {
|
|
||||||
if (links[i]->getFactor() != dst) {
|
|
||||||
Util::multiply (msg, links[i]->getMessage());
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
cout << " x " << links[i]->getMessage();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
cout << " = " << msg;
|
|
||||||
}
|
|
||||||
return msg;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
|
||||||
BpSolver::getJointByConditioning (const VarIds& jointVarIds) const
|
|
||||||
{
|
|
||||||
VarNodes jointVars;
|
|
||||||
for (unsigned i = 0; i < jointVarIds.size(); i++) {
|
|
||||||
assert (fg_->getVarNode (jointVarIds[i]));
|
|
||||||
jointVars.push_back (fg_->getVarNode (jointVarIds[i]));
|
|
||||||
}
|
|
||||||
|
|
||||||
FactorGraph* fg = new FactorGraph (*fg_);
|
|
||||||
BpSolver solver (*fg);
|
|
||||||
solver.runSolver();
|
|
||||||
Params prevBeliefs = solver.getPosterioriOf (jointVarIds[0]);
|
|
||||||
|
|
||||||
VarIds observedVids = {jointVars[0]->varId()};
|
|
||||||
|
|
||||||
for (unsigned i = 1; i < jointVarIds.size(); i++) {
|
|
||||||
assert (jointVars[i]->hasEvidence() == false);
|
|
||||||
Params newBeliefs;
|
|
||||||
Vars observedVars;
|
|
||||||
for (unsigned j = 0; j < observedVids.size(); j++) {
|
|
||||||
observedVars.push_back (fg->getVarNode (observedVids[j]));
|
|
||||||
}
|
|
||||||
StatesIndexer idx (observedVars, false);
|
|
||||||
while (idx.valid()) {
|
|
||||||
for (unsigned j = 0; j < observedVars.size(); j++) {
|
|
||||||
observedVars[j]->setEvidence (idx[j]);
|
|
||||||
}
|
|
||||||
++ idx;
|
|
||||||
BpSolver solver (*fg);
|
|
||||||
solver.runSolver();
|
|
||||||
Params beliefs = solver.getPosterioriOf (jointVarIds[i]);
|
|
||||||
for (unsigned k = 0; k < beliefs.size(); k++) {
|
|
||||||
newBeliefs.push_back (beliefs[k]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
int count = -1;
|
|
||||||
for (unsigned j = 0; j < newBeliefs.size(); j++) {
|
|
||||||
if (j % jointVars[i]->range() == 0) {
|
|
||||||
count ++;
|
|
||||||
}
|
|
||||||
newBeliefs[j] *= prevBeliefs[count];
|
|
||||||
}
|
|
||||||
prevBeliefs = newBeliefs;
|
|
||||||
observedVids.push_back (jointVars[i]->varId());
|
|
||||||
}
|
|
||||||
return prevBeliefs;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BpSolver::initializeSolver (void)
|
|
||||||
{
|
|
||||||
const VarNodes& varNodes = fg_->varNodes();
|
|
||||||
varsI_.reserve (varNodes.size());
|
|
||||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
|
||||||
varsI_.push_back (new SPNodeInfo());
|
|
||||||
}
|
|
||||||
const FacNodes& facNodes = fg_->facNodes();
|
|
||||||
facsI_.reserve (facNodes.size());
|
|
||||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
|
||||||
facsI_.push_back (new SPNodeInfo());
|
|
||||||
}
|
|
||||||
createLinks();
|
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
|
||||||
FacNode* src = links_[i]->getFactor();
|
|
||||||
VarNode* dst = links_[i]->getVariable();
|
|
||||||
ninf (dst)->addSpLink (links_[i]);
|
|
||||||
ninf (src)->addSpLink (links_[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
|
||||||
BpSolver::converged (void)
|
|
||||||
{
|
|
||||||
if (links_.size() == 0) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
if (nIters_ <= 1) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
bool converged = true;
|
|
||||||
if (BpOptions::schedule == BpOptions::Schedule::MAX_RESIDUAL) {
|
|
||||||
double maxResidual = (*(sortedOrder_.begin()))->getResidual();
|
|
||||||
if (maxResidual > BpOptions::accuracy) {
|
|
||||||
converged = false;
|
|
||||||
} else {
|
|
||||||
converged = true;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
|
||||||
double residual = links_[i]->getResidual();
|
|
||||||
if (Constants::DEBUG >= 2) {
|
|
||||||
cout << links_[i]->toString() + " residual = " << residual << endl;
|
|
||||||
}
|
|
||||||
if (residual > BpOptions::accuracy) {
|
|
||||||
converged = false;
|
|
||||||
if (Constants::DEBUG == 0) break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (Constants::DEBUG >= 2) {
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return converged;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BpSolver::printLinkInformation (void) const
|
|
||||||
{
|
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
|
||||||
SpLink* l = links_[i];
|
|
||||||
cout << l->toString() << ":" << endl;
|
|
||||||
cout << " curr msg = " ;
|
|
||||||
cout << l->getMessage() << endl;
|
|
||||||
cout << " next msg = " ;
|
|
||||||
cout << l->getNextMessage() << endl;
|
|
||||||
cout << " residual = " << l->getResidual() << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,339 +0,0 @@
|
|||||||
|
|
||||||
#include "CFactorGraph.h"
|
|
||||||
#include "Factor.h"
|
|
||||||
|
|
||||||
|
|
||||||
bool CFactorGraph::checkForIdenticalFactors = true;
|
|
||||||
|
|
||||||
CFactorGraph::CFactorGraph (const FactorGraph& fg)
|
|
||||||
{
|
|
||||||
groundFg_ = &fg;
|
|
||||||
freeColor_ = 0;
|
|
||||||
|
|
||||||
const VarNodes& varNodes = fg.varNodes();
|
|
||||||
varSignatures_.reserve (varNodes.size());
|
|
||||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
|
||||||
unsigned c = (varNodes[i]->neighbors().size() * 2) + 1;
|
|
||||||
varSignatures_.push_back (Signature (c));
|
|
||||||
}
|
|
||||||
|
|
||||||
const FacNodes& facNodes = fg.facNodes();
|
|
||||||
facSignatures_.reserve (facNodes.size());
|
|
||||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
|
||||||
unsigned c = facNodes[i]->neighbors().size() + 1;
|
|
||||||
facSignatures_.push_back (Signature (c));
|
|
||||||
}
|
|
||||||
|
|
||||||
varColors_.resize (varNodes.size());
|
|
||||||
facColors_.resize (facNodes.size());
|
|
||||||
setInitialColors();
|
|
||||||
createGroups();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
CFactorGraph::~CFactorGraph (void)
|
|
||||||
{
|
|
||||||
for (unsigned i = 0; i < varClusters_.size(); i++) {
|
|
||||||
delete varClusters_[i];
|
|
||||||
}
|
|
||||||
for (unsigned i = 0; i < facClusters_.size(); i++) {
|
|
||||||
delete facClusters_[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
CFactorGraph::setInitialColors (void)
|
|
||||||
{
|
|
||||||
// create the initial variable colors
|
|
||||||
VarColorMap colorMap;
|
|
||||||
const VarNodes& varNodes = groundFg_->varNodes();
|
|
||||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
|
||||||
unsigned dsize = varNodes[i]->range();
|
|
||||||
VarColorMap::iterator it = colorMap.find (dsize);
|
|
||||||
if (it == colorMap.end()) {
|
|
||||||
it = colorMap.insert (make_pair (
|
|
||||||
dsize, vector<Color> (dsize+1,-1))).first;
|
|
||||||
}
|
|
||||||
unsigned idx;
|
|
||||||
if (varNodes[i]->hasEvidence()) {
|
|
||||||
idx = varNodes[i]->getEvidence();
|
|
||||||
} else {
|
|
||||||
idx = dsize;
|
|
||||||
}
|
|
||||||
vector<Color>& stateColors = it->second;
|
|
||||||
if (stateColors[idx] == -1) {
|
|
||||||
stateColors[idx] = getFreeColor();
|
|
||||||
}
|
|
||||||
setColor (varNodes[i], stateColors[idx]);
|
|
||||||
}
|
|
||||||
|
|
||||||
const FacNodes& facNodes = groundFg_->facNodes();
|
|
||||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
|
||||||
facNodes[i]->factor().setDistId (Util::maxUnsigned());
|
|
||||||
}
|
|
||||||
// FIXME FIXME FIXME : pfl should give correct dist ids.
|
|
||||||
if (checkForIdenticalFactors || true) {
|
|
||||||
unsigned groupCount = 1;
|
|
||||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
|
||||||
Factor& f1 = facNodes[i]->factor();
|
|
||||||
if (f1.distId() != Util::maxUnsigned()) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
f1.setDistId (groupCount);
|
|
||||||
for (unsigned j = i + 1; j < facNodes.size(); j++) {
|
|
||||||
Factor& f2 = facNodes[j]->factor();
|
|
||||||
if (f2.distId() != Util::maxUnsigned()) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (f1.size() == f2.size() &&
|
|
||||||
f1.ranges() == f2.ranges() &&
|
|
||||||
f1.params() == f2.params()) {
|
|
||||||
f2.setDistId (groupCount);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
groupCount ++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// create the initial factor colors
|
|
||||||
DistColorMap distColors;
|
|
||||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
|
||||||
unsigned distId = facNodes[i]->factor().distId();
|
|
||||||
DistColorMap::iterator it = distColors.find (distId);
|
|
||||||
if (it == distColors.end()) {
|
|
||||||
it = distColors.insert (make_pair (distId, getFreeColor())).first;
|
|
||||||
}
|
|
||||||
setColor (facNodes[i], it->second);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
CFactorGraph::createGroups (void)
|
|
||||||
{
|
|
||||||
VarSignMap varGroups;
|
|
||||||
FacSignMap facGroups;
|
|
||||||
unsigned nIters = 0;
|
|
||||||
bool groupsHaveChanged = true;
|
|
||||||
const VarNodes& varNodes = groundFg_->varNodes();
|
|
||||||
const FacNodes& facNodes = groundFg_->facNodes();
|
|
||||||
|
|
||||||
while (groupsHaveChanged || nIters == 1) {
|
|
||||||
nIters ++;
|
|
||||||
|
|
||||||
unsigned prevFactorGroupsSize = facGroups.size();
|
|
||||||
facGroups.clear();
|
|
||||||
// set a new color to the factors with the same signature
|
|
||||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
|
||||||
const Signature& signature = getSignature (facNodes[i]);
|
|
||||||
FacSignMap::iterator it = facGroups.find (signature);
|
|
||||||
if (it == facGroups.end()) {
|
|
||||||
it = facGroups.insert (make_pair (signature, FacNodes())).first;
|
|
||||||
}
|
|
||||||
it->second.push_back (facNodes[i]);
|
|
||||||
}
|
|
||||||
for (FacSignMap::iterator it = facGroups.begin();
|
|
||||||
it != facGroups.end(); it++) {
|
|
||||||
Color newColor = getFreeColor();
|
|
||||||
FacNodes& groupMembers = it->second;
|
|
||||||
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
|
||||||
setColor (groupMembers[i], newColor);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// set a new color to the variables with the same signature
|
|
||||||
unsigned prevVarGroupsSize = varGroups.size();
|
|
||||||
varGroups.clear();
|
|
||||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
|
||||||
const Signature& signature = getSignature (varNodes[i]);
|
|
||||||
VarSignMap::iterator it = varGroups.find (signature);
|
|
||||||
if (it == varGroups.end()) {
|
|
||||||
it = varGroups.insert (make_pair (signature, VarNodes())).first;
|
|
||||||
}
|
|
||||||
it->second.push_back (varNodes[i]);
|
|
||||||
}
|
|
||||||
for (VarSignMap::iterator it = varGroups.begin();
|
|
||||||
it != varGroups.end(); it++) {
|
|
||||||
Color newColor = getFreeColor();
|
|
||||||
VarNodes& groupMembers = it->second;
|
|
||||||
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
|
||||||
setColor (groupMembers[i], newColor);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
groupsHaveChanged = prevVarGroupsSize != varGroups.size()
|
|
||||||
|| prevFactorGroupsSize != facGroups.size();
|
|
||||||
}
|
|
||||||
printGroups (varGroups, facGroups);
|
|
||||||
createClusters (varGroups, facGroups);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
CFactorGraph::createClusters (
|
|
||||||
const VarSignMap& varGroups,
|
|
||||||
const FacSignMap& facGroups)
|
|
||||||
{
|
|
||||||
varClusters_.reserve (varGroups.size());
|
|
||||||
for (VarSignMap::const_iterator it = varGroups.begin();
|
|
||||||
it != varGroups.end(); it++) {
|
|
||||||
const VarNodes& groupVars = it->second;
|
|
||||||
VarCluster* vc = new VarCluster (groupVars);
|
|
||||||
for (unsigned i = 0; i < groupVars.size(); i++) {
|
|
||||||
vid2VarCluster_.insert (make_pair (groupVars[i]->varId(), vc));
|
|
||||||
}
|
|
||||||
varClusters_.push_back (vc);
|
|
||||||
}
|
|
||||||
|
|
||||||
facClusters_.reserve (facGroups.size());
|
|
||||||
for (FacSignMap::const_iterator it = facGroups.begin();
|
|
||||||
it != facGroups.end(); it++) {
|
|
||||||
FacNode* groupFactor = it->second[0];
|
|
||||||
const VarNodes& neighs = groupFactor->neighbors();
|
|
||||||
VarClusters varClusters;
|
|
||||||
varClusters.reserve (neighs.size());
|
|
||||||
for (unsigned i = 0; i < neighs.size(); i++) {
|
|
||||||
VarId vid = neighs[i]->varId();
|
|
||||||
varClusters.push_back (vid2VarCluster_.find (vid)->second);
|
|
||||||
}
|
|
||||||
facClusters_.push_back (new FacCluster (it->second, varClusters));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
const Signature&
|
|
||||||
CFactorGraph::getSignature (const VarNode* varNode)
|
|
||||||
{
|
|
||||||
Signature& sign = varSignatures_[varNode->getIndex()];
|
|
||||||
vector<Color>::iterator it = sign.colors.begin();
|
|
||||||
const FacNodes& neighs = varNode->neighbors();
|
|
||||||
for (unsigned i = 0; i < neighs.size(); i++) {
|
|
||||||
*it = getColor (neighs[i]);
|
|
||||||
it ++;
|
|
||||||
*it = neighs[i]->factor().indexOf (varNode->varId());
|
|
||||||
it ++;
|
|
||||||
}
|
|
||||||
*it = getColor (varNode);
|
|
||||||
return sign;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
const Signature&
|
|
||||||
CFactorGraph::getSignature (const FacNode* facNode)
|
|
||||||
{
|
|
||||||
Signature& sign = facSignatures_[facNode->getIndex()];
|
|
||||||
vector<Color>::iterator it = sign.colors.begin();
|
|
||||||
const VarNodes& neighs = facNode->neighbors();
|
|
||||||
for (unsigned i = 0; i < neighs.size(); i++) {
|
|
||||||
*it = getColor (neighs[i]);
|
|
||||||
it ++;
|
|
||||||
}
|
|
||||||
*it = getColor (facNode);
|
|
||||||
return sign;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
FactorGraph*
|
|
||||||
CFactorGraph::getGroundFactorGraph (void) const
|
|
||||||
{
|
|
||||||
FactorGraph* fg = new FactorGraph();
|
|
||||||
for (unsigned i = 0; i < varClusters_.size(); i++) {
|
|
||||||
VarNode* var = varClusters_[i]->getGroundVarNodes()[0];
|
|
||||||
VarNode* newVar = new VarNode (var);
|
|
||||||
varClusters_[i]->setRepresentativeVariable (newVar);
|
|
||||||
fg->addVarNode (newVar);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (unsigned i = 0; i < facClusters_.size(); i++) {
|
|
||||||
const VarClusters& myVarClusters = facClusters_[i]->getVarClusters();
|
|
||||||
Vars myGroundVars;
|
|
||||||
myGroundVars.reserve (myVarClusters.size());
|
|
||||||
for (unsigned j = 0; j < myVarClusters.size(); j++) {
|
|
||||||
VarNode* v = myVarClusters[j]->getRepresentativeVariable();
|
|
||||||
myGroundVars.push_back (v);
|
|
||||||
}
|
|
||||||
FacNode* fn = new FacNode (Factor (myGroundVars,
|
|
||||||
facClusters_[i]->getGroundFactors()[0]->factor().params()));
|
|
||||||
facClusters_[i]->setRepresentativeFactor (fn);
|
|
||||||
fg->addFacNode (fn);
|
|
||||||
for (unsigned j = 0; j < myGroundVars.size(); j++) {
|
|
||||||
fg->addEdge (static_cast<VarNode*> (myGroundVars[j]), fn);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return fg;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
unsigned
|
|
||||||
CFactorGraph::getEdgeCount (
|
|
||||||
const FacCluster* fc,
|
|
||||||
const VarCluster* vc) const
|
|
||||||
{
|
|
||||||
unsigned count = 0;
|
|
||||||
VarId vid = vc->getGroundVarNodes().front()->varId();
|
|
||||||
const FacNodes& clusterGroundFactors = fc->getGroundFactors();
|
|
||||||
for (unsigned i = 0; i < clusterGroundFactors.size(); i++) {
|
|
||||||
if (clusterGroundFactors[i]->factor().contains (vid)) {
|
|
||||||
count ++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// CVarNodes vars = vc->getGroundVarNodes();
|
|
||||||
// for (unsigned i = 1; i < vars.size(); i++) {
|
|
||||||
// VarNode* var = vc->getGroundVarNodes()[i];
|
|
||||||
// unsigned count2 = 0;
|
|
||||||
// for (unsigned i = 0; i < clusterGroundFactors.size(); i++) {
|
|
||||||
// if (clusterGroundFactors[i]->getPosition (var) != -1) {
|
|
||||||
// count2 ++;
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// if (count != count2) { cout << "oops!" << endl; abort(); }
|
|
||||||
// }
|
|
||||||
return count;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
CFactorGraph::printGroups (
|
|
||||||
const VarSignMap& varGroups,
|
|
||||||
const FacSignMap& facGroups) const
|
|
||||||
{
|
|
||||||
unsigned count = 1;
|
|
||||||
cout << "variable groups:" << endl;
|
|
||||||
for (VarSignMap::const_iterator it = varGroups.begin();
|
|
||||||
it != varGroups.end(); it++) {
|
|
||||||
const VarNodes& groupMembers = it->second;
|
|
||||||
if (groupMembers.size() > 0) {
|
|
||||||
cout << count << ": " ;
|
|
||||||
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
|
||||||
cout << groupMembers[i]->label() << " " ;
|
|
||||||
}
|
|
||||||
count ++;
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
count = 1;
|
|
||||||
cout << endl << "factor groups:" << endl;
|
|
||||||
for (FacSignMap::const_iterator it = facGroups.begin();
|
|
||||||
it != facGroups.end(); it++) {
|
|
||||||
const FacNodes& groupMembers = it->second;
|
|
||||||
if (groupMembers.size() > 0) {
|
|
||||||
cout << ++count << ": " ;
|
|
||||||
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
|
||||||
cout << groupMembers[i]->getLabel() << " " ;
|
|
||||||
}
|
|
||||||
count ++;
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,247 +0,0 @@
|
|||||||
#ifndef HORUS_CFACTORGRAPH_H
|
|
||||||
#define HORUS_CFACTORGRAPH_H
|
|
||||||
|
|
||||||
#include <unordered_map>
|
|
||||||
|
|
||||||
#include "FactorGraph.h"
|
|
||||||
#include "Factor.h"
|
|
||||||
#include "Horus.h"
|
|
||||||
|
|
||||||
class VarCluster;
|
|
||||||
class FacCluster;
|
|
||||||
class Distribution;
|
|
||||||
class Signature;
|
|
||||||
|
|
||||||
class SignatureHash;
|
|
||||||
|
|
||||||
|
|
||||||
typedef long Color;
|
|
||||||
|
|
||||||
typedef unordered_map<unsigned, vector<Color>> VarColorMap;
|
|
||||||
|
|
||||||
typedef unordered_map<unsigned, Color> DistColorMap;
|
|
||||||
typedef unordered_map<VarId, VarCluster*> VarId2VarCluster;
|
|
||||||
|
|
||||||
typedef vector<VarCluster*> VarClusters;
|
|
||||||
typedef vector<FacCluster*> FacClusters;
|
|
||||||
|
|
||||||
typedef unordered_map<Signature, VarNodes, SignatureHash> VarSignMap;
|
|
||||||
typedef unordered_map<Signature, FacNodes, SignatureHash> FacSignMap;
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
struct Signature
|
|
||||||
{
|
|
||||||
Signature (unsigned size) : colors(size) { }
|
|
||||||
|
|
||||||
bool operator< (const Signature& sig) const
|
|
||||||
{
|
|
||||||
if (colors.size() < sig.colors.size()) {
|
|
||||||
return true;
|
|
||||||
} else if (colors.size() > sig.colors.size()) {
|
|
||||||
return false;
|
|
||||||
} else {
|
|
||||||
for (unsigned i = 0; i < colors.size(); i++) {
|
|
||||||
if (colors[i] < sig.colors[i]) {
|
|
||||||
return true;
|
|
||||||
} else if (colors[i] > sig.colors[i]) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool operator== (const Signature& sig) const
|
|
||||||
{
|
|
||||||
if (colors.size() != sig.colors.size()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
for (unsigned i = 0; i < colors.size(); i++) {
|
|
||||||
if (colors[i] != sig.colors[i]) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
vector<Color> colors;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
struct SignatureHash
|
|
||||||
{
|
|
||||||
size_t operator() (const Signature &sig) const
|
|
||||||
{
|
|
||||||
size_t val = hash<size_t>()(sig.colors.size());
|
|
||||||
for (unsigned i = 0; i < sig.colors.size(); i++) {
|
|
||||||
val ^= hash<size_t>()(sig.colors[i]);
|
|
||||||
}
|
|
||||||
return val;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class VarCluster
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
VarCluster (const VarNodes& vs)
|
|
||||||
{
|
|
||||||
for (unsigned i = 0; i < vs.size(); i++) {
|
|
||||||
groundVars_.push_back (vs[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void addFacCluster (FacCluster* fc)
|
|
||||||
{
|
|
||||||
facClusters_.push_back (fc);
|
|
||||||
}
|
|
||||||
|
|
||||||
const FacClusters& getFacClusters (void) const
|
|
||||||
{
|
|
||||||
return facClusters_;
|
|
||||||
}
|
|
||||||
|
|
||||||
VarNode* getRepresentativeVariable (void) const { return representVar_; }
|
|
||||||
|
|
||||||
void setRepresentativeVariable (VarNode* v) { representVar_ = v; }
|
|
||||||
|
|
||||||
const VarNodes& getGroundVarNodes (void) const { return groundVars_; }
|
|
||||||
|
|
||||||
private:
|
|
||||||
VarNodes groundVars_;
|
|
||||||
FacClusters facClusters_;
|
|
||||||
VarNode* representVar_;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
class FacCluster
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
FacCluster (const FacNodes& groundFactors, const VarClusters& vcs)
|
|
||||||
{
|
|
||||||
groundFactors_ = groundFactors;
|
|
||||||
varClusters_ = vcs;
|
|
||||||
for (unsigned i = 0; i < varClusters_.size(); i++) {
|
|
||||||
varClusters_[i]->addFacCluster (this);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const VarClusters& getVarClusters (void) const
|
|
||||||
{
|
|
||||||
return varClusters_;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool containsGround (const FacNode* fn)
|
|
||||||
{
|
|
||||||
for (unsigned i = 0; i < groundFactors_.size(); i++) {
|
|
||||||
if (groundFactors_[i] == fn) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
FacNode* getRepresentativeFactor (void) const
|
|
||||||
{
|
|
||||||
return representFactor_;
|
|
||||||
}
|
|
||||||
|
|
||||||
void setRepresentativeFactor (FacNode* fn)
|
|
||||||
{
|
|
||||||
representFactor_ = fn;
|
|
||||||
}
|
|
||||||
|
|
||||||
const FacNodes& getGroundFactors (void) const
|
|
||||||
{
|
|
||||||
return groundFactors_;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
private:
|
|
||||||
FacNodes groundFactors_;
|
|
||||||
VarClusters varClusters_;
|
|
||||||
FacNode* representFactor_;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
class CFactorGraph
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
CFactorGraph (const FactorGraph&);
|
|
||||||
|
|
||||||
~CFactorGraph (void);
|
|
||||||
|
|
||||||
const VarClusters& getVarClusters (void) { return varClusters_; }
|
|
||||||
|
|
||||||
const FacClusters& getFacClusters (void) { return facClusters_; }
|
|
||||||
|
|
||||||
VarNode* getEquivalentVariable (VarId vid)
|
|
||||||
{
|
|
||||||
VarCluster* vc = vid2VarCluster_.find (vid)->second;
|
|
||||||
return vc->getRepresentativeVariable();
|
|
||||||
}
|
|
||||||
|
|
||||||
FactorGraph* getGroundFactorGraph (void) const;
|
|
||||||
|
|
||||||
unsigned getEdgeCount (const FacCluster*, const VarCluster*) const;
|
|
||||||
|
|
||||||
static bool checkForIdenticalFactors;
|
|
||||||
|
|
||||||
private:
|
|
||||||
Color getFreeColor (void)
|
|
||||||
{
|
|
||||||
++ freeColor_;
|
|
||||||
return freeColor_ - 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
Color getColor (const VarNode* vn) const
|
|
||||||
{
|
|
||||||
return varColors_[vn->getIndex()];
|
|
||||||
}
|
|
||||||
Color getColor (const FacNode* fn) const {
|
|
||||||
return facColors_[fn->getIndex()];
|
|
||||||
}
|
|
||||||
|
|
||||||
void setColor (const VarNode* vn, Color c)
|
|
||||||
{
|
|
||||||
varColors_[vn->getIndex()] = c;
|
|
||||||
}
|
|
||||||
|
|
||||||
void setColor (const FacNode* fn, Color c)
|
|
||||||
{
|
|
||||||
facColors_[fn->getIndex()] = c;
|
|
||||||
}
|
|
||||||
|
|
||||||
VarCluster* getVariableCluster (VarId vid) const
|
|
||||||
{
|
|
||||||
return vid2VarCluster_.find (vid)->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
void setInitialColors (void);
|
|
||||||
|
|
||||||
void createGroups (void);
|
|
||||||
|
|
||||||
void createClusters (const VarSignMap&, const FacSignMap&);
|
|
||||||
|
|
||||||
const Signature& getSignature (const VarNode*);
|
|
||||||
|
|
||||||
const Signature& getSignature (const FacNode*);
|
|
||||||
|
|
||||||
void printGroups (const VarSignMap&, const FacSignMap&) const;
|
|
||||||
|
|
||||||
Color freeColor_;
|
|
||||||
vector<Color> varColors_;
|
|
||||||
vector<Color> facColors_;
|
|
||||||
vector<Signature> varSignatures_;
|
|
||||||
vector<Signature> facSignatures_;
|
|
||||||
VarClusters varClusters_;
|
|
||||||
FacClusters facClusters_;
|
|
||||||
VarId2VarCluster vid2VarCluster_;
|
|
||||||
const FactorGraph* groundFg_;
|
|
||||||
};
|
|
||||||
|
|
||||||
#endif // HORUS_CFACTORGRAPH_H
|
|
||||||
|
|
@ -1,244 +0,0 @@
|
|||||||
#include "CbpSolver.h"
|
|
||||||
|
|
||||||
|
|
||||||
CbpSolver::CbpSolver (const FactorGraph& fg) : BpSolver (fg)
|
|
||||||
{
|
|
||||||
unsigned nGroundVars, nGroundFacs, nWithoutNeighs;
|
|
||||||
if (Constants::COLLECT_STATS) {
|
|
||||||
nGroundVars = fg_->varNodes().size();
|
|
||||||
nGroundFacs = fg_->facNodes().size();
|
|
||||||
const VarNodes& vars = fg_->varNodes();
|
|
||||||
nWithoutNeighs = 0;
|
|
||||||
for (unsigned i = 0; i < vars.size(); i++) {
|
|
||||||
const FacNodes& factors = vars[i]->neighbors();
|
|
||||||
if (factors.size() == 1 && factors[0]->neighbors().size() == 1) {
|
|
||||||
nWithoutNeighs ++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
cfg_ = new CFactorGraph (fg);
|
|
||||||
fg_ = cfg_->getGroundFactorGraph();
|
|
||||||
if (Constants::COLLECT_STATS) {
|
|
||||||
unsigned nClusterVars = fg_->varNodes().size();
|
|
||||||
unsigned nClusterFacs = fg_->facNodes().size();
|
|
||||||
Statistics::updateCompressingStatistics (nGroundVars,
|
|
||||||
nGroundFacs, nClusterVars, nClusterFacs, nWithoutNeighs);
|
|
||||||
}
|
|
||||||
Util::printHeader ("Uncompressed Factor Graph");
|
|
||||||
fg.print();
|
|
||||||
Util::printHeader ("Compressed Factor Graph");
|
|
||||||
fg_->print();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
CbpSolver::~CbpSolver (void)
|
|
||||||
{
|
|
||||||
delete cfg_;
|
|
||||||
delete fg_;
|
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
|
||||||
delete links_[i];
|
|
||||||
}
|
|
||||||
links_.clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
|
||||||
CbpSolver::getPosterioriOf (VarId vid)
|
|
||||||
{
|
|
||||||
if (runned_ == false) {
|
|
||||||
runSolver();
|
|
||||||
}
|
|
||||||
assert (cfg_->getEquivalentVariable (vid));
|
|
||||||
VarNode* var = cfg_->getEquivalentVariable (vid);
|
|
||||||
Params probs;
|
|
||||||
if (var->hasEvidence()) {
|
|
||||||
probs.resize (var->range(), LogAware::noEvidence());
|
|
||||||
probs[var->getEvidence()] = LogAware::withEvidence();
|
|
||||||
} else {
|
|
||||||
probs.resize (var->range(), LogAware::multIdenty());
|
|
||||||
const SpLinkSet& links = ninf(var)->getLinks();
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
for (unsigned i = 0; i < links.size(); i++) {
|
|
||||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
|
||||||
Util::add (probs, l->poweredMessage());
|
|
||||||
}
|
|
||||||
LogAware::normalize (probs);
|
|
||||||
Util::fromLog (probs);
|
|
||||||
} else {
|
|
||||||
for (unsigned i = 0; i < links.size(); i++) {
|
|
||||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
|
||||||
Util::multiply (probs, l->poweredMessage());
|
|
||||||
}
|
|
||||||
LogAware::normalize (probs);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return probs;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
|
||||||
CbpSolver::getJointDistributionOf (const VarIds& jointVids)
|
|
||||||
{
|
|
||||||
VarIds eqVarIds;
|
|
||||||
for (unsigned i = 0; i < jointVids.size(); i++) {
|
|
||||||
VarNode* vn = cfg_->getEquivalentVariable (jointVids[i]);
|
|
||||||
eqVarIds.push_back (vn->varId());
|
|
||||||
}
|
|
||||||
return BpSolver::getJointDistributionOf (eqVarIds);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
CbpSolver::createLinks (void)
|
|
||||||
{
|
|
||||||
const FacClusters& fcs = cfg_->getFacClusters();
|
|
||||||
for (unsigned i = 0; i < fcs.size(); i++) {
|
|
||||||
const VarClusters& vcs = fcs[i]->getVarClusters();
|
|
||||||
for (unsigned j = 0; j < vcs.size(); j++) {
|
|
||||||
unsigned c = cfg_->getEdgeCount (fcs[i], vcs[j]);
|
|
||||||
links_.push_back (new CbpSolverLink (
|
|
||||||
fcs[i]->getRepresentativeFactor(),
|
|
||||||
vcs[j]->getRepresentativeVariable(), c));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
CbpSolver::maxResidualSchedule (void)
|
|
||||||
{
|
|
||||||
if (nIters_ == 1) {
|
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
|
||||||
calculateMessage (links_[i]);
|
|
||||||
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
|
|
||||||
linkMap_.insert (make_pair (links_[i], it));
|
|
||||||
if (Constants::DEBUG >= 2 && Constants::DEBUG < 5) {
|
|
||||||
cout << "calculating " << links_[i]->toString() << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (unsigned c = 0; c < links_.size(); c++) {
|
|
||||||
if (Constants::DEBUG >= 2) {
|
|
||||||
cout << endl << "current residuals:" << endl;
|
|
||||||
for (SortedOrder::iterator it = sortedOrder_.begin();
|
|
||||||
it != sortedOrder_.end(); it ++) {
|
|
||||||
cout << " " << setw (30) << left << (*it)->toString();
|
|
||||||
cout << "residual = " << (*it)->getResidual() << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
SortedOrder::iterator it = sortedOrder_.begin();
|
|
||||||
SpLink* link = *it;
|
|
||||||
if (Constants::DEBUG >= 2) {
|
|
||||||
cout << "updating " << (*sortedOrder_.begin())->toString() << endl;
|
|
||||||
}
|
|
||||||
if (link->getResidual() < BpOptions::accuracy) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
link->updateMessage();
|
|
||||||
link->clearResidual();
|
|
||||||
sortedOrder_.erase (it);
|
|
||||||
linkMap_.find (link)->second = sortedOrder_.insert (link);
|
|
||||||
|
|
||||||
// update the messages that depend on message source --> destin
|
|
||||||
const FacNodes& factorNeighbors = link->getVariable()->neighbors();
|
|
||||||
for (unsigned i = 0; i < factorNeighbors.size(); i++) {
|
|
||||||
const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks();
|
|
||||||
for (unsigned j = 0; j < links.size(); j++) {
|
|
||||||
if (links[j]->getVariable() != link->getVariable()) {
|
|
||||||
if (Constants::DEBUG >= 2 && Constants::DEBUG < 5) {
|
|
||||||
cout << " calculating " << links[j]->toString() << endl;
|
|
||||||
}
|
|
||||||
calculateMessage (links[j]);
|
|
||||||
SpLinkMap::iterator iter = linkMap_.find (links[j]);
|
|
||||||
sortedOrder_.erase (iter->second);
|
|
||||||
iter->second = sortedOrder_.insert (links[j]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// in counting bp, the message that a variable X sends to
|
|
||||||
// to a factor F depends on the message that F sent to the X
|
|
||||||
const SpLinkSet& links = ninf(link->getFactor())->getLinks();
|
|
||||||
for (unsigned i = 0; i < links.size(); i++) {
|
|
||||||
if (links[i]->getVariable() != link->getVariable()) {
|
|
||||||
if (Constants::DEBUG >= 2 && Constants::DEBUG < 5) {
|
|
||||||
cout << " calculating " << links[i]->toString() << endl;
|
|
||||||
}
|
|
||||||
calculateMessage (links[i]);
|
|
||||||
SpLinkMap::iterator iter = linkMap_.find (links[i]);
|
|
||||||
sortedOrder_.erase (iter->second);
|
|
||||||
iter->second = sortedOrder_.insert (links[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
|
||||||
CbpSolver::getVar2FactorMsg (const SpLink* link) const
|
|
||||||
{
|
|
||||||
Params msg;
|
|
||||||
const VarNode* src = link->getVariable();
|
|
||||||
const FacNode* dst = link->getFactor();
|
|
||||||
const CbpSolverLink* l = static_cast<const CbpSolverLink*> (link);
|
|
||||||
if (src->hasEvidence()) {
|
|
||||||
msg.resize (src->range(), LogAware::noEvidence());
|
|
||||||
double value = link->getMessage()[src->getEvidence()];
|
|
||||||
msg[src->getEvidence()] = LogAware::pow (value, l->nrEdges() - 1);
|
|
||||||
} else {
|
|
||||||
msg = link->getMessage();
|
|
||||||
LogAware::pow (msg, l->nrEdges() - 1);
|
|
||||||
}
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
cout << " " << "init: " << msg << endl;
|
|
||||||
}
|
|
||||||
const SpLinkSet& links = ninf(src)->getLinks();
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
for (unsigned i = 0; i < links.size(); i++) {
|
|
||||||
if (links[i]->getFactor() != dst) {
|
|
||||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
|
||||||
Util::add (msg, l->poweredMessage());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (unsigned i = 0; i < links.size(); i++) {
|
|
||||||
if (links[i]->getFactor() != dst) {
|
|
||||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
|
||||||
Util::multiply (msg, l->poweredMessage());
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
cout << " msg from " << l->getFactor()->getLabel() << ": " ;
|
|
||||||
cout << l->poweredMessage() << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (Constants::DEBUG >= 5) {
|
|
||||||
cout << " result = " << msg << endl;
|
|
||||||
}
|
|
||||||
return msg;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
CbpSolver::printLinkInformation (void) const
|
|
||||||
{
|
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
|
||||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links_[i]);
|
|
||||||
cout << l->toString() << ":" << endl;
|
|
||||||
cout << " curr msg = " << l->getMessage() << endl;
|
|
||||||
cout << " next msg = " << l->getNextMessage() << endl;
|
|
||||||
cout << " powered = " << l->poweredMessage() << endl;
|
|
||||||
cout << " residual = " << l->getResidual() << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,60 +0,0 @@
|
|||||||
#ifndef HORUS_CBP_H
|
|
||||||
#define HORUS_CBP_H
|
|
||||||
|
|
||||||
#include "BpSolver.h"
|
|
||||||
#include "CFactorGraph.h"
|
|
||||||
|
|
||||||
class Factor;
|
|
||||||
|
|
||||||
class CbpSolverLink : public SpLink
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
CbpSolverLink (FacNode* fn, VarNode* vn, unsigned c)
|
|
||||||
: SpLink (fn, vn), nrEdges_(c),
|
|
||||||
pwdMsg_(vn->range(), LogAware::one()) { }
|
|
||||||
|
|
||||||
unsigned nrEdges (void) const { return nrEdges_; }
|
|
||||||
|
|
||||||
const Params& poweredMessage (void) const { return pwdMsg_; }
|
|
||||||
|
|
||||||
void updateMessage (void)
|
|
||||||
{
|
|
||||||
pwdMsg_ = *nextMsg_;
|
|
||||||
swap (currMsg_, nextMsg_);
|
|
||||||
msgSended_ = true;
|
|
||||||
LogAware::pow (pwdMsg_, nrEdges_);
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
unsigned nrEdges_;
|
|
||||||
Params pwdMsg_;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CbpSolver : public BpSolver
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
CbpSolver (const FactorGraph& fg);
|
|
||||||
|
|
||||||
~CbpSolver (void);
|
|
||||||
|
|
||||||
Params getPosterioriOf (VarId);
|
|
||||||
|
|
||||||
Params getJointDistributionOf (const VarIds&);
|
|
||||||
|
|
||||||
private:
|
|
||||||
|
|
||||||
void createLinks (void);
|
|
||||||
|
|
||||||
void maxResidualSchedule (void);
|
|
||||||
|
|
||||||
Params getVar2FactorMsg (const SpLink*) const;
|
|
||||||
|
|
||||||
void printLinkInformation (void) const;
|
|
||||||
|
|
||||||
CFactorGraph* cfg_;
|
|
||||||
};
|
|
||||||
|
|
||||||
#endif // HORUS_CBP_H
|
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
@ -1,303 +0,0 @@
|
|||||||
#include <limits>
|
|
||||||
|
|
||||||
#include <fstream>
|
|
||||||
|
|
||||||
#include "ElimGraph.h"
|
|
||||||
|
|
||||||
ElimHeuristic ElimGraph::elimHeuristic_ = MIN_NEIGHBORS;
|
|
||||||
|
|
||||||
|
|
||||||
ElimGraph::ElimGraph (const vector<Factor*>& factors)
|
|
||||||
{
|
|
||||||
for (unsigned i = 0; i < factors.size(); i++) {
|
|
||||||
const VarIds& vids = factors[i]->arguments();
|
|
||||||
for (unsigned j = 0; j < vids.size() - 1; j++) {
|
|
||||||
EgNode* n1 = getEgNode (vids[j]);
|
|
||||||
if (n1 == 0) {
|
|
||||||
n1 = new EgNode (vids[j], factors[i]->range (j));
|
|
||||||
addNode (n1);
|
|
||||||
}
|
|
||||||
for (unsigned k = j + 1; k < vids.size(); k++) {
|
|
||||||
EgNode* n2 = getEgNode (vids[k]);
|
|
||||||
if (n2 == 0) {
|
|
||||||
n2 = new EgNode (vids[k], factors[i]->range (k));
|
|
||||||
addNode (n2);
|
|
||||||
}
|
|
||||||
if (neighbors (n1, n2) == false) {
|
|
||||||
addEdge (n1, n2);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (vids.size() == 1) {
|
|
||||||
if (getEgNode (vids[0]) == 0) {
|
|
||||||
addNode (new EgNode (vids[0], factors[i]->range (0)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
ElimGraph::~ElimGraph (void)
|
|
||||||
{
|
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
|
||||||
delete nodes_[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
VarIds
|
|
||||||
ElimGraph::getEliminatingOrder (const VarIds& exclude)
|
|
||||||
{
|
|
||||||
VarIds elimOrder;
|
|
||||||
marked_.resize (nodes_.size(), false);
|
|
||||||
for (unsigned i = 0; i < exclude.size(); i++) {
|
|
||||||
assert (getEgNode (exclude[i]));
|
|
||||||
EgNode* node = getEgNode (exclude[i]);
|
|
||||||
marked_[*node] = true;
|
|
||||||
}
|
|
||||||
unsigned nVarsToEliminate = nodes_.size() - exclude.size();
|
|
||||||
for (unsigned i = 0; i < nVarsToEliminate; i++) {
|
|
||||||
EgNode* node = getLowestCostNode();
|
|
||||||
marked_[*node] = true;
|
|
||||||
elimOrder.push_back (node->varId());
|
|
||||||
connectAllNeighbors (node);
|
|
||||||
}
|
|
||||||
return elimOrder;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
ElimGraph::print (void) const
|
|
||||||
{
|
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
|
||||||
cout << "node " << nodes_[i]->label() << " neighs:" ;
|
|
||||||
vector<EgNode*> neighs = nodes_[i]->neighbors();
|
|
||||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
|
||||||
cout << " " << neighs[j]->label();
|
|
||||||
}
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
ElimGraph::exportToGraphViz (
|
|
||||||
const char* fileName,
|
|
||||||
bool showNeighborless,
|
|
||||||
const VarIds& highlightVarIds) const
|
|
||||||
{
|
|
||||||
ofstream out (fileName);
|
|
||||||
if (!out.is_open()) {
|
|
||||||
cerr << "error: cannot open file to write at " ;
|
|
||||||
cerr << "Markov::exportToDotFile()" << endl;
|
|
||||||
abort();
|
|
||||||
}
|
|
||||||
|
|
||||||
out << "strict graph {" << endl;
|
|
||||||
|
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
|
||||||
if (showNeighborless || nodes_[i]->neighbors().size() != 0) {
|
|
||||||
out << '"' << nodes_[i]->label() << '"' << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (unsigned i = 0; i < highlightVarIds.size(); i++) {
|
|
||||||
EgNode* node =getEgNode (highlightVarIds[i]);
|
|
||||||
if (node) {
|
|
||||||
out << '"' << node->label() << '"' ;
|
|
||||||
out << " [shape=box3d]" << endl;
|
|
||||||
} else {
|
|
||||||
cout << "error: invalid variable id: " << highlightVarIds[i] << endl;
|
|
||||||
abort();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
|
||||||
vector<EgNode*> neighs = nodes_[i]->neighbors();
|
|
||||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
|
||||||
out << '"' << nodes_[i]->label() << '"' << " -- " ;
|
|
||||||
out << '"' << neighs[j]->label() << '"' << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
out << "}" << endl;
|
|
||||||
out.close();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
VarIds
|
|
||||||
ElimGraph::getEliminationOrder (
|
|
||||||
const vector<Factor*> factors,
|
|
||||||
VarIds excludedVids)
|
|
||||||
{
|
|
||||||
ElimGraph graph (factors);
|
|
||||||
// graph.print();
|
|
||||||
// graph.exportToGraphViz ("_egg.dot");
|
|
||||||
return graph.getEliminatingOrder (excludedVids);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
ElimGraph::addNode (EgNode* n)
|
|
||||||
{
|
|
||||||
nodes_.push_back (n);
|
|
||||||
n->setIndex (nodes_.size() - 1);
|
|
||||||
varMap_.insert (make_pair (n->varId(), n));
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
EgNode*
|
|
||||||
ElimGraph::getEgNode (VarId vid) const
|
|
||||||
{
|
|
||||||
unordered_map<VarId, EgNode*>::const_iterator it;
|
|
||||||
it = varMap_.find (vid);
|
|
||||||
return (it != varMap_.end()) ? it->second : 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
EgNode*
|
|
||||||
ElimGraph::getLowestCostNode (void) const
|
|
||||||
{
|
|
||||||
EgNode* bestNode = 0;
|
|
||||||
unsigned minCost = std::numeric_limits<unsigned>::max();
|
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
|
||||||
if (marked_[i]) continue;
|
|
||||||
unsigned cost = 0;
|
|
||||||
switch (elimHeuristic_) {
|
|
||||||
case MIN_NEIGHBORS:
|
|
||||||
cost = getNeighborsCost (nodes_[i]);
|
|
||||||
break;
|
|
||||||
case MIN_WEIGHT:
|
|
||||||
cost = getWeightCost (nodes_[i]);
|
|
||||||
break;
|
|
||||||
case MIN_FILL:
|
|
||||||
cost = getFillCost (nodes_[i]);
|
|
||||||
break;
|
|
||||||
case WEIGHTED_MIN_FILL:
|
|
||||||
cost = getWeightedFillCost (nodes_[i]);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
assert (false);
|
|
||||||
}
|
|
||||||
if (cost < minCost) {
|
|
||||||
bestNode = nodes_[i];
|
|
||||||
minCost = cost;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
assert (bestNode);
|
|
||||||
return bestNode;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
unsigned
|
|
||||||
ElimGraph::getNeighborsCost (const EgNode* n) const
|
|
||||||
{
|
|
||||||
unsigned cost = 0;
|
|
||||||
const vector<EgNode*>& neighs = n->neighbors();
|
|
||||||
for (unsigned i = 0; i < neighs.size(); i++) {
|
|
||||||
if (marked_[*neighs[i]] == false) {
|
|
||||||
cost ++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return cost;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
unsigned
|
|
||||||
ElimGraph::getWeightCost (const EgNode* n) const
|
|
||||||
{
|
|
||||||
unsigned cost = 1;
|
|
||||||
const vector<EgNode*>& neighs = n->neighbors();
|
|
||||||
for (unsigned i = 0; i < neighs.size(); i++) {
|
|
||||||
if (marked_[*neighs[i]] == false) {
|
|
||||||
cost *= neighs[i]->range();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return cost;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
unsigned
|
|
||||||
ElimGraph::getFillCost (const EgNode* n) const
|
|
||||||
{
|
|
||||||
unsigned cost = 0;
|
|
||||||
const vector<EgNode*>& neighs = n->neighbors();
|
|
||||||
if (neighs.size() > 0) {
|
|
||||||
for (unsigned i = 0; i < neighs.size() - 1; i++) {
|
|
||||||
if (marked_[*neighs[i]] == true) continue;
|
|
||||||
for (unsigned j = i+1; j < neighs.size(); j++) {
|
|
||||||
if (marked_[*neighs[j]] == true) continue;
|
|
||||||
if (!neighbors (neighs[i], neighs[j])) {
|
|
||||||
cost ++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return cost;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
unsigned
|
|
||||||
ElimGraph::getWeightedFillCost (const EgNode* n) const
|
|
||||||
{
|
|
||||||
unsigned cost = 0;
|
|
||||||
const vector<EgNode*>& neighs = n->neighbors();
|
|
||||||
if (neighs.size() > 0) {
|
|
||||||
for (unsigned i = 0; i < neighs.size() - 1; i++) {
|
|
||||||
if (marked_[*neighs[i]] == true) continue;
|
|
||||||
for (unsigned j = i+1; j < neighs.size(); j++) {
|
|
||||||
if (marked_[*neighs[j]] == true) continue;
|
|
||||||
if (!neighbors (neighs[i], neighs[j])) {
|
|
||||||
cost += neighs[i]->range() * neighs[j]->range();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return cost;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
ElimGraph::connectAllNeighbors (const EgNode* n)
|
|
||||||
{
|
|
||||||
const vector<EgNode*>& neighs = n->neighbors();
|
|
||||||
if (neighs.size() > 0) {
|
|
||||||
for (unsigned i = 0; i < neighs.size() - 1; i++) {
|
|
||||||
if (marked_[*neighs[i]] == true) continue;
|
|
||||||
for (unsigned j = i+1; j < neighs.size(); j++) {
|
|
||||||
if (marked_[*neighs[j]] == true) continue;
|
|
||||||
if (!neighbors (neighs[i], neighs[j])) {
|
|
||||||
addEdge (neighs[i], neighs[j]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
|
||||||
ElimGraph::neighbors (const EgNode* n1, const EgNode* n2) const
|
|
||||||
{
|
|
||||||
const vector<EgNode*>& neighs = n1->neighbors();
|
|
||||||
for (unsigned i = 0; i < neighs.size(); i++) {
|
|
||||||
if (neighs[i] == n2) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
@ -1,88 +0,0 @@
|
|||||||
#ifndef HORUS_ELIMGRAPH_H
|
|
||||||
#define HORUS_ELIMGRAPH_H
|
|
||||||
|
|
||||||
#include "unordered_map"
|
|
||||||
|
|
||||||
#include "FactorGraph.h"
|
|
||||||
#include "Horus.h"
|
|
||||||
|
|
||||||
using namespace std;
|
|
||||||
|
|
||||||
enum ElimHeuristic
|
|
||||||
{
|
|
||||||
MIN_NEIGHBORS,
|
|
||||||
MIN_WEIGHT,
|
|
||||||
MIN_FILL,
|
|
||||||
WEIGHTED_MIN_FILL
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
class EgNode : public Var
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
EgNode (VarId vid, unsigned range) : Var (vid, range) { }
|
|
||||||
|
|
||||||
void addNeighbor (EgNode* n) { neighs_.push_back (n); }
|
|
||||||
|
|
||||||
const vector<EgNode*>& neighbors (void) const { return neighs_; }
|
|
||||||
|
|
||||||
private:
|
|
||||||
vector<EgNode*> neighs_;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
class ElimGraph
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
ElimGraph (const vector<Factor*>&); // TODO
|
|
||||||
|
|
||||||
~ElimGraph (void);
|
|
||||||
|
|
||||||
VarIds getEliminatingOrder (const VarIds&);
|
|
||||||
|
|
||||||
void print (void) const;
|
|
||||||
|
|
||||||
void exportToGraphViz (const char*, bool = true,
|
|
||||||
const VarIds& = VarIds()) const;
|
|
||||||
|
|
||||||
static VarIds getEliminationOrder (const vector<Factor*>, VarIds);
|
|
||||||
|
|
||||||
static void setEliminationHeuristic (ElimHeuristic h)
|
|
||||||
{
|
|
||||||
elimHeuristic_ = h;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
|
|
||||||
void addEdge (EgNode* n1, EgNode* n2)
|
|
||||||
{
|
|
||||||
assert (n1 != n2);
|
|
||||||
n1->addNeighbor (n2);
|
|
||||||
n2->addNeighbor (n1);
|
|
||||||
}
|
|
||||||
|
|
||||||
void addNode (EgNode*);
|
|
||||||
|
|
||||||
EgNode* getEgNode (VarId) const;
|
|
||||||
EgNode* getLowestCostNode (void) const;
|
|
||||||
|
|
||||||
unsigned getNeighborsCost (const EgNode*) const;
|
|
||||||
|
|
||||||
unsigned getWeightCost (const EgNode*) const;
|
|
||||||
|
|
||||||
unsigned getFillCost (const EgNode*) const;
|
|
||||||
|
|
||||||
unsigned getWeightedFillCost (const EgNode*) const;
|
|
||||||
|
|
||||||
void connectAllNeighbors (const EgNode*);
|
|
||||||
|
|
||||||
bool neighbors (const EgNode*, const EgNode*) const;
|
|
||||||
|
|
||||||
vector<EgNode*> nodes_;
|
|
||||||
vector<bool> marked_;
|
|
||||||
unordered_map<VarId, EgNode*> varMap_;
|
|
||||||
static ElimHeuristic elimHeuristic_;
|
|
||||||
};
|
|
||||||
|
|
||||||
#endif // HORUS_ELIMGRAPH_H
|
|
||||||
|
|
@ -1,265 +0,0 @@
|
|||||||
#include <cstdlib>
|
|
||||||
#include <cassert>
|
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
|
|
||||||
#include <iostream>
|
|
||||||
#include <sstream>
|
|
||||||
|
|
||||||
#include "Factor.h"
|
|
||||||
#include "Indexer.h"
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Factor::Factor (const Factor& g)
|
|
||||||
{
|
|
||||||
copyFromFactor (g);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Factor::Factor (
|
|
||||||
const VarIds& vids,
|
|
||||||
const Ranges& ranges,
|
|
||||||
const Params& params,
|
|
||||||
unsigned distId)
|
|
||||||
{
|
|
||||||
args_ = vids;
|
|
||||||
ranges_ = ranges;
|
|
||||||
params_ = params;
|
|
||||||
distId_ = distId;
|
|
||||||
assert (params_.size() == Util::expectedSize (ranges_));
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Factor::Factor (
|
|
||||||
const Vars& vars,
|
|
||||||
const Params& params,
|
|
||||||
unsigned distId)
|
|
||||||
{
|
|
||||||
for (unsigned i = 0; i < vars.size(); i++) {
|
|
||||||
args_.push_back (vars[i]->varId());
|
|
||||||
ranges_.push_back (vars[i]->range());
|
|
||||||
}
|
|
||||||
params_ = params;
|
|
||||||
distId_ = distId;
|
|
||||||
assert (params_.size() == Util::expectedSize (ranges_));
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Factor::sumOutAllExcept (VarId vid)
|
|
||||||
{
|
|
||||||
assert (indexOf (vid) != -1);
|
|
||||||
while (args_.back() != vid) {
|
|
||||||
sumOutLastVariable();
|
|
||||||
}
|
|
||||||
while (args_.front() != vid) {
|
|
||||||
sumOutFirstVariable();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Factor::sumOutAllExcept (const VarIds& vids)
|
|
||||||
{
|
|
||||||
for (int i = 0; i < (int)args_.size(); i++) {
|
|
||||||
if (Util::contains (vids, args_[i]) == false) {
|
|
||||||
sumOut (args_[i]);
|
|
||||||
i --;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Factor::sumOut (VarId vid)
|
|
||||||
{
|
|
||||||
int idx = indexOf (vid);
|
|
||||||
assert (idx != -1);
|
|
||||||
|
|
||||||
if (vid == args_.back()) {
|
|
||||||
sumOutLastVariable(); // optimization
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (vid == args_.front()) {
|
|
||||||
sumOutFirstVariable(); // optimization
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// number of parameters separating a different state of `var',
|
|
||||||
// with the states of the remaining variables fixed
|
|
||||||
unsigned varOffset = 1;
|
|
||||||
|
|
||||||
// number of parameters separating a different state of the variable
|
|
||||||
// on the left of `var', with the states of the remaining vars fixed
|
|
||||||
unsigned leftVarOffset = 1;
|
|
||||||
|
|
||||||
for (int i = args_.size() - 1; i > idx; i--) {
|
|
||||||
varOffset *= ranges_[i];
|
|
||||||
leftVarOffset *= ranges_[i];
|
|
||||||
}
|
|
||||||
leftVarOffset *= ranges_[idx];
|
|
||||||
|
|
||||||
unsigned offset = 0;
|
|
||||||
unsigned count1 = 0;
|
|
||||||
unsigned count2 = 0;
|
|
||||||
unsigned newpsSize = params_.size() / ranges_[idx];
|
|
||||||
|
|
||||||
Params newps;
|
|
||||||
newps.reserve (newpsSize);
|
|
||||||
|
|
||||||
while (newps.size() < newpsSize) {
|
|
||||||
double sum = LogAware::addIdenty();
|
|
||||||
for (unsigned i = 0; i < ranges_[idx]; i++) {
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
sum = Util::logSum (sum, params_[offset]);
|
|
||||||
} else {
|
|
||||||
sum += params_[offset];
|
|
||||||
}
|
|
||||||
offset += varOffset;
|
|
||||||
}
|
|
||||||
newps.push_back (sum);
|
|
||||||
count1 ++;
|
|
||||||
if (idx == (int)args_.size() - 1) {
|
|
||||||
offset = count1 * ranges_[idx];
|
|
||||||
} else {
|
|
||||||
if (((offset - varOffset + 1) % leftVarOffset) == 0) {
|
|
||||||
count1 = 0;
|
|
||||||
count2 ++;
|
|
||||||
}
|
|
||||||
offset = (leftVarOffset * count2) + count1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
args_.erase (args_.begin() + idx);
|
|
||||||
ranges_.erase (ranges_.begin() + idx);
|
|
||||||
params_ = newps;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Factor::sumOutFirstVariable (void)
|
|
||||||
{
|
|
||||||
unsigned range = ranges_.front();
|
|
||||||
unsigned sep = params_.size() / range;
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
for (unsigned i = sep; i < params_.size(); i++) {
|
|
||||||
params_[i % sep] = Util::logSum (params_[i % sep], params_[i]);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (unsigned i = sep; i < params_.size(); i++) {
|
|
||||||
params_[i % sep] += params_[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
params_.resize (sep);
|
|
||||||
args_.erase (args_.begin());
|
|
||||||
ranges_.erase (ranges_.begin());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Factor::sumOutLastVariable (void)
|
|
||||||
{
|
|
||||||
unsigned range = ranges_.back();
|
|
||||||
unsigned idx1 = 0;
|
|
||||||
unsigned idx2 = 0;
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
while (idx1 < params_.size()) {
|
|
||||||
params_[idx2] = params_[idx1];
|
|
||||||
idx1 ++;
|
|
||||||
for (unsigned j = 1; j < range; j++) {
|
|
||||||
params_[idx2] = Util::logSum (params_[idx2], params_[idx1]);
|
|
||||||
idx1 ++;
|
|
||||||
}
|
|
||||||
idx2 ++;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
while (idx1 < params_.size()) {
|
|
||||||
params_[idx2] = params_[idx1];
|
|
||||||
idx1 ++;
|
|
||||||
for (unsigned j = 1; j < range; j++) {
|
|
||||||
params_[idx2] += params_[idx1];
|
|
||||||
idx1 ++;
|
|
||||||
}
|
|
||||||
idx2 ++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
params_.resize (idx2);
|
|
||||||
args_.pop_back();
|
|
||||||
ranges_.pop_back();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Factor::multiply (Factor& g)
|
|
||||||
{
|
|
||||||
if (args_.size() == 0) {
|
|
||||||
copyFromFactor (g);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
TFactor<VarId>::multiply (g);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Factor::reorderAccordingVarIds (void)
|
|
||||||
{
|
|
||||||
VarIds sortedVarIds = args_;
|
|
||||||
sort (sortedVarIds.begin(), sortedVarIds.end());
|
|
||||||
reorderArguments (sortedVarIds);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
string
|
|
||||||
Factor::getLabel (void) const
|
|
||||||
{
|
|
||||||
stringstream ss;
|
|
||||||
ss << "f(" ;
|
|
||||||
for (unsigned i = 0; i < args_.size(); i++) {
|
|
||||||
if (i != 0) ss << "," ;
|
|
||||||
ss << Var (args_[i], ranges_[i]).label();
|
|
||||||
}
|
|
||||||
ss << ")" ;
|
|
||||||
return ss.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Factor::print (void) const
|
|
||||||
{
|
|
||||||
Vars vars;
|
|
||||||
for (unsigned i = 0; i < args_.size(); i++) {
|
|
||||||
vars.push_back (new Var (args_[i], ranges_[i]));
|
|
||||||
}
|
|
||||||
vector<string> jointStrings = Util::getStateLines (vars);
|
|
||||||
for (unsigned i = 0; i < params_.size(); i++) {
|
|
||||||
cout << "[" << distId_ << "] f(" << jointStrings[i] << ")" ;
|
|
||||||
cout << " = " << params_[i] << endl;
|
|
||||||
}
|
|
||||||
cout << endl;
|
|
||||||
for (unsigned i = 0; i < vars.size(); i++) {
|
|
||||||
delete vars[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Factor::copyFromFactor (const Factor& g)
|
|
||||||
{
|
|
||||||
args_ = g.arguments();
|
|
||||||
ranges_ = g.ranges();
|
|
||||||
params_ = g.params();
|
|
||||||
distId_ = g.distId();
|
|
||||||
}
|
|
||||||
|
|
@ -1,288 +0,0 @@
|
|||||||
#ifndef HORUS_FACTOR_H
|
|
||||||
#define HORUS_FACTOR_H
|
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "Var.h"
|
|
||||||
#include "Indexer.h"
|
|
||||||
#include "Util.h"
|
|
||||||
|
|
||||||
|
|
||||||
using namespace std;
|
|
||||||
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
class TFactor
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
const vector<T>& arguments (void) const { return args_; }
|
|
||||||
|
|
||||||
vector<T>& arguments (void) { return args_; }
|
|
||||||
|
|
||||||
const Ranges& ranges (void) const { return ranges_; }
|
|
||||||
|
|
||||||
const Params& params (void) const { return params_; }
|
|
||||||
|
|
||||||
Params& params (void) { return params_; }
|
|
||||||
|
|
||||||
unsigned nrArguments (void) const { return args_.size(); }
|
|
||||||
|
|
||||||
unsigned size (void) const { return params_.size(); }
|
|
||||||
|
|
||||||
unsigned distId (void) const { return distId_; }
|
|
||||||
|
|
||||||
void setDistId (unsigned id) { distId_ = id; }
|
|
||||||
|
|
||||||
void normalize (void) { LogAware::normalize (params_); }
|
|
||||||
|
|
||||||
void setParams (const Params& newParams)
|
|
||||||
{
|
|
||||||
params_ = newParams;
|
|
||||||
assert (params_.size() == Util::expectedSize (ranges_));
|
|
||||||
}
|
|
||||||
|
|
||||||
int indexOf (const T& t) const
|
|
||||||
{
|
|
||||||
int idx = -1;
|
|
||||||
for (unsigned i = 0; i < args_.size(); i++) {
|
|
||||||
if (args_[i] == t) {
|
|
||||||
idx = i;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return idx;
|
|
||||||
}
|
|
||||||
|
|
||||||
const T& argument (unsigned idx) const
|
|
||||||
{
|
|
||||||
assert (idx < args_.size());
|
|
||||||
return args_[idx];
|
|
||||||
}
|
|
||||||
|
|
||||||
T& argument (unsigned idx)
|
|
||||||
{
|
|
||||||
assert (idx < args_.size());
|
|
||||||
return args_[idx];
|
|
||||||
}
|
|
||||||
|
|
||||||
unsigned range (unsigned idx) const
|
|
||||||
{
|
|
||||||
assert (idx < ranges_.size());
|
|
||||||
return ranges_[idx];
|
|
||||||
}
|
|
||||||
|
|
||||||
void multiply (TFactor<T>& g)
|
|
||||||
{
|
|
||||||
const vector<T>& g_args = g.arguments();
|
|
||||||
const Ranges& g_ranges = g.ranges();
|
|
||||||
const Params& g_params = g.params();
|
|
||||||
if (args_ == g_args) {
|
|
||||||
// optimization: if the factors contain the same set of args,
|
|
||||||
// we can do a 1 to 1 operation on the parameters
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
Util::add (params_, g_params);
|
|
||||||
} else {
|
|
||||||
Util::multiply (params_, g_params);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
bool sharedArgs = false;
|
|
||||||
vector<unsigned> gvarpos;
|
|
||||||
for (unsigned i = 0; i < g_args.size(); i++) {
|
|
||||||
int idx = indexOf (g_args[i]);
|
|
||||||
if (idx == -1) {
|
|
||||||
insertArgument (g_args[i], g_ranges[i]);
|
|
||||||
gvarpos.push_back (args_.size() - 1);
|
|
||||||
} else {
|
|
||||||
sharedArgs = true;
|
|
||||||
gvarpos.push_back (idx);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (sharedArgs == false) {
|
|
||||||
// optimization: if the original factors doesn't have common args,
|
|
||||||
// we don't need to marry the states of the common args
|
|
||||||
unsigned count = 0;
|
|
||||||
for (unsigned i = 0; i < params_.size(); i++) {
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
params_[i] += g_params[count];
|
|
||||||
} else {
|
|
||||||
params_[i] *= g_params[count];
|
|
||||||
}
|
|
||||||
count ++;
|
|
||||||
if (count >= g_params.size()) {
|
|
||||||
count = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
StatesIndexer indexer (ranges_, false);
|
|
||||||
while (indexer.valid()) {
|
|
||||||
unsigned g_li = 0;
|
|
||||||
unsigned prod = 1;
|
|
||||||
for (int j = gvarpos.size() - 1; j >= 0; j--) {
|
|
||||||
g_li += indexer[gvarpos[j]] * prod;
|
|
||||||
prod *= g_ranges[j];
|
|
||||||
}
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
params_[indexer] += g_params[g_li];
|
|
||||||
} else {
|
|
||||||
params_[indexer] *= g_params[g_li];
|
|
||||||
}
|
|
||||||
++ indexer;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void absorveEvidence (const T& arg, unsigned evidence)
|
|
||||||
{
|
|
||||||
int idx = indexOf (arg);
|
|
||||||
assert (idx != -1);
|
|
||||||
assert (evidence < ranges_[idx]);
|
|
||||||
Params copy = params_;
|
|
||||||
params_.clear();
|
|
||||||
params_.reserve (copy.size() / ranges_[idx]);
|
|
||||||
StatesIndexer indexer (ranges_);
|
|
||||||
for (unsigned i = 0; i < evidence; i++) {
|
|
||||||
indexer.increment (idx);
|
|
||||||
}
|
|
||||||
while (indexer.valid()) {
|
|
||||||
params_.push_back (copy[indexer]);
|
|
||||||
indexer.incrementExcluding (idx);
|
|
||||||
}
|
|
||||||
args_.erase (args_.begin() + idx);
|
|
||||||
ranges_.erase (ranges_.begin() + idx);
|
|
||||||
}
|
|
||||||
|
|
||||||
void reorderArguments (const vector<T> newArgs)
|
|
||||||
{
|
|
||||||
assert (newArgs.size() == args_.size());
|
|
||||||
if (newArgs == args_) {
|
|
||||||
return; // already in the wanted order
|
|
||||||
}
|
|
||||||
Ranges newRanges;
|
|
||||||
vector<unsigned> positions;
|
|
||||||
for (unsigned i = 0; i < newArgs.size(); i++) {
|
|
||||||
unsigned idx = indexOf (newArgs[i]);
|
|
||||||
newRanges.push_back (ranges_[idx]);
|
|
||||||
positions.push_back (idx);
|
|
||||||
}
|
|
||||||
unsigned N = ranges_.size();
|
|
||||||
Params newParams (params_.size());
|
|
||||||
for (unsigned i = 0; i < params_.size(); i++) {
|
|
||||||
unsigned li = i;
|
|
||||||
// calculate vector index corresponding to linear index
|
|
||||||
vector<unsigned> vi (N);
|
|
||||||
for (int k = N-1; k >= 0; k--) {
|
|
||||||
vi[k] = li % ranges_[k];
|
|
||||||
li /= ranges_[k];
|
|
||||||
}
|
|
||||||
// convert permuted vector index to corresponding linear index
|
|
||||||
unsigned prod = 1;
|
|
||||||
unsigned new_li = 0;
|
|
||||||
for (int k = N - 1; k >= 0; k--) {
|
|
||||||
new_li += vi[positions[k]] * prod;
|
|
||||||
prod *= ranges_[positions[k]];
|
|
||||||
}
|
|
||||||
newParams[new_li] = params_[i];
|
|
||||||
}
|
|
||||||
args_ = newArgs;
|
|
||||||
ranges_ = newRanges;
|
|
||||||
params_ = newParams;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool contains (const T& arg) const
|
|
||||||
{
|
|
||||||
return Util::contains (args_, arg);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool contains (const vector<T>& args) const
|
|
||||||
{
|
|
||||||
for (unsigned i = 0; i < args_.size(); i++) {
|
|
||||||
if (contains (args[i]) == false) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
|
||||||
vector<T> args_;
|
|
||||||
Ranges ranges_;
|
|
||||||
Params params_;
|
|
||||||
unsigned distId_;
|
|
||||||
|
|
||||||
private:
|
|
||||||
void insertArgument (const T& arg, unsigned range)
|
|
||||||
{
|
|
||||||
assert (indexOf (arg) == -1);
|
|
||||||
Params copy = params_;
|
|
||||||
params_.clear();
|
|
||||||
params_.reserve (copy.size() * range);
|
|
||||||
for (unsigned i = 0; i < copy.size(); i++) {
|
|
||||||
for (unsigned reps = 0; reps < range; reps++) {
|
|
||||||
params_.push_back (copy[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
args_.push_back (arg);
|
|
||||||
ranges_.push_back (range);
|
|
||||||
}
|
|
||||||
|
|
||||||
void insertArguments (const vector<T>& args, const Ranges& ranges)
|
|
||||||
{
|
|
||||||
Params copy = params_;
|
|
||||||
unsigned nrStates = 1;
|
|
||||||
for (unsigned i = 0; i < args.size(); i++) {
|
|
||||||
assert (indexOf (args[i]) == -1);
|
|
||||||
args_.push_back (args[i]);
|
|
||||||
ranges_.push_back (ranges[i]);
|
|
||||||
nrStates *= ranges[i];
|
|
||||||
}
|
|
||||||
params_.clear();
|
|
||||||
params_.reserve (copy.size() * nrStates);
|
|
||||||
for (unsigned i = 0; i < copy.size(); i++) {
|
|
||||||
for (unsigned reps = 0; reps < nrStates; reps++) {
|
|
||||||
params_.push_back (copy[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Factor : public TFactor<VarId>
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
Factor (void) { }
|
|
||||||
|
|
||||||
Factor (const Factor&);
|
|
||||||
|
|
||||||
Factor (const VarIds&, const Ranges&, const Params&,
|
|
||||||
unsigned = Util::maxUnsigned());
|
|
||||||
|
|
||||||
Factor (const Vars&, const Params&,
|
|
||||||
unsigned = Util::maxUnsigned());
|
|
||||||
|
|
||||||
void sumOutAllExcept (VarId);
|
|
||||||
|
|
||||||
void sumOutAllExcept (const VarIds&);
|
|
||||||
|
|
||||||
void sumOut (VarId);
|
|
||||||
|
|
||||||
void sumOutFirstVariable (void);
|
|
||||||
|
|
||||||
void sumOutLastVariable (void);
|
|
||||||
|
|
||||||
void multiply (Factor&);
|
|
||||||
|
|
||||||
void reorderAccordingVarIds (void);
|
|
||||||
|
|
||||||
string getLabel (void) const;
|
|
||||||
|
|
||||||
void print (void) const;
|
|
||||||
|
|
||||||
private:
|
|
||||||
void copyFromFactor (const Factor& f);
|
|
||||||
|
|
||||||
};
|
|
||||||
|
|
||||||
#endif // HORUS_FACTOR_H
|
|
||||||
|
|
@ -1,711 +0,0 @@
|
|||||||
|
|
||||||
#include <algorithm>
|
|
||||||
#include <set>
|
|
||||||
|
|
||||||
#include "FoveSolver.h"
|
|
||||||
#include "Histogram.h"
|
|
||||||
#include "Util.h"
|
|
||||||
|
|
||||||
|
|
||||||
vector<LiftedOperator*>
|
|
||||||
LiftedOperator::getValidOps (
|
|
||||||
ParfactorList& pfList,
|
|
||||||
const Grounds& query)
|
|
||||||
{
|
|
||||||
vector<LiftedOperator*> validOps;
|
|
||||||
vector<SumOutOperator*> sumOutOps;
|
|
||||||
vector<CountingOperator*> countOps;
|
|
||||||
vector<GroundOperator*> groundOps;
|
|
||||||
|
|
||||||
sumOutOps = SumOutOperator::getValidOps (pfList, query);
|
|
||||||
countOps = CountingOperator::getValidOps (pfList);
|
|
||||||
groundOps = GroundOperator::getValidOps (pfList);
|
|
||||||
|
|
||||||
validOps.insert (validOps.end(), sumOutOps.begin(), sumOutOps.end());
|
|
||||||
validOps.insert (validOps.end(), countOps.begin(), countOps.end());
|
|
||||||
validOps.insert (validOps.end(), groundOps.begin(), groundOps.end());
|
|
||||||
return validOps;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
LiftedOperator::printValidOps (
|
|
||||||
ParfactorList& pfList,
|
|
||||||
const Grounds& query)
|
|
||||||
{
|
|
||||||
vector<LiftedOperator*> validOps;
|
|
||||||
validOps = LiftedOperator::getValidOps (pfList, query);
|
|
||||||
for (unsigned i = 0; i < validOps.size(); i++) {
|
|
||||||
cout << "-> " << validOps[i]->toString() << endl;
|
|
||||||
delete validOps[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
unsigned
|
|
||||||
SumOutOperator::getCost (void)
|
|
||||||
{
|
|
||||||
TinySet<unsigned> groupSet;
|
|
||||||
ParfactorList::const_iterator pfIter = pfList_.begin();
|
|
||||||
while (pfIter != pfList_.end()) {
|
|
||||||
if ((*pfIter)->containsGroup (group_)) {
|
|
||||||
vector<unsigned> groups = (*pfIter)->getAllGroups();
|
|
||||||
groupSet |= TinySet<unsigned> (groups);
|
|
||||||
}
|
|
||||||
++ pfIter;
|
|
||||||
}
|
|
||||||
unsigned cost = 1;
|
|
||||||
for (unsigned i = 0; i < groupSet.size(); i++) {
|
|
||||||
pfIter = pfList_.begin();
|
|
||||||
while (pfIter != pfList_.end()) {
|
|
||||||
if ((*pfIter)->containsGroup (groupSet[i])) {
|
|
||||||
int idx = (*pfIter)->indexOfGroup (groupSet[i]);
|
|
||||||
cost *= (*pfIter)->range (idx);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
++ pfIter;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return cost;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
SumOutOperator::apply (void)
|
|
||||||
{
|
|
||||||
vector<ParfactorList::iterator> iters
|
|
||||||
= parfactorsWithGroup (pfList_, group_);
|
|
||||||
Parfactor* product = *(iters[0]);
|
|
||||||
pfList_.remove (iters[0]);
|
|
||||||
for (unsigned i = 1; i < iters.size(); i++) {
|
|
||||||
product->multiply (**(iters[i]));
|
|
||||||
pfList_.removeAndDelete (iters[i]);
|
|
||||||
}
|
|
||||||
if (product->nrArguments() == 1) {
|
|
||||||
delete product;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
int fIdx = product->indexOfGroup (group_);
|
|
||||||
LogVarSet excl = product->exclusiveLogVars (fIdx);
|
|
||||||
if (product->constr()->isCountNormalized (excl)) {
|
|
||||||
product->sumOut (fIdx);
|
|
||||||
pfList_.addShattered (product);
|
|
||||||
} else {
|
|
||||||
Parfactors pfs = FoveSolver::countNormalize (product, excl);
|
|
||||||
for (unsigned i = 0; i < pfs.size(); i++) {
|
|
||||||
pfs[i]->sumOut (fIdx);
|
|
||||||
pfList_.add (pfs[i]);
|
|
||||||
}
|
|
||||||
delete product;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
vector<SumOutOperator*>
|
|
||||||
SumOutOperator::getValidOps (
|
|
||||||
ParfactorList& pfList,
|
|
||||||
const Grounds& query)
|
|
||||||
{
|
|
||||||
vector<SumOutOperator*> validOps;
|
|
||||||
set<unsigned> allGroups;
|
|
||||||
ParfactorList::const_iterator it = pfList.begin();
|
|
||||||
while (it != pfList.end()) {
|
|
||||||
const ProbFormulas& formulas = (*it)->arguments();
|
|
||||||
for (unsigned i = 0; i < formulas.size(); i++) {
|
|
||||||
allGroups.insert (formulas[i].group());
|
|
||||||
}
|
|
||||||
++ it;
|
|
||||||
}
|
|
||||||
set<unsigned>::const_iterator groupIt = allGroups.begin();
|
|
||||||
while (groupIt != allGroups.end()) {
|
|
||||||
if (validOp (*groupIt, pfList, query)) {
|
|
||||||
validOps.push_back (new SumOutOperator (*groupIt, pfList));
|
|
||||||
}
|
|
||||||
++ groupIt;
|
|
||||||
}
|
|
||||||
return validOps;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
string
|
|
||||||
SumOutOperator::toString (void)
|
|
||||||
{
|
|
||||||
stringstream ss;
|
|
||||||
vector<ParfactorList::iterator> pfIters;
|
|
||||||
pfIters = parfactorsWithGroup (pfList_, group_);
|
|
||||||
int idx = (*pfIters[0])->indexOfGroup (group_);
|
|
||||||
ProbFormula f = (*pfIters[0])->argument (idx);
|
|
||||||
TupleSet tupleSet = (*pfIters[0])->constr()->tupleSet (f.logVars());
|
|
||||||
ss << "sum out " << f.functor() << "/" << f.arity();
|
|
||||||
ss << "|" << tupleSet << " (group " << group_ << ")";
|
|
||||||
ss << " [cost=" << getCost() << "]" << endl;
|
|
||||||
return ss.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
|
||||||
SumOutOperator::validOp (
|
|
||||||
unsigned group,
|
|
||||||
ParfactorList& pfList,
|
|
||||||
const Grounds& query)
|
|
||||||
{
|
|
||||||
vector<ParfactorList::iterator> pfIters;
|
|
||||||
pfIters = parfactorsWithGroup (pfList, group);
|
|
||||||
if (isToEliminate (*pfIters[0], group, query) == false) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
unordered_map<unsigned, unsigned> groupToRange;
|
|
||||||
for (unsigned i = 0; i < pfIters.size(); i++) {
|
|
||||||
int fIdx = (*pfIters[i])->indexOfGroup (group);
|
|
||||||
if ((*pfIters[i])->argument (fIdx).contains (
|
|
||||||
(*pfIters[i])->elimLogVars()) == false) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
vector<unsigned> ranges = (*pfIters[i])->ranges();
|
|
||||||
vector<unsigned> groups = (*pfIters[i])->getAllGroups();
|
|
||||||
for (unsigned i = 0; i < groups.size(); i++) {
|
|
||||||
unordered_map<unsigned, unsigned>::iterator it;
|
|
||||||
it = groupToRange.find (groups[i]);
|
|
||||||
if (it == groupToRange.end()) {
|
|
||||||
groupToRange.insert (make_pair (groups[i], ranges[i]));
|
|
||||||
} else {
|
|
||||||
if (it->second != ranges[i]) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
vector<ParfactorList::iterator>
|
|
||||||
SumOutOperator::parfactorsWithGroup (
|
|
||||||
ParfactorList& pfList,
|
|
||||||
unsigned group)
|
|
||||||
{
|
|
||||||
vector<ParfactorList::iterator> iters;
|
|
||||||
ParfactorList::iterator pflIt = pfList.begin();
|
|
||||||
while (pflIt != pfList.end()) {
|
|
||||||
if ((*pflIt)->containsGroup (group)) {
|
|
||||||
iters.push_back (pflIt);
|
|
||||||
}
|
|
||||||
++ pflIt;
|
|
||||||
}
|
|
||||||
return iters;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
|
||||||
SumOutOperator::isToEliminate (
|
|
||||||
Parfactor* g,
|
|
||||||
unsigned group,
|
|
||||||
const Grounds& query)
|
|
||||||
{
|
|
||||||
int fIdx = g->indexOfGroup (group);
|
|
||||||
const ProbFormula& formula = g->argument (fIdx);
|
|
||||||
bool toElim = true;
|
|
||||||
for (unsigned i = 0; i < query.size(); i++) {
|
|
||||||
if (formula.functor() == query[i].functor() &&
|
|
||||||
formula.arity() == query[i].arity()) {
|
|
||||||
g->constr()->moveToTop (formula.logVars());
|
|
||||||
if (g->constr()->containsTuple (query[i].args())) {
|
|
||||||
toElim = false;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return toElim;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
unsigned
|
|
||||||
CountingOperator::getCost (void)
|
|
||||||
{
|
|
||||||
unsigned cost = 0;
|
|
||||||
int fIdx = (*pfIter_)->indexOfLogVar (X_);
|
|
||||||
unsigned range = (*pfIter_)->range (fIdx);
|
|
||||||
unsigned size = (*pfIter_)->size() / range;
|
|
||||||
TinySet<unsigned> counts;
|
|
||||||
counts = (*pfIter_)->constr()->getConditionalCounts (X_);
|
|
||||||
for (unsigned i = 0; i < counts.size(); i++) {
|
|
||||||
cost += size * HistogramSet::nrHistograms (counts[i], range);
|
|
||||||
}
|
|
||||||
return cost;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
CountingOperator::apply (void)
|
|
||||||
{
|
|
||||||
if ((*pfIter_)->constr()->isCountNormalized (X_)) {
|
|
||||||
(*pfIter_)->countConvert (X_);
|
|
||||||
} else {
|
|
||||||
Parfactor* pf = *pfIter_;
|
|
||||||
pfList_.remove (pfIter_);
|
|
||||||
Parfactors pfs = FoveSolver::countNormalize (pf, X_);
|
|
||||||
for (unsigned i = 0; i < pfs.size(); i++) {
|
|
||||||
unsigned condCount = pfs[i]->constr()->getConditionalCount (X_);
|
|
||||||
bool cartProduct = pfs[i]->constr()->isCarteesianProduct (
|
|
||||||
pfs[i]->countedLogVars() | X_);
|
|
||||||
if (condCount > 1 && cartProduct) {
|
|
||||||
pfs[i]->countConvert (X_);
|
|
||||||
}
|
|
||||||
pfList_.add (pfs[i]);
|
|
||||||
}
|
|
||||||
delete pf;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
vector<CountingOperator*>
|
|
||||||
CountingOperator::getValidOps (ParfactorList& pfList)
|
|
||||||
{
|
|
||||||
vector<CountingOperator*> validOps;
|
|
||||||
ParfactorList::iterator it = pfList.begin();
|
|
||||||
while (it != pfList.end()) {
|
|
||||||
LogVarSet candidates = (*it)->uncountedLogVars();
|
|
||||||
for (unsigned i = 0; i < candidates.size(); i++) {
|
|
||||||
if (validOp (*it, candidates[i])) {
|
|
||||||
validOps.push_back (new CountingOperator (
|
|
||||||
it, candidates[i], pfList));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
++ it;
|
|
||||||
}
|
|
||||||
return validOps;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
string
|
|
||||||
CountingOperator::toString (void)
|
|
||||||
{
|
|
||||||
stringstream ss;
|
|
||||||
ss << "count convert " << X_ << " in " ;
|
|
||||||
ss << (*pfIter_)->getLabel();
|
|
||||||
ss << " [cost=" << getCost() << "]" << endl;
|
|
||||||
Parfactors pfs = FoveSolver::countNormalize (*pfIter_, X_);
|
|
||||||
if ((*pfIter_)->constr()->isCountNormalized (X_) == false) {
|
|
||||||
for (unsigned i = 0; i < pfs.size(); i++) {
|
|
||||||
ss << " º " << pfs[i]->getLabel() << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (unsigned i = 0; i < pfs.size(); i++) {
|
|
||||||
delete pfs[i];
|
|
||||||
}
|
|
||||||
return ss.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
|
||||||
CountingOperator::validOp (Parfactor* g, LogVar X)
|
|
||||||
{
|
|
||||||
if (g->nrFormulas (X) != 1) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
int fIdx = g->indexOfLogVar (X);
|
|
||||||
if (g->argument (fIdx).isCounting()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
bool countNormalized = g->constr()->isCountNormalized (X);
|
|
||||||
if (countNormalized) {
|
|
||||||
unsigned condCount = g->constr()->getConditionalCount (X);
|
|
||||||
bool cartProduct = g->constr()->isCarteesianProduct (
|
|
||||||
g->countedLogVars() | X);
|
|
||||||
if (condCount == 1 || cartProduct == false) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
unsigned
|
|
||||||
GroundOperator::getCost (void)
|
|
||||||
{
|
|
||||||
unsigned cost = 0;
|
|
||||||
bool isCountingLv = (*pfIter_)->countedLogVars().contains (X_);
|
|
||||||
if (isCountingLv) {
|
|
||||||
int fIdx = (*pfIter_)->indexOfLogVar (X_);
|
|
||||||
unsigned currSize = (*pfIter_)->size();
|
|
||||||
unsigned nrHists = (*pfIter_)->range (fIdx);
|
|
||||||
unsigned range = (*pfIter_)->argument (fIdx).range();
|
|
||||||
unsigned nrSymbols = (*pfIter_)->constr()->getConditionalCount (X_);
|
|
||||||
cost = (currSize / nrHists) * (std::pow (range, nrSymbols));
|
|
||||||
} else {
|
|
||||||
cost = (*pfIter_)->constr()->nrSymbols (X_) * (*pfIter_)->size();
|
|
||||||
}
|
|
||||||
return cost;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
GroundOperator::apply (void)
|
|
||||||
{
|
|
||||||
bool countedLv = (*pfIter_)->countedLogVars().contains (X_);
|
|
||||||
Parfactor* pf = *pfIter_;
|
|
||||||
pfList_.remove (pfIter_);
|
|
||||||
if (countedLv) {
|
|
||||||
pf->fullExpand (X_);
|
|
||||||
pfList_.add (pf);
|
|
||||||
} else {
|
|
||||||
ConstraintTrees cts = pf->constr()->ground (X_);
|
|
||||||
for (unsigned i = 0; i < cts.size(); i++) {
|
|
||||||
pfList_.add (new Parfactor (pf, cts[i]));
|
|
||||||
}
|
|
||||||
delete pf;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
vector<GroundOperator*>
|
|
||||||
GroundOperator::getValidOps (ParfactorList& pfList)
|
|
||||||
{
|
|
||||||
vector<GroundOperator*> validOps;
|
|
||||||
ParfactorList::iterator pfIter = pfList.begin();
|
|
||||||
while (pfIter != pfList.end()) {
|
|
||||||
LogVarSet set = (*pfIter)->logVarSet();
|
|
||||||
for (unsigned i = 0; i < set.size(); i++) {
|
|
||||||
if ((*pfIter)->constr()->isSingleton (set[i]) == false) {
|
|
||||||
validOps.push_back (new GroundOperator (pfIter, set[i], pfList));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
++ pfIter;
|
|
||||||
}
|
|
||||||
return validOps;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
string
|
|
||||||
GroundOperator::toString (void)
|
|
||||||
{
|
|
||||||
stringstream ss;
|
|
||||||
((*pfIter_)->countedLogVars().contains (X_))
|
|
||||||
? ss << "full expanding "
|
|
||||||
: ss << "grounding " ;
|
|
||||||
ss << X_ << " in " << (*pfIter_)->getLabel();
|
|
||||||
ss << " [cost=" << getCost() << "]" << endl;
|
|
||||||
return ss.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
|
||||||
FoveSolver::getPosterioriOf (const Ground& query)
|
|
||||||
{
|
|
||||||
return getJointDistributionOf ({query});
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
|
||||||
FoveSolver::getJointDistributionOf (const Grounds& query)
|
|
||||||
{
|
|
||||||
runSolver (query);
|
|
||||||
(*pfList_.begin())->normalize();
|
|
||||||
Params params = (*pfList_.begin())->params();
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
Util::fromLog (params);
|
|
||||||
}
|
|
||||||
return params;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
FoveSolver::absorveEvidence (
|
|
||||||
ParfactorList& pfList,
|
|
||||||
ObservedFormulas& obsFormulas)
|
|
||||||
{
|
|
||||||
for (unsigned i = 0; i < obsFormulas.size(); i++) {
|
|
||||||
Parfactors newPfs;
|
|
||||||
ParfactorList::iterator it = pfList.begin();
|
|
||||||
while (it != pfList.end()) {
|
|
||||||
Parfactor* pf = *it;
|
|
||||||
it = pfList.remove (it);
|
|
||||||
Parfactors absorvedPfs = absorve (obsFormulas[i], pf);
|
|
||||||
if (absorvedPfs.empty() == false) {
|
|
||||||
if (absorvedPfs.size() == 1 && absorvedPfs[0] == 0) {
|
|
||||||
// just remove pf;
|
|
||||||
} else {
|
|
||||||
Util::addToVector (newPfs, absorvedPfs);
|
|
||||||
}
|
|
||||||
delete pf;
|
|
||||||
} else {
|
|
||||||
it = pfList.insertShattered (it, pf);
|
|
||||||
++ it;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
pfList.add (newPfs);
|
|
||||||
}
|
|
||||||
if (Constants::DEBUG >= 2 && obsFormulas.empty() == false) {
|
|
||||||
Util::printAsteriskLine();
|
|
||||||
cout << "AFTER EVIDENCE ABSORVED" << endl;
|
|
||||||
for (unsigned i = 0; i < obsFormulas.size(); i++) {
|
|
||||||
cout << " -> " << obsFormulas[i] << endl;
|
|
||||||
}
|
|
||||||
Util::printAsteriskLine();
|
|
||||||
pfList.print();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Parfactors
|
|
||||||
FoveSolver::countNormalize (
|
|
||||||
Parfactor* g,
|
|
||||||
const LogVarSet& set)
|
|
||||||
{
|
|
||||||
Parfactors normPfs;
|
|
||||||
if (set.empty()) {
|
|
||||||
normPfs.push_back (new Parfactor (*g));
|
|
||||||
} else {
|
|
||||||
ConstraintTrees normCts = g->constr()->countNormalize (set);
|
|
||||||
for (unsigned i = 0; i < normCts.size(); i++) {
|
|
||||||
normPfs.push_back (new Parfactor (g, normCts[i]));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return normPfs;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
FoveSolver::runSolver (const Grounds& query)
|
|
||||||
{
|
|
||||||
shatterAgainstQuery (query);
|
|
||||||
runWeakBayesBall (query);
|
|
||||||
while (true) {
|
|
||||||
if (Constants::DEBUG >= 2) {
|
|
||||||
Util::printDashedLine();
|
|
||||||
pfList_.print();
|
|
||||||
LiftedOperator::printValidOps (pfList_, query);
|
|
||||||
}
|
|
||||||
LiftedOperator* op = getBestOperation (query);
|
|
||||||
if (op == 0) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if (Constants::DEBUG >= 2) {
|
|
||||||
cout << "best operation: " << op->toString() << endl;
|
|
||||||
}
|
|
||||||
op->apply();
|
|
||||||
delete op;
|
|
||||||
}
|
|
||||||
assert (pfList_.size() > 0);
|
|
||||||
if (pfList_.size() > 1) {
|
|
||||||
ParfactorList::iterator pfIter = pfList_.begin();
|
|
||||||
pfIter ++;
|
|
||||||
while (pfIter != pfList_.end()) {
|
|
||||||
(*pfList_.begin())->multiply (**pfIter);
|
|
||||||
++ pfIter;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
(*pfList_.begin())->reorderAccordingGrounds (query);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
LiftedOperator*
|
|
||||||
FoveSolver::getBestOperation (const Grounds& query)
|
|
||||||
{
|
|
||||||
unsigned bestCost;
|
|
||||||
LiftedOperator* bestOp = 0;
|
|
||||||
vector<LiftedOperator*> validOps;
|
|
||||||
validOps = LiftedOperator::getValidOps (pfList_, query);
|
|
||||||
for (unsigned i = 0; i < validOps.size(); i++) {
|
|
||||||
unsigned cost = validOps[i]->getCost();
|
|
||||||
if ((bestOp == 0) || (cost < bestCost)) {
|
|
||||||
bestOp = validOps[i];
|
|
||||||
bestCost = cost;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (unsigned i = 0; i < validOps.size(); i++) {
|
|
||||||
if (validOps[i] != bestOp) {
|
|
||||||
delete validOps[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return bestOp;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
FoveSolver::runWeakBayesBall (const Grounds& query)
|
|
||||||
{
|
|
||||||
queue<unsigned> todo; // groups to process
|
|
||||||
set<unsigned> done; // processed or in queue
|
|
||||||
for (unsigned i = 0; i < query.size(); i++) {
|
|
||||||
ParfactorList::iterator it = pfList_.begin();
|
|
||||||
while (it != pfList_.end()) {
|
|
||||||
int group = (*it)->findGroup (query[i]);
|
|
||||||
if (group != -1) {
|
|
||||||
todo.push (group);
|
|
||||||
done.insert (group);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
++ it;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
set<Parfactor*> requiredPfs;
|
|
||||||
while (todo.empty() == false) {
|
|
||||||
unsigned group = todo.front();
|
|
||||||
ParfactorList::iterator it = pfList_.begin();
|
|
||||||
while (it != pfList_.end()) {
|
|
||||||
if (Util::contains (requiredPfs, *it) == false &&
|
|
||||||
(*it)->containsGroup (group)) {
|
|
||||||
vector<unsigned> groups = (*it)->getAllGroups();
|
|
||||||
for (unsigned i = 0; i < groups.size(); i++) {
|
|
||||||
if (Util::contains (done, groups[i]) == false) {
|
|
||||||
todo.push (groups[i]);
|
|
||||||
done.insert (groups[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
requiredPfs.insert (*it);
|
|
||||||
}
|
|
||||||
++ it;
|
|
||||||
}
|
|
||||||
todo.pop();
|
|
||||||
}
|
|
||||||
|
|
||||||
ParfactorList::iterator it = pfList_.begin();
|
|
||||||
while (it != pfList_.end()) {
|
|
||||||
if (Util::contains (requiredPfs, *it) == false) {
|
|
||||||
it = pfList_.removeAndDelete (it);
|
|
||||||
} else {
|
|
||||||
++ it;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (Constants::DEBUG >= 2) {
|
|
||||||
Util::printHeader ("REQUIRED PARFACTORS");
|
|
||||||
pfList_.print();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
FoveSolver::shatterAgainstQuery (const Grounds& query)
|
|
||||||
{
|
|
||||||
for (unsigned i = 0; i < query.size(); i++) {
|
|
||||||
if (query[i].isAtom()) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
bool found = false;
|
|
||||||
Parfactors newPfs;
|
|
||||||
ParfactorList::iterator it = pfList_.begin();
|
|
||||||
while (it != pfList_.end()) {
|
|
||||||
if ((*it)->containsGround (query[i])) {
|
|
||||||
found = true;
|
|
||||||
std::pair<ConstraintTree*, ConstraintTree*> split =
|
|
||||||
(*it)->constr()->split (query[i].args(), query[i].arity());
|
|
||||||
ConstraintTree* commCt = split.first;
|
|
||||||
ConstraintTree* exclCt = split.second;
|
|
||||||
newPfs.push_back (new Parfactor (*it, commCt));
|
|
||||||
if (exclCt->empty() == false) {
|
|
||||||
newPfs.push_back (new Parfactor (*it, exclCt));
|
|
||||||
} else {
|
|
||||||
delete exclCt;
|
|
||||||
}
|
|
||||||
it = pfList_.removeAndDelete (it);
|
|
||||||
} else {
|
|
||||||
++ it;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (found == false) {
|
|
||||||
cerr << "error: could not find a parfactor with ground " ;
|
|
||||||
cerr << "`" << query[i] << "'" << endl;
|
|
||||||
exit (0);
|
|
||||||
}
|
|
||||||
pfList_.add (newPfs);
|
|
||||||
}
|
|
||||||
if (Constants::DEBUG >= 2) {
|
|
||||||
cout << endl;
|
|
||||||
Util::printAsteriskLine();
|
|
||||||
cout << "SHATTERED AGAINST THE QUERY" << endl;
|
|
||||||
for (unsigned i = 0; i < query.size(); i++) {
|
|
||||||
cout << " -> " << query[i] << endl;
|
|
||||||
}
|
|
||||||
Util::printAsteriskLine();
|
|
||||||
pfList_.print();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Parfactors
|
|
||||||
FoveSolver::absorve (
|
|
||||||
ObservedFormula& obsFormula,
|
|
||||||
Parfactor* g)
|
|
||||||
{
|
|
||||||
Parfactors absorvedPfs;
|
|
||||||
const ProbFormulas& formulas = g->arguments();
|
|
||||||
for (unsigned i = 0; i < formulas.size(); i++) {
|
|
||||||
if (obsFormula.functor() == formulas[i].functor() &&
|
|
||||||
obsFormula.arity() == formulas[i].arity()) {
|
|
||||||
|
|
||||||
if (obsFormula.isAtom()) {
|
|
||||||
if (formulas.size() > 1) {
|
|
||||||
g->absorveEvidence (formulas[i], obsFormula.evidence());
|
|
||||||
} else {
|
|
||||||
// hack to erase parfactor g
|
|
||||||
absorvedPfs.push_back (0);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
g->constr()->moveToTop (formulas[i].logVars());
|
|
||||||
std::pair<ConstraintTree*, ConstraintTree*> res
|
|
||||||
= g->constr()->split (&(obsFormula.constr()), formulas[i].arity());
|
|
||||||
ConstraintTree* commCt = res.first;
|
|
||||||
ConstraintTree* exclCt = res.second;
|
|
||||||
|
|
||||||
if (commCt->empty() == false) {
|
|
||||||
if (formulas.size() > 1) {
|
|
||||||
LogVarSet excl = g->exclusiveLogVars (i);
|
|
||||||
Parfactors countNormPfs = countNormalize (g, excl);
|
|
||||||
for (unsigned j = 0; j < countNormPfs.size(); j++) {
|
|
||||||
countNormPfs[j]->absorveEvidence (
|
|
||||||
formulas[i], obsFormula.evidence());
|
|
||||||
absorvedPfs.push_back (countNormPfs[j]);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
delete commCt;
|
|
||||||
}
|
|
||||||
if (exclCt->empty() == false) {
|
|
||||||
absorvedPfs.push_back (new Parfactor (g, exclCt));
|
|
||||||
} else {
|
|
||||||
delete exclCt;
|
|
||||||
}
|
|
||||||
if (absorvedPfs.empty()) {
|
|
||||||
// hack to erase parfactor g
|
|
||||||
absorvedPfs.push_back (0);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
} else {
|
|
||||||
delete commCt;
|
|
||||||
delete exclCt;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return absorvedPfs;
|
|
||||||
}
|
|
||||||
|
|
@ -1,76 +0,0 @@
|
|||||||
#ifndef HORUS_HORUS_H
|
|
||||||
#define HORUS_HORUS_H
|
|
||||||
|
|
||||||
#include <limits>
|
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#define DISALLOW_COPY_AND_ASSIGN(TypeName) \
|
|
||||||
TypeName(const TypeName&); \
|
|
||||||
void operator=(const TypeName&)
|
|
||||||
|
|
||||||
using namespace std;
|
|
||||||
|
|
||||||
class Var;
|
|
||||||
class Factor;
|
|
||||||
class VarNode;
|
|
||||||
class FacNode;
|
|
||||||
|
|
||||||
typedef vector<double> Params;
|
|
||||||
typedef unsigned VarId;
|
|
||||||
typedef vector<VarId> VarIds;
|
|
||||||
typedef vector<Var*> Vars;
|
|
||||||
typedef vector<VarNode*> VarNodes;
|
|
||||||
typedef vector<FacNode*> FacNodes;
|
|
||||||
typedef vector<Factor*> Factors;
|
|
||||||
typedef vector<string> States;
|
|
||||||
typedef vector<unsigned> Ranges;
|
|
||||||
|
|
||||||
|
|
||||||
enum InfAlgorithms
|
|
||||||
{
|
|
||||||
VE, // variable elimination
|
|
||||||
BP, // belief propagation
|
|
||||||
CBP // counting belief propagation
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
namespace Globals {
|
|
||||||
|
|
||||||
extern bool logDomain;
|
|
||||||
|
|
||||||
extern InfAlgorithms infAlgorithm;
|
|
||||||
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
namespace Constants {
|
|
||||||
|
|
||||||
// level of debug information
|
|
||||||
const unsigned DEBUG = 0;
|
|
||||||
|
|
||||||
const int NO_EVIDENCE = -1;
|
|
||||||
|
|
||||||
// number of digits to show when printing a parameter
|
|
||||||
const unsigned PRECISION = 5;
|
|
||||||
|
|
||||||
const bool COLLECT_STATS = false;
|
|
||||||
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
namespace BpOptions
|
|
||||||
{
|
|
||||||
enum Schedule {
|
|
||||||
SEQ_FIXED,
|
|
||||||
SEQ_RANDOM,
|
|
||||||
PARALLEL,
|
|
||||||
MAX_RESIDUAL
|
|
||||||
};
|
|
||||||
extern Schedule schedule;
|
|
||||||
extern double accuracy;
|
|
||||||
extern unsigned maxIter;
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif // HORUS_HORUS_H
|
|
||||||
|
|
@ -1,171 +0,0 @@
|
|||||||
#include <cstdlib>
|
|
||||||
|
|
||||||
#include <iostream>
|
|
||||||
#include <sstream>
|
|
||||||
|
|
||||||
#include "FactorGraph.h"
|
|
||||||
#include "VarElimSolver.h"
|
|
||||||
#include "BpSolver.h"
|
|
||||||
#include "CbpSolver.h"
|
|
||||||
|
|
||||||
using namespace std;
|
|
||||||
|
|
||||||
void processArguments (FactorGraph&, int, const char* []);
|
|
||||||
void runSolver (const FactorGraph&, const VarIds&);
|
|
||||||
|
|
||||||
const string USAGE = "usage: \
|
|
||||||
./hcli ve|bp|cbp NETWORK_FILE [VARIABLE | OBSERVED_VARIABLE=EVIDENCE]..." ;
|
|
||||||
|
|
||||||
|
|
||||||
int
|
|
||||||
main (int argc, const char* argv[])
|
|
||||||
{
|
|
||||||
if (argc <= 1) {
|
|
||||||
cerr << "error: no solver specified" << endl;
|
|
||||||
cerr << "error: no graphical model specified" << endl;
|
|
||||||
cerr << USAGE << endl;
|
|
||||||
exit (0);
|
|
||||||
}
|
|
||||||
if (argc <= 2) {
|
|
||||||
cerr << "error: no graphical model specified" << endl;
|
|
||||||
cerr << USAGE << endl;
|
|
||||||
exit (0);
|
|
||||||
}
|
|
||||||
string solver (argv[1]);
|
|
||||||
if (solver == "ve") {
|
|
||||||
Globals::infAlgorithm = InfAlgorithms::VE;
|
|
||||||
} else if (solver == "bp") {
|
|
||||||
Globals::infAlgorithm = InfAlgorithms::BP;
|
|
||||||
} else if (solver == "cbp") {
|
|
||||||
Globals::infAlgorithm = InfAlgorithms::CBP;
|
|
||||||
} else {
|
|
||||||
cerr << "error: unknow solver `" << solver << "'" << endl ;
|
|
||||||
cerr << USAGE << endl;
|
|
||||||
exit(0);
|
|
||||||
}
|
|
||||||
string fileName (argv[2]);
|
|
||||||
string extension = fileName.substr (
|
|
||||||
fileName.find_last_of ('.') + 1);
|
|
||||||
FactorGraph fg;
|
|
||||||
if (extension == "uai") {
|
|
||||||
fg.readFromUaiFormat (fileName.c_str());
|
|
||||||
} else if (extension == "fg") {
|
|
||||||
fg.readFromLibDaiFormat (fileName.c_str());
|
|
||||||
} else {
|
|
||||||
cerr << "error: the graphical model must be defined either " ;
|
|
||||||
cerr << "in a UAI or libDAI file" << endl;
|
|
||||||
exit (0);
|
|
||||||
}
|
|
||||||
processArguments (fg, argc, argv);
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
processArguments (FactorGraph& fg, int argc, const char* argv[])
|
|
||||||
{
|
|
||||||
VarIds queryIds;
|
|
||||||
for (int i = 3; i < argc; i++) {
|
|
||||||
const string& arg = argv[i];
|
|
||||||
if (arg.find ('=') == std::string::npos) {
|
|
||||||
if (!Util::isInteger (arg)) {
|
|
||||||
cerr << "error: `" << arg << "' " ;
|
|
||||||
cerr << "is not a valid variable id" ;
|
|
||||||
cerr << endl;
|
|
||||||
exit (0);
|
|
||||||
}
|
|
||||||
VarId vid;
|
|
||||||
stringstream ss;
|
|
||||||
ss << arg;
|
|
||||||
ss >> vid;
|
|
||||||
VarNode* queryVar = fg.getVarNode (vid);
|
|
||||||
if (queryVar) {
|
|
||||||
queryIds.push_back (vid);
|
|
||||||
} else {
|
|
||||||
cerr << "error: there isn't a variable with " ;
|
|
||||||
cerr << "`" << vid << "' as id" ;
|
|
||||||
cerr << endl;
|
|
||||||
exit (0);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
size_t pos = arg.find ('=');
|
|
||||||
if (arg.substr (0, pos).empty()) {
|
|
||||||
cerr << "error: missing left argument" << endl;
|
|
||||||
cerr << USAGE << endl;
|
|
||||||
exit (0);
|
|
||||||
}
|
|
||||||
if (arg.substr (pos + 1).empty()) {
|
|
||||||
cerr << "error: missing right argument" << endl;
|
|
||||||
cerr << USAGE << endl;
|
|
||||||
exit (0);
|
|
||||||
}
|
|
||||||
if (!Util::isInteger (arg.substr (0, pos))) {
|
|
||||||
cerr << "error: `" << arg.substr (0, pos) << "' " ;
|
|
||||||
cerr << "is not a variable id" ;
|
|
||||||
cerr << endl;
|
|
||||||
exit (0);
|
|
||||||
}
|
|
||||||
VarId vid;
|
|
||||||
stringstream ss;
|
|
||||||
ss << arg.substr (0, pos);
|
|
||||||
ss >> vid;
|
|
||||||
VarNode* var = fg.getVarNode (vid);
|
|
||||||
if (var) {
|
|
||||||
if (!Util::isInteger (arg.substr (pos + 1))) {
|
|
||||||
cerr << "error: `" << arg.substr (pos + 1) << "' " ;
|
|
||||||
cerr << "is not a state index" ;
|
|
||||||
cerr << endl;
|
|
||||||
exit (0);
|
|
||||||
}
|
|
||||||
int stateIndex;
|
|
||||||
stringstream ss;
|
|
||||||
ss << arg.substr (pos + 1);
|
|
||||||
ss >> stateIndex;
|
|
||||||
if (var->isValidState (stateIndex)) {
|
|
||||||
var->setEvidence (stateIndex);
|
|
||||||
} else {
|
|
||||||
cerr << "error: `" << stateIndex << "' " ;
|
|
||||||
cerr << "is not a valid state index for variable " ;
|
|
||||||
cerr << "`" << var->varId() << "'" ;
|
|
||||||
cerr << endl;
|
|
||||||
exit (0);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
cerr << "error: there isn't a variable with " ;
|
|
||||||
cerr << "`" << vid << "' as id" ;
|
|
||||||
cerr << endl;
|
|
||||||
exit (0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
runSolver (fg, queryIds);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
runSolver (const FactorGraph& fg, const VarIds& queryIds)
|
|
||||||
{
|
|
||||||
Solver* solver = 0;
|
|
||||||
switch (Globals::infAlgorithm) {
|
|
||||||
case InfAlgorithms::VE:
|
|
||||||
solver = new VarElimSolver (fg);
|
|
||||||
break;
|
|
||||||
case InfAlgorithms::BP:
|
|
||||||
solver = new BpSolver (fg);
|
|
||||||
break;
|
|
||||||
case InfAlgorithms::CBP:
|
|
||||||
solver = new CbpSolver (fg);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
assert (false);
|
|
||||||
}
|
|
||||||
if (queryIds.size() == 0) {
|
|
||||||
solver->printAllPosterioris();
|
|
||||||
} else {
|
|
||||||
solver->printAnswer (queryIds);
|
|
||||||
}
|
|
||||||
delete solver;
|
|
||||||
}
|
|
||||||
|
|
@ -1,296 +0,0 @@
|
|||||||
#ifndef HORUS_STATESINDEXER_H
|
|
||||||
#define HORUS_STATESINDEXER_H
|
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
#include <numeric>
|
|
||||||
#include <functional>
|
|
||||||
|
|
||||||
#include <sstream>
|
|
||||||
#include <iomanip>
|
|
||||||
|
|
||||||
#include "Var.h"
|
|
||||||
#include "Util.h"
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class StatesIndexer
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
|
|
||||||
StatesIndexer (const Ranges& ranges, bool calcOffsets = true)
|
|
||||||
{
|
|
||||||
size_ = 1;
|
|
||||||
indices_.resize (ranges.size(), 0);
|
|
||||||
ranges_ = ranges;
|
|
||||||
for (unsigned i = 0; i < ranges.size(); i++) {
|
|
||||||
size_ *= ranges[i];
|
|
||||||
}
|
|
||||||
li_ = 0;
|
|
||||||
if (calcOffsets) {
|
|
||||||
calculateOffsets();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
StatesIndexer (const Vars& vars, bool calcOffsets = true)
|
|
||||||
{
|
|
||||||
size_ = 1;
|
|
||||||
indices_.resize (vars.size(), 0);
|
|
||||||
ranges_.reserve (vars.size());
|
|
||||||
for (unsigned i = 0; i < vars.size(); i++) {
|
|
||||||
ranges_.push_back (vars[i]->range());
|
|
||||||
size_ *= vars[i]->range();
|
|
||||||
}
|
|
||||||
li_ = 0;
|
|
||||||
if (calcOffsets) {
|
|
||||||
calculateOffsets();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void increment (void)
|
|
||||||
{
|
|
||||||
for (int i = ranges_.size() - 1; i >= 0; i--) {
|
|
||||||
indices_[i] ++;
|
|
||||||
if (indices_[i] != ranges_[i]) {
|
|
||||||
break;
|
|
||||||
} else {
|
|
||||||
indices_[i] = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
li_ ++;
|
|
||||||
}
|
|
||||||
|
|
||||||
void increment (unsigned dim)
|
|
||||||
{
|
|
||||||
assert (dim < ranges_.size());
|
|
||||||
assert (ranges_.size() == offsets_.size());
|
|
||||||
assert (indices_[dim] < ranges_[dim]);
|
|
||||||
indices_[dim] ++;
|
|
||||||
li_ += offsets_[dim];
|
|
||||||
}
|
|
||||||
|
|
||||||
void incrementExcluding (unsigned skipDim)
|
|
||||||
{
|
|
||||||
assert (ranges_.size() == offsets_.size());
|
|
||||||
for (int i = ranges_.size() - 1; i >= 0; i--) {
|
|
||||||
if (i != (int)skipDim) {
|
|
||||||
indices_[i] ++;
|
|
||||||
li_ += offsets_[i];
|
|
||||||
if (indices_[i] != ranges_[i]) {
|
|
||||||
return;
|
|
||||||
} else {
|
|
||||||
indices_[i] = 0;
|
|
||||||
li_ -= offsets_[i] * ranges_[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
li_ = size_;
|
|
||||||
}
|
|
||||||
|
|
||||||
unsigned linearIndex (void) const
|
|
||||||
{
|
|
||||||
return li_;
|
|
||||||
}
|
|
||||||
|
|
||||||
const vector<unsigned>& indices (void) const
|
|
||||||
{
|
|
||||||
return indices_;
|
|
||||||
}
|
|
||||||
|
|
||||||
StatesIndexer& operator ++ (void)
|
|
||||||
{
|
|
||||||
increment();
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
operator unsigned (void) const
|
|
||||||
{
|
|
||||||
return li_;
|
|
||||||
}
|
|
||||||
|
|
||||||
unsigned operator[] (unsigned dim) const
|
|
||||||
{
|
|
||||||
assert (valid());
|
|
||||||
assert (dim < ranges_.size());
|
|
||||||
return indices_[dim];
|
|
||||||
}
|
|
||||||
|
|
||||||
bool valid (void) const
|
|
||||||
{
|
|
||||||
return li_ < size_;
|
|
||||||
}
|
|
||||||
|
|
||||||
void reset (void)
|
|
||||||
{
|
|
||||||
std::fill (indices_.begin(), indices_.end(), 0);
|
|
||||||
li_ = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
void reset (unsigned dim)
|
|
||||||
{
|
|
||||||
indices_[dim] = 0;
|
|
||||||
li_ -= offsets_[dim] * ranges_[dim];
|
|
||||||
}
|
|
||||||
|
|
||||||
unsigned size (void) const
|
|
||||||
{
|
|
||||||
return size_ ;
|
|
||||||
}
|
|
||||||
|
|
||||||
friend ostream& operator<< (ostream &os, const StatesIndexer& idx)
|
|
||||||
{
|
|
||||||
os << "(" << std::setw (2) << std::setfill('0') << idx.li_ << ") " ;
|
|
||||||
os << idx.indices_;
|
|
||||||
return os;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
void calculateOffsets (void)
|
|
||||||
{
|
|
||||||
unsigned prod = 1;
|
|
||||||
offsets_.resize (ranges_.size());
|
|
||||||
for (int i = ranges_.size() - 1; i >= 0; i--) {
|
|
||||||
offsets_[i] = prod;
|
|
||||||
prod *= ranges_[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
unsigned li_;
|
|
||||||
unsigned size_;
|
|
||||||
vector<unsigned> indices_;
|
|
||||||
vector<unsigned> ranges_;
|
|
||||||
vector<unsigned> offsets_;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class MapIndexer
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
MapIndexer (const Ranges& ranges, const vector<bool>& mapDims)
|
|
||||||
{
|
|
||||||
assert (ranges.size() == mapDims.size());
|
|
||||||
unsigned prod = 1;
|
|
||||||
offsets_.resize (ranges.size());
|
|
||||||
for (int i = ranges.size() - 1; i >= 0; i--) {
|
|
||||||
if (mapDims[i]) {
|
|
||||||
offsets_[i] = prod;
|
|
||||||
prod *= ranges[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
indices_.resize (ranges.size(), 0);
|
|
||||||
ranges_ = ranges;
|
|
||||||
index_ = 0;
|
|
||||||
valid_ = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
MapIndexer (const Ranges& ranges, unsigned ignoreDim)
|
|
||||||
{
|
|
||||||
unsigned prod = 1;
|
|
||||||
offsets_.resize (ranges.size());
|
|
||||||
for (int i = ranges.size() - 1; i >= 0; i--) {
|
|
||||||
if (i != (int)ignoreDim) {
|
|
||||||
offsets_[i] = prod;
|
|
||||||
prod *= ranges[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
indices_.resize (ranges.size(), 0);
|
|
||||||
ranges_ = ranges;
|
|
||||||
index_ = 0;
|
|
||||||
valid_ = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
MapIndexer (
|
|
||||||
const VarIds& loopVids,
|
|
||||||
const Ranges& loopRanges,
|
|
||||||
const VarIds& mapVids,
|
|
||||||
const Ranges& mapRanges)
|
|
||||||
{
|
|
||||||
unsigned prod = 1;
|
|
||||||
vector<unsigned> offsets (mapRanges.size());
|
|
||||||
for (int i = mapRanges.size() - 1; i >= 0; i--) {
|
|
||||||
offsets[i] = prod;
|
|
||||||
prod *= mapRanges[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
offsets_.reserve (loopVids.size());
|
|
||||||
for (unsigned i = 0; i < loopVids.size(); i++) {
|
|
||||||
VarIds::const_iterator it =
|
|
||||||
std::find (mapVids.begin(), mapVids.end(), loopVids[i]);
|
|
||||||
if (it != mapVids.end()) {
|
|
||||||
offsets_.push_back (offsets[it - mapVids.begin()]);
|
|
||||||
} else {
|
|
||||||
offsets_.push_back (0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
indices_.resize (loopVids.size(), 0);
|
|
||||||
ranges_ = loopRanges;
|
|
||||||
index_ = 0;
|
|
||||||
size_ = prod;
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
MapIndexer& operator ++ (void)
|
|
||||||
{
|
|
||||||
assert (valid_);
|
|
||||||
for (int i = ranges_.size() - 1; i >= 0; i--) {
|
|
||||||
indices_[i] ++;
|
|
||||||
index_ += offsets_[i];
|
|
||||||
if (indices_[i] != ranges_[i]) {
|
|
||||||
return *this;
|
|
||||||
} else {
|
|
||||||
indices_[i] = 0;
|
|
||||||
index_ -= offsets_[i] * ranges_[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
valid_ = false;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
unsigned mappedIndex (void) const
|
|
||||||
{
|
|
||||||
return index_;
|
|
||||||
}
|
|
||||||
|
|
||||||
operator unsigned (void) const
|
|
||||||
{
|
|
||||||
return index_;
|
|
||||||
}
|
|
||||||
|
|
||||||
unsigned operator[] (unsigned dim) const
|
|
||||||
{
|
|
||||||
assert (valid());
|
|
||||||
assert (dim < ranges_.size());
|
|
||||||
return indices_[dim];
|
|
||||||
}
|
|
||||||
|
|
||||||
bool valid (void) const
|
|
||||||
{
|
|
||||||
return valid_;
|
|
||||||
}
|
|
||||||
|
|
||||||
void reset (void)
|
|
||||||
{
|
|
||||||
std::fill (indices_.begin(), indices_.end(), 0);
|
|
||||||
index_ = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
friend ostream& operator<< (ostream &os, const MapIndexer& idx)
|
|
||||||
{
|
|
||||||
os << "(" << std::setw (2) << std::setfill('0') << idx.index_ << ") " ;
|
|
||||||
os << idx.indices_;
|
|
||||||
return os;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
unsigned index_;
|
|
||||||
bool valid_;
|
|
||||||
vector<unsigned> ranges_;
|
|
||||||
vector<unsigned> indices_;
|
|
||||||
vector<unsigned> offsets_;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
#endif // HORUS_STATESINDEXER_H
|
|
||||||
|
|
@ -1,685 +0,0 @@
|
|||||||
|
|
||||||
#include "Parfactor.h"
|
|
||||||
#include "Histogram.h"
|
|
||||||
#include "Indexer.h"
|
|
||||||
#include "Util.h"
|
|
||||||
#include "Horus.h"
|
|
||||||
|
|
||||||
|
|
||||||
Parfactor::Parfactor (
|
|
||||||
const ProbFormulas& formulas,
|
|
||||||
const Params& params,
|
|
||||||
const Tuples& tuples,
|
|
||||||
unsigned distId)
|
|
||||||
{
|
|
||||||
args_ = formulas;
|
|
||||||
params_ = params;
|
|
||||||
distId_ = distId;
|
|
||||||
|
|
||||||
LogVars logVars;
|
|
||||||
for (unsigned i = 0; i < args_.size(); i++) {
|
|
||||||
ranges_.push_back (args_[i].range());
|
|
||||||
const LogVars& lvs = args_[i].logVars();
|
|
||||||
for (unsigned j = 0; j < lvs.size(); j++) {
|
|
||||||
if (Util::contains (logVars, lvs[j]) == false) {
|
|
||||||
logVars.push_back (lvs[j]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
constr_ = new ConstraintTree (logVars, tuples);
|
|
||||||
assert (params_.size() == Util::expectedSize (ranges_));
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Parfactor::Parfactor (const Parfactor* g, const Tuple& tuple)
|
|
||||||
{
|
|
||||||
args_ = g->arguments();
|
|
||||||
params_ = g->params();
|
|
||||||
ranges_ = g->ranges();
|
|
||||||
distId_ = g->distId();
|
|
||||||
constr_ = new ConstraintTree (g->logVars(), {tuple});
|
|
||||||
assert (params_.size() == Util::expectedSize (ranges_));
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Parfactor::Parfactor (const Parfactor* g, ConstraintTree* constr)
|
|
||||||
{
|
|
||||||
args_ = g->arguments();
|
|
||||||
params_ = g->params();
|
|
||||||
ranges_ = g->ranges();
|
|
||||||
distId_ = g->distId();
|
|
||||||
constr_ = constr;
|
|
||||||
assert (params_.size() == Util::expectedSize (ranges_));
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Parfactor::Parfactor (const Parfactor& g)
|
|
||||||
{
|
|
||||||
args_ = g.arguments();
|
|
||||||
params_ = g.params();
|
|
||||||
ranges_ = g.ranges();
|
|
||||||
distId_ = g.distId();
|
|
||||||
constr_ = new ConstraintTree (*g.constr());
|
|
||||||
assert (params_.size() == Util::expectedSize (ranges_));
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Parfactor::~Parfactor (void)
|
|
||||||
{
|
|
||||||
delete constr_;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
LogVarSet
|
|
||||||
Parfactor::countedLogVars (void) const
|
|
||||||
{
|
|
||||||
LogVarSet set;
|
|
||||||
for (unsigned i = 0; i < args_.size(); i++) {
|
|
||||||
if (args_[i].isCounting()) {
|
|
||||||
set.insert (args_[i].countedLogVar());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return set;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
LogVarSet
|
|
||||||
Parfactor::uncountedLogVars (void) const
|
|
||||||
{
|
|
||||||
return constr_->logVarSet() - countedLogVars();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
LogVarSet
|
|
||||||
Parfactor::elimLogVars (void) const
|
|
||||||
{
|
|
||||||
LogVarSet requiredToElim = constr_->logVarSet();
|
|
||||||
requiredToElim -= constr_->singletons();
|
|
||||||
requiredToElim -= countedLogVars();
|
|
||||||
return requiredToElim;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
LogVarSet
|
|
||||||
Parfactor::exclusiveLogVars (unsigned fIdx) const
|
|
||||||
{
|
|
||||||
assert (fIdx < args_.size());
|
|
||||||
LogVarSet remaining;
|
|
||||||
for (unsigned i = 0; i < args_.size(); i++) {
|
|
||||||
if (i != fIdx) {
|
|
||||||
remaining |= args_[i].logVarSet();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return args_[fIdx].logVarSet() - remaining;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Parfactor::setConstraintTree (ConstraintTree* newTree)
|
|
||||||
{
|
|
||||||
delete constr_;
|
|
||||||
constr_ = newTree;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Parfactor::sumOut (unsigned fIdx)
|
|
||||||
{
|
|
||||||
assert (fIdx < args_.size());
|
|
||||||
assert (args_[fIdx].contains (elimLogVars()));
|
|
||||||
|
|
||||||
LogVarSet excl = exclusiveLogVars (fIdx);
|
|
||||||
if (args_[fIdx].isCounting()) {
|
|
||||||
LogAware::pow (params_, constr_->getConditionalCount (
|
|
||||||
excl - args_[fIdx].countedLogVar()));
|
|
||||||
} else {
|
|
||||||
LogAware::pow (params_, constr_->getConditionalCount (excl));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (args_[fIdx].isCounting()) {
|
|
||||||
unsigned N = constr_->getConditionalCount (
|
|
||||||
args_[fIdx].countedLogVar());
|
|
||||||
unsigned R = args_[fIdx].range();
|
|
||||||
vector<double> numAssigns = HistogramSet::getNumAssigns (N, R);
|
|
||||||
StatesIndexer sindexer (ranges_, fIdx);
|
|
||||||
while (sindexer.valid()) {
|
|
||||||
unsigned h = sindexer[fIdx];
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
params_[sindexer] += numAssigns[h];
|
|
||||||
} else {
|
|
||||||
params_[sindexer] *= numAssigns[h];
|
|
||||||
}
|
|
||||||
++ sindexer;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Params copy = params_;
|
|
||||||
params_.clear();
|
|
||||||
params_.resize (copy.size() / ranges_[fIdx], LogAware::addIdenty());
|
|
||||||
MapIndexer indexer (ranges_, fIdx);
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
for (unsigned i = 0; i < copy.size(); i++) {
|
|
||||||
params_[indexer] = Util::logSum (params_[indexer], copy[i]);
|
|
||||||
++ indexer;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (unsigned i = 0; i < copy.size(); i++) {
|
|
||||||
params_[indexer] += copy[i];
|
|
||||||
++ indexer;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
args_.erase (args_.begin() + fIdx);
|
|
||||||
ranges_.erase (ranges_.begin() + fIdx);
|
|
||||||
constr_->remove (excl);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Parfactor::multiply (Parfactor& g)
|
|
||||||
{
|
|
||||||
alignAndExponentiate (this, &g);
|
|
||||||
TFactor<ProbFormula>::multiply (g);
|
|
||||||
constr_->join (g.constr(), true);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Parfactor::countConvert (LogVar X)
|
|
||||||
{
|
|
||||||
int fIdx = indexOfLogVar (X);
|
|
||||||
assert (fIdx != -1);
|
|
||||||
assert (constr_->isCountNormalized (X));
|
|
||||||
assert (constr_->getConditionalCount (X) > 1);
|
|
||||||
assert (constr_->isCarteesianProduct (countedLogVars() | X));
|
|
||||||
|
|
||||||
unsigned N = constr_->getConditionalCount (X);
|
|
||||||
unsigned R = ranges_[fIdx];
|
|
||||||
unsigned H = HistogramSet::nrHistograms (N, R);
|
|
||||||
vector<Histogram> histograms = HistogramSet::getHistograms (N, R);
|
|
||||||
|
|
||||||
StatesIndexer indexer (ranges_);
|
|
||||||
vector<Params> sumout (params_.size() / R);
|
|
||||||
unsigned count = 0;
|
|
||||||
while (indexer.valid()) {
|
|
||||||
sumout[count].reserve (R);
|
|
||||||
for (unsigned r = 0; r < R; r++) {
|
|
||||||
sumout[count].push_back (params_[indexer]);
|
|
||||||
indexer.increment (fIdx);
|
|
||||||
}
|
|
||||||
count ++;
|
|
||||||
indexer.reset (fIdx);
|
|
||||||
indexer.incrementExcluding (fIdx);
|
|
||||||
}
|
|
||||||
|
|
||||||
params_.clear();
|
|
||||||
params_.reserve (sumout.size() * H);
|
|
||||||
|
|
||||||
ranges_[fIdx] = H;
|
|
||||||
MapIndexer mapIndexer (ranges_, fIdx);
|
|
||||||
while (mapIndexer.valid()) {
|
|
||||||
double prod = LogAware::multIdenty();
|
|
||||||
unsigned i = mapIndexer.mappedIndex();
|
|
||||||
unsigned h = mapIndexer[fIdx];
|
|
||||||
for (unsigned r = 0; r < R; r++) {
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
prod += LogAware::pow (sumout[i][r], histograms[h][r]);
|
|
||||||
} else {
|
|
||||||
prod *= LogAware::pow (sumout[i][r], histograms[h][r]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
params_.push_back (prod);
|
|
||||||
++ mapIndexer;
|
|
||||||
}
|
|
||||||
args_[fIdx].setCountedLogVar (X);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Parfactor::expand (LogVar X, LogVar X_new1, LogVar X_new2)
|
|
||||||
{
|
|
||||||
int fIdx = indexOfLogVar (X);
|
|
||||||
assert (fIdx != -1);
|
|
||||||
assert (args_[fIdx].isCounting());
|
|
||||||
|
|
||||||
unsigned N1 = constr_->getConditionalCount (X_new1);
|
|
||||||
unsigned N2 = constr_->getConditionalCount (X_new2);
|
|
||||||
unsigned N = N1 + N2;
|
|
||||||
unsigned R = args_[fIdx].range();
|
|
||||||
unsigned H1 = HistogramSet::nrHistograms (N1, R);
|
|
||||||
unsigned H2 = HistogramSet::nrHistograms (N2, R);
|
|
||||||
|
|
||||||
vector<Histogram> histograms = HistogramSet::getHistograms (N, R);
|
|
||||||
vector<Histogram> histograms1 = HistogramSet::getHistograms (N1, R);
|
|
||||||
vector<Histogram> histograms2 = HistogramSet::getHistograms (N2, R);
|
|
||||||
|
|
||||||
vector<unsigned> sumIndexes;
|
|
||||||
sumIndexes.reserve (H1 * H2);
|
|
||||||
for (unsigned i = 0; i < H1; i++) {
|
|
||||||
for (unsigned j = 0; j < H2; j++) {
|
|
||||||
Histogram hist = histograms1[i];
|
|
||||||
std::transform (
|
|
||||||
hist.begin(), hist.end(),
|
|
||||||
histograms2[j].begin(),
|
|
||||||
hist.begin(),
|
|
||||||
plus<int>());
|
|
||||||
sumIndexes.push_back (HistogramSet::findIndex (hist, histograms));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
expandPotential (fIdx, H1 * H2, sumIndexes);
|
|
||||||
|
|
||||||
args_.insert (args_.begin() + fIdx + 1, args_[fIdx]);
|
|
||||||
args_[fIdx].rename (X, X_new1);
|
|
||||||
args_[fIdx + 1].rename (X, X_new2);
|
|
||||||
ranges_.insert (ranges_.begin() + fIdx + 1, H2);
|
|
||||||
ranges_[fIdx] = H1;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Parfactor::fullExpand (LogVar X)
|
|
||||||
{
|
|
||||||
int fIdx = indexOfLogVar (X);
|
|
||||||
assert (fIdx != -1);
|
|
||||||
assert (args_[fIdx].isCounting());
|
|
||||||
|
|
||||||
unsigned N = constr_->getConditionalCount (X);
|
|
||||||
unsigned R = args_[fIdx].range();
|
|
||||||
|
|
||||||
vector<Histogram> originHists = HistogramSet::getHistograms (N, R);
|
|
||||||
vector<Histogram> expandHists = HistogramSet::getHistograms (1, R);
|
|
||||||
|
|
||||||
vector<unsigned> sumIndexes;
|
|
||||||
sumIndexes.reserve (N * R);
|
|
||||||
|
|
||||||
Ranges expandRanges (N, R);
|
|
||||||
StatesIndexer indexer (expandRanges);
|
|
||||||
while (indexer.valid()) {
|
|
||||||
vector<unsigned> hist (R, 0);
|
|
||||||
for (unsigned n = 0; n < N; n++) {
|
|
||||||
std::transform (
|
|
||||||
hist.begin(), hist.end(),
|
|
||||||
expandHists[indexer[n]].begin(),
|
|
||||||
hist.begin(),
|
|
||||||
plus<int>());
|
|
||||||
}
|
|
||||||
sumIndexes.push_back (HistogramSet::findIndex (hist, originHists));
|
|
||||||
++ indexer;
|
|
||||||
}
|
|
||||||
|
|
||||||
expandPotential (fIdx, std::pow (R, N), sumIndexes);
|
|
||||||
|
|
||||||
ProbFormula f = args_[fIdx];
|
|
||||||
args_.erase (args_.begin() + fIdx);
|
|
||||||
ranges_.erase (ranges_.begin() + fIdx);
|
|
||||||
LogVars newLvs = constr_->expand (X);
|
|
||||||
assert (newLvs.size() == N);
|
|
||||||
for (unsigned i = 0 ; i < N; i++) {
|
|
||||||
ProbFormula newFormula (f.functor(), f.logVars(), f.range());
|
|
||||||
newFormula.rename (X, newLvs[i]);
|
|
||||||
args_.insert (args_.begin() + fIdx + i, newFormula);
|
|
||||||
ranges_.insert (ranges_.begin() + fIdx + i, R);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Parfactor::reorderAccordingGrounds (const Grounds& grounds)
|
|
||||||
{
|
|
||||||
ProbFormulas newFormulas;
|
|
||||||
for (unsigned i = 0; i < grounds.size(); i++) {
|
|
||||||
for (unsigned j = 0; j < args_.size(); j++) {
|
|
||||||
if (grounds[i].functor() == args_[j].functor() &&
|
|
||||||
grounds[i].arity() == args_[j].arity()) {
|
|
||||||
constr_->moveToTop (args_[j].logVars());
|
|
||||||
if (constr_->containsTuple (grounds[i].args())) {
|
|
||||||
newFormulas.push_back (args_[j]);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
assert (newFormulas.size() == i + 1);
|
|
||||||
}
|
|
||||||
reorderArguments (newFormulas);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Parfactor::absorveEvidence (const ProbFormula& formula, unsigned evidence)
|
|
||||||
{
|
|
||||||
int fIdx = indexOf (formula);
|
|
||||||
assert (fIdx != -1);
|
|
||||||
LogVarSet excl = exclusiveLogVars (fIdx);
|
|
||||||
assert (args_[fIdx].isCounting() == false);
|
|
||||||
assert (constr_->isCountNormalized (excl));
|
|
||||||
LogAware::pow (params_, constr_->getConditionalCount (excl));
|
|
||||||
TFactor<ProbFormula>::absorveEvidence (formula, evidence);
|
|
||||||
constr_->remove (excl);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Parfactor::setNewGroups (void)
|
|
||||||
{
|
|
||||||
for (unsigned i = 0; i < args_.size(); i++) {
|
|
||||||
args_[i].setGroup (ProbFormula::getNewGroup());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Parfactor::applySubstitution (const Substitution& theta)
|
|
||||||
{
|
|
||||||
for (unsigned i = 0; i < args_.size(); i++) {
|
|
||||||
LogVars& lvs = args_[i].logVars();
|
|
||||||
for (unsigned j = 0; j < lvs.size(); j++) {
|
|
||||||
lvs[j] = theta.newNameFor (lvs[j]);
|
|
||||||
}
|
|
||||||
if (args_[i].isCounting()) {
|
|
||||||
LogVar clv = args_[i].countedLogVar();
|
|
||||||
args_[i].setCountedLogVar (theta.newNameFor (clv));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
constr_->applySubstitution (theta);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
int
|
|
||||||
Parfactor::findGroup (const Ground& ground) const
|
|
||||||
{
|
|
||||||
int group = -1;
|
|
||||||
for (unsigned i = 0; i < args_.size(); i++) {
|
|
||||||
if (args_[i].functor() == ground.functor() &&
|
|
||||||
args_[i].arity() == ground.arity()) {
|
|
||||||
constr_->moveToTop (args_[i].logVars());
|
|
||||||
if (constr_->containsTuple (ground.args())) {
|
|
||||||
group = args_[i].group();
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return group;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
|
||||||
Parfactor::containsGround (const Ground& ground) const
|
|
||||||
{
|
|
||||||
return findGroup (ground) != -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
|
||||||
Parfactor::containsGroup (unsigned group) const
|
|
||||||
{
|
|
||||||
for (unsigned i = 0; i < args_.size(); i++) {
|
|
||||||
if (args_[i].group() == group) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
unsigned
|
|
||||||
Parfactor::nrFormulas (LogVar X) const
|
|
||||||
{
|
|
||||||
unsigned count = 0;
|
|
||||||
for (unsigned i = 0; i < args_.size(); i++) {
|
|
||||||
if (args_[i].contains (X)) {
|
|
||||||
count ++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return count;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
int
|
|
||||||
Parfactor::indexOfLogVar (LogVar X) const
|
|
||||||
{
|
|
||||||
int idx = -1;
|
|
||||||
assert (nrFormulas (X) == 1);
|
|
||||||
for (unsigned i = 0; i < args_.size(); i++) {
|
|
||||||
if (args_[i].contains (X)) {
|
|
||||||
idx = i;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return idx;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
int
|
|
||||||
Parfactor::indexOfGroup (unsigned group) const
|
|
||||||
{
|
|
||||||
int pos = -1;
|
|
||||||
for (unsigned i = 0; i < args_.size(); i++) {
|
|
||||||
if (args_[i].group() == group) {
|
|
||||||
pos = i;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return pos;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
vector<unsigned>
|
|
||||||
Parfactor::getAllGroups (void) const
|
|
||||||
{
|
|
||||||
vector<unsigned> groups (args_.size());
|
|
||||||
for (unsigned i = 0; i < args_.size(); i++) {
|
|
||||||
groups[i] = args_[i].group();
|
|
||||||
}
|
|
||||||
return groups;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
string
|
|
||||||
Parfactor::getLabel (void) const
|
|
||||||
{
|
|
||||||
stringstream ss;
|
|
||||||
ss << "phi(" ;
|
|
||||||
for (unsigned i = 0; i < args_.size(); i++) {
|
|
||||||
if (i != 0) ss << "," ;
|
|
||||||
ss << args_[i];
|
|
||||||
}
|
|
||||||
ss << ")" ;
|
|
||||||
ConstraintTree copy (*constr_);
|
|
||||||
copy.moveToTop (copy.logVarSet().elements());
|
|
||||||
ss << "|" << copy.tupleSet();
|
|
||||||
return ss.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Parfactor::print (bool printParams) const
|
|
||||||
{
|
|
||||||
cout << "Formulas: " ;
|
|
||||||
for (unsigned i = 0; i < args_.size(); i++) {
|
|
||||||
if (i != 0) cout << ", " ;
|
|
||||||
cout << args_[i];
|
|
||||||
}
|
|
||||||
cout << endl;
|
|
||||||
if (args_[0].group() != Util::maxUnsigned()) {
|
|
||||||
vector<string> groups;
|
|
||||||
for (unsigned i = 0; i < args_.size(); i++) {
|
|
||||||
groups.push_back (string ("g") + Util::toString (args_[i].group()));
|
|
||||||
}
|
|
||||||
cout << "Groups: " << groups << endl;
|
|
||||||
}
|
|
||||||
cout << "LogVars: " << constr_->logVarSet() << endl;
|
|
||||||
cout << "Ranges: " << ranges_ << endl;
|
|
||||||
if (printParams == false) {
|
|
||||||
cout << "Params: " << params_ << endl;
|
|
||||||
}
|
|
||||||
ConstraintTree copy (*constr_);
|
|
||||||
copy.moveToTop (copy.logVarSet().elements());
|
|
||||||
cout << "Tuples: " << copy.tupleSet() << endl;
|
|
||||||
if (printParams) {
|
|
||||||
vector<string> jointStrings;
|
|
||||||
StatesIndexer indexer (ranges_);
|
|
||||||
while (indexer.valid()) {
|
|
||||||
stringstream ss;
|
|
||||||
for (unsigned i = 0; i < args_.size(); i++) {
|
|
||||||
if (i != 0) ss << ", " ;
|
|
||||||
if (args_[i].isCounting()) {
|
|
||||||
unsigned N = constr_->getConditionalCount (
|
|
||||||
args_[i].countedLogVar());
|
|
||||||
HistogramSet hs (N, args_[i].range());
|
|
||||||
unsigned c = 0;
|
|
||||||
while (c < indexer[i]) {
|
|
||||||
hs.nextHistogram();
|
|
||||||
c ++;
|
|
||||||
}
|
|
||||||
ss << hs;
|
|
||||||
} else {
|
|
||||||
ss << indexer[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
jointStrings.push_back (ss.str());
|
|
||||||
++ indexer;
|
|
||||||
}
|
|
||||||
for (unsigned i = 0; i < params_.size(); i++) {
|
|
||||||
cout << "f(" << jointStrings[i] << ")" ;
|
|
||||||
cout << " = " << params_[i] << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Parfactor::expandPotential (
|
|
||||||
int fIdx,
|
|
||||||
unsigned newRange,
|
|
||||||
const vector<unsigned>& sumIndexes)
|
|
||||||
{
|
|
||||||
unsigned size = (params_.size() / ranges_[fIdx]) * newRange;
|
|
||||||
Params copy = params_;
|
|
||||||
params_.clear();
|
|
||||||
params_.reserve (size);
|
|
||||||
|
|
||||||
unsigned prod = 1;
|
|
||||||
vector<unsigned> offsets_ (ranges_.size());
|
|
||||||
for (int i = ranges_.size() - 1; i >= 0; i--) {
|
|
||||||
offsets_[i] = prod;
|
|
||||||
prod *= ranges_[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
unsigned index = 0;
|
|
||||||
ranges_[fIdx] = newRange;
|
|
||||||
vector<unsigned> indices (ranges_.size(), 0);
|
|
||||||
for (unsigned k = 0; k < size; k++) {
|
|
||||||
params_.push_back (copy[index]);
|
|
||||||
for (int i = ranges_.size() - 1; i >= 0; i--) {
|
|
||||||
indices[i] ++;
|
|
||||||
if (i == fIdx) {
|
|
||||||
assert (indices[i] - 1 < sumIndexes.size());
|
|
||||||
int diff = sumIndexes[indices[i]] - sumIndexes[indices[i] - 1];
|
|
||||||
index += diff * offsets_[i];
|
|
||||||
} else {
|
|
||||||
index += offsets_[i];
|
|
||||||
}
|
|
||||||
if (indices[i] != ranges_[i]) {
|
|
||||||
break;
|
|
||||||
} else {
|
|
||||||
if (i == fIdx) {
|
|
||||||
int diff = sumIndexes[0] - sumIndexes[indices[i]];
|
|
||||||
index += diff * offsets_[i];
|
|
||||||
} else {
|
|
||||||
index -= offsets_[i] * ranges_[i];
|
|
||||||
}
|
|
||||||
indices[i] = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Parfactor::alignAndExponentiate (Parfactor* g1, Parfactor* g2)
|
|
||||||
{
|
|
||||||
LogVars X_1, X_2;
|
|
||||||
const ProbFormulas& formulas1 = g1->arguments();
|
|
||||||
const ProbFormulas& formulas2 = g2->arguments();
|
|
||||||
for (unsigned i = 0; i < formulas1.size(); i++) {
|
|
||||||
for (unsigned j = 0; j < formulas2.size(); j++) {
|
|
||||||
if (formulas1[i].group() == formulas2[j].group()) {
|
|
||||||
Util::addToVector (X_1, formulas1[i].logVars());
|
|
||||||
Util::addToVector (X_2, formulas2[j].logVars());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
LogVarSet Y_1 = g1->logVarSet() - LogVarSet (X_1);
|
|
||||||
LogVarSet Y_2 = g2->logVarSet() - LogVarSet (X_2);
|
|
||||||
assert (g1->constr()->isCountNormalized (Y_1));
|
|
||||||
assert (g2->constr()->isCountNormalized (Y_2));
|
|
||||||
unsigned condCount1 = g1->constr()->getConditionalCount (Y_1);
|
|
||||||
unsigned condCount2 = g2->constr()->getConditionalCount (Y_2);
|
|
||||||
LogAware::pow (g1->params(), 1.0 / condCount2);
|
|
||||||
LogAware::pow (g2->params(), 1.0 / condCount1);
|
|
||||||
// this must be done in the end or else X_1 and X_2
|
|
||||||
// will refer the old log var names in the code above
|
|
||||||
align (g1, X_1, g2, X_2);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Parfactor::align (
|
|
||||||
Parfactor* g1, const LogVars& alignLvs1,
|
|
||||||
Parfactor* g2, const LogVars& alignLvs2)
|
|
||||||
{
|
|
||||||
LogVar freeLogVar = 0;
|
|
||||||
Substitution theta1;
|
|
||||||
Substitution theta2;
|
|
||||||
const LogVarSet& allLvs1 = g1->logVarSet();
|
|
||||||
for (unsigned i = 0; i < allLvs1.size(); i++) {
|
|
||||||
theta1.add (allLvs1[i], freeLogVar);
|
|
||||||
++ freeLogVar;
|
|
||||||
}
|
|
||||||
|
|
||||||
const LogVarSet& allLvs2 = g2->logVarSet();
|
|
||||||
for (unsigned i = 0; i < allLvs2.size(); i++) {
|
|
||||||
theta2.add (allLvs2[i], freeLogVar);
|
|
||||||
++ freeLogVar;
|
|
||||||
}
|
|
||||||
|
|
||||||
assert (alignLvs1.size() == alignLvs2.size());
|
|
||||||
for (unsigned i = 0; i < alignLvs1.size(); i++) {
|
|
||||||
theta1.rename (alignLvs1[i], theta2.newNameFor (alignLvs2[i]));
|
|
||||||
}
|
|
||||||
g1->applySubstitution (theta1);
|
|
||||||
g2->applySubstitution (theta2);
|
|
||||||
}
|
|
||||||
|
|
@ -1,37 +0,0 @@
|
|||||||
#include "Solver.h"
|
|
||||||
#include "Util.h"
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Solver::printAnswer (const VarIds& vids)
|
|
||||||
{
|
|
||||||
Vars unobservedVars;
|
|
||||||
VarIds unobservedVids;
|
|
||||||
for (unsigned i = 0; i < vids.size(); i++) {
|
|
||||||
VarNode* vn = fg.getVarNode (vids[i]);
|
|
||||||
if (vn->hasEvidence() == false) {
|
|
||||||
unobservedVars.push_back (vn);
|
|
||||||
unobservedVids.push_back (vids[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Params res = solveQuery (unobservedVids);
|
|
||||||
vector<string> stateLines = Util::getStateLines (unobservedVars);
|
|
||||||
for (unsigned i = 0; i < res.size(); i++) {
|
|
||||||
cout << "P(" << stateLines[i] << ") = " ;
|
|
||||||
cout << std::setprecision (Constants::PRECISION) << res[i];
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Solver::printAllPosterioris (void)
|
|
||||||
{
|
|
||||||
const VarNodes& vars = fg.varNodes();
|
|
||||||
for (unsigned i = 0; i < vars.size(); i++) {
|
|
||||||
printAnswer ({vars[i]->varId()});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,4 +0,0 @@
|
|||||||
TODO
|
|
||||||
- add a way to calculate combinations and factorials with large numbers
|
|
||||||
- refactor sumOut in parfactor -> is really ugly code
|
|
||||||
- Indexer: start receiving ranges as constant reference
|
|
@ -1,200 +0,0 @@
|
|||||||
#ifndef HORUS_TINYSET_H
|
|
||||||
#define HORUS_TINYSET_H
|
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
#include <algorithm>
|
|
||||||
|
|
||||||
using namespace std;
|
|
||||||
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
class TinySet
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
TinySet (void) {}
|
|
||||||
|
|
||||||
TinySet (const T& t)
|
|
||||||
{
|
|
||||||
elements_.push_back (t);
|
|
||||||
}
|
|
||||||
|
|
||||||
TinySet (const vector<T>& elements)
|
|
||||||
{
|
|
||||||
elements_.reserve (elements.size());
|
|
||||||
for (unsigned i = 0; i < elements.size(); i++) {
|
|
||||||
insert (elements[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
TinySet (const TinySet<T>& s) : elements_(s.elements_) { }
|
|
||||||
|
|
||||||
void insert (const T& t)
|
|
||||||
{
|
|
||||||
typename vector<T>::iterator it =
|
|
||||||
std::lower_bound (elements_.begin(), elements_.end(), t);
|
|
||||||
if (it == elements_.end() || *it != t) {
|
|
||||||
elements_.insert (it, t);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void remove (const T& t)
|
|
||||||
{
|
|
||||||
typename vector<T>::iterator it =
|
|
||||||
std::lower_bound (elements_.begin(), elements_.end(), t);
|
|
||||||
if (it != elements_.end()) {
|
|
||||||
elements_.erase (it);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/* set union */
|
|
||||||
TinySet operator| (const TinySet& s) const
|
|
||||||
{
|
|
||||||
TinySet res;
|
|
||||||
std::set_union (
|
|
||||||
elements_.begin(),
|
|
||||||
elements_.end(),
|
|
||||||
s.elements_.begin(),
|
|
||||||
s.elements_.end(),
|
|
||||||
std::back_inserter (res.elements_));
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* set intersection */
|
|
||||||
TinySet operator& (const TinySet& s) const
|
|
||||||
{
|
|
||||||
TinySet res;
|
|
||||||
std::set_intersection (
|
|
||||||
elements_.begin(),
|
|
||||||
elements_.end(),
|
|
||||||
s.elements_.begin(),
|
|
||||||
s.elements_.end(),
|
|
||||||
std::back_inserter (res.elements_));
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* set difference */
|
|
||||||
TinySet operator- (const TinySet& s) const
|
|
||||||
{
|
|
||||||
TinySet res;
|
|
||||||
std::set_difference (
|
|
||||||
elements_.begin(),
|
|
||||||
elements_.end(),
|
|
||||||
s.elements_.begin(),
|
|
||||||
s.elements_.end(),
|
|
||||||
std::back_inserter (res.elements_));
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
TinySet& operator|= (const TinySet& s)
|
|
||||||
{
|
|
||||||
return *this = (*this | s);
|
|
||||||
}
|
|
||||||
|
|
||||||
TinySet& operator&= (const TinySet& s)
|
|
||||||
{
|
|
||||||
return *this = (*this & s);
|
|
||||||
}
|
|
||||||
|
|
||||||
TinySet& operator-= (const TinySet& s)
|
|
||||||
{
|
|
||||||
return *this = (*this - s);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool contains (const T& t) const
|
|
||||||
{
|
|
||||||
return std::binary_search (
|
|
||||||
elements_.begin(), elements_.end(), t);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool contains (const TinySet& s) const
|
|
||||||
{
|
|
||||||
return std::includes (
|
|
||||||
elements_.begin(),
|
|
||||||
elements_.end(),
|
|
||||||
s.elements_.begin(),
|
|
||||||
s.elements_.end());
|
|
||||||
}
|
|
||||||
|
|
||||||
bool in (const TinySet& s) const
|
|
||||||
{
|
|
||||||
return std::includes (
|
|
||||||
s.elements_.begin(),
|
|
||||||
s.elements_.end(),
|
|
||||||
elements_.begin(),
|
|
||||||
elements_.end());
|
|
||||||
}
|
|
||||||
|
|
||||||
bool intersects (const TinySet& s) const
|
|
||||||
{
|
|
||||||
return (*this & s).size() > 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
T operator[] (unsigned i) const
|
|
||||||
{
|
|
||||||
return elements_[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
const vector<T>& elements (void) const
|
|
||||||
{
|
|
||||||
return elements_;
|
|
||||||
}
|
|
||||||
|
|
||||||
T front (void) const
|
|
||||||
{
|
|
||||||
return elements_.front();
|
|
||||||
}
|
|
||||||
|
|
||||||
T back (void) const
|
|
||||||
{
|
|
||||||
return elements_.back();
|
|
||||||
}
|
|
||||||
|
|
||||||
unsigned size (void) const
|
|
||||||
{
|
|
||||||
return elements_.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool empty (void) const
|
|
||||||
{
|
|
||||||
return elements_.size() == 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
typedef typename std::vector<T>::const_iterator const_iterator;
|
|
||||||
|
|
||||||
const_iterator begin (void) const
|
|
||||||
{
|
|
||||||
return elements_.begin();
|
|
||||||
}
|
|
||||||
|
|
||||||
const_iterator end (void) const
|
|
||||||
{
|
|
||||||
return elements_.end();
|
|
||||||
}
|
|
||||||
|
|
||||||
friend bool operator== (const TinySet& s1, const TinySet& s2)
|
|
||||||
{
|
|
||||||
return s1.elements_ == s2.elements_;
|
|
||||||
}
|
|
||||||
|
|
||||||
friend bool operator!= (const TinySet& s1, const TinySet& s2)
|
|
||||||
{
|
|
||||||
return s1.elements_ != s2.elements_;
|
|
||||||
}
|
|
||||||
|
|
||||||
friend std::ostream& operator << (std::ostream& out, const TinySet<T>& s)
|
|
||||||
{
|
|
||||||
out << "{" ;
|
|
||||||
for (unsigned i = 0; i < s.size(); i++) {
|
|
||||||
out << ((i != 0) ? "," : "") << s.elements()[i];
|
|
||||||
}
|
|
||||||
out << "}" ;
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
|
||||||
vector<T> elements_;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
#endif // HORUS_TINYSET_H
|
|
||||||
|
|
@ -1,502 +0,0 @@
|
|||||||
#include <limits>
|
|
||||||
|
|
||||||
#include <sstream>
|
|
||||||
#include <fstream>
|
|
||||||
|
|
||||||
#include "Util.h"
|
|
||||||
#include "Indexer.h"
|
|
||||||
|
|
||||||
|
|
||||||
namespace Globals {
|
|
||||||
bool logDomain = false;
|
|
||||||
|
|
||||||
//InfAlgs infAlgorithm = InfAlgorithms::VE;
|
|
||||||
//InfAlgs infAlgorithm = InfAlgorithms::BN_BP;
|
|
||||||
//InfAlgs infAlgorithm = InfAlgorithms::FG_BP;
|
|
||||||
InfAlgorithms infAlgorithm = InfAlgorithms::CBP;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
namespace BpOptions {
|
|
||||||
Schedule schedule = BpOptions::Schedule::SEQ_FIXED;
|
|
||||||
//Schedule schedule = BpOptions::Schedule::SEQ_RANDOM;
|
|
||||||
//Schedule schedule = BpOptions::Schedule::PARALLEL;
|
|
||||||
//Schedule schedule = BpOptions::Schedule::MAX_RESIDUAL;
|
|
||||||
double accuracy = 0.0001;
|
|
||||||
unsigned maxIter = 1000;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
vector<NetInfo> Statistics::netInfo_;
|
|
||||||
vector<CompressInfo> Statistics::compressInfo_;
|
|
||||||
unsigned Statistics::primaryNetCount_;
|
|
||||||
|
|
||||||
|
|
||||||
namespace Util {
|
|
||||||
|
|
||||||
void
|
|
||||||
toLog (Params& v)
|
|
||||||
{
|
|
||||||
for (unsigned i = 0; i < v.size(); i++) {
|
|
||||||
v[i] = log (v[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
fromLog (Params& v)
|
|
||||||
{
|
|
||||||
for (unsigned i = 0; i < v.size(); i++) {
|
|
||||||
v[i] = exp (v[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
double
|
|
||||||
factorial (double num)
|
|
||||||
{
|
|
||||||
double result = 1.0;
|
|
||||||
for (int i = 1; i <= num; i++) {
|
|
||||||
result *= i;
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
unsigned
|
|
||||||
nrCombinations (unsigned n, unsigned r)
|
|
||||||
{
|
|
||||||
assert (n >= r);
|
|
||||||
unsigned prod = 1;
|
|
||||||
for (int i = (int)n; i > (int)(n - r); i--) {
|
|
||||||
prod *= i;
|
|
||||||
}
|
|
||||||
return (prod / factorial (r));
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
unsigned
|
|
||||||
expectedSize (const Ranges& ranges)
|
|
||||||
{
|
|
||||||
unsigned prod = 1;
|
|
||||||
for (unsigned i = 0; i < ranges.size(); i++) {
|
|
||||||
prod *= ranges[i];
|
|
||||||
}
|
|
||||||
return prod;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
unsigned
|
|
||||||
getNumberOfDigits (int number)
|
|
||||||
{
|
|
||||||
unsigned count = 1;
|
|
||||||
while (number >= 10) {
|
|
||||||
number /= 10;
|
|
||||||
count ++;
|
|
||||||
}
|
|
||||||
return count;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
|
||||||
isInteger (const string& s)
|
|
||||||
{
|
|
||||||
stringstream ss1 (s);
|
|
||||||
stringstream ss2;
|
|
||||||
int integer;
|
|
||||||
ss1 >> integer;
|
|
||||||
ss2 << integer;
|
|
||||||
return (ss1.str() == ss2.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
string
|
|
||||||
parametersToString (const Params& v, unsigned precision)
|
|
||||||
{
|
|
||||||
stringstream ss;
|
|
||||||
ss.precision (precision);
|
|
||||||
ss << "[" ;
|
|
||||||
for (unsigned i = 0; i < v.size(); i++) {
|
|
||||||
if (i != 0) ss << ", " ;
|
|
||||||
ss << v[i];
|
|
||||||
}
|
|
||||||
ss << "]" ;
|
|
||||||
return ss.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
vector<string>
|
|
||||||
getStateLines (const Vars& vars)
|
|
||||||
{
|
|
||||||
StatesIndexer idx (vars);
|
|
||||||
vector<string> jointStrings;
|
|
||||||
while (idx.valid()) {
|
|
||||||
stringstream ss;
|
|
||||||
for (unsigned i = 0; i < vars.size(); i++) {
|
|
||||||
if (i != 0) ss << ", " ;
|
|
||||||
ss << vars[i]->label() << "=" << vars[i]->states()[(idx[i])];
|
|
||||||
}
|
|
||||||
jointStrings.push_back (ss.str());
|
|
||||||
++ idx;
|
|
||||||
}
|
|
||||||
return jointStrings;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
printHeader (string header, std::ostream& os)
|
|
||||||
{
|
|
||||||
printAsteriskLine (os);
|
|
||||||
os << header << endl;
|
|
||||||
printAsteriskLine (os);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
printSubHeader (string header, std::ostream& os)
|
|
||||||
{
|
|
||||||
printDashedLine (os);
|
|
||||||
os << header << endl;
|
|
||||||
printDashedLine (os);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
printAsteriskLine (std::ostream& os)
|
|
||||||
{
|
|
||||||
os << "********************************" ;
|
|
||||||
os << "********************************" ;
|
|
||||||
os << endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
printDashedLine (std::ostream& os)
|
|
||||||
{
|
|
||||||
os << "--------------------------------" ;
|
|
||||||
os << "--------------------------------" ;
|
|
||||||
os << endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
namespace LogAware {
|
|
||||||
|
|
||||||
void
|
|
||||||
normalize (Params& v)
|
|
||||||
{
|
|
||||||
double sum;
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
sum = LogAware::addIdenty();
|
|
||||||
for (unsigned i = 0; i < v.size(); i++) {
|
|
||||||
sum = Util::logSum (sum, v[i]);
|
|
||||||
}
|
|
||||||
assert (sum != -numeric_limits<double>::infinity());
|
|
||||||
for (unsigned i = 0; i < v.size(); i++) {
|
|
||||||
v[i] -= sum;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
sum = 0.0;
|
|
||||||
for (unsigned i = 0; i < v.size(); i++) {
|
|
||||||
sum += v[i];
|
|
||||||
}
|
|
||||||
assert (sum != 0.0);
|
|
||||||
for (unsigned i = 0; i < v.size(); i++) {
|
|
||||||
v[i] /= sum;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
double
|
|
||||||
getL1Distance (const Params& v1, const Params& v2)
|
|
||||||
{
|
|
||||||
assert (v1.size() == v2.size());
|
|
||||||
double dist = 0.0;
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
for (unsigned i = 0; i < v1.size(); i++) {
|
|
||||||
dist += abs (exp(v1[i]) - exp(v2[i]));
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (unsigned i = 0; i < v1.size(); i++) {
|
|
||||||
dist += abs (v1[i] - v2[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return dist;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
double
|
|
||||||
getMaxNorm (const Params& v1, const Params& v2)
|
|
||||||
{
|
|
||||||
assert (v1.size() == v2.size());
|
|
||||||
double max = 0.0;
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
for (unsigned i = 0; i < v1.size(); i++) {
|
|
||||||
double diff = abs (exp(v1[i]) - exp(v2[i]));
|
|
||||||
if (diff > max) {
|
|
||||||
max = diff;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (unsigned i = 0; i < v1.size(); i++) {
|
|
||||||
double diff = abs (v1[i] - v2[i]);
|
|
||||||
if (diff > max) {
|
|
||||||
max = diff;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return max;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
double
|
|
||||||
pow (double p, unsigned expoent)
|
|
||||||
{
|
|
||||||
return Globals::logDomain ? p * expoent : std::pow (p, expoent);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
double
|
|
||||||
pow (double p, double expoent)
|
|
||||||
{
|
|
||||||
// assumes that `expoent' is never in log domain
|
|
||||||
return Globals::logDomain ? p * expoent : std::pow (p, expoent);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
pow (Params& v, unsigned expoent)
|
|
||||||
{
|
|
||||||
if (expoent == 1) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
for (unsigned i = 0; i < v.size(); i++) {
|
|
||||||
v[i] *= expoent;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (unsigned i = 0; i < v.size(); i++) {
|
|
||||||
v[i] = std::pow (v[i], expoent);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
pow (Params& v, double expoent)
|
|
||||||
{
|
|
||||||
// assumes that `expoent' is never in log domain
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
for (unsigned i = 0; i < v.size(); i++) {
|
|
||||||
v[i] *= expoent;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (unsigned i = 0; i < v.size(); i++) {
|
|
||||||
v[i] = std::pow (v[i], expoent);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
unsigned
|
|
||||||
Statistics::getSolvedNetworksCounting (void)
|
|
||||||
{
|
|
||||||
return netInfo_.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Statistics::incrementPrimaryNetworksCounting (void)
|
|
||||||
{
|
|
||||||
primaryNetCount_ ++;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
unsigned
|
|
||||||
Statistics::getPrimaryNetworksCounting (void)
|
|
||||||
{
|
|
||||||
return primaryNetCount_;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Statistics::updateStatistics (
|
|
||||||
unsigned size,
|
|
||||||
bool loopy,
|
|
||||||
unsigned nIters,
|
|
||||||
double time)
|
|
||||||
{
|
|
||||||
netInfo_.push_back (NetInfo (size, loopy, nIters, time));
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Statistics::printStatistics (void)
|
|
||||||
{
|
|
||||||
cout << getStatisticString();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Statistics::writeStatistics (const char* fileName)
|
|
||||||
{
|
|
||||||
ofstream out (fileName);
|
|
||||||
if (!out.is_open()) {
|
|
||||||
cerr << "error: cannot open file to write at " ;
|
|
||||||
cerr << "Statistics::writeStats()" << endl;
|
|
||||||
abort();
|
|
||||||
}
|
|
||||||
out << getStatisticString();
|
|
||||||
out.close();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Statistics::updateCompressingStatistics (
|
|
||||||
unsigned nrGroundVars,
|
|
||||||
unsigned nrGroundFactors,
|
|
||||||
unsigned nrClusterVars,
|
|
||||||
unsigned nrClusterFactors,
|
|
||||||
unsigned nrNeighborless) {
|
|
||||||
compressInfo_.push_back (CompressInfo (nrGroundVars, nrGroundFactors,
|
|
||||||
nrClusterVars, nrClusterFactors, nrNeighborless));
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
string
|
|
||||||
Statistics::getStatisticString (void)
|
|
||||||
{
|
|
||||||
stringstream ss2, ss3, ss4, ss1;
|
|
||||||
ss1 << "running mode: " ;
|
|
||||||
switch (Globals::infAlgorithm) {
|
|
||||||
case InfAlgorithms::VE: ss1 << "ve" << endl; break;
|
|
||||||
case InfAlgorithms::BP: ss1 << "bp" << endl; break;
|
|
||||||
case InfAlgorithms::CBP: ss1 << "cbp" << endl; break;
|
|
||||||
}
|
|
||||||
ss1 << "message schedule: " ;
|
|
||||||
switch (BpOptions::schedule) {
|
|
||||||
case BpOptions::Schedule::SEQ_FIXED:
|
|
||||||
ss1 << "sequential fixed" << endl;
|
|
||||||
break;
|
|
||||||
case BpOptions::Schedule::SEQ_RANDOM:
|
|
||||||
ss1 << "sequential random" << endl;
|
|
||||||
break;
|
|
||||||
case BpOptions::Schedule::PARALLEL:
|
|
||||||
ss1 << "parallel" << endl;
|
|
||||||
break;
|
|
||||||
case BpOptions::Schedule::MAX_RESIDUAL:
|
|
||||||
ss1 << "max residual" << endl;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
ss1 << "max iterations: " << BpOptions::maxIter << endl;
|
|
||||||
ss1 << "accuracy " << BpOptions::accuracy << endl;
|
|
||||||
ss1 << endl << endl;
|
|
||||||
Util::printSubHeader ("Network information", ss2);
|
|
||||||
ss2 << left;
|
|
||||||
ss2 << setw (15) << "Network Size" ;
|
|
||||||
ss2 << setw (9) << "Loopy" ;
|
|
||||||
ss2 << setw (15) << "Iterations" ;
|
|
||||||
ss2 << setw (15) << "Solving Time" ;
|
|
||||||
ss2 << endl;
|
|
||||||
unsigned nLoopyNets = 0;
|
|
||||||
unsigned nUnconvergedRuns = 0;
|
|
||||||
double totalSolvingTime = 0.0;
|
|
||||||
for (unsigned i = 0; i < netInfo_.size(); i++) {
|
|
||||||
ss2 << setw (15) << netInfo_[i].size;
|
|
||||||
if (netInfo_[i].loopy) {
|
|
||||||
ss2 << setw (9) << "yes";
|
|
||||||
nLoopyNets ++;
|
|
||||||
} else {
|
|
||||||
ss2 << setw (9) << "no";
|
|
||||||
}
|
|
||||||
if (netInfo_[i].nIters == 0) {
|
|
||||||
ss2 << setw (15) << "n/a" ;
|
|
||||||
} else {
|
|
||||||
ss2 << setw (15) << netInfo_[i].nIters;
|
|
||||||
if (netInfo_[i].nIters > BpOptions::maxIter) {
|
|
||||||
nUnconvergedRuns ++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ss2 << setw (15) << netInfo_[i].time;
|
|
||||||
totalSolvingTime += netInfo_[i].time;
|
|
||||||
ss2 << endl;
|
|
||||||
}
|
|
||||||
ss2 << endl << endl;
|
|
||||||
|
|
||||||
unsigned c1 = 0, c2 = 0, c3 = 0, c4 = 0;
|
|
||||||
if (compressInfo_.size() > 0) {
|
|
||||||
Util::printSubHeader ("Compress information", ss3);
|
|
||||||
ss3 << left;
|
|
||||||
ss3 << "Ground Cluster Ground Cluster Neighborless" << endl;
|
|
||||||
ss3 << "Vars Vars Factors Factors Vars" << endl;
|
|
||||||
for (unsigned i = 0; i < compressInfo_.size(); i++) {
|
|
||||||
ss3 << setw (9) << compressInfo_[i].nrGroundVars;
|
|
||||||
ss3 << setw (10) << compressInfo_[i].nrClusterVars;
|
|
||||||
ss3 << setw (10) << compressInfo_[i].nrGroundFactors;
|
|
||||||
ss3 << setw (10) << compressInfo_[i].nrClusterFactors;
|
|
||||||
ss3 << setw (10) << compressInfo_[i].nrNeighborless;
|
|
||||||
ss3 << endl;
|
|
||||||
c1 += compressInfo_[i].nrGroundVars - compressInfo_[i].nrNeighborless;
|
|
||||||
c2 += compressInfo_[i].nrClusterVars;
|
|
||||||
c3 += compressInfo_[i].nrGroundFactors - compressInfo_[i].nrNeighborless;
|
|
||||||
c4 += compressInfo_[i].nrClusterFactors;
|
|
||||||
if (compressInfo_[i].nrNeighborless != 0) {
|
|
||||||
c2 --;
|
|
||||||
c4 --;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ss3 << endl << endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
ss4 << "primary networks: " << primaryNetCount_ << endl;
|
|
||||||
ss4 << "solved networks: " << netInfo_.size() << endl;
|
|
||||||
ss4 << "loopy networks: " << nLoopyNets << endl;
|
|
||||||
ss4 << "unconverged runs: " << nUnconvergedRuns << endl;
|
|
||||||
ss4 << "total solving time: " << totalSolvingTime << endl;
|
|
||||||
if (compressInfo_.size() > 0) {
|
|
||||||
double pc1 = (1.0 - (c2 / (double)c1)) * 100.0;
|
|
||||||
double pc2 = (1.0 - (c4 / (double)c3)) * 100.0;
|
|
||||||
ss4 << setprecision (5);
|
|
||||||
ss4 << "variable compression: " << pc1 << "%" << endl;
|
|
||||||
ss4 << "factor compression: " << pc2 << "%" << endl;
|
|
||||||
}
|
|
||||||
ss4 << endl << endl;
|
|
||||||
|
|
||||||
ss1 << ss4.str() << ss2.str() << ss3.str();
|
|
||||||
return ss1.str();
|
|
||||||
}
|
|
||||||
|
|
@ -1,376 +0,0 @@
|
|||||||
#ifndef HORUS_UTIL_H
|
|
||||||
#define HORUS_UTIL_H
|
|
||||||
|
|
||||||
#include <cmath>
|
|
||||||
#include <cassert>
|
|
||||||
#include <limits>
|
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
#include <vector>
|
|
||||||
#include <set>
|
|
||||||
#include <queue>
|
|
||||||
#include <unordered_map>
|
|
||||||
|
|
||||||
#include <sstream>
|
|
||||||
#include <iostream>
|
|
||||||
|
|
||||||
#include "Horus.h"
|
|
||||||
|
|
||||||
using namespace std;
|
|
||||||
|
|
||||||
|
|
||||||
namespace Util {
|
|
||||||
|
|
||||||
template <typename T> void addToVector (vector<T>&, const vector<T>&);
|
|
||||||
|
|
||||||
template <typename T> void addToSet (set<T>&, const vector<T>&);
|
|
||||||
|
|
||||||
template <typename T> void addToQueue (queue<T>&, const vector<T>&);
|
|
||||||
|
|
||||||
template <typename T> bool contains (const vector<T>&, const T&);
|
|
||||||
|
|
||||||
template <typename T> bool contains (const set<T>&, const T&);
|
|
||||||
|
|
||||||
template <typename K, typename V> bool contains (
|
|
||||||
const unordered_map<K, V>&, const K&);
|
|
||||||
|
|
||||||
template <typename T> std::string toString (const T&);
|
|
||||||
|
|
||||||
void toLog (Params&);
|
|
||||||
|
|
||||||
void fromLog (Params&);
|
|
||||||
|
|
||||||
double logSum (double, double);
|
|
||||||
|
|
||||||
void multiply (Params&, const Params&);
|
|
||||||
|
|
||||||
void multiply (Params&, const Params&, unsigned);
|
|
||||||
|
|
||||||
void add (Params&, const Params&);
|
|
||||||
|
|
||||||
void add (Params&, const Params&, unsigned);
|
|
||||||
|
|
||||||
double factorial (double);
|
|
||||||
|
|
||||||
unsigned nrCombinations (unsigned, unsigned);
|
|
||||||
|
|
||||||
unsigned expectedSize (const Ranges&);
|
|
||||||
|
|
||||||
unsigned getNumberOfDigits (int);
|
|
||||||
|
|
||||||
bool isInteger (const string&);
|
|
||||||
|
|
||||||
string parametersToString (const Params&, unsigned = Constants::PRECISION);
|
|
||||||
|
|
||||||
vector<string> getStateLines (const Vars&);
|
|
||||||
|
|
||||||
void printHeader (string, std::ostream& os = std::cout);
|
|
||||||
|
|
||||||
void printSubHeader (string, std::ostream& os = std::cout);
|
|
||||||
|
|
||||||
void printAsteriskLine (std::ostream& os = std::cout);
|
|
||||||
|
|
||||||
void printDashedLine (std::ostream& os = std::cout);
|
|
||||||
|
|
||||||
unsigned maxUnsigned (void);
|
|
||||||
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
template <typename T> void
|
|
||||||
Util::addToVector (vector<T>& v, const vector<T>& elements)
|
|
||||||
{
|
|
||||||
v.insert (v.end(), elements.begin(), elements.end());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
template <typename T> void
|
|
||||||
Util::addToSet (set<T>& s, const vector<T>& elements)
|
|
||||||
{
|
|
||||||
s.insert (elements.begin(), elements.end());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
template <typename T> void
|
|
||||||
Util::addToQueue (queue<T>& q, const vector<T>& elements)
|
|
||||||
{
|
|
||||||
for (unsigned i = 0; i < elements.size(); i++) {
|
|
||||||
q.push (elements[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
template <typename T> bool
|
|
||||||
Util::contains (const vector<T>& v, const T& e)
|
|
||||||
{
|
|
||||||
return std::find (v.begin(), v.end(), e) != v.end();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
template <typename T> bool
|
|
||||||
Util::contains (const set<T>& s, const T& e)
|
|
||||||
{
|
|
||||||
return s.find (e) != s.end();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
template <typename K, typename V> bool
|
|
||||||
Util::contains (const unordered_map<K, V>& m, const K& k)
|
|
||||||
{
|
|
||||||
return m.find (k) != m.end();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
template <typename T> std::string
|
|
||||||
Util::toString (const T& t)
|
|
||||||
{
|
|
||||||
std::stringstream ss;
|
|
||||||
ss << t;
|
|
||||||
return ss.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
std::ostream& operator << (std::ostream& os, const vector<T>& v)
|
|
||||||
{
|
|
||||||
os << "[" ;
|
|
||||||
for (unsigned i = 0; i < v.size(); i++) {
|
|
||||||
os << ((i != 0) ? ", " : "") << v[i];
|
|
||||||
}
|
|
||||||
os << "]" ;
|
|
||||||
return os;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
const double INF = -numeric_limits<double>::infinity();
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
inline double
|
|
||||||
Util::logSum (double x, double y)
|
|
||||||
{
|
|
||||||
return log (exp (x) + exp (y));
|
|
||||||
assert (isfinite (x) && isfinite (y));
|
|
||||||
// If one value is much smaller than the other, keep the larger value.
|
|
||||||
if (x < (y - log (1e200))) {
|
|
||||||
return y;
|
|
||||||
}
|
|
||||||
if (y < (x - log (1e200))) {
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
double diff = x - y;
|
|
||||||
assert (isfinite (diff) && isfinite (x) && isfinite (y));
|
|
||||||
if (!isfinite (exp (diff))) {
|
|
||||||
// difference is too large
|
|
||||||
return x > y ? x : y;
|
|
||||||
}
|
|
||||||
// otherwise return the sum.
|
|
||||||
return y + log (static_cast<double>(1.0) + exp (diff));
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline void
|
|
||||||
Util::multiply (Params& v1, const Params& v2)
|
|
||||||
{
|
|
||||||
assert (v1.size() == v2.size());
|
|
||||||
for (unsigned i = 0; i < v1.size(); i++) {
|
|
||||||
v1[i] *= v2[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline void
|
|
||||||
Util::multiply (Params& v1, const Params& v2, unsigned repetitions)
|
|
||||||
{
|
|
||||||
for (unsigned count = 0; count < v1.size(); ) {
|
|
||||||
for (unsigned i = 0; i < v2.size(); i++) {
|
|
||||||
for (unsigned r = 0; r < repetitions; r++) {
|
|
||||||
v1[count] *= v2[i];
|
|
||||||
count ++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline void
|
|
||||||
Util::add (Params& v1, const Params& v2)
|
|
||||||
{
|
|
||||||
assert (v1.size() == v2.size());
|
|
||||||
for (unsigned i = 0; i < v1.size(); i++) {
|
|
||||||
v1[i] += v2[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline void
|
|
||||||
Util::add (Params& v1, const Params& v2, unsigned repetitions)
|
|
||||||
{
|
|
||||||
for (unsigned count = 0; count < v1.size(); ) {
|
|
||||||
for (unsigned i = 0; i < v2.size(); i++) {
|
|
||||||
for (unsigned r = 0; r < repetitions; r++) {
|
|
||||||
v1[count] += v2[i];
|
|
||||||
count ++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline unsigned
|
|
||||||
Util::maxUnsigned (void)
|
|
||||||
{
|
|
||||||
return numeric_limits<unsigned>::max();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
namespace LogAware {
|
|
||||||
|
|
||||||
inline double
|
|
||||||
one()
|
|
||||||
{
|
|
||||||
return Globals::logDomain ? 0.0 : 1.0;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
inline double
|
|
||||||
zero() {
|
|
||||||
return Globals::logDomain ? INF : 0.0 ;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
inline double
|
|
||||||
addIdenty()
|
|
||||||
{
|
|
||||||
return Globals::logDomain ? INF : 0.0;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
inline double
|
|
||||||
multIdenty()
|
|
||||||
{
|
|
||||||
return Globals::logDomain ? 0.0 : 1.0;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
inline double
|
|
||||||
withEvidence()
|
|
||||||
{
|
|
||||||
return Globals::logDomain ? 0.0 : 1.0;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
inline double
|
|
||||||
noEvidence() {
|
|
||||||
return Globals::logDomain ? INF : 0.0;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
inline double
|
|
||||||
tl (double v)
|
|
||||||
{
|
|
||||||
return Globals::logDomain ? log (v) : v;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
inline double
|
|
||||||
fl (double v)
|
|
||||||
{
|
|
||||||
return Globals::logDomain ? exp (v) : v;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
void normalize (Params&);
|
|
||||||
|
|
||||||
double getL1Distance (const Params&, const Params&);
|
|
||||||
|
|
||||||
double getMaxNorm (const Params&, const Params&);
|
|
||||||
|
|
||||||
double pow (double, unsigned);
|
|
||||||
|
|
||||||
double pow (double, double);
|
|
||||||
|
|
||||||
void pow (Params&, unsigned);
|
|
||||||
|
|
||||||
void pow (Params&, double);
|
|
||||||
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
struct NetInfo
|
|
||||||
{
|
|
||||||
NetInfo (unsigned size, bool loopy, unsigned nIters, double time)
|
|
||||||
{
|
|
||||||
this->size = size;
|
|
||||||
this->loopy = loopy;
|
|
||||||
this->nIters = nIters;
|
|
||||||
this->time = time;
|
|
||||||
}
|
|
||||||
unsigned size;
|
|
||||||
bool loopy;
|
|
||||||
unsigned nIters;
|
|
||||||
double time;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
struct CompressInfo
|
|
||||||
{
|
|
||||||
CompressInfo (unsigned a, unsigned b, unsigned c, unsigned d, unsigned e)
|
|
||||||
{
|
|
||||||
nrGroundVars = a;
|
|
||||||
nrGroundFactors = b;
|
|
||||||
nrClusterVars = c;
|
|
||||||
nrClusterFactors = d;
|
|
||||||
nrNeighborless = e;
|
|
||||||
}
|
|
||||||
unsigned nrGroundVars;
|
|
||||||
unsigned nrGroundFactors;
|
|
||||||
unsigned nrClusterVars;
|
|
||||||
unsigned nrClusterFactors;
|
|
||||||
unsigned nrNeighborless;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
class Statistics
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
static unsigned getSolvedNetworksCounting (void);
|
|
||||||
|
|
||||||
static void incrementPrimaryNetworksCounting (void);
|
|
||||||
|
|
||||||
static unsigned getPrimaryNetworksCounting (void);
|
|
||||||
|
|
||||||
static void updateStatistics (unsigned, bool, unsigned, double);
|
|
||||||
|
|
||||||
static void printStatistics (void);
|
|
||||||
|
|
||||||
static void writeStatistics (const char*);
|
|
||||||
|
|
||||||
static void updateCompressingStatistics (
|
|
||||||
unsigned, unsigned, unsigned, unsigned, unsigned);
|
|
||||||
|
|
||||||
private:
|
|
||||||
static string getStatisticString (void);
|
|
||||||
|
|
||||||
static vector<NetInfo> netInfo_;
|
|
||||||
static vector<CompressInfo> compressInfo_;
|
|
||||||
static unsigned primaryNetCount_;
|
|
||||||
};
|
|
||||||
|
|
||||||
#endif // HORUS_UTIL_H
|
|
||||||
|
|
@ -1,35 +0,0 @@
|
|||||||
|
|
||||||
if [ $1 ] && [ $1 == "clear" ]; then
|
|
||||||
rm *~
|
|
||||||
rm -f school/*.log school/*~
|
|
||||||
rm -f city/*.log city/*~
|
|
||||||
rm -f workshop_attrs/*.log workshop_attrs/*~
|
|
||||||
fi
|
|
||||||
|
|
||||||
function run_solver
|
|
||||||
{
|
|
||||||
constraint=$1
|
|
||||||
solver_flag=true
|
|
||||||
if [ -n "$2" ]; then
|
|
||||||
if [ $SOLVER = hve ]; then
|
|
||||||
extra_flag=clpbn_horus:set_horus_flag\(elim_heuristic,$2\)
|
|
||||||
elif [ $SOLVER = bp ]; then
|
|
||||||
extra_flag=clpbn_horus:set_horus_flag\(schedule,$2\)
|
|
||||||
elif [ $SOLVER = cbp ]; then
|
|
||||||
extra_flag=clpbn_horus:set_horus_flag\(schedule,$2\)
|
|
||||||
else
|
|
||||||
echo "unknow flag $2"
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
/usr/bin/time -o $LOG_FILE -a -f "real:%E\tuser:%U\tsys:%S" \
|
|
||||||
$YAP << EOF >> $LOG_FILE 2>> ignore.$LOG_FILE
|
|
||||||
[$NETWORK].
|
|
||||||
[$constraint].
|
|
||||||
clpbn_horus:set_solver($SOLVER).
|
|
||||||
clpbn_horus:set_horus_flag(use_logarithms, true).
|
|
||||||
$solver_flag.
|
|
||||||
$QUERY.
|
|
||||||
open("$LOG_FILE", 'append', S), format(S, '$constraint: ~15+ ', []), close(S).
|
|
||||||
EOF
|
|
||||||
}
|
|
||||||
|
|
@ -1,17 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
source city.sh
|
|
||||||
source ../benchs.sh
|
|
||||||
|
|
||||||
SOLVER="bp"
|
|
||||||
|
|
||||||
YAP=~/bin/$SHORTNAME-$SOLVER
|
|
||||||
|
|
||||||
LOG_FILE=$SOLVER.log
|
|
||||||
#LOG_FILE=results`date "+ %H:%M:%S %d-%m-%Y"`.
|
|
||||||
|
|
||||||
rm -f $LOG_FILE
|
|
||||||
rm -f ignore.$LOG_FILE
|
|
||||||
|
|
||||||
run_all_graphs "bp(shedule=seq_fixed) " seq_fixed
|
|
||||||
|
|
@ -1,17 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
source city.sh
|
|
||||||
source ../benchs.sh
|
|
||||||
|
|
||||||
SOLVER="cbp"
|
|
||||||
|
|
||||||
YAP=~/bin/$SHORTNAME-$SOLVER
|
|
||||||
|
|
||||||
LOG_FILE=$SOLVER.log
|
|
||||||
#LOG_FILE=results`date "+ %H:%M:%S %d-%m-%Y"`.
|
|
||||||
|
|
||||||
rm -f $LOG_FILE
|
|
||||||
rm -f ignore.$LOG_FILE
|
|
||||||
|
|
||||||
run_all_graphs "cbp(shedule=seq_fixed) " seq_fixed
|
|
||||||
|
|
@ -1,25 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
NETWORK="'../../examples/city'"
|
|
||||||
SHORTNAME="city"
|
|
||||||
QUERY="is_joe_guilty(X)"
|
|
||||||
|
|
||||||
|
|
||||||
function run_all_graphs
|
|
||||||
{
|
|
||||||
cp ~/bin/yap $YAP
|
|
||||||
echo -n "**********************************" >> $LOG_FILE
|
|
||||||
echo "**********************************" >> $LOG_FILE
|
|
||||||
echo "results for solver $1" >> $LOG_FILE
|
|
||||||
echo -n "**********************************" >> $LOG_FILE
|
|
||||||
echo "**********************************" >> $LOG_FILE
|
|
||||||
run_solver city_5 $2
|
|
||||||
#run_solver city_1000 $2
|
|
||||||
#run_solver city_5000 $2
|
|
||||||
#run_solver city_10000 $2
|
|
||||||
#run_solver city_50000 $2
|
|
||||||
#run_solver city_100000 $2
|
|
||||||
#run_solver city_500000 $2
|
|
||||||
#run_solver city_1000000 $2
|
|
||||||
}
|
|
||||||
|
|
@ -1,17 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
source city.sh
|
|
||||||
source ../benchs.sh
|
|
||||||
|
|
||||||
SOLVER="fove"
|
|
||||||
|
|
||||||
YAP=~/bin/$SHORTNAME-$SOLVER
|
|
||||||
|
|
||||||
LOG_FILE=$SOLVER.log
|
|
||||||
#LOG_FILE=results`date "+ %H:%M:%S %d-%m-%Y"`.
|
|
||||||
|
|
||||||
rm -f $LOG_FILE
|
|
||||||
rm -f ignore.$LOG_FILEE
|
|
||||||
|
|
||||||
run_all_graphs "fove "
|
|
||||||
|
|
@ -1,17 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
source city.sh
|
|
||||||
source ../benchs.sh
|
|
||||||
|
|
||||||
SOLVER="hve"
|
|
||||||
|
|
||||||
YAP=~/bin/$SHORTNAME-$SOLVER
|
|
||||||
|
|
||||||
LOG_FILE=$SOLVER.log
|
|
||||||
#LOG_FILE=results`date "+ %H:%M:%S %d-%m-%Y"`.
|
|
||||||
|
|
||||||
rm -f $LOG_FILE
|
|
||||||
rm -f ignore.$LOG_FILE
|
|
||||||
|
|
||||||
run_all_graphs "hve(elim_heuristic=min_neighbors) " min_neighbors
|
|
||||||
|
|
@ -1,11 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
source wa.sh
|
|
||||||
source ../benchs.sh
|
|
||||||
|
|
||||||
SOLVER="bp"
|
|
||||||
|
|
||||||
YAP=~/bin/$SHORTNAME-$SOLVER
|
|
||||||
|
|
||||||
run_all_graphs "bp(shedule=seq_fixed) " seq_fixed
|
|
||||||
|
|
@ -1,11 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
source wa.sh
|
|
||||||
source ../benchs.sh
|
|
||||||
|
|
||||||
SOLVER="cbp"
|
|
||||||
|
|
||||||
YAP=~/bin/$SHORTNAME-$SOLVER
|
|
||||||
|
|
||||||
run_all_graphs "cbp(shedule=seq_fixed) " seq_fixed
|
|
||||||
|
|
@ -1,12 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
source wa.sh
|
|
||||||
source ../benchs.sh
|
|
||||||
|
|
||||||
SOLVER="fove"
|
|
||||||
|
|
||||||
YAP=~/bin/$SHORTNAME-$SOLVER
|
|
||||||
|
|
||||||
run_all_graphs "fove "
|
|
||||||
|
|
||||||
|
|
@ -1,11 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
source wa.sh
|
|
||||||
source ../benchs.sh
|
|
||||||
|
|
||||||
SOLVER="hve"
|
|
||||||
|
|
||||||
YAP=~/bin/$SHORTNAME-$SOLVER
|
|
||||||
|
|
||||||
run_all_graphs "hve(elim_heuristic=min_neighbors) " min_neighbors
|
|
||||||
|
|
@ -1,33 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
NETWORK="'../../examples/workshop_attrs'"
|
|
||||||
SHORTNAME="wa"
|
|
||||||
QUERY="series(X)"
|
|
||||||
|
|
||||||
|
|
||||||
function run_all_graphs
|
|
||||||
{
|
|
||||||
LOG_FILE=$SOLVER.log
|
|
||||||
#LOG_FILE=results`date "+ %H:%M:%S %d-%m-%Y"`.
|
|
||||||
|
|
||||||
rm -f $LOG_FILE
|
|
||||||
rm -f ignore.$LOG_FILE
|
|
||||||
|
|
||||||
cp ~/bin/yap $YAP
|
|
||||||
|
|
||||||
echo -n "**********************************" >> $LOG_FILE
|
|
||||||
echo "**********************************" >> $LOG_FILE
|
|
||||||
echo "results for solver $1" >> $LOG_FILE
|
|
||||||
echo -n "**********************************" >> $LOG_FILE
|
|
||||||
echo "**********************************" >> $LOG_FILE
|
|
||||||
run_solver pop_10 $2
|
|
||||||
#run_solver pop_1000 $2
|
|
||||||
#run_solver pop_5000 $2
|
|
||||||
#run_solver pop_10000 $2
|
|
||||||
#run_solver pop_50000 $2
|
|
||||||
#run_solver pop_100000 $2
|
|
||||||
#run_solver pop_500000 $2
|
|
||||||
#run_solver pop_1000000 $2
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
@ -1,41 +0,0 @@
|
|||||||
|
|
||||||
:- use_module(library(pfl)).
|
|
||||||
|
|
||||||
%:- set_pfl_flag(solver,ve).
|
|
||||||
:- set_pfl_flag(solver,bp), clpbn_horus:set_horus_flag(inf_alg,ve).
|
|
||||||
%:- set_pfl_flag(solver,bp), clpbn_horus:set_horus_flag(inf_alg,bp).
|
|
||||||
%:- set_pfl_flag(solver,fove).
|
|
||||||
|
|
||||||
% :- yap_flag(write_strings, off).
|
|
||||||
|
|
||||||
|
|
||||||
bayes burglary::[b1,b3] ; [0.001, 0.999] ; [].
|
|
||||||
|
|
||||||
bayes earthquake::[e1,e2] ; [0.002, 0.998]; [].
|
|
||||||
|
|
||||||
bayes alarm::[a1,a2] , burglary, earthquake ; [0.95, 0.94, 0.29, 0.001, 0.05, 0.06, 0.71, 0.999] ; [].
|
|
||||||
|
|
||||||
bayes john_calls::[j1,j2] , alarm ; [0.9, 0.05, 0.1, 0.95] ; [].
|
|
||||||
|
|
||||||
bayes mary_calls::[m1,m2] , alarm ; [0.7, 0.01, 0.3, 0.99] ; [].
|
|
||||||
|
|
||||||
|
|
||||||
b_cpt([0.001, 0.999]).
|
|
||||||
|
|
||||||
e_cpt([0.002, 0.998]).
|
|
||||||
|
|
||||||
a_cpt([0.95, 0.94, 0.29, 0.001,
|
|
||||||
0.05, 0.06, 0.71, 0.999]).
|
|
||||||
|
|
||||||
jc_cpt([0.9, 0.05,
|
|
||||||
0.1, 0.95]).
|
|
||||||
|
|
||||||
mc_cpt([0.7, 0.01,
|
|
||||||
0.3, 0.99]).
|
|
||||||
|
|
||||||
% ?- alarm(A).
|
|
||||||
?- john_calls(J), mary_calls(m1).
|
|
||||||
%?- john_calls(J), mary_calls(m1), alarm(a1).
|
|
||||||
%?- john_calls(J), alarm(a1).
|
|
||||||
|
|
||||||
|
|
@ -1,15 +0,0 @@
|
|||||||
# example in counting belief propagation paper
|
|
||||||
|
|
||||||
MARKOV
|
|
||||||
3
|
|
||||||
2 2 2
|
|
||||||
2
|
|
||||||
2 0 1
|
|
||||||
2 2 1
|
|
||||||
|
|
||||||
4
|
|
||||||
1.2 1.4 2.0 0.4
|
|
||||||
|
|
||||||
4
|
|
||||||
1.2 1.4 2.0 0.4
|
|
||||||
|
|
@ -1,31 +0,0 @@
|
|||||||
:- use_module(library(pfl)).
|
|
||||||
|
|
||||||
:- clpbn_horus:set_solver(fove).
|
|
||||||
%:- clpbn_horus:set_solver(hve).
|
|
||||||
%:- clpbn_horus:set_solver(bp).
|
|
||||||
%:- clpbn_horus:set_solver(cbp).
|
|
||||||
|
|
||||||
:- yap_flag(write_strings, off).
|
|
||||||
|
|
||||||
c(p1,w1).
|
|
||||||
c(p1,w2).
|
|
||||||
c(p1,w3).
|
|
||||||
c(p2,w1).
|
|
||||||
c(p2,w2).
|
|
||||||
c(p2,w3).
|
|
||||||
c(p3,w1).
|
|
||||||
c(p3,w2).
|
|
||||||
c(p3,w3).
|
|
||||||
c(p4,w1).
|
|
||||||
c(p4,w2).
|
|
||||||
c(p4,w3).
|
|
||||||
c(p5,w1).
|
|
||||||
c(p5,w2).
|
|
||||||
c(p5,w3).
|
|
||||||
|
|
||||||
markov attends(P)::[t,f] , hot(W)::[t,f] ; [0.1, 0.2, 0.3, 0.4] ; [c(P,W)].
|
|
||||||
|
|
||||||
markov attends(P)::[t,f], series::[t,f] ; [0.5, 0.6, 0.7, 0.8] ; [c(P,_)].
|
|
||||||
|
|
||||||
% ?- series(X).
|
|
||||||
|
|
@ -1,24 +0,0 @@
|
|||||||
:- use_module(library(pfl)).
|
|
||||||
|
|
||||||
:- clpbn_horus:set_solver(fove).
|
|
||||||
%:- clpbn_horus:set_solver(hve).
|
|
||||||
%:- clpbn_horus:set_solver(bp).
|
|
||||||
%:- clpbn_horus:set_solver(cbp).
|
|
||||||
|
|
||||||
:- yap_flag(write_strings, off).
|
|
||||||
|
|
||||||
|
|
||||||
friends(P1, P2) :-
|
|
||||||
people(P1),
|
|
||||||
people(P2),
|
|
||||||
P1 \= P2.
|
|
||||||
|
|
||||||
people @ 3.
|
|
||||||
|
|
||||||
markov smokes(P)::[t,f], cancer(P)::[t,f] ; [0.1, 0.2, 0.3, 0.4] ; [people(P)].
|
|
||||||
|
|
||||||
markov friend(P1,P2)::[t,f], smokes(P1)::[t,f], smokes(P2)::[t,f] ;
|
|
||||||
[0.5, 0.6, 0.7, 0.8, 0.5, 0.6, 0.7, 0.8] ; [friends(P1, P2)].
|
|
||||||
|
|
||||||
% ?- smokes(p1, t), smokes(p2, f), friend(p1, p2, X).
|
|
||||||
|
|
@ -1,27 +0,0 @@
|
|||||||
:- use_module(library(pfl)).
|
|
||||||
|
|
||||||
%:- clpbn_horus:set_solver(fove).
|
|
||||||
%:- clpbn_horus:set_solver(hve).
|
|
||||||
:- clpbn_horus:set_solver(bp).
|
|
||||||
%:- clpbn_horus:set_solver(cbp).
|
|
||||||
|
|
||||||
:- yap_flag(write_strings, off).
|
|
||||||
|
|
||||||
people @ 3.
|
|
||||||
|
|
||||||
markov attends(P)::[t,f], attr1::[t,f] ; [0.11, 0.2, 0.3, 0.4] ; [people(P)].
|
|
||||||
|
|
||||||
markov attends(P)::[t,f], attr2::[t,f] ; [0.1, 0.22, 0.3, 0.4] ; [people(P)].
|
|
||||||
|
|
||||||
markov attends(P)::[t,f], attr3::[t,f] ; [0.1, 0.2, 0.33, 0.4] ; [people(P)].
|
|
||||||
|
|
||||||
markov attends(P)::[t,f], attr4::[t,f] ; [0.1, 0.2, 0.3, 0.44] ; [people(P)].
|
|
||||||
|
|
||||||
markov attends(P)::[t,f], attr5::[t,f] ; [0.1, 0.2, 0.3, 0.45] ; [people(P)].
|
|
||||||
|
|
||||||
markov attends(P)::[t,f], attr6::[t,f] ; [0.1, 0.2, 0.3, 0.46] ; [people(P)].
|
|
||||||
|
|
||||||
markov attends(P)::[t,f], series::[t,f] ; [0.5, 0.6, 0.7, 0.87] ; [people(P)].
|
|
||||||
|
|
||||||
% ?- series(X).
|
|
||||||
|
|
@ -41,20 +41,30 @@ do_network([], _, _, _) :- !.
|
|||||||
do_network(QueryVars, EVars, Keys, Factors) :-
|
do_network(QueryVars, EVars, Keys, Factors) :-
|
||||||
retractall(currently_defined(_)),
|
retractall(currently_defined(_)),
|
||||||
retractall(f(_,_,_,_)),
|
retractall(f(_,_,_,_)),
|
||||||
writeln(keys:Keys),
|
|
||||||
run_through_factors(QueryVars),
|
run_through_factors(QueryVars),
|
||||||
run_through_factors(EVars),
|
run_through_factors(EVars),
|
||||||
findall(K, currently_defined(K), Keys),
|
findall(K, currently_defined(K), Keys),
|
||||||
writeln(keys2:Keys),
|
ground_all_keys(QueryVars, Keys),
|
||||||
|
ground_all_keys(EVars, Keys),
|
||||||
findall(f(FType,FId,FKeys,FCPT), f(FType,FId,FKeys,FCPT), Factors).
|
findall(f(FType,FId,FKeys,FCPT), f(FType,FId,FKeys,FCPT), Factors).
|
||||||
|
|
||||||
match([], _Keys).
|
run_through_factors([]).
|
||||||
match([V|GVars], Keys) :-
|
run_through_factors([Var|_QueryVars]) :-
|
||||||
clpbn:get_atts(V,[key(GKey)]), !,
|
clpbn:get_atts(Var,[key(K)]),
|
||||||
member(GKey, Keys), ground(GKey),
|
find_factors(K),
|
||||||
match(GVars, Keys).
|
fail.
|
||||||
match([_V|GVars], Keys) :-
|
run_through_factors([_|QueryVars]) :-
|
||||||
match(GVars, Keys).
|
run_through_factors(QueryVars).
|
||||||
|
|
||||||
|
|
||||||
|
ground_all_keys([], _).
|
||||||
|
ground_all_keys([V|GVars], AllKeys) :-
|
||||||
|
clpbn:get_atts(V,[key(Key)]),
|
||||||
|
\+ ground(Key), !,
|
||||||
|
member(Key, AllKeys),
|
||||||
|
ground_all_keys(GVars, AllKeys).
|
||||||
|
ground_all_keys([_V|GVars], AllKeys) :-
|
||||||
|
ground_all_keys(GVars, AllKeys).
|
||||||
|
|
||||||
|
|
||||||
%
|
%
|
||||||
@ -99,6 +109,7 @@ keys([Var|QueryVars], [Key|QueryKeys]) :-
|
|||||||
initialize_evidence([]).
|
initialize_evidence([]).
|
||||||
initialize_evidence([V|EVars]) :-
|
initialize_evidence([V|EVars]) :-
|
||||||
clpbn:get_atts(V, [key(K)]),
|
clpbn:get_atts(V, [key(K)]),
|
||||||
|
ground(K),
|
||||||
assert(currently_defined(K)),
|
assert(currently_defined(K)),
|
||||||
initialize_evidence(EVars).
|
initialize_evidence(EVars).
|
||||||
|
|
||||||
@ -106,7 +117,7 @@ initialize_evidence([V|EVars]) :-
|
|||||||
% gets key K, and collects factors that define it
|
% gets key K, and collects factors that define it
|
||||||
find_factors(K) :-
|
find_factors(K) :-
|
||||||
\+ currently_defined(K),
|
\+ currently_defined(K),
|
||||||
assert(currently_defined(K)),
|
( ground(K) -> assert(currently_defined(K)) ; true),
|
||||||
defined_in_factor(K, ParFactor),
|
defined_in_factor(K, ParFactor),
|
||||||
add_factor(ParFactor, Ks),
|
add_factor(ParFactor, Ks),
|
||||||
member(K1, Ks),
|
member(K1, Ks),
|
||||||
|
@ -1,61 +1,63 @@
|
|||||||
|
|
||||||
/*******************************************************
|
/*******************************************************
|
||||||
|
|
||||||
Interface with C++
|
Horus Interface
|
||||||
|
|
||||||
********************************************************/
|
********************************************************/
|
||||||
|
|
||||||
:- module(clpbn_horus,
|
:- module(clpbn_horus,
|
||||||
[set_solver/1,
|
[set_solver/1,
|
||||||
create_lifted_network/3,
|
set_horus_flag/1,
|
||||||
create_ground_network/4,
|
cpp_create_lifted_network/3,
|
||||||
set_parfactors_params/2,
|
cpp_create_ground_network/4,
|
||||||
set_factors_params/2,
|
cpp_set_parfactors_params/2,
|
||||||
run_lifted_solver/3,
|
cpp_set_factors_params/2,
|
||||||
run_ground_solver/3,
|
cpp_run_lifted_solver/3,
|
||||||
set_vars_information/2,
|
cpp_run_ground_solver/3,
|
||||||
set_horus_flag/2,
|
cpp_set_vars_information/2,
|
||||||
free_parfactors/1,
|
cpp_set_horus_flag/2,
|
||||||
free_ground_network/1
|
cpp_free_parfactors/1,
|
||||||
|
cpp_free_ground_network/1
|
||||||
]).
|
]).
|
||||||
|
|
||||||
|
:- use_module(library(clpbn),
|
||||||
|
[set_clpbn_flag/2]).
|
||||||
|
|
||||||
|
|
||||||
patch_things_up :-
|
patch_things_up :-
|
||||||
assert_static(clpbn_horus:set_horus_flag(_,_)).
|
assert_static(clpbn_horus:cpp_set_horus_flag(_,_)).
|
||||||
|
|
||||||
|
|
||||||
warning :-
|
warning :-
|
||||||
format(user_error,"Horus library not installed: cannot use bp, fove~n.",[]).
|
format(user_error,"Horus library not installed: cannot use bp, fove~n.",[]).
|
||||||
|
|
||||||
:- catch(load_foreign_files([horus], [], init_predicates), _, patch_things_up) -> true ; warning.
|
|
||||||
|
:- catch(load_foreign_files([horus], [], init_predicates), _, patch_things_up)
|
||||||
|
-> true ; warning.
|
||||||
|
|
||||||
|
|
||||||
set_solver(ve) :- pfl:set_pfl_flag(solver,ve).
|
set_solver(ve) :- set_clpbn_flag(solver,ve).
|
||||||
set_solver(jt) :- pfl:set_pfl_flag(solver,jt).
|
set_solver(jt) :- set_clpbn_flag(solver,jt).
|
||||||
set_solver(gibbs) :- pfl:set_pfl_flag(solver,gibbs).
|
set_solver(gibbs) :- set_clpbn_flag(solver,gibbs).
|
||||||
set_solver(fove) :- pfl:set_pfl_flag(solver,fove).
|
set_solver(fove) :- set_clpbn_flag(solver,fove), set_horus_flag(lifted_solver, fove).
|
||||||
set_solver(hve) :- pfl:set_pfl_flag(solver,bp), set_horus_flag(inf_alg, ve).
|
set_solver(lbp) :- set_clpbn_flag(solver,fove), set_horus_flag(lifted_solver, lbp).
|
||||||
set_solver(bp) :- pfl:set_pfl_flag(solver,bp), set_horus_flag(inf_alg, bp).
|
set_solver(hve) :- set_clpbn_flag(solver,bp), set_horus_flag(ground_solver, ve).
|
||||||
set_solver(cbp) :- pfl:set_pfl_flag(solver,bp), set_horus_flag(inf_alg, cbp).
|
set_solver(bp) :- set_clpbn_flag(solver,bp), set_horus_flag(ground_solver, bp).
|
||||||
|
set_solver(cbp) :- set_clpbn_flag(solver,bp), set_horus_flag(ground_solver, cbp).
|
||||||
set_solver(S) :- throw(error('unknow solver ', S)).
|
set_solver(S) :- throw(error('unknow solver ', S)).
|
||||||
|
|
||||||
|
|
||||||
%:- set_horus_flag(inf_alg, ve).
|
set_horus_flag(K,V) :- cpp_set_horus_flag(K,V).
|
||||||
%:- set_horus_flag(inf_alg, bp).
|
|
||||||
%: -set_horus_flag(inf_alg, cbp).
|
|
||||||
|
|
||||||
:- set_horus_flag(schedule, seq_fixed).
|
|
||||||
%:- set_horus_flag(schedule, seq_random).
|
|
||||||
%:- set_horus_flag(schedule, parallel).
|
|
||||||
%:- set_horus_flag(schedule, max_residual).
|
|
||||||
|
|
||||||
:- set_horus_flag(accuracy, 0.0001).
|
:- cpp_set_horus_flag(schedule, seq_fixed).
|
||||||
|
%:- cpp_set_horus_flag(schedule, seq_random).
|
||||||
|
%:- cpp_set_horus_flag(schedule, parallel).
|
||||||
|
%:- cpp_set_horus_flag(schedule, max_residual).
|
||||||
|
|
||||||
:- set_horus_flag(max_iter, 1000).
|
:- cpp_set_horus_flag(accuracy, 0.0001).
|
||||||
|
|
||||||
:- set_horus_flag(order_factor_variables, false).
|
:- cpp_set_horus_flag(max_iter, 1000).
|
||||||
%:- set_horus_flag(order_factor_variables, true).
|
|
||||||
|
|
||||||
:- set_horus_flag(use_logarithms, false).
|
:- cpp_set_horus_flag(use_logarithms, false).
|
||||||
% :- set_horus_flag(use_logarithms, true).
|
% :- cpp_set_horus_flag(use_logarithms, true).
|
||||||
|
|
||||||
|
@ -1,19 +1,27 @@
|
|||||||
|
|
||||||
/*******************************************************
|
/*******************************************************
|
||||||
|
|
||||||
Belief Propagation and Variable Elimination Interface
|
Interface to Horus Ground Solvers. Used by:
|
||||||
|
- Variable Elimination
|
||||||
|
- Belief Propagation
|
||||||
|
- Counting Belief Propagation
|
||||||
|
|
||||||
********************************************************/
|
********************************************************/
|
||||||
|
|
||||||
:- module(clpbn_bp,
|
:- module(clpbn_horus_ground,
|
||||||
[bp/3,
|
[call_horus_ground_solver/6,
|
||||||
check_if_bp_done/1,
|
check_if_horus_ground_solver_done/1,
|
||||||
init_bp_solver/4,
|
init_horus_ground_solver/4,
|
||||||
run_bp_solver/3,
|
run_horus_ground_solver/3,
|
||||||
call_bp_ground/6,
|
finalize_horus_ground_solver/1
|
||||||
finalize_bp_solver/1
|
|
||||||
]).
|
]).
|
||||||
|
|
||||||
|
:- use_module(horus,
|
||||||
|
[cpp_create_ground_network/4,
|
||||||
|
cpp_set_factors_params/2,
|
||||||
|
cpp_run_ground_solver/3,
|
||||||
|
cpp_set_vars_information/2,
|
||||||
|
cpp_free_ground_network/1
|
||||||
|
]).
|
||||||
|
|
||||||
:- use_module(library('clpbn/dists'),
|
:- use_module(library('clpbn/dists'),
|
||||||
[dist/4,
|
[dist/4,
|
||||||
@ -22,25 +30,20 @@
|
|||||||
get_dist_params/2
|
get_dist_params/2
|
||||||
]).
|
]).
|
||||||
|
|
||||||
|
|
||||||
:- use_module(library('clpbn/display'),
|
:- use_module(library('clpbn/display'),
|
||||||
[clpbn_bind_vals/3]).
|
[clpbn_bind_vals/3]).
|
||||||
|
|
||||||
|
|
||||||
:- use_module(library('clpbn/aggregates'),
|
:- use_module(library('clpbn/aggregates'),
|
||||||
[check_for_agg_vars/2]).
|
[check_for_agg_vars/2]).
|
||||||
|
|
||||||
|
|
||||||
:- use_module(library(charsio),
|
:- use_module(library(charsio),
|
||||||
[term_to_atom/2]).
|
[term_to_atom/2]).
|
||||||
|
|
||||||
|
|
||||||
:- use_module(library(pfl),
|
:- use_module(library(pfl),
|
||||||
[skolem/2,
|
[skolem/2,
|
||||||
get_pfl_parameters/2
|
get_pfl_parameters/2
|
||||||
]).
|
]).
|
||||||
|
|
||||||
|
|
||||||
:- use_module(library(lists)).
|
:- use_module(library(lists)).
|
||||||
|
|
||||||
:- use_module(library(atts)).
|
:- use_module(library(atts)).
|
||||||
@ -48,45 +51,36 @@
|
|||||||
:- use_module(library(bhash)).
|
:- use_module(library(bhash)).
|
||||||
|
|
||||||
|
|
||||||
:- use_module(horus,
|
call_horus_ground_solver(QueryVars, QueryKeys, AllKeys, Factors, Evidence, Output) :-
|
||||||
[create_ground_network/4,
|
|
||||||
set_factors_params/2,
|
|
||||||
run_ground_solver/3,
|
|
||||||
set_vars_information/2,
|
|
||||||
free_ground_network/1
|
|
||||||
]).
|
|
||||||
|
|
||||||
|
|
||||||
call_bp_ground(QueryVars, QueryKeys, AllKeys, Factors, Evidence, Output) :-
|
|
||||||
writeln(here:Factors),
|
|
||||||
b_hash_new(Hash0),
|
b_hash_new(Hash0),
|
||||||
keys_to_ids(AllKeys, 0, Hash0, Hash),
|
keys_to_ids(AllKeys, 0, Hash0, Hash),
|
||||||
get_factors_type(Factors, Type),
|
get_factors_type(Factors, Type),
|
||||||
evidence_to_ids(Evidence, Hash, EvidenceIds),
|
evidence_to_ids(Evidence, Hash, EvidenceIds),
|
||||||
factors_to_ids(Factors, Hash, FactorIds),
|
factors_to_ids(Factors, Hash, FactorIds),
|
||||||
writeln(type:Type), writeln(''),
|
%writeln(type:Type), writeln(''),
|
||||||
writeln(allKeys:AllKeys), writeln(''),
|
%writeln(allKeys:AllKeys), writeln(''),
|
||||||
writeln(factors:Factors), writeln(''),
|
%sort(AllKeys,SKeys),writeln(allKeys:SKeys), writeln(''),
|
||||||
writeln(factorIds:FactorIds), writeln(''),
|
%writeln(factors:Factors), writeln(''),
|
||||||
writeln(evidence:Evidence), writeln(''),
|
%writeln(factorIds:FactorIds), writeln(''),
|
||||||
writeln(evidenceIds:EvidenceIds), writeln(''),
|
%writeln(evidence:Evidence), writeln(''),
|
||||||
create_ground_network(Type, FactorIds, EvidenceIds, Network),
|
%writeln(evidenceIds:EvidenceIds), writeln(''),
|
||||||
|
cpp_create_ground_network(Type, FactorIds, EvidenceIds, Network),
|
||||||
%get_vars_information(AllKeys, StatesNames),
|
%get_vars_information(AllKeys, StatesNames),
|
||||||
%set_vars_information(AllKeys, StatesNames),
|
%terms_to_atoms(AllKeys, KeysAtoms),
|
||||||
|
%cpp_set_vars_information(KeysAtoms, StatesNames),
|
||||||
run_solver(ground(Network,Hash), QueryKeys, Solutions),
|
run_solver(ground(Network,Hash), QueryKeys, Solutions),
|
||||||
writeln(answer:Solutions),
|
|
||||||
clpbn_bind_vals([QueryVars], Solutions, Output),
|
clpbn_bind_vals([QueryVars], Solutions, Output),
|
||||||
free_ground_network(Network).
|
cpp_free_ground_network(Network).
|
||||||
|
|
||||||
|
|
||||||
run_solver(ground(Network,Hash), QueryKeys, Solutions) :-
|
run_solver(ground(Network,Hash), QueryKeys, Solutions) :-
|
||||||
%get_dists_parameters(DistIds, DistsParams),
|
%get_dists_parameters(DistIds, DistsParams),
|
||||||
%set_factors_params(Network, DistsParams),
|
%cpp_set_factors_params(Network, DistsParams),
|
||||||
list_of_keys_to_ids(QueryKeys, Hash, QueryIds),
|
list_of_keys_to_ids(QueryKeys, Hash, QueryIds),
|
||||||
writeln(queryKeys:QueryKeys), writeln(''),
|
%writeln(queryKeys:QueryKeys), writeln(''),
|
||||||
writeln(queryIds:QueryIds), writeln(''),
|
%writeln(queryIds:QueryIds), writeln(''),
|
||||||
list_of_keys_to_ids(QueryKeys, Hash, QueryIds),
|
list_of_keys_to_ids(QueryKeys, Hash, QueryIds),
|
||||||
run_ground_solver(Network, [QueryIds], Solutions).
|
cpp_run_ground_solver(Network, [QueryIds], Solutions).
|
||||||
|
|
||||||
|
|
||||||
keys_to_ids([], _, Hash, Hash).
|
keys_to_ids([], _, Hash, Hash).
|
||||||
@ -132,31 +126,40 @@ get_vars_information(Key.QueryKeys, Domain.StatesNames) :-
|
|||||||
get_vars_information(QueryKeys, StatesNames).
|
get_vars_information(QueryKeys, StatesNames).
|
||||||
|
|
||||||
|
|
||||||
finalize_bp_solver(bp(Network, _)) :-
|
terms_to_atoms([], []).
|
||||||
free_ground_network(Network).
|
terms_to_atoms(K.Ks, Atom.As) :-
|
||||||
|
term_to_atom(K,Atom),
|
||||||
|
terms_to_atoms(Ks,As).
|
||||||
|
|
||||||
|
|
||||||
bp([[]],_,_) :- !.
|
finalize_horus_ground_solver(bp(Network, _)) :-
|
||||||
bp([QueryVars], AllVars, Output) :-
|
cpp_free_ground_network(Network).
|
||||||
init_bp_solver(_, AllVars, _, Network),
|
|
||||||
run_bp_solver([QueryVars], LPs, Network),
|
|
||||||
finalize_bp_solver(Network),
|
|
||||||
clpbn_bind_vals([QueryVars], LPs, Output).
|
|
||||||
|
|
||||||
|
|
||||||
init_bp_solver(_, AllVars0, _, bp(BayesNet, DistIds)) :-
|
init_horus_ground_solver(_, _AllVars0, _, bp(_BayesNet, _DistIds)) :- !.
|
||||||
%check_for_agg_vars(AllVars0, AllVars),
|
|
||||||
get_vars_info(AllVars0, VarsInfo, DistIds0),
|
|
||||||
sort(DistIds0, DistIds),
|
|
||||||
create_ground_network(VarsInfo, BayesNet),
|
|
||||||
true.
|
|
||||||
|
|
||||||
|
run_horus_ground_solver(_QueryVars, _Solutions, bp(_Network, _DistIds)) :- !.
|
||||||
|
|
||||||
run_bp_solver(QueryVars, Solutions, bp(Network, DistIds)) :-
|
%bp([[]],_,_) :- !.
|
||||||
get_dists_parameters(DistIds, DistsParams),
|
%bp([QueryVars], AllVars, Output) :-
|
||||||
set_factors_params(Network, DistsParams),
|
% init_horus_ground_solver(_, AllVars, _, Network),
|
||||||
vars_to_ids(QueryVars, QueryVarsIds),
|
% run_horus_ground_solver([QueryVars], LPs, Network),
|
||||||
run_ground_solver(Network, QueryVarsIds, Solutions).
|
% finalize_horus_ground_solver(Network),
|
||||||
|
% clpbn_bind_vals([QueryVars], LPs, Output).
|
||||||
|
%
|
||||||
|
%init_horus_ground_solver(_, AllVars0, _, bp(BayesNet, DistIds)) :-
|
||||||
|
% %check_for_agg_vars(AllVars0, AllVars),
|
||||||
|
% get_vars_info(AllVars0, VarsInfo, DistIds0),
|
||||||
|
% sort(DistIds0, DistIds),
|
||||||
|
% cpp_create_ground_network(VarsInfo, BayesNet),
|
||||||
|
% true.
|
||||||
|
%
|
||||||
|
%
|
||||||
|
%run_horus_ground_solver(QueryVars, Solutions, bp(Network, DistIds)) :-
|
||||||
|
% get_dists_parameters(DistIds, DistsParams),
|
||||||
|
% cpp_set_factors_params(Network, DistsParams),
|
||||||
|
% vars_to_ids(QueryVars, QueryVarsIds),
|
||||||
|
% cpp_run_ground_solver(Network, QueryVarsIds, Solutions).
|
||||||
|
|
||||||
|
|
||||||
get_dists_parameters([],[]).
|
get_dists_parameters([],[]).
|
@ -1,27 +1,31 @@
|
|||||||
|
|
||||||
/*******************************************************
|
/*******************************************************
|
||||||
|
|
||||||
First Order Variable Elimination Interface
|
Interface to Horus Lifted Solvers. Used by:
|
||||||
|
- Lifted Variable Elimination
|
||||||
|
|
||||||
********************************************************/
|
********************************************************/
|
||||||
|
|
||||||
:- module(clpbn_fove,
|
:- module(clpbn_horus_lifted,
|
||||||
[fove/3,
|
[call_horus_lifted_solver/3,
|
||||||
check_if_fove_done/1,
|
check_if_horus_lifted_solver_done/1,
|
||||||
init_fove_solver/4,
|
init_horus_lifted_solver/4,
|
||||||
run_fove_solver/3,
|
run_horus_lifted_solver/3,
|
||||||
finalize_fove_solver/1
|
finalize_horus_lifted_solver/1
|
||||||
]).
|
]).
|
||||||
|
|
||||||
|
:- use_module(horus,
|
||||||
|
[cpp_create_lifted_network/3,
|
||||||
|
cpp_set_parfactors_params/2,
|
||||||
|
cpp_run_lifted_solver/3,
|
||||||
|
cpp_free_parfactors/1
|
||||||
|
]).
|
||||||
|
|
||||||
:- use_module(library('clpbn/display'),
|
:- use_module(library('clpbn/display'),
|
||||||
[clpbn_bind_vals/3]).
|
[clpbn_bind_vals/3]).
|
||||||
|
|
||||||
|
|
||||||
:- use_module(library('clpbn/dists'),
|
:- use_module(library('clpbn/dists'),
|
||||||
[get_dist_params/2]).
|
[get_dist_params/2]).
|
||||||
|
|
||||||
|
|
||||||
:- use_module(library(pfl),
|
:- use_module(library(pfl),
|
||||||
[factor/6,
|
[factor/6,
|
||||||
skolem/2,
|
skolem/2,
|
||||||
@ -29,30 +33,22 @@
|
|||||||
]).
|
]).
|
||||||
|
|
||||||
|
|
||||||
:- use_module(horus,
|
call_horus_lifted_solver([[]], _, _) :- !.
|
||||||
[create_lifted_network/3,
|
call_horus_lifted_solver([QueryVars], AllVars, Output) :-
|
||||||
set_parfactors_params/2,
|
init_horus_lifted_solver(_, AllVars, _, ParfactorList),
|
||||||
run_lifted_solver/3,
|
run_horus_lifted_solver([QueryVars], LPs, ParfactorList),
|
||||||
free_parfactors/1
|
finalize_horus_lifted_solver(ParfactorList),
|
||||||
]).
|
|
||||||
|
|
||||||
|
|
||||||
fove([[]], _, _) :- !.
|
|
||||||
fove([QueryVars], AllVars, Output) :-
|
|
||||||
init_fove_solver(_, AllVars, _, ParfactorList),
|
|
||||||
run_fove_solver([QueryVars], LPs, ParfactorList),
|
|
||||||
finalize_fove_solver(ParfactorList),
|
|
||||||
clpbn_bind_vals([QueryVars], LPs, Output).
|
clpbn_bind_vals([QueryVars], LPs, Output).
|
||||||
|
|
||||||
|
|
||||||
init_fove_solver(_, AllAttVars, _, fove(ParfactorList, DistIds)) :-
|
init_horus_lifted_solver(_, AllAttVars, _, fove(ParfactorList, DistIds)) :-
|
||||||
get_parfactors(Parfactors),
|
get_parfactors(Parfactors),
|
||||||
get_dist_ids(Parfactors, DistIds0),
|
get_dist_ids(Parfactors, DistIds0),
|
||||||
sort(DistIds0, DistIds),
|
sort(DistIds0, DistIds),
|
||||||
get_observed_vars(AllAttVars, ObservedVars),
|
get_observed_vars(AllAttVars, ObservedVars),
|
||||||
writeln(parfactors:Parfactors:'\n'),
|
%writeln(parfactors:Parfactors:'\n'),
|
||||||
writeln(evidence:ObservedVars:'\n'),
|
%writeln(evidence:ObservedVars:'\n'),
|
||||||
create_lifted_network(Parfactors,ObservedVars,ParfactorList).
|
cpp_create_lifted_network(Parfactors,ObservedVars,ParfactorList).
|
||||||
|
|
||||||
|
|
||||||
:- table get_parfactors/1.
|
:- table get_parfactors/1.
|
||||||
@ -138,15 +134,15 @@ get_dists_parameters([Id|Ids], [dist(Id, Params)|DistsInfo]) :-
|
|||||||
get_dists_parameters(Ids, DistsInfo).
|
get_dists_parameters(Ids, DistsInfo).
|
||||||
|
|
||||||
|
|
||||||
run_fove_solver(QueryVarsAtts, Solutions, fove(ParfactorList, DistIds)) :-
|
run_horus_lifted_solver(QueryVarsAtts, Solutions, fove(ParfactorList, DistIds)) :-
|
||||||
get_query_vars(QueryVarsAtts, QueryVars),
|
get_query_vars(QueryVarsAtts, QueryVars),
|
||||||
writeln(queryVars:QueryVars), writeln(''),
|
%writeln(queryVars:QueryVars), writeln(''),
|
||||||
get_dists_parameters(DistIds, DistsParams),
|
get_dists_parameters(DistIds, DistsParams),
|
||||||
writeln(dists:DistsParams), writeln(''),
|
%writeln(dists:DistsParams), writeln(''),
|
||||||
set_parfactors_params(ParfactorList, DistsParams),
|
cpp_set_parfactors_params(ParfactorList, DistsParams),
|
||||||
run_lifted_solver(ParfactorList, QueryVars, Solutions).
|
cpp_run_lifted_solver(ParfactorList, QueryVars, Solutions).
|
||||||
|
|
||||||
|
|
||||||
finalize_fove_solver(fove(ParfactorList, _)) :-
|
finalize_horus_lifted_solver(fove(ParfactorList, _)) :-
|
||||||
free_parfactors(ParfactorList).
|
cpp_free_parfactors(ParfactorList).
|
||||||
|
|
@ -1,25 +0,0 @@
|
|||||||
|
|
||||||
%conservative_city(nyc, t).
|
|
||||||
|
|
||||||
hair_color(joe, t).
|
|
||||||
|
|
||||||
car_color(joe, t).
|
|
||||||
|
|
||||||
shoe_size(joe, f).
|
|
||||||
|
|
||||||
/* Steps:
|
|
||||||
|
|
||||||
1. generate N facts lives(I, nyc), 0 <= I < N.
|
|
||||||
|
|
||||||
2. generate evidence on descn for N people, *** except for 1 ***
|
|
||||||
|
|
||||||
3. Run query ?- guilty(joe, Guilty), witness(joe, t), descn(2,t), descn(3, f), descn(4, f).
|
|
||||||
|
|
||||||
query(Guilty) :-
|
|
||||||
guilty(joe, Guilty),
|
|
||||||
witness(joe, t),
|
|
||||||
descn(2,t),
|
|
||||||
descn(3,f),
|
|
||||||
descn(4,f), ....
|
|
||||||
|
|
||||||
*/
|
|
@ -1,60 +0,0 @@
|
|||||||
|
|
||||||
/* base file for school database. Supposed to be called from school_*.yap */
|
|
||||||
|
|
||||||
conservative_city(City, Cons) :-
|
|
||||||
cons_table(City, ConsDist),
|
|
||||||
{ Cons = cons(City) with p([y,n], ConsDist) }.
|
|
||||||
|
|
||||||
gender(X, Gender) :-
|
|
||||||
gender_table(City, GenderDist),
|
|
||||||
{ Gender = gender(City) with p([m,f], GenderDist) }.
|
|
||||||
|
|
||||||
hair_color(X, Color) :-
|
|
||||||
lives(X, City),
|
|
||||||
conservative_city(City, Cons),
|
|
||||||
gender(X, Gender),
|
|
||||||
color_table(X,ColorTable),
|
|
||||||
{ Color = color(X) with
|
|
||||||
p([t,f], ColorTable,[Gender,Cons]) }.
|
|
||||||
|
|
||||||
car_color(X, Color) :-
|
|
||||||
hair_color(City, HColor),
|
|
||||||
ccolor_table(X,CColorTable),
|
|
||||||
{ Color = ccolor(X) with
|
|
||||||
p([t,f], CColorTable,[HColor]) }.
|
|
||||||
|
|
||||||
height(X, Height) :-
|
|
||||||
gender(X, Gender),
|
|
||||||
height_table(X,HeightTable),
|
|
||||||
{ Height = height(X) with
|
|
||||||
p([t,f], HeightTable,[Gender]) }.
|
|
||||||
|
|
||||||
shoe_size(X, Shoesize) :-
|
|
||||||
height(X, Height),
|
|
||||||
shoe_size_table(X,ShoesizeTable),
|
|
||||||
{ Shoesize = shoe_size(X) with
|
|
||||||
p([t,f], ShoesizeTable,[Height]) }.
|
|
||||||
|
|
||||||
guilty(X, Guilt) :-
|
|
||||||
guilt_table(X, GuiltDist),
|
|
||||||
{ Guilt = guilt(X) with p([y,n], GuiltDist) }.
|
|
||||||
|
|
||||||
|
|
||||||
descn(X, Descn) :-
|
|
||||||
car_color(X, Car),
|
|
||||||
hair_color(X, Hair),
|
|
||||||
height(X, Height),
|
|
||||||
guilty(X, Guilt),
|
|
||||||
descn_table(X, DescTable),
|
|
||||||
{ Descn = descn(X) with
|
|
||||||
p([t,f], DescTable,[Car,Hair,Height,Guilt]) }.
|
|
||||||
|
|
||||||
witness(City, Witness) :-
|
|
||||||
descn(joe, DescnJ),
|
|
||||||
descn(1, Descn1),
|
|
||||||
wit_table(WitTable),
|
|
||||||
{ Witness = wit(City) with
|
|
||||||
p([t,f], WitTable,[DescnJ, Descn1]) }.
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1,35 +0,0 @@
|
|||||||
|
|
||||||
cons_table(amsterdam,[0.2,
|
|
||||||
0.8]) :- !.
|
|
||||||
cons_table(_, [0.8,
|
|
||||||
0.2]).
|
|
||||||
|
|
||||||
color_table(_,
|
|
||||||
/* tm tf fm ff */
|
|
||||||
[ 0.05, 0.1, 0.3, 0.5 ,
|
|
||||||
0.95, 0.9, 0.7, 0.5 ]).
|
|
||||||
|
|
||||||
ccolor_table(_,
|
|
||||||
/* t f */
|
|
||||||
[ 0.9, 0.2 ,
|
|
||||||
0.1, 0.8 ]).
|
|
||||||
|
|
||||||
height_table(_,
|
|
||||||
/* m f */
|
|
||||||
[ 0.6, 0.4 ,
|
|
||||||
0.4, 0.6 ]).
|
|
||||||
|
|
||||||
shoe_size_table(_,
|
|
||||||
/* t f */
|
|
||||||
[ 0.9, 0.1 ,
|
|
||||||
0.1, 0.9 ]).
|
|
||||||
|
|
||||||
A: professor's ability;
|
|
||||||
B: student's grade (for course registration).
|
|
||||||
*/
|
|
||||||
descn_table(_,
|
|
||||||
/* color, hair, height, guilt */
|
|
||||||
/* ttttt tttf ttft ttff tfttt tftf tfft tfff ttttt fttf ftft ftff ffttt fftf ffft ffff */
|
|
||||||
/*t*/ [0.99, 0.99, 0.99, 0.99, 0.99, 0.99, 0.99, 0.99, 0.99, 0.99, 0.99, 0.99, 0.99, 0.99, 0.99, 0.99,
|
|
||||||
/*f*/ 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01 ]).
|
|
||||||
|
|
23
packages/CLPBN/examples/burglary-alarm.yap
Normal file
23
packages/CLPBN/examples/burglary-alarm.yap
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
:- use_module(library(pfl)).
|
||||||
|
|
||||||
|
%:- set_solver(fove).
|
||||||
|
%:- set_solver(hve).
|
||||||
|
%:- set_solver(bp).
|
||||||
|
%:- set_solver(cbp).
|
||||||
|
|
||||||
|
:- yap_flag(write_strings, off).
|
||||||
|
|
||||||
|
bayes burglary::[b1,b2] ; [0.001, 0.999] ; [].
|
||||||
|
|
||||||
|
bayes earthquake::[e1,e2] ; [0.002, 0.998]; [].
|
||||||
|
|
||||||
|
bayes alarm::[a1,a2], burglary, earthquake ;
|
||||||
|
[0.95, 0.94, 0.29, 0.001, 0.05, 0.06, 0.71, 0.999] ;
|
||||||
|
[].
|
||||||
|
|
||||||
|
bayes john_calls::[j1,j2], alarm ; [0.9, 0.05, 0.1, 0.95] ; [].
|
||||||
|
|
||||||
|
bayes mary_calls::[m1,m2], alarm ; [0.7, 0.01, 0.3, 0.99] ; [].
|
||||||
|
|
||||||
|
% ?- john_calls(J), mary_calls(m1).
|
||||||
|
|
@ -1,20 +1,23 @@
|
|||||||
:- use_module(library(pfl)).
|
:- use_module(library(pfl)).
|
||||||
|
|
||||||
:- clpbn_horus:set_solver(fove).
|
%:- set_solver(fove).
|
||||||
%:- clpbn_horus:set_solver(hve).
|
%:- set_solver(hve).
|
||||||
:- clpbn_horus:set_solver(bp).
|
%:- set_solver(bp).
|
||||||
%:- clpbn_horus:set_solver(cbp).
|
%:- set_solver(cbp).
|
||||||
|
|
||||||
|
:- multifile people/2.
|
||||||
|
:- multifile ev/1.
|
||||||
|
|
||||||
people(joe,nyc).
|
people(joe,nyc).
|
||||||
people(p2, nyc).
|
people(p2, nyc).
|
||||||
people(p3, nyc).
|
people(p3, nyc).
|
||||||
|
people(p4, nyc).
|
||||||
|
people(p5, nyc).
|
||||||
|
|
||||||
ev(descn(p2, t)).
|
ev(descn(p2, t)).
|
||||||
ev(descn(p3, t)).
|
ev(descn(p3, t)).
|
||||||
|
ev(descn(p4, t)).
|
||||||
% :- [city_7].
|
ev(descn(p5, t)).
|
||||||
|
|
||||||
bayes city_conservativeness(C)::[y,n] ; cons_table(C) ; [people(_,C)].
|
bayes city_conservativeness(C)::[y,n] ; cons_table(C) ; [people(_,C)].
|
||||||
|
|
||||||
@ -34,12 +37,12 @@ bayes descn(P)::[t,f], car_color(P), hair_color(P), height(P), guilty(P) ; descn
|
|||||||
|
|
||||||
bayes witness(C)::[t,f], descn(Joe), descn(P2) ; wit_table ; [people(_,C), Joe=joe, P2=p2].
|
bayes witness(C)::[t,f], descn(Joe), descn(P2) ; wit_table ; [people(_,C), Joe=joe, P2=p2].
|
||||||
|
|
||||||
|
% FIXME
|
||||||
cons_table(amsterdam, [0.2, 0.8]) :- !.
|
%cons_table(amsterdam, [0.2, 0.8]) :- !.
|
||||||
cons_table(_, [0.8, 0.2]).
|
cons_table(_, [0.8, 0.2]).
|
||||||
|
|
||||||
|
|
||||||
gender_table(_, [0.55, 0.44]).
|
gender_table(_, [0.55, 0.45]).
|
||||||
|
|
||||||
|
|
||||||
hair_color_table(_,
|
hair_color_table(_,
|
||||||
@ -73,8 +76,8 @@ guilty_table(_, [0.23, 0.77]).
|
|||||||
descn_table(_,
|
descn_table(_,
|
||||||
/* color, hair, height, guilt */
|
/* color, hair, height, guilt */
|
||||||
/* ttttt tttf ttft ttff tfttt tftf tfft tfff ttttt fttf ftft ftff ffttt fftf ffft ffff */
|
/* ttttt tttf ttft ttff tfttt tftf tfft tfff ttttt fttf ftft ftff ffttt fftf ffft ffff */
|
||||||
[ 0.99, 0.5, 0.23, 0.88, 0.41, 0.3, 0.76, 0.87, 0.44, 0.43, 0.29, 0.72, 0.33, 0.91, 0.95, 0.92,
|
[ 0.99, 0.5, 0.23, 0.88, 0.41, 0.3, 0.76, 0.87, 0.44, 0.43, 0.29, 0.72, 0.23, 0.91, 0.95, 0.92,
|
||||||
0.01, 0.5, 0.77, 0.12, 0.59, 0.7, 0.24, 0.13, 0.56, 0.57, 0.61, 0.28, 0.77, 0.09, 0.05, 0.08]).
|
0.01, 0.5, 0.77, 0.12, 0.59, 0.7, 0.24, 0.13, 0.56, 0.57, 0.71, 0.28, 0.77, 0.09, 0.05, 0.08]).
|
||||||
|
|
||||||
|
|
||||||
wit_table([0.2, 0.45, 0.24, 0.34,
|
wit_table([0.2, 0.45, 0.24, 0.34,
|
33
packages/CLPBN/examples/comp_workshops.yap
Normal file
33
packages/CLPBN/examples/comp_workshops.yap
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
:- use_module(library(pfl)).
|
||||||
|
|
||||||
|
%:- set_solver(fove).
|
||||||
|
%:- set_solver(hve).
|
||||||
|
%:- set_solver(bp).
|
||||||
|
%:- set_solver(cbp).
|
||||||
|
|
||||||
|
:- yap_flag(write_strings, off).
|
||||||
|
|
||||||
|
:- multifile c/2.
|
||||||
|
|
||||||
|
c(p1,w1).
|
||||||
|
c(p1,w2).
|
||||||
|
c(p1,w3).
|
||||||
|
c(p2,w1).
|
||||||
|
c(p2,w2).
|
||||||
|
c(p2,w3).
|
||||||
|
c(p3,w1).
|
||||||
|
c(p3,w2).
|
||||||
|
c(p3,w3).
|
||||||
|
c(p4,w1).
|
||||||
|
c(p4,w2).
|
||||||
|
c(p4,w3).
|
||||||
|
c(p5,w1).
|
||||||
|
c(p5,w2).
|
||||||
|
c(p5,w3).
|
||||||
|
|
||||||
|
markov attends(P)::[t,f], hot(W)::[t,f] ; [0.2, 0.8, 0.8, 0.8] ; [c(P,W)].
|
||||||
|
|
||||||
|
markov attends(P)::[t,f], series::[t,f] ; [0.501, 0.499, 0.499, 0.499] ; [c(P,_)].
|
||||||
|
|
||||||
|
% ?- series(X).
|
||||||
|
|
21
packages/CLPBN/examples/fail2.yap
Normal file
21
packages/CLPBN/examples/fail2.yap
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
:- use_module(library(pfl)).
|
||||||
|
|
||||||
|
:- set_solver(fove).
|
||||||
|
%:- set_solver(hve).
|
||||||
|
%:- set_solver(bp).
|
||||||
|
%:- set_solver(cbp).
|
||||||
|
|
||||||
|
:- yap_flag(write_strings, off).
|
||||||
|
|
||||||
|
:- clpbn_horus:set_horus_flag(verbosity,5).
|
||||||
|
|
||||||
|
people(p1,p1).
|
||||||
|
people(p1,p2).
|
||||||
|
people(p2,p1).
|
||||||
|
people(p2,p2).
|
||||||
|
|
||||||
|
markov p(A,A)::[t,f] ; [1.0,4.5] ; [people(A,_)].
|
||||||
|
|
||||||
|
markov p(A,B)::[t,f] ; [1.0,4.5] ; [people(A,B)].
|
||||||
|
|
||||||
|
?- p(p1,p1,X).
|
31
packages/CLPBN/examples/social_domain1.yap
Normal file
31
packages/CLPBN/examples/social_domain1.yap
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
:- use_module(library(pfl)).
|
||||||
|
|
||||||
|
%:- set_solver(fove).
|
||||||
|
%:- set_solver(hve).
|
||||||
|
%:- set_solver(bp).
|
||||||
|
%:- set_solver(cbp).
|
||||||
|
|
||||||
|
:- yap_flag(write_strings, off).
|
||||||
|
|
||||||
|
:- multifile people/1.
|
||||||
|
|
||||||
|
people @ 5.
|
||||||
|
|
||||||
|
people(X,Y) :-
|
||||||
|
people(X),
|
||||||
|
people(Y),
|
||||||
|
X \== Y.
|
||||||
|
|
||||||
|
markov smokes(X)::[t,f]; [1.0, 4.0552]; [people(X)].
|
||||||
|
|
||||||
|
markov cancer(X)::[t,f]; [1.0, 9.9742]; [people(X)].
|
||||||
|
|
||||||
|
markov friends(X,Y)::[t,f] ; [1.0, 99.48432] ; [people(X,Y)].
|
||||||
|
|
||||||
|
markov smokes(X)::[t,f], cancer(X)::[t,f] ; [4.48169, 4.48169, 1.0, 4.48169] ; [people(X)].
|
||||||
|
|
||||||
|
markov friends(X,Y)::[t,f], smokes(X)::[t,f], smokes(Y)::[t,f] ;
|
||||||
|
[3.004166, 3.004166, 3.004166, 3.004166, 3.004166, 1.0, 1.0, 3.004166] ; [people(X,Y)].
|
||||||
|
|
||||||
|
% ?- friends(p1,p2,X).
|
||||||
|
|
31
packages/CLPBN/examples/social_domain2.yap
Normal file
31
packages/CLPBN/examples/social_domain2.yap
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
:- use_module(library(pfl)).
|
||||||
|
|
||||||
|
%:- set_solver(fove).
|
||||||
|
%:- set_solver(hve).
|
||||||
|
%:- set_solver(bp).
|
||||||
|
%:- set_solver(cbp).
|
||||||
|
|
||||||
|
:- yap_flag(write_strings, off).
|
||||||
|
|
||||||
|
:- multifile people/1.
|
||||||
|
|
||||||
|
people @ 5.
|
||||||
|
|
||||||
|
people(X,Y) :-
|
||||||
|
people(X),
|
||||||
|
people(Y).
|
||||||
|
% X \== Y.
|
||||||
|
|
||||||
|
markov smokes(X)::[t,f]; [1.0, 4.0552]; [people(X)].
|
||||||
|
|
||||||
|
markov asthma(X)::[t,f]; [1.0, 9.9742] ; [people(X)].
|
||||||
|
|
||||||
|
markov friends(X,Y)::[t,f]; [1.0, 99.48432] ; [people(X,Y)].
|
||||||
|
|
||||||
|
markov asthma(X)::[t,f], smokes(X)::[t,f]; [4.48169, 4.48169, 1.0, 4.48169] ; [people(X)].
|
||||||
|
|
||||||
|
markov asthma(X)::[t,f], friends(X,Y)::[t,f], smokes(Y)::[t,f];
|
||||||
|
[3.004166, 3.004166, 3.004166, 3.004166, 3.004166, 1.0, 1.0, 3.004166] ; [people(X,Y)].
|
||||||
|
|
||||||
|
% ?- smokes(p1,t), smokes(p2,t), friends(p1,p2,X)
|
||||||
|
|
29
packages/CLPBN/examples/workshop_attrs.yap
Normal file
29
packages/CLPBN/examples/workshop_attrs.yap
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
:- use_module(library(pfl)).
|
||||||
|
|
||||||
|
%:- set_solver(fove).
|
||||||
|
%:- set_solver(hve).
|
||||||
|
%:- set_solver(bp).
|
||||||
|
%:- set_solver(cbp).
|
||||||
|
|
||||||
|
:- yap_flag(write_strings, off).
|
||||||
|
|
||||||
|
:- multifile people/1.
|
||||||
|
|
||||||
|
people @ 5.
|
||||||
|
|
||||||
|
markov attends(P)::[t,f], attr1::[t,f] ; [0.7, 0.3, 0.3, 0.3] ; [people(P)].
|
||||||
|
|
||||||
|
markov attends(P)::[t,f], attr2::[t,f] ; [0.7, 0.3, 0.3, 0.3] ; [people(P)].
|
||||||
|
|
||||||
|
markov attends(P)::[t,f], attr3::[t,f] ; [0.7, 0.3, 0.3, 0.3] ; [people(P)].
|
||||||
|
|
||||||
|
markov attends(P)::[t,f], attr4::[t,f] ; [0.7, 0.3, 0.3, 0.3] ; [people(P)].
|
||||||
|
|
||||||
|
markov attends(P)::[t,f], attr5::[t,f] ; [0.7, 0.3, 0.3, 0.3] ; [people(P)].
|
||||||
|
|
||||||
|
markov attends(P)::[t,f], attr6::[t,f] ; [0.7, 0.3, 0.3, 0.3] ; [people(P)].
|
||||||
|
|
||||||
|
markov attends(P)::[t,f], series::[t,f] ; [0.501, 0.499, 0.499, 0.499] ; [people(P)].
|
||||||
|
|
||||||
|
% ?- series(X).
|
||||||
|
|
@ -9,14 +9,12 @@
|
|||||||
#include "Util.h"
|
#include "Util.h"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
FactorGraph*
|
FactorGraph*
|
||||||
BayesBall::getMinimalFactorGraph (const VarIds& queryIds)
|
BayesBall::getMinimalFactorGraph (const VarIds& queryIds)
|
||||||
{
|
{
|
||||||
assert (fg_.isFromBayesNetwork());
|
assert (fg_.bayesianFactors());
|
||||||
|
|
||||||
Scheduling scheduling;
|
Scheduling scheduling;
|
||||||
for (unsigned i = 0; i < queryIds.size(); i++) {
|
for (size_t i = 0; i < queryIds.size(); i++) {
|
||||||
assert (dag_.getNode (queryIds[i]));
|
assert (dag_.getNode (queryIds[i]));
|
||||||
DAGraphNode* n = dag_.getNode (queryIds[i]);
|
DAGraphNode* n = dag_.getNode (queryIds[i]);
|
||||||
scheduling.push (ScheduleInfo (n, false, true));
|
scheduling.push (ScheduleInfo (n, false, true));
|
||||||
@ -60,11 +58,11 @@ void
|
|||||||
BayesBall::constructGraph (FactorGraph* fg) const
|
BayesBall::constructGraph (FactorGraph* fg) const
|
||||||
{
|
{
|
||||||
const FacNodes& facNodes = fg_.facNodes();
|
const FacNodes& facNodes = fg_.facNodes();
|
||||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||||
const DAGraphNode* n = dag_.getNode (
|
const DAGraphNode* n = dag_.getNode (
|
||||||
facNodes[i]->factor().argument (0));
|
facNodes[i]->factor().argument (0));
|
||||||
if (n->isMarkedOnTop()) {
|
if (n->isMarkedOnTop()) {
|
||||||
fg->addFactor (Factor (facNodes[i]->factor()));
|
fg->addFactor (facNodes[i]->factor());
|
||||||
} else if (n->hasEvidence() && n->isVisited()) {
|
} else if (n->hasEvidence() && n->isVisited()) {
|
||||||
VarIds varIds = { facNodes[i]->factor().argument (0) };
|
VarIds varIds = { facNodes[i]->factor().argument (0) };
|
||||||
Ranges ranges = { facNodes[i]->factor().range (0) };
|
Ranges ranges = { facNodes[i]->factor().range (0) };
|
||||||
@ -73,5 +71,14 @@ BayesBall::constructGraph (FactorGraph* fg) const
|
|||||||
fg->addFactor (Factor (varIds, ranges, params));
|
fg->addFactor (Factor (varIds, ranges, params));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
const VarNodes& varNodes = fg_.varNodes();
|
||||||
|
for (size_t i = 0; i < varNodes.size(); i++) {
|
||||||
|
if (varNodes[i]->hasEvidence()) {
|
||||||
|
VarNode* vn = fg->getVarNode (varNodes[i]->varId());
|
||||||
|
if (vn) {
|
||||||
|
vn->setEvidence (varNodes[i]->getEvidence());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -64,7 +64,7 @@ BayesBall::scheduleParents (const DAGraphNode* n, Scheduling& sch) const
|
|||||||
{
|
{
|
||||||
const vector<DAGraphNode*>& ps = n->parents();
|
const vector<DAGraphNode*>& ps = n->parents();
|
||||||
for (vector<DAGraphNode*>::const_iterator it = ps.begin();
|
for (vector<DAGraphNode*>::const_iterator it = ps.begin();
|
||||||
it != ps.end(); it++) {
|
it != ps.end(); ++it) {
|
||||||
sch.push (ScheduleInfo (*it, false, true));
|
sch.push (ScheduleInfo (*it, false, true));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -76,7 +76,7 @@ BayesBall::scheduleChilds (const DAGraphNode* n, Scheduling& sch) const
|
|||||||
{
|
{
|
||||||
const vector<DAGraphNode*>& cs = n->childs();
|
const vector<DAGraphNode*>& cs = n->childs();
|
||||||
for (vector<DAGraphNode*>::const_iterator it = cs.begin();
|
for (vector<DAGraphNode*>::const_iterator it = cs.begin();
|
||||||
it != cs.end(); it++) {
|
it != cs.end(); ++it) {
|
||||||
sch.push (ScheduleInfo (*it, true, false));
|
sch.push (ScheduleInfo (*it, true, false));
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -57,7 +57,7 @@ DAGraph::getNode (VarId vid)
|
|||||||
void
|
void
|
||||||
DAGraph::setIndexes (void)
|
DAGraph::setIndexes (void)
|
||||||
{
|
{
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
for (size_t i = 0; i < nodes_.size(); i++) {
|
||||||
nodes_[i]->setIndex (i);
|
nodes_[i]->setIndex (i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -67,7 +67,7 @@ DAGraph::setIndexes (void)
|
|||||||
void
|
void
|
||||||
DAGraph::clear (void)
|
DAGraph::clear (void)
|
||||||
{
|
{
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
for (size_t i = 0; i < nodes_.size(); i++) {
|
||||||
nodes_[i]->clear();
|
nodes_[i]->clear();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -85,7 +85,7 @@ DAGraph::exportToGraphViz (const char* fileName)
|
|||||||
}
|
}
|
||||||
out << "digraph {" << endl;
|
out << "digraph {" << endl;
|
||||||
out << "ranksep=1" << endl;
|
out << "ranksep=1" << endl;
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
for (size_t i = 0; i < nodes_.size(); i++) {
|
||||||
out << nodes_[i]->varId() ;
|
out << nodes_[i]->varId() ;
|
||||||
out << " [" ;
|
out << " [" ;
|
||||||
out << "label=\"" << nodes_[i]->label() << "\"" ;
|
out << "label=\"" << nodes_[i]->label() << "\"" ;
|
||||||
@ -94,9 +94,9 @@ DAGraph::exportToGraphViz (const char* fileName)
|
|||||||
}
|
}
|
||||||
out << "]" << endl;
|
out << "]" << endl;
|
||||||
}
|
}
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
for (size_t i = 0; i < nodes_.size(); i++) {
|
||||||
const vector<DAGraphNode*>& childs = nodes_[i]->childs();
|
const vector<DAGraphNode*>& childs = nodes_[i]->childs();
|
||||||
for (unsigned j = 0; j < childs.size(); j++) {
|
for (size_t j = 0; j < childs.size(); j++) {
|
||||||
out << nodes_[i]->varId() << " -> " << childs[j]->varId();
|
out << nodes_[i]->varId() << " -> " << childs[j]->varId();
|
||||||
out << " [style=bold]" << endl ;
|
out << " [style=bold]" << endl ;
|
||||||
}
|
}
|
502
packages/CLPBN/horus/BpSolver.cpp
Normal file
502
packages/CLPBN/horus/BpSolver.cpp
Normal file
@ -0,0 +1,502 @@
|
|||||||
|
#include <cassert>
|
||||||
|
#include <limits>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
#include "BpSolver.h"
|
||||||
|
#include "FactorGraph.h"
|
||||||
|
#include "Factor.h"
|
||||||
|
#include "Indexer.h"
|
||||||
|
#include "Horus.h"
|
||||||
|
|
||||||
|
|
||||||
|
BpSolver::BpSolver (const FactorGraph& fg) : Solver (fg)
|
||||||
|
{
|
||||||
|
runned_ = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
BpSolver::~BpSolver (void)
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < varsI_.size(); i++) {
|
||||||
|
delete varsI_[i];
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < facsI_.size(); i++) {
|
||||||
|
delete facsI_[i];
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < links_.size(); i++) {
|
||||||
|
delete links_[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Params
|
||||||
|
BpSolver::solveQuery (VarIds queryVids)
|
||||||
|
{
|
||||||
|
assert (queryVids.empty() == false);
|
||||||
|
return queryVids.size() == 1
|
||||||
|
? getPosterioriOf (queryVids[0])
|
||||||
|
: getJointDistributionOf (queryVids);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
BpSolver::printSolverFlags (void) const
|
||||||
|
{
|
||||||
|
stringstream ss;
|
||||||
|
ss << "belief propagation [" ;
|
||||||
|
ss << "schedule=" ;
|
||||||
|
typedef BpOptions::Schedule Sch;
|
||||||
|
switch (BpOptions::schedule) {
|
||||||
|
case Sch::SEQ_FIXED: ss << "seq_fixed"; break;
|
||||||
|
case Sch::SEQ_RANDOM: ss << "seq_random"; break;
|
||||||
|
case Sch::PARALLEL: ss << "parallel"; break;
|
||||||
|
case Sch::MAX_RESIDUAL: ss << "max_residual"; break;
|
||||||
|
}
|
||||||
|
ss << ",max_iter=" << Util::toString (BpOptions::maxIter);
|
||||||
|
ss << ",accuracy=" << Util::toString (BpOptions::accuracy);
|
||||||
|
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
||||||
|
ss << "]" ;
|
||||||
|
cout << ss.str() << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Params
|
||||||
|
BpSolver::getPosterioriOf (VarId vid)
|
||||||
|
{
|
||||||
|
if (runned_ == false) {
|
||||||
|
runSolver();
|
||||||
|
}
|
||||||
|
assert (fg.getVarNode (vid));
|
||||||
|
VarNode* var = fg.getVarNode (vid);
|
||||||
|
Params probs;
|
||||||
|
if (var->hasEvidence()) {
|
||||||
|
probs.resize (var->range(), LogAware::noEvidence());
|
||||||
|
probs[var->getEvidence()] = LogAware::withEvidence();
|
||||||
|
} else {
|
||||||
|
probs.resize (var->range(), LogAware::multIdenty());
|
||||||
|
const BpLinks& links = ninf(var)->getLinks();
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
for (size_t i = 0; i < links.size(); i++) {
|
||||||
|
probs += links[i]->message();
|
||||||
|
}
|
||||||
|
LogAware::normalize (probs);
|
||||||
|
Util::exp (probs);
|
||||||
|
} else {
|
||||||
|
for (size_t i = 0; i < links.size(); i++) {
|
||||||
|
probs *= links[i]->message();
|
||||||
|
}
|
||||||
|
LogAware::normalize (probs);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return probs;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Params
|
||||||
|
BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
||||||
|
{
|
||||||
|
if (runned_ == false) {
|
||||||
|
runSolver();
|
||||||
|
}
|
||||||
|
VarNode* vn = fg.getVarNode (jointVarIds[0]);
|
||||||
|
const FacNodes& facNodes = vn->neighbors();
|
||||||
|
size_t idx = facNodes.size();
|
||||||
|
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||||
|
if (facNodes[i]->factor().contains (jointVarIds)) {
|
||||||
|
idx = i;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (idx == facNodes.size()) {
|
||||||
|
return getJointByConditioning (jointVarIds);
|
||||||
|
} else {
|
||||||
|
Factor res (facNodes[idx]->factor());
|
||||||
|
const BpLinks& links = ninf(facNodes[idx])->getLinks();
|
||||||
|
for (size_t i = 0; i < links.size(); i++) {
|
||||||
|
Factor msg ({links[i]->varNode()->varId()},
|
||||||
|
{links[i]->varNode()->range()},
|
||||||
|
getVarToFactorMsg (links[i]));
|
||||||
|
res.multiply (msg);
|
||||||
|
}
|
||||||
|
res.sumOutAllExcept (jointVarIds);
|
||||||
|
res.reorderArguments (jointVarIds);
|
||||||
|
res.normalize();
|
||||||
|
Params jointDist = res.params();
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
Util::exp (jointDist);
|
||||||
|
}
|
||||||
|
return jointDist;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
BpSolver::runSolver (void)
|
||||||
|
{
|
||||||
|
initializeSolver();
|
||||||
|
nIters_ = 0;
|
||||||
|
while (!converged() && nIters_ < BpOptions::maxIter) {
|
||||||
|
nIters_ ++;
|
||||||
|
if (Globals::verbosity > 1) {
|
||||||
|
Util::printHeader (string ("Iteration ") + Util::toString (nIters_));
|
||||||
|
}
|
||||||
|
switch (BpOptions::schedule) {
|
||||||
|
case BpOptions::Schedule::SEQ_RANDOM:
|
||||||
|
std::random_shuffle (links_.begin(), links_.end());
|
||||||
|
// no break
|
||||||
|
case BpOptions::Schedule::SEQ_FIXED:
|
||||||
|
for (size_t i = 0; i < links_.size(); i++) {
|
||||||
|
calculateAndUpdateMessage (links_[i]);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case BpOptions::Schedule::PARALLEL:
|
||||||
|
for (size_t i = 0; i < links_.size(); i++) {
|
||||||
|
calculateMessage (links_[i]);
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < links_.size(); i++) {
|
||||||
|
updateMessage(links_[i]);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case BpOptions::Schedule::MAX_RESIDUAL:
|
||||||
|
maxResidualSchedule();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (Globals::verbosity > 0) {
|
||||||
|
if (nIters_ < BpOptions::maxIter) {
|
||||||
|
cout << "Sum-Product converged in " ;
|
||||||
|
cout << nIters_ << " iterations" << endl;
|
||||||
|
} else {
|
||||||
|
cout << "The maximum number of iterations was hit, terminating..." ;
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
runned_ = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
BpSolver::createLinks (void)
|
||||||
|
{
|
||||||
|
const FacNodes& facNodes = fg.facNodes();
|
||||||
|
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||||
|
const VarNodes& neighbors = facNodes[i]->neighbors();
|
||||||
|
for (size_t j = 0; j < neighbors.size(); j++) {
|
||||||
|
links_.push_back (new BpLink (facNodes[i], neighbors[j]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
BpSolver::maxResidualSchedule (void)
|
||||||
|
{
|
||||||
|
if (nIters_ == 1) {
|
||||||
|
for (size_t i = 0; i < links_.size(); i++) {
|
||||||
|
calculateMessage (links_[i]);
|
||||||
|
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
|
||||||
|
linkMap_.insert (make_pair (links_[i], it));
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t c = 0; c < links_.size(); c++) {
|
||||||
|
if (Globals::verbosity > 1) {
|
||||||
|
cout << "current residuals:" << endl;
|
||||||
|
for (SortedOrder::iterator it = sortedOrder_.begin();
|
||||||
|
it != sortedOrder_.end(); ++it) {
|
||||||
|
cout << " " << setw (30) << left << (*it)->toString();
|
||||||
|
cout << "residual = " << (*it)->residual() << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
SortedOrder::iterator it = sortedOrder_.begin();
|
||||||
|
BpLink* link = *it;
|
||||||
|
if (link->residual() < BpOptions::accuracy) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
updateMessage (link);
|
||||||
|
link->clearResidual();
|
||||||
|
sortedOrder_.erase (it);
|
||||||
|
linkMap_.find (link)->second = sortedOrder_.insert (link);
|
||||||
|
|
||||||
|
// update the messages that depend on message source --> destin
|
||||||
|
const FacNodes& factorNeighbors = link->varNode()->neighbors();
|
||||||
|
for (size_t i = 0; i < factorNeighbors.size(); i++) {
|
||||||
|
if (factorNeighbors[i] != link->facNode()) {
|
||||||
|
const BpLinks& links = ninf(factorNeighbors[i])->getLinks();
|
||||||
|
for (size_t j = 0; j < links.size(); j++) {
|
||||||
|
if (links[j]->varNode() != link->varNode()) {
|
||||||
|
calculateMessage (links[j]);
|
||||||
|
BpLinkMap::iterator iter = linkMap_.find (links[j]);
|
||||||
|
sortedOrder_.erase (iter->second);
|
||||||
|
iter->second = sortedOrder_.insert (links[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (Globals::verbosity > 1) {
|
||||||
|
Util::printDashedLine();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
BpSolver::calcFactorToVarMsg (BpLink* link)
|
||||||
|
{
|
||||||
|
FacNode* src = link->facNode();
|
||||||
|
const VarNode* dst = link->varNode();
|
||||||
|
const BpLinks& links = ninf(src)->getLinks();
|
||||||
|
// calculate the product of messages that were sent
|
||||||
|
// to factor `src', except from var `dst'
|
||||||
|
unsigned reps = 1;
|
||||||
|
unsigned msgSize = Util::sizeExpected (src->factor().ranges());
|
||||||
|
Params msgProduct (msgSize, LogAware::multIdenty());
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
for (size_t i = links.size(); i-- > 0; ) {
|
||||||
|
if (links[i]->varNode() != dst) {
|
||||||
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
|
cout << " message from " << links[i]->varNode()->label();
|
||||||
|
cout << ": " ;
|
||||||
|
}
|
||||||
|
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
|
||||||
|
reps, std::plus<double>());
|
||||||
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
reps *= links[i]->varNode()->range();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (size_t i = links.size(); i-- > 0; ) {
|
||||||
|
if (links[i]->varNode() != dst) {
|
||||||
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
|
cout << " message from " << links[i]->varNode()->label();
|
||||||
|
cout << ": " ;
|
||||||
|
}
|
||||||
|
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
|
||||||
|
reps, std::multiplies<double>());
|
||||||
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
reps *= links[i]->varNode()->range();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Factor result (src->factor().arguments(),
|
||||||
|
src->factor().ranges(), msgProduct);
|
||||||
|
result.multiply (src->factor());
|
||||||
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
|
cout << " message product: " << msgProduct << endl;
|
||||||
|
cout << " original factor: " << src->factor().params() << endl;
|
||||||
|
cout << " factor product: " << result.params() << endl;
|
||||||
|
}
|
||||||
|
result.sumOutAllExcept (dst->varId());
|
||||||
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
|
cout << " marginalized: " << result.params() << endl;
|
||||||
|
}
|
||||||
|
link->nextMessage() = result.params();
|
||||||
|
LogAware::normalize (link->nextMessage());
|
||||||
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
|
cout << " curr msg: " << link->message() << endl;
|
||||||
|
cout << " next msg: " << link->nextMessage() << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Params
|
||||||
|
BpSolver::getVarToFactorMsg (const BpLink* link) const
|
||||||
|
{
|
||||||
|
const VarNode* src = link->varNode();
|
||||||
|
Params msg;
|
||||||
|
if (src->hasEvidence()) {
|
||||||
|
msg.resize (src->range(), LogAware::noEvidence());
|
||||||
|
msg[src->getEvidence()] = LogAware::withEvidence();
|
||||||
|
} else {
|
||||||
|
msg.resize (src->range(), LogAware::one());
|
||||||
|
}
|
||||||
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
|
cout << msg;
|
||||||
|
}
|
||||||
|
BpLinks::const_iterator it;
|
||||||
|
const BpLinks& links = ninf (src)->getLinks();
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
for (it = links.begin(); it != links.end(); ++it) {
|
||||||
|
msg += (*it)->message();
|
||||||
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
|
cout << " x " << (*it)->message();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
msg -= link->message();
|
||||||
|
} else {
|
||||||
|
for (it = links.begin(); it != links.end(); ++it) {
|
||||||
|
msg *= (*it)->message();
|
||||||
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
|
cout << " x " << (*it)->message();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
msg /= link->message();
|
||||||
|
}
|
||||||
|
if (Constants::SHOW_BP_CALCS) {
|
||||||
|
cout << " = " << msg;
|
||||||
|
}
|
||||||
|
return msg;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Params
|
||||||
|
BpSolver::getJointByConditioning (const VarIds& jointVarIds) const
|
||||||
|
{
|
||||||
|
VarNodes jointVars;
|
||||||
|
for (size_t i = 0; i < jointVarIds.size(); i++) {
|
||||||
|
assert (fg.getVarNode (jointVarIds[i]));
|
||||||
|
jointVars.push_back (fg.getVarNode (jointVarIds[i]));
|
||||||
|
}
|
||||||
|
|
||||||
|
FactorGraph* tempFg = new FactorGraph (fg);
|
||||||
|
BpSolver solver (*tempFg);
|
||||||
|
solver.runSolver();
|
||||||
|
Params prevBeliefs = solver.getPosterioriOf (jointVarIds[0]);
|
||||||
|
|
||||||
|
VarIds observedVids = {jointVars[0]->varId()};
|
||||||
|
|
||||||
|
for (size_t i = 1; i < jointVarIds.size(); i++) {
|
||||||
|
assert (jointVars[i]->hasEvidence() == false);
|
||||||
|
Params newBeliefs;
|
||||||
|
Vars observedVars;
|
||||||
|
Ranges observedRanges;
|
||||||
|
for (size_t j = 0; j < observedVids.size(); j++) {
|
||||||
|
observedVars.push_back (tempFg->getVarNode (observedVids[j]));
|
||||||
|
observedRanges.push_back (observedVars.back()->range());
|
||||||
|
}
|
||||||
|
Indexer indexer (observedRanges, false);
|
||||||
|
while (indexer.valid()) {
|
||||||
|
for (size_t j = 0; j < observedVars.size(); j++) {
|
||||||
|
observedVars[j]->setEvidence (indexer[j]);
|
||||||
|
}
|
||||||
|
BpSolver solver (*tempFg);
|
||||||
|
solver.runSolver();
|
||||||
|
Params beliefs = solver.getPosterioriOf (jointVarIds[i]);
|
||||||
|
for (size_t k = 0; k < beliefs.size(); k++) {
|
||||||
|
newBeliefs.push_back (beliefs[k]);
|
||||||
|
}
|
||||||
|
++ indexer;
|
||||||
|
}
|
||||||
|
|
||||||
|
int count = -1;
|
||||||
|
for (size_t j = 0; j < newBeliefs.size(); j++) {
|
||||||
|
if (j % jointVars[i]->range() == 0) {
|
||||||
|
count ++;
|
||||||
|
}
|
||||||
|
newBeliefs[j] *= prevBeliefs[count];
|
||||||
|
}
|
||||||
|
prevBeliefs = newBeliefs;
|
||||||
|
observedVids.push_back (jointVars[i]->varId());
|
||||||
|
}
|
||||||
|
return prevBeliefs;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
BpSolver::initializeSolver (void)
|
||||||
|
{
|
||||||
|
const VarNodes& varNodes = fg.varNodes();
|
||||||
|
varsI_.reserve (varNodes.size());
|
||||||
|
for (size_t i = 0; i < varNodes.size(); i++) {
|
||||||
|
varsI_.push_back (new SPNodeInfo());
|
||||||
|
}
|
||||||
|
const FacNodes& facNodes = fg.facNodes();
|
||||||
|
facsI_.reserve (facNodes.size());
|
||||||
|
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||||
|
facsI_.push_back (new SPNodeInfo());
|
||||||
|
}
|
||||||
|
createLinks();
|
||||||
|
for (size_t i = 0; i < links_.size(); i++) {
|
||||||
|
FacNode* src = links_[i]->facNode();
|
||||||
|
VarNode* dst = links_[i]->varNode();
|
||||||
|
ninf (dst)->addBpLink (links_[i]);
|
||||||
|
ninf (src)->addBpLink (links_[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
BpSolver::converged (void)
|
||||||
|
{
|
||||||
|
if (links_.size() == 0) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (nIters_ == 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (Globals::verbosity > 2) {
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
if (nIters_ == 1) {
|
||||||
|
if (Globals::verbosity > 1) {
|
||||||
|
cout << "no residuals" << endl << endl;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
bool converged = true;
|
||||||
|
if (BpOptions::schedule == BpOptions::Schedule::MAX_RESIDUAL) {
|
||||||
|
double maxResidual = (*(sortedOrder_.begin()))->residual();
|
||||||
|
if (maxResidual > BpOptions::accuracy) {
|
||||||
|
converged = false;
|
||||||
|
} else {
|
||||||
|
converged = true;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (size_t i = 0; i < links_.size(); i++) {
|
||||||
|
double residual = links_[i]->residual();
|
||||||
|
if (Globals::verbosity > 1) {
|
||||||
|
cout << links_[i]->toString() + " residual = " << residual << endl;
|
||||||
|
}
|
||||||
|
if (residual > BpOptions::accuracy) {
|
||||||
|
converged = false;
|
||||||
|
if (Globals::verbosity < 2) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (Globals::verbosity > 1) {
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return converged;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
BpSolver::printLinkInformation (void) const
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < links_.size(); i++) {
|
||||||
|
BpLink* l = links_[i];
|
||||||
|
cout << l->toString() << ":" << endl;
|
||||||
|
cout << " curr msg = " ;
|
||||||
|
cout << l->message() << endl;
|
||||||
|
cout << " next msg = " ;
|
||||||
|
cout << l->nextMessage() << endl;
|
||||||
|
cout << " residual = " << l->residual() << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -13,34 +13,31 @@
|
|||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
|
|
||||||
class SpLink
|
class BpLink
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
SpLink (FacNode* fn, VarNode* vn)
|
BpLink (FacNode* fn, VarNode* vn)
|
||||||
{
|
{
|
||||||
fac_ = fn;
|
fac_ = fn;
|
||||||
var_ = vn;
|
var_ = vn;
|
||||||
v1_.resize (vn->range(), LogAware::tl (1.0 / vn->range()));
|
v1_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
|
||||||
v2_.resize (vn->range(), LogAware::tl (1.0 / vn->range()));
|
v2_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
|
||||||
currMsg_ = &v1_;
|
currMsg_ = &v1_;
|
||||||
nextMsg_ = &v2_;
|
nextMsg_ = &v2_;
|
||||||
msgSended_ = false;
|
|
||||||
residual_ = 0.0;
|
residual_ = 0.0;
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual ~SpLink (void) { };
|
virtual ~BpLink (void) { };
|
||||||
|
|
||||||
FacNode* getFactor (void) const { return fac_; }
|
FacNode* facNode (void) const { return fac_; }
|
||||||
|
|
||||||
VarNode* getVariable (void) const { return var_; }
|
VarNode* varNode (void) const { return var_; }
|
||||||
|
|
||||||
const Params& getMessage (void) const { return *currMsg_; }
|
const Params& message (void) const { return *currMsg_; }
|
||||||
|
|
||||||
Params& getNextMessage (void) { return *nextMsg_; }
|
Params& nextMessage (void) { return *nextMsg_; }
|
||||||
|
|
||||||
bool messageWasSended (void) const { return msgSended_; }
|
double residual (void) const { return residual_; }
|
||||||
|
|
||||||
double getResidual (void) const { return residual_; }
|
|
||||||
|
|
||||||
void clearResidual (void) { residual_ = 0.0; }
|
void clearResidual (void) { residual_ = 0.0; }
|
||||||
|
|
||||||
@ -52,7 +49,6 @@ class SpLink
|
|||||||
virtual void updateMessage (void)
|
virtual void updateMessage (void)
|
||||||
{
|
{
|
||||||
swap (currMsg_, nextMsg_);
|
swap (currMsg_, nextMsg_);
|
||||||
msgSended_ = true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
string toString (void) const
|
string toString (void) const
|
||||||
@ -71,20 +67,19 @@ class SpLink
|
|||||||
Params v2_;
|
Params v2_;
|
||||||
Params* currMsg_;
|
Params* currMsg_;
|
||||||
Params* nextMsg_;
|
Params* nextMsg_;
|
||||||
bool msgSended_;
|
|
||||||
double residual_;
|
double residual_;
|
||||||
};
|
};
|
||||||
|
|
||||||
typedef vector<SpLink*> SpLinkSet;
|
typedef vector<BpLink*> BpLinks;
|
||||||
|
|
||||||
|
|
||||||
class SPNodeInfo
|
class SPNodeInfo
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
void addSpLink (SpLink* link) { links_.push_back (link); }
|
void addBpLink (BpLink* link) { links_.push_back (link); }
|
||||||
const SpLinkSet& getLinks (void) { return links_; }
|
const BpLinks& getLinks (void) { return links_; }
|
||||||
private:
|
private:
|
||||||
SpLinkSet links_;
|
BpLinks links_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
@ -97,6 +92,8 @@ class BpSolver : public Solver
|
|||||||
|
|
||||||
Params solveQuery (VarIds);
|
Params solveQuery (VarIds);
|
||||||
|
|
||||||
|
virtual void printSolverFlags (void) const;
|
||||||
|
|
||||||
virtual Params getPosterioriOf (VarId);
|
virtual Params getPosterioriOf (VarId);
|
||||||
|
|
||||||
virtual Params getJointDistributionOf (const VarIds&);
|
virtual Params getJointDistributionOf (const VarIds&);
|
||||||
@ -108,9 +105,9 @@ class BpSolver : public Solver
|
|||||||
|
|
||||||
virtual void maxResidualSchedule (void);
|
virtual void maxResidualSchedule (void);
|
||||||
|
|
||||||
virtual void calculateFactor2VariableMsg (SpLink*);
|
virtual void calcFactorToVarMsg (BpLink*);
|
||||||
|
|
||||||
virtual Params getVar2FactorMsg (const SpLink*) const;
|
virtual Params getVarToFactorMsg (const BpLink*) const;
|
||||||
|
|
||||||
virtual Params getJointByConditioning (const VarIds&) const;
|
virtual Params getJointByConditioning (const VarIds&) const;
|
||||||
|
|
||||||
@ -124,64 +121,63 @@ class BpSolver : public Solver
|
|||||||
return facsI_[fac->getIndex()];
|
return facsI_[fac->getIndex()];
|
||||||
}
|
}
|
||||||
|
|
||||||
void calculateAndUpdateMessage (SpLink* link, bool calcResidual = true)
|
void calculateAndUpdateMessage (BpLink* link, bool calcResidual = true)
|
||||||
{
|
{
|
||||||
if (Constants::DEBUG >= 3) {
|
if (Globals::verbosity > 2) {
|
||||||
cout << "calculating & updating " << link->toString() << endl;
|
cout << "calculating & updating " << link->toString() << endl;
|
||||||
}
|
}
|
||||||
calculateFactor2VariableMsg (link);
|
calcFactorToVarMsg (link);
|
||||||
if (calcResidual) {
|
if (calcResidual) {
|
||||||
link->updateResidual();
|
link->updateResidual();
|
||||||
}
|
}
|
||||||
link->updateMessage();
|
link->updateMessage();
|
||||||
}
|
}
|
||||||
|
|
||||||
void calculateMessage (SpLink* link, bool calcResidual = true)
|
void calculateMessage (BpLink* link, bool calcResidual = true)
|
||||||
{
|
{
|
||||||
if (Constants::DEBUG >= 3) {
|
if (Globals::verbosity > 2) {
|
||||||
cout << "calculating " << link->toString() << endl;
|
cout << "calculating " << link->toString() << endl;
|
||||||
}
|
}
|
||||||
calculateFactor2VariableMsg (link);
|
calcFactorToVarMsg (link);
|
||||||
if (calcResidual) {
|
if (calcResidual) {
|
||||||
link->updateResidual();
|
link->updateResidual();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void updateMessage (SpLink* link)
|
void updateMessage (BpLink* link)
|
||||||
{
|
{
|
||||||
link->updateMessage();
|
link->updateMessage();
|
||||||
if (Constants::DEBUG >= 3) {
|
if (Globals::verbosity > 2) {
|
||||||
cout << "updating " << link->toString() << endl;
|
cout << "updating " << link->toString() << endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct CompareResidual
|
struct CompareResidual
|
||||||
{
|
{
|
||||||
inline bool operator() (const SpLink* link1, const SpLink* link2)
|
inline bool operator() (const BpLink* link1, const BpLink* link2)
|
||||||
{
|
{
|
||||||
return link1->getResidual() > link2->getResidual();
|
return link1->residual() > link2->residual();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
SpLinkSet links_;
|
BpLinks links_;
|
||||||
unsigned nIters_;
|
unsigned nIters_;
|
||||||
vector<SPNodeInfo*> varsI_;
|
vector<SPNodeInfo*> varsI_;
|
||||||
vector<SPNodeInfo*> facsI_;
|
vector<SPNodeInfo*> facsI_;
|
||||||
bool runned_;
|
bool runned_;
|
||||||
const FactorGraph* fg_;
|
|
||||||
|
|
||||||
typedef multiset<SpLink*, CompareResidual> SortedOrder;
|
typedef multiset<BpLink*, CompareResidual> SortedOrder;
|
||||||
SortedOrder sortedOrder_;
|
SortedOrder sortedOrder_;
|
||||||
|
|
||||||
typedef unordered_map<SpLink*, SortedOrder::iterator> SpLinkMap;
|
typedef unordered_map<BpLink*, SortedOrder::iterator> BpLinkMap;
|
||||||
SpLinkMap linkMap_;
|
BpLinkMap linkMap_;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void initializeSolver (void);
|
void initializeSolver (void);
|
||||||
|
|
||||||
bool converged (void);
|
bool converged (void);
|
||||||
|
|
||||||
void printLinkInformation (void) const;
|
virtual void printLinkInformation (void) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // HORUS_BPSOLVER_H
|
#endif // HORUS_BPSOLVER_H
|
400
packages/CLPBN/horus/CbpSolver.cpp
Normal file
400
packages/CLPBN/horus/CbpSolver.cpp
Normal file
@ -0,0 +1,400 @@
|
|||||||
|
#include "CbpSolver.h"
|
||||||
|
#include "WeightedBpSolver.h"
|
||||||
|
|
||||||
|
|
||||||
|
bool CbpSolver::checkForIdenticalFactors = true;
|
||||||
|
|
||||||
|
|
||||||
|
CbpSolver::CbpSolver (const FactorGraph& fg)
|
||||||
|
: Solver (fg), freeColor_(0)
|
||||||
|
{
|
||||||
|
findIdenticalFactors();
|
||||||
|
setInitialColors();
|
||||||
|
createGroups();
|
||||||
|
compressedFg_ = getCompressedFactorGraph();
|
||||||
|
solver_ = new WeightedBpSolver (*compressedFg_, getWeights());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
CbpSolver::~CbpSolver (void)
|
||||||
|
{
|
||||||
|
delete solver_;
|
||||||
|
delete compressedFg_;
|
||||||
|
for (size_t i = 0; i < varClusters_.size(); i++) {
|
||||||
|
delete varClusters_[i];
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < facClusters_.size(); i++) {
|
||||||
|
delete facClusters_[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
CbpSolver::printSolverFlags (void) const
|
||||||
|
{
|
||||||
|
stringstream ss;
|
||||||
|
ss << "counting bp [" ;
|
||||||
|
ss << "schedule=" ;
|
||||||
|
typedef BpOptions::Schedule Sch;
|
||||||
|
switch (BpOptions::schedule) {
|
||||||
|
case Sch::SEQ_FIXED: ss << "seq_fixed"; break;
|
||||||
|
case Sch::SEQ_RANDOM: ss << "seq_random"; break;
|
||||||
|
case Sch::PARALLEL: ss << "parallel"; break;
|
||||||
|
case Sch::MAX_RESIDUAL: ss << "max_residual"; break;
|
||||||
|
}
|
||||||
|
ss << ",max_iter=" << BpOptions::maxIter;
|
||||||
|
ss << ",accuracy=" << BpOptions::accuracy;
|
||||||
|
ss << ",log_domain=" << Util::toString (Globals::logDomain);
|
||||||
|
ss << ",chkif=" <<
|
||||||
|
Util::toString (CbpSolver::checkForIdenticalFactors);
|
||||||
|
ss << "]" ;
|
||||||
|
cout << ss.str() << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Params
|
||||||
|
CbpSolver::solveQuery (VarIds queryVids)
|
||||||
|
{
|
||||||
|
assert (queryVids.empty() == false);
|
||||||
|
Params res;
|
||||||
|
if (queryVids.size() == 1) {
|
||||||
|
res = solver_->getPosterioriOf (getRepresentative (queryVids[0]));
|
||||||
|
} else {
|
||||||
|
VarNode* vn = fg.getVarNode (queryVids[0]);
|
||||||
|
const FacNodes& facNodes = vn->neighbors();
|
||||||
|
size_t idx = facNodes.size();
|
||||||
|
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||||
|
if (facNodes[i]->factor().contains (queryVids)) {
|
||||||
|
idx = i;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
if (idx == facNodes.size()) {
|
||||||
|
cerr << "error: only joint distributions on variables of some " ;
|
||||||
|
cerr << "clique are supported with the current solver" ;
|
||||||
|
cerr << endl;
|
||||||
|
exit (1);
|
||||||
|
}
|
||||||
|
VarIds representatives;
|
||||||
|
for (size_t i = 0; i < queryVids.size(); i++) {
|
||||||
|
representatives.push_back (getRepresentative (queryVids[i]));
|
||||||
|
}
|
||||||
|
res = solver_->getJointDistributionOf (representatives);
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
CbpSolver::findIdenticalFactors()
|
||||||
|
{
|
||||||
|
const FacNodes& facNodes = fg.facNodes();
|
||||||
|
if (checkForIdenticalFactors == false ||
|
||||||
|
facNodes.size() == 1) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||||
|
facNodes[i]->factor().setDistId (Util::maxUnsigned());
|
||||||
|
}
|
||||||
|
unsigned groupCount = 1;
|
||||||
|
for (size_t i = 0; i < facNodes.size() - 1; i++) {
|
||||||
|
Factor& f1 = facNodes[i]->factor();
|
||||||
|
if (f1.distId() != Util::maxUnsigned()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
f1.setDistId (groupCount);
|
||||||
|
for (size_t j = i + 1; j < facNodes.size(); j++) {
|
||||||
|
Factor& f2 = facNodes[j]->factor();
|
||||||
|
if (f2.distId() != Util::maxUnsigned()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (f1.size() == f2.size() &&
|
||||||
|
f1.ranges() == f2.ranges() &&
|
||||||
|
f1.params() == f2.params()) {
|
||||||
|
f2.setDistId (groupCount);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
groupCount ++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
CbpSolver::setInitialColors (void)
|
||||||
|
{
|
||||||
|
varColors_.resize (fg.nrVarNodes());
|
||||||
|
facColors_.resize (fg.nrFacNodes());
|
||||||
|
// create the initial variable colors
|
||||||
|
VarColorMap colorMap;
|
||||||
|
const VarNodes& varNodes = fg.varNodes();
|
||||||
|
for (size_t i = 0; i < varNodes.size(); i++) {
|
||||||
|
unsigned range = varNodes[i]->range();
|
||||||
|
VarColorMap::iterator it = colorMap.find (range);
|
||||||
|
if (it == colorMap.end()) {
|
||||||
|
it = colorMap.insert (make_pair (
|
||||||
|
range, Colors (range + 1, -1))).first;
|
||||||
|
}
|
||||||
|
unsigned idx = varNodes[i]->hasEvidence()
|
||||||
|
? varNodes[i]->getEvidence()
|
||||||
|
: range;
|
||||||
|
Colors& stateColors = it->second;
|
||||||
|
if (stateColors[idx] == -1) {
|
||||||
|
stateColors[idx] = getNewColor();
|
||||||
|
}
|
||||||
|
setColor (varNodes[i], stateColors[idx]);
|
||||||
|
}
|
||||||
|
const FacNodes& facNodes = fg.facNodes();
|
||||||
|
// create the initial factor colors
|
||||||
|
DistColorMap distColors;
|
||||||
|
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||||
|
unsigned distId = facNodes[i]->factor().distId();
|
||||||
|
DistColorMap::iterator it = distColors.find (distId);
|
||||||
|
if (it == distColors.end()) {
|
||||||
|
it = distColors.insert (make_pair (distId, getNewColor())).first;
|
||||||
|
}
|
||||||
|
setColor (facNodes[i], it->second);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
CbpSolver::createGroups (void)
|
||||||
|
{
|
||||||
|
VarSignMap varGroups;
|
||||||
|
FacSignMap facGroups;
|
||||||
|
unsigned nIters = 0;
|
||||||
|
bool groupsHaveChanged = true;
|
||||||
|
const VarNodes& varNodes = fg.varNodes();
|
||||||
|
const FacNodes& facNodes = fg.facNodes();
|
||||||
|
|
||||||
|
while (groupsHaveChanged || nIters == 1) {
|
||||||
|
nIters ++;
|
||||||
|
|
||||||
|
// set a new color to the variables with the same signature
|
||||||
|
size_t prevVarGroupsSize = varGroups.size();
|
||||||
|
varGroups.clear();
|
||||||
|
for (size_t i = 0; i < varNodes.size(); i++) {
|
||||||
|
const VarSignature& signature = getSignature (varNodes[i]);
|
||||||
|
VarSignMap::iterator it = varGroups.find (signature);
|
||||||
|
if (it == varGroups.end()) {
|
||||||
|
it = varGroups.insert (make_pair (signature, VarNodes())).first;
|
||||||
|
}
|
||||||
|
it->second.push_back (varNodes[i]);
|
||||||
|
}
|
||||||
|
for (VarSignMap::iterator it = varGroups.begin();
|
||||||
|
it != varGroups.end(); ++it) {
|
||||||
|
Color newColor = getNewColor();
|
||||||
|
VarNodes& groupMembers = it->second;
|
||||||
|
for (size_t i = 0; i < groupMembers.size(); i++) {
|
||||||
|
setColor (groupMembers[i], newColor);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t prevFactorGroupsSize = facGroups.size();
|
||||||
|
facGroups.clear();
|
||||||
|
// set a new color to the factors with the same signature
|
||||||
|
for (size_t i = 0; i < facNodes.size(); i++) {
|
||||||
|
const FacSignature& signature = getSignature (facNodes[i]);
|
||||||
|
FacSignMap::iterator it = facGroups.find (signature);
|
||||||
|
if (it == facGroups.end()) {
|
||||||
|
it = facGroups.insert (make_pair (signature, FacNodes())).first;
|
||||||
|
}
|
||||||
|
it->second.push_back (facNodes[i]);
|
||||||
|
}
|
||||||
|
for (FacSignMap::iterator it = facGroups.begin();
|
||||||
|
it != facGroups.end(); ++it) {
|
||||||
|
Color newColor = getNewColor();
|
||||||
|
FacNodes& groupMembers = it->second;
|
||||||
|
for (size_t i = 0; i < groupMembers.size(); i++) {
|
||||||
|
setColor (groupMembers[i], newColor);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
groupsHaveChanged = prevVarGroupsSize != varGroups.size()
|
||||||
|
|| prevFactorGroupsSize != facGroups.size();
|
||||||
|
}
|
||||||
|
// printGroups (varGroups, facGroups);
|
||||||
|
createClusters (varGroups, facGroups);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
CbpSolver::createClusters (
|
||||||
|
const VarSignMap& varGroups,
|
||||||
|
const FacSignMap& facGroups)
|
||||||
|
{
|
||||||
|
varClusters_.reserve (varGroups.size());
|
||||||
|
for (VarSignMap::const_iterator it = varGroups.begin();
|
||||||
|
it != varGroups.end(); ++it) {
|
||||||
|
const VarNodes& groupVars = it->second;
|
||||||
|
VarCluster* vc = new VarCluster (groupVars);
|
||||||
|
for (size_t i = 0; i < groupVars.size(); i++) {
|
||||||
|
vid2VarCluster_.insert (make_pair (groupVars[i]->varId(), vc));
|
||||||
|
}
|
||||||
|
varClusters_.push_back (vc);
|
||||||
|
}
|
||||||
|
|
||||||
|
facClusters_.reserve (facGroups.size());
|
||||||
|
for (FacSignMap::const_iterator it = facGroups.begin();
|
||||||
|
it != facGroups.end(); ++it) {
|
||||||
|
FacNode* groupFactor = it->second[0];
|
||||||
|
const VarNodes& neighs = groupFactor->neighbors();
|
||||||
|
VarClusters varClusters;
|
||||||
|
varClusters.reserve (neighs.size());
|
||||||
|
for (size_t i = 0; i < neighs.size(); i++) {
|
||||||
|
VarId vid = neighs[i]->varId();
|
||||||
|
varClusters.push_back (vid2VarCluster_.find (vid)->second);
|
||||||
|
}
|
||||||
|
facClusters_.push_back (new FacCluster (it->second, varClusters));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
VarSignature
|
||||||
|
CbpSolver::getSignature (const VarNode* varNode)
|
||||||
|
{
|
||||||
|
const FacNodes& neighs = varNode->neighbors();
|
||||||
|
VarSignature sign;
|
||||||
|
sign.reserve (neighs.size() + 1);
|
||||||
|
for (size_t i = 0; i < neighs.size(); i++) {
|
||||||
|
sign.push_back (make_pair (
|
||||||
|
getColor (neighs[i]),
|
||||||
|
neighs[i]->factor().indexOf (varNode->varId())));
|
||||||
|
}
|
||||||
|
std::sort (sign.begin(), sign.end());
|
||||||
|
sign.push_back (make_pair (getColor (varNode), 0));
|
||||||
|
return sign;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
FacSignature
|
||||||
|
CbpSolver::getSignature (const FacNode* facNode)
|
||||||
|
{
|
||||||
|
const VarNodes& neighs = facNode->neighbors();
|
||||||
|
FacSignature sign;
|
||||||
|
sign.reserve (neighs.size() + 1);
|
||||||
|
for (size_t i = 0; i < neighs.size(); i++) {
|
||||||
|
sign.push_back (getColor (neighs[i]));
|
||||||
|
}
|
||||||
|
sign.push_back (getColor (facNode));
|
||||||
|
return sign;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
FactorGraph*
|
||||||
|
CbpSolver::getCompressedFactorGraph (void)
|
||||||
|
{
|
||||||
|
FactorGraph* fg = new FactorGraph();
|
||||||
|
for (size_t i = 0; i < varClusters_.size(); i++) {
|
||||||
|
VarNode* newVar = new VarNode (varClusters_[i]->first());
|
||||||
|
varClusters_[i]->setRepresentative (newVar);
|
||||||
|
fg->addVarNode (newVar);
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < facClusters_.size(); i++) {
|
||||||
|
Vars vars;
|
||||||
|
const VarClusters& clusters = facClusters_[i]->varClusters();
|
||||||
|
for (size_t j = 0; j < clusters.size(); j++) {
|
||||||
|
vars.push_back (clusters[j]->representative());
|
||||||
|
}
|
||||||
|
const Factor& groundFac = facClusters_[i]->first()->factor();
|
||||||
|
FacNode* fn = new FacNode (Factor (
|
||||||
|
vars, groundFac.params(), groundFac.distId()));
|
||||||
|
facClusters_[i]->setRepresentative (fn);
|
||||||
|
fg->addFacNode (fn);
|
||||||
|
for (size_t j = 0; j < vars.size(); j++) {
|
||||||
|
fg->addEdge (static_cast<VarNode*> (vars[j]), fn);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fg;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
vector<vector<unsigned>>
|
||||||
|
CbpSolver::getWeights (void) const
|
||||||
|
{
|
||||||
|
vector<vector<unsigned>> weights;
|
||||||
|
weights.reserve (facClusters_.size());
|
||||||
|
for (size_t i = 0; i < facClusters_.size(); i++) {
|
||||||
|
const VarClusters& neighs = facClusters_[i]->varClusters();
|
||||||
|
weights.push_back ({ });
|
||||||
|
weights.back().reserve (neighs.size());
|
||||||
|
for (size_t j = 0; j < neighs.size(); j++) {
|
||||||
|
weights.back().push_back (getWeight (
|
||||||
|
facClusters_[i], neighs[j], j));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return weights;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
unsigned
|
||||||
|
CbpSolver::getWeight (
|
||||||
|
const FacCluster* fc,
|
||||||
|
const VarCluster* vc,
|
||||||
|
size_t index) const
|
||||||
|
{
|
||||||
|
unsigned weight = 0;
|
||||||
|
VarId reprVid = vc->representative()->varId();
|
||||||
|
VarNode* groundVar = fg.getVarNode (reprVid);
|
||||||
|
const FacNodes& neighs = groundVar->neighbors();
|
||||||
|
for (size_t i = 0; i < neighs.size(); i++) {
|
||||||
|
FacNodes::const_iterator it;
|
||||||
|
it = std::find (fc->members().begin(), fc->members().end(), neighs[i]);
|
||||||
|
if (it != fc->members().end() &&
|
||||||
|
(*it)->factor().indexOf (reprVid) == index) {
|
||||||
|
weight ++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return weight;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
CbpSolver::printGroups (
|
||||||
|
const VarSignMap& varGroups,
|
||||||
|
const FacSignMap& facGroups) const
|
||||||
|
{
|
||||||
|
unsigned count = 1;
|
||||||
|
cout << "variable groups:" << endl;
|
||||||
|
for (VarSignMap::const_iterator it = varGroups.begin();
|
||||||
|
it != varGroups.end(); ++it) {
|
||||||
|
const VarNodes& groupMembers = it->second;
|
||||||
|
if (groupMembers.size() > 0) {
|
||||||
|
cout << count << ": " ;
|
||||||
|
for (size_t i = 0; i < groupMembers.size(); i++) {
|
||||||
|
cout << groupMembers[i]->label() << " " ;
|
||||||
|
}
|
||||||
|
count ++;
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
count = 1;
|
||||||
|
cout << endl << "factor groups:" << endl;
|
||||||
|
for (FacSignMap::const_iterator it = facGroups.begin();
|
||||||
|
it != facGroups.end(); ++it) {
|
||||||
|
const FacNodes& groupMembers = it->second;
|
||||||
|
if (groupMembers.size() > 0) {
|
||||||
|
cout << ++count << ": " ;
|
||||||
|
for (size_t i = 0; i < groupMembers.size(); i++) {
|
||||||
|
cout << groupMembers[i]->getLabel() << " " ;
|
||||||
|
}
|
||||||
|
count ++;
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
183
packages/CLPBN/horus/CbpSolver.h
Normal file
183
packages/CLPBN/horus/CbpSolver.h
Normal file
@ -0,0 +1,183 @@
|
|||||||
|
#ifndef HORUS_CBPSOLVER_H
|
||||||
|
#define HORUS_CBPSOLVER_H
|
||||||
|
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
|
#include "Solver.h"
|
||||||
|
#include "FactorGraph.h"
|
||||||
|
#include "Util.h"
|
||||||
|
#include "Horus.h"
|
||||||
|
|
||||||
|
class VarCluster;
|
||||||
|
class FacCluster;
|
||||||
|
class VarSignHash;
|
||||||
|
class FacSignHash;
|
||||||
|
class WeightedBpSolver;
|
||||||
|
|
||||||
|
typedef long Color;
|
||||||
|
typedef vector<Color> Colors;
|
||||||
|
typedef vector<std::pair<Color,unsigned>> VarSignature;
|
||||||
|
typedef vector<Color> FacSignature;
|
||||||
|
|
||||||
|
typedef unordered_map<unsigned, Color> DistColorMap;
|
||||||
|
typedef unordered_map<unsigned, Colors> VarColorMap;
|
||||||
|
|
||||||
|
typedef unordered_map<VarSignature, VarNodes, VarSignHash> VarSignMap;
|
||||||
|
typedef unordered_map<FacSignature, FacNodes, FacSignHash> FacSignMap;
|
||||||
|
|
||||||
|
typedef vector<VarCluster*> VarClusters;
|
||||||
|
typedef vector<FacCluster*> FacClusters;
|
||||||
|
|
||||||
|
typedef unordered_map<VarId, VarCluster*> VarId2VarCluster;
|
||||||
|
|
||||||
|
|
||||||
|
struct VarSignHash
|
||||||
|
{
|
||||||
|
size_t operator() (const VarSignature &sig) const
|
||||||
|
{
|
||||||
|
size_t val = hash<size_t>()(sig.size());
|
||||||
|
for (size_t i = 0; i < sig.size(); i++) {
|
||||||
|
val ^= hash<size_t>()(sig[i].first);
|
||||||
|
val ^= hash<size_t>()(sig[i].second);
|
||||||
|
}
|
||||||
|
return val;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
struct FacSignHash
|
||||||
|
{
|
||||||
|
size_t operator() (const FacSignature &sig) const
|
||||||
|
{
|
||||||
|
size_t val = hash<size_t>()(sig.size());
|
||||||
|
for (size_t i = 0; i < sig.size(); i++) {
|
||||||
|
val ^= hash<size_t>()(sig[i]);
|
||||||
|
}
|
||||||
|
return val;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
class VarCluster
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
VarCluster (const VarNodes& vs) : members_(vs) { }
|
||||||
|
|
||||||
|
const VarNode* first (void) const { return members_.front(); }
|
||||||
|
|
||||||
|
const VarNodes& members (void) const { return members_; }
|
||||||
|
|
||||||
|
VarNode* representative (void) const { return repr_; }
|
||||||
|
|
||||||
|
void setRepresentative (VarNode* vn) { repr_ = vn; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
VarNodes members_;
|
||||||
|
VarNode* repr_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
class FacCluster
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
FacCluster (const FacNodes& fcs, const VarClusters& vcs)
|
||||||
|
: members_(fcs), varClusters_(vcs) { }
|
||||||
|
|
||||||
|
const FacNode* first (void) const { return members_.front(); }
|
||||||
|
|
||||||
|
const FacNodes& members (void) const { return members_; }
|
||||||
|
|
||||||
|
VarClusters& varClusters (void) { return varClusters_; }
|
||||||
|
|
||||||
|
FacNode* representative (void) const { return repr_; }
|
||||||
|
|
||||||
|
void setRepresentative (FacNode* fn) { repr_ = fn; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
FacNodes members_;
|
||||||
|
VarClusters varClusters_;
|
||||||
|
FacNode* repr_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
class CbpSolver : public Solver
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
CbpSolver (const FactorGraph& fg);
|
||||||
|
|
||||||
|
~CbpSolver (void);
|
||||||
|
|
||||||
|
void printSolverFlags (void) const;
|
||||||
|
|
||||||
|
Params solveQuery (VarIds);
|
||||||
|
|
||||||
|
static bool checkForIdenticalFactors;
|
||||||
|
|
||||||
|
private:
|
||||||
|
Color getNewColor (void)
|
||||||
|
{
|
||||||
|
++ freeColor_;
|
||||||
|
return freeColor_ - 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
Color getColor (const VarNode* vn) const
|
||||||
|
{
|
||||||
|
return varColors_[vn->getIndex()];
|
||||||
|
}
|
||||||
|
|
||||||
|
Color getColor (const FacNode* fn) const
|
||||||
|
{
|
||||||
|
return facColors_[fn->getIndex()];
|
||||||
|
}
|
||||||
|
|
||||||
|
void setColor (const VarNode* vn, Color c)
|
||||||
|
{
|
||||||
|
varColors_[vn->getIndex()] = c;
|
||||||
|
}
|
||||||
|
|
||||||
|
void setColor (const FacNode* fn, Color c)
|
||||||
|
{
|
||||||
|
facColors_[fn->getIndex()] = c;
|
||||||
|
}
|
||||||
|
|
||||||
|
void findIdenticalFactors (void);
|
||||||
|
|
||||||
|
void setInitialColors (void);
|
||||||
|
|
||||||
|
void createGroups (void);
|
||||||
|
|
||||||
|
void createClusters (const VarSignMap&, const FacSignMap&);
|
||||||
|
|
||||||
|
VarSignature getSignature (const VarNode*);
|
||||||
|
|
||||||
|
FacSignature getSignature (const FacNode*);
|
||||||
|
|
||||||
|
void printGroups (const VarSignMap&, const FacSignMap&) const;
|
||||||
|
|
||||||
|
VarId getRepresentative (VarId vid)
|
||||||
|
{
|
||||||
|
assert (Util::contains (vid2VarCluster_, vid));
|
||||||
|
VarCluster* vc = vid2VarCluster_.find (vid)->second;
|
||||||
|
return vc->representative()->varId();
|
||||||
|
}
|
||||||
|
|
||||||
|
FactorGraph* getCompressedFactorGraph (void);
|
||||||
|
|
||||||
|
vector<vector<unsigned>> getWeights (void) const;
|
||||||
|
|
||||||
|
unsigned getWeight (const FacCluster*,
|
||||||
|
const VarCluster*, size_t index) const;
|
||||||
|
|
||||||
|
|
||||||
|
Color freeColor_;
|
||||||
|
Colors varColors_;
|
||||||
|
Colors facColors_;
|
||||||
|
VarClusters varClusters_;
|
||||||
|
FacClusters facClusters_;
|
||||||
|
VarId2VarCluster vid2VarCluster_;
|
||||||
|
const FactorGraph* compressedFg_;
|
||||||
|
WeightedBpSolver* solver_;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // HORUS_CBPSOLVER_H
|
||||||
|
|
1130
packages/CLPBN/horus/ConstraintTree.cpp
Normal file
1130
packages/CLPBN/horus/ConstraintTree.cpp
Normal file
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user