Conflicts:
	packages/CLPBN/clpbn/horus.yap
This commit is contained in:
Vitor Santos Costa 2012-06-04 16:29:56 +01:00
commit 3669cb894f
139 changed files with 9203 additions and 7699 deletions

View File

@ -3298,6 +3298,28 @@ code_to_indexcl(yamop *ipc, int is_lu)
return ret; return ret;
} }
static void
increase_expand_depth(yamop *ipc, struct intermediates *cint)
{
yamop *ncode;
cint->term_depth++;
if (/* ipc->opc == Yap_opcode(_switch_on_sub_arg_type) && */
(ncode = ipc->u.sllll.l4)->opc == Yap_opcode(_expand_clauses)) {
if (ncode->u.sssllp.s2 != cint->last_depth_size) {
cint->last_index_new_depth = cint->term_depth;
cint->last_depth_size = ncode->u.sssllp.s2;
}
}
}
static void
zero_expand_depth(PredEntry *ap, struct intermediates *cint)
{
cint->term_depth = cint->last_index_new_depth;
cint->last_depth_size = ap->cs.p_code.NOfClauses;
}
static yamop ** static yamop **
expand_index(struct intermediates *cint) { expand_index(struct intermediates *cint) {
CACHE_REGS CACHE_REGS
@ -3499,6 +3521,7 @@ expand_index(struct intermediates *cint) {
break; break;
/* instructions type e */ /* instructions type e */
case _switch_on_type: case _switch_on_type:
zero_expand_depth(ap, cint);
t = Deref(ARG1); t = Deref(ARG1);
argno = 1; argno = 1;
i = 0; i = 0;
@ -3520,6 +3543,7 @@ expand_index(struct intermediates *cint) {
parentcl = index_jmp(parentcl, parentcl, ipc, is_lu, e_code); parentcl = index_jmp(parentcl, parentcl, ipc, is_lu, e_code);
break; break;
case _switch_list_nl: case _switch_list_nl:
zero_expand_depth(ap, cint);
t = Deref(ARG1); t = Deref(ARG1);
argno = 1; argno = 1;
i = 0; i = 0;
@ -3548,6 +3572,7 @@ expand_index(struct intermediates *cint) {
parentcl = index_jmp(parentcl, parentcl, ipc, is_lu, e_code); parentcl = index_jmp(parentcl, parentcl, ipc, is_lu, e_code);
break; break;
case _switch_on_arg_type: case _switch_on_arg_type:
zero_expand_depth(ap, cint);
argno = arg_from_x(ipc->u.xllll.x); argno = arg_from_x(ipc->u.xllll.x);
i = 0; i = 0;
t = Deref(XREGS[argno]); t = Deref(XREGS[argno]);
@ -3578,12 +3603,14 @@ expand_index(struct intermediates *cint) {
ipc = ipc->u.sllll.l4; ipc = ipc->u.sllll.l4;
i++; i++;
} else if (IsPairTerm(t)) { } else if (IsPairTerm(t)) {
increase_expand_depth(ipc, cint);
s_reg = RepPair(t); s_reg = RepPair(t);
sp = push_stack(sp, -i-1, AbsPair(NULL), TermNil, cint); sp = push_stack(sp, -i-1, AbsPair(NULL), TermNil, cint);
labp = &(ipc->u.sllll.l1); labp = &(ipc->u.sllll.l1);
ipc = ipc->u.sllll.l1; ipc = ipc->u.sllll.l1;
i = 0; i = 0;
} else if (IsApplTerm(t)) { } else if (IsApplTerm(t)) {
increase_expand_depth(ipc, cint);
sp = push_stack(sp, -i-1, AbsAppl((CELL *)FunctorOfTerm(t)), TermNil, cint); sp = push_stack(sp, -i-1, AbsAppl((CELL *)FunctorOfTerm(t)), TermNil, cint);
ipc = ipc->u.sllll.l3; ipc = ipc->u.sllll.l3;
i = 0; i = 0;
@ -3591,6 +3618,7 @@ expand_index(struct intermediates *cint) {
/* We don't push stack here, instead we go over to next argument /* We don't push stack here, instead we go over to next argument
sp = push_stack(sp, -i-1, t, cint); sp = push_stack(sp, -i-1, t, cint);
*/ */
increase_expand_depth(ipc, cint);
sp = push_stack(sp, -i-1, t, TermNil, cint); sp = push_stack(sp, -i-1, t, TermNil, cint);
ipc = ipc->u.sllll.l2; ipc = ipc->u.sllll.l2;
i++; i++;

View File

@ -706,7 +706,7 @@ all: startup.yss
@ENABLE_SEMWEB@ @INSTALL_DLLS@ (cd packages/semweb; $(MAKE)) @ENABLE_SEMWEB@ @INSTALL_DLLS@ (cd packages/semweb; $(MAKE))
@ENABLE_SGML@ @INSTALL_DLLS@ (cd packages/sgml; $(MAKE)) @ENABLE_SGML@ @INSTALL_DLLS@ (cd packages/sgml; $(MAKE))
@ENABLE_REAL@ (cd packages/real; $(MAKE)) @ENABLE_REAL@ (cd packages/real; $(MAKE))
@ENABLE_CLPBN_BP@ (cd packages/CLPBN/clpbn/bp ; $(MAKE)) @ENABLE_CLPBN_BP@ (cd packages/CLPBN/horus; $(MAKE))
@ENABLE_MINISAT@ (cd packages/swi-minisat2/C; $(MAKE)) @ENABLE_MINISAT@ (cd packages/swi-minisat2/C; $(MAKE))
@ENABLE_ZLIB@ @INSTALL_DLLS@ (cd packages/zlib; $(MAKE)) @ENABLE_ZLIB@ @INSTALL_DLLS@ (cd packages/zlib; $(MAKE))
@ENABLE_CPLINT@ (cd packages/cplint/approx/simplecuddLPADs; $(MAKE)) @ENABLE_CPLINT@ (cd packages/cplint/approx/simplecuddLPADs; $(MAKE))
@ -775,7 +775,7 @@ install_unix: startup.yss libYap.a
@ENABLE_SEMWEB@ @INSTALL_DLLS@ (cd packages/semweb; $(MAKE) install) @ENABLE_SEMWEB@ @INSTALL_DLLS@ (cd packages/semweb; $(MAKE) install)
@ENABLE_SGML@ @INSTALL_DLLS@ (cd packages/sgml; $(MAKE) install) @ENABLE_SGML@ @INSTALL_DLLS@ (cd packages/sgml; $(MAKE) install)
@ENABLE_ZLIB@ @INSTALL_DLLS@ (cd packages/zlib; $(MAKE) @ZLIB_INSTALL@) @ENABLE_ZLIB@ @INSTALL_DLLS@ (cd packages/zlib; $(MAKE) @ZLIB_INSTALL@)
@ENABLE_CLPBN_BP@ @INSTALL_DLLS@ (cd packages/CLPBN/clpbn/bp ; $(MAKE) install) @ENABLE_CLPBN_BP@ @INSTALL_DLLS@ (cd packages/CLPBN/horus; $(MAKE) install)
@ENABLE_MINISAT@ (cd packages/swi-minisat2/C; $(MAKE) install) @ENABLE_MINISAT@ (cd packages/swi-minisat2/C; $(MAKE) install)
@INSTALL_MATLAB@ (cd library/matlab; $(MAKE) install) @INSTALL_MATLAB@ (cd library/matlab; $(MAKE) install)
@ENABLE_REAL@ (cd packages/real; $(MAKE) install) @ENABLE_REAL@ (cd packages/real; $(MAKE) install)
@ -839,7 +839,7 @@ install_win32: startup.yss @ENABLE_WINCONSOLE@ pl-yap@EXEC_SUFFIX@
@ENABLE_SGML@ (cd packages/sgml; $(MAKE) install) @ENABLE_SGML@ (cd packages/sgml; $(MAKE) install)
@ENABLE_ZLIB@ (cd packages/zlib; $(MAKE) @ZLIB_INSTALL@) @ENABLE_ZLIB@ (cd packages/zlib; $(MAKE) @ZLIB_INSTALL@)
(cd packages/CLPBN ; $(MAKE) install) (cd packages/CLPBN ; $(MAKE) install)
@ENABLE_CLPBN_BP@ (cd packages/CLPBN/clpbn/bp ; $(MAKE) install) @ENABLE_CLPBN_BP@ (cd packages/CLPBN/horus; $(MAKE) install)
@ENABLE_JPL@ (cd packages/jpl ; $(MAKE) install) @ENABLE_JPL@ (cd packages/jpl ; $(MAKE) install)
@ENABLE_MINISAT@ (cd packages/swi-minisat2/C; $(MAKE) install) @ENABLE_MINISAT@ (cd packages/swi-minisat2/C; $(MAKE) install)
@ENABLE_CPLINT@ (cd packages/cplint; $(MAKE) install) @ENABLE_CPLINT@ (cd packages/cplint; $(MAKE) install)
@ -904,7 +904,7 @@ clean: clean_docs
@ENABLE_SGML@ @INSTALL_DLLS@ (cd packages/sgml; $(MAKE) clean) @ENABLE_SGML@ @INSTALL_DLLS@ (cd packages/sgml; $(MAKE) clean)
@ENABLE_REAL@ (cd packages/real; $(MAKE) clean) @ENABLE_REAL@ (cd packages/real; $(MAKE) clean)
@ENABLE_MINISAT@ (cd packages/swi-minisat2; $(MAKE) clean) @ENABLE_MINISAT@ (cd packages/swi-minisat2; $(MAKE) clean)
@ENABLE_CLPBN_BP@ (cd packages/CLPBN/clpbn/bp; $(MAKE) clean) @ENABLE_CLPBN_BP@ (cd packages/CLPBN/horus; $(MAKE) clean)
@ENABLE_ZLIB@ @INSTALL_DLLS@ (cd packages/zlib; $(MAKE) clean) @ENABLE_ZLIB@ @INSTALL_DLLS@ (cd packages/zlib; $(MAKE) clean)
@ENABLE_PRISM@ (cd packages/prism/src/c; $(MAKE) clean) @ENABLE_PRISM@ (cd packages/prism/src/c; $(MAKE) clean)
@ENABLE_PRISM@ (cd packages/prism/src/prolog; $(MAKE) clean) @ENABLE_PRISM@ (cd packages/prism/src/prolog; $(MAKE) clean)

8
configure vendored
View File

@ -10519,8 +10519,7 @@ mkdir -p packages/clib/maildrop/rfc822
mkdir -p packages/clib/maildrop/rfc2045 mkdir -p packages/clib/maildrop/rfc2045
mkdir -p packages/CLPBN mkdir -p packages/CLPBN
mkdir -p packages/CLPBN/clpbn mkdir -p packages/CLPBN/clpbn
mkdir -p packages/CLPBN/clpbn/bp mkdir -p packages/CLPBN/horus
mkdir -p packages/CLPBN/clpbn/bp/xmlParser
mkdir -p packages/clpqr mkdir -p packages/clpqr
mkdir -p packages/cplint mkdir -p packages/cplint
mkdir -p packages/cplint/approx mkdir -p packages/cplint/approx
@ -10673,6 +10672,7 @@ ac_config_files="$ac_config_files packages/zlib/Makefile"
fi fi
if test "$ENABLE_CUDD" = ""; then if test "$ENABLE_CUDD" = ""; then
ac_config_files="$ac_config_files packages/bdd/Makefile" ac_config_files="$ac_config_files packages/bdd/Makefile"
@ -10695,7 +10695,7 @@ ac_config_files="$ac_config_files packages/real/Makefile"
fi fi
if test "$ENABLE_CLPBN_BP" = ""; then if test "$ENABLE_CLPBN_BP" = ""; then
ac_config_files="$ac_config_files packages/CLPBN/clpbn/bp/Makefile" ac_config_files="$ac_config_files packages/CLPBN/horus/Makefile"
fi fi
@ -11464,7 +11464,7 @@ do
"packages/swi-minisat2/Makefile") CONFIG_FILES="$CONFIG_FILES packages/swi-minisat2/Makefile" ;; "packages/swi-minisat2/Makefile") CONFIG_FILES="$CONFIG_FILES packages/swi-minisat2/Makefile" ;;
"packages/swi-minisat2/C/Makefile") CONFIG_FILES="$CONFIG_FILES packages/swi-minisat2/C/Makefile" ;; "packages/swi-minisat2/C/Makefile") CONFIG_FILES="$CONFIG_FILES packages/swi-minisat2/C/Makefile" ;;
"packages/real/Makefile") CONFIG_FILES="$CONFIG_FILES packages/real/Makefile" ;; "packages/real/Makefile") CONFIG_FILES="$CONFIG_FILES packages/real/Makefile" ;;
"packages/CLPBN/clpbn/bp/Makefile") CONFIG_FILES="$CONFIG_FILES packages/CLPBN/clpbn/bp/Makefile" ;; "packages/CLPBN/horus/Makefile") CONFIG_FILES="$CONFIG_FILES packages/CLPBN/horus/Makefile" ;;
"library/gecode/Makefile") CONFIG_FILES="$CONFIG_FILES library/gecode/Makefile" ;; "library/gecode/Makefile") CONFIG_FILES="$CONFIG_FILES library/gecode/Makefile" ;;
"packages/prism/src/c/Makefile") CONFIG_FILES="$CONFIG_FILES packages/prism/src/c/Makefile" ;; "packages/prism/src/c/Makefile") CONFIG_FILES="$CONFIG_FILES packages/prism/src/c/Makefile" ;;
"packages/prism/src/prolog/Makefile") CONFIG_FILES="$CONFIG_FILES packages/prism/src/prolog/Makefile" ;; "packages/prism/src/prolog/Makefile") CONFIG_FILES="$CONFIG_FILES packages/prism/src/prolog/Makefile" ;;

View File

@ -2300,8 +2300,7 @@ mkdir -p packages/clib/maildrop/rfc822
mkdir -p packages/clib/maildrop/rfc2045 mkdir -p packages/clib/maildrop/rfc2045
mkdir -p packages/CLPBN mkdir -p packages/CLPBN
mkdir -p packages/CLPBN/clpbn mkdir -p packages/CLPBN/clpbn
mkdir -p packages/CLPBN/clpbn/bp mkdir -p packages/CLPBN/horus
mkdir -p packages/CLPBN/clpbn/bp/xmlParser
mkdir -p packages/clpqr mkdir -p packages/clpqr
mkdir -p packages/cplint mkdir -p packages/cplint
mkdir -p packages/cplint/approx mkdir -p packages/cplint/approx
@ -2415,6 +2414,7 @@ if test "$ENABLE_ZLIB" = ""; then
AC_CONFIG_FILES([packages/zlib/Makefile]) AC_CONFIG_FILES([packages/zlib/Makefile])
fi fi
if test "$ENABLE_CUDD" = ""; then if test "$ENABLE_CUDD" = ""; then
AC_CONFIG_FILES([packages/bdd/Makefile]) AC_CONFIG_FILES([packages/bdd/Makefile])
AC_CONFIG_FILES([packages/ProbLog/simplecudd/Makefile]) AC_CONFIG_FILES([packages/ProbLog/simplecudd/Makefile])
@ -2431,7 +2431,7 @@ AC_CONFIG_FILES([packages/real/Makefile])
fi fi
if test "$ENABLE_CLPBN_BP" = ""; then if test "$ENABLE_CLPBN_BP" = ""; then
AC_CONFIG_FILES([packages/CLPBN/clpbn/bp/Makefile]) AC_CONFIG_FILES([packages/CLPBN/horus/Makefile])
fi fi
if test "$ENABLE_GECODE" = ""; then if test "$ENABLE_GECODE" = ""; then

View File

@ -42,19 +42,19 @@ CLPBN_PROGRAMS= \
$(CLPBN_SRCDIR)/aggregates.yap \ $(CLPBN_SRCDIR)/aggregates.yap \
$(CLPBN_SRCDIR)/bdd.yap \ $(CLPBN_SRCDIR)/bdd.yap \
$(CLPBN_SRCDIR)/bnt.yap \ $(CLPBN_SRCDIR)/bnt.yap \
$(CLPBN_SRCDIR)/bp.yap \
$(CLPBN_SRCDIR)/connected.yap \ $(CLPBN_SRCDIR)/connected.yap \
$(CLPBN_SRCDIR)/discrete_utils.yap \ $(CLPBN_SRCDIR)/discrete_utils.yap \
$(CLPBN_SRCDIR)/display.yap \ $(CLPBN_SRCDIR)/display.yap \
$(CLPBN_SRCDIR)/dists.yap \ $(CLPBN_SRCDIR)/dists.yap \
$(CLPBN_SRCDIR)/evidence.yap \ $(CLPBN_SRCDIR)/evidence.yap \
$(CLPBN_SRCDIR)/fove.yap \
$(CLPBN_SRCDIR)/gibbs.yap \ $(CLPBN_SRCDIR)/gibbs.yap \
$(CLPBN_SRCDIR)/graphs.yap \ $(CLPBN_SRCDIR)/graphs.yap \
$(CLPBN_SRCDIR)/graphviz.yap \ $(CLPBN_SRCDIR)/graphviz.yap \
$(CLPBN_SRCDIR)/ground_factors.yap \ $(CLPBN_SRCDIR)/ground_factors.yap \
$(CLPBN_SRCDIR)/hmm.yap \ $(CLPBN_SRCDIR)/hmm.yap \
$(CLPBN_SRCDIR)/horus.yap \ $(CLPBN_SRCDIR)/horus.yap \
$(CLPBN_SRCDIR)/horus_ground.yap \
$(CLPBN_SRCDIR)/horus_lifted.yap \
$(CLPBN_SRCDIR)/jt.yap \ $(CLPBN_SRCDIR)/jt.yap \
$(CLPBN_SRCDIR)/matrix_cpt_utils.yap \ $(CLPBN_SRCDIR)/matrix_cpt_utils.yap \
$(CLPBN_SRCDIR)/pgrammar.yap \ $(CLPBN_SRCDIR)/pgrammar.yap \
@ -94,8 +94,16 @@ CLPBN_HMMER_EXAMPLES= \
$(CLPBN_EXDIR)/HMMer/score.yap $(CLPBN_EXDIR)/HMMer/score.yap
CLPBN_EXAMPLES= \ CLPBN_EXAMPLES= \
$(CLPBN_EXDIR)/burglary-alarm.fg \
$(CLPBN_EXDIR)/burglary-alarm.yap \
$(CLPBN_EXDIR)/burglary-alarm.uai \
$(CLPBN_EXDIR)/cg.yap \ $(CLPBN_EXDIR)/cg.yap \
$(CLPBN_EXDIR)/sprinkler.yap $(CLPBN_EXDIR)/city.yap \
$(CLPBN_EXDIR)/comp_workshops.yap \
$(CLPBN_EXDIR)/social_domain1.yap \
$(CLPBN_EXDIR)/social_domain2.yap \
$(CLPBN_EXDIR)/sprinkler.yap \
$(CLPBN_EXDIR)/workshop_attrs.yap
install: $(CLBN_TOP) $(CLBN_PROGRAMS) $(CLPBN_PROGRAMS) install: $(CLBN_TOP) $(CLBN_PROGRAMS) $(CLPBN_PROGRAMS)

View File

