support mtbdds.

This commit is contained in:
Vitor Santos Costa 2012-04-20 16:42:59 -05:00
parent 178ad27db8
commit 220f7e6efc
2 changed files with 124 additions and 11 deletions

View File

@ -53,6 +53,10 @@ Va <- P*X1*Y1 + Q*X2*Y2 + ...
:- use_module(library(bdd)).
:- use_module(library(ddnnf)).
:- use_module(library(simpbool)).
:- use_module(library(rbtrees)).
:- use_module(library(bhash)).
@ -63,6 +67,9 @@ Va <- P*X1*Y1 + Q*X2*Y2 + ...
:- attribute order/1.
:- dynamic bdds/1.
bdds(ddnnf).
check_if_bdd_done(_Var).
bdd([[]],_,_) :- !.
@ -271,7 +278,7 @@ avg_tree( _PVars, P, _, Im, IM, _Size, O, H0, H0) :-
b_hash_lookup(k(P,Im,IM), O=_Exp, H0), !.
avg_tree([], _P, _Max, _Im, _IM, _Size, 1, H, H).
avg_tree([Vals|PVars], P, Max, Im, IM, Size, O, H0, HF) :-
b_hash_insert(H0, k(P,Im,IM), O=Simp, HI),
b_hash_insert(H0, k(P,Im,IM), O=Simp*1, HI),
MaxI is Max-(Size-1),
avg_exp(Vals, PVars, 0, P, MaxI, Size, Im, IM, HI, HF, Exp),
simplify_exp(Exp, Simp).
@ -732,7 +739,6 @@ run_bdd_solver([[V]], LPs, bdd(Term, _Leaves, Nodes)) :-
build_out_node(Nodes, Node),
findall(Prob, get_prob(Term, Node, V, Prob),TermProbs),
sumlist(TermProbs, Sum),
writeln(TermProbs:Sum),
normalise(TermProbs, Sum, LPs).
build_out_node([_Top], []).
@ -744,7 +750,12 @@ build_out_node2([T,T1|Tops], T*Top) :-
build_out_node2(T1.Tops, Top).
get_prob(Term, _Node, V, SP) :-
bdds(ddnnf), !,
all_cnfs(Term, CNF, IVs, Indics, V, AllParms, AllParmValues),
build_cnf(CNF, IVs, Indics, AllParms, AllParmValues, SP).
get_prob(Term, Node, V, SP) :-
bdds(bdd), !,
bind_all(Term, Node, Bindings, V, AllParms, AllParmValues),
% reverse(AllParms, RAllParms),
term_variables(AllParms, NVs),
@ -754,25 +765,25 @@ get_prob(Term, Node, V, SP) :-
build_bdd(Bindings, NVs, VTheta, Theta, Bdd) :-
bdd_from_list(Bindings, NVs, Bdd),
bdd_size(Bdd, Len),
number_codes(Len,Codes),
atom_codes(Name,Codes),
bdd_print(Bdd, Name),
writeln(length=Len),
% bdd_size(Bdd, Len),
% number_codes(Len,Codes),
% atom_codes(Name,Codes),
% bdd_print(Bdd, Name),
% writeln(length=Len),
VTheta = Theta.
bind_all([], End, End, _V, [], []).
bind_all(info(V, _Tree, Ev, _Values, Formula, ParmVars, Parms).Term, End, BindsF, V0, ParmVars.AllParms, Parms.AllTheta) :-
bind_all([info(V, _Tree, Ev, _Values, Formula, ParmVars, Parms)|Term], End, BindsF, V0, ParmVars.AllParms, Parms.AllTheta) :-
V0 == V, !,
set_to_one_zeros(Ev),
bind_formula(Formula, BindsF, BindsI),
bind_all(Term, End, BindsI, V0, AllParms, AllTheta).
bind_all(info(_V, _Tree, Ev, _Values, Formula, ParmVars, Parms).Term, End, BindsF, V0, ParmVars.AllParms, Parms.AllTheta) :-
bind_all([info(_V, _Tree, Ev, _Values, Formula, ParmVars, Parms)|Term], End, BindsF, V0, ParmVars.AllParms, Parms.AllTheta) :-
set_to_ones(Ev),!,
bind_formula(Formula, BindsF, BindsI),
bind_all(Term, End, BindsI, V0, AllParms, AllTheta).
% evidence: no need to add any stuff.
bind_all(info(_V, _Tree, _Ev, _Values, Formula, ParmVars, Parms).Term, End, BindsF, V0, ParmVars.AllParms, Parms.AllTheta) :-
bind_all([info(_V, _Tree, _Ev, _Values, Formula, ParmVars, Parms)|Term], End, BindsF, V0, ParmVars.AllParms, Parms.AllTheta) :-
bind_formula(Formula, BindsF, BindsI),
bind_all(Term, End, BindsI, V0, AllParms, AllTheta).
@ -800,3 +811,103 @@ normalise(P.TermProbs, Sum, NP.LPs) :-
finalize_bdd_solver(_).
all_cnfs([], [], [], [], _V, [], []).
all_cnfs([info(V, Tree, Ev, Values, Formula, ParmVars, Parms)|Term], BindsF, IVars, Indics, V0, AllParmsF, AllThetaF) :-
%writeln(f:Formula),
V0 == V, !,
set_to_one_zeros(Ev),
all_indicators(Values, BindsF, Binds0),
indicators(Values, [], Ev, IVars, IVarsI, Indics, IndicsI, Binds0, Binds1),
parms( ParmVars, Parms, AllParmsF, AllThetaF, AllParms, AllTheta),
parameters(Formula, Tree, Binds1, BindsI),
all_cnfs(Term, BindsI, IVarsI, IndicsI, V0, AllParms, AllTheta).
all_cnfs([info(_V, Tree, Ev, Values, Formula, ParmVars, Parms)|Term], BindsF, IVars, Indics, V0, AllParmsF, AllThetaF) :-
set_to_ones(Ev),!,
all_indicators(Values, BindsF, Binds0),
indicators(Values, [], Ev, IVars, IVarsI, Indics, IndicsI, Binds0, Binds1),
parms( ParmVars, Parms, AllParmsF, AllThetaF, AllParms, AllTheta),
parameters(Formula, Tree, Binds1, BindsI),
all_cnfs(Term, BindsI, IVarsI, IndicsI, V0, AllParms, AllTheta).
% evidence: no need to add any stuff.
all_cnfs([info(_V, Tree, Ev, Values, Formula, ParmVars, Parms)|Term], BindsF, IVars, Indics, V0, AllParmsF, AllThetaF) :-
all_indicators(Values, BindsF, Binds0),
indicators(Values, [], Ev, IVars, IVarsI, Indics, IndicsI, Binds0, Binds1),
parms( ParmVars, Parms, AllParmsF, AllThetaF, AllParms, AllTheta),
parameters(Formula, Tree, Binds1, BindsI),
all_cnfs(Term, BindsI, IVarsI, IndicsI, V0, AllParms, AllTheta).
all_indicators(Values) -->
{ values_to_disj(Values, Disj) },
[Disj].
values_to_disj([V], V) :- !.
values_to_disj([V|Values], V+Disj) :-
values_to_disj(Values, Disj).
indicators([V|Vars], SeenVs, [E|Ev], [V|IsF], IsI, [E|Inds], Inds0) -->
generate_exclusions(SeenVs, V),
indicators(Vars, [V|SeenVs], Ev, IsF, IsI, Inds, Inds0).
indicators([], _SeenVs, [], IsF, IsF, Inds, Inds) --> [].
parms([], [], AllParms, AllTheta, AllParms, AllTheta).
parms([V|ParmVars], [P|Parms], [V|AllParmsF], [P|AllThetaF], AllParms, AllTheta) :-
parms( ParmVars, Parms, AllParmsF, AllThetaF, AllParms, AllTheta).
parameters([], _) --> [].
% ignore disj, only useful to BDDs
parameters([(T=_)|Formula], Tree) -->
{ Tree == T }, !,
parameters(Formula, Tree).
parameters([(V0=Disj*_I0)|Formula], Tree) -->
conj(Disj, V0),
parameters(Formula, Tree).
% transform V0<- A*B+C*(D+not(E))
% [V0+not(A)+not(B),V0+not(C)+not(D),V0+not(C)+E]
conj(Disj, V0) -->
{ conj2(Disj, [[V0]], LVs) },
to_disjs(LVs).
conj2(A, L0, LF) :- var(A), !,
add(not(A), L0, LF).
conj2((A*B), L0, LF) :-
conj2(A, L0, LI),
conj2(B, LI, LF).
conj2((A+B), L0, LF) :-
conj2(A, L0, L1),
conj2(B, L0, L2),
append(L1, L2, LF).
conj2(not(A), L0, LF) :-
add(A, L0, LF).
add(_, [], []).
add(Head, [H|L], [[Head|H]|NL]) :-
add(Head, L, NL).
to_disjs([]) --> [].
to_disjs([[H|L]|LVs]) -->
mkdisj(L, H),
to_disjs(LVs).
mkdisj([], Disj) --> [Disj].
mkdisj([H|L], Disj) -->
mkdisj(L, (H+Disj)).
%
% add formula for V \== V0 -> V or V0 and not(V) or not(V0)
%
generate_exclusions([], _V) --> [].
generate_exclusions([V0|SeenVs], V) -->
[(not(V0)+not(V))],
generate_exclusions(SeenVs, V).
build_cnf(CNF, IVs, Indics, AllParms, AllParmValues, Val) :-
%(numbervars(CNF,1,_), writeln(cnf_to_ddnnf(CNF, Vars, IVs, [], F)), fail ; true ),
cnf_to_ddnnf(CNF, AllParms, F),
AllParms = AllParmValues,
IVs = Indics,
term_variables(CNF, Extra),
set_to_ones(Extra),
ddnnf_is(F, Val).

View File

@ -244,7 +244,9 @@ get_dist_domain_size(Id, DSize) :-
recorded(clpbn_dist_db, db(Id, _, _, _, _, _, DSize), _).
get_dist_domain(Id, Domain) :-
recorded(clpbn_dist_db, db(Id, _, _, _, Domain, _, _), _).
recorded(clpbn_dist_db, db(Id, _, _, _, Domain, _, _), _), !.
get_dist_domain(avg(Domain), Domain) :-
recorded(clpbn_dist_db, db(Id, _, _, _, Domain, _, _), _), !.
get_dist_key(Id, Key) :-
use_parfactors(on), !,