@ -0,0 +1,78 @@
function prepare_new_run
{
YAP=~/bin/$SHORTNAME-$SOLVER
LOG_FILE=$SOLVER.log
#LOG_FILE=results`date "+ %H:%M:%S %d-%m-%Y"`.
rm -f $LOG_FILE
rm -f ignore.$LOG_FILE
cp ~/bin/yap $YAP
}
function run_solver
{
constraint=$1
solver_flag=true
if [ -n "$2" ]; then
if [ $SOLVER = hve ]; then
solver_flag=clpbn_horus:set_horus_flag\(elim_heuristic,$2\)
elif [ $SOLVER = bp ]; then
solver_flag=clpbn_horus:set_horus_flag\(schedule,$2\)
elif [ $SOLVER = cbp ]; then
solver_flag=clpbn_horus:set_horus_flag\(schedule,$2\)
else
echo "unknow flag $2"
fi
fi
/usr/bin/time -o $LOG_FILE -a -f "%U\t%S\t%e\t%M" \
$YAP << EOF >> $LOG_FILE &>> ignore.$LOG_FILE
nogc.
[$NETWORK].
[$constraint].
clpbn_horus:set_solver($SOLVER).
clpbn_horus:set_horus_flag(use_logarithms, true).
clpbn_horus:set_horus_flag(verbosity, 1).
$solver_flag.
$QUERY.
open("$LOG_FILE", 'append', S), format(S, '$constraint ~15+ ', []), close(S).
EOF
}
function clear_log_files
{
rm -f *~
rm -f ../*~
rm -f school/*.log school/*~
rm -f ../school/*.log ../school/*~
rm -f city/*.log city/*~
rm -f ../city/*.log ../city/*~
rm -f workshop_attrs/*.log workshop_attrs/*~
rm -f ../workshop_attrs/*.log ../workshop_attrs/*~
echo all done!
}
function write_header
{
echo -n "****************************************" >> $LOG_FILE
echo "****************************************" >> $LOG_FILE
echo "results for solver $1 user(s) sys(s) real(s), mem(kB)" >> $LOG_FILE
echo -n "****************************************" >> $LOG_FILE
echo "****************************************" >> $LOG_FILE
}
if [ $1 ] && [ $1 == "clean" ]; then
clear_log_files
fi

View File

@ -0,0 +1,37 @@
#!/bin/bash
source city.sh
source ../benchs.sh
SOLVER="bp"
function run_all_graphs
{
write_header $1
run_solver city1000 $2
run_solver city5000 $2
run_solver city10000 $2
run_solver city15000 $2
run_solver city20000 $2
run_solver city25000 $2
run_solver city30000 $2
run_solver city35000 $2
run_solver city40000 $2
run_solver city45000 $2
run_solver city50000 $2
run_solver city55000 $2
run_solver city60000 $2
run_solver city65000 $2
return
run_solver city70000 $2
run_solver city75000 $2
run_solver city80000 $2
run_solver city85000 $2
run_solver city90000 $2
run_solver city95000 $2
run_solver city100000 $2
}
prepare_new_run
run_all_graphs "bp(shedule=seq_fixed) " seq_fixed

View File

@ -0,0 +1,36 @@
#!/bin/bash
source city.sh
source ../benchs.sh
SOLVER="cbp"
function run_all_graphs
{
write_header $1
run_solver city1000 $2
run_solver city5000 $2
run_solver city10000 $2
run_solver city15000 $2
run_solver city20000 $2
run_solver city25000 $2
run_solver city30000 $2
run_solver city35000 $2
run_solver city40000 $2
run_solver city45000 $2
run_solver city50000 $2
run_solver city55000 $2
run_solver city60000 $2
run_solver city65000 $2
run_solver city70000 $2
run_solver city75000 $2
run_solver city80000 $2
run_solver city85000 $2
run_solver city90000 $2
run_solver city95000 $2
run_solver city100000 $2
}
prepare_new_run
run_all_graphs "cbp(shedule=seq_fixed) " seq_fixed

View File

@ -0,0 +1,6 @@
#!/bin/bash
NETWORK="'../../examples/city'"
SHORTNAME="city"
QUERY="is_joe_guilty(X)"

View File

@ -0,0 +1,36 @@
#!/bin/bash
source city.sh
source ../benchs.sh
SOLVER="fove"
function run_all_graphs
{
write_header $1
run_solver city1000 $2
run_solver city5000 $2
run_solver city10000 $2
run_solver city15000 $2
run_solver city20000 $2
run_solver city25000 $2
run_solver city30000 $2
run_solver city35000 $2
run_solver city40000 $2
run_solver city45000 $2
run_solver city50000 $2
run_solver city55000 $2
run_solver city60000 $2
run_solver city65000 $2
run_solver city70000 $2
run_solver city75000 $2
run_solver city80000 $2
run_solver city85000 $2
run_solver city90000 $2
run_solver city95000 $2
run_solver city100000 $2
}
prepare_new_run
run_all_graphs "fove "

View File

@ -1,21 +1,17 @@
#!/home/tiago/bin/yap -L -- #! /home/tgomes/bin/yap -L --
:- initialization(main). :- initialization(main).
main :- main :-
unix(argv([H])), unix(argv([N])),
generate_town(H). atomic_concat(['city', N, '.yap'], FileName),
generate_town(N) :-
atomic_concat(['city_', N, '.yap'], FileName),
open(FileName, 'write', S), open(FileName, 'write', S),
atom_number(N, N2), atom_number(N, N2),
generate_people(S, N2, 4), generate_people(S, N2, 1),
write(S, '\n'), write(S, '\n'),
generate_query(S, N2, 4), generate_evidence(S, N2, 1),
write(S, '\n'), write(S, '\n'),
close(S). close(S).
@ -28,10 +24,10 @@ generate_people(S, N, Counting) :-
generate_people(S, N, Counting1). generate_people(S, N, Counting1).
generate_query(S, N, Counting) :- generate_evidence(S, N, Counting) :-
Counting > N, !. Counting > N, !.
generate_query(S, N, Counting) :- !, generate_evidence(S, N, Counting) :- !,
format(S, 'ev(descn(p~w, t)).~n', [Counting]), format(S, 'ev(descn(p~w, t)).~n', [Counting]),
Counting1 is Counting + 1, Counting1 is Counting + 1,
generate_query(S, N, Counting1). generate_evidence(S, N, Counting1).

View File

@ -0,0 +1,37 @@
#!/bin/bash
source city.sh
source ../benchs.sh
SOLVER="hve"
function run_all_graphs
{
write_header $1
run_solver city1000 $2
run_solver city5000 $2
run_solver city10000 $2
run_solver city15000 $2
run_solver city20000 $2
run_solver city25000 $2
run_solver city30000 $2
run_solver city35000 $2
run_solver city40000 $2
run_solver city45000 $2
run_solver city50000 $2
run_solver city55000 $2
run_solver city60000 $2
run_solver city65000 $2
run_solver city70000 $2
run_solver city75000 $2
run_solver city80000 $2
run_solver city85000 $2
run_solver city90000 $2
run_solver city95000 $2
run_solver city100000 $2
}
prepare_new_run
run_all_graphs "hve(elim_heuristic=min_neighbors) " min_neighbors

View File

@ -0,0 +1,31 @@
#!/bin/bash
source cw.sh
source ../benchs.sh
SOLVER="bp"
function run_all_graphs
{
write_header $1
run_solver p1000w$N_WORKSHOPS $2
run_solver p5000w$N_WORKSHOPS $2
run_solver p10000w$N_WORKSHOPS $2
run_solver p15000w$N_WORKSHOPS $2
run_solver p20000w$N_WORKSHOPS $2
run_solver p25000w$N_WORKSHOPS $2
return
run_solver p30000w$N_WORKSHOPS $2
run_solver p35000w$N_WORKSHOPS $2
run_solver p40000w$N_WORKSHOPS $2
run_solver p45000w$N_WORKSHOPS $2
run_solver p50000w$N_WORKSHOPS $2
run_solver p55000w$N_WORKSHOPS $2
run_solver p60000w$N_WORKSHOPS $2
run_solver p65000w$N_WORKSHOPS $2
run_solver p70000w$N_WORKSHOPS $2
}
prepare_new_run
run_all_graphs "bp(shedule=seq_fixed) " seq_fixed

View File

@ -0,0 +1,30 @@
#!/bin/bash
source cw.sh
source ../benchs.sh
SOLVER="cbp"
function run_all_graphs
{
write_header $1
run_solver p1000w$N_WORKSHOPS $2
run_solver p5000w$N_WORKSHOPS $2
run_solver p10000w$N_WORKSHOPS $2
run_solver p15000w$N_WORKSHOPS $2
run_solver p20000w$N_WORKSHOPS $2
run_solver p25000w$N_WORKSHOPS $2
run_solver p30000w$N_WORKSHOPS $2
run_solver p35000w$N_WORKSHOPS $2
run_solver p40000w$N_WORKSHOPS $2
run_solver p45000w$N_WORKSHOPS $2
run_solver p50000w$N_WORKSHOPS $2
run_solver p55000w$N_WORKSHOPS $2
run_solver p60000w$N_WORKSHOPS $2
run_solver p65000w$N_WORKSHOPS $2
run_solver p70000w$N_WORKSHOPS $2
}
prepare_new_run
run_all_graphs "cbp(shedule=seq_fixed) " seq_fixed

View File

@ -0,0 +1,8 @@
#!/bin/bash
NETWORK="'../../examples/comp_workshops'"
SHORTNAME="cw"
QUERY="series(X)"
N_WORKSHOPS=10

View File

@ -0,0 +1,31 @@
#!/bin/bash
source cw.sh
source ../benchs.sh
SOLVER="fove"
function run_all_graphs
{
write_header $1
run_solver p1000w$N_WORKSHOPS $2
run_solver p5000w$N_WORKSHOPS $2
run_solver p10000w$N_WORKSHOPS $2
run_solver p15000w$N_WORKSHOPS $2
run_solver p20000w$N_WORKSHOPS $2
run_solver p25000w$N_WORKSHOPS $2
run_solver p30000w$N_WORKSHOPS $2
run_solver p35000w$N_WORKSHOPS $2
run_solver p40000w$N_WORKSHOPS $2
run_solver p45000w$N_WORKSHOPS $2
run_solver p50000w$N_WORKSHOPS $2
run_solver p55000w$N_WORKSHOPS $2
run_solver p60000w$N_WORKSHOPS $2
run_solver p65000w$N_WORKSHOPS $2
run_solver p70000w$N_WORKSHOPS $2
}
prepare_new_run
run_all_graphs "fove "

View File

@ -0,0 +1,35 @@
#!/home/tgomes/bin/yap -L --
:- use_module(library(lists)).
:- initialization(main).
main :-
unix(argv(Args)),
nth(1, Args, NP), % number of invitees
nth(2, Args, NW), % number of workshops
atomic_concat(['p', NP , 'w', NW, '.yap'], FileName),
open(FileName, 'write', S),
atom_number(NP, NP2),
atom_number(NW, NW2),
gen(S, NP2, NW2, 1),
write(S, '\n'),
close(S).
gen(_, NP, _, Count) :-
Count > NP, !.
gen(S, NP, NW, Count) :-
gen_workshops(S, Count, NW, 1),
Count1 is Count + 1,
gen(S, NP, NW, Count1).
gen_workshops(_, _, NW, Count) :-
Count > NW, !.
gen_workshops(S, P, NW, Count) :-
format(S, 'c(p~w,w~w).~n', [P,Count]),
Count1 is Count + 1,
gen_workshops(S, P, NW, Count1).

View File

@ -0,0 +1,30 @@
#!/bin/bash
source cw.sh
source ../benchs.sh
SOLVER="hve"
function run_all_graphs
{
write_header $1
run_solver p1000w$N_WORKSHOPS $2
run_solver p5000w$N_WORKSHOPS $2
run_solver p10000w$N_WORKSHOPS $2
run_solver p15000w$N_WORKSHOPS $2
run_solver p20000w$N_WORKSHOPS $2
run_solver p25000w$N_WORKSHOPS $2
run_solver p30000w$N_WORKSHOPS $2
run_solver p35000w$N_WORKSHOPS $2
run_solver p40000w$N_WORKSHOPS $2
run_solver p45000w$N_WORKSHOPS $2
run_solver p50000w$N_WORKSHOPS $2
run_solver p55000w$N_WORKSHOPS $2
run_solver p60000w$N_WORKSHOPS $2
run_solver p65000w$N_WORKSHOPS $2
run_solver p70000w$N_WORKSHOPS $2
}
prepare_new_run
run_all_graphs "hve(elim_heuristic=min_neighbors) " min_neighbors

View File

@ -0,0 +1,30 @@
#!/bin/bash
source sm.sh
source ../benchs.sh
SOLVER="bp"
function run_all_graphs
{
write_header $1
run_solver pop100 $2
run_solver pop200 $2
run_solver pop300 $2
run_solver pop400 $2
run_solver pop500 $2
run_solver pop600 $2
run_solver pop700 $2
run_solver pop800 $2
run_solver pop900 $2
run_solver pop1000 $2
run_solver pop1100 $2
run_solver pop1200 $2
run_solver pop1300 $2
run_solver pop1400 $2
run_solver pop1500 $2
}
prepare_new_run
run_all_graphs "bp(shedule=seq_fixed) " seq_fixed

View File

@ -0,0 +1,30 @@
#!/bin/bash
source sm.sh
source ../benchs.sh
SOLVER="cbp"
function run_all_graphs
{
write_header $1
run_solver pop100 $2
run_solver pop200 $2
run_solver pop300 $2
run_solver pop400 $2
run_solver pop500 $2
run_solver pop600 $2
run_solver pop700 $2
run_solver pop800 $2
run_solver pop900 $2
run_solver pop1000 $2
run_solver pop1100 $2
run_solver pop1200 $2
run_solver pop1300 $2
run_solver pop1400 $2
run_solver pop1500 $2
}
prepare_new_run
run_all_graphs "cbp(shedule=seq_fixed) " seq_fixed

View File

@ -0,0 +1,31 @@
#!/bin/bash
source sm.sh
source ../benchs.sh
SOLVER="fove"
function run_all_graphs
{
write_header $1
run_solver pop100 $2
run_solver pop200 $2
run_solver pop300 $2
run_solver pop400 $2
run_solver pop500 $2
run_solver pop600 $2
run_solver pop700 $2
run_solver pop800 $2
run_solver pop900 $2
run_solver pop1000 $2
run_solver pop1100 $2
run_solver pop1200 $2
run_solver pop1300 $2
run_solver pop1400 $2
run_solver pop1500 $2
}
prepare_new_run
run_all_graphs "fove "

View File

@ -1,16 +1,12 @@
#!/home/tiago/bin/yap -L -- #!/home/tgomes/bin/yap -L --
:- initialization(main). :- initialization(main).
main :- main :-
unix(argv([H])), unix(argv([N])),
generate_town(H). atomic_concat(['pop', N, '.yap'], FileName),
generate_town(N) :-
atomic_concat(['pop_', N, '.yap'], FileName),
open(FileName, 'write', S), open(FileName, 'write', S),
atom_number(N, N2), atom_number(N, N2),
generate_people(S, N2, 4), generate_people(S, N2, 4),

View File

@ -0,0 +1,33 @@
#!/bin/bash
source sm.sh
source ../benchs.sh
SOLVER="hve"
function run_all_graphs
{
write_header $1
run_solver pop100 $2
#run_solver pop200 $2
#run_solver pop300 $2
#run_solver pop400 $2
#run_solver pop500 $2
#run_solver pop600 $2
#run_solver pop700 $2
#run_solver pop800 $2
#run_solver pop900 $2
#run_solver pop1000 $2
#run_solver pop1100 $2
#run_solver pop1200 $2
#run_solver pop1300 $2
#run_solver pop1400 $2
#run_solver pop1500 $2
}
prepare_new_run
run_all_graphs "hve(elim_heuristic=min_neighbors) " min_neighbors
#run_all_graphs "hve(elim_heuristic=min_weight) " min_weight
#run_all_graphs "hve(elim_heuristic=min_fill) " min_fill
#run_all_graphs "hve(elim_heuristic=weighted_min_fill) " weighted_min_fill

View File

@ -0,0 +1,6 @@
#!/bin/bash
NETWORK="'../../examples/social_domain2'"
SHORTNAME="sm"
QUERY="smokes(p1,t), smokes(p2,t), friends(p1,p2,X)"

View File

@ -0,0 +1,37 @@
#!/bin/bash
source wa.sh
source ../benchs.sh
SOLVER="bp"
function run_all_graphs
{
write_header $1
run_solver p1000attrs$N_ATTRS $2
run_solver p5000attrs$N_ATTRS $2
run_solver p10000attrs$N_ATTRS $2
run_solver p15000attrs$N_ATTRS $2
run_solver p20000attrs$N_ATTRS $2
run_solver p25000attrs$N_ATTRS $2
run_solver p30000attrs$N_ATTRS $2
run_solver p35000attrs$N_ATTRS $2
return
run_solver p40000attrs$N_ATTRS $2
run_solver p45000attrs$N_ATTRS $2
run_solver p50000attrs$N_ATTRS $2
run_solver p55000attrs$N_ATTRS $2
run_solver p60000attrs$N_ATTRS $2
run_solver p65000attrs$N_ATTRS $2
run_solver p70000attrs$N_ATTRS $2
run_solver p75000attrs$N_ATTRS $2
run_solver p80000attrs$N_ATTRS $2
run_solver p85000attrs$N_ATTRS $2
run_solver p90000attrs$N_ATTRS $2
run_solver p95000attrs$N_ATTRS $2
run_solver p100000attrs$N_ATTRS $2
}
prepare_new_run
run_all_graphs "bp(shedule=seq_fixed) " seq_fixed

View File

@ -0,0 +1,36 @@
#!/bin/bash
source wa.sh
source ../benchs.sh
SOLVER="cbp"
function run_all_graphs
{
write_header $1
run_solver p1000attrs$N_ATTRS $2
run_solver p5000attrs$N_ATTRS $2
run_solver p10000attrs$N_ATTRS $2
run_solver p15000attrs$N_ATTRS $2
run_solver p20000attrs$N_ATTRS $2
run_solver p25000attrs$N_ATTRS $2
run_solver p30000attrs$N_ATTRS $2
run_solver p35000attrs$N_ATTRS $2
run_solver p40000attrs$N_ATTRS $2
run_solver p45000attrs$N_ATTRS $2
run_solver p50000attrs$N_ATTRS $2
run_solver p55000attrs$N_ATTRS $2
run_solver p60000attrs$N_ATTRS $2
run_solver p65000attrs$N_ATTRS $2
run_solver p70000attrs$N_ATTRS $2
run_solver p75000attrs$N_ATTRS $2
run_solver p80000attrs$N_ATTRS $2
run_solver p85000attrs$N_ATTRS $2
run_solver p90000attrs$N_ATTRS $2
run_solver p95000attrs$N_ATTRS $2
run_solver p100000attrs$N_ATTRS $2
}
prepare_new_run
run_all_graphs "cbp(shedule=seq_fixed) " seq_fixed

View File

@ -0,0 +1,37 @@
#!/bin/bash
source wa.sh
source ../benchs.sh
SOLVER="fove"
function run_all_graphs
{
write_header $1
run_solver p1000attrs$N_ATTRS $2
run_solver p5000attrs$N_ATTRS $2
run_solver p10000attrs$N_ATTRS $2
run_solver p15000attrs$N_ATTRS $2
run_solver p20000attrs$N_ATTRS $2
run_solver p25000attrs$N_ATTRS $2
run_solver p30000attrs$N_ATTRS $2
run_solver p35000attrs$N_ATTRS $2
run_solver p40000attrs$N_ATTRS $2
run_solver p45000attrs$N_ATTRS $2
run_solver p50000attrs$N_ATTRS $2
run_solver p55000attrs$N_ATTRS $2
run_solver p60000attrs$N_ATTRS $2
run_solver p65000attrs$N_ATTRS $2
run_solver p70000attrs$N_ATTRS $2
run_solver p75000attrs$N_ATTRS $2
run_solver p80000attrs$N_ATTRS $2
run_solver p85000attrs$N_ATTRS $2
run_solver p90000attrs$N_ATTRS $2
run_solver p95000attrs$N_ATTRS $2
run_solver p100000attrs$N_ATTRS $2
}
prepare_new_run
run_all_graphs "fove "

View File

@ -0,0 +1,39 @@
#!/home/tgomes/bin/yap -L --
:- use_module(library(lists)).
:- initialization(main).
main :-
unix(argv(Args)),
nth(1, Args, NP), % number of invitees
nth(2, Args, NA), % number of attributes
atomic_concat(['p', NP , 'attrs', NA, '.yap'], FileName),
open(FileName, 'write', S),
atom_number(NP, NP2),
atom_number(NA, NA2),
generate_people(S, NP2, 1),
write(S, '\n'),
generate_attrs(S, NA2, 7),
write(S, '\n'),
close(S).
generate_people(S, N, Counting) :-
Counting > N, !.
generate_people(S, N, Counting) :-
format(S, 'people(p~w).~n', [Counting]),
Counting1 is Counting + 1,
generate_people(S, N, Counting1).
generate_attrs(S, N, Counting) :-
Counting > N, !.
generate_attrs(S, N, Counting) :-
%format(S, 'people(p~w).~n', [Counting]),
format(S, 'markov attends(P)::[t,f], attr~w::[t,f]', [Counting]),
format(S, '; [0.7, 0.3, 0.3, 0.3] ; [people(P)].~n',[]),
Counting1 is Counting + 1,
generate_attrs(S, N, Counting1).

View File

@ -0,0 +1,36 @@
#!/bin/bash
source wa.sh
source ../benchs.sh
SOLVER="hve"
function run_all_graphs
{
write_header $1
run_solver p1000attrs$N_ATTRS $2
run_solver p5000attrs$N_ATTRS $2
run_solver p10000attrs$N_ATTRS $2
run_solver p15000attrs$N_ATTRS $2
run_solver p20000attrs$N_ATTRS $2
run_solver p25000attrs$N_ATTRS $2
run_solver p30000attrs$N_ATTRS $2
run_solver p35000attrs$N_ATTRS $2
run_solver p40000attrs$N_ATTRS $2
run_solver p45000attrs$N_ATTRS $2
run_solver p50000attrs$N_ATTRS $2
run_solver p55000attrs$N_ATTRS $2
run_solver p60000attrs$N_ATTRS $2
run_solver p65000attrs$N_ATTRS $2
run_solver p70000attrs$N_ATTRS $2
run_solver p75000attrs$N_ATTRS $2
run_solver p80000attrs$N_ATTRS $2
run_solver p85000attrs$N_ATTRS $2
run_solver p90000attrs$N_ATTRS $2
run_solver p95000attrs$N_ATTRS $2
run_solver p100000attrs$N_ATTRS $2
}
prepare_new_run
run_all_graphs "hve(elim_heuristic=min_neighbors) " min_neighbors

View File

@ -0,0 +1,9 @@
#!/bin/bash
NETWORK="'../../examples/workshop_attrs'"
SHORTNAME="wa"
QUERY="series(X)"
N_ATTRS=6

View File

@ -1,5 +1,4 @@
:- module(clpbn, [{}/1, :- module(clpbn, [{}/1,
clpbn_flag/2, clpbn_flag/2,
set_clpbn_flag/2, set_clpbn_flag/2,
@ -39,24 +38,22 @@
run_ve_solver/3 run_ve_solver/3
]). ]).
:- use_module('clpbn/bp', :- use_module('clpbn/horus_ground',
[bp/3, [call_horus_ground_solver/6,
check_if_bp_done/1, check_if_horus_ground_solver_done/1,
init_bp_solver/4, init_horus_ground_solver/4,
run_bp_solver/3, run_horus_ground_solver/3,
call_bp_ground/6, finalize_horus_ground_solver/1
finalize_bp_solver/1
]). ]).
:- use_module('clpbn/fove', :- use_module('clpbn/horus_lifted',
[fove/3, [call_horus_lifted_solver/3,
check_if_fove_done/1, check_if_horus_lifted_solver_done/1,
init_fove_solver/4, init_horus_lifted_solver/4,
run_fove_solver/3, run_horus_lifted_solver/3,
finalize_fove_solver/1 finalize_horus_lifted_solver/1
]). ]).
:- use_module('clpbn/jt', :- use_module('clpbn/jt',
[jt/3, [jt/3,
init_jt_solver/4, init_jt_solver/4,
@ -306,18 +303,19 @@ write_out(jt, GVars, AVars, DiffVars) :-
jt(GVars, AVars, DiffVars). jt(GVars, AVars, DiffVars).
write_out(bdd, GVars, AVars, DiffVars) :- write_out(bdd, GVars, AVars, DiffVars) :-
bdd(GVars, AVars, DiffVars). bdd(GVars, AVars, DiffVars).
write_out(bp, GVars, AVars, DiffVars) :- write_out(bp, _GVars, _AVars, _DiffVars) :-
bp(GVars, AVars, DiffVars). writeln('interface not supported anymore').
%bp(GVars, AVars, DiffVars).
write_out(gibbs, GVars, AVars, DiffVars) :- write_out(gibbs, GVars, AVars, DiffVars) :-
gibbs(GVars, AVars, DiffVars). gibbs(GVars, AVars, DiffVars).
write_out(bnt, GVars, AVars, DiffVars) :- write_out(bnt, GVars, AVars, DiffVars) :-
do_bnt(GVars, AVars, DiffVars). do_bnt(GVars, AVars, DiffVars).
write_out(fove, GVars, AVars, DiffVars) :- write_out(fove, GVars, AVars, DiffVars) :-
fove(GVars, AVars, DiffVars). call_horus_lifted_solver(GVars, AVars, DiffVars).
% call a solver with keys, not actual variables % call a solver with keys, not actual variables
call_ground_solver(bp, GVars, GoalKeys, Keys, Factors, Evidence, Answ) :- call_ground_solver(bp, GVars, GoalKeys, Keys, Factors, Evidence, Answ) :-
call_bp_ground(GVars, GoalKeys, Keys, Factors, Evidence, Answ). call_horus_ground_solver(GVars, GoalKeys, Keys, Factors, Evidence, Answ).
get_bnode(Var, Goal) :- get_bnode(Var, Goal) :-
@ -400,7 +398,7 @@ bind_clpbn(_, Var, _, _, _, _, []) :-
check_if_ve_done(Var), !. check_if_ve_done(Var), !.
bind_clpbn(_, Var, _, _, _, _, []) :- bind_clpbn(_, Var, _, _, _, _, []) :-
use(bp), use(bp),
check_if_bp_done(Var), !. check_if_horus_ground_solver_done(Var), !.
bind_clpbn(_, Var, _, _, _, _, []) :- bind_clpbn(_, Var, _, _, _, _, []) :-
use(jt), use(jt),
check_if_ve_done(Var), !. check_if_ve_done(Var), !.
@ -475,7 +473,7 @@ clpbn_init_solver(gibbs, LVs, Vs0, VarsWithUnboundKeys, State) :-
clpbn_init_solver(ve, LVs, Vs0, VarsWithUnboundKeys, State) :- clpbn_init_solver(ve, LVs, Vs0, VarsWithUnboundKeys, State) :-
init_ve_solver(LVs, Vs0, VarsWithUnboundKeys, State). init_ve_solver(LVs, Vs0, VarsWithUnboundKeys, State).
clpbn_init_solver(bp, LVs, Vs0, VarsWithUnboundKeys, State) :- clpbn_init_solver(bp, LVs, Vs0, VarsWithUnboundKeys, State) :-
init_bp_solver(LVs, Vs0, VarsWithUnboundKeys, State). init_horus_ground_solver(LVs, Vs0, VarsWithUnboundKeys, State).
clpbn_init_solver(jt, LVs, Vs0, VarsWithUnboundKeys, State) :- clpbn_init_solver(jt, LVs, Vs0, VarsWithUnboundKeys, State) :-
init_jt_solver(LVs, Vs0, VarsWithUnboundKeys, State). init_jt_solver(LVs, Vs0, VarsWithUnboundKeys, State).
clpbn_init_solver(bdd, LVs, Vs0, VarsWithUnboundKeys, State) :- clpbn_init_solver(bdd, LVs, Vs0, VarsWithUnboundKeys, State) :-
@ -501,7 +499,7 @@ clpbn_run_solver(ve, LVs, LPs, State) :-
run_ve_solver(LVs, LPs, State). run_ve_solver(LVs, LPs, State).
clpbn_run_solver(bp, LVs, LPs, State) :- clpbn_run_solver(bp, LVs, LPs, State) :-
run_bp_solver(LVs, LPs, State). run_horus_ground_solver(LVs, LPs, State).
clpbn_run_solver(jt, LVs, LPs, State) :- clpbn_run_solver(jt, LVs, LPs, State) :-
run_jt_solver(LVs, LPs, State). run_jt_solver(LVs, LPs, State).
@ -522,7 +520,7 @@ clpbn_finalize_solver(State) :-
solver(bp), !, solver(bp), !,
functor(State, _, Last), functor(State, _, Last),
arg(Last, State, Info), arg(Last, State, Info),
finalize_bp_solver(Info). finalize_horus_ground_solver(Info).
clpbn_finalize_solver(_State). clpbn_finalize_solver(_State).
probability(Goal, Prob) :- probability(Goal, Prob) :-

View File

@ -1,493 +0,0 @@
#include <cassert>
#include <limits>
#include <algorithm>
#include <iostream>
#include "BpSolver.h"
#include "FactorGraph.h"
#include "Factor.h"
#include "Indexer.h"
#include "Horus.h"
BpSolver::BpSolver (const FactorGraph& fg) : Solver (fg)
{
fg_ = &fg;
runned_ = false;
}
BpSolver::~BpSolver (void)
{
for (unsigned i = 0; i < varsI_.size(); i++) {
delete varsI_[i];
}
for (unsigned i = 0; i < facsI_.size(); i++) {
delete facsI_[i];
}
for (unsigned i = 0; i < links_.size(); i++) {
delete links_[i];
}
}
Params
BpSolver::solveQuery (VarIds queryVids)
{
assert (queryVids.empty() == false);
if (queryVids.size() == 1) {
return getPosterioriOf (queryVids[0]);
} else {
return getJointDistributionOf (queryVids);
}
}
Params
BpSolver::getPosterioriOf (VarId vid)
{
if (runned_ == false) {
runSolver();
}
assert (fg_->getVarNode (vid));
VarNode* var = fg_->getVarNode (vid);
Params probs;
if (var->hasEvidence()) {
probs.resize (var->range(), LogAware::noEvidence());
probs[var->getEvidence()] = LogAware::withEvidence();
} else {
probs.resize (var->range(), LogAware::multIdenty());
const SpLinkSet& links = ninf(var)->getLinks();
if (Globals::logDomain) {
for (unsigned i = 0; i < links.size(); i++) {
Util::add (probs, links[i]->getMessage());
}
LogAware::normalize (probs);
Util::fromLog (probs);
} else {
for (unsigned i = 0; i < links.size(); i++) {
Util::multiply (probs, links[i]->getMessage());
}
LogAware::normalize (probs);
}
}
return probs;
}
Params
BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
{
if (runned_ == false) {
runSolver();
}
int idx = -1;
VarNode* vn = fg_->getVarNode (jointVarIds[0]);
const FacNodes& facNodes = vn->neighbors();
for (unsigned i = 0; i < facNodes.size(); i++) {
if (facNodes[i]->factor().contains (jointVarIds)) {
idx = i;
break;
}
}
if (idx == -1) {
return getJointByConditioning (jointVarIds);
} else {
Factor res (facNodes[idx]->factor());
const SpLinkSet& links = ninf(facNodes[idx])->getLinks();
for (unsigned i = 0; i < links.size(); i++) {
Factor msg ({links[i]->getVariable()->varId()},
{links[i]->getVariable()->range()},
getVar2FactorMsg (links[i]));
res.multiply (msg);
}
res.sumOutAllExcept (jointVarIds);
res.reorderArguments (jointVarIds);
res.normalize();
Params jointDist = res.params();
if (Globals::logDomain) {
Util::fromLog (jointDist);
}
return jointDist;
}
}
void
BpSolver::runSolver (void)
{
clock_t start;
if (Constants::COLLECT_STATS) {
start = clock();
}
initializeSolver();
nIters_ = 0;
while (!converged() && nIters_ < BpOptions::maxIter) {
nIters_ ++;
if (Constants::DEBUG >= 2) {
Util::printHeader (string ("Iteration ") + Util::toString (nIters_));
// cout << endl;
}
switch (BpOptions::schedule) {
case BpOptions::Schedule::SEQ_RANDOM:
random_shuffle (links_.begin(), links_.end());
// no break
case BpOptions::Schedule::SEQ_FIXED:
for (unsigned i = 0; i < links_.size(); i++) {
calculateAndUpdateMessage (links_[i]);
}
break;
case BpOptions::Schedule::PARALLEL:
for (unsigned i = 0; i < links_.size(); i++) {
calculateMessage (links_[i]);
}
for (unsigned i = 0; i < links_.size(); i++) {
updateMessage(links_[i]);
}
break;
case BpOptions::Schedule::MAX_RESIDUAL:
maxResidualSchedule();
break;
}
if (Constants::DEBUG >= 2) {
cout << endl;
}
}
if (Constants::DEBUG >= 2) {
cout << endl;
if (nIters_ < BpOptions::maxIter) {
cout << "Sum-Product converged in " ;
cout << nIters_ << " iterations" << endl;
} else {
cout << "The maximum number of iterations was hit, terminating..." ;
cout << endl;
}
}
unsigned size = fg_->varNodes().size();
if (Constants::COLLECT_STATS) {
unsigned nIters = 0;
bool loopy = fg_->isTree() == false;
if (loopy) nIters = nIters_;
double time = (double (clock() - start)) / CLOCKS_PER_SEC;
Statistics::updateStatistics (size, loopy, nIters, time);
}
runned_ = true;
}
void
BpSolver::createLinks (void)
{
const FacNodes& facNodes = fg_->facNodes();
for (unsigned i = 0; i < facNodes.size(); i++) {
const VarNodes& neighbors = facNodes[i]->neighbors();
for (unsigned j = 0; j < neighbors.size(); j++) {
links_.push_back (new SpLink (facNodes[i], neighbors[j]));
}
}
}
void
BpSolver::maxResidualSchedule (void)
{
if (nIters_ == 1) {
for (unsigned i = 0; i < links_.size(); i++) {
calculateMessage (links_[i]);
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
linkMap_.insert (make_pair (links_[i], it));
}
return;
}
for (unsigned c = 0; c < links_.size(); c++) {
if (Constants::DEBUG >= 2) {
cout << "current residuals:" << endl;
for (SortedOrder::iterator it = sortedOrder_.begin();
it != sortedOrder_.end(); it ++) {
cout << " " << setw (30) << left << (*it)->toString();
cout << "residual = " << (*it)->getResidual() << endl;
}
}
SortedOrder::iterator it = sortedOrder_.begin();
SpLink* link = *it;
if (link->getResidual() < BpOptions::accuracy) {
return;
}
updateMessage (link);
link->clearResidual();
sortedOrder_.erase (it);
linkMap_.find (link)->second = sortedOrder_.insert (link);
// update the messages that depend on message source --> destin
const FacNodes& factorNeighbors = link->getVariable()->neighbors();
for (unsigned i = 0; i < factorNeighbors.size(); i++) {
if (factorNeighbors[i] != link->getFactor()) {
const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks();
for (unsigned j = 0; j < links.size(); j++) {
if (links[j]->getVariable() != link->getVariable()) {
calculateMessage (links[j]);
SpLinkMap::iterator iter = linkMap_.find (links[j]);
sortedOrder_.erase (iter->second);
iter->second = sortedOrder_.insert (links[j]);
}
}
}
}
if (Constants::DEBUG >= 2) {
Util::printDashedLine();
}
}
}
void
BpSolver::calculateFactor2VariableMsg (SpLink* link)
{
FacNode* src = link->getFactor();
const VarNode* dst = link->getVariable();
const SpLinkSet& links = ninf(src)->getLinks();
// calculate the product of messages that were sent
// to factor `src', except from var `dst'
unsigned msgSize = 1;
for (unsigned i = 0; i < links.size(); i++) {
msgSize *= links[i]->getVariable()->range();
}
unsigned repetitions = 1;
Params msgProduct (msgSize, LogAware::multIdenty());
if (Globals::logDomain) {
for (int i = links.size() - 1; i >= 0; i--) {
if (links[i]->getVariable() != dst) {
Util::add (msgProduct, getVar2FactorMsg (links[i]), repetitions);
repetitions *= links[i]->getVariable()->range();
} else {
unsigned ds = links[i]->getVariable()->range();
Util::add (msgProduct, Params (ds, 1.0), repetitions);
repetitions *= ds;
}
}
} else {
for (int i = links.size() - 1; i >= 0; i--) {
if (links[i]->getVariable() != dst) {
if (Constants::DEBUG >= 5) {
cout << " message from " << links[i]->getVariable()->label();
cout << ": " << endl;
}
Util::multiply (msgProduct, getVar2FactorMsg (links[i]), repetitions);
repetitions *= links[i]->getVariable()->range();
} else {
unsigned ds = links[i]->getVariable()->range();
Util::multiply (msgProduct, Params (ds, 1.0), repetitions);
repetitions *= ds;
}
}
}
Factor result (src->factor().arguments(),
src->factor().ranges(), msgProduct);
result.multiply (src->factor());
if (Constants::DEBUG >= 5) {
cout << " message product: " << msgProduct << endl;
cout << " original factor: " << src->factor().params() << endl;
cout << " factor product: " << result.params() << endl;
}
result.sumOutAllExcept (dst->varId());
if (Constants::DEBUG >= 5) {
cout << " marginalized: " ;
cout << result.params() << endl;
}
const Params& resultParams = result.params();
Params& message = link->getNextMessage();
for (unsigned i = 0; i < resultParams.size(); i++) {
message[i] = resultParams[i];
}
LogAware::normalize (message);
if (Constants::DEBUG >= 5) {
cout << " curr msg: " << link->getMessage() << endl;
cout << " next msg: " << message << endl;
}
}
Params
BpSolver::getVar2FactorMsg (const SpLink* link) const
{
const VarNode* src = link->getVariable();
const FacNode* dst = link->getFactor();
Params msg;
if (src->hasEvidence()) {
msg.resize (src->range(), LogAware::noEvidence());
msg[src->getEvidence()] = LogAware::withEvidence();
if (Constants::DEBUG >= 5) {
cout << msg;
}
} else {
msg.resize (src->range(), LogAware::one());
}
if (Constants::DEBUG >= 5) {
cout << msg;
}
const SpLinkSet& links = ninf (src)->getLinks();
if (Globals::logDomain) {
for (unsigned i = 0; i < links.size(); i++) {
if (links[i]->getFactor() != dst) {
Util::add (msg, links[i]->getMessage());
}
}
} else {
for (unsigned i = 0; i < links.size(); i++) {
if (links[i]->getFactor() != dst) {
Util::multiply (msg, links[i]->getMessage());
if (Constants::DEBUG >= 5) {
cout << " x " << links[i]->getMessage();
}
}
}
}
if (Constants::DEBUG >= 5) {
cout << " = " << msg;
}
return msg;
}
Params
BpSolver::getJointByConditioning (const VarIds& jointVarIds) const
{
VarNodes jointVars;
for (unsigned i = 0; i < jointVarIds.size(); i++) {
assert (fg_->getVarNode (jointVarIds[i]));
jointVars.push_back (fg_->getVarNode (jointVarIds[i]));
}
FactorGraph* fg = new FactorGraph (*fg_);
BpSolver solver (*fg);
solver.runSolver();
Params prevBeliefs = solver.getPosterioriOf (jointVarIds[0]);
VarIds observedVids = {jointVars[0]->varId()};
for (unsigned i = 1; i < jointVarIds.size(); i++) {
assert (jointVars[i]->hasEvidence() == false);
Params newBeliefs;
Vars observedVars;
for (unsigned j = 0; j < observedVids.size(); j++) {
observedVars.push_back (fg->getVarNode (observedVids[j]));
}
StatesIndexer idx (observedVars, false);
while (idx.valid()) {
for (unsigned j = 0; j < observedVars.size(); j++) {
observedVars[j]->setEvidence (idx[j]);
}
++ idx;
BpSolver solver (*fg);
solver.runSolver();
Params beliefs = solver.getPosterioriOf (jointVarIds[i]);
for (unsigned k = 0; k < beliefs.size(); k++) {
newBeliefs.push_back (beliefs[k]);
}
}
int count = -1;
for (unsigned j = 0; j < newBeliefs.size(); j++) {
if (j % jointVars[i]->range() == 0) {
count ++;
}
newBeliefs[j] *= prevBeliefs[count];
}
prevBeliefs = newBeliefs;
observedVids.push_back (jointVars[i]->varId());
}
return prevBeliefs;
}
void
BpSolver::initializeSolver (void)
{
const VarNodes& varNodes = fg_->varNodes();
varsI_.reserve (varNodes.size());
for (unsigned i = 0; i < varNodes.size(); i++) {
varsI_.push_back (new SPNodeInfo());
}
const FacNodes& facNodes = fg_->facNodes();
facsI_.reserve (facNodes.size());
for (unsigned i = 0; i < facNodes.size(); i++) {
facsI_.push_back (new SPNodeInfo());
}
createLinks();
for (unsigned i = 0; i < links_.size(); i++) {
FacNode* src = links_[i]->getFactor();
VarNode* dst = links_[i]->getVariable();
ninf (dst)->addSpLink (links_[i]);
ninf (src)->addSpLink (links_[i]);
}
}
bool
BpSolver::converged (void)
{
if (links_.size() == 0) {
return true;
}
if (nIters_ <= 1) {
return false;
}
bool converged = true;
if (BpOptions::schedule == BpOptions::Schedule::MAX_RESIDUAL) {
double maxResidual = (*(sortedOrder_.begin()))->getResidual();
if (maxResidual > BpOptions::accuracy) {
converged = false;
} else {
converged = true;
}
} else {
for (unsigned i = 0; i < links_.size(); i++) {
double residual = links_[i]->getResidual();
if (Constants::DEBUG >= 2) {
cout << links_[i]->toString() + " residual = " << residual << endl;
}
if (residual > BpOptions::accuracy) {
converged = false;
if (Constants::DEBUG == 0) break;
}
}
if (Constants::DEBUG >= 2) {
cout << endl;
}
}
return converged;
}
void
BpSolver::printLinkInformation (void) const
{
for (unsigned i = 0; i < links_.size(); i++) {
SpLink* l = links_[i];
cout << l->toString() << ":" << endl;
cout << " curr msg = " ;
cout << l->getMessage() << endl;
cout << " next msg = " ;
cout << l->getNextMessage() << endl;
cout << " residual = " << l->getResidual() << endl;
}
}

View File

@ -1,339 +0,0 @@
#include "CFactorGraph.h"
#include "Factor.h"
bool CFactorGraph::checkForIdenticalFactors = true;
CFactorGraph::CFactorGraph (const FactorGraph& fg)
{
groundFg_ = &fg;
freeColor_ = 0;
const VarNodes& varNodes = fg.varNodes();
varSignatures_.reserve (varNodes.size());
for (unsigned i = 0; i < varNodes.size(); i++) {
unsigned c = (varNodes[i]->neighbors().size() * 2) + 1;
varSignatures_.push_back (Signature (c));
}
const FacNodes& facNodes = fg.facNodes();
facSignatures_.reserve (facNodes.size());
for (unsigned i = 0; i < facNodes.size(); i++) {
unsigned c = facNodes[i]->neighbors().size() + 1;
facSignatures_.push_back (Signature (c));
}
varColors_.resize (varNodes.size());
facColors_.resize (facNodes.size());
setInitialColors();
createGroups();
}
CFactorGraph::~CFactorGraph (void)
{
for (unsigned i = 0; i < varClusters_.size(); i++) {
delete varClusters_[i];
}
for (unsigned i = 0; i < facClusters_.size(); i++) {
delete facClusters_[i];
}
}
void
CFactorGraph::setInitialColors (void)
{
// create the initial variable colors
VarColorMap colorMap;
const VarNodes& varNodes = groundFg_->varNodes();
for (unsigned i = 0; i < varNodes.size(); i++) {
unsigned dsize = varNodes[i]->range();
VarColorMap::iterator it = colorMap.find (dsize);
if (it == colorMap.end()) {
it = colorMap.insert (make_pair (
dsize, vector<Color> (dsize+1,-1))).first;
}
unsigned idx;
if (varNodes[i]->hasEvidence()) {
idx = varNodes[i]->getEvidence();
} else {
idx = dsize;
}
vector<Color>& stateColors = it->second;
if (stateColors[idx] == -1) {
stateColors[idx] = getFreeColor();
}
setColor (varNodes[i], stateColors[idx]);
}
const FacNodes& facNodes = groundFg_->facNodes();
for (unsigned i = 0; i < facNodes.size(); i++) {
facNodes[i]->factor().setDistId (Util::maxUnsigned());
}
// FIXME FIXME FIXME : pfl should give correct dist ids.
if (checkForIdenticalFactors || true) {
unsigned groupCount = 1;
for (unsigned i = 0; i < facNodes.size(); i++) {
Factor& f1 = facNodes[i]->factor();
if (f1.distId() != Util::maxUnsigned()) {
continue;
}
f1.setDistId (groupCount);
for (unsigned j = i + 1; j < facNodes.size(); j++) {
Factor& f2 = facNodes[j]->factor();
if (f2.distId() != Util::maxUnsigned()) {
continue;
}
if (f1.size() == f2.size() &&
f1.ranges() == f2.ranges() &&
f1.params() == f2.params()) {
f2.setDistId (groupCount);
}
}
groupCount ++;
}
}
// create the initial factor colors
DistColorMap distColors;
for (unsigned i = 0; i < facNodes.size(); i++) {
unsigned distId = facNodes[i]->factor().distId();
DistColorMap::iterator it = distColors.find (distId);
if (it == distColors.end()) {
it = distColors.insert (make_pair (distId, getFreeColor())).first;
}
setColor (facNodes[i], it->second);
}
}
void
CFactorGraph::createGroups (void)
{
VarSignMap varGroups;
FacSignMap facGroups;
unsigned nIters = 0;
bool groupsHaveChanged = true;
const VarNodes& varNodes = groundFg_->varNodes();
const FacNodes& facNodes = groundFg_->facNodes();
while (groupsHaveChanged || nIters == 1) {
nIters ++;
unsigned prevFactorGroupsSize = facGroups.size();
facGroups.clear();
// set a new color to the factors with the same signature
for (unsigned i = 0; i < facNodes.size(); i++) {
const Signature& signature = getSignature (facNodes[i]);
FacSignMap::iterator it = facGroups.find (signature);
if (it == facGroups.end()) {
it = facGroups.insert (make_pair (signature, FacNodes())).first;
}
it->second.push_back (facNodes[i]);
}
for (FacSignMap::iterator it = facGroups.begin();
it != facGroups.end(); it++) {
Color newColor = getFreeColor();
FacNodes& groupMembers = it->second;
for (unsigned i = 0; i < groupMembers.size(); i++) {
setColor (groupMembers[i], newColor);
}
}
// set a new color to the variables with the same signature
unsigned prevVarGroupsSize = varGroups.size();
varGroups.clear();
for (unsigned i = 0; i < varNodes.size(); i++) {
const Signature& signature = getSignature (varNodes[i]);
VarSignMap::iterator it = varGroups.find (signature);
if (it == varGroups.end()) {
it = varGroups.insert (make_pair (signature, VarNodes())).first;
}
it->second.push_back (varNodes[i]);
}
for (VarSignMap::iterator it = varGroups.begin();
it != varGroups.end(); it++) {
Color newColor = getFreeColor();
VarNodes& groupMembers = it->second;
for (unsigned i = 0; i < groupMembers.size(); i++) {
setColor (groupMembers[i], newColor);
}
}
groupsHaveChanged = prevVarGroupsSize != varGroups.size()
|| prevFactorGroupsSize != facGroups.size();
}
printGroups (varGroups, facGroups);
createClusters (varGroups, facGroups);
}
void
CFactorGraph::createClusters (
const VarSignMap& varGroups,
const FacSignMap& facGroups)
{
varClusters_.reserve (varGroups.size());
for (VarSignMap::const_iterator it = varGroups.begin();
it != varGroups.end(); it++) {
const VarNodes& groupVars = it->second;
VarCluster* vc = new VarCluster (groupVars);
for (unsigned i = 0; i < groupVars.size(); i++) {
vid2VarCluster_.insert (make_pair (groupVars[i]->varId(), vc));
}
varClusters_.push_back (vc);
}
facClusters_.reserve (facGroups.size());
for (FacSignMap::const_iterator it = facGroups.begin();
it != facGroups.end(); it++) {
FacNode* groupFactor = it->second[0];
const VarNodes& neighs = groupFactor->neighbors();
VarClusters varClusters;
varClusters.reserve (neighs.size());
for (unsigned i = 0; i < neighs.size(); i++) {
VarId vid = neighs[i]->varId();
varClusters.push_back (vid2VarCluster_.find (vid)->second);
}
facClusters_.push_back (new FacCluster (it->second, varClusters));
}
}
const Signature&
CFactorGraph::getSignature (const VarNode* varNode)
{
Signature& sign = varSignatures_[varNode->getIndex()];
vector<Color>::iterator it = sign.colors.begin();
const FacNodes& neighs = varNode->neighbors();
for (unsigned i = 0; i < neighs.size(); i++) {
*it = getColor (neighs[i]);
it ++;
*it = neighs[i]->factor().indexOf (varNode->varId());
it ++;
}
*it = getColor (varNode);
return sign;
}
const Signature&
CFactorGraph::getSignature (const FacNode* facNode)
{
Signature& sign = facSignatures_[facNode->getIndex()];
vector<Color>::iterator it = sign.colors.begin();
const VarNodes& neighs = facNode->neighbors();
for (unsigned i = 0; i < neighs.size(); i++) {
*it = getColor (neighs[i]);
it ++;
}
*it = getColor (facNode);
return sign;
}
FactorGraph*
CFactorGraph::getGroundFactorGraph (void) const
{
FactorGraph* fg = new FactorGraph();
for (unsigned i = 0; i < varClusters_.size(); i++) {
VarNode* var = varClusters_[i]->getGroundVarNodes()[0];
VarNode* newVar = new VarNode (var);
varClusters_[i]->setRepresentativeVariable (newVar);
fg->addVarNode (newVar);
}
for (unsigned i = 0; i < facClusters_.size(); i++) {
const VarClusters& myVarClusters = facClusters_[i]->getVarClusters();
Vars myGroundVars;
myGroundVars.reserve (myVarClusters.size());
for (unsigned j = 0; j < myVarClusters.size(); j++) {
VarNode* v = myVarClusters[j]->getRepresentativeVariable();
myGroundVars.push_back (v);
}
FacNode* fn = new FacNode (Factor (myGroundVars,
facClusters_[i]->getGroundFactors()[0]->factor().params()));
facClusters_[i]->setRepresentativeFactor (fn);
fg->addFacNode (fn);
for (unsigned j = 0; j < myGroundVars.size(); j++) {
fg->addEdge (static_cast<VarNode*> (myGroundVars[j]), fn);
}
}
return fg;
}
unsigned
CFactorGraph::getEdgeCount (
const FacCluster* fc,
const VarCluster* vc) const
{
unsigned count = 0;
VarId vid = vc->getGroundVarNodes().front()->varId();
const FacNodes& clusterGroundFactors = fc->getGroundFactors();
for (unsigned i = 0; i < clusterGroundFactors.size(); i++) {
if (clusterGroundFactors[i]->factor().contains (vid)) {
count ++;
}
}
// CVarNodes vars = vc->getGroundVarNodes();
// for (unsigned i = 1; i < vars.size(); i++) {
// VarNode* var = vc->getGroundVarNodes()[i];
// unsigned count2 = 0;
// for (unsigned i = 0; i < clusterGroundFactors.size(); i++) {
// if (clusterGroundFactors[i]->getPosition (var) != -1) {
// count2 ++;
// }
// }
// if (count != count2) { cout << "oops!" << endl; abort(); }
// }
return count;
}
void
CFactorGraph::printGroups (
const VarSignMap& varGroups,
const FacSignMap& facGroups) const
{
unsigned count = 1;
cout << "variable groups:" << endl;
for (VarSignMap::const_iterator it = varGroups.begin();
it != varGroups.end(); it++) {
const VarNodes& groupMembers = it->second;
if (groupMembers.size() > 0) {
cout << count << ": " ;
for (unsigned i = 0; i < groupMembers.size(); i++) {
cout << groupMembers[i]->label() << " " ;
}
count ++;
cout << endl;
}
}
count = 1;
cout << endl << "factor groups:" << endl;
for (FacSignMap::const_iterator it = facGroups.begin();
it != facGroups.end(); it++) {
const FacNodes& groupMembers = it->second;
if (groupMembers.size() > 0) {
cout << ++count << ": " ;
for (unsigned i = 0; i < groupMembers.size(); i++) {
cout << groupMembers[i]->getLabel() << " " ;
}
count ++;
cout << endl;
}
}
}

View File

@ -1,247 +0,0 @@
#ifndef HORUS_CFACTORGRAPH_H
#define HORUS_CFACTORGRAPH_H
#include <unordered_map>
#include "FactorGraph.h"
#include "Factor.h"
#include "Horus.h"
class VarCluster;
class FacCluster;
class Distribution;
class Signature;
class SignatureHash;
typedef long Color;
typedef unordered_map<unsigned, vector<Color>> VarColorMap;
typedef unordered_map<unsigned, Color> DistColorMap;
typedef unordered_map<VarId, VarCluster*> VarId2VarCluster;
typedef vector<VarCluster*> VarClusters;
typedef vector<FacCluster*> FacClusters;
typedef unordered_map<Signature, VarNodes, SignatureHash> VarSignMap;
typedef unordered_map<Signature, FacNodes, SignatureHash> FacSignMap;
struct Signature
{
Signature (unsigned size) : colors(size) { }
bool operator< (const Signature& sig) const
{
if (colors.size() < sig.colors.size()) {
return true;
} else if (colors.size() > sig.colors.size()) {
return false;
} else {
for (unsigned i = 0; i < colors.size(); i++) {
if (colors[i] < sig.colors[i]) {
return true;
} else if (colors[i] > sig.colors[i]) {
return false;
}
}
}
return false;
}
bool operator== (const Signature& sig) const
{
if (colors.size() != sig.colors.size()) {
return false;
}
for (unsigned i = 0; i < colors.size(); i++) {
if (colors[i] != sig.colors[i]) {
return false;
}
}
return true;
}
vector<Color> colors;
};
struct SignatureHash
{
size_t operator() (const Signature &sig) const
{
size_t val = hash<size_t>()(sig.colors.size());
for (unsigned i = 0; i < sig.colors.size(); i++) {
val ^= hash<size_t>()(sig.colors[i]);
}
return val;
}
};
class VarCluster
{
public:
VarCluster (const VarNodes& vs)
{
for (unsigned i = 0; i < vs.size(); i++) {
groundVars_.push_back (vs[i]);
}
}
void addFacCluster (FacCluster* fc)
{
facClusters_.push_back (fc);
}
const FacClusters& getFacClusters (void) const
{
return facClusters_;
}
VarNode* getRepresentativeVariable (void) const { return representVar_; }
void setRepresentativeVariable (VarNode* v) { representVar_ = v; }
const VarNodes& getGroundVarNodes (void) const { return groundVars_; }
private:
VarNodes groundVars_;
FacClusters facClusters_;
VarNode* representVar_;
};
class FacCluster
{
public:
FacCluster (const FacNodes& groundFactors, const VarClusters& vcs)
{
groundFactors_ = groundFactors;
varClusters_ = vcs;
for (unsigned i = 0; i < varClusters_.size(); i++) {
varClusters_[i]->addFacCluster (this);
}
}
const VarClusters& getVarClusters (void) const
{
return varClusters_;
}
bool containsGround (const FacNode* fn)
{
for (unsigned i = 0; i < groundFactors_.size(); i++) {
if (groundFactors_[i] == fn) {
return true;
}
}
return false;
}
FacNode* getRepresentativeFactor (void) const
{
return representFactor_;
}
void setRepresentativeFactor (FacNode* fn)
{
representFactor_ = fn;
}
const FacNodes& getGroundFactors (void) const
{
return groundFactors_;
}
private:
FacNodes groundFactors_;
VarClusters varClusters_;
FacNode* representFactor_;
};
class CFactorGraph
{
public:
CFactorGraph (const FactorGraph&);
~CFactorGraph (void);
const VarClusters& getVarClusters (void) { return varClusters_; }
const FacClusters& getFacClusters (void) { return facClusters_; }
VarNode* getEquivalentVariable (VarId vid)
{
VarCluster* vc = vid2VarCluster_.find (vid)->second;
return vc->getRepresentativeVariable();
}
FactorGraph* getGroundFactorGraph (void) const;
unsigned getEdgeCount (const FacCluster*, const VarCluster*) const;
static bool checkForIdenticalFactors;
private:
Color getFreeColor (void)
{
++ freeColor_;
return freeColor_ - 1;
}
Color getColor (const VarNode* vn) const
{
return varColors_[vn->getIndex()];
}
Color getColor (const FacNode* fn) const {
return facColors_[fn->getIndex()];
}
void setColor (const VarNode* vn, Color c)
{
varColors_[vn->getIndex()] = c;
}
void setColor (const FacNode* fn, Color c)
{
facColors_[fn->getIndex()] = c;
}
VarCluster* getVariableCluster (VarId vid) const
{
return vid2VarCluster_.find (vid)->second;
}
void setInitialColors (void);
void createGroups (void);
void createClusters (const VarSignMap&, const FacSignMap&);
const Signature& getSignature (const VarNode*);
const Signature& getSignature (const FacNode*);
void printGroups (const VarSignMap&, const FacSignMap&) const;
Color freeColor_;
vector<Color> varColors_;
vector<Color> facColors_;
vector<Signature> varSignatures_;
vector<Signature> facSignatures_;
VarClusters varClusters_;
FacClusters facClusters_;
VarId2VarCluster vid2VarCluster_;
const FactorGraph* groundFg_;
};
#endif // HORUS_CFACTORGRAPH_H

View File

@ -1,244 +0,0 @@
#include "CbpSolver.h"
CbpSolver::CbpSolver (const FactorGraph& fg) : BpSolver (fg)
{
unsigned nGroundVars, nGroundFacs, nWithoutNeighs;
if (Constants::COLLECT_STATS) {
nGroundVars = fg_->varNodes().size();
nGroundFacs = fg_->facNodes().size();
const VarNodes& vars = fg_->varNodes();
nWithoutNeighs = 0;
for (unsigned i = 0; i < vars.size(); i++) {
const FacNodes& factors = vars[i]->neighbors();
if (factors.size() == 1 && factors[0]->neighbors().size() == 1) {
nWithoutNeighs ++;
}
}
}
cfg_ = new CFactorGraph (fg);
fg_ = cfg_->getGroundFactorGraph();
if (Constants::COLLECT_STATS) {
unsigned nClusterVars = fg_->varNodes().size();
unsigned nClusterFacs = fg_->facNodes().size();
Statistics::updateCompressingStatistics (nGroundVars,
nGroundFacs, nClusterVars, nClusterFacs, nWithoutNeighs);
}
Util::printHeader ("Uncompressed Factor Graph");
fg.print();
Util::printHeader ("Compressed Factor Graph");
fg_->print();
}
CbpSolver::~CbpSolver (void)
{
delete cfg_;
delete fg_;
for (unsigned i = 0; i < links_.size(); i++) {
delete links_[i];
}
links_.clear();
}
Params
CbpSolver::getPosterioriOf (VarId vid)
{
if (runned_ == false) {
runSolver();
}
assert (cfg_->getEquivalentVariable (vid));
VarNode* var = cfg_->getEquivalentVariable (vid);
Params probs;
if (var->hasEvidence()) {
probs.resize (var->range(), LogAware::noEvidence());
probs[var->getEvidence()] = LogAware::withEvidence();
} else {
probs.resize (var->range(), LogAware::multIdenty());
const SpLinkSet& links = ninf(var)->getLinks();
if (Globals::logDomain) {
for (unsigned i = 0; i < links.size(); i++) {
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
Util::add (probs, l->poweredMessage());
}
LogAware::normalize (probs);
Util::fromLog (probs);
} else {
for (unsigned i = 0; i < links.size(); i++) {
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
Util::multiply (probs, l->poweredMessage());
}
LogAware::normalize (probs);
}
}
return probs;
}
Params
CbpSolver::getJointDistributionOf (const VarIds& jointVids)
{
VarIds eqVarIds;
for (unsigned i = 0; i < jointVids.size(); i++) {
VarNode* vn = cfg_->getEquivalentVariable (jointVids[i]);
eqVarIds.push_back (vn->varId());
}
return BpSolver::getJointDistributionOf (eqVarIds);
}
void
CbpSolver::createLinks (void)
{
const FacClusters& fcs = cfg_->getFacClusters();
for (unsigned i = 0; i < fcs.size(); i++) {
const VarClusters& vcs = fcs[i]->getVarClusters();
for (unsigned j = 0; j < vcs.size(); j++) {
unsigned c = cfg_->getEdgeCount (fcs[i], vcs[j]);
links_.push_back (new CbpSolverLink (
fcs[i]->getRepresentativeFactor(),
vcs[j]->getRepresentativeVariable(), c));
}
}
}
void
CbpSolver::maxResidualSchedule (void)
{
if (nIters_ == 1) {
for (unsigned i = 0; i < links_.size(); i++) {
calculateMessage (links_[i]);
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
linkMap_.insert (make_pair (links_[i], it));
if (Constants::DEBUG >= 2 && Constants::DEBUG < 5) {
cout << "calculating " << links_[i]->toString() << endl;
}
}
return;
}
for (unsigned c = 0; c < links_.size(); c++) {
if (Constants::DEBUG >= 2) {
cout << endl << "current residuals:" << endl;
for (SortedOrder::iterator it = sortedOrder_.begin();
it != sortedOrder_.end(); it ++) {
cout << " " << setw (30) << left << (*it)->toString();
cout << "residual = " << (*it)->getResidual() << endl;
}
}
SortedOrder::iterator it = sortedOrder_.begin();
SpLink* link = *it;
if (Constants::DEBUG >= 2) {
cout << "updating " << (*sortedOrder_.begin())->toString() << endl;
}
if (link->getResidual() < BpOptions::accuracy) {
return;
}
link->updateMessage();
link->clearResidual();
sortedOrder_.erase (it);
linkMap_.find (link)->second = sortedOrder_.insert (link);
// update the messages that depend on message source --> destin
const FacNodes& factorNeighbors = link->getVariable()->neighbors();
for (unsigned i = 0; i < factorNeighbors.size(); i++) {
const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks();
for (unsigned j = 0; j < links.size(); j++) {
if (links[j]->getVariable() != link->getVariable()) {
if (Constants::DEBUG >= 2 && Constants::DEBUG < 5) {
cout << " calculating " << links[j]->toString() << endl;
}
calculateMessage (links[j]);
SpLinkMap::iterator iter = linkMap_.find (links[j]);
sortedOrder_.erase (iter->second);
iter->second = sortedOrder_.insert (links[j]);
}
}
}
// in counting bp, the message that a variable X sends to
// to a factor F depends on the message that F sent to the X
const SpLinkSet& links = ninf(link->getFactor())->getLinks();
for (unsigned i = 0; i < links.size(); i++) {
if (links[i]->getVariable() != link->getVariable()) {
if (Constants::DEBUG >= 2 && Constants::DEBUG < 5) {
cout << " calculating " << links[i]->toString() << endl;
}
calculateMessage (links[i]);
SpLinkMap::iterator iter = linkMap_.find (links[i]);
sortedOrder_.erase (iter->second);
iter->second = sortedOrder_.insert (links[i]);
}
}
}
}
Params
CbpSolver::getVar2FactorMsg (const SpLink* link) const
{
Params msg;
const VarNode* src = link->getVariable();
const FacNode* dst = link->getFactor();
const CbpSolverLink* l = static_cast<const CbpSolverLink*> (link);
if (src->hasEvidence()) {
msg.resize (src->range(), LogAware::noEvidence());
double value = link->getMessage()[src->getEvidence()];
msg[src->getEvidence()] = LogAware::pow (value, l->nrEdges() - 1);
} else {
msg = link->getMessage();
LogAware::pow (msg, l->nrEdges() - 1);
}
if (Constants::DEBUG >= 5) {
cout << " " << "init: " << msg << endl;
}
const SpLinkSet& links = ninf(src)->getLinks();
if (Globals::logDomain) {
for (unsigned i = 0; i < links.size(); i++) {
if (links[i]->getFactor() != dst) {
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
Util::add (msg, l->poweredMessage());
}
}
} else {
for (unsigned i = 0; i < links.size(); i++) {
if (links[i]->getFactor() != dst) {
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
Util::multiply (msg, l->poweredMessage());
if (Constants::DEBUG >= 5) {
cout << " msg from " << l->getFactor()->getLabel() << ": " ;
cout << l->poweredMessage() << endl;
}
}
}
}
if (Constants::DEBUG >= 5) {
cout << " result = " << msg << endl;
}
return msg;
}
void
CbpSolver::printLinkInformation (void) const
{
for (unsigned i = 0; i < links_.size(); i++) {
CbpSolverLink* l = static_cast<CbpSolverLink*> (links_[i]);
cout << l->toString() << ":" << endl;
cout << " curr msg = " << l->getMessage() << endl;
cout << " next msg = " << l->getNextMessage() << endl;
cout << " powered = " << l->poweredMessage() << endl;
cout << " residual = " << l->getResidual() << endl;
}
}

View File

@ -1,60 +0,0 @@
#ifndef HORUS_CBP_H
#define HORUS_CBP_H
#include "BpSolver.h"
#include "CFactorGraph.h"
class Factor;
class CbpSolverLink : public SpLink
{
public:
CbpSolverLink (FacNode* fn, VarNode* vn, unsigned c)
: SpLink (fn, vn), nrEdges_(c),
pwdMsg_(vn->range(), LogAware::one()) { }
unsigned nrEdges (void) const { return nrEdges_; }
const Params& poweredMessage (void) const { return pwdMsg_; }
void updateMessage (void)
{
pwdMsg_ = *nextMsg_;
swap (currMsg_, nextMsg_);
msgSended_ = true;
LogAware::pow (pwdMsg_, nrEdges_);
}
private:
unsigned nrEdges_;
Params pwdMsg_;
};
class CbpSolver : public BpSolver
{
public:
CbpSolver (const FactorGraph& fg);
~CbpSolver (void);
Params getPosterioriOf (VarId);
Params getJointDistributionOf (const VarIds&);
private:
void createLinks (void);
void maxResidualSchedule (void);
Params getVar2FactorMsg (const SpLink*) const;
void printLinkInformation (void) const;
CFactorGraph* cfg_;
};
#endif // HORUS_CBP_H

File diff suppressed because it is too large Load Diff

View File

@ -1,303 +0,0 @@
#include <limits>
#include <fstream>
#include "ElimGraph.h"
ElimHeuristic ElimGraph::elimHeuristic_ = MIN_NEIGHBORS;
ElimGraph::ElimGraph (const vector<Factor*>& factors)
{
for (unsigned i = 0; i < factors.size(); i++) {
const VarIds& vids = factors[i]->arguments();
for (unsigned j = 0; j < vids.size() - 1; j++) {
EgNode* n1 = getEgNode (vids[j]);
if (n1 == 0) {
n1 = new EgNode (vids[j], factors[i]->range (j));
addNode (n1);
}
for (unsigned k = j + 1; k < vids.size(); k++) {
EgNode* n2 = getEgNode (vids[k]);
if (n2 == 0) {
n2 = new EgNode (vids[k], factors[i]->range (k));
addNode (n2);
}
if (neighbors (n1, n2) == false) {
addEdge (n1, n2);
}
}
}
if (vids.size() == 1) {
if (getEgNode (vids[0]) == 0) {
addNode (new EgNode (vids[0], factors[i]->range (0)));
}
}
}
}
ElimGraph::~ElimGraph (void)
{
for (unsigned i = 0; i < nodes_.size(); i++) {
delete nodes_[i];
}
}
VarIds
ElimGraph::getEliminatingOrder (const VarIds& exclude)
{
VarIds elimOrder;
marked_.resize (nodes_.size(), false);
for (unsigned i = 0; i < exclude.size(); i++) {
assert (getEgNode (exclude[i]));
EgNode* node = getEgNode (exclude[i]);
marked_[*node] = true;
}
unsigned nVarsToEliminate = nodes_.size() - exclude.size();
for (unsigned i = 0; i < nVarsToEliminate; i++) {
EgNode* node = getLowestCostNode();
marked_[*node] = true;
elimOrder.push_back (node->varId());
connectAllNeighbors (node);
}
return elimOrder;
}
void
ElimGraph::print (void) const
{
for (unsigned i = 0; i < nodes_.size(); i++) {
cout << "node " << nodes_[i]->label() << " neighs:" ;
vector<EgNode*> neighs = nodes_[i]->neighbors();
for (unsigned j = 0; j < neighs.size(); j++) {
cout << " " << neighs[j]->label();
}
cout << endl;
}
}
void
ElimGraph::exportToGraphViz (
const char* fileName,
bool showNeighborless,
const VarIds& highlightVarIds) const
{
ofstream out (fileName);
if (!out.is_open()) {
cerr << "error: cannot open file to write at " ;
cerr << "Markov::exportToDotFile()" << endl;
abort();
}
out << "strict graph {" << endl;
for (unsigned i = 0; i < nodes_.size(); i++) {
if (showNeighborless || nodes_[i]->neighbors().size() != 0) {
out << '"' << nodes_[i]->label() << '"' << endl;
}
}
for (unsigned i = 0; i < highlightVarIds.size(); i++) {
EgNode* node =getEgNode (highlightVarIds[i]);
if (node) {
out << '"' << node->label() << '"' ;
out << " [shape=box3d]" << endl;
} else {
cout << "error: invalid variable id: " << highlightVarIds[i] << endl;
abort();
}
}
for (unsigned i = 0; i < nodes_.size(); i++) {
vector<EgNode*> neighs = nodes_[i]->neighbors();
for (unsigned j = 0; j < neighs.size(); j++) {
out << '"' << nodes_[i]->label() << '"' << " -- " ;
out << '"' << neighs[j]->label() << '"' << endl;
}
}
out << "}" << endl;
out.close();
}
VarIds
ElimGraph::getEliminationOrder (
const vector<Factor*> factors,
VarIds excludedVids)
{
ElimGraph graph (factors);
// graph.print();
// graph.exportToGraphViz ("_egg.dot");
return graph.getEliminatingOrder (excludedVids);
}
void
ElimGraph::addNode (EgNode* n)
{
nodes_.push_back (n);
n->setIndex (nodes_.size() - 1);
varMap_.insert (make_pair (n->varId(), n));
}
EgNode*
ElimGraph::getEgNode (VarId vid) const
{
unordered_map<VarId, EgNode*>::const_iterator it;
it = varMap_.find (vid);
return (it != varMap_.end()) ? it->second : 0;
}
EgNode*
ElimGraph::getLowestCostNode (void) const
{
EgNode* bestNode = 0;
unsigned minCost = std::numeric_limits<unsigned>::max();
for (unsigned i = 0; i < nodes_.size(); i++) {
if (marked_[i]) continue;
unsigned cost = 0;
switch (elimHeuristic_) {
case MIN_NEIGHBORS:
cost = getNeighborsCost (nodes_[i]);
break;
case MIN_WEIGHT:
cost = getWeightCost (nodes_[i]);
break;
case MIN_FILL:
cost = getFillCost (nodes_[i]);
break;
case WEIGHTED_MIN_FILL:
cost = getWeightedFillCost (nodes_[i]);
break;
default:
assert (false);
}
if (cost < minCost) {
bestNode = nodes_[i];
minCost = cost;
}
}
assert (bestNode);
return bestNode;
}
unsigned
ElimGraph::getNeighborsCost (const EgNode* n) const
{
unsigned cost = 0;
const vector<EgNode*>& neighs = n->neighbors();
for (unsigned i = 0; i < neighs.size(); i++) {
if (marked_[*neighs[i]] == false) {
cost ++;
}
}
return cost;
}
unsigned
ElimGraph::getWeightCost (const EgNode* n) const
{
unsigned cost = 1;
const vector<EgNode*>& neighs = n->neighbors();
for (unsigned i = 0; i < neighs.size(); i++) {
if (marked_[*neighs[i]] == false) {
cost *= neighs[i]->range();
}
}
return cost;
}
unsigned
ElimGraph::getFillCost (const EgNode* n) const
{
unsigned cost = 0;
const vector<EgNode*>& neighs = n->neighbors();
if (neighs.size() > 0) {
for (unsigned i = 0; i < neighs.size() - 1; i++) {
if (marked_[*neighs[i]] == true) continue;
for (unsigned j = i+1; j < neighs.size(); j++) {
if (marked_[*neighs[j]] == true) continue;
if (!neighbors (neighs[i], neighs[j])) {
cost ++;
}
}
}
}
return cost;
}
unsigned
ElimGraph::getWeightedFillCost (const EgNode* n) const
{
unsigned cost = 0;
const vector<EgNode*>& neighs = n->neighbors();
if (neighs.size() > 0) {
for (unsigned i = 0; i < neighs.size() - 1; i++) {
if (marked_[*neighs[i]] == true) continue;
for (unsigned j = i+1; j < neighs.size(); j++) {
if (marked_[*neighs[j]] == true) continue;
if (!neighbors (neighs[i], neighs[j])) {
cost += neighs[i]->range() * neighs[j]->range();
}
}
}
}
return cost;
}
void
ElimGraph::connectAllNeighbors (const EgNode* n)
{
const vector<EgNode*>& neighs = n->neighbors();
if (neighs.size() > 0) {
for (unsigned i = 0; i < neighs.size() - 1; i++) {
if (marked_[*neighs[i]] == true) continue;
for (unsigned j = i+1; j < neighs.size(); j++) {
if (marked_[*neighs[j]] == true) continue;
if (!neighbors (neighs[i], neighs[j])) {
addEdge (neighs[i], neighs[j]);
}
}
}
}
}
bool
ElimGraph::neighbors (const EgNode* n1, const EgNode* n2) const
{
const vector<EgNode*>& neighs = n1->neighbors();
for (unsigned i = 0; i < neighs.size(); i++) {
if (neighs[i] == n2) {
return true;
}
}
return false;
}

View File

@ -1,88 +0,0 @@
#ifndef HORUS_ELIMGRAPH_H
#define HORUS_ELIMGRAPH_H
#include "unordered_map"
#include "FactorGraph.h"
#include "Horus.h"
using namespace std;
enum ElimHeuristic
{
MIN_NEIGHBORS,
MIN_WEIGHT,
MIN_FILL,
WEIGHTED_MIN_FILL
};
class EgNode : public Var
{
public:
EgNode (VarId vid, unsigned range) : Var (vid, range) { }
void addNeighbor (EgNode* n) { neighs_.push_back (n); }
const vector<EgNode*>& neighbors (void) const { return neighs_; }
private:
vector<EgNode*> neighs_;
};
class ElimGraph
{
public:
ElimGraph (const vector<Factor*>&); // TODO
~ElimGraph (void);
VarIds getEliminatingOrder (const VarIds&);
void print (void) const;
void exportToGraphViz (const char*, bool = true,
const VarIds& = VarIds()) const;
static VarIds getEliminationOrder (const vector<Factor*>, VarIds);
static void setEliminationHeuristic (ElimHeuristic h)
{
elimHeuristic_ = h;
}
private:
void addEdge (EgNode* n1, EgNode* n2)
{
assert (n1 != n2);
n1->addNeighbor (n2);
n2->addNeighbor (n1);
}
void addNode (EgNode*);
EgNode* getEgNode (VarId) const;
EgNode* getLowestCostNode (void) const;
unsigned getNeighborsCost (const EgNode*) const;
unsigned getWeightCost (const EgNode*) const;
unsigned getFillCost (const EgNode*) const;
unsigned getWeightedFillCost (const EgNode*) const;
void connectAllNeighbors (const EgNode*);
bool neighbors (const EgNode*, const EgNode*) const;
vector<EgNode*> nodes_;
vector<bool> marked_;
unordered_map<VarId, EgNode*> varMap_;
static ElimHeuristic elimHeuristic_;
};
#endif // HORUS_ELIMGRAPH_H

View File

@ -1,265 +0,0 @@
#include <cstdlib>
#include <cassert>
#include <algorithm>
#include <iostream>
#include <sstream>
#include "Factor.h"
#include "Indexer.h"
Factor::Factor (const Factor& g)
{
copyFromFactor (g);
}
Factor::Factor (
const VarIds& vids,
const Ranges& ranges,
const Params& params,
unsigned distId)
{
args_ = vids;
ranges_ = ranges;
params_ = params;
distId_ = distId;
assert (params_.size() == Util::expectedSize (ranges_));
}
Factor::Factor (
const Vars& vars,
const Params& params,
unsigned distId)
{
for (unsigned i = 0; i < vars.size(); i++) {
args_.push_back (vars[i]->varId());
ranges_.push_back (vars[i]->range());
}
params_ = params;
distId_ = distId;
assert (params_.size() == Util::expectedSize (ranges_));
}
void
Factor::sumOutAllExcept (VarId vid)
{
assert (indexOf (vid) != -1);
while (args_.back() != vid) {
sumOutLastVariable();
}
while (args_.front() != vid) {
sumOutFirstVariable();
}
}
void
Factor::sumOutAllExcept (const VarIds& vids)
{
for (int i = 0; i < (int)args_.size(); i++) {
if (Util::contains (vids, args_[i]) == false) {
sumOut (args_[i]);
i --;
}
}
}
void
Factor::sumOut (VarId vid)
{
int idx = indexOf (vid);
assert (idx != -1);
if (vid == args_.back()) {
sumOutLastVariable(); // optimization
return;
}
if (vid == args_.front()) {
sumOutFirstVariable(); // optimization
return;
}
// number of parameters separating a different state of `var',
// with the states of the remaining variables fixed
unsigned varOffset = 1;
// number of parameters separating a different state of the variable
// on the left of `var', with the states of the remaining vars fixed
unsigned leftVarOffset = 1;
for (int i = args_.size() - 1; i > idx; i--) {
varOffset *= ranges_[i];
leftVarOffset *= ranges_[i];
}
leftVarOffset *= ranges_[idx];
unsigned offset = 0;
unsigned count1 = 0;
unsigned count2 = 0;
unsigned newpsSize = params_.size() / ranges_[idx];
Params newps;
newps.reserve (newpsSize);
while (newps.size() < newpsSize) {
double sum = LogAware::addIdenty();
for (unsigned i = 0; i < ranges_[idx]; i++) {
if (Globals::logDomain) {
sum = Util::logSum (sum, params_[offset]);
} else {
sum += params_[offset];
}
offset += varOffset;
}
newps.push_back (sum);
count1 ++;
if (idx == (int)args_.size() - 1) {
offset = count1 * ranges_[idx];
} else {
if (((offset - varOffset + 1) % leftVarOffset) == 0) {
count1 = 0;
count2 ++;
}
offset = (leftVarOffset * count2) + count1;
}
}
args_.erase (args_.begin() + idx);
ranges_.erase (ranges_.begin() + idx);
params_ = newps;
}
void
Factor::sumOutFirstVariable (void)
{
unsigned range = ranges_.front();
unsigned sep = params_.size() / range;
if (Globals::logDomain) {
for (unsigned i = sep; i < params_.size(); i++) {
params_[i % sep] = Util::logSum (params_[i % sep], params_[i]);
}
} else {
for (unsigned i = sep; i < params_.size(); i++) {
params_[i % sep] += params_[i];
}
}
params_.resize (sep);
args_.erase (args_.begin());
ranges_.erase (ranges_.begin());
}
void
Factor::sumOutLastVariable (void)
{
unsigned range = ranges_.back();
unsigned idx1 = 0;
unsigned idx2 = 0;
if (Globals::logDomain) {
while (idx1 < params_.size()) {
params_[idx2] = params_[idx1];
idx1 ++;
for (unsigned j = 1; j < range; j++) {
params_[idx2] = Util::logSum (params_[idx2], params_[idx1]);
idx1 ++;
}
idx2 ++;
}
} else {
while (idx1 < params_.size()) {
params_[idx2] = params_[idx1];
idx1 ++;
for (unsigned j = 1; j < range; j++) {
params_[idx2] += params_[idx1];
idx1 ++;
}
idx2 ++;
}
}
params_.resize (idx2);
args_.pop_back();
ranges_.pop_back();
}
void
Factor::multiply (Factor& g)
{
if (args_.size() == 0) {
copyFromFactor (g);
return;
}
TFactor<VarId>::multiply (g);
}
void
Factor::reorderAccordingVarIds (void)
{
VarIds sortedVarIds = args_;
sort (sortedVarIds.begin(), sortedVarIds.end());
reorderArguments (sortedVarIds);
}
string
Factor::getLabel (void) const
{
stringstream ss;
ss << "f(" ;
for (unsigned i = 0; i < args_.size(); i++) {
if (i != 0) ss << "," ;
ss << Var (args_[i], ranges_[i]).label();
}
ss << ")" ;
return ss.str();
}
void
Factor::print (void) const
{
Vars vars;
for (unsigned i = 0; i < args_.size(); i++) {
vars.push_back (new Var (args_[i], ranges_[i]));
}
vector<string> jointStrings = Util::getStateLines (vars);
for (unsigned i = 0; i < params_.size(); i++) {
cout << "[" << distId_ << "] f(" << jointStrings[i] << ")" ;
cout << " = " << params_[i] << endl;
}
cout << endl;
for (unsigned i = 0; i < vars.size(); i++) {
delete vars[i];
}
}
void
Factor::copyFromFactor (const Factor& g)
{
args_ = g.arguments();
ranges_ = g.ranges();
params_ = g.params();
distId_ = g.distId();
}

View File

@ -1,288 +0,0 @@
#ifndef HORUS_FACTOR_H
#define HORUS_FACTOR_H
#include <vector>
#include "Var.h"
#include "Indexer.h"
#include "Util.h"
using namespace std;
template <typename T>
class TFactor
{
public:
const vector<T>& arguments (void) const { return args_; }
vector<T>& arguments (void) { return args_; }
const Ranges& ranges (void) const { return ranges_; }
const Params& params (void) const { return params_; }
Params& params (void) { return params_; }
unsigned nrArguments (void) const { return args_.size(); }
unsigned size (void) const { return params_.size(); }
unsigned distId (void) const { return distId_; }
void setDistId (unsigned id) { distId_ = id; }
void normalize (void) { LogAware::normalize (params_); }
void setParams (const Params& newParams)
{
params_ = newParams;
assert (params_.size() == Util::expectedSize (ranges_));
}
int indexOf (const T& t) const
{
int idx = -1;
for (unsigned i = 0; i < args_.size(); i++) {
if (args_[i] == t) {
idx = i;
break;
}
}
return idx;
}
const T& argument (unsigned idx) const
{
assert (idx < args_.size());
return args_[idx];
}
T& argument (unsigned idx)
{
assert (idx < args_.size());
return args_[idx];
}
unsigned range (unsigned idx) const
{
assert (idx < ranges_.size());
return ranges_[idx];
}
void multiply (TFactor<T>& g)
{
const vector<T>& g_args = g.arguments();
const Ranges& g_ranges = g.ranges();
const Params& g_params = g.params();
if (args_ == g_args) {
// optimization: if the factors contain the same set of args,
// we can do a 1 to 1 operation on the parameters
if (Globals::logDomain) {
Util::add (params_, g_params);
} else {
Util::multiply (params_, g_params);
}
} else {
bool sharedArgs = false;
vector<unsigned> gvarpos;
for (unsigned i = 0; i < g_args.size(); i++) {
int idx = indexOf (g_args[i]);
if (idx == -1) {
insertArgument (g_args[i], g_ranges[i]);
gvarpos.push_back (args_.size() - 1);
} else {
sharedArgs = true;
gvarpos.push_back (idx);
}
}
if (sharedArgs == false) {
// optimization: if the original factors doesn't have common args,
// we don't need to marry the states of the common args
unsigned count = 0;
for (unsigned i = 0; i < params_.size(); i++) {
if (Globals::logDomain) {
params_[i] += g_params[count];
} else {
params_[i] *= g_params[count];
}
count ++;
if (count >= g_params.size()) {
count = 0;
}
}
} else {
StatesIndexer indexer (ranges_, false);
while (indexer.valid()) {
unsigned g_li = 0;
unsigned prod = 1;
for (int j = gvarpos.size() - 1; j >= 0; j--) {
g_li += indexer[gvarpos[j]] * prod;
prod *= g_ranges[j];
}
if (Globals::logDomain) {
params_[indexer] += g_params[g_li];
} else {
params_[indexer] *= g_params[g_li];
}
++ indexer;
}
}
}
}
void absorveEvidence (const T& arg, unsigned evidence)
{
int idx = indexOf (arg);
assert (idx != -1);
assert (evidence < ranges_[idx]);
Params copy = params_;
params_.clear();
params_.reserve (copy.size() / ranges_[idx]);
StatesIndexer indexer (ranges_);
for (unsigned i = 0; i < evidence; i++) {
indexer.increment (idx);
}
while (indexer.valid()) {
params_.push_back (copy[indexer]);
indexer.incrementExcluding (idx);
}
args_.erase (args_.begin() + idx);
ranges_.erase (ranges_.begin() + idx);
}
void reorderArguments (const vector<T> newArgs)
{
assert (newArgs.size() == args_.size());
if (newArgs == args_) {
return; // already in the wanted order
}
Ranges newRanges;
vector<unsigned> positions;
for (unsigned i = 0; i < newArgs.size(); i++) {
unsigned idx = indexOf (newArgs[i]);
newRanges.push_back (ranges_[idx]);
positions.push_back (idx);
}
unsigned N = ranges_.size();
Params newParams (params_.size());
for (unsigned i = 0; i < params_.size(); i++) {
unsigned li = i;
// calculate vector index corresponding to linear index
vector<unsigned> vi (N);
for (int k = N-1; k >= 0; k--) {
vi[k] = li % ranges_[k];
li /= ranges_[k];
}
// convert permuted vector index to corresponding linear index
unsigned prod = 1;
unsigned new_li = 0;
for (int k = N - 1; k >= 0; k--) {
new_li += vi[positions[k]] * prod;
prod *= ranges_[positions[k]];
}
newParams[new_li] = params_[i];
}
args_ = newArgs;
ranges_ = newRanges;
params_ = newParams;
}
bool contains (const T& arg) const
{
return Util::contains (args_, arg);
}
bool contains (const vector<T>& args) const
{
for (unsigned i = 0; i < args_.size(); i++) {
if (contains (args[i]) == false) {
return false;
}
}
return true;
}
protected:
vector<T> args_;
Ranges ranges_;
Params params_;
unsigned distId_;
private:
void insertArgument (const T& arg, unsigned range)
{
assert (indexOf (arg) == -1);
Params copy = params_;
params_.clear();
params_.reserve (copy.size() * range);
for (unsigned i = 0; i < copy.size(); i++) {
for (unsigned reps = 0; reps < range; reps++) {
params_.push_back (copy[i]);
}
}
args_.push_back (arg);
ranges_.push_back (range);
}
void insertArguments (const vector<T>& args, const Ranges& ranges)
{
Params copy = params_;
unsigned nrStates = 1;
for (unsigned i = 0; i < args.size(); i++) {
assert (indexOf (args[i]) == -1);
args_.push_back (args[i]);
ranges_.push_back (ranges[i]);
nrStates *= ranges[i];
}
params_.clear();
params_.reserve (copy.size() * nrStates);
for (unsigned i = 0; i < copy.size(); i++) {
for (unsigned reps = 0; reps < nrStates; reps++) {
params_.push_back (copy[i]);
}
}
}
};
class Factor : public TFactor<VarId>
{
public:
Factor (void) { }
Factor (const Factor&);
Factor (const VarIds&, const Ranges&, const Params&,
unsigned = Util::maxUnsigned());
Factor (const Vars&, const Params&,
unsigned = Util::maxUnsigned());
void sumOutAllExcept (VarId);
void sumOutAllExcept (const VarIds&);
void sumOut (VarId);
void sumOutFirstVariable (void);
void sumOutLastVariable (void);
void multiply (Factor&);
void reorderAccordingVarIds (void);
string getLabel (void) const;
void print (void) const;
private:
void copyFromFactor (const Factor& f);
};
#endif // HORUS_FACTOR_H

View File

@ -1,711 +0,0 @@
#include <algorithm>
#include <set>
#include "FoveSolver.h"
#include "Histogram.h"
#include "Util.h"
vector<LiftedOperator*>
LiftedOperator::getValidOps (
ParfactorList& pfList,
const Grounds& query)
{
vector<LiftedOperator*> validOps;
vector<SumOutOperator*> sumOutOps;
vector<CountingOperator*> countOps;
vector<GroundOperator*> groundOps;
sumOutOps = SumOutOperator::getValidOps (pfList, query);
countOps = CountingOperator::getValidOps (pfList);
groundOps = GroundOperator::getValidOps (pfList);
validOps.insert (validOps.end(), sumOutOps.begin(), sumOutOps.end());
validOps.insert (validOps.end(), countOps.begin(), countOps.end());
validOps.insert (validOps.end(), groundOps.begin(), groundOps.end());
return validOps;
}
void
LiftedOperator::printValidOps (
ParfactorList& pfList,
const Grounds& query)
{
vector<LiftedOperator*> validOps;
validOps = LiftedOperator::getValidOps (pfList, query);
for (unsigned i = 0; i < validOps.size(); i++) {
cout << "-> " << validOps[i]->toString() << endl;
delete validOps[i];
}
}
unsigned
SumOutOperator::getCost (void)
{
TinySet<unsigned> groupSet;
ParfactorList::const_iterator pfIter = pfList_.begin();
while (pfIter != pfList_.end()) {
if ((*pfIter)->containsGroup (group_)) {
vector<unsigned> groups = (*pfIter)->getAllGroups();
groupSet |= TinySet<unsigned> (groups);
}
++ pfIter;
}
unsigned cost = 1;
for (unsigned i = 0; i < groupSet.size(); i++) {
pfIter = pfList_.begin();
while (pfIter != pfList_.end()) {
if ((*pfIter)->containsGroup (groupSet[i])) {
int idx = (*pfIter)->indexOfGroup (groupSet[i]);
cost *= (*pfIter)->range (idx);
break;
}
++ pfIter;
}
}
return cost;
}
void
SumOutOperator::apply (void)
{
vector<ParfactorList::iterator> iters
= parfactorsWithGroup (pfList_, group_);
Parfactor* product = *(iters[0]);
pfList_.remove (iters[0]);
for (unsigned i = 1; i < iters.size(); i++) {
product->multiply (**(iters[i]));
pfList_.removeAndDelete (iters[i]);
}
if (product->nrArguments() == 1) {
delete product;
return;
}
int fIdx = product->indexOfGroup (group_);
LogVarSet excl = product->exclusiveLogVars (fIdx);
if (product->constr()->isCountNormalized (excl)) {
product->sumOut (fIdx);
pfList_.addShattered (product);
} else {
Parfactors pfs = FoveSolver::countNormalize (product, excl);
for (unsigned i = 0; i < pfs.size(); i++) {
pfs[i]->sumOut (fIdx);
pfList_.add (pfs[i]);
}
delete product;
}
}
vector<SumOutOperator*>
SumOutOperator::getValidOps (
ParfactorList& pfList,
const Grounds& query)
{
vector<SumOutOperator*> validOps;
set<unsigned> allGroups;
ParfactorList::const_iterator it = pfList.begin();
while (it != pfList.end()) {
const ProbFormulas& formulas = (*it)->arguments();
for (unsigned i = 0; i < formulas.size(); i++) {
allGroups.insert (formulas[i].group());
}
++ it;
}
set<unsigned>::const_iterator groupIt = allGroups.begin();
while (groupIt != allGroups.end()) {
if (validOp (*groupIt, pfList, query)) {
validOps.push_back (new SumOutOperator (*groupIt, pfList));
}
++ groupIt;
}
return validOps;
}
string
SumOutOperator::toString (void)
{
stringstream ss;
vector<ParfactorList::iterator> pfIters;
pfIters = parfactorsWithGroup (pfList_, group_);
int idx = (*pfIters[0])->indexOfGroup (group_);
ProbFormula f = (*pfIters[0])->argument (idx);
TupleSet tupleSet = (*pfIters[0])->constr()->tupleSet (f.logVars());
ss << "sum out " << f.functor() << "/" << f.arity();
ss << "|" << tupleSet << " (group " << group_ << ")";
ss << " [cost=" << getCost() << "]" << endl;
return ss.str();
}
bool
SumOutOperator::validOp (
unsigned group,
ParfactorList& pfList,
const Grounds& query)
{
vector<ParfactorList::iterator> pfIters;
pfIters = parfactorsWithGroup (pfList, group);
if (isToEliminate (*pfIters[0], group, query) == false) {
return false;
}
unordered_map<unsigned, unsigned> groupToRange;
for (unsigned i = 0; i < pfIters.size(); i++) {
int fIdx = (*pfIters[i])->indexOfGroup (group);
if ((*pfIters[i])->argument (fIdx).contains (
(*pfIters[i])->elimLogVars()) == false) {
return false;
}
vector<unsigned> ranges = (*pfIters[i])->ranges();
vector<unsigned> groups = (*pfIters[i])->getAllGroups();
for (unsigned i = 0; i < groups.size(); i++) {
unordered_map<unsigned, unsigned>::iterator it;
it = groupToRange.find (groups[i]);
if (it == groupToRange.end()) {
groupToRange.insert (make_pair (groups[i], ranges[i]));
} else {
if (it->second != ranges[i]) {
return false;
}
}
}
}
return true;
}
vector<ParfactorList::iterator>
SumOutOperator::parfactorsWithGroup (
ParfactorList& pfList,
unsigned group)
{
vector<ParfactorList::iterator> iters;
ParfactorList::iterator pflIt = pfList.begin();
while (pflIt != pfList.end()) {
if ((*pflIt)->containsGroup (group)) {
iters.push_back (pflIt);
}
++ pflIt;
}
return iters;
}
bool
SumOutOperator::isToEliminate (
Parfactor* g,
unsigned group,
const Grounds& query)
{
int fIdx = g->indexOfGroup (group);
const ProbFormula& formula = g->argument (fIdx);
bool toElim = true;
for (unsigned i = 0; i < query.size(); i++) {
if (formula.functor() == query[i].functor() &&
formula.arity() == query[i].arity()) {
g->constr()->moveToTop (formula.logVars());
if (g->constr()->containsTuple (query[i].args())) {
toElim = false;
break;
}
}
}
return toElim;
}
unsigned
CountingOperator::getCost (void)
{
unsigned cost = 0;
int fIdx = (*pfIter_)->indexOfLogVar (X_);
unsigned range = (*pfIter_)->range (fIdx);
unsigned size = (*pfIter_)->size() / range;
TinySet<unsigned> counts;
counts = (*pfIter_)->constr()->getConditionalCounts (X_);
for (unsigned i = 0; i < counts.size(); i++) {
cost += size * HistogramSet::nrHistograms (counts[i], range);
}
return cost;
}
void
CountingOperator::apply (void)
{
if ((*pfIter_)->constr()->isCountNormalized (X_)) {
(*pfIter_)->countConvert (X_);
} else {
Parfactor* pf = *pfIter_;
pfList_.remove (pfIter_);
Parfactors pfs = FoveSolver::countNormalize (pf, X_);
for (unsigned i = 0; i < pfs.size(); i++) {
unsigned condCount = pfs[i]->constr()->getConditionalCount (X_);
bool cartProduct = pfs[i]->constr()->isCarteesianProduct (
pfs[i]->countedLogVars() | X_);
if (condCount > 1 && cartProduct) {
pfs[i]->countConvert (X_);
}
pfList_.add (pfs[i]);
}
delete pf;
}
}
vector<CountingOperator*>
CountingOperator::getValidOps (ParfactorList& pfList)
{
vector<CountingOperator*> validOps;
ParfactorList::iterator it = pfList.begin();
while (it != pfList.end()) {
LogVarSet candidates = (*it)->uncountedLogVars();
for (unsigned i = 0; i < candidates.size(); i++) {
if (validOp (*it, candidates[i])) {
validOps.push_back (new CountingOperator (
it, candidates[i], pfList));
}
}
++ it;
}
return validOps;
}
string
CountingOperator::toString (void)
{
stringstream ss;
ss << "count convert " << X_ << " in " ;
ss << (*pfIter_)->getLabel();
ss << " [cost=" << getCost() << "]" << endl;
Parfactors pfs = FoveSolver::countNormalize (*pfIter_, X_);
if ((*pfIter_)->constr()->isCountNormalized (X_) == false) {
for (unsigned i = 0; i < pfs.size(); i++) {
ss << " º " << pfs[i]->getLabel() << endl;
}
}
for (unsigned i = 0; i < pfs.size(); i++) {
delete pfs[i];
}
return ss.str();
}
bool
CountingOperator::validOp (Parfactor* g, LogVar X)
{
if (g->nrFormulas (X) != 1) {
return false;
}
int fIdx = g->indexOfLogVar (X);
if (g->argument (fIdx).isCounting()) {
return false;
}
bool countNormalized = g->constr()->isCountNormalized (X);
if (countNormalized) {
unsigned condCount = g->constr()->getConditionalCount (X);
bool cartProduct = g->constr()->isCarteesianProduct (
g->countedLogVars() | X);
if (condCount == 1 || cartProduct == false) {
return false;
}
}
return true;
}
unsigned
GroundOperator::getCost (void)
{
unsigned cost = 0;
bool isCountingLv = (*pfIter_)->countedLogVars().contains (X_);
if (isCountingLv) {
int fIdx = (*pfIter_)->indexOfLogVar (X_);
unsigned currSize = (*pfIter_)->size();
unsigned nrHists = (*pfIter_)->range (fIdx);
unsigned range = (*pfIter_)->argument (fIdx).range();
unsigned nrSymbols = (*pfIter_)->constr()->getConditionalCount (X_);
cost = (currSize / nrHists) * (std::pow (range, nrSymbols));
} else {
cost = (*pfIter_)->constr()->nrSymbols (X_) * (*pfIter_)->size();
}
return cost;
}
void
GroundOperator::apply (void)
{
bool countedLv = (*pfIter_)->countedLogVars().contains (X_);
Parfactor* pf = *pfIter_;
pfList_.remove (pfIter_);
if (countedLv) {
pf->fullExpand (X_);
pfList_.add (pf);
} else {
ConstraintTrees cts = pf->constr()->ground (X_);
for (unsigned i = 0; i < cts.size(); i++) {
pfList_.add (new Parfactor (pf, cts[i]));
}
delete pf;
}
}
vector<GroundOperator*>
GroundOperator::getValidOps (ParfactorList& pfList)
{
vector<GroundOperator*> validOps;
ParfactorList::iterator pfIter = pfList.begin();
while (pfIter != pfList.end()) {
LogVarSet set = (*pfIter)->logVarSet();
for (unsigned i = 0; i < set.size(); i++) {
if ((*pfIter)->constr()->isSingleton (set[i]) == false) {
validOps.push_back (new GroundOperator (pfIter, set[i], pfList));
}
}
++ pfIter;
}
return validOps;
}
string
GroundOperator::toString (void)
{
stringstream ss;
((*pfIter_)->countedLogVars().contains (X_))
? ss << "full expanding "
: ss << "grounding " ;
ss << X_ << " in " << (*pfIter_)->getLabel();
ss << " [cost=" << getCost() << "]" << endl;
return ss.str();
}
Params
FoveSolver::getPosterioriOf (const Ground& query)
{
return getJointDistributionOf ({query});
}
Params
FoveSolver::getJointDistributionOf (const Grounds& query)
{
runSolver (query);
(*pfList_.begin())->normalize();
Params params = (*pfList_.begin())->params();
if (Globals::logDomain) {
Util::fromLog (params);
}
return params;
}
void
FoveSolver::absorveEvidence (
ParfactorList& pfList,
ObservedFormulas& obsFormulas)
{
for (unsigned i = 0; i < obsFormulas.size(); i++) {
Parfactors newPfs;
ParfactorList::iterator it = pfList.begin();
while (it != pfList.end()) {
Parfactor* pf = *it;
it = pfList.remove (it);
Parfactors absorvedPfs = absorve (obsFormulas[i], pf);
if (absorvedPfs.empty() == false) {
if (absorvedPfs.size() == 1 && absorvedPfs[0] == 0) {
// just remove pf;
} else {
Util::addToVector (newPfs, absorvedPfs);
}
delete pf;
} else {
it = pfList.insertShattered (it, pf);
++ it;
}
}
pfList.add (newPfs);
}
if (Constants::DEBUG >= 2 && obsFormulas.empty() == false) {
Util::printAsteriskLine();
cout << "AFTER EVIDENCE ABSORVED" << endl;
for (unsigned i = 0; i < obsFormulas.size(); i++) {
cout << " -> " << obsFormulas[i] << endl;
}
Util::printAsteriskLine();
pfList.print();
}
}
Parfactors
FoveSolver::countNormalize (
Parfactor* g,
const LogVarSet& set)
{
Parfactors normPfs;
if (set.empty()) {
normPfs.push_back (new Parfactor (*g));
} else {
ConstraintTrees normCts = g->constr()->countNormalize (set);
for (unsigned i = 0; i < normCts.size(); i++) {
normPfs.push_back (new Parfactor (g, normCts[i]));
}
}
return normPfs;
}
void
FoveSolver::runSolver (const Grounds& query)
{
shatterAgainstQuery (query);
runWeakBayesBall (query);
while (true) {
if (Constants::DEBUG >= 2) {
Util::printDashedLine();
pfList_.print();
LiftedOperator::printValidOps (pfList_, query);
}
LiftedOperator* op = getBestOperation (query);
if (op == 0) {
break;
}
if (Constants::DEBUG >= 2) {
cout << "best operation: " << op->toString() << endl;
}
op->apply();
delete op;
}
assert (pfList_.size() > 0);
if (pfList_.size() > 1) {
ParfactorList::iterator pfIter = pfList_.begin();
pfIter ++;
while (pfIter != pfList_.end()) {
(*pfList_.begin())->multiply (**pfIter);
++ pfIter;
}
}
(*pfList_.begin())->reorderAccordingGrounds (query);
}
LiftedOperator*
FoveSolver::getBestOperation (const Grounds& query)
{
unsigned bestCost;
LiftedOperator* bestOp = 0;
vector<LiftedOperator*> validOps;
validOps = LiftedOperator::getValidOps (pfList_, query);
for (unsigned i = 0; i < validOps.size(); i++) {
unsigned cost = validOps[i]->getCost();
if ((bestOp == 0) || (cost < bestCost)) {
bestOp = validOps[i];
bestCost = cost;
}
}
for (unsigned i = 0; i < validOps.size(); i++) {
if (validOps[i] != bestOp) {
delete validOps[i];
}
}
return bestOp;
}
void
FoveSolver::runWeakBayesBall (const Grounds& query)
{
queue<unsigned> todo; // groups to process
set<unsigned> done; // processed or in queue
for (unsigned i = 0; i < query.size(); i++) {
ParfactorList::iterator it = pfList_.begin();
while (it != pfList_.end()) {
int group = (*it)->findGroup (query[i]);
if (group != -1) {
todo.push (group);
done.insert (group);
break;
}
++ it;
}
}
set<Parfactor*> requiredPfs;
while (todo.empty() == false) {
unsigned group = todo.front();
ParfactorList::iterator it = pfList_.begin();
while (it != pfList_.end()) {
if (Util::contains (requiredPfs, *it) == false &&
(*it)->containsGroup (group)) {
vector<unsigned> groups = (*it)->getAllGroups();
for (unsigned i = 0; i < groups.size(); i++) {
if (Util::contains (done, groups[i]) == false) {
todo.push (groups[i]);
done.insert (groups[i]);
}
}
requiredPfs.insert (*it);
}
++ it;
}
todo.pop();
}
ParfactorList::iterator it = pfList_.begin();
while (it != pfList_.end()) {
if (Util::contains (requiredPfs, *it) == false) {
it = pfList_.removeAndDelete (it);
} else {
++ it;
}
}
if (Constants::DEBUG >= 2) {
Util::printHeader ("REQUIRED PARFACTORS");
pfList_.print();
}
}
void
FoveSolver::shatterAgainstQuery (const Grounds& query)
{
for (unsigned i = 0; i < query.size(); i++) {
if (query[i].isAtom()) {
continue;
}
bool found = false;
Parfactors newPfs;
ParfactorList::iterator it = pfList_.begin();
while (it != pfList_.end()) {
if ((*it)->containsGround (query[i])) {
found = true;
std::pair<ConstraintTree*, ConstraintTree*> split =
(*it)->constr()->split (query[i].args(), query[i].arity());
ConstraintTree* commCt = split.first;
ConstraintTree* exclCt = split.second;
newPfs.push_back (new Parfactor (*it, commCt));
if (exclCt->empty() == false) {
newPfs.push_back (new Parfactor (*it, exclCt));
} else {
delete exclCt;
}
it = pfList_.removeAndDelete (it);
} else {
++ it;
}
}
if (found == false) {
cerr << "error: could not find a parfactor with ground " ;
cerr << "`" << query[i] << "'" << endl;
exit (0);
}
pfList_.add (newPfs);
}
if (Constants::DEBUG >= 2) {
cout << endl;
Util::printAsteriskLine();
cout << "SHATTERED AGAINST THE QUERY" << endl;
for (unsigned i = 0; i < query.size(); i++) {
cout << " -> " << query[i] << endl;
}
Util::printAsteriskLine();
pfList_.print();
}
}
Parfactors
FoveSolver::absorve (
ObservedFormula& obsFormula,
Parfactor* g)
{
Parfactors absorvedPfs;
const ProbFormulas& formulas = g->arguments();
for (unsigned i = 0; i < formulas.size(); i++) {
if (obsFormula.functor() == formulas[i].functor() &&
obsFormula.arity() == formulas[i].arity()) {
if (obsFormula.isAtom()) {
if (formulas.size() > 1) {
g->absorveEvidence (formulas[i], obsFormula.evidence());
} else {
// hack to erase parfactor g
absorvedPfs.push_back (0);
}
break;
}
g->constr()->moveToTop (formulas[i].logVars());
std::pair<ConstraintTree*, ConstraintTree*> res
= g->constr()->split (&(obsFormula.constr()), formulas[i].arity());
ConstraintTree* commCt = res.first;
ConstraintTree* exclCt = res.second;
if (commCt->empty() == false) {
if (formulas.size() > 1) {
LogVarSet excl = g->exclusiveLogVars (i);
Parfactors countNormPfs = countNormalize (g, excl);
for (unsigned j = 0; j < countNormPfs.size(); j++) {
countNormPfs[j]->absorveEvidence (
formulas[i], obsFormula.evidence());
absorvedPfs.push_back (countNormPfs[j]);
}
} else {
delete commCt;
}
if (exclCt->empty() == false) {
absorvedPfs.push_back (new Parfactor (g, exclCt));
} else {
delete exclCt;
}
if (absorvedPfs.empty()) {
// hack to erase parfactor g
absorvedPfs.push_back (0);
}
break;
} else {
delete commCt;
delete exclCt;
}
}
}
return absorvedPfs;
}

View File

@ -1,76 +0,0 @@
#ifndef HORUS_HORUS_H
#define HORUS_HORUS_H
#include <limits>
#include <vector>
#define DISALLOW_COPY_AND_ASSIGN(TypeName) \
TypeName(const TypeName&); \
void operator=(const TypeName&)
using namespace std;
class Var;
class Factor;
class VarNode;
class FacNode;
typedef vector<double> Params;
typedef unsigned VarId;
typedef vector<VarId> VarIds;
typedef vector<Var*> Vars;
typedef vector<VarNode*> VarNodes;
typedef vector<FacNode*> FacNodes;
typedef vector<Factor*> Factors;
typedef vector<string> States;
typedef vector<unsigned> Ranges;
enum InfAlgorithms
{
VE, // variable elimination
BP, // belief propagation
CBP // counting belief propagation
};
namespace Globals {
extern bool logDomain;
extern InfAlgorithms infAlgorithm;
};
namespace Constants {
// level of debug information
const unsigned DEBUG = 0;
const int NO_EVIDENCE = -1;
// number of digits to show when printing a parameter
const unsigned PRECISION = 5;
const bool COLLECT_STATS = false;
};
namespace BpOptions
{
enum Schedule {
SEQ_FIXED,
SEQ_RANDOM,
PARALLEL,
MAX_RESIDUAL
};
extern Schedule schedule;
extern double accuracy;
extern unsigned maxIter;
}
#endif // HORUS_HORUS_H

View File

@ -1,171 +0,0 @@
#include <cstdlib>
#include <iostream>
#include <sstream>
#include "FactorGraph.h"
#include "VarElimSolver.h"
#include "BpSolver.h"
#include "CbpSolver.h"
using namespace std;
void processArguments (FactorGraph&, int, const char* []);
void runSolver (const FactorGraph&, const VarIds&);
const string USAGE = "usage: \
./hcli ve|bp|cbp NETWORK_FILE [VARIABLE | OBSERVED_VARIABLE=EVIDENCE]..." ;
int
main (int argc, const char* argv[])
{
if (argc <= 1) {
cerr << "error: no solver specified" << endl;
cerr << "error: no graphical model specified" << endl;
cerr << USAGE << endl;
exit (0);
}
if (argc <= 2) {
cerr << "error: no graphical model specified" << endl;
cerr << USAGE << endl;
exit (0);
}
string solver (argv[1]);
if (solver == "ve") {
Globals::infAlgorithm = InfAlgorithms::VE;
} else if (solver == "bp") {
Globals::infAlgorithm = InfAlgorithms::BP;
} else if (solver == "cbp") {
Globals::infAlgorithm = InfAlgorithms::CBP;
} else {
cerr << "error: unknow solver `" << solver << "'" << endl ;
cerr << USAGE << endl;
exit(0);
}
string fileName (argv[2]);
string extension = fileName.substr (
fileName.find_last_of ('.') + 1);
FactorGraph fg;
if (extension == "uai") {
fg.readFromUaiFormat (fileName.c_str());
} else if (extension == "fg") {
fg.readFromLibDaiFormat (fileName.c_str());
} else {
cerr << "error: the graphical model must be defined either " ;
cerr << "in a UAI or libDAI file" << endl;
exit (0);
}
processArguments (fg, argc, argv);
return 0;
}
void
processArguments (FactorGraph& fg, int argc, const char* argv[])
{
VarIds queryIds;
for (int i = 3; i < argc; i++) {
const string& arg = argv[i];
if (arg.find ('=') == std::string::npos) {
if (!Util::isInteger (arg)) {
cerr << "error: `" << arg << "' " ;
cerr << "is not a valid variable id" ;
cerr << endl;
exit (0);
}
VarId vid;
stringstream ss;
ss << arg;
ss >> vid;
VarNode* queryVar = fg.getVarNode (vid);
if (queryVar) {
queryIds.push_back (vid);
} else {
cerr << "error: there isn't a variable with " ;
cerr << "`" << vid << "' as id" ;
cerr << endl;
exit (0);
}
} else {
size_t pos = arg.find ('=');
if (arg.substr (0, pos).empty()) {
cerr << "error: missing left argument" << endl;
cerr << USAGE << endl;
exit (0);
}
if (arg.substr (pos + 1).empty()) {
cerr << "error: missing right argument" << endl;
cerr << USAGE << endl;
exit (0);
}
if (!Util::isInteger (arg.substr (0, pos))) {
cerr << "error: `" << arg.substr (0, pos) << "' " ;
cerr << "is not a variable id" ;
cerr << endl;
exit (0);
}
VarId vid;
stringstream ss;
ss << arg.substr (0, pos);
ss >> vid;
VarNode* var = fg.getVarNode (vid);
if (var) {
if (!Util::isInteger (arg.substr (pos + 1))) {
cerr << "error: `" << arg.substr (pos + 1) << "' " ;
cerr << "is not a state index" ;
cerr << endl;
exit (0);
}
int stateIndex;
stringstream ss;
ss << arg.substr (pos + 1);
ss >> stateIndex;
if (var->isValidState (stateIndex)) {
var->setEvidence (stateIndex);
} else {
cerr << "error: `" << stateIndex << "' " ;
cerr << "is not a valid state index for variable " ;
cerr << "`" << var->varId() << "'" ;
cerr << endl;
exit (0);
}
} else {
cerr << "error: there isn't a variable with " ;
cerr << "`" << vid << "' as id" ;
cerr << endl;
exit (0);
}
}
}
runSolver (fg, queryIds);
}
void
runSolver (const FactorGraph& fg, const VarIds& queryIds)
{
Solver* solver = 0;
switch (Globals::infAlgorithm) {
case InfAlgorithms::VE:
solver = new VarElimSolver (fg);
break;
case InfAlgorithms::BP:
solver = new BpSolver (fg);
break;
case InfAlgorithms::CBP:
solver = new CbpSolver (fg);
break;
default:
assert (false);
}
if (queryIds.size() == 0) {
solver->printAllPosterioris();
} else {
solver->printAnswer (queryIds);
}
delete solver;
}

View File

@ -1,296 +0,0 @@
#ifndef HORUS_STATESINDEXER_H
#define HORUS_STATESINDEXER_H
#include <algorithm>
#include <numeric>
#include <functional>
#include <sstream>
#include <iomanip>
#include "Var.h"
#include "Util.h"
class StatesIndexer
{
public:
StatesIndexer (const Ranges& ranges, bool calcOffsets = true)
{
size_ = 1;
indices_.resize (ranges.size(), 0);
ranges_ = ranges;
for (unsigned i = 0; i < ranges.size(); i++) {
size_ *= ranges[i];
}
li_ = 0;
if (calcOffsets) {
calculateOffsets();
}
}
StatesIndexer (const Vars& vars, bool calcOffsets = true)
{
size_ = 1;
indices_.resize (vars.size(), 0);
ranges_.reserve (vars.size());
for (unsigned i = 0; i < vars.size(); i++) {
ranges_.push_back (vars[i]->range());
size_ *= vars[i]->range();
}
li_ = 0;
if (calcOffsets) {
calculateOffsets();
}
}
void increment (void)
{
for (int i = ranges_.size() - 1; i >= 0; i--) {
indices_[i] ++;
if (indices_[i] != ranges_[i]) {
break;
} else {
indices_[i] = 0;
}
}
li_ ++;
}
void increment (unsigned dim)
{
assert (dim < ranges_.size());
assert (ranges_.size() == offsets_.size());
assert (indices_[dim] < ranges_[dim]);
indices_[dim] ++;
li_ += offsets_[dim];
}
void incrementExcluding (unsigned skipDim)
{
assert (ranges_.size() == offsets_.size());
for (int i = ranges_.size() - 1; i >= 0; i--) {
if (i != (int)skipDim) {
indices_[i] ++;
li_ += offsets_[i];
if (indices_[i] != ranges_[i]) {
return;
} else {
indices_[i] = 0;
li_ -= offsets_[i] * ranges_[i];
}
}
}
li_ = size_;
}
unsigned linearIndex (void) const
{
return li_;
}
const vector<unsigned>& indices (void) const
{
return indices_;
}
StatesIndexer& operator ++ (void)
{
increment();
return *this;
}
operator unsigned (void) const
{
return li_;
}
unsigned operator[] (unsigned dim) const
{
assert (valid());
assert (dim < ranges_.size());
return indices_[dim];
}
bool valid (void) const
{
return li_ < size_;
}
void reset (void)
{
std::fill (indices_.begin(), indices_.end(), 0);
li_ = 0;
}
void reset (unsigned dim)
{
indices_[dim] = 0;
li_ -= offsets_[dim] * ranges_[dim];
}
unsigned size (void) const
{
return size_ ;
}
friend ostream& operator<< (ostream &os, const StatesIndexer& idx)
{
os << "(" << std::setw (2) << std::setfill('0') << idx.li_ << ") " ;
os << idx.indices_;
return os;
}
private:
void calculateOffsets (void)
{
unsigned prod = 1;
offsets_.resize (ranges_.size());
for (int i = ranges_.size() - 1; i >= 0; i--) {
offsets_[i] = prod;
prod *= ranges_[i];
}
}
unsigned li_;
unsigned size_;
vector<unsigned> indices_;
vector<unsigned> ranges_;
vector<unsigned> offsets_;
};
class MapIndexer
{
public:
MapIndexer (const Ranges& ranges, const vector<bool>& mapDims)
{
assert (ranges.size() == mapDims.size());
unsigned prod = 1;
offsets_.resize (ranges.size());
for (int i = ranges.size() - 1; i >= 0; i--) {
if (mapDims[i]) {
offsets_[i] = prod;
prod *= ranges[i];
}
}
indices_.resize (ranges.size(), 0);
ranges_ = ranges;
index_ = 0;
valid_ = true;
}
MapIndexer (const Ranges& ranges, unsigned ignoreDim)
{
unsigned prod = 1;
offsets_.resize (ranges.size());
for (int i = ranges.size() - 1; i >= 0; i--) {
if (i != (int)ignoreDim) {
offsets_[i] = prod;
prod *= ranges[i];
}
}
indices_.resize (ranges.size(), 0);
ranges_ = ranges;
index_ = 0;
valid_ = true;
}
/*
MapIndexer (
const VarIds& loopVids,
const Ranges& loopRanges,
const VarIds& mapVids,
const Ranges& mapRanges)
{
unsigned prod = 1;
vector<unsigned> offsets (mapRanges.size());
for (int i = mapRanges.size() - 1; i >= 0; i--) {
offsets[i] = prod;
prod *= mapRanges[i];
}
offsets_.reserve (loopVids.size());
for (unsigned i = 0; i < loopVids.size(); i++) {
VarIds::const_iterator it =
std::find (mapVids.begin(), mapVids.end(), loopVids[i]);
if (it != mapVids.end()) {
offsets_.push_back (offsets[it - mapVids.begin()]);
} else {
offsets_.push_back (0);
}
}
indices_.resize (loopVids.size(), 0);
ranges_ = loopRanges;
index_ = 0;
size_ = prod;
}
*/
MapIndexer& operator ++ (void)
{
assert (valid_);
for (int i = ranges_.size() - 1; i >= 0; i--) {
indices_[i] ++;
index_ += offsets_[i];
if (indices_[i] != ranges_[i]) {
return *this;
} else {
indices_[i] = 0;
index_ -= offsets_[i] * ranges_[i];
}
}
valid_ = false;
return *this;
}
unsigned mappedIndex (void) const
{
return index_;
}
operator unsigned (void) const
{
return index_;
}
unsigned operator[] (unsigned dim) const
{
assert (valid());
assert (dim < ranges_.size());
return indices_[dim];
}
bool valid (void) const
{
return valid_;
}
void reset (void)
{
std::fill (indices_.begin(), indices_.end(), 0);
index_ = 0;
}
friend ostream& operator<< (ostream &os, const MapIndexer& idx)
{
os << "(" << std::setw (2) << std::setfill('0') << idx.index_ << ") " ;
os << idx.indices_;
return os;
}
private:
unsigned index_;
bool valid_;
vector<unsigned> ranges_;
vector<unsigned> indices_;
vector<unsigned> offsets_;
};
#endif // HORUS_STATESINDEXER_H

View File

@ -1,685 +0,0 @@
#include "Parfactor.h"
#include "Histogram.h"
#include "Indexer.h"
#include "Util.h"
#include "Horus.h"
Parfactor::Parfactor (
const ProbFormulas& formulas,
const Params& params,
const Tuples& tuples,
unsigned distId)
{
args_ = formulas;
params_ = params;
distId_ = distId;
LogVars logVars;
for (unsigned i = 0; i < args_.size(); i++) {
ranges_.push_back (args_[i].range());
const LogVars& lvs = args_[i].logVars();
for (unsigned j = 0; j < lvs.size(); j++) {
if (Util::contains (logVars, lvs[j]) == false) {
logVars.push_back (lvs[j]);
}
}
}
constr_ = new ConstraintTree (logVars, tuples);
assert (params_.size() == Util::expectedSize (ranges_));
}
Parfactor::Parfactor (const Parfactor* g, const Tuple& tuple)
{
args_ = g->arguments();
params_ = g->params();
ranges_ = g->ranges();
distId_ = g->distId();
constr_ = new ConstraintTree (g->logVars(), {tuple});
assert (params_.size() == Util::expectedSize (ranges_));
}
Parfactor::Parfactor (const Parfactor* g, ConstraintTree* constr)
{
args_ = g->arguments();
params_ = g->params();
ranges_ = g->ranges();
distId_ = g->distId();
constr_ = constr;
assert (params_.size() == Util::expectedSize (ranges_));
}
Parfactor::Parfactor (const Parfactor& g)
{
args_ = g.arguments();
params_ = g.params();
ranges_ = g.ranges();
distId_ = g.distId();
constr_ = new ConstraintTree (*g.constr());
assert (params_.size() == Util::expectedSize (ranges_));
}
Parfactor::~Parfactor (void)
{
delete constr_;
}
LogVarSet
Parfactor::countedLogVars (void) const
{
LogVarSet set;
for (unsigned i = 0; i < args_.size(); i++) {
if (args_[i].isCounting()) {
set.insert (args_[i].countedLogVar());
}
}
return set;
}
LogVarSet
Parfactor::uncountedLogVars (void) const
{
return constr_->logVarSet() - countedLogVars();
}
LogVarSet
Parfactor::elimLogVars (void) const
{
LogVarSet requiredToElim = constr_->logVarSet();
requiredToElim -= constr_->singletons();
requiredToElim -= countedLogVars();
return requiredToElim;
}
LogVarSet
Parfactor::exclusiveLogVars (unsigned fIdx) const
{
assert (fIdx < args_.size());
LogVarSet remaining;
for (unsigned i = 0; i < args_.size(); i++) {
if (i != fIdx) {
remaining |= args_[i].logVarSet();
}
}
return args_[fIdx].logVarSet() - remaining;
}
void
Parfactor::setConstraintTree (ConstraintTree* newTree)
{
delete constr_;
constr_ = newTree;
}
void
Parfactor::sumOut (unsigned fIdx)
{
assert (fIdx < args_.size());
assert (args_[fIdx].contains (elimLogVars()));
LogVarSet excl = exclusiveLogVars (fIdx);
if (args_[fIdx].isCounting()) {
LogAware::pow (params_, constr_->getConditionalCount (
excl - args_[fIdx].countedLogVar()));
} else {
LogAware::pow (params_, constr_->getConditionalCount (excl));
}
if (args_[fIdx].isCounting()) {
unsigned N = constr_->getConditionalCount (
args_[fIdx].countedLogVar());
unsigned R = args_[fIdx].range();
vector<double> numAssigns = HistogramSet::getNumAssigns (N, R);
StatesIndexer sindexer (ranges_, fIdx);
while (sindexer.valid()) {
unsigned h = sindexer[fIdx];
if (Globals::logDomain) {
params_[sindexer] += numAssigns[h];
} else {
params_[sindexer] *= numAssigns[h];
}
++ sindexer;
}
}
Params copy = params_;
params_.clear();
params_.resize (copy.size() / ranges_[fIdx], LogAware::addIdenty());
MapIndexer indexer (ranges_, fIdx);
if (Globals::logDomain) {
for (unsigned i = 0; i < copy.size(); i++) {
params_[indexer] = Util::logSum (params_[indexer], copy[i]);
++ indexer;
}
} else {
for (unsigned i = 0; i < copy.size(); i++) {
params_[indexer] += copy[i];
++ indexer;
}
}
args_.erase (args_.begin() + fIdx);
ranges_.erase (ranges_.begin() + fIdx);
constr_->remove (excl);
}
void
Parfactor::multiply (Parfactor& g)
{
alignAndExponentiate (this, &g);
TFactor<ProbFormula>::multiply (g);
constr_->join (g.constr(), true);
}
void
Parfactor::countConvert (LogVar X)
{
int fIdx = indexOfLogVar (X);
assert (fIdx != -1);
assert (constr_->isCountNormalized (X));
assert (constr_->getConditionalCount (X) > 1);
assert (constr_->isCarteesianProduct (countedLogVars() | X));
unsigned N = constr_->getConditionalCount (X);
unsigned R = ranges_[fIdx];
unsigned H = HistogramSet::nrHistograms (N, R);
vector<Histogram> histograms = HistogramSet::getHistograms (N, R);
StatesIndexer indexer (ranges_);
vector<Params> sumout (params_.size() / R);
unsigned count = 0;
while (indexer.valid()) {
sumout[count].reserve (R);
for (unsigned r = 0; r < R; r++) {
sumout[count].push_back (params_[indexer]);
indexer.increment (fIdx);
}
count ++;
indexer.reset (fIdx);
indexer.incrementExcluding (fIdx);
}
params_.clear();
params_.reserve (sumout.size() * H);
ranges_[fIdx] = H;
MapIndexer mapIndexer (ranges_, fIdx);
while (mapIndexer.valid()) {
double prod = LogAware::multIdenty();
unsigned i = mapIndexer.mappedIndex();
unsigned h = mapIndexer[fIdx];
for (unsigned r = 0; r < R; r++) {
if (Globals::logDomain) {
prod += LogAware::pow (sumout[i][r], histograms[h][r]);
} else {
prod *= LogAware::pow (sumout[i][r], histograms[h][r]);
}
}
params_.push_back (prod);
++ mapIndexer;
}
args_[fIdx].setCountedLogVar (X);
}
void
Parfactor::expand (LogVar X, LogVar X_new1, LogVar X_new2)
{
int fIdx = indexOfLogVar (X);
assert (fIdx != -1);
assert (args_[fIdx].isCounting());
unsigned N1 = constr_->getConditionalCount (X_new1);
unsigned N2 = constr_->getConditionalCount (X_new2);
unsigned N = N1 + N2;
unsigned R = args_[fIdx].range();
unsigned H1 = HistogramSet::nrHistograms (N1, R);
unsigned H2 = HistogramSet::nrHistograms (N2, R);
vector<Histogram> histograms = HistogramSet::getHistograms (N, R);
vector<Histogram> histograms1 = HistogramSet::getHistograms (N1, R);
vector<Histogram> histograms2 = HistogramSet::getHistograms (N2, R);
vector<unsigned> sumIndexes;
sumIndexes.reserve (H1 * H2);
for (unsigned i = 0; i < H1; i++) {
for (unsigned j = 0; j < H2; j++) {
Histogram hist = histograms1[i];
std::transform (
hist.begin(), hist.end(),
histograms2[j].begin(),
hist.begin(),
plus<int>());
sumIndexes.push_back (HistogramSet::findIndex (hist, histograms));
}
}
expandPotential (fIdx, H1 * H2, sumIndexes);
args_.insert (args_.begin() + fIdx + 1, args_[fIdx]);
args_[fIdx].rename (X, X_new1);
args_[fIdx + 1].rename (X, X_new2);
ranges_.insert (ranges_.begin() + fIdx + 1, H2);
ranges_[fIdx] = H1;
}
void
Parfactor::fullExpand (LogVar X)
{
int fIdx = indexOfLogVar (X);
assert (fIdx != -1);
assert (args_[fIdx].isCounting());
unsigned N = constr_->getConditionalCount (X);
unsigned R = args_[fIdx].range();
vector<Histogram> originHists = HistogramSet::getHistograms (N, R);
vector<Histogram> expandHists = HistogramSet::getHistograms (1, R);
vector<unsigned> sumIndexes;
sumIndexes.reserve (N * R);
Ranges expandRanges (N, R);
StatesIndexer indexer (expandRanges);
while (indexer.valid()) {
vector<unsigned> hist (R, 0);
for (unsigned n = 0; n < N; n++) {
std::transform (
hist.begin(), hist.end(),
expandHists[indexer[n]].begin(),
hist.begin(),
plus<int>());
}
sumIndexes.push_back (HistogramSet::findIndex (hist, originHists));
++ indexer;
}
expandPotential (fIdx, std::pow (R, N), sumIndexes);
ProbFormula f = args_[fIdx];
args_.erase (args_.begin() + fIdx);
ranges_.erase (ranges_.begin() + fIdx);
LogVars newLvs = constr_->expand (X);
assert (newLvs.size() == N);
for (unsigned i = 0 ; i < N; i++) {
ProbFormula newFormula (f.functor(), f.logVars(), f.range());
newFormula.rename (X, newLvs[i]);
args_.insert (args_.begin() + fIdx + i, newFormula);
ranges_.insert (ranges_.begin() + fIdx + i, R);
}
}
void
Parfactor::reorderAccordingGrounds (const Grounds& grounds)
{
ProbFormulas newFormulas;
for (unsigned i = 0; i < grounds.size(); i++) {
for (unsigned j = 0; j < args_.size(); j++) {
if (grounds[i].functor() == args_[j].functor() &&
grounds[i].arity() == args_[j].arity()) {
constr_->moveToTop (args_[j].logVars());
if (constr_->containsTuple (grounds[i].args())) {
newFormulas.push_back (args_[j]);
break;
}
}
}
assert (newFormulas.size() == i + 1);
}
reorderArguments (newFormulas);
}
void
Parfactor::absorveEvidence (const ProbFormula& formula, unsigned evidence)
{
int fIdx = indexOf (formula);
assert (fIdx != -1);
LogVarSet excl = exclusiveLogVars (fIdx);
assert (args_[fIdx].isCounting() == false);
assert (constr_->isCountNormalized (excl));
LogAware::pow (params_, constr_->getConditionalCount (excl));
TFactor<ProbFormula>::absorveEvidence (formula, evidence);
constr_->remove (excl);
}
void
Parfactor::setNewGroups (void)
{
for (unsigned i = 0; i < args_.size(); i++) {
args_[i].setGroup (ProbFormula::getNewGroup());
}
}
void
Parfactor::applySubstitution (const Substitution& theta)
{
for (unsigned i = 0; i < args_.size(); i++) {
LogVars& lvs = args_[i].logVars();
for (unsigned j = 0; j < lvs.size(); j++) {
lvs[j] = theta.newNameFor (lvs[j]);
}
if (args_[i].isCounting()) {
LogVar clv = args_[i].countedLogVar();
args_[i].setCountedLogVar (theta.newNameFor (clv));
}
}
constr_->applySubstitution (theta);
}
int
Parfactor::findGroup (const Ground& ground) const
{
int group = -1;
for (unsigned i = 0; i < args_.size(); i++) {
if (args_[i].functor() == ground.functor() &&
args_[i].arity() == ground.arity()) {
constr_->moveToTop (args_[i].logVars());
if (constr_->containsTuple (ground.args())) {
group = args_[i].group();
break;
}
}
}
return group;
}
bool
Parfactor::containsGround (const Ground& ground) const
{
return findGroup (ground) != -1;
}
bool
Parfactor::containsGroup (unsigned group) const
{
for (unsigned i = 0; i < args_.size(); i++) {
if (args_[i].group() == group) {
return true;
}
}
return false;
}
unsigned
Parfactor::nrFormulas (LogVar X) const
{
unsigned count = 0;
for (unsigned i = 0; i < args_.size(); i++) {
if (args_[i].contains (X)) {
count ++;
}
}
return count;
}
int
Parfactor::indexOfLogVar (LogVar X) const
{
int idx = -1;
assert (nrFormulas (X) == 1);
for (unsigned i = 0; i < args_.size(); i++) {
if (args_[i].contains (X)) {
idx = i;
break;
}
}
return idx;
}
int
Parfactor::indexOfGroup (unsigned group) const
{
int pos = -1;
for (unsigned i = 0; i < args_.size(); i++) {
if (args_[i].group() == group) {
pos = i;
break;
}
}
return pos;
}
vector<unsigned>
Parfactor::getAllGroups (void) const
{
vector<unsigned> groups (args_.size());
for (unsigned i = 0; i < args_.size(); i++) {
groups[i] = args_[i].group();
}
return groups;
}
string
Parfactor::getLabel (void) const
{
stringstream ss;
ss << "phi(" ;
for (unsigned i = 0; i < args_.size(); i++) {
if (i != 0) ss << "," ;
ss << args_[i];
}
ss << ")" ;
ConstraintTree copy (*constr_);
copy.moveToTop (copy.logVarSet().elements());
ss << "|" << copy.tupleSet();
return ss.str();
}
void
Parfactor::print (bool printParams) const
{
cout << "Formulas: " ;
for (unsigned i = 0; i < args_.size(); i++) {
if (i != 0) cout << ", " ;
cout << args_[i];
}
cout << endl;
if (args_[0].group() != Util::maxUnsigned()) {
vector<string> groups;
for (unsigned i = 0; i < args_.size(); i++) {
groups.push_back (string ("g") + Util::toString (args_[i].group()));
}
cout << "Groups: " << groups << endl;
}
cout << "LogVars: " << constr_->logVarSet() << endl;
cout << "Ranges: " << ranges_ << endl;
if (printParams == false) {
cout << "Params: " << params_ << endl;
}
ConstraintTree copy (*constr_);
copy.moveToTop (copy.logVarSet().elements());
cout << "Tuples: " << copy.tupleSet() << endl;
if (printParams) {
vector<string> jointStrings;
StatesIndexer indexer (ranges_);
while (indexer.valid()) {
stringstream ss;
for (unsigned i = 0; i < args_.size(); i++) {
if (i != 0) ss << ", " ;
if (args_[i].isCounting()) {
unsigned N = constr_->getConditionalCount (
args_[i].countedLogVar());
HistogramSet hs (N, args_[i].range());
unsigned c = 0;
while (c < indexer[i]) {
hs.nextHistogram();
c ++;
}
ss << hs;
} else {
ss << indexer[i];
}
}
jointStrings.push_back (ss.str());
++ indexer;
}
for (unsigned i = 0; i < params_.size(); i++) {
cout << "f(" << jointStrings[i] << ")" ;
cout << " = " << params_[i] << endl;
}
}
cout << endl;
}
void
Parfactor::expandPotential (
int fIdx,
unsigned newRange,
const vector<unsigned>& sumIndexes)
{
unsigned size = (params_.size() / ranges_[fIdx]) * newRange;
Params copy = params_;
params_.clear();
params_.reserve (size);
unsigned prod = 1;
vector<unsigned> offsets_ (ranges_.size());
for (int i = ranges_.size() - 1; i >= 0; i--) {
offsets_[i] = prod;
prod *= ranges_[i];
}
unsigned index = 0;
ranges_[fIdx] = newRange;
vector<unsigned> indices (ranges_.size(), 0);
for (unsigned k = 0; k < size; k++) {
params_.push_back (copy[index]);
for (int i = ranges_.size() - 1; i >= 0; i--) {
indices[i] ++;
if (i == fIdx) {
assert (indices[i] - 1 < sumIndexes.size());
int diff = sumIndexes[indices[i]] - sumIndexes[indices[i] - 1];
index += diff * offsets_[i];
} else {
index += offsets_[i];
}
if (indices[i] != ranges_[i]) {
break;
} else {
if (i == fIdx) {
int diff = sumIndexes[0] - sumIndexes[indices[i]];
index += diff * offsets_[i];
} else {
index -= offsets_[i] * ranges_[i];
}
indices[i] = 0;
}
}
}
}
void
Parfactor::alignAndExponentiate (Parfactor* g1, Parfactor* g2)
{
LogVars X_1, X_2;
const ProbFormulas& formulas1 = g1->arguments();
const ProbFormulas& formulas2 = g2->arguments();
for (unsigned i = 0; i < formulas1.size(); i++) {
for (unsigned j = 0; j < formulas2.size(); j++) {
if (formulas1[i].group() == formulas2[j].group()) {
Util::addToVector (X_1, formulas1[i].logVars());
Util::addToVector (X_2, formulas2[j].logVars());
}
}
}
LogVarSet Y_1 = g1->logVarSet() - LogVarSet (X_1);
LogVarSet Y_2 = g2->logVarSet() - LogVarSet (X_2);
assert (g1->constr()->isCountNormalized (Y_1));
assert (g2->constr()->isCountNormalized (Y_2));
unsigned condCount1 = g1->constr()->getConditionalCount (Y_1);
unsigned condCount2 = g2->constr()->getConditionalCount (Y_2);
LogAware::pow (g1->params(), 1.0 / condCount2);
LogAware::pow (g2->params(), 1.0 / condCount1);
// this must be done in the end or else X_1 and X_2
// will refer the old log var names in the code above
align (g1, X_1, g2, X_2);
}
void
Parfactor::align (
Parfactor* g1, const LogVars& alignLvs1,
Parfactor* g2, const LogVars& alignLvs2)
{
LogVar freeLogVar = 0;
Substitution theta1;
Substitution theta2;
const LogVarSet& allLvs1 = g1->logVarSet();
for (unsigned i = 0; i < allLvs1.size(); i++) {
theta1.add (allLvs1[i], freeLogVar);
++ freeLogVar;
}
const LogVarSet& allLvs2 = g2->logVarSet();
for (unsigned i = 0; i < allLvs2.size(); i++) {
theta2.add (allLvs2[i], freeLogVar);
++ freeLogVar;
}
assert (alignLvs1.size() == alignLvs2.size());
for (unsigned i = 0; i < alignLvs1.size(); i++) {
theta1.rename (alignLvs1[i], theta2.newNameFor (alignLvs2[i]));
}
g1->applySubstitution (theta1);
g2->applySubstitution (theta2);
}

View File

@ -1,37 +0,0 @@
#include "Solver.h"
#include "Util.h"
void
Solver::printAnswer (const VarIds& vids)
{
Vars unobservedVars;
VarIds unobservedVids;
for (unsigned i = 0; i < vids.size(); i++) {
VarNode* vn = fg.getVarNode (vids[i]);
if (vn->hasEvidence() == false) {
unobservedVars.push_back (vn);
unobservedVids.push_back (vids[i]);
}
}
Params res = solveQuery (unobservedVids);
vector<string> stateLines = Util::getStateLines (unobservedVars);
for (unsigned i = 0; i < res.size(); i++) {
cout << "P(" << stateLines[i] << ") = " ;
cout << std::setprecision (Constants::PRECISION) << res[i];
cout << endl;
}
cout << endl;
}
void
Solver::printAllPosterioris (void)
{
const VarNodes& vars = fg.varNodes();
for (unsigned i = 0; i < vars.size(); i++) {
printAnswer ({vars[i]->varId()});
}
}

View File

@ -1,4 +0,0 @@
TODO
- add a way to calculate combinations and factorials with large numbers
- refactor sumOut in parfactor -> is really ugly code
- Indexer: start receiving ranges as constant reference

View File

@ -1,200 +0,0 @@
#ifndef HORUS_TINYSET_H
#define HORUS_TINYSET_H
#include <vector>
#include <algorithm>
using namespace std;
template <typename T>
class TinySet
{
public:
TinySet (void) {}
TinySet (const T& t)
{
elements_.push_back (t);
}
TinySet (const vector<T>& elements)
{
elements_.reserve (elements.size());
for (unsigned i = 0; i < elements.size(); i++) {
insert (elements[i]);
}
}
TinySet (const TinySet<T>& s) : elements_(s.elements_) { }
void insert (const T& t)
{
typename vector<T>::iterator it =
std::lower_bound (elements_.begin(), elements_.end(), t);
if (it == elements_.end() || *it != t) {
elements_.insert (it, t);
}
}
void remove (const T& t)
{
typename vector<T>::iterator it =
std::lower_bound (elements_.begin(), elements_.end(), t);
if (it != elements_.end()) {
elements_.erase (it);
}
}
/* set union */
TinySet operator| (const TinySet& s) const
{
TinySet res;
std::set_union (
elements_.begin(),
elements_.end(),
s.elements_.begin(),
s.elements_.end(),
std::back_inserter (res.elements_));
return res;
}
/* set intersection */
TinySet operator& (const TinySet& s) const
{
TinySet res;
std::set_intersection (
elements_.begin(),
elements_.end(),
s.elements_.begin(),
s.elements_.end(),
std::back_inserter (res.elements_));
return res;
}
/* set difference */
TinySet operator- (const TinySet& s) const
{
TinySet res;
std::set_difference (
elements_.begin(),
elements_.end(),
s.elements_.begin(),
s.elements_.end(),
std::back_inserter (res.elements_));
return res;
}
TinySet& operator|= (const TinySet& s)
{
return *this = (*this | s);
}
TinySet& operator&= (const TinySet& s)
{
return *this = (*this & s);
}
TinySet& operator-= (const TinySet& s)
{
return *this = (*this - s);
}
bool contains (const T& t) const
{
return std::binary_search (
elements_.begin(), elements_.end(), t);
}
bool contains (const TinySet& s) const
{
return std::includes (
elements_.begin(),
elements_.end(),
s.elements_.begin(),
s.elements_.end());
}
bool in (const TinySet& s) const
{
return std::includes (
s.elements_.begin(),
s.elements_.end(),
elements_.begin(),
elements_.end());
}
bool intersects (const TinySet& s) const
{
return (*this & s).size() > 0;
}
T operator[] (unsigned i) const
{
return elements_[i];
}
const vector<T>& elements (void) const
{
return elements_;
}
T front (void) const
{
return elements_.front();
}
T back (void) const
{
return elements_.back();
}
unsigned size (void) const
{
return elements_.size();
}
bool empty (void) const
{
return elements_.size() == 0;
}
typedef typename std::vector<T>::const_iterator const_iterator;
const_iterator begin (void) const
{
return elements_.begin();
}
const_iterator end (void) const
{
return elements_.end();
}
friend bool operator== (const TinySet& s1, const TinySet& s2)
{
return s1.elements_ == s2.elements_;
}
friend bool operator!= (const TinySet& s1, const TinySet& s2)
{
return s1.elements_ != s2.elements_;
}
friend std::ostream& operator << (std::ostream& out, const TinySet<T>& s)
{
out << "{" ;
for (unsigned i = 0; i < s.size(); i++) {
out << ((i != 0) ? "," : "") << s.elements()[i];
}
out << "}" ;
return out;
}
protected:
vector<T> elements_;
};
#endif // HORUS_TINYSET_H

View File

@ -1,502 +0,0 @@
#include <limits>
#include <sstream>
#include <fstream>
#include "Util.h"
#include "Indexer.h"
namespace Globals {
bool logDomain = false;
//InfAlgs infAlgorithm = InfAlgorithms::VE;
//InfAlgs infAlgorithm = InfAlgorithms::BN_BP;
//InfAlgs infAlgorithm = InfAlgorithms::FG_BP;
InfAlgorithms infAlgorithm = InfAlgorithms::CBP;
};
namespace BpOptions {
Schedule schedule = BpOptions::Schedule::SEQ_FIXED;
//Schedule schedule = BpOptions::Schedule::SEQ_RANDOM;
//Schedule schedule = BpOptions::Schedule::PARALLEL;
//Schedule schedule = BpOptions::Schedule::MAX_RESIDUAL;
double accuracy = 0.0001;
unsigned maxIter = 1000;
}
vector<NetInfo> Statistics::netInfo_;
vector<CompressInfo> Statistics::compressInfo_;
unsigned Statistics::primaryNetCount_;
namespace Util {
void
toLog (Params& v)
{
for (unsigned i = 0; i < v.size(); i++) {
v[i] = log (v[i]);
}
}
void
fromLog (Params& v)
{
for (unsigned i = 0; i < v.size(); i++) {
v[i] = exp (v[i]);
}
}
double
factorial (double num)
{
double result = 1.0;
for (int i = 1; i <= num; i++) {
result *= i;
}
return result;
}
unsigned
nrCombinations (unsigned n, unsigned r)
{
assert (n >= r);
unsigned prod = 1;
for (int i = (int)n; i > (int)(n - r); i--) {
prod *= i;
}
return (prod / factorial (r));
}
unsigned
expectedSize (const Ranges& ranges)
{
unsigned prod = 1;
for (unsigned i = 0; i < ranges.size(); i++) {
prod *= ranges[i];
}
return prod;
}
unsigned
getNumberOfDigits (int number)
{
unsigned count = 1;
while (number >= 10) {
number /= 10;
count ++;
}
return count;
}
bool
isInteger (const string& s)
{
stringstream ss1 (s);
stringstream ss2;
int integer;
ss1 >> integer;
ss2 << integer;
return (ss1.str() == ss2.str());
}
string
parametersToString (const Params& v, unsigned precision)
{
stringstream ss;
ss.precision (precision);
ss << "[" ;
for (unsigned i = 0; i < v.size(); i++) {
if (i != 0) ss << ", " ;
ss << v[i];
}
ss << "]" ;
return ss.str();
}
vector<string>
getStateLines (const Vars& vars)
{
StatesIndexer idx (vars);
vector<string> jointStrings;
while (idx.valid()) {
stringstream ss;
for (unsigned i = 0; i < vars.size(); i++) {
if (i != 0) ss << ", " ;
ss << vars[i]->label() << "=" << vars[i]->states()[(idx[i])];
}
jointStrings.push_back (ss.str());
++ idx;
}
return jointStrings;
}
void
printHeader (string header, std::ostream& os)
{
printAsteriskLine (os);
os << header << endl;
printAsteriskLine (os);
}
void
printSubHeader (string header, std::ostream& os)
{
printDashedLine (os);
os << header << endl;
printDashedLine (os);
}
void
printAsteriskLine (std::ostream& os)
{
os << "********************************" ;
os << "********************************" ;
os << endl;
}
void
printDashedLine (std::ostream& os)
{
os << "--------------------------------" ;
os << "--------------------------------" ;
os << endl;
}
}
namespace LogAware {
void
normalize (Params& v)
{
double sum;
if (Globals::logDomain) {
sum = LogAware::addIdenty();
for (unsigned i = 0; i < v.size(); i++) {
sum = Util::logSum (sum, v[i]);
}
assert (sum != -numeric_limits<double>::infinity());
for (unsigned i = 0; i < v.size(); i++) {
v[i] -= sum;
}
} else {
sum = 0.0;
for (unsigned i = 0; i < v.size(); i++) {
sum += v[i];
}
assert (sum != 0.0);
for (unsigned i = 0; i < v.size(); i++) {
v[i] /= sum;
}
}
}
double
getL1Distance (const Params& v1, const Params& v2)
{
assert (v1.size() == v2.size());
double dist = 0.0;
if (Globals::logDomain) {
for (unsigned i = 0; i < v1.size(); i++) {
dist += abs (exp(v1[i]) - exp(v2[i]));
}
} else {
for (unsigned i = 0; i < v1.size(); i++) {
dist += abs (v1[i] - v2[i]);
}
}
return dist;
}
double
getMaxNorm (const Params& v1, const Params& v2)
{
assert (v1.size() == v2.size());
double max = 0.0;
if (Globals::logDomain) {
for (unsigned i = 0; i < v1.size(); i++) {
double diff = abs (exp(v1[i]) - exp(v2[i]));
if (diff > max) {
max = diff;
}
}
} else {
for (unsigned i = 0; i < v1.size(); i++) {
double diff = abs (v1[i] - v2[i]);
if (diff > max) {
max = diff;
}
}
}
return max;
}
double
pow (double p, unsigned expoent)
{
return Globals::logDomain ? p * expoent : std::pow (p, expoent);
}
double
pow (double p, double expoent)
{
// assumes that `expoent' is never in log domain
return Globals::logDomain ? p * expoent : std::pow (p, expoent);
}
void
pow (Params& v, unsigned expoent)
{
if (expoent == 1) {
return;
}
if (Globals::logDomain) {
for (unsigned i = 0; i < v.size(); i++) {
v[i] *= expoent;
}
} else {
for (unsigned i = 0; i < v.size(); i++) {
v[i] = std::pow (v[i], expoent);
}
}
}
void
pow (Params& v, double expoent)
{
// assumes that `expoent' is never in log domain
if (Globals::logDomain) {
for (unsigned i = 0; i < v.size(); i++) {
v[i] *= expoent;
}
} else {
for (unsigned i = 0; i < v.size(); i++) {
v[i] = std::pow (v[i], expoent);
}
}
}
}
unsigned
Statistics::getSolvedNetworksCounting (void)
{
return netInfo_.size();
}
void
Statistics::incrementPrimaryNetworksCounting (void)
{
primaryNetCount_ ++;
}
unsigned
Statistics::getPrimaryNetworksCounting (void)
{
return primaryNetCount_;
}
void
Statistics::updateStatistics (
unsigned size,
bool loopy,
unsigned nIters,
double time)
{
netInfo_.push_back (NetInfo (size, loopy, nIters, time));
}
void
Statistics::printStatistics (void)
{
cout << getStatisticString();
}
void
Statistics::writeStatistics (const char* fileName)
{
ofstream out (fileName);
if (!out.is_open()) {
cerr << "error: cannot open file to write at " ;
cerr << "Statistics::writeStats()" << endl;
abort();
}
out << getStatisticString();
out.close();
}
void
Statistics::updateCompressingStatistics (
unsigned nrGroundVars,
unsigned nrGroundFactors,
unsigned nrClusterVars,
unsigned nrClusterFactors,
unsigned nrNeighborless) {
compressInfo_.push_back (CompressInfo (nrGroundVars, nrGroundFactors,
nrClusterVars, nrClusterFactors, nrNeighborless));
}
string
Statistics::getStatisticString (void)
{
stringstream ss2, ss3, ss4, ss1;
ss1 << "running mode: " ;
switch (Globals::infAlgorithm) {
case InfAlgorithms::VE: ss1 << "ve" << endl; break;
case InfAlgorithms::BP: ss1 << "bp" << endl; break;
case InfAlgorithms::CBP: ss1 << "cbp" << endl; break;
}
ss1 << "message schedule: " ;
switch (BpOptions::schedule) {
case BpOptions::Schedule::SEQ_FIXED:
ss1 << "sequential fixed" << endl;
break;
case BpOptions::Schedule::SEQ_RANDOM:
ss1 << "sequential random" << endl;
break;
case BpOptions::Schedule::PARALLEL:
ss1 << "parallel" << endl;
break;
case BpOptions::Schedule::MAX_RESIDUAL:
ss1 << "max residual" << endl;
break;
}
ss1 << "max iterations: " << BpOptions::maxIter << endl;
ss1 << "accuracy " << BpOptions::accuracy << endl;
ss1 << endl << endl;
Util::printSubHeader ("Network information", ss2);
ss2 << left;
ss2 << setw (15) << "Network Size" ;
ss2 << setw (9) << "Loopy" ;
ss2 << setw (15) << "Iterations" ;
ss2 << setw (15) << "Solving Time" ;
ss2 << endl;
unsigned nLoopyNets = 0;
unsigned nUnconvergedRuns = 0;
double totalSolvingTime = 0.0;
for (unsigned i = 0; i < netInfo_.size(); i++) {
ss2 << setw (15) << netInfo_[i].size;
if (netInfo_[i].loopy) {
ss2 << setw (9) << "yes";
nLoopyNets ++;
} else {
ss2 << setw (9) << "no";
}
if (netInfo_[i].nIters == 0) {
ss2 << setw (15) << "n/a" ;
} else {
ss2 << setw (15) << netInfo_[i].nIters;
if (netInfo_[i].nIters > BpOptions::maxIter) {
nUnconvergedRuns ++;
}
}
ss2 << setw (15) << netInfo_[i].time;
totalSolvingTime += netInfo_[i].time;
ss2 << endl;
}
ss2 << endl << endl;
unsigned c1 = 0, c2 = 0, c3 = 0, c4 = 0;
if (compressInfo_.size() > 0) {
Util::printSubHeader ("Compress information", ss3);
ss3 << left;
ss3 << "Ground Cluster Ground Cluster Neighborless" << endl;
ss3 << "Vars Vars Factors Factors Vars" << endl;
for (unsigned i = 0; i < compressInfo_.size(); i++) {
ss3 << setw (9) << compressInfo_[i].nrGroundVars;
ss3 << setw (10) << compressInfo_[i].nrClusterVars;
ss3 << setw (10) << compressInfo_[i].nrGroundFactors;
ss3 << setw (10) << compressInfo_[i].nrClusterFactors;
ss3 << setw (10) << compressInfo_[i].nrNeighborless;
ss3 << endl;
c1 += compressInfo_[i].nrGroundVars - compressInfo_[i].nrNeighborless;
c2 += compressInfo_[i].nrClusterVars;
c3 += compressInfo_[i].nrGroundFactors - compressInfo_[i].nrNeighborless;
c4 += compressInfo_[i].nrClusterFactors;
if (compressInfo_[i].nrNeighborless != 0) {
c2 --;
c4 --;
}
}
ss3 << endl << endl;
}
ss4 << "primary networks: " << primaryNetCount_ << endl;
ss4 << "solved networks: " << netInfo_.size() << endl;
ss4 << "loopy networks: " << nLoopyNets << endl;
ss4 << "unconverged runs: " << nUnconvergedRuns << endl;
ss4 << "total solving time: " << totalSolvingTime << endl;
if (compressInfo_.size() > 0) {
double pc1 = (1.0 - (c2 / (double)c1)) * 100.0;
double pc2 = (1.0 - (c4 / (double)c3)) * 100.0;
ss4 << setprecision (5);
ss4 << "variable compression: " << pc1 << "%" << endl;
ss4 << "factor compression: " << pc2 << "%" << endl;
}
ss4 << endl << endl;
ss1 << ss4.str() << ss2.str() << ss3.str();
return ss1.str();
}

View File

@ -1,376 +0,0 @@
#ifndef HORUS_UTIL_H
#define HORUS_UTIL_H
#include <cmath>
#include <cassert>
#include <limits>
#include <algorithm>
#include <vector>
#include <set>
#include <queue>
#include <unordered_map>
#include <sstream>
#include <iostream>
#include "Horus.h"
using namespace std;
namespace Util {
template <typename T> void addToVector (vector<T>&, const vector<T>&);
template <typename T> void addToSet (set<T>&, const vector<T>&);
template <typename T> void addToQueue (queue<T>&, const vector<T>&);
template <typename T> bool contains (const vector<T>&, const T&);
template <typename T> bool contains (const set<T>&, const T&);
template <typename K, typename V> bool contains (
const unordered_map<K, V>&, const K&);
template <typename T> std::string toString (const T&);
void toLog (Params&);
void fromLog (Params&);
double logSum (double, double);
void multiply (Params&, const Params&);
void multiply (Params&, const Params&, unsigned);
void add (Params&, const Params&);
void add (Params&, const Params&, unsigned);
double factorial (double);
unsigned nrCombinations (unsigned, unsigned);
unsigned expectedSize (const Ranges&);
unsigned getNumberOfDigits (int);
bool isInteger (const string&);
string parametersToString (const Params&, unsigned = Constants::PRECISION);
vector<string> getStateLines (const Vars&);
void printHeader (string, std::ostream& os = std::cout);
void printSubHeader (string, std::ostream& os = std::cout);
void printAsteriskLine (std::ostream& os = std::cout);
void printDashedLine (std::ostream& os = std::cout);
unsigned maxUnsigned (void);
};
template <typename T> void
Util::addToVector (vector<T>& v, const vector<T>& elements)
{
v.insert (v.end(), elements.begin(), elements.end());
}
template <typename T> void
Util::addToSet (set<T>& s, const vector<T>& elements)
{
s.insert (elements.begin(), elements.end());
}
template <typename T> void
Util::addToQueue (queue<T>& q, const vector<T>& elements)
{
for (unsigned i = 0; i < elements.size(); i++) {
q.push (elements[i]);
}
}
template <typename T> bool
Util::contains (const vector<T>& v, const T& e)
{
return std::find (v.begin(), v.end(), e) != v.end();
}
template <typename T> bool
Util::contains (const set<T>& s, const T& e)
{
return s.find (e) != s.end();
}
template <typename K, typename V> bool
Util::contains (const unordered_map<K, V>& m, const K& k)
{
return m.find (k) != m.end();
}
template <typename T> std::string
Util::toString (const T& t)
{
std::stringstream ss;
ss << t;
return ss.str();
}
template <typename T>
std::ostream& operator << (std::ostream& os, const vector<T>& v)
{
os << "[" ;
for (unsigned i = 0; i < v.size(); i++) {
os << ((i != 0) ? ", " : "") << v[i];
}
os << "]" ;
return os;
}
namespace {
const double INF = -numeric_limits<double>::infinity();
};
inline double
Util::logSum (double x, double y)
{
return log (exp (x) + exp (y));
assert (isfinite (x) && isfinite (y));
// If one value is much smaller than the other, keep the larger value.
if (x < (y - log (1e200))) {
return y;
}
if (y < (x - log (1e200))) {
return x;
}
double diff = x - y;
assert (isfinite (diff) && isfinite (x) && isfinite (y));
if (!isfinite (exp (diff))) {
// difference is too large
return x > y ? x : y;
}
// otherwise return the sum.
return y + log (static_cast<double>(1.0) + exp (diff));
}
inline void
Util::multiply (Params& v1, const Params& v2)
{
assert (v1.size() == v2.size());
for (unsigned i = 0; i < v1.size(); i++) {
v1[i] *= v2[i];
}
}
inline void
Util::multiply (Params& v1, const Params& v2, unsigned repetitions)
{
for (unsigned count = 0; count < v1.size(); ) {
for (unsigned i = 0; i < v2.size(); i++) {
for (unsigned r = 0; r < repetitions; r++) {
v1[count] *= v2[i];
count ++;
}
}
}
}
inline void
Util::add (Params& v1, const Params& v2)
{
assert (v1.size() == v2.size());
for (unsigned i = 0; i < v1.size(); i++) {
v1[i] += v2[i];
}
}
inline void
Util::add (Params& v1, const Params& v2, unsigned repetitions)
{
for (unsigned count = 0; count < v1.size(); ) {
for (unsigned i = 0; i < v2.size(); i++) {
for (unsigned r = 0; r < repetitions; r++) {
v1[count] += v2[i];
count ++;
}
}
}
}
inline unsigned
Util::maxUnsigned (void)
{
return numeric_limits<unsigned>::max();
}
namespace LogAware {
inline double
one()
{
return Globals::logDomain ? 0.0 : 1.0;
}
inline double
zero() {
return Globals::logDomain ? INF : 0.0 ;
}
inline double
addIdenty()
{
return Globals::logDomain ? INF : 0.0;
}
inline double
multIdenty()
{
return Globals::logDomain ? 0.0 : 1.0;
}
inline double
withEvidence()
{
return Globals::logDomain ? 0.0 : 1.0;
}
inline double
noEvidence() {
return Globals::logDomain ? INF : 0.0;
}
inline double
tl (double v)
{
return Globals::logDomain ? log (v) : v;
}
inline double
fl (double v)
{
return Globals::logDomain ? exp (v) : v;
}
void normalize (Params&);
double getL1Distance (const Params&, const Params&);
double getMaxNorm (const Params&, const Params&);
double pow (double, unsigned);
double pow (double, double);
void pow (Params&, unsigned);
void pow (Params&, double);
};
struct NetInfo
{
NetInfo (unsigned size, bool loopy, unsigned nIters, double time)
{
this->size = size;
this->loopy = loopy;
this->nIters = nIters;
this->time = time;
}
unsigned size;
bool loopy;
unsigned nIters;
double time;
};
struct CompressInfo
{
CompressInfo (unsigned a, unsigned b, unsigned c, unsigned d, unsigned e)
{
nrGroundVars = a;
nrGroundFactors = b;
nrClusterVars = c;
nrClusterFactors = d;
nrNeighborless = e;
}
unsigned nrGroundVars;
unsigned nrGroundFactors;
unsigned nrClusterVars;
unsigned nrClusterFactors;
unsigned nrNeighborless;
};
class Statistics
{
public:
static unsigned getSolvedNetworksCounting (void);
static void incrementPrimaryNetworksCounting (void);
static unsigned getPrimaryNetworksCounting (void);
static void updateStatistics (unsigned, bool, unsigned, double);
static void printStatistics (void);
static void writeStatistics (const char*);
static void updateCompressingStatistics (
unsigned, unsigned, unsigned, unsigned, unsigned);
private:
static string getStatisticString (void);
static vector<NetInfo> netInfo_;
static vector<CompressInfo> compressInfo_;
static unsigned primaryNetCount_;
};
#endif // HORUS_UTIL_H

View File

@ -1,35 +0,0 @@
if [ $1 ] && [ $1 == "clear" ]; then
rm *~
rm -f school/*.log school/*~
rm -f city/*.log city/*~
rm -f workshop_attrs/*.log workshop_attrs/*~
fi
function run_solver
{
constraint=$1
solver_flag=true
if [ -n "$2" ]; then
if [ $SOLVER = hve ]; then
extra_flag=clpbn_horus:set_horus_flag\(elim_heuristic,$2\)
elif [ $SOLVER = bp ]; then
extra_flag=clpbn_horus:set_horus_flag\(schedule,$2\)
elif [ $SOLVER = cbp ]; then
extra_flag=clpbn_horus:set_horus_flag\(schedule,$2\)
else
echo "unknow flag $2"
fi
fi
/usr/bin/time -o $LOG_FILE -a -f "real:%E\tuser:%U\tsys:%S" \
$YAP << EOF >> $LOG_FILE 2>> ignore.$LOG_FILE
[$NETWORK].
[$constraint].
clpbn_horus:set_solver($SOLVER).
clpbn_horus:set_horus_flag(use_logarithms, true).
$solver_flag.
$QUERY.
open("$LOG_FILE", 'append', S), format(S, '$constraint: ~15+ ', []), close(S).
EOF
}

View File

@ -1,17 +0,0 @@
#!/bin/bash
source city.sh
source ../benchs.sh
SOLVER="bp"
YAP=~/bin/$SHORTNAME-$SOLVER
LOG_FILE=$SOLVER.log
#LOG_FILE=results`date "+ %H:%M:%S %d-%m-%Y"`.
rm -f $LOG_FILE
rm -f ignore.$LOG_FILE
run_all_graphs "bp(shedule=seq_fixed) " seq_fixed

View File

@ -1,17 +0,0 @@
#!/bin/bash
source city.sh
source ../benchs.sh
SOLVER="cbp"
YAP=~/bin/$SHORTNAME-$SOLVER
LOG_FILE=$SOLVER.log
#LOG_FILE=results`date "+ %H:%M:%S %d-%m-%Y"`.
rm -f $LOG_FILE
rm -f ignore.$LOG_FILE
run_all_graphs "cbp(shedule=seq_fixed) " seq_fixed

View File

@ -1,25 +0,0 @@
#!/bin/bash
NETWORK="'../../examples/city'"
SHORTNAME="city"
QUERY="is_joe_guilty(X)"
function run_all_graphs
{
cp ~/bin/yap $YAP
echo -n "**********************************" >> $LOG_FILE
echo "**********************************" >> $LOG_FILE
echo "results for solver $1" >> $LOG_FILE
echo -n "**********************************" >> $LOG_FILE
echo "**********************************" >> $LOG_FILE
run_solver city_5 $2
#run_solver city_1000 $2
#run_solver city_5000 $2
#run_solver city_10000 $2
#run_solver city_50000 $2
#run_solver city_100000 $2
#run_solver city_500000 $2
#run_solver city_1000000 $2
}

View File

@ -1,17 +0,0 @@
#!/bin/bash
source city.sh
source ../benchs.sh
SOLVER="fove"
YAP=~/bin/$SHORTNAME-$SOLVER
LOG_FILE=$SOLVER.log
#LOG_FILE=results`date "+ %H:%M:%S %d-%m-%Y"`.
rm -f $LOG_FILE
rm -f ignore.$LOG_FILEE
run_all_graphs "fove "

View File

@ -1,17 +0,0 @@
#!/bin/bash
source city.sh
source ../benchs.sh
SOLVER="hve"
YAP=~/bin/$SHORTNAME-$SOLVER
LOG_FILE=$SOLVER.log
#LOG_FILE=results`date "+ %H:%M:%S %d-%m-%Y"`.
rm -f $LOG_FILE
rm -f ignore.$LOG_FILE
run_all_graphs "hve(elim_heuristic=min_neighbors) " min_neighbors

View File

@ -1,11 +0,0 @@
#!/bin/bash
source wa.sh
source ../benchs.sh
SOLVER="bp"
YAP=~/bin/$SHORTNAME-$SOLVER
run_all_graphs "bp(shedule=seq_fixed) " seq_fixed

View File

@ -1,11 +0,0 @@
#!/bin/bash
source wa.sh
source ../benchs.sh
SOLVER="cbp"
YAP=~/bin/$SHORTNAME-$SOLVER
run_all_graphs "cbp(shedule=seq_fixed) " seq_fixed

View File

@ -1,12 +0,0 @@
#!/bin/bash
source wa.sh
source ../benchs.sh
SOLVER="fove"
YAP=~/bin/$SHORTNAME-$SOLVER
run_all_graphs "fove "

View File

@ -1,11 +0,0 @@
#!/bin/bash
source wa.sh
source ../benchs.sh
SOLVER="hve"
YAP=~/bin/$SHORTNAME-$SOLVER
run_all_graphs "hve(elim_heuristic=min_neighbors) " min_neighbors

View File

@ -1,33 +0,0 @@
#!/bin/bash
NETWORK="'../../examples/workshop_attrs'"
SHORTNAME="wa"
QUERY="series(X)"
function run_all_graphs
{
LOG_FILE=$SOLVER.log
#LOG_FILE=results`date "+ %H:%M:%S %d-%m-%Y"`.
rm -f $LOG_FILE
rm -f ignore.$LOG_FILE
cp ~/bin/yap $YAP
echo -n "**********************************" >> $LOG_FILE
echo "**********************************" >> $LOG_FILE
echo "results for solver $1" >> $LOG_FILE
echo -n "**********************************" >> $LOG_FILE
echo "**********************************" >> $LOG_FILE
run_solver pop_10 $2
#run_solver pop_1000 $2
#run_solver pop_5000 $2
#run_solver pop_10000 $2
#run_solver pop_50000 $2
#run_solver pop_100000 $2
#run_solver pop_500000 $2
#run_solver pop_1000000 $2
}

View File

@ -1,41 +0,0 @@
:- use_module(library(pfl)).
%:- set_pfl_flag(solver,ve).
:- set_pfl_flag(solver,bp), clpbn_horus:set_horus_flag(inf_alg,ve).
%:- set_pfl_flag(solver,bp), clpbn_horus:set_horus_flag(inf_alg,bp).
%:- set_pfl_flag(solver,fove).
% :- yap_flag(write_strings, off).
bayes burglary::[b1,b3] ; [0.001, 0.999] ; [].
bayes earthquake::[e1,e2] ; [0.002, 0.998]; [].
bayes alarm::[a1,a2] , burglary, earthquake ; [0.95, 0.94, 0.29, 0.001, 0.05, 0.06, 0.71, 0.999] ; [].
bayes john_calls::[j1,j2] , alarm ; [0.9, 0.05, 0.1, 0.95] ; [].
bayes mary_calls::[m1,m2] , alarm ; [0.7, 0.01, 0.3, 0.99] ; [].
b_cpt([0.001, 0.999]).
e_cpt([0.002, 0.998]).
a_cpt([0.95, 0.94, 0.29, 0.001,
0.05, 0.06, 0.71, 0.999]).
jc_cpt([0.9, 0.05,
0.1, 0.95]).
mc_cpt([0.7, 0.01,
0.3, 0.99]).
% ?- alarm(A).
?- john_calls(J), mary_calls(m1).
%?- john_calls(J), mary_calls(m1), alarm(a1).
%?- john_calls(J), alarm(a1).

View File

@ -1,15 +0,0 @@
# example in counting belief propagation paper
MARKOV
3
2 2 2
2
2 0 1
2 2 1
4
1.2 1.4 2.0 0.4
4
1.2 1.4 2.0 0.4

View File

@ -1,31 +0,0 @@
:- use_module(library(pfl)).
:- clpbn_horus:set_solver(fove).
%:- clpbn_horus:set_solver(hve).
%:- clpbn_horus:set_solver(bp).
%:- clpbn_horus:set_solver(cbp).
:- yap_flag(write_strings, off).
c(p1,w1).
c(p1,w2).
c(p1,w3).
c(p2,w1).
c(p2,w2).
c(p2,w3).
c(p3,w1).
c(p3,w2).
c(p3,w3).
c(p4,w1).
c(p4,w2).
c(p4,w3).
c(p5,w1).
c(p5,w2).
c(p5,w3).
markov attends(P)::[t,f] , hot(W)::[t,f] ; [0.1, 0.2, 0.3, 0.4] ; [c(P,W)].
markov attends(P)::[t,f], series::[t,f] ; [0.5, 0.6, 0.7, 0.8] ; [c(P,_)].
% ?- series(X).

View File

@ -1,24 +0,0 @@
:- use_module(library(pfl)).
:- clpbn_horus:set_solver(fove).
%:- clpbn_horus:set_solver(hve).
%:- clpbn_horus:set_solver(bp).
%:- clpbn_horus:set_solver(cbp).
:- yap_flag(write_strings, off).
friends(P1, P2) :-
people(P1),
people(P2),
P1 \= P2.
people @ 3.
markov smokes(P)::[t,f], cancer(P)::[t,f] ; [0.1, 0.2, 0.3, 0.4] ; [people(P)].
markov friend(P1,P2)::[t,f], smokes(P1)::[t,f], smokes(P2)::[t,f] ;
[0.5, 0.6, 0.7, 0.8, 0.5, 0.6, 0.7, 0.8] ; [friends(P1, P2)].
% ?- smokes(p1, t), smokes(p2, f), friend(p1, p2, X).

View File

@ -1,27 +0,0 @@
:- use_module(library(pfl)).
%:- clpbn_horus:set_solver(fove).
%:- clpbn_horus:set_solver(hve).
:- clpbn_horus:set_solver(bp).
%:- clpbn_horus:set_solver(cbp).
:- yap_flag(write_strings, off).
people @ 3.
markov attends(P)::[t,f], attr1::[t,f] ; [0.11, 0.2, 0.3, 0.4] ; [people(P)].
markov attends(P)::[t,f], attr2::[t,f] ; [0.1, 0.22, 0.3, 0.4] ; [people(P)].
markov attends(P)::[t,f], attr3::[t,f] ; [0.1, 0.2, 0.33, 0.4] ; [people(P)].
markov attends(P)::[t,f], attr4::[t,f] ; [0.1, 0.2, 0.3, 0.44] ; [people(P)].
markov attends(P)::[t,f], attr5::[t,f] ; [0.1, 0.2, 0.3, 0.45] ; [people(P)].
markov attends(P)::[t,f], attr6::[t,f] ; [0.1, 0.2, 0.3, 0.46] ; [people(P)].
markov attends(P)::[t,f], series::[t,f] ; [0.5, 0.6, 0.7, 0.87] ; [people(P)].
% ?- series(X).

View File

@ -41,20 +41,30 @@ do_network([], _, _, _) :- !.
do_network(QueryVars, EVars, Keys, Factors) :- do_network(QueryVars, EVars, Keys, Factors) :-
retractall(currently_defined(_)), retractall(currently_defined(_)),
retractall(f(_,_,_,_)), retractall(f(_,_,_,_)),
writeln(keys:Keys),
run_through_factors(QueryVars), run_through_factors(QueryVars),
run_through_factors(EVars), run_through_factors(EVars),
findall(K, currently_defined(K), Keys), findall(K, currently_defined(K), Keys),
writeln(keys2:Keys), ground_all_keys(QueryVars, Keys),
ground_all_keys(EVars, Keys),
findall(f(FType,FId,FKeys,FCPT), f(FType,FId,FKeys,FCPT), Factors). findall(f(FType,FId,FKeys,FCPT), f(FType,FId,FKeys,FCPT), Factors).
match([], _Keys). run_through_factors([]).
match([V|GVars], Keys) :- run_through_factors([Var|_QueryVars]) :-
clpbn:get_atts(V,[key(GKey)]), !, clpbn:get_atts(Var,[key(K)]),
member(GKey, Keys), ground(GKey), find_factors(K),
match(GVars, Keys). fail.
match([_V|GVars], Keys) :- run_through_factors([_|QueryVars]) :-
match(GVars, Keys). run_through_factors(QueryVars).
ground_all_keys([], _).
ground_all_keys([V|GVars], AllKeys) :-
clpbn:get_atts(V,[key(Key)]),
\+ ground(Key), !,
member(Key, AllKeys),
ground_all_keys(GVars, AllKeys).
ground_all_keys([_V|GVars], AllKeys) :-
ground_all_keys(GVars, AllKeys).
% %
@ -99,6 +109,7 @@ keys([Var|QueryVars], [Key|QueryKeys]) :-
initialize_evidence([]). initialize_evidence([]).
initialize_evidence([V|EVars]) :- initialize_evidence([V|EVars]) :-
clpbn:get_atts(V, [key(K)]), clpbn:get_atts(V, [key(K)]),
ground(K),
assert(currently_defined(K)), assert(currently_defined(K)),
initialize_evidence(EVars). initialize_evidence(EVars).
@ -106,7 +117,7 @@ initialize_evidence([V|EVars]) :-
% gets key K, and collects factors that define it % gets key K, and collects factors that define it
find_factors(K) :- find_factors(K) :-
\+ currently_defined(K), \+ currently_defined(K),
assert(currently_defined(K)), ( ground(K) -> assert(currently_defined(K)) ; true),
defined_in_factor(K, ParFactor), defined_in_factor(K, ParFactor),
add_factor(ParFactor, Ks), add_factor(ParFactor, Ks),
member(K1, Ks), member(K1, Ks),

View File

@ -1,61 +1,63 @@
/******************************************************* /*******************************************************
Interface with C++ Horus Interface
********************************************************/ ********************************************************/
:- module(clpbn_horus, :- module(clpbn_horus,
[set_solver/1, [set_solver/1,
create_lifted_network/3, set_horus_flag/1,
create_ground_network/4, cpp_create_lifted_network/3,
set_parfactors_params/2, cpp_create_ground_network/4,
set_factors_params/2, cpp_set_parfactors_params/2,
run_lifted_solver/3, cpp_set_factors_params/2,
run_ground_solver/3, cpp_run_lifted_solver/3,
set_vars_information/2, cpp_run_ground_solver/3,
set_horus_flag/2, cpp_set_vars_information/2,
free_parfactors/1, cpp_set_horus_flag/2,
free_ground_network/1 cpp_free_parfactors/1,
cpp_free_ground_network/1
]). ]).
:- use_module(library(clpbn),
[set_clpbn_flag/2]).
patch_things_up :- patch_things_up :-
assert_static(clpbn_horus:set_horus_flag(_,_)). assert_static(clpbn_horus:cpp_set_horus_flag(_,_)).
warning :- warning :-
format(user_error,"Horus library not installed: cannot use bp, fove~n.",[]). format(user_error,"Horus library not installed: cannot use bp, fove~n.",[]).
:- catch(load_foreign_files([horus], [], init_predicates), _, patch_things_up) -> true ; warning.
:- catch(load_foreign_files([horus], [], init_predicates), _, patch_things_up)
-> true ; warning.
set_solver(ve) :- pfl:set_pfl_flag(solver,ve). set_solver(ve) :- set_clpbn_flag(solver,ve).
set_solver(jt) :- pfl:set_pfl_flag(solver,jt). set_solver(jt) :- set_clpbn_flag(solver,jt).
set_solver(gibbs) :- pfl:set_pfl_flag(solver,gibbs). set_solver(gibbs) :- set_clpbn_flag(solver,gibbs).
set_solver(fove) :- pfl:set_pfl_flag(solver,fove). set_solver(fove) :- set_clpbn_flag(solver,fove), set_horus_flag(lifted_solver, fove).
set_solver(hve) :- pfl:set_pfl_flag(solver,bp), set_horus_flag(inf_alg, ve). set_solver(lbp) :- set_clpbn_flag(solver,fove), set_horus_flag(lifted_solver, lbp).
set_solver(bp) :- pfl:set_pfl_flag(solver,bp), set_horus_flag(inf_alg, bp). set_solver(hve) :- set_clpbn_flag(solver,bp), set_horus_flag(ground_solver, ve).
set_solver(cbp) :- pfl:set_pfl_flag(solver,bp), set_horus_flag(inf_alg, cbp). set_solver(bp) :- set_clpbn_flag(solver,bp), set_horus_flag(ground_solver, bp).
set_solver(cbp) :- set_clpbn_flag(solver,bp), set_horus_flag(ground_solver, cbp).
set_solver(S) :- throw(error('unknow solver ', S)). set_solver(S) :- throw(error('unknow solver ', S)).
%:- set_horus_flag(inf_alg, ve). set_horus_flag(K,V) :- cpp_set_horus_flag(K,V).
%:- set_horus_flag(inf_alg, bp).
%: -set_horus_flag(inf_alg, cbp).
:- set_horus_flag(schedule, seq_fixed).
%:- set_horus_flag(schedule, seq_random).
%:- set_horus_flag(schedule, parallel).
%:- set_horus_flag(schedule, max_residual).
:- set_horus_flag(accuracy, 0.0001). :- cpp_set_horus_flag(schedule, seq_fixed).
%:- cpp_set_horus_flag(schedule, seq_random).
%:- cpp_set_horus_flag(schedule, parallel).
%:- cpp_set_horus_flag(schedule, max_residual).
:- set_horus_flag(max_iter, 1000). :- cpp_set_horus_flag(accuracy, 0.0001).
:- set_horus_flag(order_factor_variables, false). :- cpp_set_horus_flag(max_iter, 1000).
%:- set_horus_flag(order_factor_variables, true).
:- set_horus_flag(use_logarithms, false). :- cpp_set_horus_flag(use_logarithms, false).
% :- set_horus_flag(use_logarithms, true). % :- cpp_set_horus_flag(use_logarithms, true).

View File

@ -1,19 +1,27 @@
/******************************************************* /*******************************************************
Belief Propagation and Variable Elimination Interface Interface to Horus Ground Solvers. Used by:
- Variable Elimination
- Belief Propagation
- Counting Belief Propagation
********************************************************/ ********************************************************/
:- module(clpbn_bp, :- module(clpbn_horus_ground,
[bp/3, [call_horus_ground_solver/6,
check_if_bp_done/1, check_if_horus_ground_solver_done/1,
init_bp_solver/4, init_horus_ground_solver/4,
run_bp_solver/3, run_horus_ground_solver/3,
call_bp_ground/6, finalize_horus_ground_solver/1
finalize_bp_solver/1
]). ]).
:- use_module(horus,
[cpp_create_ground_network/4,
cpp_set_factors_params/2,
cpp_run_ground_solver/3,
cpp_set_vars_information/2,
cpp_free_ground_network/1
]).
:- use_module(library('clpbn/dists'), :- use_module(library('clpbn/dists'),
[dist/4, [dist/4,
@ -22,25 +30,20 @@
get_dist_params/2 get_dist_params/2
]). ]).
:- use_module(library('clpbn/display'), :- use_module(library('clpbn/display'),
[clpbn_bind_vals/3]). [clpbn_bind_vals/3]).
:- use_module(library('clpbn/aggregates'), :- use_module(library('clpbn/aggregates'),
[check_for_agg_vars/2]). [check_for_agg_vars/2]).
:- use_module(library(charsio), :- use_module(library(charsio),
[term_to_atom/2]). [term_to_atom/2]).
:- use_module(library(pfl), :- use_module(library(pfl),
[skolem/2, [skolem/2,
get_pfl_parameters/2 get_pfl_parameters/2
]). ]).
:- use_module(library(lists)). :- use_module(library(lists)).
:- use_module(library(atts)). :- use_module(library(atts)).
@ -48,45 +51,36 @@
:- use_module(library(bhash)). :- use_module(library(bhash)).
:- use_module(horus, call_horus_ground_solver(QueryVars, QueryKeys, AllKeys, Factors, Evidence, Output) :-
[create_ground_network/4,
set_factors_params/2,
run_ground_solver/3,
set_vars_information/2,
free_ground_network/1
]).
call_bp_ground(QueryVars, QueryKeys, AllKeys, Factors, Evidence, Output) :-
writeln(here:Factors),
b_hash_new(Hash0), b_hash_new(Hash0),
keys_to_ids(AllKeys, 0, Hash0, Hash), keys_to_ids(AllKeys, 0, Hash0, Hash),
get_factors_type(Factors, Type), get_factors_type(Factors, Type),
evidence_to_ids(Evidence, Hash, EvidenceIds), evidence_to_ids(Evidence, Hash, EvidenceIds),
factors_to_ids(Factors, Hash, FactorIds), factors_to_ids(Factors, Hash, FactorIds),
writeln(type:Type), writeln(''), %writeln(type:Type), writeln(''),
writeln(allKeys:AllKeys), writeln(''), %writeln(allKeys:AllKeys), writeln(''),
writeln(factors:Factors), writeln(''), %sort(AllKeys,SKeys),writeln(allKeys:SKeys), writeln(''),
writeln(factorIds:FactorIds), writeln(''), %writeln(factors:Factors), writeln(''),
writeln(evidence:Evidence), writeln(''), %writeln(factorIds:FactorIds), writeln(''),
writeln(evidenceIds:EvidenceIds), writeln(''), %writeln(evidence:Evidence), writeln(''),
create_ground_network(Type, FactorIds, EvidenceIds, Network), %writeln(evidenceIds:EvidenceIds), writeln(''),
cpp_create_ground_network(Type, FactorIds, EvidenceIds, Network),
%get_vars_information(AllKeys, StatesNames), %get_vars_information(AllKeys, StatesNames),
%set_vars_information(AllKeys, StatesNames), %terms_to_atoms(AllKeys, KeysAtoms),
%cpp_set_vars_information(KeysAtoms, StatesNames),
run_solver(ground(Network,Hash), QueryKeys, Solutions), run_solver(ground(Network,Hash), QueryKeys, Solutions),
writeln(answer:Solutions),
clpbn_bind_vals([QueryVars], Solutions, Output), clpbn_bind_vals([QueryVars], Solutions, Output),
free_ground_network(Network). cpp_free_ground_network(Network).
run_solver(ground(Network,Hash), QueryKeys, Solutions) :- run_solver(ground(Network,Hash), QueryKeys, Solutions) :-
%get_dists_parameters(DistIds, DistsParams), %get_dists_parameters(DistIds, DistsParams),
%set_factors_params(Network, DistsParams), %cpp_set_factors_params(Network, DistsParams),
list_of_keys_to_ids(QueryKeys, Hash, QueryIds), list_of_keys_to_ids(QueryKeys, Hash, QueryIds),
writeln(queryKeys:QueryKeys), writeln(''), %writeln(queryKeys:QueryKeys), writeln(''),
writeln(queryIds:QueryIds), writeln(''), %writeln(queryIds:QueryIds), writeln(''),
list_of_keys_to_ids(QueryKeys, Hash, QueryIds), list_of_keys_to_ids(QueryKeys, Hash, QueryIds),
run_ground_solver(Network, [QueryIds], Solutions). cpp_run_ground_solver(Network, [QueryIds], Solutions).
keys_to_ids([], _, Hash, Hash). keys_to_ids([], _, Hash, Hash).
@ -132,31 +126,40 @@ get_vars_information(Key.QueryKeys, Domain.StatesNames) :-
get_vars_information(QueryKeys, StatesNames). get_vars_information(QueryKeys, StatesNames).
finalize_bp_solver(bp(Network, _)) :- terms_to_atoms([], []).
free_ground_network(Network). terms_to_atoms(K.Ks, Atom.As) :-
term_to_atom(K,Atom),
terms_to_atoms(Ks,As).
bp([[]],_,_) :- !. finalize_horus_ground_solver(bp(Network, _)) :-
bp([QueryVars], AllVars, Output) :- cpp_free_ground_network(Network).
init_bp_solver(_, AllVars, _, Network),
run_bp_solver([QueryVars], LPs, Network),
finalize_bp_solver(Network),
clpbn_bind_vals([QueryVars], LPs, Output).
init_bp_solver(_, AllVars0, _, bp(BayesNet, DistIds)) :- init_horus_ground_solver(_, _AllVars0, _, bp(_BayesNet, _DistIds)) :- !.
%check_for_agg_vars(AllVars0, AllVars),
get_vars_info(AllVars0, VarsInfo, DistIds0),
sort(DistIds0, DistIds),
create_ground_network(VarsInfo, BayesNet),
true.
run_horus_ground_solver(_QueryVars, _Solutions, bp(_Network, _DistIds)) :- !.
run_bp_solver(QueryVars, Solutions, bp(Network, DistIds)) :- %bp([[]],_,_) :- !.
get_dists_parameters(DistIds, DistsParams), %bp([QueryVars], AllVars, Output) :-
set_factors_params(Network, DistsParams), % init_horus_ground_solver(_, AllVars, _, Network),
vars_to_ids(QueryVars, QueryVarsIds), % run_horus_ground_solver([QueryVars], LPs, Network),
run_ground_solver(Network, QueryVarsIds, Solutions). % finalize_horus_ground_solver(Network),
% clpbn_bind_vals([QueryVars], LPs, Output).
%
%init_horus_ground_solver(_, AllVars0, _, bp(BayesNet, DistIds)) :-
% %check_for_agg_vars(AllVars0, AllVars),
% get_vars_info(AllVars0, VarsInfo, DistIds0),
% sort(DistIds0, DistIds),
% cpp_create_ground_network(VarsInfo, BayesNet),
% true.
%
%
%run_horus_ground_solver(QueryVars, Solutions, bp(Network, DistIds)) :-
% get_dists_parameters(DistIds, DistsParams),
% cpp_set_factors_params(Network, DistsParams),
% vars_to_ids(QueryVars, QueryVarsIds),
% cpp_run_ground_solver(Network, QueryVarsIds, Solutions).
get_dists_parameters([],[]). get_dists_parameters([],[]).

View File

@ -1,27 +1,31 @@
/******************************************************* /*******************************************************
First Order Variable Elimination Interface Interface to Horus Lifted Solvers. Used by:
- Lifted Variable Elimination
********************************************************/ ********************************************************/
:- module(clpbn_fove, :- module(clpbn_horus_lifted,
[fove/3, [call_horus_lifted_solver/3,
check_if_fove_done/1, check_if_horus_lifted_solver_done/1,
init_fove_solver/4, init_horus_lifted_solver/4,
run_fove_solver/3, run_horus_lifted_solver/3,
finalize_fove_solver/1 finalize_horus_lifted_solver/1
]). ]).
:- use_module(horus,
[cpp_create_lifted_network/3,
cpp_set_parfactors_params/2,
cpp_run_lifted_solver/3,
cpp_free_parfactors/1
]).
:- use_module(library('clpbn/display'), :- use_module(library('clpbn/display'),
[clpbn_bind_vals/3]). [clpbn_bind_vals/3]).
:- use_module(library('clpbn/dists'), :- use_module(library('clpbn/dists'),
[get_dist_params/2]). [get_dist_params/2]).
:- use_module(library(pfl), :- use_module(library(pfl),
[factor/6, [factor/6,
skolem/2, skolem/2,
@ -29,30 +33,22 @@
]). ]).
:- use_module(horus, call_horus_lifted_solver([[]], _, _) :- !.
[create_lifted_network/3, call_horus_lifted_solver([QueryVars], AllVars, Output) :-
set_parfactors_params/2, init_horus_lifted_solver(_, AllVars, _, ParfactorList),
run_lifted_solver/3, run_horus_lifted_solver([QueryVars], LPs, ParfactorList),
free_parfactors/1 finalize_horus_lifted_solver(ParfactorList),
]).
fove([[]], _, _) :- !.
fove([QueryVars], AllVars, Output) :-
init_fove_solver(_, AllVars, _, ParfactorList),
run_fove_solver([QueryVars], LPs, ParfactorList),
finalize_fove_solver(ParfactorList),
clpbn_bind_vals([QueryVars], LPs, Output). clpbn_bind_vals([QueryVars], LPs, Output).
init_fove_solver(_, AllAttVars, _, fove(ParfactorList, DistIds)) :- init_horus_lifted_solver(_, AllAttVars, _, fove(ParfactorList, DistIds)) :-
get_parfactors(Parfactors), get_parfactors(Parfactors),
get_dist_ids(Parfactors, DistIds0), get_dist_ids(Parfactors, DistIds0),
sort(DistIds0, DistIds), sort(DistIds0, DistIds),
get_observed_vars(AllAttVars, ObservedVars), get_observed_vars(AllAttVars, ObservedVars),
writeln(parfactors:Parfactors:'\n'), %writeln(parfactors:Parfactors:'\n'),
writeln(evidence:ObservedVars:'\n'), %writeln(evidence:ObservedVars:'\n'),
create_lifted_network(Parfactors,ObservedVars,ParfactorList). cpp_create_lifted_network(Parfactors,ObservedVars,ParfactorList).
:- table get_parfactors/1. :- table get_parfactors/1.
@ -138,15 +134,15 @@ get_dists_parameters([Id|Ids], [dist(Id, Params)|DistsInfo]) :-
get_dists_parameters(Ids, DistsInfo). get_dists_parameters(Ids, DistsInfo).
run_fove_solver(QueryVarsAtts, Solutions, fove(ParfactorList, DistIds)) :- run_horus_lifted_solver(QueryVarsAtts, Solutions, fove(ParfactorList, DistIds)) :-
get_query_vars(QueryVarsAtts, QueryVars), get_query_vars(QueryVarsAtts, QueryVars),
writeln(queryVars:QueryVars), writeln(''), %writeln(queryVars:QueryVars), writeln(''),
get_dists_parameters(DistIds, DistsParams), get_dists_parameters(DistIds, DistsParams),
writeln(dists:DistsParams), writeln(''), %writeln(dists:DistsParams), writeln(''),
set_parfactors_params(ParfactorList, DistsParams), cpp_set_parfactors_params(ParfactorList, DistsParams),
run_lifted_solver(ParfactorList, QueryVars, Solutions). cpp_run_lifted_solver(ParfactorList, QueryVars, Solutions).
finalize_fove_solver(fove(ParfactorList, _)) :- finalize_horus_lifted_solver(fove(ParfactorList, _)) :-
free_parfactors(ParfactorList). cpp_free_parfactors(ParfactorList).

View File

@ -1,25 +0,0 @@
%conservative_city(nyc, t).
hair_color(joe, t).
car_color(joe, t).
shoe_size(joe, f).
/* Steps:
1. generate N facts lives(I, nyc), 0 <= I < N.
2. generate evidence on descn for N people, *** except for 1 ***
3. Run query ?- guilty(joe, Guilty), witness(joe, t), descn(2,t), descn(3, f), descn(4, f).
query(Guilty) :-
guilty(joe, Guilty),
witness(joe, t),
descn(2,t),
descn(3,f),
descn(4,f), ....
*/

View File

@ -1,60 +0,0 @@
/* base file for school database. Supposed to be called from school_*.yap */
conservative_city(City, Cons) :-
cons_table(City, ConsDist),
{ Cons = cons(City) with p([y,n], ConsDist) }.
gender(X, Gender) :-
gender_table(City, GenderDist),
{ Gender = gender(City) with p([m,f], GenderDist) }.
hair_color(X, Color) :-
lives(X, City),
conservative_city(City, Cons),
gender(X, Gender),
color_table(X,ColorTable),
{ Color = color(X) with
p([t,f], ColorTable,[Gender,Cons]) }.
car_color(X, Color) :-
hair_color(City, HColor),
ccolor_table(X,CColorTable),
{ Color = ccolor(X) with
p([t,f], CColorTable,[HColor]) }.
height(X, Height) :-
gender(X, Gender),
height_table(X,HeightTable),
{ Height = height(X) with
p([t,f], HeightTable,[Gender]) }.
shoe_size(X, Shoesize) :-
height(X, Height),
shoe_size_table(X,ShoesizeTable),
{ Shoesize = shoe_size(X) with
p([t,f], ShoesizeTable,[Height]) }.
guilty(X, Guilt) :-
guilt_table(X, GuiltDist),
{ Guilt = guilt(X) with p([y,n], GuiltDist) }.
descn(X, Descn) :-
car_color(X, Car),
hair_color(X, Hair),
height(X, Height),
guilty(X, Guilt),
descn_table(X, DescTable),
{ Descn = descn(X) with
p([t,f], DescTable,[Car,Hair,Height,Guilt]) }.
witness(City, Witness) :-
descn(joe, DescnJ),
descn(1, Descn1),
wit_table(WitTable),
{ Witness = wit(City) with
p([t,f], WitTable,[DescnJ, Descn1]) }.

View File

@ -1,35 +0,0 @@
cons_table(amsterdam,[0.2,
0.8]) :- !.
cons_table(_, [0.8,
0.2]).
color_table(_,
/* tm tf fm ff */
[ 0.05, 0.1, 0.3, 0.5 ,
0.95, 0.9, 0.7, 0.5 ]).
ccolor_table(_,
/* t f */
[ 0.9, 0.2 ,
0.1, 0.8 ]).
height_table(_,
/* m f */
[ 0.6, 0.4 ,
0.4, 0.6 ]).
shoe_size_table(_,
/* t f */
[ 0.9, 0.1 ,
0.1, 0.9 ]).
A: professor's ability;
B: student's grade (for course registration).
*/
descn_table(_,
/* color, hair, height, guilt */
/* ttttt tttf ttft ttff tfttt tftf tfft tfff ttttt fttf ftft ftff ffttt fftf ffft ffff */
/*t*/ [0.99, 0.99, 0.99, 0.99, 0.99, 0.99, 0.99, 0.99, 0.99, 0.99, 0.99, 0.99, 0.99, 0.99, 0.99, 0.99,
/*f*/ 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01 ]).

View File

@ -0,0 +1,23 @@
:- use_module(library(pfl)).
%:- set_solver(fove).
%:- set_solver(hve).
%:- set_solver(bp).
%:- set_solver(cbp).
:- yap_flag(write_strings, off).
bayes burglary::[b1,b2] ; [0.001, 0.999] ; [].
bayes earthquake::[e1,e2] ; [0.002, 0.998]; [].
bayes alarm::[a1,a2], burglary, earthquake ;
[0.95, 0.94, 0.29, 0.001, 0.05, 0.06, 0.71, 0.999] ;
[].
bayes john_calls::[j1,j2], alarm ; [0.9, 0.05, 0.1, 0.95] ; [].
bayes mary_calls::[m1,m2], alarm ; [0.7, 0.01, 0.3, 0.99] ; [].
% ?- john_calls(J), mary_calls(m1).

View File

@ -1,20 +1,23 @@
:- use_module(library(pfl)). :- use_module(library(pfl)).
:- clpbn_horus:set_solver(fove). %:- set_solver(fove).
%:- clpbn_horus:set_solver(hve). %:- set_solver(hve).
:- clpbn_horus:set_solver(bp). %:- set_solver(bp).
%:- clpbn_horus:set_solver(cbp). %:- set_solver(cbp).
:- multifile people/2.
:- multifile ev/1.
people(joe,nyc). people(joe,nyc).
people(p2, nyc). people(p2, nyc).
people(p3, nyc). people(p3, nyc).
people(p4, nyc).
people(p5, nyc).
ev(descn(p2, t)). ev(descn(p2, t)).
ev(descn(p3, t)). ev(descn(p3, t)).
ev(descn(p4, t)).
% :- [city_7]. ev(descn(p5, t)).
bayes city_conservativeness(C)::[y,n] ; cons_table(C) ; [people(_,C)]. bayes city_conservativeness(C)::[y,n] ; cons_table(C) ; [people(_,C)].
@ -34,12 +37,12 @@ bayes descn(P)::[t,f], car_color(P), hair_color(P), height(P), guilty(P) ; descn
bayes witness(C)::[t,f], descn(Joe), descn(P2) ; wit_table ; [people(_,C), Joe=joe, P2=p2]. bayes witness(C)::[t,f], descn(Joe), descn(P2) ; wit_table ; [people(_,C), Joe=joe, P2=p2].
% FIXME
cons_table(amsterdam, [0.2, 0.8]) :- !. %cons_table(amsterdam, [0.2, 0.8]) :- !.
cons_table(_, [0.8, 0.2]). cons_table(_, [0.8, 0.2]).
gender_table(_, [0.55, 0.44]). gender_table(_, [0.55, 0.45]).
hair_color_table(_, hair_color_table(_,
@ -73,8 +76,8 @@ guilty_table(_, [0.23, 0.77]).
descn_table(_, descn_table(_,
/* color, hair, height, guilt */ /* color, hair, height, guilt */
/* ttttt tttf ttft ttff tfttt tftf tfft tfff ttttt fttf ftft ftff ffttt fftf ffft ffff */ /* ttttt tttf ttft ttff tfttt tftf tfft tfff ttttt fttf ftft ftff ffttt fftf ffft ffff */
[ 0.99, 0.5, 0.23, 0.88, 0.41, 0.3, 0.76, 0.87, 0.44, 0.43, 0.29, 0.72, 0.33, 0.91, 0.95, 0.92, [ 0.99, 0.5, 0.23, 0.88, 0.41, 0.3, 0.76, 0.87, 0.44, 0.43, 0.29, 0.72, 0.23, 0.91, 0.95, 0.92,
0.01, 0.5, 0.77, 0.12, 0.59, 0.7, 0.24, 0.13, 0.56, 0.57, 0.61, 0.28, 0.77, 0.09, 0.05, 0.08]). 0.01, 0.5, 0.77, 0.12, 0.59, 0.7, 0.24, 0.13, 0.56, 0.57, 0.71, 0.28, 0.77, 0.09, 0.05, 0.08]).
wit_table([0.2, 0.45, 0.24, 0.34, wit_table([0.2, 0.45, 0.24, 0.34,

View File

@ -0,0 +1,33 @@
:- use_module(library(pfl)).
%:- set_solver(fove).
%:- set_solver(hve).
%:- set_solver(bp).
%:- set_solver(cbp).
:- yap_flag(write_strings, off).
:- multifile c/2.
c(p1,w1).
c(p1,w2).
c(p1,w3).
c(p2,w1).
c(p2,w2).
c(p2,w3).
c(p3,w1).
c(p3,w2).
c(p3,w3).
c(p4,w1).
c(p4,w2).
c(p4,w3).
c(p5,w1).
c(p5,w2).
c(p5,w3).
markov attends(P)::[t,f], hot(W)::[t,f] ; [0.2, 0.8, 0.8, 0.8] ; [c(P,W)].
markov attends(P)::[t,f], series::[t,f] ; [0.501, 0.499, 0.499, 0.499] ; [c(P,_)].
% ?- series(X).

View File

@ -0,0 +1,21 @@
:- use_module(library(pfl)).
:- set_solver(fove).
%:- set_solver(hve).
%:- set_solver(bp).
%:- set_solver(cbp).
:- yap_flag(write_strings, off).
:- clpbn_horus:set_horus_flag(verbosity,5).
people(p1,p1).
people(p1,p2).
people(p2,p1).
people(p2,p2).
markov p(A,A)::[t,f] ; [1.0,4.5] ; [people(A,_)].
markov p(A,B)::[t,f] ; [1.0,4.5] ; [people(A,B)].
?- p(p1,p1,X).

View File

@ -0,0 +1,31 @@
:- use_module(library(pfl)).
%:- set_solver(fove).
%:- set_solver(hve).
%:- set_solver(bp).
%:- set_solver(cbp).
:- yap_flag(write_strings, off).
:- multifile people/1.
people @ 5.
people(X,Y) :-
people(X),
people(Y),
X \== Y.
markov smokes(X)::[t,f]; [1.0, 4.0552]; [people(X)].
markov cancer(X)::[t,f]; [1.0, 9.9742]; [people(X)].
markov friends(X,Y)::[t,f] ; [1.0, 99.48432] ; [people(X,Y)].
markov smokes(X)::[t,f], cancer(X)::[t,f] ; [4.48169, 4.48169, 1.0, 4.48169] ; [people(X)].
markov friends(X,Y)::[t,f], smokes(X)::[t,f], smokes(Y)::[t,f] ;
[3.004166, 3.004166, 3.004166, 3.004166, 3.004166, 1.0, 1.0, 3.004166] ; [people(X,Y)].
% ?- friends(p1,p2,X).

View File

@ -0,0 +1,31 @@
:- use_module(library(pfl)).
%:- set_solver(fove).
%:- set_solver(hve).
%:- set_solver(bp).
%:- set_solver(cbp).
:- yap_flag(write_strings, off).
:- multifile people/1.
people @ 5.
people(X,Y) :-
people(X),
people(Y).
% X \== Y.
markov smokes(X)::[t,f]; [1.0, 4.0552]; [people(X)].
markov asthma(X)::[t,f]; [1.0, 9.9742] ; [people(X)].
markov friends(X,Y)::[t,f]; [1.0, 99.48432] ; [people(X,Y)].
markov asthma(X)::[t,f], smokes(X)::[t,f]; [4.48169, 4.48169, 1.0, 4.48169] ; [people(X)].
markov asthma(X)::[t,f], friends(X,Y)::[t,f], smokes(Y)::[t,f];
[3.004166, 3.004166, 3.004166, 3.004166, 3.004166, 1.0, 1.0, 3.004166] ; [people(X,Y)].
% ?- smokes(p1,t), smokes(p2,t), friends(p1,p2,X)

View File

@ -0,0 +1,29 @@
:- use_module(library(pfl)).
%:- set_solver(fove).
%:- set_solver(hve).
%:- set_solver(bp).
%:- set_solver(cbp).
:- yap_flag(write_strings, off).
:- multifile people/1.
people @ 5.
markov attends(P)::[t,f], attr1::[t,f] ; [0.7, 0.3, 0.3, 0.3] ; [people(P)].
markov attends(P)::[t,f], attr2::[t,f] ; [0.7, 0.3, 0.3, 0.3] ; [people(P)].
markov attends(P)::[t,f], attr3::[t,f] ; [0.7, 0.3, 0.3, 0.3] ; [people(P)].
markov attends(P)::[t,f], attr4::[t,f] ; [0.7, 0.3, 0.3, 0.3] ; [people(P)].
markov attends(P)::[t,f], attr5::[t,f] ; [0.7, 0.3, 0.3, 0.3] ; [people(P)].
markov attends(P)::[t,f], attr6::[t,f] ; [0.7, 0.3, 0.3, 0.3] ; [people(P)].
markov attends(P)::[t,f], series::[t,f] ; [0.501, 0.499, 0.499, 0.499] ; [people(P)].
% ?- series(X).

View File

@ -9,14 +9,12 @@
#include "Util.h" #include "Util.h"
FactorGraph* FactorGraph*
BayesBall::getMinimalFactorGraph (const VarIds& queryIds) BayesBall::getMinimalFactorGraph (const VarIds& queryIds)
{ {
assert (fg_.isFromBayesNetwork()); assert (fg_.bayesianFactors());
Scheduling scheduling; Scheduling scheduling;
for (unsigned i = 0; i < queryIds.size(); i++) { for (size_t i = 0; i < queryIds.size(); i++) {
assert (dag_.getNode (queryIds[i])); assert (dag_.getNode (queryIds[i]));
DAGraphNode* n = dag_.getNode (queryIds[i]); DAGraphNode* n = dag_.getNode (queryIds[i]);
scheduling.push (ScheduleInfo (n, false, true)); scheduling.push (ScheduleInfo (n, false, true));
@ -60,11 +58,11 @@ void
BayesBall::constructGraph (FactorGraph* fg) const BayesBall::constructGraph (FactorGraph* fg) const
{ {
const FacNodes& facNodes = fg_.facNodes(); const FacNodes& facNodes = fg_.facNodes();
for (unsigned i = 0; i < facNodes.size(); i++) { for (size_t i = 0; i < facNodes.size(); i++) {
const DAGraphNode* n = dag_.getNode ( const DAGraphNode* n = dag_.getNode (
facNodes[i]->factor().argument (0)); facNodes[i]->factor().argument (0));
if (n->isMarkedOnTop()) { if (n->isMarkedOnTop()) {
fg->addFactor (Factor (facNodes[i]->factor())); fg->addFactor (facNodes[i]->factor());
} else if (n->hasEvidence() && n->isVisited()) { } else if (n->hasEvidence() && n->isVisited()) {
VarIds varIds = { facNodes[i]->factor().argument (0) }; VarIds varIds = { facNodes[i]->factor().argument (0) };
Ranges ranges = { facNodes[i]->factor().range (0) }; Ranges ranges = { facNodes[i]->factor().range (0) };
@ -73,5 +71,14 @@ BayesBall::constructGraph (FactorGraph* fg) const
fg->addFactor (Factor (varIds, ranges, params)); fg->addFactor (Factor (varIds, ranges, params));
} }
} }
const VarNodes& varNodes = fg_.varNodes();
for (size_t i = 0; i < varNodes.size(); i++) {
if (varNodes[i]->hasEvidence()) {
VarNode* vn = fg->getVarNode (varNodes[i]->varId());
if (vn) {
vn->setEvidence (varNodes[i]->getEvidence());
}
}
}
} }

View File

@ -64,7 +64,7 @@ BayesBall::scheduleParents (const DAGraphNode* n, Scheduling& sch) const
{ {
const vector<DAGraphNode*>& ps = n->parents(); const vector<DAGraphNode*>& ps = n->parents();
for (vector<DAGraphNode*>::const_iterator it = ps.begin(); for (vector<DAGraphNode*>::const_iterator it = ps.begin();
it != ps.end(); it++) { it != ps.end(); ++it) {
sch.push (ScheduleInfo (*it, false, true)); sch.push (ScheduleInfo (*it, false, true));
} }
} }
@ -76,7 +76,7 @@ BayesBall::scheduleChilds (const DAGraphNode* n, Scheduling& sch) const
{ {
const vector<DAGraphNode*>& cs = n->childs(); const vector<DAGraphNode*>& cs = n->childs();
for (vector<DAGraphNode*>::const_iterator it = cs.begin(); for (vector<DAGraphNode*>::const_iterator it = cs.begin();
it != cs.end(); it++) { it != cs.end(); ++it) {
sch.push (ScheduleInfo (*it, true, false)); sch.push (ScheduleInfo (*it, true, false));
} }
} }

View File

@ -57,7 +57,7 @@ DAGraph::getNode (VarId vid)
void void
DAGraph::setIndexes (void) DAGraph::setIndexes (void)
{ {
for (unsigned i = 0; i < nodes_.size(); i++) { for (size_t i = 0; i < nodes_.size(); i++) {
nodes_[i]->setIndex (i); nodes_[i]->setIndex (i);
} }
} }
@ -67,7 +67,7 @@ DAGraph::setIndexes (void)
void void
DAGraph::clear (void) DAGraph::clear (void)
{ {
for (unsigned i = 0; i < nodes_.size(); i++) { for (size_t i = 0; i < nodes_.size(); i++) {
nodes_[i]->clear(); nodes_[i]->clear();
} }
} }
@ -85,7 +85,7 @@ DAGraph::exportToGraphViz (const char* fileName)
} }
out << "digraph {" << endl; out << "digraph {" << endl;
out << "ranksep=1" << endl; out << "ranksep=1" << endl;
for (unsigned i = 0; i < nodes_.size(); i++) { for (size_t i = 0; i < nodes_.size(); i++) {
out << nodes_[i]->varId() ; out << nodes_[i]->varId() ;
out << " [" ; out << " [" ;
out << "label=\"" << nodes_[i]->label() << "\"" ; out << "label=\"" << nodes_[i]->label() << "\"" ;
@ -94,9 +94,9 @@ DAGraph::exportToGraphViz (const char* fileName)
} }
out << "]" << endl; out << "]" << endl;
} }
for (unsigned i = 0; i < nodes_.size(); i++) { for (size_t i = 0; i < nodes_.size(); i++) {
const vector<DAGraphNode*>& childs = nodes_[i]->childs(); const vector<DAGraphNode*>& childs = nodes_[i]->childs();
for (unsigned j = 0; j < childs.size(); j++) { for (size_t j = 0; j < childs.size(); j++) {
out << nodes_[i]->varId() << " -> " << childs[j]->varId(); out << nodes_[i]->varId() << " -> " << childs[j]->varId();
out << " [style=bold]" << endl ; out << " [style=bold]" << endl ;
} }

View File

@ -0,0 +1,502 @@
#include <cassert>
#include <limits>
#include <algorithm>
#include <iostream>
#include "BpSolver.h"
#include "FactorGraph.h"
#include "Factor.h"
#include "Indexer.h"
#include "Horus.h"
BpSolver::BpSolver (const FactorGraph& fg) : Solver (fg)
{
runned_ = false;
}
BpSolver::~BpSolver (void)
{
for (size_t i = 0; i < varsI_.size(); i++) {
delete varsI_[i];
}
for (size_t i = 0; i < facsI_.size(); i++) {
delete facsI_[i];
}
for (size_t i = 0; i < links_.size(); i++) {
delete links_[i];
}
}
Params
BpSolver::solveQuery (VarIds queryVids)
{
assert (queryVids.empty() == false);
return queryVids.size() == 1
? getPosterioriOf (queryVids[0])
: getJointDistributionOf (queryVids);
}
void
BpSolver::printSolverFlags (void) const
{
stringstream ss;
ss << "belief propagation [" ;
ss << "schedule=" ;
typedef BpOptions::Schedule Sch;
switch (BpOptions::schedule) {
case Sch::SEQ_FIXED: ss << "seq_fixed"; break;
case Sch::SEQ_RANDOM: ss << "seq_random"; break;
case Sch::PARALLEL: ss << "parallel"; break;
case Sch::MAX_RESIDUAL: ss << "max_residual"; break;
}
ss << ",max_iter=" << Util::toString (BpOptions::maxIter);
ss << ",accuracy=" << Util::toString (BpOptions::accuracy);
ss << ",log_domain=" << Util::toString (Globals::logDomain);
ss << "]" ;
cout << ss.str() << endl;
}
Params
BpSolver::getPosterioriOf (VarId vid)
{
if (runned_ == false) {
runSolver();
}
assert (fg.getVarNode (vid));
VarNode* var = fg.getVarNode (vid);
Params probs;
if (var->hasEvidence()) {
probs.resize (var->range(), LogAware::noEvidence());
probs[var->getEvidence()] = LogAware::withEvidence();
} else {
probs.resize (var->range(), LogAware::multIdenty());
const BpLinks& links = ninf(var)->getLinks();
if (Globals::logDomain) {
for (size_t i = 0; i < links.size(); i++) {
probs += links[i]->message();
}
LogAware::normalize (probs);
Util::exp (probs);
} else {
for (size_t i = 0; i < links.size(); i++) {
probs *= links[i]->message();
}
LogAware::normalize (probs);
}
}
return probs;
}
Params
BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
{
if (runned_ == false) {
runSolver();
}
VarNode* vn = fg.getVarNode (jointVarIds[0]);
const FacNodes& facNodes = vn->neighbors();
size_t idx = facNodes.size();
for (size_t i = 0; i < facNodes.size(); i++) {
if (facNodes[i]->factor().contains (jointVarIds)) {
idx = i;
break;
}
}
if (idx == facNodes.size()) {
return getJointByConditioning (jointVarIds);
} else {
Factor res (facNodes[idx]->factor());
const BpLinks& links = ninf(facNodes[idx])->getLinks();
for (size_t i = 0; i < links.size(); i++) {
Factor msg ({links[i]->varNode()->varId()},
{links[i]->varNode()->range()},
getVarToFactorMsg (links[i]));
res.multiply (msg);
}
res.sumOutAllExcept (jointVarIds);
res.reorderArguments (jointVarIds);
res.normalize();
Params jointDist = res.params();
if (Globals::logDomain) {
Util::exp (jointDist);
}
return jointDist;
}
}
void
BpSolver::runSolver (void)
{
initializeSolver();
nIters_ = 0;
while (!converged() && nIters_ < BpOptions::maxIter) {
nIters_ ++;
if (Globals::verbosity > 1) {
Util::printHeader (string ("Iteration ") + Util::toString (nIters_));
}
switch (BpOptions::schedule) {
case BpOptions::Schedule::SEQ_RANDOM:
std::random_shuffle (links_.begin(), links_.end());
// no break
case BpOptions::Schedule::SEQ_FIXED:
for (size_t i = 0; i < links_.size(); i++) {
calculateAndUpdateMessage (links_[i]);
}
break;
case BpOptions::Schedule::PARALLEL:
for (size_t i = 0; i < links_.size(); i++) {
calculateMessage (links_[i]);
}
for (size_t i = 0; i < links_.size(); i++) {
updateMessage(links_[i]);
}
break;
case BpOptions::Schedule::MAX_RESIDUAL:
maxResidualSchedule();
break;
}
}
if (Globals::verbosity > 0) {
if (nIters_ < BpOptions::maxIter) {
cout << "Sum-Product converged in " ;
cout << nIters_ << " iterations" << endl;
} else {
cout << "The maximum number of iterations was hit, terminating..." ;
cout << endl;
}
cout << endl;
}
runned_ = true;
}
void
BpSolver::createLinks (void)
{
const FacNodes& facNodes = fg.facNodes();
for (size_t i = 0; i < facNodes.size(); i++) {
const VarNodes& neighbors = facNodes[i]->neighbors();
for (size_t j = 0; j < neighbors.size(); j++) {
links_.push_back (new BpLink (facNodes[i], neighbors[j]));
}
}
}
void
BpSolver::maxResidualSchedule (void)
{
if (nIters_ == 1) {
for (size_t i = 0; i < links_.size(); i++) {
calculateMessage (links_[i]);
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
linkMap_.insert (make_pair (links_[i], it));
}
return;
}
for (size_t c = 0; c < links_.size(); c++) {
if (Globals::verbosity > 1) {
cout << "current residuals:" << endl;
for (SortedOrder::iterator it = sortedOrder_.begin();
it != sortedOrder_.end(); ++it) {
cout << " " << setw (30) << left << (*it)->toString();
cout << "residual = " << (*it)->residual() << endl;
}
}
SortedOrder::iterator it = sortedOrder_.begin();
BpLink* link = *it;
if (link->residual() < BpOptions::accuracy) {
return;
}
updateMessage (link);
link->clearResidual();
sortedOrder_.erase (it);
linkMap_.find (link)->second = sortedOrder_.insert (link);
// update the messages that depend on message source --> destin
const FacNodes& factorNeighbors = link->varNode()->neighbors();
for (size_t i = 0; i < factorNeighbors.size(); i++) {
if (factorNeighbors[i] != link->facNode()) {
const BpLinks& links = ninf(factorNeighbors[i])->getLinks();
for (size_t j = 0; j < links.size(); j++) {
if (links[j]->varNode() != link->varNode()) {
calculateMessage (links[j]);
BpLinkMap::iterator iter = linkMap_.find (links[j]);
sortedOrder_.erase (iter->second);
iter->second = sortedOrder_.insert (links[j]);
}
}
}
}
if (Globals::verbosity > 1) {
Util::printDashedLine();
}
}
}
void
BpSolver::calcFactorToVarMsg (BpLink* link)
{
FacNode* src = link->facNode();
const VarNode* dst = link->varNode();
const BpLinks& links = ninf(src)->getLinks();
// calculate the product of messages that were sent
// to factor `src', except from var `dst'
unsigned reps = 1;
unsigned msgSize = Util::sizeExpected (src->factor().ranges());
Params msgProduct (msgSize, LogAware::multIdenty());
if (Globals::logDomain) {
for (size_t i = links.size(); i-- > 0; ) {
if (links[i]->varNode() != dst) {
if (Constants::SHOW_BP_CALCS) {
cout << " message from " << links[i]->varNode()->label();
cout << ": " ;
}
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
reps, std::plus<double>());
if (Constants::SHOW_BP_CALCS) {
cout << endl;
}
}
reps *= links[i]->varNode()->range();
}
} else {
for (size_t i = links.size(); i-- > 0; ) {
if (links[i]->varNode() != dst) {
if (Constants::SHOW_BP_CALCS) {
cout << " message from " << links[i]->varNode()->label();
cout << ": " ;
}
Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
reps, std::multiplies<double>());
if (Constants::SHOW_BP_CALCS) {
cout << endl;
}
}
reps *= links[i]->varNode()->range();
}
}
Factor result (src->factor().arguments(),
src->factor().ranges(), msgProduct);
result.multiply (src->factor());
if (Constants::SHOW_BP_CALCS) {
cout << " message product: " << msgProduct << endl;
cout << " original factor: " << src->factor().params() << endl;
cout << " factor product: " << result.params() << endl;
}
result.sumOutAllExcept (dst->varId());
if (Constants::SHOW_BP_CALCS) {
cout << " marginalized: " << result.params() << endl;
}
link->nextMessage() = result.params();
LogAware::normalize (link->nextMessage());
if (Constants::SHOW_BP_CALCS) {
cout << " curr msg: " << link->message() << endl;
cout << " next msg: " << link->nextMessage() << endl;
}
}
Params
BpSolver::getVarToFactorMsg (const BpLink* link) const
{
const VarNode* src = link->varNode();
Params msg;
if (src->hasEvidence()) {
msg.resize (src->range(), LogAware::noEvidence());
msg[src->getEvidence()] = LogAware::withEvidence();
} else {
msg.resize (src->range(), LogAware::one());
}
if (Constants::SHOW_BP_CALCS) {
cout << msg;
}
BpLinks::const_iterator it;
const BpLinks& links = ninf (src)->getLinks();
if (Globals::logDomain) {
for (it = links.begin(); it != links.end(); ++it) {
msg += (*it)->message();
if (Constants::SHOW_BP_CALCS) {
cout << " x " << (*it)->message();
}
}
msg -= link->message();
} else {
for (it = links.begin(); it != links.end(); ++it) {
msg *= (*it)->message();
if (Constants::SHOW_BP_CALCS) {
cout << " x " << (*it)->message();
}
}
msg /= link->message();
}
if (Constants::SHOW_BP_CALCS) {
cout << " = " << msg;
}
return msg;
}
Params
BpSolver::getJointByConditioning (const VarIds& jointVarIds) const
{
VarNodes jointVars;
for (size_t i = 0; i < jointVarIds.size(); i++) {
assert (fg.getVarNode (jointVarIds[i]));
jointVars.push_back (fg.getVarNode (jointVarIds[i]));
}
FactorGraph* tempFg = new FactorGraph (fg);
BpSolver solver (*tempFg);
solver.runSolver();
Params prevBeliefs = solver.getPosterioriOf (jointVarIds[0]);
VarIds observedVids = {jointVars[0]->varId()};
for (size_t i = 1; i < jointVarIds.size(); i++) {
assert (jointVars[i]->hasEvidence() == false);
Params newBeliefs;
Vars observedVars;
Ranges observedRanges;
for (size_t j = 0; j < observedVids.size(); j++) {
observedVars.push_back (tempFg->getVarNode (observedVids[j]));
observedRanges.push_back (observedVars.back()->range());
}
Indexer indexer (observedRanges, false);
while (indexer.valid()) {
for (size_t j = 0; j < observedVars.size(); j++) {
observedVars[j]->setEvidence (indexer[j]);
}
BpSolver solver (*tempFg);
solver.runSolver();
Params beliefs = solver.getPosterioriOf (jointVarIds[i]);
for (size_t k = 0; k < beliefs.size(); k++) {
newBeliefs.push_back (beliefs[k]);
}
++ indexer;
}
int count = -1;
for (size_t j = 0; j < newBeliefs.size(); j++) {
if (j % jointVars[i]->range() == 0) {
count ++;
}
newBeliefs[j] *= prevBeliefs[count];
}
prevBeliefs = newBeliefs;
observedVids.push_back (jointVars[i]->varId());
}
return prevBeliefs;
}
void
BpSolver::initializeSolver (void)
{
const VarNodes& varNodes = fg.varNodes();
varsI_.reserve (varNodes.size());
for (size_t i = 0; i < varNodes.size(); i++) {
varsI_.push_back (new SPNodeInfo());
}
const FacNodes& facNodes = fg.facNodes();
facsI_.reserve (facNodes.size());
for (size_t i = 0; i < facNodes.size(); i++) {
facsI_.push_back (new SPNodeInfo());
}
createLinks();
for (size_t i = 0; i < links_.size(); i++) {
FacNode* src = links_[i]->facNode();
VarNode* dst = links_[i]->varNode();
ninf (dst)->addBpLink (links_[i]);
ninf (src)->addBpLink (links_[i]);
}
}
bool
BpSolver::converged (void)
{
if (links_.size() == 0) {
return true;
}
if (nIters_ == 0) {
return false;
}
if (Globals::verbosity > 2) {
cout << endl;
}
if (nIters_ == 1) {
if (Globals::verbosity > 1) {
cout << "no residuals" << endl << endl;
}
return false;
}
bool converged = true;
if (BpOptions::schedule == BpOptions::Schedule::MAX_RESIDUAL) {
double maxResidual = (*(sortedOrder_.begin()))->residual();
if (maxResidual > BpOptions::accuracy) {
converged = false;
} else {
converged = true;
}
} else {
for (size_t i = 0; i < links_.size(); i++) {
double residual = links_[i]->residual();
if (Globals::verbosity > 1) {
cout << links_[i]->toString() + " residual = " << residual << endl;
}
if (residual > BpOptions::accuracy) {
converged = false;
if (Globals::verbosity < 2) {
break;
}
}
}
if (Globals::verbosity > 1) {
cout << endl;
}
}
return converged;
}
void
BpSolver::printLinkInformation (void) const
{
for (size_t i = 0; i < links_.size(); i++) {
BpLink* l = links_[i];
cout << l->toString() << ":" << endl;
cout << " curr msg = " ;
cout << l->message() << endl;
cout << " next msg = " ;
cout << l->nextMessage() << endl;
cout << " residual = " << l->residual() << endl;
}
}

View File

@ -13,34 +13,31 @@
using namespace std; using namespace std;
class SpLink class BpLink
{ {
public: public:
SpLink (FacNode* fn, VarNode* vn) BpLink (FacNode* fn, VarNode* vn)
{ {
fac_ = fn; fac_ = fn;
var_ = vn; var_ = vn;
v1_.resize (vn->range(), LogAware::tl (1.0 / vn->range())); v1_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
v2_.resize (vn->range(), LogAware::tl (1.0 / vn->range())); v2_.resize (vn->range(), LogAware::log (1.0 / vn->range()));
currMsg_ = &v1_; currMsg_ = &v1_;
nextMsg_ = &v2_; nextMsg_ = &v2_;
msgSended_ = false;
residual_ = 0.0; residual_ = 0.0;
} }
virtual ~SpLink (void) { }; virtual ~BpLink (void) { };
FacNode* getFactor (void) const { return fac_; } FacNode* facNode (void) const { return fac_; }
VarNode* getVariable (void) const { return var_; } VarNode* varNode (void) const { return var_; }
const Params& getMessage (void) const { return *currMsg_; } const Params& message (void) const { return *currMsg_; }
Params& getNextMessage (void) { return *nextMsg_; } Params& nextMessage (void) { return *nextMsg_; }
bool messageWasSended (void) const { return msgSended_; } double residual (void) const { return residual_; }
double getResidual (void) const { return residual_; }
void clearResidual (void) { residual_ = 0.0; } void clearResidual (void) { residual_ = 0.0; }
@ -52,7 +49,6 @@ class SpLink
virtual void updateMessage (void) virtual void updateMessage (void)
{ {
swap (currMsg_, nextMsg_); swap (currMsg_, nextMsg_);
msgSended_ = true;
} }
string toString (void) const string toString (void) const
@ -71,20 +67,19 @@ class SpLink
Params v2_; Params v2_;
Params* currMsg_; Params* currMsg_;
Params* nextMsg_; Params* nextMsg_;
bool msgSended_;
double residual_; double residual_;
}; };
typedef vector<SpLink*> SpLinkSet; typedef vector<BpLink*> BpLinks;
class SPNodeInfo class SPNodeInfo
{ {
public: public:
void addSpLink (SpLink* link) { links_.push_back (link); } void addBpLink (BpLink* link) { links_.push_back (link); }
const SpLinkSet& getLinks (void) { return links_; } const BpLinks& getLinks (void) { return links_; }
private: private:
SpLinkSet links_; BpLinks links_;
}; };
@ -97,6 +92,8 @@ class BpSolver : public Solver
Params solveQuery (VarIds); Params solveQuery (VarIds);
virtual void printSolverFlags (void) const;
virtual Params getPosterioriOf (VarId); virtual Params getPosterioriOf (VarId);
virtual Params getJointDistributionOf (const VarIds&); virtual Params getJointDistributionOf (const VarIds&);
@ -108,9 +105,9 @@ class BpSolver : public Solver
virtual void maxResidualSchedule (void); virtual void maxResidualSchedule (void);
virtual void calculateFactor2VariableMsg (SpLink*); virtual void calcFactorToVarMsg (BpLink*);
virtual Params getVar2FactorMsg (const SpLink*) const; virtual Params getVarToFactorMsg (const BpLink*) const;
virtual Params getJointByConditioning (const VarIds&) const; virtual Params getJointByConditioning (const VarIds&) const;
@ -124,64 +121,63 @@ class BpSolver : public Solver
return facsI_[fac->getIndex()]; return facsI_[fac->getIndex()];
} }
void calculateAndUpdateMessage (SpLink* link, bool calcResidual = true) void calculateAndUpdateMessage (BpLink* link, bool calcResidual = true)
{ {
if (Constants::DEBUG >= 3) { if (Globals::verbosity > 2) {
cout << "calculating & updating " << link->toString() << endl; cout << "calculating & updating " << link->toString() << endl;
} }
calculateFactor2VariableMsg (link); calcFactorToVarMsg (link);
if (calcResidual) { if (calcResidual) {
link->updateResidual(); link->updateResidual();
} }
link->updateMessage(); link->updateMessage();
} }
void calculateMessage (SpLink* link, bool calcResidual = true) void calculateMessage (BpLink* link, bool calcResidual = true)
{ {
if (Constants::DEBUG >= 3) { if (Globals::verbosity > 2) {
cout << "calculating " << link->toString() << endl; cout << "calculating " << link->toString() << endl;
} }
calculateFactor2VariableMsg (link); calcFactorToVarMsg (link);
if (calcResidual) { if (calcResidual) {
link->updateResidual(); link->updateResidual();
} }
} }
void updateMessage (SpLink* link) void updateMessage (BpLink* link)
{ {
link->updateMessage(); link->updateMessage();
if (Constants::DEBUG >= 3) { if (Globals::verbosity > 2) {
cout << "updating " << link->toString() << endl; cout << "updating " << link->toString() << endl;
} }
} }
struct CompareResidual struct CompareResidual
{ {
inline bool operator() (const SpLink* link1, const SpLink* link2) inline bool operator() (const BpLink* link1, const BpLink* link2)
{ {
return link1->getResidual() > link2->getResidual(); return link1->residual() > link2->residual();
} }
}; };
SpLinkSet links_; BpLinks links_;
unsigned nIters_; unsigned nIters_;
vector<SPNodeInfo*> varsI_; vector<SPNodeInfo*> varsI_;
vector<SPNodeInfo*> facsI_; vector<SPNodeInfo*> facsI_;
bool runned_; bool runned_;
const FactorGraph* fg_;
typedef multiset<SpLink*, CompareResidual> SortedOrder; typedef multiset<BpLink*, CompareResidual> SortedOrder;
SortedOrder sortedOrder_; SortedOrder sortedOrder_;
typedef unordered_map<SpLink*, SortedOrder::iterator> SpLinkMap; typedef unordered_map<BpLink*, SortedOrder::iterator> BpLinkMap;
SpLinkMap linkMap_; BpLinkMap linkMap_;
private: private:
void initializeSolver (void); void initializeSolver (void);
bool converged (void); bool converged (void);
void printLinkInformation (void) const; virtual void printLinkInformation (void) const;
}; };
#endif // HORUS_BPSOLVER_H #endif // HORUS_BPSOLVER_H

View File

@ -0,0 +1,400 @@
#include "CbpSolver.h"
#include "WeightedBpSolver.h"
bool CbpSolver::checkForIdenticalFactors = true;
CbpSolver::CbpSolver (const FactorGraph& fg)
: Solver (fg), freeColor_(0)
{
findIdenticalFactors();
setInitialColors();
createGroups();
compressedFg_ = getCompressedFactorGraph();
solver_ = new WeightedBpSolver (*compressedFg_, getWeights());
}
CbpSolver::~CbpSolver (void)
{
delete solver_;
delete compressedFg_;
for (size_t i = 0; i < varClusters_.size(); i++) {
delete varClusters_[i];
}
for (size_t i = 0; i < facClusters_.size(); i++) {
delete facClusters_[i];
}
}
void
CbpSolver::printSolverFlags (void) const
{
stringstream ss;
ss << "counting bp [" ;
ss << "schedule=" ;
typedef BpOptions::Schedule Sch;
switch (BpOptions::schedule) {
case Sch::SEQ_FIXED: ss << "seq_fixed"; break;
case Sch::SEQ_RANDOM: ss << "seq_random"; break;
case Sch::PARALLEL: ss << "parallel"; break;
case Sch::MAX_RESIDUAL: ss << "max_residual"; break;
}
ss << ",max_iter=" << BpOptions::maxIter;
ss << ",accuracy=" << BpOptions::accuracy;
ss << ",log_domain=" << Util::toString (Globals::logDomain);
ss << ",chkif=" <<
Util::toString (CbpSolver::checkForIdenticalFactors);
ss << "]" ;
cout << ss.str() << endl;
}
Params
CbpSolver::solveQuery (VarIds queryVids)
{
assert (queryVids.empty() == false);
Params res;
if (queryVids.size() == 1) {
res = solver_->getPosterioriOf (getRepresentative (queryVids[0]));
} else {
VarNode* vn = fg.getVarNode (queryVids[0]);
const FacNodes& facNodes = vn->neighbors();
size_t idx = facNodes.size();
for (size_t i = 0; i < facNodes.size(); i++) {
if (facNodes[i]->factor().contains (queryVids)) {
idx = i;
break;
}
cout << endl;
}
if (idx == facNodes.size()) {
cerr << "error: only joint distributions on variables of some " ;
cerr << "clique are supported with the current solver" ;
cerr << endl;
exit (1);
}
VarIds representatives;
for (size_t i = 0; i < queryVids.size(); i++) {
representatives.push_back (getRepresentative (queryVids[i]));
}
res = solver_->getJointDistributionOf (representatives);
}
return res;
}
void
CbpSolver::findIdenticalFactors()
{
const FacNodes& facNodes = fg.facNodes();
if (checkForIdenticalFactors == false ||
facNodes.size() == 1) {
return;
}
for (size_t i = 0; i < facNodes.size(); i++) {
facNodes[i]->factor().setDistId (Util::maxUnsigned());
}
unsigned groupCount = 1;
for (size_t i = 0; i < facNodes.size() - 1; i++) {
Factor& f1 = facNodes[i]->factor();
if (f1.distId() != Util::maxUnsigned()) {
continue;
}
f1.setDistId (groupCount);
for (size_t j = i + 1; j < facNodes.size(); j++) {
Factor& f2 = facNodes[j]->factor();
if (f2.distId() != Util::maxUnsigned()) {
continue;
}
if (f1.size() == f2.size() &&
f1.ranges() == f2.ranges() &&
f1.params() == f2.params()) {
f2.setDistId (groupCount);
}
}
groupCount ++;
}
}
void
CbpSolver::setInitialColors (void)
{
varColors_.resize (fg.nrVarNodes());
facColors_.resize (fg.nrFacNodes());
// create the initial variable colors
VarColorMap colorMap;
const VarNodes& varNodes = fg.varNodes();
for (size_t i = 0; i < varNodes.size(); i++) {
unsigned range = varNodes[i]->range();
VarColorMap::iterator it = colorMap.find (range);
if (it == colorMap.end()) {
it = colorMap.insert (make_pair (
range, Colors (range + 1, -1))).first;
}
unsigned idx = varNodes[i]->hasEvidence()
? varNodes[i]->getEvidence()
: range;
Colors& stateColors = it->second;
if (stateColors[idx] == -1) {
stateColors[idx] = getNewColor();
}
setColor (varNodes[i], stateColors[idx]);
}
const FacNodes& facNodes = fg.facNodes();
// create the initial factor colors
DistColorMap distColors;
for (size_t i = 0; i < facNodes.size(); i++) {
unsigned distId = facNodes[i]->factor().distId();
DistColorMap::iterator it = distColors.find (distId);
if (it == distColors.end()) {
it = distColors.insert (make_pair (distId, getNewColor())).first;
}
setColor (facNodes[i], it->second);
}
}
void
CbpSolver::createGroups (void)
{
VarSignMap varGroups;
FacSignMap facGroups;
unsigned nIters = 0;
bool groupsHaveChanged = true;
const VarNodes& varNodes = fg.varNodes();
const FacNodes& facNodes = fg.facNodes();
while (groupsHaveChanged || nIters == 1) {
nIters ++;
// set a new color to the variables with the same signature
size_t prevVarGroupsSize = varGroups.size();
varGroups.clear();
for (size_t i = 0; i < varNodes.size(); i++) {
const VarSignature& signature = getSignature (varNodes[i]);
VarSignMap::iterator it = varGroups.find (signature);
if (it == varGroups.end()) {
it = varGroups.insert (make_pair (signature, VarNodes())).first;
}
it->second.push_back (varNodes[i]);
}
for (VarSignMap::iterator it = varGroups.begin();
it != varGroups.end(); ++it) {
Color newColor = getNewColor();
VarNodes& groupMembers = it->second;
for (size_t i = 0; i < groupMembers.size(); i++) {
setColor (groupMembers[i], newColor);
}
}
size_t prevFactorGroupsSize = facGroups.size();
facGroups.clear();
// set a new color to the factors with the same signature
for (size_t i = 0; i < facNodes.size(); i++) {
const FacSignature& signature = getSignature (facNodes[i]);
FacSignMap::iterator it = facGroups.find (signature);
if (it == facGroups.end()) {
it = facGroups.insert (make_pair (signature, FacNodes())).first;
}
it->second.push_back (facNodes[i]);
}
for (FacSignMap::iterator it = facGroups.begin();
it != facGroups.end(); ++it) {
Color newColor = getNewColor();
FacNodes& groupMembers = it->second;
for (size_t i = 0; i < groupMembers.size(); i++) {
setColor (groupMembers[i], newColor);
}
}
groupsHaveChanged = prevVarGroupsSize != varGroups.size()
|| prevFactorGroupsSize != facGroups.size();
}
// printGroups (varGroups, facGroups);
createClusters (varGroups, facGroups);
}
void
CbpSolver::createClusters (
const VarSignMap& varGroups,
const FacSignMap& facGroups)
{
varClusters_.reserve (varGroups.size());
for (VarSignMap::const_iterator it = varGroups.begin();
it != varGroups.end(); ++it) {
const VarNodes& groupVars = it->second;
VarCluster* vc = new VarCluster (groupVars);
for (size_t i = 0; i < groupVars.size(); i++) {
vid2VarCluster_.insert (make_pair (groupVars[i]->varId(), vc));
}
varClusters_.push_back (vc);
}
facClusters_.reserve (facGroups.size());
for (FacSignMap::const_iterator it = facGroups.begin();
it != facGroups.end(); ++it) {
FacNode* groupFactor = it->second[0];
const VarNodes& neighs = groupFactor->neighbors();
VarClusters varClusters;
varClusters.reserve (neighs.size());
for (size_t i = 0; i < neighs.size(); i++) {
VarId vid = neighs[i]->varId();
varClusters.push_back (vid2VarCluster_.find (vid)->second);
}
facClusters_.push_back (new FacCluster (it->second, varClusters));
}
}
VarSignature
CbpSolver::getSignature (const VarNode* varNode)
{
const FacNodes& neighs = varNode->neighbors();
VarSignature sign;
sign.reserve (neighs.size() + 1);
for (size_t i = 0; i < neighs.size(); i++) {
sign.push_back (make_pair (
getColor (neighs[i]),
neighs[i]->factor().indexOf (varNode->varId())));
}
std::sort (sign.begin(), sign.end());
sign.push_back (make_pair (getColor (varNode), 0));
return sign;
}
FacSignature
CbpSolver::getSignature (const FacNode* facNode)
{
const VarNodes& neighs = facNode->neighbors();
FacSignature sign;
sign.reserve (neighs.size() + 1);
for (size_t i = 0; i < neighs.size(); i++) {
sign.push_back (getColor (neighs[i]));
}
sign.push_back (getColor (facNode));
return sign;
}
FactorGraph*
CbpSolver::getCompressedFactorGraph (void)
{
FactorGraph* fg = new FactorGraph();
for (size_t i = 0; i < varClusters_.size(); i++) {
VarNode* newVar = new VarNode (varClusters_[i]->first());
varClusters_[i]->setRepresentative (newVar);
fg->addVarNode (newVar);
}
for (size_t i = 0; i < facClusters_.size(); i++) {
Vars vars;
const VarClusters& clusters = facClusters_[i]->varClusters();
for (size_t j = 0; j < clusters.size(); j++) {
vars.push_back (clusters[j]->representative());
}
const Factor& groundFac = facClusters_[i]->first()->factor();
FacNode* fn = new FacNode (Factor (
vars, groundFac.params(), groundFac.distId()));
facClusters_[i]->setRepresentative (fn);
fg->addFacNode (fn);
for (size_t j = 0; j < vars.size(); j++) {
fg->addEdge (static_cast<VarNode*> (vars[j]), fn);
}
}
return fg;
}
vector<vector<unsigned>>
CbpSolver::getWeights (void) const
{
vector<vector<unsigned>> weights;
weights.reserve (facClusters_.size());
for (size_t i = 0; i < facClusters_.size(); i++) {
const VarClusters& neighs = facClusters_[i]->varClusters();
weights.push_back ({ });
weights.back().reserve (neighs.size());
for (size_t j = 0; j < neighs.size(); j++) {
weights.back().push_back (getWeight (
facClusters_[i], neighs[j], j));
}
}
return weights;
}
unsigned
CbpSolver::getWeight (
const FacCluster* fc,
const VarCluster* vc,
size_t index) const
{
unsigned weight = 0;
VarId reprVid = vc->representative()->varId();
VarNode* groundVar = fg.getVarNode (reprVid);
const FacNodes& neighs = groundVar->neighbors();
for (size_t i = 0; i < neighs.size(); i++) {
FacNodes::const_iterator it;
it = std::find (fc->members().begin(), fc->members().end(), neighs[i]);
if (it != fc->members().end() &&
(*it)->factor().indexOf (reprVid) == index) {
weight ++;
}
}
return weight;
}
void
CbpSolver::printGroups (
const VarSignMap& varGroups,
const FacSignMap& facGroups) const
{
unsigned count = 1;
cout << "variable groups:" << endl;
for (VarSignMap::const_iterator it = varGroups.begin();
it != varGroups.end(); ++it) {
const VarNodes& groupMembers = it->second;
if (groupMembers.size() > 0) {
cout << count << ": " ;
for (size_t i = 0; i < groupMembers.size(); i++) {
cout << groupMembers[i]->label() << " " ;
}
count ++;
cout << endl;
}
}
count = 1;
cout << endl << "factor groups:" << endl;
for (FacSignMap::const_iterator it = facGroups.begin();
it != facGroups.end(); ++it) {
const FacNodes& groupMembers = it->second;
if (groupMembers.size() > 0) {
cout << ++count << ": " ;
for (size_t i = 0; i < groupMembers.size(); i++) {
cout << groupMembers[i]->getLabel() << " " ;
}
count ++;
cout << endl;
}
}
}

View File

@ -0,0 +1,183 @@
#ifndef HORUS_CBPSOLVER_H
#define HORUS_CBPSOLVER_H
#include <unordered_map>
#include "Solver.h"
#include "FactorGraph.h"
#include "Util.h"
#include "Horus.h"
class VarCluster;
class FacCluster;
class VarSignHash;
class FacSignHash;
class WeightedBpSolver;
typedef long Color;
typedef vector<Color> Colors;
typedef vector<std::pair<Color,unsigned>> VarSignature;
typedef vector<Color> FacSignature;
typedef unordered_map<unsigned, Color> DistColorMap;
typedef unordered_map<unsigned, Colors> VarColorMap;
typedef unordered_map<VarSignature, VarNodes, VarSignHash> VarSignMap;
typedef unordered_map<FacSignature, FacNodes, FacSignHash> FacSignMap;
typedef vector<VarCluster*> VarClusters;
typedef vector<FacCluster*> FacClusters;
typedef unordered_map<VarId, VarCluster*> VarId2VarCluster;
struct VarSignHash
{
size_t operator() (const VarSignature &sig) const
{
size_t val = hash<size_t>()(sig.size());
for (size_t i = 0; i < sig.size(); i++) {
val ^= hash<size_t>()(sig[i].first);
val ^= hash<size_t>()(sig[i].second);
}
return val;
}
};
struct FacSignHash
{
size_t operator() (const FacSignature &sig) const
{
size_t val = hash<size_t>()(sig.size());
for (size_t i = 0; i < sig.size(); i++) {
val ^= hash<size_t>()(sig[i]);
}
return val;
}
};
class VarCluster
{
public:
VarCluster (const VarNodes& vs) : members_(vs) { }
const VarNode* first (void) const { return members_.front(); }
const VarNodes& members (void) const { return members_; }
VarNode* representative (void) const { return repr_; }
void setRepresentative (VarNode* vn) { repr_ = vn; }
private:
VarNodes members_;
VarNode* repr_;
};
class FacCluster
{
public:
FacCluster (const FacNodes& fcs, const VarClusters& vcs)
: members_(fcs), varClusters_(vcs) { }
const FacNode* first (void) const { return members_.front(); }
const FacNodes& members (void) const { return members_; }
VarClusters& varClusters (void) { return varClusters_; }
FacNode* representative (void) const { return repr_; }
void setRepresentative (FacNode* fn) { repr_ = fn; }
private:
FacNodes members_;
VarClusters varClusters_;
FacNode* repr_;
};
class CbpSolver : public Solver
{
public:
CbpSolver (const FactorGraph& fg);
~CbpSolver (void);
void printSolverFlags (void) const;
Params solveQuery (VarIds);
static bool checkForIdenticalFactors;
private:
Color getNewColor (void)
{
++ freeColor_;
return freeColor_ - 1;
}
Color getColor (const VarNode* vn) const
{
return varColors_[vn->getIndex()];
}
Color getColor (const FacNode* fn) const
{
return facColors_[fn->getIndex()];
}
void setColor (const VarNode* vn, Color c)
{
varColors_[vn->getIndex()] = c;
}
void setColor (const FacNode* fn, Color c)
{
facColors_[fn->getIndex()] = c;
}
void findIdenticalFactors (void);
void setInitialColors (void);
void createGroups (void);
void createClusters (const VarSignMap&, const FacSignMap&);
VarSignature getSignature (const VarNode*);
FacSignature getSignature (const FacNode*);
void printGroups (const VarSignMap&, const FacSignMap&) const;
VarId getRepresentative (VarId vid)
{
assert (Util::contains (vid2VarCluster_, vid));
VarCluster* vc = vid2VarCluster_.find (vid)->second;
return vc->representative()->varId();
}
FactorGraph* getCompressedFactorGraph (void);
vector<vector<unsigned>> getWeights (void) const;
unsigned getWeight (const FacCluster*,
const VarCluster*, size_t index) const;
Color freeColor_;
Colors varColors_;
Colors facColors_;
VarClusters varClusters_;
FacClusters facClusters_;
VarId2VarCluster vid2VarCluster_;
const FactorGraph* compressedFg_;
WeightedBpSolver* solver_;
};
#endif // HORUS_CBPSOLVER_H

File diff suppressed because it is too large Load Diff

Some files were not shown because too many files have changed in this diff Show More